diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index c8359ef5f..aad5793f7 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -17,7 +17,6 @@ import ( "reflect" "runtime" "slices" - "sync" "sync/atomic" "time" @@ -91,13 +90,6 @@ type endpoint struct { endpointState map[netip.AddrPort]*endpointState // netip.AddrPort type for key (instead of [epAddr]) as [endpointState] is irrelevant for Geneve-encapsulated paths isCallMeMaybeEP map[netip.AddrPort]bool - // We save the previous discoKeys to ensure that control is not overwriting a - // newer received via TSMP. We need to store multiple previous keys as we - // could have the endpoint restart multiple times while not connected to control. - previousDiscoKeys map[string]bool - - initializeEndpoint sync.Once - // The following fields are related to the new "silent disco" // implementation that's a WIP as of 2022-10-20. // See #540 for background. @@ -369,8 +361,9 @@ func (de *endpoint) setProbeUDPLifetimeConfigLocked(desired *ProbeUDPLifetimeCon // endpointDisco is the current disco key and short string for an endpoint. This // structure is immutable. type endpointDisco struct { - key key.DiscoPublic // for discovery messages. - short string // ShortString of discoKey. + key key.DiscoPublic // for discovery messages. + short string // ShortString of discoKey. + viaTSMP bool // the key was learned via TSMP } type sentPing struct { @@ -1478,25 +1471,47 @@ func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { // control from overwriting a key set by TSMP in a case where the endpoint // represented by de is unable to contact control and has shared its disco key // via TSMP. If key is a previously held key, this method is a noop. de.mu must -// be held while calling. -func (de *endpoint) updateDiscoKeyLocked(key *key.DiscoPublic) { - de.initializeEndpoint.Do(func() { - de.previousDiscoKeys = make(map[string]bool) - }) - var epDisco *endpointDisco - if key != nil { - if _, ok := de.previousDiscoKeys[key.String()]; ok { - return - } - epDisco = &endpointDisco{ - key: *key, - short: key.ShortString(), - } - if de.disco.Load() != nil { - de.previousDiscoKeys[de.disco.Load().key.String()] = true - } +// be held while calling. The return value is true if we stored a different key +// but false if the key is the same or a previously used key. +func (de *endpoint) updateDiscoKeyLocked(key *key.DiscoPublic, viaTSMP bool) bool { + if key == nil { + de.disco.Store((*endpointDisco)(nil)) + return true } - de.disco.Store(epDisco) + + epDisco := de.disco.Load() + // Existing key is nil, set new key + if epDisco == nil { + de.disco.Store(&endpointDisco{ + key: *key, + short: key.ShortString(), + viaTSMP: viaTSMP, + }) + return true + } + + // Key is the same. If we had learned it via TSMP before and are now getting + // it via control, update the field, but do not change the key. + if epDisco.key.Compare(*key) == 0 { + if epDisco.viaTSMP && !viaTSMP { + epDisco.viaTSMP = false + de.disco.Store(epDisco) + } + return false + } + + // The new key is from control but the old one viaTSMP, do nothing. + if !viaTSMP && epDisco.viaTSMP { + return false + } + + // New key that needs to be stored. + de.disco.Store(&endpointDisco{ + key: *key, + short: key.ShortString(), + viaTSMP: viaTSMP, + }) + return true } // updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap @@ -1525,12 +1540,14 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p if discoKey != n.DiscoKey() { de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), discoKey, n.DiscoKey()) key := n.DiscoKey() - de.updateDiscoKeyLocked(&key) + keyChanged := de.updateDiscoKeyLocked(&key, false) de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "updateFromNode-resetLocked", }) - de.resetLocked() + if keyChanged { + de.resetLocked() + } } if n.HomeDERP() == 0 { if de.derpAddr.IsValid() { diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index c77043114..1fc386d09 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -455,10 +455,19 @@ func Test_endpoint_udpRelayEndpointReady(t *testing.T) { } func TestUpdateDiscoKey(t *testing.T) { - t.Run("SetKey", func(t *testing.T) { + t.Run("SetKeyNotTSMP", func(t *testing.T) { de := &endpoint{} newKey := key.NewDisco().Public() - de.updateDiscoKeyLocked(&newKey) + de.updateDiscoKeyLocked(&newKey, false) + if newKey.Compare(de.disco.Load().key) != 0 { + t.Errorf("disco keys not equal, expected %v, got %v", newKey, de.disco.Load().key) + } + }) + + t.Run("SetKeyTSMP", func(t *testing.T) { + de := &endpoint{} + newKey := key.NewDisco().Public() + de.updateDiscoKeyLocked(&newKey, true) if newKey.Compare(de.disco.Load().key) != 0 { t.Errorf("disco keys not equal, expected %v, got %v", newKey, de.disco.Load().key) } @@ -466,7 +475,7 @@ func TestUpdateDiscoKey(t *testing.T) { t.Run("SetNilKey", func(t *testing.T) { de := &endpoint{} - de.updateDiscoKeyLocked(nil) + de.updateDiscoKeyLocked(nil, false) if de.disco.Load() != nil { t.Errorf("disco keys not equal, expected %v, got %v", nil, de.disco.Load().key) } @@ -476,11 +485,25 @@ func TestUpdateDiscoKey(t *testing.T) { de := &endpoint{} oldKey := key.NewDisco().Public() newKey := key.NewDisco().Public() - de.updateDiscoKeyLocked(&oldKey) - de.updateDiscoKeyLocked(&newKey) - de.updateDiscoKeyLocked(&oldKey) // <- Should not change the key + de.updateDiscoKeyLocked(&oldKey, false) + de.updateDiscoKeyLocked(&newKey, true) + de.updateDiscoKeyLocked(&oldKey, false) // <- Should not change the key + if newKey.Compare(de.disco.Load().key) != 0 { t.Errorf("disco keys not equal, expected %v, got %v", newKey, de.disco.Load().key) } + + de.updateDiscoKeyLocked(&newKey, false) // <- Connected to control + if de.disco.Load().viaTSMP { + t.Errorf("disco keys is learned viaTSMP, expected false") + } + + newerKey := key.NewDisco().Public() + de.updateDiscoKeyLocked(&newerKey, true) // <- No longer connected to control + de.updateDiscoKeyLocked(&oldKey, false) + de.updateDiscoKeyLocked(&newKey, false) + if newerKey.Compare(de.disco.Load().key) != 0 { + t.Errorf("disco keys not equal, expected %v, got %v", newerKey, de.disco.Load().key) + } }) } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 563a484d2..75cad4052 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -3182,10 +3182,10 @@ func (c *Conn) updateNodes(update NodeViewsUpdate) (peersChanged bool) { ep.initFakeUDPAddr() ep.mu.Lock() if n.DiscoKey().IsZero() { - ep.updateDiscoKeyLocked(nil) + ep.updateDiscoKeyLocked(nil, false) } else { key := n.DiscoKey() - ep.updateDiscoKeyLocked(&key) + ep.updateDiscoKeyLocked(&key, false) } ep.mu.Unlock() @@ -4321,7 +4321,9 @@ func (c *Conn) HandleDiscoKeyAdvertisement(node tailcfg.NodeView, update packet. return } c.discoInfoForKnownPeerLocked(discoKey) - ep.updateDiscoKeyLocked(&discoKey) + ep.mu.Lock() + ep.updateDiscoKeyLocked(&discoKey, true) + ep.mu.Unlock() c.peerMap.upsertEndpoint(ep, oldDiscoKey) c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString()) metricTSMPDiscoKeyAdvertisementApplied.Add(1)