feature/conn25: return expired assignments to address pools

in the expiry ticker.
Now over time the client will only exhaust its pools if it uses all the
addresses within the defaultExpiry time.
Move the loop up to the client, so that it has access to the pools.

Updates tailscale/corp#39975

Signed-off-by: Fran Bull <fran@tailscale.com>
This commit is contained in:
Fran Bull 2026-04-27 09:33:25 -07:00
parent 9d624f3d4b
commit 8b48819a87
3 changed files with 88 additions and 15 deletions

View File

@ -4,7 +4,6 @@
package conn25
import (
"context"
"errors"
"net/netip"
"sync"
@ -123,16 +122,3 @@ func (a *addrAssignments) removeExpiredAddrs() []addrs {
}
return removed
}
func (a *addrAssignments) expireAddrAssignmentsLoop(ctx context.Context) {
ticker, ch := a.clock.NewTicker(61 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ch:
a.removeExpiredAddrs()
}
}
}

View File

@ -642,10 +642,56 @@ func newClient(ctx context.Context, logf logger.Logf) *client {
addrsCh: make(chan addrs, 64),
assignments: addrAssignments{clock: tstime.StdClock{}},
}
go c.assignments.expireAddrAssignmentsLoop(ctx)
// It gets racy in the tests whether the ticker fires when you advance the clock,
// so in the tests we'll call handleExpireAddrAssignmentsLoopTick by hand.
if !testenv.InTest() {
go c.expireAddrAssignmentsLoop(ctx)
}
return c
}
func (c *client) handleExpireAddrAssignmentsLoopTick() {
expired := c.assignments.removeExpiredAddrs()
c.returnExpiredToPool(expired)
}
func (c *client) expireAddrAssignmentsLoop(ctx context.Context) {
ticker, ch := c.assignments.clock.NewTicker(61 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ch:
c.handleExpireAddrAssignmentsLoopTick()
}
}
}
func (c *client) returnExpiredToPool(expired []addrs) {
if len(expired) == 0 {
return
}
c.mu.Lock()
defer c.mu.Unlock()
for _, as := range expired {
var magicPool, transitPool *ippool
if as.magic.Is4() {
magicPool = c.v4MagicIPPool
transitPool = c.v4TransitIPPool
} else {
magicPool = c.v6MagicIPPool
transitPool = c.v6TransitIPPool
}
if err := magicPool.returnAddr(as.magic); err != nil {
c.logf("error returning magic IP %v to pool: %v", as.magic, err)
}
if err := transitPool.returnAddr(as.transit); err != nil {
c.logf("error returning transit IP %v to pool: %v", as.transit, err)
}
}
}
func (c *client) getConfig() config {
c.mu.Lock()
defer c.mu.Unlock()

View File

@ -5,6 +5,7 @@ package conn25
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
@ -26,6 +27,7 @@ import (
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/types/appctype"
"tailscale.com/types/key"
"tailscale.com/types/logger"
@ -2013,3 +2015,42 @@ func TestGetMagicRange(t *testing.T) {
}
}
}
func TestExpiredAddrsReturnedToPool(t *testing.T) {
clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()})
c := newConn25(t.Context(), logger.Discard)
c.client.assignments.clock = clock
// Single address pools.
c.client.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/128"))
c.client.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/128"))
c.client.config.nv.appNamesByDomain = map[dnsname.FQDN][]string{"example.com.": {"app"}}
// Use the one address.
first, err := c.client.reserveAddresses("example.com.", netip.MustParseAddr("::1"))
if err != nil {
t.Fatal(err)
}
// The pools are exhausted.
_, err = c.client.reserveAddresses("example.com.", netip.MustParseAddr("::2"))
if !errors.Is(err, errPoolExhausted) {
t.Fatalf("want errPoolExhausted, got: %v", err)
}
// Advance the clock past the expiry window and run the expiry loop tick.
// The addresses are returned to their pools.
clock.Advance(defaultExpiry * 2)
c.client.handleExpireAddrAssignmentsLoopTick()
// The addresses are available for use again.
second, err := c.client.reserveAddresses("example.com.", netip.MustParseAddr("::2"))
if err != nil {
t.Fatalf("want nil error after pool return, got: %v", err)
}
if second.magic != first.magic {
t.Errorf("magic: want %v, got %v", first.magic, second.magic)
}
if second.transit != first.transit {
t.Errorf("transit: want %v, got %v", first.transit, second.transit)
}
}