diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index 2afc06052..4e144ce40 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -19,6 +19,7 @@ import ( "slices" "strings" "sync" + "time" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" @@ -32,6 +33,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/appctype" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -327,8 +329,9 @@ func (c *Conn25) isConfigured() bool { func newConn25(logf logger.Logf) *Conn25 { c := &Conn25{ client: &client{ - logf: logf, - addrsCh: make(chan addrs, 64), + logf: logf, + addrsCh: make(chan addrs, 64), + assignments: addrAssignments{clock: tstime.StdClock{}}, }, connector: &connector{logf: logf}, } @@ -1160,11 +1163,12 @@ func (c *connector) reconfig(newCfg config) { } type addrs struct { - dst netip.Addr - magic netip.Addr - transit netip.Addr - domain dnsname.FQDN - app string + dst netip.Addr + magic netip.Addr + transit netip.Addr + domain dnsname.FQDN + app string + expiresAt time.Time } func (c addrs) isValid() bool { @@ -1187,22 +1191,38 @@ type addrAssignments struct { byMagicIP map[netip.Addr]addrs byTransitIP map[netip.Addr]addrs byDomainDst map[domainDst]addrs + clock tstime.Clock } +const defaultExpiry = 48 * time.Hour + func (a *addrAssignments) insert(as addrs) error { - // we likely will want to allow overwriting in the future when we - // have address expiry, but for now this should not happen - if _, ok := a.byMagicIP[as.magic]; ok { - return errors.New("byMagicIP key exists") + return a.insertWithExpiry(as, defaultExpiry) +} + +func (a *addrAssignments) insertWithExpiry(as addrs, d time.Duration) error { + if !as.expiresAt.IsZero() { + return errors.New("expiresAt already set") + } + now := a.clock.Now() + as.expiresAt = now.Add(d) + // we don't expect for addresses to be reused before expiry + if existing, ok := a.byMagicIP[as.magic]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byMagicIP key exists") + } } ddst := domainDst{domain: as.domain, dst: as.dst} - if _, ok := a.byDomainDst[ddst]; ok { - return errors.New("byDomainDst key exists") + if existing, ok := a.byDomainDst[ddst]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byDomainDst key exists") + } } - if _, ok := a.byTransitIP[as.transit]; ok { - return errors.New("byTransitIP key exists") + if existing, ok := a.byTransitIP[as.transit]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byTransitIP key exists") + } } - mak.Set(&a.byMagicIP, as.magic, as) mak.Set(&a.byTransitIP, as.transit, as) mak.Set(&a.byDomainDst, ddst, as) @@ -1211,17 +1231,26 @@ func (a *addrAssignments) insert(as addrs) error { func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] - return v, ok + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true } func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (addrs, bool) { v, ok := a.byMagicIP[mip] - return v, ok + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true } func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { v, ok := a.byTransitIP[tip] - return v, ok + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true } // insertTransitConnMapping adds an entry to the byConnKey map diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 1784ccb68..124b739e4 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -26,6 +26,7 @@ import ( "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" + "tailscale.com/tstest" "tailscale.com/types/appctype" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -950,7 +951,13 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { c.reconfig(cfg) c.mapDNSResponse(dnsResp) - if diff := cmp.Diff(tt.wantByMagicIP, c.client.assignments.byMagicIP, cmpopts.EquateComparable(addrs{}, netip.Addr{})); diff != "" { + if diff := cmp.Diff( + tt.wantByMagicIP, + c.client.assignments.byMagicIP, + cmp.AllowUnexported(addrs{}), + cmpopts.IgnoreFields(addrs{}, "expiresAt"), + cmpopts.EquateComparable(netip.Addr{}), + ); diff != "" { t.Errorf("byMagicIP diff (-want, +got):\n%s", diff) } }) @@ -989,8 +996,8 @@ func TestReserveAddressesDeduplicated(t *testing.T) { t.Fatal(err) } - if first != second { - t.Errorf("expected same addrs on repeated call, got first=%v second=%v", first, second) + if first.magic != second.magic { + t.Errorf("expected same magic addrs on repeated call, got first=%v second=%v", first.magic, second.magic) } if got := len(c.client.assignments.byMagicIP); got != 1 { t.Errorf("want 1 entry in byMagicIP, got %d", got) @@ -2007,3 +2014,57 @@ func TestGetMagicRange(t *testing.T) { } } } + +func TestAssignmentsExpire(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()}) + assignments := addrAssignments{clock: clock} + as := addrs{ + dst: netip.MustParseAddr("0.0.0.1"), + magic: netip.MustParseAddr("0.0.0.2"), + transit: netip.MustParseAddr("0.0.0.3"), + app: "a", + domain: "example.com.", + } + err := assignments.insert(as) + if err != nil { + t.Fatal(err) + } + // Time has not passed since the insert, the assignment should be returned. + foundAs, ok := assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") + } + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) + } + // and we cannot insert over the addresses + err = assignments.insert(as) + if err == nil { + t.Fatal("expected an error but got nil") + } + // After a time greater than the default expiry passes, the assignment should + // not be returned. + clock.Advance(defaultExpiry * 2) + foundAsAfter, okAfter := assignments.lookupByMagicIP(as.magic) + if okAfter { + t.Fatal("expected not to find (expired)") + } + if foundAsAfter.isValid() { + t.Fatal("expected zero val") + } + // Now we can reuse the addresses + err = assignments.insert(as) + if err != nil { + t.Fatal(err) + } + foundAs, ok = assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") + } + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) + } + if !foundAs.expiresAt.After(clock.Now()) { + t.Fatalf("expected foundAs to expire after now") + } +}