mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 20:26:47 +02:00
By adding a server-global parent bucket. Per-client rate limiting is subject to the parent bucket if global rate limiting is enabled. This implementation is experimental, and all related APIs should be considered unstable. Updates tailscale/corp#40291 Signed-off-by: Jordan Whited <jordan@tailscale.com>
1433 lines
36 KiB
Go
1433 lines
36 KiB
Go
// Copyright (c) Tailscale Inc & contributors
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package derpserver
|
|
|
|
import (
|
|
"bufio"
|
|
"cmp"
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/asn1"
|
|
"encoding/binary"
|
|
"expvar"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
"testing/synctest"
|
|
"time"
|
|
|
|
"github.com/axiomhq/hyperloglog"
|
|
qt "github.com/frankban/quicktest"
|
|
"go4.org/mem"
|
|
"golang.org/x/time/rate"
|
|
"tailscale.com/derp"
|
|
"tailscale.com/derp/derpconst"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/util/set"
|
|
)
|
|
|
|
const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
|
|
func TestSetMeshKey(t *testing.T) {
|
|
for name, tt := range map[string]struct {
|
|
key string
|
|
want key.DERPMesh
|
|
wantErr bool
|
|
}{
|
|
"clobber": {
|
|
key: testMeshKey,
|
|
wantErr: false,
|
|
},
|
|
"invalid": {
|
|
key: "badf00d",
|
|
wantErr: true,
|
|
},
|
|
} {
|
|
t.Run(name, func(t *testing.T) {
|
|
s := &Server{}
|
|
|
|
err := s.SetMeshKey(tt.key)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatalf("expected err")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
want, err := key.ParseDERPMesh(tt.key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !s.meshKey.Equal(want) {
|
|
t.Fatalf("got %v, want %v", s.meshKey, want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsMeshPeer(t *testing.T) {
|
|
s := &Server{}
|
|
err := s.SetMeshKey(testMeshKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for name, tt := range map[string]struct {
|
|
want bool
|
|
meshKey string
|
|
wantAllocs float64
|
|
}{
|
|
"nil": {
|
|
want: false,
|
|
wantAllocs: 0,
|
|
},
|
|
"mismatch": {
|
|
meshKey: "6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8",
|
|
want: false,
|
|
wantAllocs: 1,
|
|
},
|
|
"match": {
|
|
meshKey: testMeshKey,
|
|
want: true,
|
|
wantAllocs: 0,
|
|
},
|
|
} {
|
|
t.Run(name, func(t *testing.T) {
|
|
var got bool
|
|
var mKey key.DERPMesh
|
|
if tt.meshKey != "" {
|
|
mKey, err = key.ParseDERPMesh(tt.meshKey)
|
|
if err != nil {
|
|
t.Fatalf("ParseDERPMesh(%q) failed: %v", tt.meshKey, err)
|
|
}
|
|
}
|
|
|
|
info := derp.ClientInfo{
|
|
MeshKey: mKey,
|
|
}
|
|
allocs := testing.AllocsPerRun(1, func() {
|
|
got = s.isMeshPeer(&info)
|
|
})
|
|
if got != tt.want {
|
|
t.Fatalf("got %t, want %t: info = %#v", got, tt.want, info)
|
|
}
|
|
|
|
if allocs != tt.wantAllocs && tt.want {
|
|
t.Errorf("%f allocations, want %f", allocs, tt.wantAllocs)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type testFwd int
|
|
|
|
func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error {
|
|
panic("not called in tests")
|
|
}
|
|
func (testFwd) String() string {
|
|
panic("not called in tests")
|
|
}
|
|
|
|
func pubAll(b byte) (ret key.NodePublic) {
|
|
var bs [32]byte
|
|
for i := range bs {
|
|
bs[i] = b
|
|
}
|
|
return key.NodePublicFromRaw32(mem.B(bs[:]))
|
|
}
|
|
|
|
func TestForwarderRegistration(t *testing.T) {
|
|
s := &Server{
|
|
clients: make(map[key.NodePublic]*clientSet),
|
|
clientsMesh: map[key.NodePublic]PacketForwarder{},
|
|
}
|
|
want := func(want map[key.NodePublic]PacketForwarder) {
|
|
t.Helper()
|
|
if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
|
|
}
|
|
}
|
|
wantCounter := func(c *expvar.Int, want int) {
|
|
t.Helper()
|
|
if got := c.Value(); got != int64(want) {
|
|
t.Errorf("counter = %v; want %v", got, want)
|
|
}
|
|
}
|
|
singleClient := func(c *sclient) *clientSet {
|
|
cs := &clientSet{}
|
|
cs.activeClient.Store(c)
|
|
return cs
|
|
}
|
|
|
|
u1 := pubAll(1)
|
|
u2 := pubAll(2)
|
|
u3 := pubAll(3)
|
|
|
|
s.AddPacketForwarder(u1, testFwd(1))
|
|
s.AddPacketForwarder(u2, testFwd(2))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(1),
|
|
u2: testFwd(2),
|
|
})
|
|
|
|
// Verify a remove of non-registered forwarder is no-op.
|
|
s.RemovePacketForwarder(u2, testFwd(999))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(1),
|
|
u2: testFwd(2),
|
|
})
|
|
|
|
// Verify a remove of non-registered user is no-op.
|
|
s.RemovePacketForwarder(u3, testFwd(1))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(1),
|
|
u2: testFwd(2),
|
|
})
|
|
|
|
// Actual removal.
|
|
s.RemovePacketForwarder(u2, testFwd(2))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(1),
|
|
})
|
|
|
|
// Adding a dup for a user.
|
|
wantCounter(&s.multiForwarderCreated, 0)
|
|
s.AddPacketForwarder(u1, testFwd(100))
|
|
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
|
})
|
|
wantCounter(&s.multiForwarderCreated, 1)
|
|
|
|
// Removing a forwarder in a multi set that doesn't exist; does nothing.
|
|
s.RemovePacketForwarder(u1, testFwd(55))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
|
})
|
|
|
|
// Removing a forwarder in a multi set that does exist should collapse it away
|
|
// from being a multiForwarder.
|
|
wantCounter(&s.multiForwarderDeleted, 0)
|
|
s.RemovePacketForwarder(u1, testFwd(1))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(100),
|
|
})
|
|
wantCounter(&s.multiForwarderDeleted, 1)
|
|
|
|
// Removing an entry for a client that's still connected locally should result
|
|
// in a nil forwarder.
|
|
u1c := &sclient{
|
|
key: u1,
|
|
logf: logger.Discard,
|
|
}
|
|
s.clients[u1] = singleClient(u1c)
|
|
s.RemovePacketForwarder(u1, testFwd(100))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: nil,
|
|
})
|
|
|
|
// But once that client disconnects, it should go away.
|
|
s.unregisterClient(u1c)
|
|
want(map[key.NodePublic]PacketForwarder{})
|
|
|
|
// But if it already has a forwarder, it's not removed.
|
|
s.AddPacketForwarder(u1, testFwd(2))
|
|
s.unregisterClient(u1c)
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(2),
|
|
})
|
|
|
|
// Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard
|
|
// that they're also connected to a peer of ours. That shouldn't transition the forwarder
|
|
// from nil to the new one, not a multiForwarder.
|
|
s.clients[u1] = singleClient(u1c)
|
|
s.clientsMesh[u1] = nil
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: nil,
|
|
})
|
|
s.AddPacketForwarder(u1, testFwd(3))
|
|
want(map[key.NodePublic]PacketForwarder{
|
|
u1: testFwd(3),
|
|
})
|
|
}
|
|
|
|
type channelFwd struct {
|
|
// id is to ensure that different instances that reference the
|
|
// same channel are not equal, as they are used as keys in the
|
|
// multiForwarder map.
|
|
id int
|
|
c chan []byte
|
|
}
|
|
|
|
func (f channelFwd) String() string { return "" }
|
|
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
|
|
f.c <- packet
|
|
return nil
|
|
}
|
|
|
|
func TestMultiForwarder(t *testing.T) {
|
|
received := 0
|
|
var wg sync.WaitGroup
|
|
ch := make(chan []byte)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
s := &Server{
|
|
clients: make(map[key.NodePublic]*clientSet),
|
|
clientsMesh: map[key.NodePublic]PacketForwarder{},
|
|
}
|
|
u := pubAll(1)
|
|
s.AddPacketForwarder(u, channelFwd{1, ch})
|
|
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-ch:
|
|
received += 1
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
s.AddPacketForwarder(u, channelFwd{2, ch})
|
|
s.AddPacketForwarder(u, channelFwd{3, ch})
|
|
s.RemovePacketForwarder(u, channelFwd{2, ch})
|
|
s.RemovePacketForwarder(u, channelFwd{1, ch})
|
|
s.AddPacketForwarder(u, channelFwd{1, ch})
|
|
s.RemovePacketForwarder(u, channelFwd{3, ch})
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Number of messages is chosen arbitrarily, just for this loop to
|
|
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
|
|
numMsgs := 5000
|
|
var fwd PacketForwarder
|
|
for i := range numMsgs {
|
|
s.mu.Lock()
|
|
fwd = s.clientsMesh[u]
|
|
s.mu.Unlock()
|
|
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i)))
|
|
}
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
if received != numMsgs {
|
|
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received)
|
|
}
|
|
}
|
|
func TestMetaCert(t *testing.T) {
|
|
priv := key.NewNode()
|
|
pub := priv.Public()
|
|
s := New(priv, t.Logf)
|
|
|
|
certBytes := s.MetaCert()
|
|
cert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(derp.ProtocolVersion) {
|
|
t.Errorf("serial = %v; want %v", cert.SerialNumber, derp.ProtocolVersion)
|
|
}
|
|
if g, w := cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix+pub.UntypedHexString(); g != w {
|
|
t.Errorf("CommonName = %q; want %q", g, w)
|
|
}
|
|
if n := len(cert.Extensions); n != 1 {
|
|
t.Fatalf("got %d extensions; want 1", n)
|
|
}
|
|
|
|
// oidExtensionBasicConstraints is the Basic Constraints ID copied
|
|
// from the x509 package.
|
|
oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19}
|
|
|
|
if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) {
|
|
t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints)
|
|
}
|
|
}
|
|
|
|
func TestServerDupClients(t *testing.T) {
|
|
serverPriv := key.NewNode()
|
|
var s *Server
|
|
|
|
clientPriv := key.NewNode()
|
|
clientPub := clientPriv.Public()
|
|
|
|
var c1, c2, c3 *sclient
|
|
var clientName map[*sclient]string
|
|
|
|
// run starts a new test case and resets clients back to their zero values.
|
|
run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) {
|
|
s = New(serverPriv, t.Logf)
|
|
s.dupPolicy = dupPolicy
|
|
c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")}
|
|
c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")}
|
|
c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")}
|
|
clientName = map[*sclient]string{
|
|
c1: "c1",
|
|
c2: "c2",
|
|
c3: "c3",
|
|
}
|
|
t.Run(name, f)
|
|
}
|
|
runBothWays := func(name string, f func(t *testing.T)) {
|
|
run(name+"_disablefighters", disableFighters, f)
|
|
run(name+"_lastwriteractive", lastWriterIsActive, f)
|
|
}
|
|
wantSingleClient := func(t *testing.T, want *sclient) {
|
|
t.Helper()
|
|
got, ok := s.clients[want.key]
|
|
if !ok {
|
|
t.Error("no clients for key")
|
|
return
|
|
}
|
|
if got.dup != nil {
|
|
t.Errorf("unexpected dup set for single client")
|
|
}
|
|
cur := got.activeClient.Load()
|
|
if cur != want {
|
|
t.Errorf("active client = %q; want %q", clientName[cur], clientName[want])
|
|
}
|
|
if cur != nil {
|
|
if cur.isDup.Load() {
|
|
t.Errorf("unexpected isDup on singleClient")
|
|
}
|
|
if cur.isDisabled.Load() {
|
|
t.Errorf("unexpected isDisabled on singleClient")
|
|
}
|
|
}
|
|
}
|
|
wantNoClient := func(t *testing.T) {
|
|
t.Helper()
|
|
_, ok := s.clients[clientPub]
|
|
if !ok {
|
|
// Good
|
|
return
|
|
}
|
|
t.Errorf("got client; want empty")
|
|
}
|
|
wantDupSet := func(t *testing.T) *dupClientSet {
|
|
t.Helper()
|
|
cs, ok := s.clients[clientPub]
|
|
if !ok {
|
|
t.Fatal("no set for key; want dup set")
|
|
return nil
|
|
}
|
|
if cs.dup != nil {
|
|
return cs.dup
|
|
}
|
|
t.Fatalf("no dup set for key; want dup set")
|
|
return nil
|
|
}
|
|
wantActive := func(t *testing.T, want *sclient) {
|
|
t.Helper()
|
|
set, ok := s.clients[clientPub]
|
|
if !ok {
|
|
t.Error("no set for key")
|
|
return
|
|
}
|
|
got := set.activeClient.Load()
|
|
if got != want {
|
|
t.Errorf("active client = %q; want %q", clientName[got], clientName[want])
|
|
}
|
|
}
|
|
checkDup := func(t *testing.T, c *sclient, want bool) {
|
|
t.Helper()
|
|
if got := c.isDup.Load(); got != want {
|
|
t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want)
|
|
}
|
|
}
|
|
checkDisabled := func(t *testing.T, c *sclient, want bool) {
|
|
t.Helper()
|
|
if got := c.isDisabled.Load(); got != want {
|
|
t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want)
|
|
}
|
|
}
|
|
wantDupConns := func(t *testing.T, want int) {
|
|
t.Helper()
|
|
if got := s.dupClientConns.Value(); got != int64(want) {
|
|
t.Errorf("dupClientConns = %v; want %v", got, want)
|
|
}
|
|
}
|
|
wantDupKeys := func(t *testing.T, want int) {
|
|
t.Helper()
|
|
if got := s.dupClientKeys.Value(); got != int64(want) {
|
|
t.Errorf("dupClientKeys = %v; want %v", got, want)
|
|
}
|
|
}
|
|
|
|
// Common case: a single client comes and goes, with no dups.
|
|
runBothWays("one_comes_and_goes", func(t *testing.T) {
|
|
wantNoClient(t)
|
|
s.registerClient(c1)
|
|
wantSingleClient(t, c1)
|
|
s.unregisterClient(c1)
|
|
wantNoClient(t)
|
|
})
|
|
|
|
// A still somewhat common case: a single client was
|
|
// connected and then their wifi dies or laptop closes
|
|
// or they switch networks and connect from a
|
|
// different network. They have two connections but
|
|
// it's not very bad. Only their new one is
|
|
// active. The last one, being dead, doesn't send and
|
|
// thus the new one doesn't get disabled.
|
|
runBothWays("small_overlap_replacement", func(t *testing.T) {
|
|
wantNoClient(t)
|
|
s.registerClient(c1)
|
|
wantSingleClient(t, c1)
|
|
wantActive(t, c1)
|
|
wantDupKeys(t, 0)
|
|
wantDupKeys(t, 0)
|
|
|
|
s.registerClient(c2) // wifi dies; c2 replacement connects
|
|
wantDupSet(t)
|
|
wantDupConns(t, 2)
|
|
wantDupKeys(t, 1)
|
|
checkDup(t, c1, true)
|
|
checkDup(t, c2, true)
|
|
checkDisabled(t, c1, false)
|
|
checkDisabled(t, c2, false)
|
|
wantActive(t, c2) // sends go to the replacement
|
|
|
|
s.unregisterClient(c1) // c1 finally times out
|
|
wantSingleClient(t, c2)
|
|
checkDup(t, c2, false) // c2 is longer a dup
|
|
wantActive(t, c2)
|
|
wantDupConns(t, 0)
|
|
wantDupKeys(t, 0)
|
|
})
|
|
|
|
// Key cloning situation with concurrent clients, both trying
|
|
// to write.
|
|
run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) {
|
|
wantNoClient(t)
|
|
s.registerClient(c1)
|
|
wantSingleClient(t, c1)
|
|
wantActive(t, c1)
|
|
s.registerClient(c2)
|
|
wantDupSet(t)
|
|
wantDupKeys(t, 1)
|
|
wantDupConns(t, 2)
|
|
wantActive(t, c2)
|
|
checkDup(t, c1, true)
|
|
checkDup(t, c2, true)
|
|
checkDisabled(t, c1, false)
|
|
checkDisabled(t, c2, false)
|
|
|
|
s.noteClientActivity(c2)
|
|
checkDisabled(t, c1, false)
|
|
checkDisabled(t, c2, false)
|
|
s.noteClientActivity(c1)
|
|
checkDisabled(t, c1, true)
|
|
checkDisabled(t, c2, true)
|
|
wantActive(t, nil)
|
|
|
|
s.registerClient(c3)
|
|
wantActive(t, c3)
|
|
checkDisabled(t, c3, false)
|
|
wantDupKeys(t, 1)
|
|
wantDupConns(t, 3)
|
|
|
|
s.unregisterClient(c3)
|
|
wantActive(t, nil)
|
|
wantDupKeys(t, 1)
|
|
wantDupConns(t, 2)
|
|
|
|
s.unregisterClient(c2)
|
|
wantSingleClient(t, c1)
|
|
wantDupKeys(t, 0)
|
|
wantDupConns(t, 0)
|
|
})
|
|
|
|
// Key cloning with an A->B->C->A series instead.
|
|
run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) {
|
|
wantNoClient(t)
|
|
s.registerClient(c1)
|
|
s.registerClient(c2)
|
|
s.registerClient(c3)
|
|
s.noteClientActivity(c1)
|
|
checkDisabled(t, c1, true)
|
|
checkDisabled(t, c2, true)
|
|
checkDisabled(t, c3, true)
|
|
wantActive(t, nil)
|
|
})
|
|
|
|
run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) {
|
|
wantNoClient(t)
|
|
|
|
// Last registered client is the active one...
|
|
s.registerClient(c1)
|
|
wantActive(t, c1)
|
|
s.registerClient(c2)
|
|
wantActive(t, c2)
|
|
s.registerClient(c3)
|
|
s.noteClientActivity(c2)
|
|
wantActive(t, c3)
|
|
|
|
// But if the last one goes away, the one with the
|
|
// most recent activity wins.
|
|
s.unregisterClient(c3)
|
|
wantActive(t, c2)
|
|
})
|
|
|
|
run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) {
|
|
wantNoClient(t)
|
|
|
|
s.registerClient(c1)
|
|
wantActive(t, c1)
|
|
s.registerClient(c2)
|
|
wantActive(t, c2)
|
|
|
|
s.noteClientActivity(c1)
|
|
checkDisabled(t, c1, false)
|
|
checkDisabled(t, c2, false)
|
|
wantActive(t, c1)
|
|
|
|
s.noteClientActivity(c2)
|
|
checkDisabled(t, c1, false)
|
|
checkDisabled(t, c2, false)
|
|
wantActive(t, c2)
|
|
|
|
s.unregisterClient(c2)
|
|
checkDisabled(t, c1, false)
|
|
wantActive(t, c1)
|
|
})
|
|
}
|
|
|
|
func TestLimiter(t *testing.T) {
|
|
rl := rate.NewLimiter(rate.Every(time.Minute), 100)
|
|
for i := range 200 {
|
|
r := rl.Reserve()
|
|
d := r.Delay()
|
|
t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d)
|
|
}
|
|
}
|
|
|
|
// BenchmarkConcurrentStreams exercises mutex contention on a
|
|
// single Server instance with multiple concurrent client flows.
|
|
func BenchmarkConcurrentStreams(b *testing.B) {
|
|
serverPrivateKey := key.NewNode()
|
|
s := New(serverPrivateKey, logger.Discard)
|
|
defer s.Close()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
|
|
ctx := b.Context()
|
|
|
|
acceptDone := make(chan struct{})
|
|
go func() {
|
|
defer close(acceptDone)
|
|
for {
|
|
connIn, err := ln.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn))
|
|
go s.Accept(ctx, connIn, brwServer, "test-client")
|
|
}
|
|
}()
|
|
|
|
newClient := func(t testing.TB) *derp.Client {
|
|
t.Helper()
|
|
connOut, err := net.Dial("tcp", ln.Addr().String())
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { connOut.Close() })
|
|
|
|
k := key.NewNode()
|
|
|
|
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut))
|
|
client, err := derp.NewClient(k, connOut, brw, logger.Discard)
|
|
if err != nil {
|
|
b.Fatalf("client: %v", err)
|
|
}
|
|
return client
|
|
}
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
c1, c2 := newClient(b), newClient(b)
|
|
const packetSize = 100
|
|
msg := make([]byte, packetSize)
|
|
for pb.Next() {
|
|
if err := c1.Send(c2.PublicKey(), msg); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
_, err := c2.Recv()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
ln.Close()
|
|
<-acceptDone
|
|
}
|
|
|
|
func BenchmarkSendRecv(b *testing.B) {
|
|
for _, size := range []int{10, 100, 1000, 10000} {
|
|
b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) })
|
|
}
|
|
}
|
|
|
|
func benchmarkSendRecvSize(b *testing.B, packetSize int) {
|
|
serverPrivateKey := key.NewNode()
|
|
s := New(serverPrivateKey, logger.Discard)
|
|
defer s.Close()
|
|
|
|
k := key.NewNode()
|
|
clientKey := k.Public()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer ln.Close()
|
|
|
|
connOut, err := net.Dial("tcp", ln.Addr().String())
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer connOut.Close()
|
|
|
|
connIn, err := ln.Accept()
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer connIn.Close()
|
|
|
|
brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
go s.Accept(ctx, connIn, brwServer, "test-client")
|
|
|
|
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut))
|
|
client, err := derp.NewClient(k, connOut, brw, logger.Discard)
|
|
if err != nil {
|
|
b.Fatalf("client: %v", err)
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
_, err := client.Recv()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
msg := make([]byte, packetSize)
|
|
b.SetBytes(int64(len(msg)))
|
|
b.ReportAllocs()
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
if err := client.Send(clientKey, msg); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseSSOutput(t *testing.T) {
|
|
contents, err := os.ReadFile("testdata/example_ss.txt")
|
|
if err != nil {
|
|
t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err)
|
|
}
|
|
seen := parseSSOutput(string(contents))
|
|
if len(seen) == 0 {
|
|
t.Errorf("parseSSOutput expected non-empty map")
|
|
}
|
|
}
|
|
|
|
func TestServeDebugTrafficUniqueSenders(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
clientKey := key.NewNode().Public()
|
|
c := &sclient{
|
|
key: clientKey,
|
|
s: s,
|
|
logf: logger.Discard,
|
|
senderCardinality: hyperloglog.New(),
|
|
}
|
|
|
|
for range 5 {
|
|
c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil))
|
|
}
|
|
|
|
s.mu.Lock()
|
|
cs := &clientSet{}
|
|
cs.activeClient.Store(c)
|
|
s.clients[clientKey] = cs
|
|
s.mu.Unlock()
|
|
|
|
estimate := c.EstimatedUniqueSenders()
|
|
t.Logf("Estimated unique senders: %d", estimate)
|
|
if estimate < 4 || estimate > 6 {
|
|
t.Errorf("EstimatedUniqueSenders() = %d, want ~5 (4-6 range)", estimate)
|
|
}
|
|
}
|
|
|
|
func TestGetPerClientSendQueueDepth(t *testing.T) {
|
|
c := qt.New(t)
|
|
envKey := "TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH"
|
|
|
|
testCases := []struct {
|
|
envVal string
|
|
want int
|
|
}{
|
|
// Empty case, envknob treats empty as missing also.
|
|
{
|
|
"", defaultPerClientSendQueueDepth,
|
|
},
|
|
{
|
|
"64", 64,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(cmp.Or(tc.envVal, "empty"), func(t *testing.T) {
|
|
t.Setenv(envKey, tc.envVal)
|
|
val := getPerClientSendQueueDepth()
|
|
c.Assert(val, qt.Equals, tc.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSenderCardinality(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
c := &sclient{
|
|
key: key.NewNode().Public(),
|
|
s: s,
|
|
logf: logger.WithPrefix(t.Logf, "test client: "),
|
|
}
|
|
|
|
if got := c.EstimatedUniqueSenders(); got != 0 {
|
|
t.Errorf("EstimatedUniqueSenders() before init = %d, want 0", got)
|
|
}
|
|
|
|
c.senderCardinality = hyperloglog.New()
|
|
|
|
if got := c.EstimatedUniqueSenders(); got != 0 {
|
|
t.Errorf("EstimatedUniqueSenders() with no senders = %d, want 0", got)
|
|
}
|
|
|
|
senders := make([]key.NodePublic, 10)
|
|
for i := range senders {
|
|
senders[i] = key.NewNode().Public()
|
|
c.senderCardinality.Insert(senders[i].AppendTo(nil))
|
|
}
|
|
|
|
estimate := c.EstimatedUniqueSenders()
|
|
t.Logf("Estimated unique senders after 10 inserts: %d", estimate)
|
|
|
|
if estimate < 8 || estimate > 12 {
|
|
t.Errorf("EstimatedUniqueSenders() = %d, want ~10 (8-12 range)", estimate)
|
|
}
|
|
|
|
for i := range 5 {
|
|
c.senderCardinality.Insert(senders[i].AppendTo(nil))
|
|
}
|
|
|
|
estimate2 := c.EstimatedUniqueSenders()
|
|
t.Logf("Estimated unique senders after duplicates: %d", estimate2)
|
|
|
|
if estimate2 < 8 || estimate2 > 12 {
|
|
t.Errorf("EstimatedUniqueSenders() after duplicates = %d, want ~10 (8-12 range)", estimate2)
|
|
}
|
|
}
|
|
|
|
func TestSenderCardinality100(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
c := &sclient{
|
|
key: key.NewNode().Public(),
|
|
s: s,
|
|
logf: logger.WithPrefix(t.Logf, "test client: "),
|
|
senderCardinality: hyperloglog.New(),
|
|
}
|
|
|
|
numSenders := 100
|
|
for range numSenders {
|
|
c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil))
|
|
}
|
|
|
|
estimate := c.EstimatedUniqueSenders()
|
|
t.Logf("Estimated unique senders for 100 actual senders: %d", estimate)
|
|
|
|
if estimate < 85 || estimate > 115 {
|
|
t.Errorf("EstimatedUniqueSenders() = %d, want ~100 (85-115 range)", estimate)
|
|
}
|
|
}
|
|
|
|
func TestSenderCardinalityTracking(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
c := &sclient{
|
|
key: key.NewNode().Public(),
|
|
s: s,
|
|
logf: logger.WithPrefix(t.Logf, "test client: "),
|
|
senderCardinality: hyperloglog.New(),
|
|
}
|
|
|
|
zeroKey := key.NodePublic{}
|
|
if zeroKey != (key.NodePublic{}) {
|
|
c.senderCardinality.Insert(zeroKey.AppendTo(nil))
|
|
}
|
|
|
|
if estimate := c.EstimatedUniqueSenders(); estimate != 0 {
|
|
t.Errorf("EstimatedUniqueSenders() after zero key = %d, want 0", estimate)
|
|
}
|
|
|
|
sender1 := key.NewNode().Public()
|
|
sender2 := key.NewNode().Public()
|
|
|
|
if sender1 != (key.NodePublic{}) {
|
|
c.senderCardinality.Insert(sender1.AppendTo(nil))
|
|
}
|
|
if sender2 != (key.NodePublic{}) {
|
|
c.senderCardinality.Insert(sender2.AppendTo(nil))
|
|
}
|
|
|
|
estimate := c.EstimatedUniqueSenders()
|
|
t.Logf("Estimated unique senders after 2 senders: %d", estimate)
|
|
|
|
if estimate < 1 || estimate > 3 {
|
|
t.Errorf("EstimatedUniqueSenders() = %d, want ~2 (1-3 range)", estimate)
|
|
}
|
|
}
|
|
|
|
func BenchmarkHyperLogLogInsert(b *testing.B) {
|
|
hll := hyperloglog.New()
|
|
sender := key.NewNode().Public()
|
|
senderBytes := sender.AppendTo(nil)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
hll.Insert(senderBytes)
|
|
}
|
|
}
|
|
|
|
func BenchmarkHyperLogLogInsertUnique(b *testing.B) {
|
|
hll := hyperloglog.New()
|
|
|
|
b.ResetTimer()
|
|
|
|
buf := make([]byte, 32)
|
|
for i := 0; i < b.N; i++ {
|
|
binary.LittleEndian.PutUint64(buf, uint64(i))
|
|
hll.Insert(buf)
|
|
}
|
|
}
|
|
|
|
func BenchmarkHyperLogLogEstimate(b *testing.B) {
|
|
hll := hyperloglog.New()
|
|
|
|
for range 100 {
|
|
hll.Insert(key.NewNode().Public().AppendTo(nil))
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_ = hll.Estimate()
|
|
}
|
|
}
|
|
|
|
func TestPerClientRateLimit(t *testing.T) {
|
|
t.Run("throttled", func(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
c := &sclient{
|
|
ctx: ctx,
|
|
}
|
|
lim := &parentChildTokenBuckets{
|
|
// Set parent limit to half of child to enable verification of
|
|
// rate limiting across both layers with a single sclient.
|
|
parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize)/2, minRateLimitTokenBucketSize),
|
|
child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
|
|
}
|
|
c.recvLim.Store(lim)
|
|
wantTokens := func(t *testing.T, wantParentTokens, wantChildTokens float64) {
|
|
t.Helper()
|
|
if lim.parent.Tokens() != wantParentTokens {
|
|
t.Fatalf("want parent tokens: %v got: %v", wantParentTokens, lim.parent.Tokens())
|
|
}
|
|
if lim.child.Tokens() != wantChildTokens {
|
|
t.Fatalf("want child tokens: %v got: %v", wantChildTokens, lim.child.Tokens())
|
|
}
|
|
}
|
|
|
|
// First call within burst should not block.
|
|
c.rateLimit(minRateLimitTokenBucketSize)
|
|
|
|
wantTokens(t, 0, 0)
|
|
|
|
// Next call exceeds burst, should block until tokens replenish.
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- c.rateLimit(minRateLimitTokenBucketSize)
|
|
}()
|
|
|
|
// After settling, the goroutine should be blocked (no result yet).
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-done:
|
|
t.Fatalf("rateLimit should have blocked, but returned: %v", err)
|
|
default:
|
|
}
|
|
|
|
// Advance time by 1 second, the goroutine should still be blocked
|
|
// on the parent bucket (negative tokens).
|
|
time.Sleep(1 * time.Second)
|
|
synctest.Wait()
|
|
select {
|
|
case err := <-done:
|
|
t.Fatalf("rateLimit should have blocked, but returned: %v", err)
|
|
default:
|
|
}
|
|
|
|
// Verify the parent bucket fills at half the rate of the child.
|
|
wantTokens(t, -(minRateLimitTokenBucketSize / 2), 0)
|
|
|
|
// Advance time by another second, parent should have enough tokens
|
|
// to unblock.
|
|
time.Sleep(1 * time.Second)
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("rateLimit after time advance: %v", err)
|
|
}
|
|
default:
|
|
t.Fatal("rateLimit should have unblocked after 1s")
|
|
}
|
|
|
|
wantTokens(t, 0, minRateLimitTokenBucketSize)
|
|
})
|
|
})
|
|
|
|
t.Run("context_canceled", func(t *testing.T) {
|
|
synctest.Test(t, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
c := &sclient{
|
|
ctx: ctx,
|
|
}
|
|
lim := &parentChildTokenBuckets{
|
|
child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
|
|
parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
|
|
}
|
|
c.recvLim.Store(lim)
|
|
|
|
// Exhaust burst.
|
|
if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil {
|
|
t.Fatalf("rateLimit: %v", err)
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- c.rateLimit(minRateLimitTokenBucketSize)
|
|
}()
|
|
synctest.Wait()
|
|
|
|
// Cancel the context; the blocked rateLimit should return an error.
|
|
cancel()
|
|
synctest.Wait()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err == nil {
|
|
t.Fatal("expected error from canceled context")
|
|
}
|
|
default:
|
|
t.Fatal("rateLimit should have returned after context cancelation")
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("mesh_peer_exempt", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
// Mesh peers have nil recvLim, so rate limiting is a no-op.
|
|
c := &sclient{
|
|
ctx: ctx,
|
|
canMesh: true,
|
|
}
|
|
|
|
if err := c.rateLimit(1000); err != nil {
|
|
t.Fatalf("mesh peer rateLimit should be no-op: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("zero_config_no_limiter", func(t *testing.T) {
|
|
s := New(key.NewNode(), logger.Discard)
|
|
defer s.Close()
|
|
if !reflect.DeepEqual(s.rateConfig, RateConfig{}) {
|
|
t.Errorf("expected zero rate limit, got %+v", s.rateConfig)
|
|
}
|
|
})
|
|
}
|
|
|
|
func verifyLimiter(t *testing.T, lim *parentChildTokenBuckets, wantRateConfig RateConfig) {
|
|
t.Helper()
|
|
if got := lim.child.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) {
|
|
t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec)
|
|
}
|
|
if got := lim.child.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) {
|
|
t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes)
|
|
}
|
|
if got := lim.parent.Limit(); got != rate.Limit(wantRateConfig.GlobalRateLimitBytesPerSec) {
|
|
t.Errorf("global rate limit = %v, want %d", got, wantRateConfig.GlobalRateLimitBytesPerSec)
|
|
}
|
|
if got := lim.parent.Burst(); got != int(wantRateConfig.GlobalRateBurstBytes) {
|
|
t.Errorf("global burst = %v, want %d", got, wantRateConfig.GlobalRateBurstBytes)
|
|
}
|
|
}
|
|
|
|
func TestUpdateRateLimits(t *testing.T) {
|
|
const (
|
|
testClientBurst1 = minRateLimitTokenBucketSize + 1
|
|
testClientRate1 = minRateLimitTokenBucketSize + 2
|
|
testClientBurst2 = minRateLimitTokenBucketSize + 3
|
|
testClientRate2 = minRateLimitTokenBucketSize + 4
|
|
testGlobalBurst1 = minRateLimitTokenBucketSize + 5
|
|
testGlobalRate1 = minRateLimitTokenBucketSize + 6
|
|
testGlobalBurst2 = minRateLimitTokenBucketSize + 7
|
|
testGlobalRate2 = minRateLimitTokenBucketSize + 8
|
|
)
|
|
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
// Create a non-mesh client with no initial limiter.
|
|
clientKey := key.NewNode().Public()
|
|
c := &sclient{
|
|
key: clientKey,
|
|
s: s,
|
|
logf: logger.Discard,
|
|
canMesh: false,
|
|
}
|
|
cs := &clientSet{}
|
|
cs.activeClient.Store(c)
|
|
|
|
s.mu.Lock()
|
|
s.clients[clientKey] = cs
|
|
s.mu.Unlock()
|
|
|
|
rc := RateConfig{
|
|
PerClientRateLimitBytesPerSec: testClientRate1,
|
|
PerClientRateBurstBytes: testClientBurst1,
|
|
GlobalRateLimitBytesPerSec: testGlobalRate1,
|
|
GlobalRateBurstBytes: testGlobalBurst1,
|
|
}
|
|
s.UpdateRateLimits(rc)
|
|
|
|
lim := c.recvLim.Load()
|
|
if lim == nil {
|
|
t.Fatal("expected non-nil limiter after update")
|
|
}
|
|
verifyLimiter(t, lim, rc)
|
|
|
|
// Verify server fields updated.
|
|
s.mu.Lock()
|
|
if !reflect.DeepEqual(s.rateConfig, rc) {
|
|
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
// Update again with different nonzero values.
|
|
rc = RateConfig{
|
|
PerClientRateLimitBytesPerSec: testClientRate2,
|
|
PerClientRateBurstBytes: testClientBurst2,
|
|
GlobalRateLimitBytesPerSec: testGlobalRate2,
|
|
GlobalRateBurstBytes: testGlobalBurst2,
|
|
}
|
|
s.UpdateRateLimits(rc)
|
|
lim = c.recvLim.Load()
|
|
if lim == nil {
|
|
t.Fatal("expected non-nil limiter")
|
|
}
|
|
verifyLimiter(t, lim, rc)
|
|
|
|
// Disable rate limiting (set to 0).
|
|
s.UpdateRateLimits(RateConfig{})
|
|
|
|
if got := c.recvLim.Load(); got != nil {
|
|
t.Errorf("expected nil limiter after disable, got limit=%v", got.child.Limit())
|
|
}
|
|
|
|
// Mesh peer should always have nil limiter regardless of update.
|
|
meshKey := key.NewNode().Public()
|
|
meshClient := &sclient{
|
|
key: meshKey,
|
|
s: s,
|
|
logf: logger.Discard,
|
|
canMesh: true,
|
|
}
|
|
meshCS := &clientSet{}
|
|
meshCS.activeClient.Store(meshClient)
|
|
|
|
s.mu.Lock()
|
|
s.clients[meshKey] = meshCS
|
|
s.mu.Unlock()
|
|
|
|
rc = RateConfig{
|
|
PerClientRateLimitBytesPerSec: testClientRate2,
|
|
PerClientRateBurstBytes: testClientBurst2,
|
|
GlobalRateLimitBytesPerSec: testGlobalRate2,
|
|
GlobalRateBurstBytes: testGlobalBurst2,
|
|
}
|
|
s.UpdateRateLimits(rc)
|
|
|
|
if got := meshClient.recvLim.Load(); got != nil {
|
|
t.Errorf("mesh peer should have nil limiter, got limit=%v", got.child.Limit())
|
|
}
|
|
// Non-mesh client should be updated.
|
|
lim = c.recvLim.Load()
|
|
if lim == nil {
|
|
t.Fatal("expected non-nil limiter for non-mesh client")
|
|
}
|
|
verifyLimiter(t, lim, rc)
|
|
|
|
// Verify dup clients are also updated.
|
|
dupKey := key.NewNode().Public()
|
|
d1 := &sclient{key: dupKey, s: s, logf: logger.Discard}
|
|
d2 := &sclient{key: dupKey, s: s, logf: logger.Discard}
|
|
dupCS := &clientSet{}
|
|
dupCS.activeClient.Store(d1)
|
|
dupCS.dup = &dupClientSet{set: set.Of(d1, d2)}
|
|
s.mu.Lock()
|
|
s.clients[dupKey] = dupCS
|
|
s.mu.Unlock()
|
|
|
|
rc = RateConfig{
|
|
GlobalRateLimitBytesPerSec: testGlobalRate1,
|
|
GlobalRateBurstBytes: testGlobalBurst1,
|
|
PerClientRateLimitBytesPerSec: testClientRate1,
|
|
PerClientRateBurstBytes: testClientBurst1,
|
|
}
|
|
s.UpdateRateLimits(rc)
|
|
for i, d := range []*sclient{d1, d2} {
|
|
dl := d.recvLim.Load()
|
|
if dl == nil {
|
|
t.Fatalf("dup client %d: expected non-nil limiter", i)
|
|
}
|
|
verifyLimiter(t, dl, rc)
|
|
}
|
|
}
|
|
|
|
func TestLoadRateConfig(t *testing.T) {
|
|
for _, tt := range []struct {
|
|
name string
|
|
json string
|
|
wantRateConfig RateConfig
|
|
}{
|
|
{"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2, "GlobalRateLimitBytesPerSec": 3, "GlobalRateBurstBytes": 4}`, RateConfig{
|
|
PerClientRateLimitBytesPerSec: 1,
|
|
PerClientRateBurstBytes: 2,
|
|
GlobalRateLimitBytesPerSec: 3,
|
|
GlobalRateBurstBytes: 4,
|
|
}},
|
|
{"rate_only", `{"PerClientRateLimitBytesPerSec": 1, "GlobalRateLimitBytesPerSec": 3}`, RateConfig{
|
|
PerClientRateLimitBytesPerSec: 1,
|
|
GlobalRateLimitBytesPerSec: 3,
|
|
}},
|
|
{"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0, "GlobalRateLimitBytesPerSec": 0, "GlobalRateBurstBytes": 0}`, RateConfig{}},
|
|
{"empty_json", `{}`, RateConfig{}},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
f := filepath.Join(t.TempDir(), "rate.json")
|
|
if err := os.WriteFile(f, []byte(tt.json), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rc, err := LoadRateConfig(f)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !reflect.DeepEqual(rc, tt.wantRateConfig) {
|
|
t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig)
|
|
}
|
|
})
|
|
}
|
|
|
|
for _, tt := range []struct {
|
|
name string
|
|
path string
|
|
content string // written to loaded path if non-empty; path used as-is if empty
|
|
}{
|
|
{"empty_path", "", ""},
|
|
{"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""},
|
|
{"invalid_json", "", "not json"},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
path := tt.path
|
|
if tt.content != "" {
|
|
path = filepath.Join(t.TempDir(), "rate.json")
|
|
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
_, err := LoadRateConfig(path)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLoadAndApplyRateConfig(t *testing.T) {
|
|
writeConfig := func(t *testing.T, json string) string {
|
|
t.Helper()
|
|
f := filepath.Join(t.TempDir(), "rate.json")
|
|
if err := os.WriteFile(f, []byte(json), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return f
|
|
}
|
|
|
|
t.Run("applies_and_updates_clients", func(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
clientKey := key.NewNode().Public()
|
|
c := &sclient{key: clientKey, s: s, logf: logger.Discard}
|
|
cs := &clientSet{}
|
|
cs.activeClient.Store(c)
|
|
s.mu.Lock()
|
|
s.clients[clientKey] = cs
|
|
s.mu.Unlock()
|
|
|
|
f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d, "GlobalRateLimitBytesPerSec": %d, "GlobalRateBurstBytes": %d}`,
|
|
minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1, minRateLimitTokenBucketSize+2, minRateLimitTokenBucketSize+3))
|
|
if err := s.LoadAndApplyRateConfig(f); err != nil {
|
|
t.Fatalf("LoadAndApplyRateConfig: %v", err)
|
|
}
|
|
|
|
// Verify server fields.
|
|
wantRateConfig := RateConfig{
|
|
PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize,
|
|
PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1,
|
|
GlobalRateLimitBytesPerSec: minRateLimitTokenBucketSize + 2,
|
|
GlobalRateBurstBytes: minRateLimitTokenBucketSize + 3,
|
|
}
|
|
s.mu.Lock()
|
|
if !reflect.DeepEqual(s.rateConfig, wantRateConfig) {
|
|
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
// Verify client limiter.
|
|
lim := c.recvLim.Load()
|
|
if lim == nil {
|
|
t.Fatal("expected non-nil limiter")
|
|
}
|
|
verifyLimiter(t, lim, wantRateConfig)
|
|
})
|
|
|
|
t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10, "GlobalRateLimitBytesPerSec": 1250000, "GlobalRateBurstBytes": 10}`)
|
|
if err := s.LoadAndApplyRateConfig(f); err != nil {
|
|
t.Fatalf("LoadAndApplyRateConfig: %v", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
gotClientBurst := s.rateConfig.PerClientRateBurstBytes
|
|
gotGlobalBurst := s.rateConfig.GlobalRateBurstBytes
|
|
s.mu.Unlock()
|
|
if gotClientBurst != minRateLimitTokenBucketSize {
|
|
t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize)
|
|
}
|
|
if gotGlobalBurst != minRateLimitTokenBucketSize {
|
|
t.Errorf("global burst = %d; want %d", gotGlobalBurst, minRateLimitTokenBucketSize)
|
|
}
|
|
})
|
|
|
|
t.Run("reload_disables_limiting", func(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000, "GlobalRateLimitBytesPerSec": 12500000, "GlobalRateBurstBytes": 25000000}`)
|
|
if err := s.LoadAndApplyRateConfig(f); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s.mu.Lock()
|
|
if reflect.DeepEqual(s.rateConfig, RateConfig{}) {
|
|
t.Error("s.rateConfig is zero val; want nonzero rates")
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := s.LoadAndApplyRateConfig(f); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
if !reflect.DeepEqual(s.rateConfig, RateConfig{}) {
|
|
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{})
|
|
}
|
|
s.mu.Unlock()
|
|
})
|
|
|
|
t.Run("propagates_errors", func(t *testing.T) {
|
|
s := New(key.NewNode(), t.Logf)
|
|
defer s.Close()
|
|
|
|
if err := s.LoadAndApplyRateConfig(filepath.Join(t.TempDir(), "nonexistent.json")); err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func BenchmarkSenderCardinalityOverhead(b *testing.B) {
|
|
hll := hyperloglog.New()
|
|
sender := key.NewNode().Public()
|
|
|
|
b.Run("WithTracking", func(b *testing.B) {
|
|
b.ReportAllocs()
|
|
for i := 0; i < b.N; i++ {
|
|
if hll != nil {
|
|
hll.Insert(sender.AppendTo(nil))
|
|
}
|
|
}
|
|
})
|
|
|
|
b.Run("WithoutTracking", func(b *testing.B) {
|
|
b.ReportAllocs()
|
|
for i := 0; i < b.N; i++ {
|
|
_ = sender.AppendTo(nil)
|
|
}
|
|
})
|
|
}
|