tsnet: use ipnlocal.LocalBackend.SetCertsForTest

This replaces tsnet's independent and more limited cert-override logic.

Signed-off-by: Harry Harpham <harry@tailscale.com>
This commit is contained in:
Harry Harpham 2025-12-22 17:27:17 -07:00
parent 376cd5fc61
commit 1584825c9a
No known key found for this signature in database
2 changed files with 37 additions and 51 deletions

View File

@ -159,8 +159,6 @@ type Server struct {
// that the control server will allow the node to adopt that tag.
AdvertiseTags []string
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
initOnce sync.Once
initErr error
lb *ipnlocal.LocalBackend
@ -1102,9 +1100,6 @@ func (s *Server) RegisterFallbackTCPHandler(cb FallbackTCPHandler) func() {
// It calls GetCertificate on the localClient, passing in the ClientHelloInfo.
// For testing, if s.getCertForTesting is set, it will call that instead.
func (s *Server) getCert(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.getCertForTesting != nil {
return s.getCertForTesting(hi)
}
lc, err := s.LocalClient()
if err != nil {
return nil, err

View File

@ -30,7 +30,6 @@ import (
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -139,9 +138,6 @@ func startControl(t *testing.T) (controlURL string, control *testcontrol.Server)
}
type testCertIssuer struct {
mu sync.Mutex
certs map[string]*tls.Certificate
root *x509.Certificate
rootKey *ecdsa.PrivateKey
}
@ -172,40 +168,35 @@ func newCertIssuer() *testCertIssuer {
panic(err)
}
return &testCertIssuer{
certs: make(map[string]*tls.Certificate),
root: rootCA,
rootKey: rootKey,
}
}
func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
tci.mu.Lock()
defer tci.mu.Unlock()
cert, ok := tci.certs[chi.ServerName]
if ok {
return cert, nil
}
func (tci *testCertIssuer) makeCert(domain string) (certPEM, keyPEM []byte, err error) {
certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
return nil, nil, err
}
certTmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{chi.ServerName},
DNSNames: []string{domain},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, certTmpl, tci.root, &certPrivKey.PublicKey, tci.rootKey)
if err != nil {
return nil, err
return nil, nil, err
}
cert = &tls.Certificate{
Certificate: [][]byte{certDER, tci.root.Raw},
PrivateKey: certPrivKey,
}
tci.certs[chi.ServerName] = cert
return cert, nil
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must.Get(x509.MarshalPKCS8PrivateKey(certPrivKey)),
})
return certPEM, keyPEM, nil
}
func (tci *testCertIssuer) Pool() *x509.CertPool {
@ -222,12 +213,11 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string)
tmp := filepath.Join(t.TempDir(), hostname)
os.MkdirAll(tmp, 0755)
s := &Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: hostname,
Store: new(mem.Store),
Ephemeral: true,
getCertForTesting: testCertRoot.getCert,
Dir: tmp,
ControlURL: controlURL,
Hostname: hostname,
Store: new(mem.Store),
Ephemeral: true,
}
if *verboseNodes {
s.Logf = t.Logf
@ -238,6 +228,11 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string)
if err != nil {
t.Fatal(err)
}
nodeFQDN := hostname + "." + status.CurrentTailnet.MagicDNSSuffix
certPEM, keyPEM, err := testCertRoot.makeCert(nodeFQDN)
s.lb.SetCertsForTest(ipnlocal.TLSCertKeyPair{CertPEM: certPEM, KeyPEM: keyPEM})
return s, status.TailscaleIPs[0], status.Self.PublicKey
}
@ -263,12 +258,11 @@ func TestDialBlocks(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "s2")
os.MkdirAll(tmp, 0755)
s2 := &Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: "s2",
Store: new(mem.Store),
Ephemeral: true,
getCertForTesting: testCertRoot.getCert,
Dir: tmp,
ControlURL: controlURL,
Hostname: "s2",
Store: new(mem.Store),
Ephemeral: true,
}
if *verboseNodes {
s2.Logf = log.Printf
@ -816,20 +810,12 @@ func TestListenService(t *testing.T) {
control.UpdateNode(serviceHostNode)
// Configure a certificate for the Service domain (in production,
// the local backend would use an ACME client to obtain a cert).
// the local backend would use an ACME client to obtain a certPEM).
// This is only used when serving over TLS.
cert := must.Get(testCertRoot.getCert(&tls.ClientHelloInfo{
ServerName: serviceFQDN,
}))
certPEM, keyPEM := must.Get2(testCertRoot.makeCert(serviceFQDN))
serviceHost.lb.SetCertsForTest(ipnlocal.TLSCertKeyPair{
CertPEM: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
}),
KeyPEM: pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must.Get(x509.MarshalPKCS8PrivateKey(cert.PrivateKey)),
}),
CertPEM: certPEM,
KeyPEM: keyPEM,
})
// The service client must accept routes advertised by other nodes
@ -841,6 +827,11 @@ func TestListenService(t *testing.T) {
},
}))
// Force netmap updates to avoid race conditions. The nodes need to
// see our control updates before we can start the test.
serviceClient.lb.DebugForceNetmapUpdate()
serviceHost.lb.DebugForceNetmapUpdate()
// == Done setting up mock state ==
// Start a Service listener.