diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index e1b45ba3f..947f4b005 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -66,8 +66,6 @@ import ( // verbosely log whenever DERP drops a packet. var verboseDropKeys = map[key.NodePublic]bool{} -var debugDisablePeerHashTrie = envknob.RegisterBool("TS_DEBUG_DERP_DISABLE_PEER_HASHTRIE") - // IdealNodeContextKey is the context key used to pass the IdealNodeHeader value // from the HTTP handler to the DERP server's Accept method. var IdealNodeContextKey = ctxkey.New("ideal-node", "") @@ -193,8 +191,16 @@ type Server struct { mu syncs.Mutex // guards the following fields closed bool netConns map[derp.Conn]chan struct{} // chan is closed when conn closes - clients map[key.NodePublic]*clientSet - watchers set.Set[*sclient] // mesh peers + // clients holds the set of clients connected locally to this server, + // keyed by their public key. Writes happen under Server.mu so they + // stay consistent with clientsMesh, watchers, dup tracking, and the + // numLocalClientKeys counter. Reads on the packet send hot path + // are performed lock-free; see lookupDest. + clients hashtriemap.HashTrieMap[key.NodePublic, *clientSet] + // numLocalClientKeys is the number of distinct keys in clients. + // HashTrieMap has no Len, so the count is tracked here. + numLocalClientKeys int + watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the // peer is only local (and thus in the clients Map, but not @@ -209,11 +215,6 @@ type Server struct { // maps from netip.AddrPort to a client's public key keyOfAddr map[netip.AddrPort]key.NodePublic rateConfig RateConfig // per-client DERP frame rate limiting config - - // clientsAtomic mirrors clients for local active-client lookup without - // taking Server.mu. The authoritative clients map is still guarded by - // Server.mu; this mirror is only a fast path for handleFrameSendPacket. - clientsAtomic hashtriemap.HashTrieMap[key.NodePublic, *clientSet] } // clientSet represents 1 or more *sclients. @@ -372,7 +373,6 @@ func New(privateKey key.NodePrivate, logf logger.Logf) *Server { logf: logf, limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), packetsRecvByKind: metrics.LabelMap{Label: "kind"}, - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, netConns: map[derp.Conn]chan struct{}{}, memSys0: ms.Sys, @@ -577,7 +577,7 @@ func (s *Server) UpdateRateLimits(rc RateConfig) (applied RateConfig) { rc.PerClientRateBurstBytes = max(rc.PerClientRateBurstBytes, minRateLimitTokenBucketSize) } s.rateConfig = rc - for _, cs := range s.clients { + for _, cs := range s.clients.All() { cs.ForeachClient(func(c *sclient) { c.setRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) }) @@ -634,7 +634,7 @@ func (s *Server) isClosed() bool { func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - x, ok := s.clients[k] + x, ok := s.clients.Load(k) if !ok { return false } @@ -749,11 +749,12 @@ func (s *Server) registerClient(c *sclient) { c.setRateLimit(s.rateConfig.PerClientRateLimitBytesPerSec, s.rateConfig.PerClientRateBurstBytes) - cs, ok := s.clients[c.key] + cs, ok := s.clients.Load(c.key) if !ok { c.debugLogf("register single client") cs = &clientSet{} - s.clients[c.key] = cs + s.clients.Store(c.key, cs) + s.numLocalClientKeys++ } was := cs.activeClient.Load() if was == nil { @@ -785,7 +786,6 @@ func (s *Server) registerClient(c *sclient) { } cs.activeClient.Store(c) - s.clientsAtomic.Store(c.key, cs) if _, ok := s.clientsMesh[c.key]; !ok { s.clientsMesh[c.key] = nil // just for varz of total users in cluster @@ -820,7 +820,7 @@ func (s *Server) unregisterClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - set, ok := s.clients[c.key] + set, ok := s.clients.Load(c.key) if !ok { c.logf("[unexpected]; clients map is empty") return @@ -840,8 +840,9 @@ func (s *Server) unregisterClient(c *sclient) { } c.debugLogf("removed connection") set.activeClient.Store(nil) - delete(s.clients, c.key) - s.clientsAtomic.CompareAndDelete(c.key, set) + if s.clients.CompareAndDelete(c.key, set) { + s.numLocalClientKeys-- + } if v, ok := s.clientsMesh[c.key]; ok && v == nil { delete(s.clientsMesh, c.key) s.notePeerGoneFromRegionLocked(c.key) @@ -972,7 +973,7 @@ func (s *Server) addWatcher(c *sclient) { defer s.mu.Unlock() // Queue messages for each already-connected client. - for peer, clientSet := range s.clients { + for peer, clientSet := range s.clients.All() { ac := clientSet.activeClient.Load() if ac == nil { continue @@ -1210,7 +1211,7 @@ func (c *sclient) handleFrameClosePeer(ft derp.FrameType, fl uint32) error { s.mu.Lock() defer s.mu.Unlock() - if set, ok := s.clients[targetKey]; ok { + if set, ok := s.clients.Load(targetKey); ok { if set.Len() == 1 { c.logf("frameClosePeer closing peer %x", targetKey) } else { @@ -1240,15 +1241,10 @@ func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { } s.packetsForwardedIn.Add(1) - var dstLen int - var dst *sclient - - s.mu.Lock() - if set, ok := s.clients[dstKey]; ok { - dstLen = set.Len() - dst = set.activeClient.Load() - } - s.mu.Unlock() + // Use the same lock-free fast path as the local send path. The mesh + // forwarder return is intentionally discarded: we never re-forward an + // already-forwarded packet. + dst, _, dstLen := c.lookupDest(dstKey) if dst == nil { reason := dropReasonUnknownDestOnFwd @@ -1274,30 +1270,25 @@ func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { // count for dst. dstLen is only meaningful when the returned local client is // nil; when a local client is returned, dstLen is just non-zero. // -// It first tries clientsAtomic as a lock-free fast path for active local -// clients. Cache misses, inactive clientSets, duplicate-client accounting, and -// mesh forwarder lookups fall back to lookupDestUncached. +// The fast path reads Server.clients lock-free: if a *clientSet is present +// for dst and has an active client, we return that without taking Server.mu. +// Misses, inactive clientSets, duplicate-client accounting, and mesh +// forwarder lookups fall through to a slow path under Server.mu. At most +// one local client and PacketForwarder can be non-nil: local clients win +// over mesh forwarding, and mesh forwarding is considered only when there +// is no local clientSet. func (c *sclient) lookupDest(dst key.NodePublic) (_ *sclient, fwd PacketForwarder, dstLen int) { - if !debugDisablePeerHashTrie() { - if set, ok := c.s.clientsAtomic.Load(dst); ok { - if dst := set.activeClient.Load(); dst != nil { - return dst, nil, 1 - } + s := c.s + if set, ok := s.clients.Load(dst); ok { + if dst := set.activeClient.Load(); dst != nil { + return dst, nil, 1 } } - return c.lookupDestUncached(dst) -} - -// lookupDestUncached is the authoritative destination lookup. It takes -// Server.mu to read Server.clients and Server.clientsMesh. At most one local -// client and PacketForwarder can be non-nil: local clients win over mesh -// forwarding, and mesh forwarding is considered only when there is no local -// clientSet. -func (c *sclient) lookupDestUncached(dst key.NodePublic) (_ *sclient, fwd PacketForwarder, dstLen int) { - s := c.s + // Slow path: no active local client. Take Server.mu to read the + // duplicate-client count and clientsMesh consistently. s.mu.Lock() defer s.mu.Unlock() - if set, ok := s.clients[dst]; ok { + if set, ok := s.clients.Load(dst); ok { if dst := set.activeClient.Load(); dst != nil { return dst, nil, 1 } @@ -1650,7 +1641,7 @@ func (s *Server) noteClientActivity(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - cs, ok := s.clients[c.key] + cs, ok := s.clients.Load(c.key) if !ok { return } @@ -2316,7 +2307,7 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) return } - if _, isLocal := s.clients[dst]; isLocal { + if _, isLocal := s.clients.Load(dst); isLocal { s.clientsMesh[dst] = nil } else { delete(s.clientsMesh, dst) @@ -2415,8 +2406,8 @@ func (s *Server) ExpVar(rateLimitEnabled bool) expvar.Var { m.Set("gauge_current_home_connections", &s.curHomeClients) m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) m.Set("gauge_clients_total", s.expVarFunc(func() any { return len(s.clientsMesh) })) - m.Set("gauge_clients_local", s.expVarFunc(func() any { return len(s.clients) })) - m.Set("gauge_clients_remote", s.expVarFunc(func() any { return len(s.clientsMesh) - len(s.clients) })) + m.Set("gauge_clients_local", s.expVarFunc(func() any { return s.numLocalClientKeys })) + m.Set("gauge_clients_remote", s.expVarFunc(func() any { return len(s.clientsMesh) - s.numLocalClientKeys })) m.Set("gauge_current_dup_client_keys", &s.dupClientKeys) m.Set("gauge_current_dup_client_conns", &s.dupClientConns) m.Set("counter_total_dup_client_conns", &s.dupClientConnTotal) @@ -2473,7 +2464,7 @@ func (s *Server) ConsistencyCheck() error { var nilMeshNotInClient int for k, f := range s.clientsMesh { if f == nil { - if _, ok := s.clients[k]; !ok { + if _, ok := s.clients.Load(k); !ok { nilMeshNotInClient++ } } @@ -2483,7 +2474,7 @@ func (s *Server) ConsistencyCheck() error { } var clientNotInMesh int - for k := range s.clients { + for k := range s.clients.All() { if _, ok := s.clientsMesh[k]; !ok { clientNotInMesh++ } @@ -2492,10 +2483,10 @@ func (s *Server) ConsistencyCheck() error { errs = append(errs, fmt.Sprintf("%d s.clients keys not in s.clientsMesh", clientNotInMesh)) } - if s.curClients.Value() != int64(len(s.clients)) { + if s.curClients.Value() != int64(s.numLocalClientKeys) { errs = append(errs, fmt.Sprintf("expvar connections = %d != clients map says of %d", s.curClients.Value(), - len(s.clients))) + s.numLocalClientKeys)) } if s.verifyClientsLocalTailscaled { @@ -2591,7 +2582,7 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) { if prev.Sent < next.Sent || prev.Recv < next.Recv { if pkey, ok := s.keyOfAddr[k]; ok { next.Key = pkey - if cs, ok := s.clients[pkey]; ok { + if cs, ok := s.clients.Load(pkey); ok { if c := cs.activeClient.Load(); c != nil { next.UniqueSenders = c.EstimatedUniqueSenders() } diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 9e9dc9802..13826819b 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "tailscale.com/derp" "tailscale.com/derp/derpconst" - "tailscale.com/envknob" "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -150,7 +149,6 @@ func pubAll(b byte) (ret key.NodePublic) { func TestForwarderRegistration(t *testing.T) { s := &Server{ - clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } want := func(want map[key.NodePublic]PacketForwarder) { @@ -232,7 +230,7 @@ func TestForwarderRegistration(t *testing.T) { key: u1, logf: logger.Discard, } - s.clients[u1] = singleClient(u1c) + s.clients.Store(u1, singleClient(u1c)) s.RemovePacketForwarder(u1, testFwd(100)) want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -252,7 +250,7 @@ func TestForwarderRegistration(t *testing.T) { // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard // that they're also connected to a peer of ours. That shouldn't transition the forwarder // from nil to the new one, not a multiForwarder. - s.clients[u1] = singleClient(u1c) + s.clients.Store(u1, singleClient(u1c)) s.clientsMesh[u1] = nil want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -284,7 +282,6 @@ func TestMultiForwarder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ - clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } u := pubAll(1) @@ -393,7 +390,7 @@ func TestServerDupClients(t *testing.T) { } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() - got, ok := s.clients[want.key] + got, ok := s.clients.Load(want.key) if !ok { t.Error("no clients for key") return @@ -416,7 +413,7 @@ func TestServerDupClients(t *testing.T) { } wantNoClient := func(t *testing.T) { t.Helper() - _, ok := s.clients[clientPub] + _, ok := s.clients.Load(clientPub) if !ok { // Good return @@ -425,7 +422,7 @@ func TestServerDupClients(t *testing.T) { } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() - cs, ok := s.clients[clientPub] + cs, ok := s.clients.Load(clientPub) if !ok { t.Fatal("no set for key; want dup set") return nil @@ -438,7 +435,7 @@ func TestServerDupClients(t *testing.T) { } wantActive := func(t *testing.T, want *sclient) { t.Helper() - set, ok := s.clients[clientPub] + set, ok := s.clients.Load(clientPub) if !ok { t.Error("no set for key") return @@ -779,7 +776,7 @@ func TestServeDebugTrafficUniqueSenders(t *testing.T) { s.mu.Lock() cs := &clientSet{} cs.activeClient.Store(c) - s.clients[clientKey] = cs + s.clients.Store(clientKey, cs) s.mu.Unlock() estimate := c.EstimatedUniqueSenders() @@ -1192,7 +1189,7 @@ func TestUpdateRateLimits(t *testing.T) { cs.activeClient.Store(c) s.mu.Lock() - s.clients[clientKey] = cs + s.clients.Store(clientKey, cs) s.mu.Unlock() rc := RateConfig{ @@ -1245,7 +1242,7 @@ func TestUpdateRateLimits(t *testing.T) { meshCS.activeClient.Store(meshClient) s.mu.Lock() - s.clients[meshKey] = meshCS + s.clients.Store(meshKey, meshCS) s.mu.Unlock() rc = RateConfig{ @@ -1272,7 +1269,7 @@ func TestUpdateRateLimits(t *testing.T) { dupCS.activeClient.Store(d1) dupCS.dup = &dupClientSet{set: set.Of(d1, d2)} s.mu.Lock() - s.clients[dupKey] = dupCS + s.clients.Store(dupKey, dupCS) s.mu.Unlock() rc = RateConfig{ @@ -1364,7 +1361,7 @@ func TestLoadAndApplyRateConfig(t *testing.T) { cs := &clientSet{} cs.activeClient.Store(c) s.mu.Lock() - s.clients[clientKey] = cs + s.clients.Store(clientKey, cs) s.mu.Unlock() f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d}`, @@ -1447,19 +1444,8 @@ func TestLoadAndApplyRateConfig(t *testing.T) { }) } -const peerHashTrieDisableEnv = "TS_DEBUG_DERP_DISABLE_PEER_HASHTRIE" - -func setPeerHashTrieDisabled(tb testing.TB, disabled bool) { - tb.Helper() - envknob.Setenv(peerHashTrieDisableEnv, fmt.Sprint(disabled)) - tb.Cleanup(func() { envknob.Setenv(peerHashTrieDisableEnv, "") }) -} - func TestLookupDestHashTrieFastPath(t *testing.T) { - setPeerHashTrieDisabled(t, false) - s := &Server{ - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, clock: tstime.StdClock{}, } @@ -1468,8 +1454,7 @@ func TestLookupDestHashTrieFastPath(t *testing.T) { dstClient := &sclient{key: dst} cs := &clientSet{} cs.activeClient.Store(dstClient) - s.clients[dst] = cs - s.clientsAtomic.Store(dst, cs) + s.clients.Store(dst, cs) c := &sclient{s: s, key: src} got, fwd, dstLen := c.lookupDest(dst) @@ -1488,10 +1473,7 @@ func TestLookupDestHashTrieFastPath(t *testing.T) { } func TestLookupDestHashTrieFallsBackForForwarder(t *testing.T) { - setPeerHashTrieDisabled(t, false) - s := &Server{ - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, clock: tstime.StdClock{}, } @@ -1506,11 +1488,8 @@ func TestLookupDestHashTrieFallsBackForForwarder(t *testing.T) { } } -func TestLookupDestHashTrieIgnoresInactiveStaleSet(t *testing.T) { - setPeerHashTrieDisabled(t, false) - +func TestLookupDestHashTrieIgnoresInactiveSet(t *testing.T) { s := &Server{ - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, clock: tstime.StdClock{}, } @@ -1518,24 +1497,28 @@ func TestLookupDestHashTrieIgnoresInactiveStaleSet(t *testing.T) { dst := pubAll(2) c := &sclient{s: s, key: src} - s.clientsAtomic.Store(dst, &clientSet{}) - - newClient := &sclient{key: dst} - newSet := &clientSet{} - newSet.activeClient.Store(newClient) - s.clients[dst] = newSet + // A clientSet with no activeClient (a transient state during + // register/unregister) must not be returned by the fast path. + cs := &clientSet{} + s.clients.Store(dst, cs) got, fwd, dstLen := c.lookupDest(dst) + if got != nil || fwd != nil || dstLen != 0 { + t.Fatalf("lookupDest with inactive set = (%v, %v, %d), want (nil, nil, 0)", got, fwd, dstLen) + } + + // Setting activeClient on the same in-map entry must make the next + // fast-path lookup observe it. + newClient := &sclient{key: dst} + cs.activeClient.Store(newClient) + got, fwd, dstLen = c.lookupDest(dst) if got != newClient || fwd != nil || dstLen != 1 { - t.Fatalf("lookupDest = (%v, %v, %d), want (%v, nil, 1)", got, fwd, dstLen, newClient) + t.Fatalf("lookupDest after activation = (%v, %v, %d), want (%v, nil, 1)", got, fwd, dstLen, newClient) } } func TestLookupDestHashTrieNoAlloc(t *testing.T) { - setPeerHashTrieDisabled(t, false) - s := &Server{ - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, clock: tstime.StdClock{}, } @@ -1546,8 +1529,7 @@ func TestLookupDestHashTrieNoAlloc(t *testing.T) { dstClients[i] = &sclient{key: dstKeys[i]} cs := &clientSet{} cs.activeClient.Store(dstClients[i]) - s.clients[dstKeys[i]] = cs - s.clientsAtomic.Store(dstKeys[i], cs) + s.clients.Store(dstKeys[i], cs) } c := &sclient{s: s, key: pubAll(1)} @@ -1568,7 +1550,6 @@ func TestLookupDestHashTrieNoAlloc(t *testing.T) { func BenchmarkLookupDestHashTrie(b *testing.B) { s := &Server{ - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, clock: tstime.StdClock{}, } @@ -1579,8 +1560,7 @@ func BenchmarkLookupDestHashTrie(b *testing.B) { dstClients[i] = &sclient{key: dstKeys[i]} cs := &clientSet{} cs.activeClient.Store(dstClients[i]) - s.clients[dstKeys[i]] = cs - s.clientsAtomic.Store(dstKeys[i], cs) + s.clients.Store(dstKeys[i], cs) } b.ReportAllocs()