James Tucker c09c95ef67 types/key,wgengine/magicsock,control/controlclient,ipn: add debug disco key rotation
Adds the ability to rotate discovery keys on running clients, needed for
testing upcoming disco key distribution changes.

Introduces key.DiscoKey, an atomic container for a disco private key,
public key, and the public key's ShortString, replacing the prior
separate atomic fields.

magicsock.Conn has a new RotateDiscoKey method, and access to this is
provided via localapi and a CLI debug command.

Note that this implementation is primarily for testing as it stands, and
regular use should likely introduce an additional mechanism that allows
the old key to be used for some time, to provide a seamless key rotation
rather than one that invalidates all sessions.

Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
2025-11-18 12:16:15 -08:00

184 lines
4.0 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package controlclient
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"tailscale.com/hostinfo"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/eventbus/eventbustest"
)
func TestSetDiscoPublicKey(t *testing.T) {
initialKey := key.NewDisco().Public()
c := &Direct{
discoPubKey: initialKey,
}
c.mu.Lock()
if c.discoPubKey != initialKey {
t.Fatalf("initial disco key mismatch: got %v, want %v", c.discoPubKey, initialKey)
}
c.mu.Unlock()
newKey := key.NewDisco().Public()
c.SetDiscoPublicKey(newKey)
c.mu.Lock()
if c.discoPubKey != newKey {
t.Fatalf("disco key not updated: got %v, want %v", c.discoPubKey, newKey)
}
if c.discoPubKey == initialKey {
t.Fatal("disco key should have changed")
}
c.mu.Unlock()
}
func TestNewDirect(t *testing.T) {
hi := hostinfo.New()
ni := tailcfg.NetInfo{LinkType: "wired"}
hi.NetInfo = &ni
bus := eventbustest.NewBus(t)
k := key.NewMachine()
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
opts := Options{
ServerURL: "https://example.com",
Hostinfo: hi,
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return k, nil
},
Dialer: dialer,
Bus: bus,
}
c, err := NewDirect(opts)
if err != nil {
t.Fatal(err)
}
if c.serverURL != opts.ServerURL {
t.Errorf("c.serverURL got %v want %v", c.serverURL, opts.ServerURL)
}
// hi is stored without its NetInfo field.
hiWithoutNi := *hi
hiWithoutNi.NetInfo = nil
if !hiWithoutNi.Equal(c.hostinfo) {
t.Errorf("c.hostinfo got %v want %v", c.hostinfo, hi)
}
changed := c.SetNetInfo(&ni)
if changed {
t.Errorf("c.SetNetInfo(ni) want false got %v", changed)
}
ni = tailcfg.NetInfo{LinkType: "wifi"}
changed = c.SetNetInfo(&ni)
if !changed {
t.Errorf("c.SetNetInfo(ni) want true got %v", changed)
}
changed = c.SetHostinfo(hi)
if changed {
t.Errorf("c.SetHostinfo(hi) want false got %v", changed)
}
hi = hostinfo.New()
hi.Hostname = "different host name"
changed = c.SetHostinfo(hi)
if !changed {
t.Errorf("c.SetHostinfo(hi) want true got %v", changed)
}
endpoints := fakeEndpoints(1, 2, 3)
changed = c.newEndpoints(endpoints)
if !changed {
t.Errorf("c.newEndpoints want true got %v", changed)
}
changed = c.newEndpoints(endpoints)
if changed {
t.Errorf("c.newEndpoints want false got %v", changed)
}
endpoints = fakeEndpoints(4, 5, 6)
changed = c.newEndpoints(endpoints)
if !changed {
t.Errorf("c.newEndpoints want true got %v", changed)
}
}
func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) {
for _, port := range ports {
ret = append(ret, tailcfg.Endpoint{
Addr: netip.AddrPortFrom(netip.Addr{}, port),
})
}
return
}
func TestTsmpPing(t *testing.T) {
hi := hostinfo.New()
ni := tailcfg.NetInfo{LinkType: "wired"}
hi.NetInfo = &ni
bus := eventbustest.NewBus(t)
k := key.NewMachine()
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
opts := Options{
ServerURL: "https://example.com",
Hostinfo: hi,
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return k, nil
},
Dialer: dialer,
Bus: bus,
}
c, err := NewDirect(opts)
if err != nil {
t.Fatal(err)
}
pingRes := &tailcfg.PingResponse{
Type: "TSMP",
IP: "123.456.7890",
Err: "",
NodeName: "testnode",
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
body := new(ipnstate.PingResult)
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
t.Fatal(err)
}
if pingRes.IP != body.IP {
t.Fatalf("PingResult did not have the correct IP : got %v, expected : %v", body.IP, pingRes.IP)
}
w.WriteHeader(200)
}))
defer ts.Close()
now := time.Now()
pr := &tailcfg.PingRequest{
URL: ts.URL,
}
err = postPingResult(now, t.Logf, c.httpc, pr, pingRes)
if err != nil {
t.Fatal(err)
}
}