diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index b087e1444..05c7552c8 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -356,7 +356,15 @@ func (c *Auto) authRoutine() { if err != nil { c.direct.health.SetAuthRoutineInError(err) report(err, f) - bo.BackOff(ctx, err) + if rle, ok := errors.AsType[*rateLimitError](err); ok { + c.logf("authRoutine: %s", rle) + select { + case <-ctx.Done(): + case <-time.After(rle.retryAfter): + } + } else { + bo.BackOff(ctx, err) + } continue } if url != "" { diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index 2205a0eb3..5c25af0f4 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -406,6 +406,118 @@ func testHTTPS(t *testing.T, withProxy bool) { } } +// TestRegisterRateLimited verifies that the client correctly handles 429 +// responses to registration requests by parsing the Retry-After header +// and returning a rateLimitError. +func TestRegisterRateLimited(t *testing.T) { + bakedroots.ResetForTest(t, tlstest.TestRootCA()) + + bus := eventbustest.NewBus(t) + + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer controlLn.Close() + + var registerAttempts atomic.Int64 + tc := &testcontrol.Server{ + Logf: tstest.WhileTestRunningLogger(t), + MaybeRateLimitRegister: func() (bool, string, string) { + if registerAttempts.Add(1) == 1 { + return true, "30", "try again later" + } + return false, "", "" + }, + } + controlSrv := &http.Server{ + Handler: tc, + ErrorLog: logger.StdLogger(t.Logf), + } + go controlSrv.Serve(controlLn) + + const fakeControlIP = "1.2.3.4" + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetBus(bus) + dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) + } + var d net.Dialer + if host == fakeControlIP { + return d.DialContext(ctx, network, controlLn.Addr().String()) + } + return nil, fmt.Errorf("unexpected dial to %q", addr) + }) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + ServerURL: "https://controlplane.tstest", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: health.NewTracker(bus), + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + Bus: bus, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + + d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { + if host == "controlplane.tstest" { + return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil + } + t.Errorf("unexpected DNS query for %q", host) + return nil, fmt.Errorf("unexpected DNS lookup for %q", host) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // First attempt should get a 429 and return a rateLimitError. + _, err = d.TryLogin(ctx, LoginEphemeral) + if err == nil { + t.Fatal("expected rate limit error on first attempt, got nil") + } + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("expected *rateLimitError, got %T: %v", err, err) + } + if rle.retryAfter != 30*time.Second { + t.Errorf("retryAfter = %v, want 30s", rle.retryAfter) + } + if rle.msg != "try again later" { + t.Errorf("msg = %q, want %q", rle.msg, "try again later") + } + + // Second attempt should succeed (server no longer rate-limiting). + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin after rate limit: %v", err) + } + if url != "" { + t.Errorf("got URL %q, want empty", url) + } + + if got := registerAttempts.Load(); got != 2 { + t.Errorf("register attempts = %d, want 2", got) + } +} + func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI != target { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index d873cc745..032999cb9 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -24,6 +25,7 @@ import ( "reflect" "runtime" "slices" + "strconv" "strings" "sync/atomic" "time" @@ -575,6 +577,37 @@ var macOSScreenTime = health.Register(&health.Warnable{ ImpactsConnectivity: true, }) +type rateLimitError struct { + msg string + retryAfter time.Duration +} + +func (e *rateLimitError) Error() string { + return fmt.Sprintf("rate limited: %s (retry after %v)", e.msg, e.retryAfter) +} + +func parseRateLimitError(res *http.Response) *rateLimitError { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + + ret := &rateLimitError{ + msg: strings.TrimSpace(string(msg)), + } + + v := res.Header.Get("Retry-After") + if i, err := strconv.Atoi(v); err == nil { + ret.retryAfter = time.Duration(i) * time.Second + } else if t, err := http.ParseTime(v); err == nil { + ret.retryAfter = time.Until(t) + } + + // If the server didn't give us a valid Retry-After, default to 10s. + if ret.retryAfter <= 0 || ret.retryAfter > time.Hour { + ret.retryAfter = 5*time.Second + rand.N(5*time.Second) + } + return ret +} + func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, nks tkatype.MarshaledSignature, err error) { if c.panicOnUse { panic("tainted client") @@ -769,6 +802,12 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if err != nil { return regen, opt.URL, nil, fmt.Errorf("register request: %w", err) } + // Handle 429 Too Many Requests with a specific error type that includes the retry-after duration. + if res.StatusCode == 429 { + rle := parseRateLimitError(res) + msg := fmt.Sprintf("node registration rate limited; will retry after %v", rle.retryAfter) + return false, "", nil, vizerror.WrapWithMessage(rle, msg) + } if res.StatusCode != 200 { msg, _ := io.ReadAll(res.Body) res.Body.Close() diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index d10b346ae..98741482f 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -5,9 +5,11 @@ package controlclient import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "net/netip" + "strings" "testing" "time" @@ -126,6 +128,109 @@ func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) { return } +func TestParseRateLimitError(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + retryAfter string // Retry-After header value + wantMsg string + wantMin time.Duration // minimum expected retryAfter + wantMax time.Duration // maximum expected retryAfter + }{ + { + name: "retry-after-seconds", + statusCode: 429, + body: "too many requests", + retryAfter: "30", + wantMsg: "too many requests", + wantMin: 30 * time.Second, + wantMax: 30 * time.Second, + }, + { + name: "no-retry-after-header", + statusCode: 429, + body: "slow down", + retryAfter: "", + wantMsg: "slow down", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "unparseable-retry-after", + statusCode: 429, + body: "rate limited", + retryAfter: "not-a-number", + wantMsg: "rate limited", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "empty-body", + statusCode: 429, + body: "", + retryAfter: "5", + wantMsg: "", + wantMin: 5 * time.Second, + wantMax: 5 * time.Second, + }, + { + name: "body-with-whitespace", + statusCode: 429, + body: " too many requests \n", + retryAfter: "10", + wantMsg: "too many requests", + wantMin: 10 * time.Second, + wantMax: 10 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + if tt.retryAfter != "" { + rec.Header().Set("Retry-After", tt.retryAfter) + } + rec.WriteHeader(tt.statusCode) + rec.Body.WriteString(tt.body) + res := rec.Result() + + err := parseRateLimitError(res) + if err == nil { + t.Fatal("expected non-nil error") + } + + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("error is not a *rateLimitError: %T", err) + } + if rle.msg != tt.wantMsg { + t.Errorf("msg = %q, want %q", rle.msg, tt.wantMsg) + } + if rle.retryAfter < tt.wantMin || rle.retryAfter > tt.wantMax { + t.Errorf("retryAfter = %v, want between %v and %v", rle.retryAfter, tt.wantMin, tt.wantMax) + } + + // Verify the Error() string contains useful information. + errStr := err.Error() + if !strings.Contains(errStr, "rate limited") { + t.Errorf("Error() = %q, want it to contain 'rate limited'", errStr) + } + }) + } +} + +func TestRateLimitErrorIsError(t *testing.T) { + err := &rateLimitError{msg: "test", retryAfter: 5 * time.Second} + var target *rateLimitError + if !errors.As(err, &target) { + t.Fatal("errors.As should match *rateLimitError") + } + if target.retryAfter != 5*time.Second { + t.Errorf("retryAfter = %v, want 5s", target.retryAfter) + } +} + func TestTsmpPing(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index c704a6248..7405ec0e5 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -183,7 +183,8 @@ type CapabilityVersion int // - 134: 2026-03-09: Client understands [NodeAttrDisableAndroidBindToActiveNetwork] // - 135: 2026-03-30: Client understands [NodeAttrCacheNetworkMaps] // - 136: 2026-04-09: Client understands [NodeAttrDisableLinuxCGNATDropRule] -const CurrentCapabilityVersion CapabilityVersion = 136 +// - 137: 2026-04-15: Client handles 429 responses to /machine/register. +const CurrentCapabilityVersion CapabilityVersion = 137 // ID is an integer ID for a user, node, or login allocated by the // control plane. diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 486bc8b81..53d9137c4 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -80,6 +80,11 @@ type Server struct { ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL + // MaybeRateLimitRegister, if non-nil, is called before processing + // register requests. If it returns true, a 429 response is sent + // with the given Retry-After header value and body string. + MaybeRateLimitRegister func() (reject bool, retryAfter string, msg string) + // ModifyFirstMapResponse, if non-nil, is called exactly once per // MapResponse stream to modify the first MapResponse sent in response to it. ModifyFirstMapResponse func(*tailcfg.MapResponse, *tailcfg.MapRequest) @@ -768,6 +773,16 @@ func (s *Server) CompleteDeviceApproval(controlUrl string, urlStr string, nodeKe } func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) { + if fn := s.MaybeRateLimitRegister; fn != nil { + if reject, retryAfter, msg := fn(); reject { + if retryAfter != "" { + w.Header().Set("Retry-After", retryAfter) + } + http.Error(w, msg, http.StatusTooManyRequests) + return + } + } + msg, err := io.ReadAll(io.LimitReader(r.Body, msgLimit)) r.Body.Close() if err != nil {