diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 87f9a0bc0..429aff361 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -87,6 +87,9 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") + perClientRateLimit = flag.Uint("per-client-rate-limit", 0, "per-client receive rate limit in bytes/sec; 0 means unlimited. Mesh peers are exempt.") + perClientRateBurst = flag.Uint("per-client-rate-burst", 0, "per-client receive rate burst in bytes; 0 defaults to 2x the rate limit (only relevant when using nonzero --per-client-rate-limit)") + // tcpKeepAlive is intentionally long, to reduce battery cost. There is an L7 keepalive on a higher frequency schedule. tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") // tcpUserTimeout is intentionally short, so that hung connections are cleaned up promptly. DERPs should be nearby users. @@ -192,6 +195,13 @@ func main() { s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetTCPWriteTimeout(*tcpWriteTimeout) + if *perClientRateLimit > 0 { + burst := *perClientRateBurst + if burst < 1 { + burst = *perClientRateLimit * 2 + } + s.SetPerClientRateLimit(*perClientRateLimit, burst) + } var meshKey string if *dev { diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index 0959a4729..ea6b0e99a 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -40,6 +40,7 @@ import ( "github.com/axiomhq/hyperloglog" "go4.org/mem" "golang.org/x/sync/errgroup" + xrate "golang.org/x/time/rate" "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derpconst" @@ -205,6 +206,15 @@ type Server struct { tcpWriteTimeout time.Duration + // perClientRecvBytesPerSec is the rate limit for receiving data from + // a single client connection, in bytes per second. 0 means unlimited. + // Mesh peers are exempt from this limit. + perClientRecvBytesPerSec uint + // perClientRecvBurst is the burst size in bytes for the per-client + // receive rate limiter. perClientRecvBurst is only relevant when + // perClientRecvBytesPerSec is nonzero. + perClientRecvBurst uint + clock tstime.Clock } @@ -508,6 +518,16 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { s.tcpWriteTimeout = d } +// SetPerClientRateLimit sets the per-client receive rate limit in bytes per +// second and the burst size in bytes. Mesh peers are exempt from this limit. +// The burst is at least [derp.MaxPacketSize], or burst, if burst is greater +// than [derp.MaxPacketSize]. This ensures at least a full packet can +// be received in a burst, even if the rate limit is low. +func (s *Server) SetPerClientRateLimit(bytesPerSec, burst uint) { + s.perClientRecvBytesPerSec = bytesPerSec + s.perClientRecvBurst = max(burst, derp.MaxPacketSize) +} + // HasMeshKey reports whether the server is configured with a mesh key. func (s *Server) HasMeshKey() bool { return !s.meshKey.IsZero() } @@ -943,7 +963,7 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter br: br, bw: bw, logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v%s: ", remoteAddr, clientKey.ShortString())), - done: ctx.Done(), + ctx: ctx, remoteIPPort: remoteIPPort, connectedAt: s.clock.Now(), sendQueue: make(chan pkt, s.perClientSendQueueDepth), @@ -955,6 +975,9 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } + if s.perClientRecvBytesPerSec > 0 && !c.canMesh { + c.recvLim = xrate.NewLimiter(xrate.Limit(s.perClientRecvBytesPerSec), int(s.perClientRecvBurst)) + } if c.canMesh { c.meshUpdate = make(chan struct{}, 1) // must be buffered; >1 is fine but wasteful } @@ -1027,6 +1050,14 @@ func (c *sclient) run(ctx context.Context) error { } return fmt.Errorf("client %s: readFrameHeader: %w", c.key.ShortString(), err) } + // Rate limit by DERP frame length (fl), which excludes DERP + // and TLS protocol overheads. + // Note: meshed clients are exempt from rate limits. + // meshed clients are exempt from rate limits + if err := c.rateLimit(int(fl)); err != nil { + return err // context canceled, connection closing + } + c.s.noteClientActivity(c) switch ft { case derp.FrameNotePreferred: @@ -1096,6 +1127,7 @@ func (c *sclient) handleFramePing(ft derp.FrameType, fl uint32) error { if extra := int64(fl) - int64(len(m)); extra > 0 { _, err = io.CopyN(io.Discard, c.br, extra) } + select { case c.sendPongCh <- [8]byte(m): default: @@ -1139,7 +1171,7 @@ func (c *sclient) handleFrameClosePeer(ft derp.FrameType, fl uint32) error { // handleFrameForwardPacket reads a "forward packet" frame from the client // (which must be a trusted client, a peer in our mesh). -func (c *sclient) handleFrameForwardPacket(ft derp.FrameType, fl uint32) error { +func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { if !c.canMesh { return fmt.Errorf("insufficient permissions") } @@ -1182,7 +1214,7 @@ func (c *sclient) handleFrameForwardPacket(ft derp.FrameType, fl uint32) error { } // handleFrameSendPacket reads a "send packet" frame from the client. -func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { +func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { s := c.s dstKey, contents, err := s.recvPacket(c.br, fl) @@ -1235,6 +1267,21 @@ func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { return c.sendPkt(dst, p) } +// rateLimit applies the per-client receive rate limit, if configured. +// By limiting here we prevent reading from the buffered reader +// [sclient.br] if the limit has been exceeded. Any reads done here provide space +// within the buffered reader to fill back in with data from +// the TCP socket. Pacing reads acts as a form of natural +// backpressure via TCP flow control. +// meshed clients are exempt from rate limits. +func (c *sclient) rateLimit(n int) error { + if c.recvLim == nil || c.canMesh { + return nil + } + + return c.recvLim.WaitN(c.ctx, n) +} + func (c *sclient) debugLogf(format string, v ...any) { if c.debug { c.logf(format, v...) @@ -1296,7 +1343,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { } for attempt := range 3 { select { - case <-dst.done: + case <-dst.ctx.Done(): s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected) dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt) return nil @@ -1341,7 +1388,7 @@ func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason derp.PeerGone peer: peer, reason: reason, }: - case <-c.done: + case <-c.ctx.Done(): } } @@ -1626,7 +1673,7 @@ type sclient struct { key key.NodePublic info derp.ClientInfo logf logger.Logf - done <-chan struct{} // closed when connection closes + ctx context.Context // closed when connection closes remoteIPPort netip.AddrPort // zero if remoteAddr is not ip:port. sendQueue chan pkt // packets queued to this client; never closed discoSendQueue chan pkt // important packets queued to this client; never closed @@ -1666,6 +1713,11 @@ type sclient struct { // client that it's trying to establish a direct connection // through us with a peer we have no record of. peerGoneLim *rate.Limiter + + // recvLim is the per-connection receive rate limiter. If non-nil, + // the server calls WaitN per received DERP frame in order to + // apply TCP backpressure and throttle the sender. + recvLim *xrate.Limiter } func (c *sclient) presentFlags() derp.PeerPresentFlags { diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 7f956ba78..3fb4b838e 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -19,6 +19,7 @@ import ( "strconv" "sync" "testing" + "testing/synctest" "time" "github.com/axiomhq/hyperloglog" @@ -953,6 +954,127 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) { } } +func TestPerClientRateLimit(t *testing.T) { + t.Run("throttled", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // 100 bytes/sec with a burst of 100 bytes. + const bytesPerSec = 100 + const burst = 100 + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + c := &sclient{ + ctx: ctx, + recvLim: rate.NewLimiter(rate.Limit(bytesPerSec), burst), + } + + // First call within burst should not block. + c.rateLimit(burst) + + // Next call exceeds burst, should block until tokens replenish. + done := make(chan error, 1) + go func() { + done <- c.rateLimit(burst) + }() + + // After settling, the goroutine should be blocked (no result yet). + synctest.Wait() + select { + case err := <-done: + t.Fatalf("rateLimit should have blocked, but returned: %v", err) + default: + } + + // Advance time by 1 second; 100 bytes/sec * 1s = 100 bytes = burst. + time.Sleep(1 * time.Second) + synctest.Wait() + + select { + case err := <-done: + if err != nil { + t.Fatalf("rateLimit after time advance: %v", err) + } + default: + t.Fatal("rateLimit should have unblocked after 1s") + } + }) + }) + + t.Run("context_canceled", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + c := &sclient{ + ctx: ctx, + recvLim: rate.NewLimiter(rate.Limit(100), 100), + } + + // Exhaust burst. + if err := c.rateLimit(100); err != nil { + t.Fatalf("rateLimit: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- c.rateLimit(100) + }() + synctest.Wait() + + // Cancel the context; the blocked rateLimit should return an error. + cancel() + synctest.Wait() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error from canceled context") + } + default: + t.Fatal("rateLimit should have returned after context cancelation") + } + }) + }) + + t.Run("mesh_peer_exempt", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + c := &sclient{ + ctx: ctx, + canMesh: true, + recvLim: rate.NewLimiter(rate.Limit(1), 1), // would block immediately if not exempt + } + + // rateLimit should be a no-op for mesh peers. + if err := c.rateLimit(1000); err != nil { + t.Fatalf("mesh peer rateLimit should be no-op: %v", err) + } + }) + + t.Run("nil_limiter_no_op", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + c := &sclient{ + ctx: ctx, + } + + // rateLimit with nil recvLim should be a no-op. + if err := c.rateLimit(1000); err != nil { + t.Fatalf("nil limiter rateLimit should be no-op: %v", err) + } + }) + + t.Run("zero_config_no_limiter", func(t *testing.T) { + s := New(key.NewNode(), logger.Discard) + defer s.Close() + if s.perClientRecvBytesPerSec != 0 { + t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec) + } + }) +} + func BenchmarkSenderCardinalityOverhead(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public()