diff --git a/feature/conn25/addrAssignments.go b/feature/conn25/addrAssignments.go index 988932113..26b2543d3 100644 --- a/feature/conn25/addrAssignments.go +++ b/feature/conn25/addrAssignments.go @@ -4,7 +4,6 @@ package conn25 import ( - "context" "errors" "net/netip" "sync" @@ -123,16 +122,3 @@ func (a *addrAssignments) removeExpiredAddrs() []addrs { } return removed } - -func (a *addrAssignments) expireAddrAssignmentsLoop(ctx context.Context) { - ticker, ch := a.clock.NewTicker(61 * time.Second) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ch: - a.removeExpiredAddrs() - } - } -} diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index dc867a34e..158ac06c1 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -642,10 +642,56 @@ func newClient(ctx context.Context, logf logger.Logf) *client { addrsCh: make(chan addrs, 64), assignments: addrAssignments{clock: tstime.StdClock{}}, } - go c.assignments.expireAddrAssignmentsLoop(ctx) + // It gets racy in the tests whether the ticker fires when you advance the clock, + // so in the tests we'll call handleExpireAddrAssignmentsLoopTick by hand. + if !testenv.InTest() { + go c.expireAddrAssignmentsLoop(ctx) + } return c } +func (c *client) handleExpireAddrAssignmentsLoopTick() { + expired := c.assignments.removeExpiredAddrs() + c.returnExpiredToPool(expired) +} + +func (c *client) expireAddrAssignmentsLoop(ctx context.Context) { + ticker, ch := c.assignments.clock.NewTicker(61 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ch: + c.handleExpireAddrAssignmentsLoopTick() + } + } +} + +func (c *client) returnExpiredToPool(expired []addrs) { + if len(expired) == 0 { + return + } + c.mu.Lock() + defer c.mu.Unlock() + for _, as := range expired { + var magicPool, transitPool *ippool + if as.magic.Is4() { + magicPool = c.v4MagicIPPool + transitPool = c.v4TransitIPPool + } else { + magicPool = c.v6MagicIPPool + transitPool = c.v6TransitIPPool + } + if err := magicPool.returnAddr(as.magic); err != nil { + c.logf("error returning magic IP %v to pool: %v", as.magic, err) + } + if err := transitPool.returnAddr(as.transit); err != nil { + c.logf("error returning transit IP %v to pool: %v", as.transit, err) + } + } +} + func (c *client) getConfig() config { c.mu.Lock() defer c.mu.Unlock() diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 97a6f19e8..b0e61b270 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -5,6 +5,7 @@ package conn25 import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "net/netip" @@ -26,6 +27,7 @@ import ( "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" + "tailscale.com/tstest" "tailscale.com/types/appctype" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -2013,3 +2015,42 @@ func TestGetMagicRange(t *testing.T) { } } } + +func TestExpiredAddrsReturnedToPool(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()}) + c := newConn25(t.Context(), logger.Discard) + c.client.assignments.clock = clock + // Single address pools. + c.client.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/128")) + c.client.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/128")) + c.client.config.nv.appNamesByDomain = map[dnsname.FQDN][]string{"example.com.": {"app"}} + + // Use the one address. + first, err := c.client.reserveAddresses("example.com.", netip.MustParseAddr("::1")) + if err != nil { + t.Fatal(err) + } + + // The pools are exhausted. + _, err = c.client.reserveAddresses("example.com.", netip.MustParseAddr("::2")) + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("want errPoolExhausted, got: %v", err) + } + + // Advance the clock past the expiry window and run the expiry loop tick. + // The addresses are returned to their pools. + clock.Advance(defaultExpiry * 2) + c.client.handleExpireAddrAssignmentsLoopTick() + + // The addresses are available for use again. + second, err := c.client.reserveAddresses("example.com.", netip.MustParseAddr("::2")) + if err != nil { + t.Fatalf("want nil error after pool return, got: %v", err) + } + if second.magic != first.magic { + t.Errorf("magic: want %v, got %v", first.magic, second.magic) + } + if second.transit != first.transit { + t.Errorf("transit: want %v, got %v", first.transit, second.transit) + } +}