mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 00:01:40 +01:00 
			
		
		
		
	This updates all source files to use a new standard header for copyright and license declaration. Notably, copyright no longer includes a date, and we now use the standard SPDX-License-Identifier header. This commit was done almost entirely mechanically with perl, and then some minimal manual fixes. Updates #6865 Signed-off-by: Will Norris <will@tailscale.com>
		
			
				
	
	
		
			237 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			237 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| //go:build windows && cgo
 | |
| 
 | |
| package controlclient
 | |
| 
 | |
| import (
 | |
| 	"crypto"
 | |
| 	"crypto/x509"
 | |
| 	"crypto/x509/pkix"
 | |
| 	"errors"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/tailscale/certstore"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	testRootCommonName = "testroot"
 | |
| 	testRootSubject    = "CN=testroot"
 | |
| )
 | |
| 
 | |
| type testIdentity struct {
 | |
| 	chain []*x509.Certificate
 | |
| }
 | |
| 
 | |
| func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
 | |
| 	return []*x509.Certificate{
 | |
| 		{
 | |
| 			NotBefore:          notBefore,
 | |
| 			NotAfter:           notAfter,
 | |
| 			PublicKeyAlgorithm: x509.RSA,
 | |
| 		},
 | |
| 		{
 | |
| 			Subject: pkix.Name{
 | |
| 				CommonName: rootCommonName,
 | |
| 			},
 | |
| 			PublicKeyAlgorithm: x509.RSA,
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *testIdentity) Certificate() (*x509.Certificate, error) {
 | |
| 	return t.chain[0], nil
 | |
| }
 | |
| 
 | |
| func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
 | |
| 	return t.chain, nil
 | |
| }
 | |
| 
 | |
| func (t *testIdentity) Signer() (crypto.Signer, error) {
 | |
| 	return nil, errors.New("not implemented")
 | |
| }
 | |
| 
 | |
| func (t *testIdentity) Delete() error {
 | |
| 	return errors.New("not implemented")
 | |
| }
 | |
| 
 | |
| func (t *testIdentity) Close() {}
 | |
| 
 | |
| func TestSelectIdentityFromSlice(t *testing.T) {
 | |
| 	var times []time.Time
 | |
| 	for _, ts := range []string{
 | |
| 		"2000-01-01T00:00:00Z",
 | |
| 		"2001-01-01T00:00:00Z",
 | |
| 		"2002-01-01T00:00:00Z",
 | |
| 		"2003-01-01T00:00:00Z",
 | |
| 	} {
 | |
| 		tm, err := time.Parse(time.RFC3339, ts)
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		times = append(times, tm)
 | |
| 	}
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		name    string
 | |
| 		subject string
 | |
| 		ids     []certstore.Identity
 | |
| 		now     time.Time
 | |
| 		// wantIndex is an index into ids, or -1 for nil.
 | |
| 		wantIndex int
 | |
| 	}{
 | |
| 		{
 | |
| 			name:    "single unexpired identity",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[2]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[1],
 | |
| 			wantIndex: 0,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "single expired identity",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[1]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[2],
 | |
| 			wantIndex: -1,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "unrelated ids",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain("something", times[0], times[2]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[2]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain("else", times[0], times[2]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[1],
 | |
| 			wantIndex: 1,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "expired with unrelated ids",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain("something", times[0], times[3]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[1]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain("else", times[0], times[3]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[2],
 | |
| 			wantIndex: -1,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "one expired",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[1]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[1], times[3]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[2],
 | |
| 			wantIndex: 1,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "two certs both unexpired",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[3]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[1], times[3]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[2],
 | |
| 			wantIndex: 1,
 | |
| 		},
 | |
| 		{
 | |
| 			name:    "two unexpired one expired",
 | |
| 			subject: testRootSubject,
 | |
| 			ids: []certstore.Identity{
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[3]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[1], times[3]),
 | |
| 				},
 | |
| 				&testIdentity{
 | |
| 					chain: makeChain(testRootCommonName, times[0], times[1]),
 | |
| 				},
 | |
| 			},
 | |
| 			now:       times[2],
 | |
| 			wantIndex: 1,
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
 | |
| 
 | |
| 			if gotId == nil && gotChain != nil {
 | |
| 				t.Error("id is nil: got non-nil chain, want nil chain")
 | |
| 				return
 | |
| 			}
 | |
| 			if gotId != nil && gotChain == nil {
 | |
| 				t.Error("id is not nil: got nil chain, want non-nil chain")
 | |
| 				return
 | |
| 			}
 | |
| 			if tt.wantIndex == -1 {
 | |
| 				if gotId != nil {
 | |
| 					t.Error("got non-nil id, want nil id")
 | |
| 				}
 | |
| 				return
 | |
| 			}
 | |
| 			if gotId == nil {
 | |
| 				t.Error("got nil id, want non-nil id")
 | |
| 				return
 | |
| 			}
 | |
| 			if gotId != tt.ids[tt.wantIndex] {
 | |
| 				found := -1
 | |
| 				for i := range tt.ids {
 | |
| 					if tt.ids[i] == gotId {
 | |
| 						found = i
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 				if found == -1 {
 | |
| 					t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
 | |
| 				} else {
 | |
| 					t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
 | |
| 			if !ok {
 | |
| 				t.Error("got non-testIdentity, want testIdentity")
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			if !reflect.DeepEqual(tid.chain, gotChain) {
 | |
| 				t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 |