mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-24 02:31:48 +01:00
tsnet: use ipnlocal.LocalBackend.SetCertsForTest
This replaces tsnet's independent and more limited cert-override logic. Signed-off-by: Harry Harpham <harry@tailscale.com>
This commit is contained in:
parent
376cd5fc61
commit
1584825c9a
@ -159,8 +159,6 @@ type Server struct {
|
||||
// that the control server will allow the node to adopt that tag.
|
||||
AdvertiseTags []string
|
||||
|
||||
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
lb *ipnlocal.LocalBackend
|
||||
@ -1102,9 +1100,6 @@ func (s *Server) RegisterFallbackTCPHandler(cb FallbackTCPHandler) func() {
|
||||
// It calls GetCertificate on the localClient, passing in the ClientHelloInfo.
|
||||
// For testing, if s.getCertForTesting is set, it will call that instead.
|
||||
func (s *Server) getCert(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.getCertForTesting != nil {
|
||||
return s.getCertForTesting(hi)
|
||||
}
|
||||
lc, err := s.LocalClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -30,7 +30,6 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@ -139,9 +138,6 @@ func startControl(t *testing.T) (controlURL string, control *testcontrol.Server)
|
||||
}
|
||||
|
||||
type testCertIssuer struct {
|
||||
mu sync.Mutex
|
||||
certs map[string]*tls.Certificate
|
||||
|
||||
root *x509.Certificate
|
||||
rootKey *ecdsa.PrivateKey
|
||||
}
|
||||
@ -172,40 +168,35 @@ func newCertIssuer() *testCertIssuer {
|
||||
panic(err)
|
||||
}
|
||||
return &testCertIssuer{
|
||||
certs: make(map[string]*tls.Certificate),
|
||||
root: rootCA,
|
||||
rootKey: rootKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
tci.mu.Lock()
|
||||
defer tci.mu.Unlock()
|
||||
cert, ok := tci.certs[chi.ServerName]
|
||||
if ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (tci *testCertIssuer) makeCert(domain string) (certPEM, keyPEM []byte, err error) {
|
||||
certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
certTmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
DNSNames: []string{chi.ServerName},
|
||||
DNSNames: []string{domain},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, certTmpl, tci.root, &certPrivKey.PublicKey, tci.rootKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
cert = &tls.Certificate{
|
||||
Certificate: [][]byte{certDER, tci.root.Raw},
|
||||
PrivateKey: certPrivKey,
|
||||
}
|
||||
tci.certs[chi.ServerName] = cert
|
||||
return cert, nil
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: must.Get(x509.MarshalPKCS8PrivateKey(certPrivKey)),
|
||||
})
|
||||
return certPEM, keyPEM, nil
|
||||
}
|
||||
|
||||
func (tci *testCertIssuer) Pool() *x509.CertPool {
|
||||
@ -222,12 +213,11 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string)
|
||||
tmp := filepath.Join(t.TempDir(), hostname)
|
||||
os.MkdirAll(tmp, 0755)
|
||||
s := &Server{
|
||||
Dir: tmp,
|
||||
ControlURL: controlURL,
|
||||
Hostname: hostname,
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
getCertForTesting: testCertRoot.getCert,
|
||||
Dir: tmp,
|
||||
ControlURL: controlURL,
|
||||
Hostname: hostname,
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
if *verboseNodes {
|
||||
s.Logf = t.Logf
|
||||
@ -238,6 +228,11 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nodeFQDN := hostname + "." + status.CurrentTailnet.MagicDNSSuffix
|
||||
certPEM, keyPEM, err := testCertRoot.makeCert(nodeFQDN)
|
||||
s.lb.SetCertsForTest(ipnlocal.TLSCertKeyPair{CertPEM: certPEM, KeyPEM: keyPEM})
|
||||
|
||||
return s, status.TailscaleIPs[0], status.Self.PublicKey
|
||||
}
|
||||
|
||||
@ -263,12 +258,11 @@ func TestDialBlocks(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "s2")
|
||||
os.MkdirAll(tmp, 0755)
|
||||
s2 := &Server{
|
||||
Dir: tmp,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s2",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
getCertForTesting: testCertRoot.getCert,
|
||||
Dir: tmp,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s2",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
if *verboseNodes {
|
||||
s2.Logf = log.Printf
|
||||
@ -816,20 +810,12 @@ func TestListenService(t *testing.T) {
|
||||
control.UpdateNode(serviceHostNode)
|
||||
|
||||
// Configure a certificate for the Service domain (in production,
|
||||
// the local backend would use an ACME client to obtain a cert).
|
||||
// the local backend would use an ACME client to obtain a certPEM).
|
||||
// This is only used when serving over TLS.
|
||||
cert := must.Get(testCertRoot.getCert(&tls.ClientHelloInfo{
|
||||
ServerName: serviceFQDN,
|
||||
}))
|
||||
certPEM, keyPEM := must.Get2(testCertRoot.makeCert(serviceFQDN))
|
||||
serviceHost.lb.SetCertsForTest(ipnlocal.TLSCertKeyPair{
|
||||
CertPEM: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Certificate[0],
|
||||
}),
|
||||
KeyPEM: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: must.Get(x509.MarshalPKCS8PrivateKey(cert.PrivateKey)),
|
||||
}),
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: keyPEM,
|
||||
})
|
||||
|
||||
// The service client must accept routes advertised by other nodes
|
||||
@ -841,6 +827,11 @@ func TestListenService(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
|
||||
// Force netmap updates to avoid race conditions. The nodes need to
|
||||
// see our control updates before we can start the test.
|
||||
serviceClient.lb.DebugForceNetmapUpdate()
|
||||
serviceHost.lb.DebugForceNetmapUpdate()
|
||||
|
||||
// == Done setting up mock state ==
|
||||
|
||||
// Start a Service listener.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user