diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 429aff361..7bca26a6d 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -87,8 +87,7 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") - perClientRateLimit = flag.Uint("per-client-rate-limit", 0, "per-client receive rate limit in bytes/sec; 0 means unlimited. Mesh peers are exempt.") - perClientRateBurst = flag.Uint("per-client-rate-burst", 0, "per-client receive rate burst in bytes; 0 defaults to 2x the rate limit (only relevant when using nonzero --per-client-rate-limit)") + rateConfigPath = flag.String("rate-config", "", "path to JSON rate limit config file; reloaded on SIGHUP") // tcpKeepAlive is intentionally long, to reduce battery cost. There is an L7 keepalive on a higher frequency schedule. tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") @@ -195,12 +194,11 @@ func main() { s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetTCPWriteTimeout(*tcpWriteTimeout) - if *perClientRateLimit > 0 { - burst := *perClientRateBurst - if burst < 1 { - burst = *perClientRateLimit * 2 + if *rateConfigPath != "" { + if err := s.LoadAndApplyRateConfig(*rateConfigPath); err != nil { + log.Fatalf("derper: loading rate config: %v", err) } - s.SetPerClientRateLimit(*perClientRateLimit, burst) + go watchRateConfig(ctx, s, *rateConfigPath) } var meshKey string @@ -436,6 +434,27 @@ func main() { } } +// watchRateConfig listens for SIGHUP signals and reloads the rate config +// file on each signal, applying it to the server. It returns when ctx is done. +func watchRateConfig(ctx context.Context, s *derpserver.Server, path string) { + sighup := make(chan os.Signal, 1) + signal.Notify(sighup, syscall.SIGHUP) + defer signal.Stop(sighup) + for { + select { + case <-ctx.Done(): + return + case <-sighup: + log.Printf("derper: received SIGHUP, reloading rate config from %s", path) + if err := s.LoadAndApplyRateConfig(path); err != nil { + log.Printf("derper: rate config reload failed: %v", err) + continue + } + log.Printf("derper: rate config reloaded successfully") + } + } +} + var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`) func prodAutocertHostPolicy(_ context.Context, host string) error { diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index ae8e9d433..8a393d011 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -209,11 +209,11 @@ type Server struct { // perClientRecvBytesPerSec is the rate limit for receiving data from // a single client connection, in bytes per second. 0 means unlimited. // Mesh peers are exempt from this limit. - perClientRecvBytesPerSec uint + perClientRecvBytesPerSec uint64 // perClientRecvBurst is the burst size in bytes for the per-client - // receive rate limiter. perClientRecvBurst is only relevant when - // perClientRecvBytesPerSec is nonzero. - perClientRecvBurst uint + // receive rate limiter. Always at least [derp.MaxPacketSize] when + // set via [Server.UpdatePerClientRateLimit]. + perClientRecvBurst uint64 clock tstime.Clock } @@ -239,10 +239,6 @@ type clientSet struct { // activeClient holds the currently active connection for the set. It's nil // if there are no connections or the connection is disabled. // - // A pointer to a clientSet can be held by peers for long periods of time - // without holding Server.mu to avoid mutex contention on Server.mu, only - // re-acquiring the mutex and checking the clients map if activeClient is - // nil. activeClient atomic.Pointer[sclient] // dup is non-nil if there are multiple connections for the @@ -518,14 +514,62 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { s.tcpWriteTimeout = d } -// SetPerClientRateLimit sets the per-client receive rate limit in bytes per -// second and the burst size in bytes. Mesh peers are exempt from this limit. -// The burst is at least [derp.MaxPacketSize], or burst, if burst is greater -// than [derp.MaxPacketSize]. This ensures at least a full packet can -// be received in a burst, even if the rate limit is low. -func (s *Server) SetPerClientRateLimit(bytesPerSec, burst uint) { +// RateConfig is a JSON-serializable configuration for per-client rate limits. +// Values are in bytes. +type RateConfig struct { + // PerClientRateLimitBytesPerSec represents the per-client + // rate limit in bytes per second. A zero value disables rate-limiting. + PerClientRateLimitBytesPerSec uint64 `json:",omitzero"` + // PerClientRateBurstBytes represents the per-client token bucket depth, + // or burst, in bytes. Any value lower than [derp.MaxPacketSize] + // will be increased to [derp.MaxPacketSize] before application. + PerClientRateBurstBytes uint64 `json:",omitzero"` +} + +// LoadRateConfig reads and JSON-unmarshals a [RateConfig] from the file at path. +func LoadRateConfig(path string) (RateConfig, error) { + if path == "" { + return RateConfig{}, errors.New("rate config path is empty") + } + b, err := os.ReadFile(path) + if err != nil { + return RateConfig{}, fmt.Errorf("reading rate config: %w", err) + } + var rc RateConfig + if err := json.Unmarshal(b, &rc); err != nil { + return RateConfig{}, fmt.Errorf("parsing rate config: %w", err) + } + return rc, nil +} + +// LoadAndApplyRateConfig reads a [RateConfig] from the file at path and +// applies it to the server via [Server.UpdatePerClientRateLimit]. +func (s *Server) LoadAndApplyRateConfig(path string) error { + rc, err := LoadRateConfig(path) + if err != nil { + return err + } + s.UpdatePerClientRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) + s.logf("rate config applied: rate=%d bytes/sec, burst=%d bytes", rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) + return nil +} + +// UpdatePerClientRateLimit sets the per-client receive rate limit in bytes per +// second and the burst size in bytes, updating all existing client connections. +// The burst is at least [derp.MaxPacketSize], ensuring at least a full packet +// can be received in a burst even if the rate limit is low. If bytesPerSec is +// 0, rate limiting is set to infinity. Mesh peers are always exempt from rate +// limiting. +func (s *Server) UpdatePerClientRateLimit(bytesPerSec, burst uint64) { + s.mu.Lock() + defer s.mu.Unlock() s.perClientRecvBytesPerSec = bytesPerSec s.perClientRecvBurst = max(burst, derp.MaxPacketSize) + for _, cs := range s.clients { + cs.ForeachClient(func(c *sclient) { + c.setRateLimit(s.perClientRecvBytesPerSec, s.perClientRecvBurst) + }) + } } // HasMeshKey reports whether the server is configured with a mesh key. @@ -690,6 +734,8 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() + c.setRateLimit(s.perClientRecvBytesPerSec, s.perClientRecvBurst) + cs, ok := s.clients[c.key] if !ok { c.debugLogf("register single client") @@ -975,9 +1021,6 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } - if s.perClientRecvBytesPerSec > 0 && !c.canMesh { - c.recvLim = xrate.NewLimiter(xrate.Limit(s.perClientRecvBytesPerSec), int(s.perClientRecvBurst)) - } if c.canMesh { c.meshUpdate = make(chan struct{}, 1) // must be buffered; >1 is fine but wasteful } @@ -1267,19 +1310,38 @@ func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { return c.sendPkt(dst, p) } -// rateLimit applies the per-client receive rate limit, if configured. +// setRateLimit updates the per-client receive rate limiter. +// When bytesPerSec is 0 or the client is a mesh peer, the limiter is +// set to nil so that [sclient.rateLimit] is a no-op. +func (c *sclient) setRateLimit(bytesPerSec uint64, burst uint64) { + if bytesPerSec == 0 || c.canMesh { + c.recvLim.Store(nil) + return + } + if lim := c.recvLim.Load(); lim != nil { + // Update in place. SetBurst before SetLimit to avoid a transient + // state where a new higher rate exceeds the old lower burst. + lim.SetBurst(int(burst)) + lim.SetLimit(xrate.Limit(bytesPerSec)) + return + } + lim := xrate.NewLimiter(xrate.Limit(bytesPerSec), int(burst)) + c.recvLim.Store(lim) +} + +// rateLimit applies the per-client receive rate limit. // By limiting here we prevent reading from the buffered reader // [sclient.br] if the limit has been exceeded. Any reads done here provide space // within the buffered reader to fill back in with data from // the TCP socket. Pacing reads acts as a form of natural // backpressure via TCP flow control. -// meshed clients are exempt from rate limits. +// When rate limiting is disabled or the client is a mesh peer, recvLim is nil +// and this is a no-op. func (c *sclient) rateLimit(n int) error { - if c.recvLim == nil || c.canMesh { - return nil + if lim := c.recvLim.Load(); lim != nil { + return lim.WaitN(c.ctx, n) } - - return c.recvLim.WaitN(c.ctx, n) + return nil } func (c *sclient) debugLogf(format string, v ...any) { @@ -1714,10 +1776,15 @@ type sclient struct { // through us with a peer we have no record of. peerGoneLim *rate.Limiter - // recvLim is the per-connection receive rate limiter. If non-nil, - // the server calls WaitN per received DERP frame in order to - // apply TCP backpressure and throttle the sender. - recvLim *xrate.Limiter + // recvLim is the per-connection receive rate limiter. When rate + // limiting is enabled for a non-mesh client, it points to an + // [xrate.Limiter]. When rate limiting is disabled or the client is a + // mesh peer, it is nil and [sclient.rateLimit] is a no-op. + // Updated atomically by [sclient.setRateLimitLocked] so that + // [sclient.rateLimit] can load it without holding Server.mu. + // TODO(mikeodr): update to use mono time, requires updates + // to tstime/rate.Limiter + recvLim atomic.Pointer[xrate.Limiter] } func (c *sclient) presentFlags() derp.PeerPresentFlags { diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 3fb4b838e..b2ecea29f 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -15,6 +15,7 @@ import ( "log" "net" "os" + "path/filepath" "reflect" "strconv" "sync" @@ -30,6 +31,7 @@ import ( "tailscale.com/derp/derpconst" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/set" ) const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" @@ -965,9 +967,9 @@ func TestPerClientRateLimit(t *testing.T) { t.Cleanup(cancel) c := &sclient{ - ctx: ctx, - recvLim: rate.NewLimiter(rate.Limit(bytesPerSec), burst), + ctx: ctx, } + c.recvLim.Store(rate.NewLimiter(rate.Limit(bytesPerSec), burst)) // First call within burst should not block. c.rateLimit(burst) @@ -1006,9 +1008,9 @@ func TestPerClientRateLimit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) c := &sclient{ - ctx: ctx, - recvLim: rate.NewLimiter(rate.Limit(100), 100), + ctx: ctx, } + c.recvLim.Store(rate.NewLimiter(rate.Limit(100), 100)) // Exhaust burst. if err := c.rateLimit(100); err != nil { @@ -1040,32 +1042,17 @@ func TestPerClientRateLimit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + // Mesh peers have nil recvLim, so rate limiting is a no-op. c := &sclient{ ctx: ctx, canMesh: true, - recvLim: rate.NewLimiter(rate.Limit(1), 1), // would block immediately if not exempt } - // rateLimit should be a no-op for mesh peers. if err := c.rateLimit(1000); err != nil { t.Fatalf("mesh peer rateLimit should be no-op: %v", err) } }) - t.Run("nil_limiter_no_op", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - c := &sclient{ - ctx: ctx, - } - - // rateLimit with nil recvLim should be a no-op. - if err := c.rateLimit(1000); err != nil { - t.Fatalf("nil limiter rateLimit should be no-op: %v", err) - } - }) - t.Run("zero_config_no_limiter", func(t *testing.T) { s := New(key.NewNode(), logger.Discard) defer s.Close() @@ -1075,6 +1062,293 @@ func TestPerClientRateLimit(t *testing.T) { }) } +func TestUpdatePerClientRateLimit(t *testing.T) { + const ( + testBurst1 = derp.MaxPacketSize * 2 + testRate1 = 1000 + testBurst2 = derp.MaxPacketSize * 4 + testRate2 = 5000 + ) + + s := New(key.NewNode(), t.Logf) + defer s.Close() + + // Create a non-mesh client with no initial limiter. + clientKey := key.NewNode().Public() + c := &sclient{ + key: clientKey, + s: s, + logf: logger.Discard, + canMesh: false, + } + cs := &clientSet{} + cs.activeClient.Store(c) + + s.mu.Lock() + s.clients[clientKey] = cs + s.mu.Unlock() + + s.UpdatePerClientRateLimit(testRate1, testBurst1) + + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter after update") + } + if got := lim.Limit(); got != rate.Limit(testRate1) { + t.Errorf("rate limit = %v; want %d", got, testRate1) + } + if got := lim.Burst(); got != int(testBurst1) { + t.Errorf("burst = %v; want %d", got, testBurst1) + } + + // Verify server fields updated. + s.mu.Lock() + if s.perClientRecvBytesPerSec != testRate1 { + t.Errorf("server rate = %d; want %d", s.perClientRecvBytesPerSec, testRate1) + } + if s.perClientRecvBurst != testBurst1 { + t.Errorf("server burst = %d; want %d", s.perClientRecvBurst, testBurst1) + } + s.mu.Unlock() + + // Update again with different nonzero values. This exercises the + // in-place update path (existing limiter is reused, not recreated). + prevLim := c.recvLim.Load() + s.UpdatePerClientRateLimit(testRate2, testBurst2) + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter after in-place update") + } + if lim != prevLim { + t.Error("expected same limiter pointer after in-place update") + } + if got := lim.Limit(); got != rate.Limit(testRate2) { + t.Errorf("rate limit after in-place update = %v; want %d", got, testRate2) + } + if got := lim.Burst(); got != int(testBurst2) { + t.Errorf("burst after in-place update = %v; want %d", got, testBurst2) + } + + // Disable rate limiting (set to 0). + s.UpdatePerClientRateLimit(0, 0) + + if got := c.recvLim.Load(); got != nil { + t.Errorf("expected nil limiter after disable, got limit=%v", got.Limit()) + } + + // Mesh peer should always have nil limiter regardless of update. + meshKey := key.NewNode().Public() + meshClient := &sclient{ + key: meshKey, + s: s, + logf: logger.Discard, + canMesh: true, + } + meshCS := &clientSet{} + meshCS.activeClient.Store(meshClient) + + s.mu.Lock() + s.clients[meshKey] = meshCS + s.mu.Unlock() + + s.UpdatePerClientRateLimit(testRate2, testBurst2) + + if got := meshClient.recvLim.Load(); got != nil { + t.Errorf("mesh peer should have nil limiter, got limit=%v", got.Limit()) + } + // Non-mesh client should be updated. + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter for non-mesh client") + } + if got := lim.Limit(); got != rate.Limit(testRate2) { + t.Errorf("rate limit = %v; want %d", got, testRate2) + } + if got := lim.Burst(); got != int(testBurst2) { + t.Errorf("burst = %v; want %d", got, testBurst2) + } + + // Verify dup clients are also updated. + dupKey := key.NewNode().Public() + d1 := &sclient{key: dupKey, s: s, logf: logger.Discard} + d2 := &sclient{key: dupKey, s: s, logf: logger.Discard} + dupCS := &clientSet{} + dupCS.activeClient.Store(d1) + dupCS.dup = &dupClientSet{set: set.Of(d1, d2)} + s.mu.Lock() + s.clients[dupKey] = dupCS + s.mu.Unlock() + + s.UpdatePerClientRateLimit(testRate1, testBurst1) + for i, d := range []*sclient{d1, d2} { + dl := d.recvLim.Load() + if dl == nil { + t.Fatalf("dup client %d: expected non-nil limiter", i) + } + if got := dl.Limit(); got != rate.Limit(testRate1) { + t.Errorf("dup client %d: rate = %v; want %d", i, got, testRate1) + } + if got := dl.Burst(); got != int(testBurst1) { + t.Errorf("dup client %d: burst = %v; want %d", i, got, testBurst1) + } + } +} + +func TestLoadRateConfig(t *testing.T) { + for _, tt := range []struct { + name string + json string + wantRate uint64 + wantBurst uint64 + }{ + {"both_set", `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`, 1250000, 2500000}, + {"rate_only", `{"PerClientRateLimitBytesPerSec": 500000}`, 500000, 0}, + {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0}`, 0, 0}, + {"empty_json", `{}`, 0, 0}, + } { + t.Run(tt.name, func(t *testing.T) { + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(tt.json), 0644); err != nil { + t.Fatal(err) + } + rc, err := LoadRateConfig(f) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rc.PerClientRateLimitBytesPerSec != tt.wantRate { + t.Errorf("rate = %d; want %d", rc.PerClientRateLimitBytesPerSec, tt.wantRate) + } + if rc.PerClientRateBurstBytes != tt.wantBurst { + t.Errorf("burst = %d; want %d", rc.PerClientRateBurstBytes, tt.wantBurst) + } + }) + } + + for _, tt := range []struct { + name string + path string + content string // written to path if non-empty; path used as-is if empty + }{ + {"empty_path", "", ""}, + {"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""}, + {"invalid_json", "", "not json"}, + } { + t.Run(tt.name, func(t *testing.T) { + path := tt.path + if tt.content != "" { + path = filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { + t.Fatal(err) + } + } + _, err := LoadRateConfig(path) + if err == nil { + t.Fatal("expected error") + } + }) + } +} + +func TestLoadAndApplyRateConfig(t *testing.T) { + writeConfig := func(t *testing.T, json string) string { + t.Helper() + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(json), 0644); err != nil { + t.Fatal(err) + } + return f + } + + t.Run("applies_and_updates_clients", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + clientKey := key.NewNode().Public() + c := &sclient{key: clientKey, s: s, logf: logger.Discard} + cs := &clientSet{} + cs.activeClient.Store(c) + s.mu.Lock() + s.clients[clientKey] = cs + s.mu.Unlock() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + // Verify server fields. + s.mu.Lock() + gotRate := s.perClientRecvBytesPerSec + gotBurst := s.perClientRecvBurst + s.mu.Unlock() + if gotRate != 1250000 { + t.Errorf("server rate = %d; want 1250000", gotRate) + } + if gotBurst != 2500000 { + t.Errorf("server burst = %d; want 2500000", gotBurst) + } + + // Verify client limiter. + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter") + } + if got := lim.Limit(); got != rate.Limit(1250000) { + t.Errorf("client rate = %v; want 1250000", got) + } + }) + + t.Run("burst_is_at_least_max_packet_size", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + s.mu.Lock() + gotBurst := s.perClientRecvBurst + s.mu.Unlock() + if gotBurst != derp.MaxPacketSize { + t.Errorf("burst = %d; want at least %d", gotBurst, derp.MaxPacketSize) + } + }) + + t.Run("reload_disables_limiting", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + + if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil { + t.Fatal(err) + } + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + + s.mu.Lock() + gotRate := s.perClientRecvBytesPerSec + s.mu.Unlock() + if gotRate != 0 { + t.Errorf("rate = %d; want 0 (unlimited)", gotRate) + } + }) + + t.Run("propagates_errors", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + if err := s.LoadAndApplyRateConfig(filepath.Join(t.TempDir(), "nonexistent.json")); err == nil { + t.Fatal("expected error") + } + }) +} + func BenchmarkSenderCardinalityOverhead(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public()