// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause package derpserver import ( "bufio" "cmp" "context" "crypto/x509" "encoding/asn1" "encoding/binary" "expvar" "fmt" "log" "net" "os" "path/filepath" "reflect" "strconv" "sync" "testing" "testing/synctest" "time" "github.com/axiomhq/hyperloglog" qt "github.com/frankban/quicktest" "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/derp" "tailscale.com/derp/derpconst" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/set" ) const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" func TestSetMeshKey(t *testing.T) { for name, tt := range map[string]struct { key string want key.DERPMesh wantErr bool }{ "clobber": { key: testMeshKey, wantErr: false, }, "invalid": { key: "badf00d", wantErr: true, }, } { t.Run(name, func(t *testing.T) { s := &Server{} err := s.SetMeshKey(tt.key) if tt.wantErr { if err == nil { t.Fatalf("expected err") } return } if err != nil { t.Fatalf("unexpected err: %v", err) } want, err := key.ParseDERPMesh(tt.key) if err != nil { t.Fatal(err) } if !s.meshKey.Equal(want) { t.Fatalf("got %v, want %v", s.meshKey, want) } }) } } func TestIsMeshPeer(t *testing.T) { s := &Server{} err := s.SetMeshKey(testMeshKey) if err != nil { t.Fatal(err) } for name, tt := range map[string]struct { want bool meshKey string wantAllocs float64 }{ "nil": { want: false, wantAllocs: 0, }, "mismatch": { meshKey: "6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8", want: false, wantAllocs: 1, }, "match": { meshKey: testMeshKey, want: true, wantAllocs: 0, }, } { t.Run(name, func(t *testing.T) { var got bool var mKey key.DERPMesh if tt.meshKey != "" { mKey, err = key.ParseDERPMesh(tt.meshKey) if err != nil { t.Fatalf("ParseDERPMesh(%q) failed: %v", tt.meshKey, err) } } info := derp.ClientInfo{ MeshKey: mKey, } allocs := testing.AllocsPerRun(1, func() { got = s.isMeshPeer(&info) }) if got != tt.want { t.Fatalf("got %t, want %t: info = %#v", got, tt.want, info) } if allocs != tt.wantAllocs && tt.want { t.Errorf("%f allocations, want %f", allocs, tt.wantAllocs) } }) } } type testFwd int func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { panic("not called in tests") } func (testFwd) String() string { panic("not called in tests") } func pubAll(b byte) (ret key.NodePublic) { var bs [32]byte for i := range bs { bs[i] = b } return key.NodePublicFromRaw32(mem.B(bs[:])) } func TestForwarderRegistration(t *testing.T) { s := &Server{ clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } want := func(want map[key.NodePublic]PacketForwarder) { t.Helper() if got := s.clientsMesh; !reflect.DeepEqual(got, want) { t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) } } wantCounter := func(c *expvar.Int, want int) { t.Helper() if got := c.Value(); got != int64(want) { t.Errorf("counter = %v; want %v", got, want) } } singleClient := func(c *sclient) *clientSet { cs := &clientSet{} cs.activeClient.Store(c) return cs } u1 := pubAll(1) u2 := pubAll(2) u3 := pubAll(3) s.AddPacketForwarder(u1, testFwd(1)) s.AddPacketForwarder(u2, testFwd(2)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered forwarder is no-op. s.RemovePacketForwarder(u2, testFwd(999)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered user is no-op. s.RemovePacketForwarder(u3, testFwd(1)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Actual removal. s.RemovePacketForwarder(u2, testFwd(2)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), }) // Adding a dup for a user. wantCounter(&s.multiForwarderCreated, 0) s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path want(map[key.NodePublic]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) wantCounter(&s.multiForwarderCreated, 1) // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) want(map[key.NodePublic]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) // Removing a forwarder in a multi set that does exist should collapse it away // from being a multiForwarder. wantCounter(&s.multiForwarderDeleted, 0) s.RemovePacketForwarder(u1, testFwd(1)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(100), }) wantCounter(&s.multiForwarderDeleted, 1) // Removing an entry for a client that's still connected locally should result // in a nil forwarder. u1c := &sclient{ key: u1, logf: logger.Discard, } s.clients[u1] = singleClient(u1c) s.RemovePacketForwarder(u1, testFwd(100)) want(map[key.NodePublic]PacketForwarder{ u1: nil, }) // But once that client disconnects, it should go away. s.unregisterClient(u1c) want(map[key.NodePublic]PacketForwarder{}) // But if it already has a forwarder, it's not removed. s.AddPacketForwarder(u1, testFwd(2)) s.unregisterClient(u1c) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(2), }) // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard // that they're also connected to a peer of ours. That shouldn't transition the forwarder // from nil to the new one, not a multiForwarder. s.clients[u1] = singleClient(u1c) s.clientsMesh[u1] = nil want(map[key.NodePublic]PacketForwarder{ u1: nil, }) s.AddPacketForwarder(u1, testFwd(3)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(3), }) } type channelFwd struct { // id is to ensure that different instances that reference the // same channel are not equal, as they are used as keys in the // multiForwarder map. id int c chan []byte } func (f channelFwd) String() string { return "" } func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { f.c <- packet return nil } func TestMultiForwarder(t *testing.T) { received := 0 var wg sync.WaitGroup ch := make(chan []byte) ctx, cancel := context.WithCancel(context.Background()) s := &Server{ clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } u := pubAll(1) s.AddPacketForwarder(u, channelFwd{1, ch}) wg.Add(2) go func() { defer wg.Done() for { select { case <-ch: received += 1 case <-ctx.Done(): return } } }() go func() { defer wg.Done() for { s.AddPacketForwarder(u, channelFwd{2, ch}) s.AddPacketForwarder(u, channelFwd{3, ch}) s.RemovePacketForwarder(u, channelFwd{2, ch}) s.RemovePacketForwarder(u, channelFwd{1, ch}) s.AddPacketForwarder(u, channelFwd{1, ch}) s.RemovePacketForwarder(u, channelFwd{3, ch}) if ctx.Err() != nil { return } } }() // Number of messages is chosen arbitrarily, just for this loop to // run long enough concurrently with {Add,Remove}PacketForwarder loop above. numMsgs := 5000 var fwd PacketForwarder for i := range numMsgs { s.mu.Lock() fwd = s.clientsMesh[u] s.mu.Unlock() fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) } cancel() wg.Wait() if received != numMsgs { t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) } } func TestMetaCert(t *testing.T) { priv := key.NewNode() pub := priv.Public() s := New(priv, t.Logf) certBytes := s.MetaCert() cert, err := x509.ParseCertificate(certBytes) if err != nil { log.Fatal(err) } if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(derp.ProtocolVersion) { t.Errorf("serial = %v; want %v", cert.SerialNumber, derp.ProtocolVersion) } if g, w := cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix+pub.UntypedHexString(); g != w { t.Errorf("CommonName = %q; want %q", g, w) } if n := len(cert.Extensions); n != 1 { t.Fatalf("got %d extensions; want 1", n) } // oidExtensionBasicConstraints is the Basic Constraints ID copied // from the x509 package. oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19} if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) { t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints) } } func TestServerDupClients(t *testing.T) { serverPriv := key.NewNode() var s *Server clientPriv := key.NewNode() clientPub := clientPriv.Public() var c1, c2, c3 *sclient var clientName map[*sclient]string // run starts a new test case and resets clients back to their zero values. run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { s = New(serverPriv, t.Logf) s.dupPolicy = dupPolicy c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} clientName = map[*sclient]string{ c1: "c1", c2: "c2", c3: "c3", } t.Run(name, f) } runBothWays := func(name string, f func(t *testing.T)) { run(name+"_disablefighters", disableFighters, f) run(name+"_lastwriteractive", lastWriterIsActive, f) } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() got, ok := s.clients[want.key] if !ok { t.Error("no clients for key") return } if got.dup != nil { t.Errorf("unexpected dup set for single client") } cur := got.activeClient.Load() if cur != want { t.Errorf("active client = %q; want %q", clientName[cur], clientName[want]) } if cur != nil { if cur.isDup.Load() { t.Errorf("unexpected isDup on singleClient") } if cur.isDisabled.Load() { t.Errorf("unexpected isDisabled on singleClient") } } } wantNoClient := func(t *testing.T) { t.Helper() _, ok := s.clients[clientPub] if !ok { // Good return } t.Errorf("got client; want empty") } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() cs, ok := s.clients[clientPub] if !ok { t.Fatal("no set for key; want dup set") return nil } if cs.dup != nil { return cs.dup } t.Fatalf("no dup set for key; want dup set") return nil } wantActive := func(t *testing.T, want *sclient) { t.Helper() set, ok := s.clients[clientPub] if !ok { t.Error("no set for key") return } got := set.activeClient.Load() if got != want { t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) } } checkDup := func(t *testing.T, c *sclient, want bool) { t.Helper() if got := c.isDup.Load(); got != want { t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want) } } checkDisabled := func(t *testing.T, c *sclient, want bool) { t.Helper() if got := c.isDisabled.Load(); got != want { t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want) } } wantDupConns := func(t *testing.T, want int) { t.Helper() if got := s.dupClientConns.Value(); got != int64(want) { t.Errorf("dupClientConns = %v; want %v", got, want) } } wantDupKeys := func(t *testing.T, want int) { t.Helper() if got := s.dupClientKeys.Value(); got != int64(want) { t.Errorf("dupClientKeys = %v; want %v", got, want) } } // Common case: a single client comes and goes, with no dups. runBothWays("one_comes_and_goes", func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) s.unregisterClient(c1) wantNoClient(t) }) // A still somewhat common case: a single client was // connected and then their wifi dies or laptop closes // or they switch networks and connect from a // different network. They have two connections but // it's not very bad. Only their new one is // active. The last one, being dead, doesn't send and // thus the new one doesn't get disabled. runBothWays("small_overlap_replacement", func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) wantActive(t, c1) wantDupKeys(t, 0) wantDupKeys(t, 0) s.registerClient(c2) // wifi dies; c2 replacement connects wantDupSet(t) wantDupConns(t, 2) wantDupKeys(t, 1) checkDup(t, c1, true) checkDup(t, c2, true) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c2) // sends go to the replacement s.unregisterClient(c1) // c1 finally times out wantSingleClient(t, c2) checkDup(t, c2, false) // c2 is longer a dup wantActive(t, c2) wantDupConns(t, 0) wantDupKeys(t, 0) }) // Key cloning situation with concurrent clients, both trying // to write. run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) wantActive(t, c1) s.registerClient(c2) wantDupSet(t) wantDupKeys(t, 1) wantDupConns(t, 2) wantActive(t, c2) checkDup(t, c1, true) checkDup(t, c2, true) checkDisabled(t, c1, false) checkDisabled(t, c2, false) s.noteClientActivity(c2) checkDisabled(t, c1, false) checkDisabled(t, c2, false) s.noteClientActivity(c1) checkDisabled(t, c1, true) checkDisabled(t, c2, true) wantActive(t, nil) s.registerClient(c3) wantActive(t, c3) checkDisabled(t, c3, false) wantDupKeys(t, 1) wantDupConns(t, 3) s.unregisterClient(c3) wantActive(t, nil) wantDupKeys(t, 1) wantDupConns(t, 2) s.unregisterClient(c2) wantSingleClient(t, c1) wantDupKeys(t, 0) wantDupConns(t, 0) }) // Key cloning with an A->B->C->A series instead. run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) s.registerClient(c2) s.registerClient(c3) s.noteClientActivity(c1) checkDisabled(t, c1, true) checkDisabled(t, c2, true) checkDisabled(t, c3, true) wantActive(t, nil) }) run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) { wantNoClient(t) // Last registered client is the active one... s.registerClient(c1) wantActive(t, c1) s.registerClient(c2) wantActive(t, c2) s.registerClient(c3) s.noteClientActivity(c2) wantActive(t, c3) // But if the last one goes away, the one with the // most recent activity wins. s.unregisterClient(c3) wantActive(t, c2) }) run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantActive(t, c1) s.registerClient(c2) wantActive(t, c2) s.noteClientActivity(c1) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c1) s.noteClientActivity(c2) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c2) s.unregisterClient(c2) checkDisabled(t, c1, false) wantActive(t, c1) }) } func TestLimiter(t *testing.T) { rl := rate.NewLimiter(rate.Every(time.Minute), 100) for i := range 200 { r := rl.Reserve() d := r.Delay() t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) } } // BenchmarkConcurrentStreams exercises mutex contention on a // single Server instance with multiple concurrent client flows. func BenchmarkConcurrentStreams(b *testing.B) { serverPrivateKey := key.NewNode() s := New(serverPrivateKey, logger.Discard) defer s.Close() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatal(err) } ctx := b.Context() acceptDone := make(chan struct{}) go func() { defer close(acceptDone) for { connIn, err := ln.Accept() if err != nil { return } brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) go s.Accept(ctx, connIn, brwServer, "test-client") } }() newClient := func(t testing.TB) *derp.Client { t.Helper() connOut, err := net.Dial("tcp", ln.Addr().String()) if err != nil { b.Fatal(err) } t.Cleanup(func() { connOut.Close() }) k := key.NewNode() brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) client, err := derp.NewClient(k, connOut, brw, logger.Discard) if err != nil { b.Fatalf("client: %v", err) } return client } b.RunParallel(func(pb *testing.PB) { c1, c2 := newClient(b), newClient(b) const packetSize = 100 msg := make([]byte, packetSize) for pb.Next() { if err := c1.Send(c2.PublicKey(), msg); err != nil { b.Fatal(err) } _, err := c2.Recv() if err != nil { return } } }) ln.Close() <-acceptDone } func BenchmarkSendRecv(b *testing.B) { for _, size := range []int{10, 100, 1000, 10000} { b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) }) } } func benchmarkSendRecvSize(b *testing.B, packetSize int) { serverPrivateKey := key.NewNode() s := New(serverPrivateKey, logger.Discard) defer s.Close() k := key.NewNode() clientKey := k.Public() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatal(err) } defer ln.Close() connOut, err := net.Dial("tcp", ln.Addr().String()) if err != nil { b.Fatal(err) } defer connOut.Close() connIn, err := ln.Accept() if err != nil { b.Fatal(err) } defer connIn.Close() brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go s.Accept(ctx, connIn, brwServer, "test-client") brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) client, err := derp.NewClient(k, connOut, brw, logger.Discard) if err != nil { b.Fatalf("client: %v", err) } go func() { for { _, err := client.Recv() if err != nil { return } } }() msg := make([]byte, packetSize) b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for range b.N { if err := client.Send(clientKey, msg); err != nil { b.Fatal(err) } } } func TestParseSSOutput(t *testing.T) { contents, err := os.ReadFile("testdata/example_ss.txt") if err != nil { t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err) } seen := parseSSOutput(string(contents)) if len(seen) == 0 { t.Errorf("parseSSOutput expected non-empty map") } } func TestServeDebugTrafficUniqueSenders(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() clientKey := key.NewNode().Public() c := &sclient{ key: clientKey, s: s, logf: logger.Discard, senderCardinality: hyperloglog.New(), } for range 5 { c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil)) } s.mu.Lock() cs := &clientSet{} cs.activeClient.Store(c) s.clients[clientKey] = cs s.mu.Unlock() estimate := c.EstimatedUniqueSenders() t.Logf("Estimated unique senders: %d", estimate) if estimate < 4 || estimate > 6 { t.Errorf("EstimatedUniqueSenders() = %d, want ~5 (4-6 range)", estimate) } } func TestGetPerClientSendQueueDepth(t *testing.T) { c := qt.New(t) envKey := "TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH" testCases := []struct { envVal string want int }{ // Empty case, envknob treats empty as missing also. { "", defaultPerClientSendQueueDepth, }, { "64", 64, }, } for _, tc := range testCases { t.Run(cmp.Or(tc.envVal, "empty"), func(t *testing.T) { t.Setenv(envKey, tc.envVal) val := getPerClientSendQueueDepth() c.Assert(val, qt.Equals, tc.want) }) } } func TestSenderCardinality(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() c := &sclient{ key: key.NewNode().Public(), s: s, logf: logger.WithPrefix(t.Logf, "test client: "), } if got := c.EstimatedUniqueSenders(); got != 0 { t.Errorf("EstimatedUniqueSenders() before init = %d, want 0", got) } c.senderCardinality = hyperloglog.New() if got := c.EstimatedUniqueSenders(); got != 0 { t.Errorf("EstimatedUniqueSenders() with no senders = %d, want 0", got) } senders := make([]key.NodePublic, 10) for i := range senders { senders[i] = key.NewNode().Public() c.senderCardinality.Insert(senders[i].AppendTo(nil)) } estimate := c.EstimatedUniqueSenders() t.Logf("Estimated unique senders after 10 inserts: %d", estimate) if estimate < 8 || estimate > 12 { t.Errorf("EstimatedUniqueSenders() = %d, want ~10 (8-12 range)", estimate) } for i := range 5 { c.senderCardinality.Insert(senders[i].AppendTo(nil)) } estimate2 := c.EstimatedUniqueSenders() t.Logf("Estimated unique senders after duplicates: %d", estimate2) if estimate2 < 8 || estimate2 > 12 { t.Errorf("EstimatedUniqueSenders() after duplicates = %d, want ~10 (8-12 range)", estimate2) } } func TestSenderCardinality100(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() c := &sclient{ key: key.NewNode().Public(), s: s, logf: logger.WithPrefix(t.Logf, "test client: "), senderCardinality: hyperloglog.New(), } numSenders := 100 for range numSenders { c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil)) } estimate := c.EstimatedUniqueSenders() t.Logf("Estimated unique senders for 100 actual senders: %d", estimate) if estimate < 85 || estimate > 115 { t.Errorf("EstimatedUniqueSenders() = %d, want ~100 (85-115 range)", estimate) } } func TestSenderCardinalityTracking(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() c := &sclient{ key: key.NewNode().Public(), s: s, logf: logger.WithPrefix(t.Logf, "test client: "), senderCardinality: hyperloglog.New(), } zeroKey := key.NodePublic{} if zeroKey != (key.NodePublic{}) { c.senderCardinality.Insert(zeroKey.AppendTo(nil)) } if estimate := c.EstimatedUniqueSenders(); estimate != 0 { t.Errorf("EstimatedUniqueSenders() after zero key = %d, want 0", estimate) } sender1 := key.NewNode().Public() sender2 := key.NewNode().Public() if sender1 != (key.NodePublic{}) { c.senderCardinality.Insert(sender1.AppendTo(nil)) } if sender2 != (key.NodePublic{}) { c.senderCardinality.Insert(sender2.AppendTo(nil)) } estimate := c.EstimatedUniqueSenders() t.Logf("Estimated unique senders after 2 senders: %d", estimate) if estimate < 1 || estimate > 3 { t.Errorf("EstimatedUniqueSenders() = %d, want ~2 (1-3 range)", estimate) } } func BenchmarkHyperLogLogInsert(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public() senderBytes := sender.AppendTo(nil) b.ResetTimer() for i := 0; i < b.N; i++ { hll.Insert(senderBytes) } } func BenchmarkHyperLogLogInsertUnique(b *testing.B) { hll := hyperloglog.New() b.ResetTimer() buf := make([]byte, 32) for i := 0; i < b.N; i++ { binary.LittleEndian.PutUint64(buf, uint64(i)) hll.Insert(buf) } } func BenchmarkHyperLogLogEstimate(b *testing.B) { hll := hyperloglog.New() for range 100 { hll.Insert(key.NewNode().Public().AppendTo(nil)) } b.ResetTimer() for i := 0; i < b.N; i++ { _ = hll.Estimate() } } func TestPerClientRateLimit(t *testing.T) { t.Run("throttled", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) c := &sclient{ ctx: ctx, } lim := &parentChildTokenBuckets{ // Set parent limit to half of child to enable verification of // rate limiting across both layers with a single sclient. parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize)/2, minRateLimitTokenBucketSize), child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), } c.recvLim.Store(lim) wantTokens := func(t *testing.T, wantParentTokens, wantChildTokens float64) { t.Helper() if lim.parent.Tokens() != wantParentTokens { t.Fatalf("want parent tokens: %v got: %v", wantParentTokens, lim.parent.Tokens()) } if lim.child.Tokens() != wantChildTokens { t.Fatalf("want child tokens: %v got: %v", wantChildTokens, lim.child.Tokens()) } } // First call within burst should not block. c.rateLimit(minRateLimitTokenBucketSize) wantTokens(t, 0, 0) // Next call exceeds burst, should block until tokens replenish. done := make(chan error, 1) go func() { done <- c.rateLimit(minRateLimitTokenBucketSize) }() // After settling, the goroutine should be blocked (no result yet). synctest.Wait() select { case err := <-done: t.Fatalf("rateLimit should have blocked, but returned: %v", err) default: } // Advance time by 1 second, the goroutine should still be blocked // on the parent bucket (negative tokens). time.Sleep(1 * time.Second) synctest.Wait() select { case err := <-done: t.Fatalf("rateLimit should have blocked, but returned: %v", err) default: } // Verify the parent bucket fills at half the rate of the child. wantTokens(t, -(minRateLimitTokenBucketSize / 2), 0) // Advance time by another second, parent should have enough tokens // to unblock. time.Sleep(1 * time.Second) synctest.Wait() select { case err := <-done: if err != nil { t.Fatalf("rateLimit after time advance: %v", err) } default: t.Fatal("rateLimit should have unblocked after 1s") } wantTokens(t, 0, minRateLimitTokenBucketSize) }) }) t.Run("context_canceled", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) c := &sclient{ ctx: ctx, } lim := &parentChildTokenBuckets{ child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), } c.recvLim.Store(lim) // Exhaust burst. if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil { t.Fatalf("rateLimit: %v", err) } done := make(chan error, 1) go func() { done <- c.rateLimit(minRateLimitTokenBucketSize) }() synctest.Wait() // Cancel the context; the blocked rateLimit should return an error. cancel() synctest.Wait() select { case err := <-done: if err == nil { t.Fatal("expected error from canceled context") } default: t.Fatal("rateLimit should have returned after context cancelation") } }) }) t.Run("mesh_peer_exempt", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) // Mesh peers have nil recvLim, so rate limiting is a no-op. c := &sclient{ ctx: ctx, canMesh: true, } if err := c.rateLimit(1000); err != nil { t.Fatalf("mesh peer rateLimit should be no-op: %v", err) } }) t.Run("zero_config_no_limiter", func(t *testing.T) { s := New(key.NewNode(), logger.Discard) defer s.Close() if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { t.Errorf("expected zero rate limit, got %+v", s.rateConfig) } }) } func verifyLimiter(t *testing.T, lim *parentChildTokenBuckets, wantRateConfig RateConfig) { t.Helper() if got := lim.child.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) { t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec) } if got := lim.child.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) { t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes) } if got := lim.parent.Limit(); got != rate.Limit(wantRateConfig.GlobalRateLimitBytesPerSec) { t.Errorf("global rate limit = %v, want %d", got, wantRateConfig.GlobalRateLimitBytesPerSec) } if got := lim.parent.Burst(); got != int(wantRateConfig.GlobalRateBurstBytes) { t.Errorf("global burst = %v, want %d", got, wantRateConfig.GlobalRateBurstBytes) } } func TestUpdateRateLimits(t *testing.T) { const ( testClientBurst1 = minRateLimitTokenBucketSize + 1 testClientRate1 = minRateLimitTokenBucketSize + 2 testClientBurst2 = minRateLimitTokenBucketSize + 3 testClientRate2 = minRateLimitTokenBucketSize + 4 testGlobalBurst1 = minRateLimitTokenBucketSize + 5 testGlobalRate1 = minRateLimitTokenBucketSize + 6 testGlobalBurst2 = minRateLimitTokenBucketSize + 7 testGlobalRate2 = minRateLimitTokenBucketSize + 8 ) s := New(key.NewNode(), t.Logf) defer s.Close() // Create a non-mesh client with no initial limiter. clientKey := key.NewNode().Public() c := &sclient{ key: clientKey, s: s, logf: logger.Discard, canMesh: false, } cs := &clientSet{} cs.activeClient.Store(c) s.mu.Lock() s.clients[clientKey] = cs s.mu.Unlock() rc := RateConfig{ PerClientRateLimitBytesPerSec: testClientRate1, PerClientRateBurstBytes: testClientBurst1, GlobalRateLimitBytesPerSec: testGlobalRate1, GlobalRateBurstBytes: testGlobalBurst1, } s.UpdateRateLimits(rc) lim := c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter after update") } verifyLimiter(t, lim, rc) // Verify server fields updated. s.mu.Lock() if !reflect.DeepEqual(s.rateConfig, rc) { t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc) } s.mu.Unlock() // Update again with different nonzero values. rc = RateConfig{ PerClientRateLimitBytesPerSec: testClientRate2, PerClientRateBurstBytes: testClientBurst2, GlobalRateLimitBytesPerSec: testGlobalRate2, GlobalRateBurstBytes: testGlobalBurst2, } s.UpdateRateLimits(rc) lim = c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter") } verifyLimiter(t, lim, rc) // Disable rate limiting (set to 0). s.UpdateRateLimits(RateConfig{}) if got := c.recvLim.Load(); got != nil { t.Errorf("expected nil limiter after disable, got limit=%v", got.child.Limit()) } // Mesh peer should always have nil limiter regardless of update. meshKey := key.NewNode().Public() meshClient := &sclient{ key: meshKey, s: s, logf: logger.Discard, canMesh: true, } meshCS := &clientSet{} meshCS.activeClient.Store(meshClient) s.mu.Lock() s.clients[meshKey] = meshCS s.mu.Unlock() rc = RateConfig{ PerClientRateLimitBytesPerSec: testClientRate2, PerClientRateBurstBytes: testClientBurst2, GlobalRateLimitBytesPerSec: testGlobalRate2, GlobalRateBurstBytes: testGlobalBurst2, } s.UpdateRateLimits(rc) if got := meshClient.recvLim.Load(); got != nil { t.Errorf("mesh peer should have nil limiter, got limit=%v", got.child.Limit()) } // Non-mesh client should be updated. lim = c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter for non-mesh client") } verifyLimiter(t, lim, rc) // Verify dup clients are also updated. dupKey := key.NewNode().Public() d1 := &sclient{key: dupKey, s: s, logf: logger.Discard} d2 := &sclient{key: dupKey, s: s, logf: logger.Discard} dupCS := &clientSet{} dupCS.activeClient.Store(d1) dupCS.dup = &dupClientSet{set: set.Of(d1, d2)} s.mu.Lock() s.clients[dupKey] = dupCS s.mu.Unlock() rc = RateConfig{ GlobalRateLimitBytesPerSec: testGlobalRate1, GlobalRateBurstBytes: testGlobalBurst1, PerClientRateLimitBytesPerSec: testClientRate1, PerClientRateBurstBytes: testClientBurst1, } s.UpdateRateLimits(rc) for i, d := range []*sclient{d1, d2} { dl := d.recvLim.Load() if dl == nil { t.Fatalf("dup client %d: expected non-nil limiter", i) } verifyLimiter(t, dl, rc) } } func TestLoadRateConfig(t *testing.T) { for _, tt := range []struct { name string json string wantRateConfig RateConfig }{ {"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2, "GlobalRateLimitBytesPerSec": 3, "GlobalRateBurstBytes": 4}`, RateConfig{ PerClientRateLimitBytesPerSec: 1, PerClientRateBurstBytes: 2, GlobalRateLimitBytesPerSec: 3, GlobalRateBurstBytes: 4, }}, {"rate_only", `{"PerClientRateLimitBytesPerSec": 1, "GlobalRateLimitBytesPerSec": 3}`, RateConfig{ PerClientRateLimitBytesPerSec: 1, GlobalRateLimitBytesPerSec: 3, }}, {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0, "GlobalRateLimitBytesPerSec": 0, "GlobalRateBurstBytes": 0}`, RateConfig{}}, {"empty_json", `{}`, RateConfig{}}, } { t.Run(tt.name, func(t *testing.T) { f := filepath.Join(t.TempDir(), "rate.json") if err := os.WriteFile(f, []byte(tt.json), 0644); err != nil { t.Fatal(err) } rc, err := LoadRateConfig(f) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(rc, tt.wantRateConfig) { t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig) } }) } for _, tt := range []struct { name string path string content string // written to loaded path if non-empty; path used as-is if empty }{ {"empty_path", "", ""}, {"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""}, {"invalid_json", "", "not json"}, } { t.Run(tt.name, func(t *testing.T) { path := tt.path if tt.content != "" { path = filepath.Join(t.TempDir(), "rate.json") if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { t.Fatal(err) } } _, err := LoadRateConfig(path) if err == nil { t.Fatal("expected error") } }) } } func TestLoadAndApplyRateConfig(t *testing.T) { writeConfig := func(t *testing.T, json string) string { t.Helper() f := filepath.Join(t.TempDir(), "rate.json") if err := os.WriteFile(f, []byte(json), 0644); err != nil { t.Fatal(err) } return f } t.Run("applies_and_updates_clients", func(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() clientKey := key.NewNode().Public() c := &sclient{key: clientKey, s: s, logf: logger.Discard} cs := &clientSet{} cs.activeClient.Store(c) s.mu.Lock() s.clients[clientKey] = cs s.mu.Unlock() f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d, "GlobalRateLimitBytesPerSec": %d, "GlobalRateBurstBytes": %d}`, minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1, minRateLimitTokenBucketSize+2, minRateLimitTokenBucketSize+3)) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatalf("LoadAndApplyRateConfig: %v", err) } // Verify server fields. wantRateConfig := RateConfig{ PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize, PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1, GlobalRateLimitBytesPerSec: minRateLimitTokenBucketSize + 2, GlobalRateBurstBytes: minRateLimitTokenBucketSize + 3, } s.mu.Lock() if !reflect.DeepEqual(s.rateConfig, wantRateConfig) { t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig) } s.mu.Unlock() // Verify client limiter. lim := c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter") } verifyLimiter(t, lim, wantRateConfig) }) t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10, "GlobalRateLimitBytesPerSec": 1250000, "GlobalRateBurstBytes": 10}`) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatalf("LoadAndApplyRateConfig: %v", err) } s.mu.Lock() gotClientBurst := s.rateConfig.PerClientRateBurstBytes gotGlobalBurst := s.rateConfig.GlobalRateBurstBytes s.mu.Unlock() if gotClientBurst != minRateLimitTokenBucketSize { t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize) } if gotGlobalBurst != minRateLimitTokenBucketSize { t.Errorf("global burst = %d; want %d", gotGlobalBurst, minRateLimitTokenBucketSize) } }) t.Run("reload_disables_limiting", func(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000, "GlobalRateLimitBytesPerSec": 12500000, "GlobalRateBurstBytes": 25000000}`) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatal(err) } s.mu.Lock() if reflect.DeepEqual(s.rateConfig, RateConfig{}) { t.Error("s.rateConfig is zero val; want nonzero rates") } s.mu.Unlock() if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil { t.Fatal(err) } if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatal(err) } s.mu.Lock() if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{}) } s.mu.Unlock() }) t.Run("propagates_errors", func(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() if err := s.LoadAndApplyRateConfig(filepath.Join(t.TempDir(), "nonexistent.json")); err == nil { t.Fatal("expected error") } }) } func BenchmarkSenderCardinalityOverhead(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public() b.Run("WithTracking", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { if hll != nil { hll.Insert(sender.AppendTo(nil)) } } }) b.Run("WithoutTracking", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { _ = sender.AppendTo(nil) } }) }