diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 195525228..30fc1bb2a 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -478,6 +478,27 @@ func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { return err == nil && ok } +var _ DiscoUpdateNotifier = mapRoutineState{} + +func (mrs mapRoutineState) MarkDiscoAsLearnedFromTSMP(pub key.NodePublic, disco key.DiscoPublic) { + c := mrs.c + c.mu.Lock() + goodState := c.loggedIn && c.inMapPoll + dun, ok := c.observer.(DiscoUpdateNotifier) + c.mu.Unlock() + + if !goodState || !ok { + return + } + + ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + defer cancel() + + c.observerQueue.RunSync(ctx, func() { + dun.MarkDiscoAsLearnedFromTSMP(pub, disco) + }) +} + // mapRoutine is responsible for keeping a read-only streaming connection to the // control server, and keeping the netmap up to date. func (c *Auto) mapRoutine() { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index dc3ebd300..7bc0b76b6 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -228,6 +228,17 @@ type NetmapDeltaUpdater interface { UpdateNetmapDelta([]netmap.NodeMutation) (ok bool) } +// DiscoUpdateNotifier is implemented by users of the control client (such as +// local backend) to get notified of updates to disco keys received via TSMP. +// This interface only regards itself with marking. Implementors must remove +// marks themselves. +// +// There is a mirror implementation in the wgengine watchdog, implemented there +// to avoid the dependency edge between controlClient and wgengine. +type DiscoUpdateNotifier interface { + MarkDiscoAsLearnedFromTSMP(key.NodePublic, key.DiscoPublic) +} + var nextControlClientID atomic.Int64 // NewDirect returns a new Direct client. @@ -367,7 +378,7 @@ func NewDirect(opts Options) (*Direct, error) { // mapSession has gone away, we want to fall back to pushing the key // further down the chain. if err := c.streamingMapSession.updateDiscoForNode( - peer.ID(), update.Key, time.Now(), false); err == nil || + peer.ID(), peer.Key(), update.Key, time.Now(), false); err == nil || !errors.Is(err, ErrChangeQueueClosed) { return } @@ -377,10 +388,7 @@ func NewDirect(opts Options) (*Direct, error) { // not have a mapSession (we are not connected to control) or because the // mapSession queue has closed. c.logf("controlclient direct: updating discoKey for %v via magicsock", update.Src) - discoKeyPub.Publish(events.PeerDiscoKeyUpdate{ - Src: update.Src, - Key: update.Key, - }) + discoKeyPub.Publish(events.PeerDiscoKeyUpdate(update)) }) return c, nil @@ -859,8 +867,10 @@ func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error { // update it observed. It is used by tests and [NetmapFromMapResponseForDebug]. // It will report only the first netmap seen. type rememberLastNetmapUpdater struct { - last *netmap.NetworkMap - done chan any + last *netmap.NetworkMap + lastTSMPKey key.NodePublic + lastTSMPDisco key.DiscoPublic + done chan any } func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { @@ -871,6 +881,12 @@ func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { } } +func (nu *rememberLastNetmapUpdater) MarkDiscoAsLearnedFromTSMP( + key key.NodePublic, disco key.DiscoPublic) { + nu.lastTSMPKey = key + nu.lastTSMPDisco = disco +} + // FetchNetMapForTest fetches the netmap once. func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) { var nu rememberLastNetmapUpdater diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 1a0ab0037..b11c69f0f 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -37,6 +37,18 @@ import ( "tailscale.com/wgengine/filter" ) +type updateSource int + +const ( + sourceControl updateSource = iota + sourceTSMP +) + +type responseWithSource struct { + response *tailcfg.MapResponse + source updateSource +} + // mapSession holds the state over a long-polled "map" request to the // control plane. // @@ -98,7 +110,7 @@ type mapSession struct { lastTKAInfo *tailcfg.TKAInfo lastNetmapSummary string // from NetworkMap.VeryConcise cqmu sync.Mutex - changeQueue chan (*tailcfg.MapResponse) + changeQueue chan (responseWithSource) changeQueueClosed bool processQueue sync.WaitGroup } @@ -123,7 +135,7 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob cancel: func() {}, onDebug: func(context.Context, *tailcfg.Debug) error { return nil }, onSelfNodeChanged: func(*netmap.NetworkMap) {}, - changeQueue: make(chan *tailcfg.MapResponse), + changeQueue: make(chan responseWithSource), changeQueueClosed: false, } ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background()) @@ -142,7 +154,7 @@ func (ms *mapSession) run() { for { select { case change := <-ms.changeQueue: - ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change) + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.source) case <-ms.sessionAliveCtx.Done(): // Drain any remaining items in the queue before exiting. // Lock the queue during this time to avoid updates through other channels @@ -154,7 +166,7 @@ func (ms *mapSession) run() { for { select { case change := <-ms.changeQueue: - ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change) + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.source) default: // Queue is empty, close it and exit close(ms.changeQueue) @@ -190,7 +202,7 @@ func (ms *mapSession) Close() { var ErrChangeQueueClosed = errors.New("change queue closed") -func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.DiscoPublic, lastSeen time.Time, online bool) error { +func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.NodePublic, discoKey key.DiscoPublic, lastSeen time.Time, online bool) error { ms.cqmu.Lock() if ms.changeQueueClosed { @@ -199,13 +211,17 @@ func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.DiscoPublic, return ErrChangeQueueClosed } - resp := &tailcfg.MapResponse{ - PeersChangedPatch: []*tailcfg.PeerChange{{ - NodeID: id, - LastSeen: &lastSeen, - Online: &online, - DiscoKey: &key, - }}, + resp := responseWithSource{ + response: &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: id, + Key: &key, + LastSeen: &lastSeen, + Online: &online, + DiscoKey: &discoKey, + }}, + }, + source: sourceTSMP, } ms.changeQueue <- resp ms.cqmu.Unlock() @@ -221,7 +237,12 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t return ErrChangeQueueClosed } - ms.changeQueue <- resp + change := responseWithSource{ + response: resp, + source: sourceControl, + } + + ms.changeQueue <- change ms.cqmu.Unlock() return nil } @@ -234,7 +255,9 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t // // TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this // is [re]factoring progress enough. -func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error { +func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, + resp *tailcfg.MapResponse, source updateSource, +) error { if debug := resp.Debug; debug != nil { if err := ms.onDebug(ctx, debug); err != nil { return err @@ -284,6 +307,13 @@ func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.updateStateFromResponse(resp) + // If source was learned via TSMP, the updated disco key need to be marked in + // userspaceEngine as an update that should not reconfigure the wireguard + // connection. + if source == sourceTSMP { + ms.tryMarkDiscoAsLearnedFromTSMP(resp) + } + if ms.tryHandleIncrementally(resp) { ms.occasionallyPrintSummary(ms.lastNetmapSummary) return nil @@ -312,6 +342,21 @@ func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *t return nil } +func (ms *mapSession) tryMarkDiscoAsLearnedFromTSMP(res *tailcfg.MapResponse) { + dun, ok := ms.netmapUpdater.(DiscoUpdateNotifier) + if !ok { + return + } + + // In reality we should never really have more than one change here over TSMP. + for _, change := range res.PeersChangedPatch { + if change == nil || change.DiscoKey == nil || change.Key == nil { + continue + } + dun.MarkDiscoAsLearnedFromTSMP(*change.Key, *change.DiscoKey) + } +} + // upgradeNode upgrades Node fields from the server into the modern forms // not using deprecated fields. func upgradeNode(n *tailcfg.Node) { diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 154b9742e..520d0714b 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -678,6 +678,7 @@ func TestUpdateDiscoForNode(t *testing.T) { // Insert existing node node := tailcfg.Node{ ID: 1, + Key: key.NewNode().Public(), DiscoKey: oldKey.Public(), Online: &tt.initialOnline, LastSeen: &tt.initialLastSeen, @@ -690,7 +691,7 @@ func TestUpdateDiscoForNode(t *testing.T) { } newKey := key.NewDisco() - ms.updateDiscoForNode(node.ID, newKey.Public(), tt.updateLastSeen, tt.updateOnline) + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), tt.updateLastSeen, tt.updateOnline) <-nu.done nm := ms.netmap() @@ -707,6 +708,82 @@ func TestUpdateDiscoForNode(t *testing.T) { } } +func TestUpdateDiscoForNodeCallback(t *testing.T) { + t.Run("key_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco() + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), time.Now(), false) + <-nu.done + + if nu.lastTSMPKey != node.Key || nu.lastTSMPDisco != newKey.Public() { + t.Fatalf("expected [%s]=%s, got [%s]=%s", node.Key, newKey.Public(), + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) + t.Run("key_not_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco().Public() + resp := &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: node.ID, + Key: &node.Key, + LastSeen: new(time.Now()), + Online: new(true), + DiscoKey: &newKey, + }}, + } + ms.HandleNonKeepAliveMapResponse(t.Context(), resp) + <-nu.done + + if !nu.lastTSMPKey.IsZero() || !nu.lastTSMPDisco.IsZero() { + t.Fatalf("expected zero keys, got [%s]=%s", + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) +} + func first[T any](s []T) T { if len(s) == 0 { var zero T diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 49e1f00c7..5b1984233 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1857,6 +1857,12 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c b.authReconfigLocked() } +func (b *LocalBackend) MarkDiscoAsLearnedFromTSMP(pub key.NodePublic, disco key.DiscoPublic) { + if e, ok := b.e.(controlclient.DiscoUpdateNotifier); ok { + e.MarkDiscoAsLearnedFromTSMP(pub, disco) + } +} + type preferencePolicyInfo struct { key pkey.Key get func(ipn.PrefsView) bool diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 5670541af..364c70c9c 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -121,7 +121,8 @@ type userspaceEngine struct { birdClient BIRDClient // or nil controlKnobs *controlknobs.Knobs // or nil - testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testDiscoChangedHook func(map[key.NodePublic]bool) // for tests; if non-nil, fires after assembling discoChanged map // isLocalAddr reports the whether an IP is assigned to the local // tunnel interface. It's used to reflect local packets @@ -167,6 +168,10 @@ type userspaceEngine struct { // networkLogger logs statistics about network connections. networkLogger netlog.Logger + // tsmpLearnedDisco tracks per node key if a peer disco key was learned via TSMP. + // wgLock must be held when using this map. + tsmpLearnedDisco map[key.NodePublic]key.DiscoPublic + // Lock ordering: magicsock.Conn.mu, wgLock, then mu. } @@ -1028,6 +1033,12 @@ func (e *userspaceEngine) ResetAndStop() (*Status, error) { } } +func (e *userspaceEngine) MarkDiscoAsLearnedFromTSMP(pub key.NodePublic, disco key.DiscoPublic) { + e.wgLock.Lock() + defer e.wgLock.Unlock() + mak.Set(&e.tsmpLearnedDisco, pub, disco) +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -1119,14 +1130,31 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, if p.DiscoKey.IsZero() { continue } + + // If the key changed, mark the connection for reconfiguration. pub := p.PublicKey if old, ok := prevEP[pub]; ok && old != p.DiscoKey { + // If the disco key was learned via TSMP, we do not need to reset the + // wireguard config as the new key was received over an existing wireguard + // connection. + if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP && + discoTSMP == p.DiscoKey { + delete(e.tsmpLearnedDisco, p.PublicKey) + e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) + continue + } + discoChanged[pub] = true e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) } } } + // For tests, what disco connections needs to be changed. + if e.testDiscoChangedHook != nil { + e.testDiscoChangedHook(discoChanged) + } + e.lastCfgFull = *cfg.Clone() // Tell magicsock about the new (or initial) private key @@ -1144,6 +1172,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, return err } + // Cleanup map of tsmp marks for peers that no longer exists in config. + for nodeKey := range e.tsmpLearnedDisco { + if !peerSet.Contains(nodeKey) { + delete(e.tsmpLearnedDisco, nodeKey) + } + } + // Shutdown the network logger because the IDs changed. // Let it be started back up by subsequent logic. if buildfeatures.HasNetLog && netLogIDsChanged && e.networkLogger.Running() { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 18d870af1..b72fbe5b1 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -164,6 +164,79 @@ func TestUserspaceEngineReconfig(t *testing.T) { } } +func TestUserspaceEngineTSMPLearned(t *testing.T) { + bus := eventbustest.NewBus(t) + + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + + discoChangedChan := make(chan map[key.NodePublic]bool, 1) + ue.testDiscoChangedHook = func(m map[key.NodePublic]bool) { + discoChangedChan <- m + } + + routerCfg := &router.Config{} + + keyChanges := []struct { + tsmp bool + inMap bool + }{ + {tsmp: false, inMap: false}, + {tsmp: true, inMap: false}, + {tsmp: false, inMap: true}, + } + + nkHex := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + for _, change := range keyChanges { + oldDisco := key.NewDisco() + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: nkFromHex(nkHex), + DiscoKey: oldDisco.Public(), + }, + }), + } + nk, err := key.ParseNodePublicUntyped(mem.S(nkHex)) + if err != nil { + t.Fatal(err) + } + e.SetNetworkMap(nm) + + newDisco := key.NewDisco() + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nk, + DiscoKey: newDisco.Public(), + }, + }, + } + + if change.tsmp { + ue.MarkDiscoAsLearnedFromTSMP(nk, newDisco.Public()) + } + err = e.Reconfig(cfg, routerCfg, &dns.Config{}) + if err != nil { + t.Fatal(err) + } + + changeMap := <-discoChangedChan + + if _, ok := changeMap[nk]; ok != change.inMap { + t.Fatalf("expect key %v in map %v to be %t, got %t", nk, changeMap, + change.inMap, ok) + } + } +} + func TestUserspaceEnginePortReconfig(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/2855") const defaultPort = 49983 diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index f12b1c19e..567c62294 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -242,3 +242,15 @@ func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { func (e *watchdogEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok bool) { return e.wrap.PeerByKey(pubKey) } + +func (e *watchdogEngine) MarkDiscoAsLearnedFromTSMP(pub key.NodePublic, disco key.DiscoPublic) { + // discoUpdateNotifier mirrors the implementation of [controlclient.DiscoUpdateNotifier]. + // It is implemented here to avoid the dependency edge to controlclient, but must be kept + // in sync with the original implementation. + type discoUpdateNotifier interface { + MarkDiscoAsLearnedFromTSMP(key.NodePublic, key.DiscoPublic) + } + if n, ok := e.wrap.(discoUpdateNotifier); ok { + n.MarkDiscoAsLearnedFromTSMP(pub, disco) + } +}