diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 87f9a0bc0..767afac6b 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -87,6 +87,9 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") + perClientRateLimit = flag.Int("per-client-rate-limit", 0, "per-client receive rate limit in bytes/sec; 0 means unlimited. Mesh peers are exempt.") + perClientRateBurst = flag.Int("per-client-rate-burst", 0, "per-client receive rate burst in bytes; 0 defaults to 2x the rate limit") + // tcpKeepAlive is intentionally long, to reduce battery cost. There is an L7 keepalive on a higher frequency schedule. tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") // tcpUserTimeout is intentionally short, so that hung connections are cleaned up promptly. DERPs should be nearby users. @@ -192,6 +195,13 @@ func main() { s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetTCPWriteTimeout(*tcpWriteTimeout) + if *perClientRateLimit > 0 { + burst := *perClientRateBurst + if burst == 0 { + burst = *perClientRateLimit * 2 + } + s.SetPerClientRateLimit(*perClientRateLimit, burst) + } var meshKey string if *dev { diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index 0959a4729..eab47f8e9 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -40,6 +40,7 @@ import ( "github.com/axiomhq/hyperloglog" "go4.org/mem" "golang.org/x/sync/errgroup" + xrate "golang.org/x/time/rate" "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derpconst" @@ -205,6 +206,14 @@ type Server struct { tcpWriteTimeout time.Duration + // perClientRecvBytesPerSec is the rate limit for receiving data from + // a single client connection, in bytes per second. 0 means unlimited. + // Mesh peers are exempt from this limit. + perClientRecvBytesPerSec int + // perClientRecvBurst is the burst size in bytes for the per-client + // receive rate limiter. + perClientRecvBurst int + clock tstime.Clock } @@ -508,6 +517,15 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { s.tcpWriteTimeout = d } +// SetPerClientRateLimit sets the per-client receive rate limit in bytes per +// second and the burst size in bytes. Mesh peers are exempt from this limit. +// The burst is clamped to at least derp.MaxPacketSize to ensure a single +// max-size frame can always be processed. +func (s *Server) SetPerClientRateLimit(bytesPerSec, burst int) { + s.perClientRecvBytesPerSec = bytesPerSec + s.perClientRecvBurst = max(burst, int(derp.MaxPacketSize)) +} + // HasMeshKey reports whether the server is configured with a mesh key. func (s *Server) HasMeshKey() bool { return !s.meshKey.IsZero() } @@ -943,7 +961,7 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter br: br, bw: bw, logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v%s: ", remoteAddr, clientKey.ShortString())), - done: ctx.Done(), + ctx: ctx, remoteIPPort: remoteIPPort, connectedAt: s.clock.Now(), sendQueue: make(chan pkt, s.perClientSendQueueDepth), @@ -955,6 +973,9 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } + if s.perClientRecvBytesPerSec > 0 && !c.canMesh { + c.recvLim = xrate.NewLimiter(xrate.Limit(s.perClientRecvBytesPerSec), s.perClientRecvBurst) + } if c.canMesh { c.meshUpdate = make(chan struct{}, 1) // must be buffered; >1 is fine but wasteful } @@ -1190,6 +1211,17 @@ func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { return fmt.Errorf("client %v: recvPacket: %v", c.key, err) } + // Rate limit non-DISCO packets via TCP backpressure. By blocking + // here, we delay reading the next frame, causing the TCP receive + // buffer to fill and the TCP window to shrink, which throttles the + // sender. DISCO frames are exempt because they are small control + // messages critical for direct connection establishment. + if c.recvLim != nil && !disco.LooksLikeDiscoWrapper(contents) { + if err := c.recvLim.WaitN(c.ctx, len(contents)); err != nil { + return nil // context canceled, connection closing + } + } + var fwd PacketForwarder var dstLen int var dst *sclient @@ -1296,7 +1328,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { } for attempt := range 3 { select { - case <-dst.done: + case <-dst.ctx.Done(): s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected) dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt) return nil @@ -1341,7 +1373,7 @@ func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason derp.PeerGone peer: peer, reason: reason, }: - case <-c.done: + case <-c.ctx.Done(): } } @@ -1508,7 +1540,12 @@ func (s *Server) noteClientActivity(c *sclient) { type ServerInfo = derp.ServerInfo func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) error { - msg, err := json.Marshal(ServerInfo{Version: derp.ProtocolVersion}) + si := ServerInfo{Version: derp.ProtocolVersion} + if s.perClientRecvBytesPerSec > 0 { + si.TokenBucketBytesPerSecond = s.perClientRecvBytesPerSec + si.TokenBucketBytesBurst = s.perClientRecvBurst + } + msg, err := json.Marshal(si) if err != nil { return err } @@ -1626,7 +1663,7 @@ type sclient struct { key key.NodePublic info derp.ClientInfo logf logger.Logf - done <-chan struct{} // closed when connection closes + ctx context.Context // closed when connection closes remoteIPPort netip.AddrPort // zero if remoteAddr is not ip:port. sendQueue chan pkt // packets queued to this client; never closed discoSendQueue chan pkt // important packets queued to this client; never closed @@ -1666,6 +1703,11 @@ type sclient struct { // client that it's trying to establish a direct connection // through us with a peer we have no record of. peerGoneLim *rate.Limiter + + // recvLim is the per-connection receive rate limiter. If non-nil, + // the server calls WaitN after reading non-DISCO data frames to + // apply TCP backpressure and throttle the sender. + recvLim *xrate.Limiter } func (c *sclient) presentFlags() derp.PeerPresentFlags { diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 7f956ba78..75a135b0e 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -953,6 +953,171 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) { } } +func TestPerClientRateLimit(t *testing.T) { + // newServer creates a DERP server with a listener and returns a client factory. + newServer := func(t *testing.T, bytesPerSec, burst int) (*Server, func(t *testing.T) *derp.Client) { + t.Helper() + serverPriv := key.NewNode() + s := New(serverPriv, logger.Discard) + if bytesPerSec > 0 { + s.SetPerClientRateLimit(bytesPerSec, burst) + } + t.Cleanup(func() { s.Close() }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ln.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + go s.Accept(ctx, conn, brw, "test-client") + } + }() + + newClient := func(t *testing.T) *derp.Client { + t.Helper() + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + k := key.NewNode() + brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + c, err := derp.NewClient(k, conn, brw, logger.Discard) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + return c + } + return s, newClient + } + + // recvNPackets receives exactly n ReceivedPacket messages from c, + // discarding any other message types (e.g. ServerInfoMessage). + // It returns the time taken to receive all n data packets. + recvNPackets := func(t *testing.T, c *derp.Client, n int) time.Duration { + t.Helper() + start := time.Now() + got := 0 + for got < n { + m, err := c.Recv() + if err != nil { + t.Fatalf("Recv: %v (got %d/%d)", err, got, n) + } + if _, ok := m.(derp.ReceivedPacket); ok { + got++ + } + } + return time.Since(start) + } + + t.Run("non_disco_throttled", func(t *testing.T) { + // Use a rate that will show measurable delay. + // SetPerClientRateLimit clamps burst to max(burst, MaxPacketSize=64KB). + // So with 100KB/s rate and 64KB effective burst, sending 128KB of data + // should take at least ~640ms for the 64KB over burst. + const bytesPerSec = 100_000 + _, newClient := newServer(t, bytesPerSec, bytesPerSec) + sender := newClient(t) + receiver := newClient(t) + + // Drain the ServerInfoMessage from receiver before timing. + if _, err := receiver.Recv(); err != nil { + t.Fatal(err) + } + + const pktSize = 1000 + const numPkts = 128 // 128KB total + msg := make([]byte, pktSize) + + // Send all packets. + for i := range numPkts { + if err := sender.Send(receiver.PublicKey(), msg); err != nil { + t.Fatalf("Send(%d): %v", i, err) + } + } + + // Measure how long it takes to receive all data packets. + elapsed := recvNPackets(t, receiver, numPkts) + + // 128KB total, ~64KB effective burst, 100KB/s rate. + // Should take meaningfully longer than without rate limiting. + // Without rate limiting, the same data transfers in <1ms on loopback. + if elapsed < 100*time.Millisecond { + t.Errorf("expected receives to be throttled, but took only %v", elapsed) + } + t.Logf("received %d packets of %d bytes in %v", numPkts, pktSize, elapsed) + }) + + t.Run("disco_not_throttled", func(t *testing.T) { + // Same rate as above, but DISCO packets should bypass the limiter. + // Send the same amount of data to contrast with the throttled case. + const bytesPerSec = 100_000 + _, newClient := newServer(t, bytesPerSec, bytesPerSec) + sender := newClient(t) + receiver := newClient(t) + + if _, err := receiver.Recv(); err != nil { + t.Fatal(err) + } + + // disco.Magic (6 bytes) + 32 byte key + 24 byte nonce + payload + discoPacket := make([]byte, 6+32+24+932) // ~1000 bytes total + copy(discoPacket, "TS💬") // disco.Magic + + const numPkts = 128 + for i := range numPkts { + if err := sender.Send(receiver.PublicKey(), discoPacket); err != nil { + t.Fatalf("Send(%d): %v", i, err) + } + } + + elapsed := recvNPackets(t, receiver, numPkts) + + // DISCO packets bypass the rate limiter; should complete quickly + // (no 640ms+ delay like the non-DISCO case). + if elapsed > 2*time.Second { + t.Errorf("expected DISCO receives to be fast, but took %v", elapsed) + } + t.Logf("received %d DISCO packets in %v", numPkts, elapsed) + }) + + t.Run("mesh_peer_exempt", func(t *testing.T) { + // Verify the server would not assign a rate limiter to mesh peers. + s, _ := newServer(t, 10_000, 10_000) + c := &sclient{s: s, canMesh: true} + // accept() logic: s.perClientRecvBytesPerSec > 0 && !c.canMesh + // For mesh peer (canMesh=true), condition is false → no limiter. + if s.perClientRecvBytesPerSec > 0 && !c.canMesh { + t.Error("mesh peer should be exempt from rate limiting") + } + if c.recvLim != nil { + t.Error("expected nil recvLim for mesh peer") + } + }) + + t.Run("zero_config_no_limiter", func(t *testing.T) { + s, _ := newServer(t, 0, 0) + if s.perClientRecvBytesPerSec != 0 { + t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec) + } + c := &sclient{s: s, canMesh: false} + if c.recvLim != nil { + t.Errorf("expected nil recvLim with zero config") + } + }) +} + func BenchmarkSenderCardinalityOverhead(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public()