diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index cbc09272e..ebf125435 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -175,7 +175,10 @@ func (c *Conn25) isConfigured() bool { func newConn25(logf logger.Logf) *Conn25 { c := &Conn25{ - client: &client{logf: logf}, + client: &client{ + logf: logf, + addrs: newAddrTable(), + }, server: &server{logf: logf}, } return c @@ -389,9 +392,8 @@ type client struct { mu sync.Mutex // protects the fields below magicIPPool *ippool transitIPPool *ippool - // map of magic IP -> (transit IP, app) - magicIPs map[netip.Addr]appAddr - config config + config config + addrs *addrTable } func (c *client) isConfigured() bool { @@ -438,15 +440,7 @@ func (c *client) reconfig(selfNode tailcfg.NodeView) error { return nil } -func (c *client) setMagicIP(magicAddr, transitAddr netip.Addr, app string) { - c.mu.Lock() - defer c.mu.Unlock() - mak.Set(&c.magicIPs, magicAddr, appAddr{addr: transitAddr, app: app}) -} - func (c *client) reserveAddresses(domain string, dst netip.Addr) (connection, error) { - c.mu.Lock() - defer c.mu.Unlock() appNames, ok := c.config.appsByDomain[domain] // Is this domain routed by connectors? if !ok || len(appNames) == 0 { @@ -472,17 +466,31 @@ func (c *client) reserveAddresses(domain string, dst netip.Addr) (connection, er magic: mip, transit: tip, app: app, + domain: domain, } c.logf("assigning magic ip for domain: %s, app: %s, %v", domain, app, mip) return connection, nil } func (c *client) enqueueAddressAssignment(conn connection) { - c.setMagicIP(conn.magic, conn.transit, conn.app) // TODO(fran) 2026-02-03 asynchronously send peerapi req to connector to // allocate these addresses for us. } +func (c *client) handleObservedDNS(domain string, dst netip.Addr) (connection, error) { + c.mu.Lock() + defer c.mu.Unlock() + if addrs, ok := c.addrs.getByDomainAndDst(domain, dst); ok { + return addrs, nil + } + conn, err := c.reserveAddresses(domain, dst) + if err != nil || !conn.isValid() { + return conn, err + } + c.addrs.add(conn) + return conn, nil +} + func (c *client) mapDNSResponse(buf []byte) []byte { var msg dnsmessage.Message err := msg.Unpack(buf) @@ -497,7 +505,7 @@ func (c *client) mapDNSResponse(buf []byte) []byte { msgARecord := (a.Body).(*dnsmessage.AResource) domain := a.Header.Name.String() dst := netip.AddrFrom4(msgARecord.A) - connection, err := c.reserveAddresses(domain, dst) + connection, err := c.handleObservedDNS(domain, dst) if err != nil { // TODO(fran) log return buf @@ -536,8 +544,46 @@ type connection struct { magic netip.Addr transit netip.Addr app string + domain string } func (c connection) isValid() bool { return c.dst.IsValid() } + +// not safe for concurrent usage +// correct usage in the context of conn25 requires client to manage the locks +// ie, lock -> check if addr already has an entry -> if not check out an address from the pool -> assign to table -> unlock +type addrTable struct { + entries map[netip.Addr]connection // indexed by magicIP +} + +func (at *addrTable) add(c connection) { + if at == nil { + return + } + at.entries[c.magic] = c +} + +func (at *addrTable) getByMagicIP(a netip.Addr) (connection, bool) { + c, ok := at.entries[a] + return c, ok +} + +func (at *addrTable) getByDomainAndDst(domain string, dst netip.Addr) (connection, bool) { + if at == nil { + return connection{}, false + } + for _, val := range at.entries { + if val.domain == domain && val.dst == dst { + return val, true + } + } + return connection{}, false +} + +func newAddrTable() *addrTable { + return &addrTable{ + entries: make(map[netip.Addr]connection), + } +} diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 92951f915..e55dd5dd7 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -206,24 +206,6 @@ func TestTransitIPTargetUnknownTIP(t *testing.T) { } } -func TestSetMagicIP(t *testing.T) { - c := newConn25(logger.Discard) - mip := netip.MustParseAddr("0.0.0.1") - tip := netip.MustParseAddr("0.0.0.2") - app := "a" - c.client.setMagicIP(mip, tip, app) - val, ok := c.client.magicIPs[mip] - if !ok { - t.Fatal("expected there to be a value stored for the magic IP") - } - if val.addr != tip { - t.Fatalf("want %v, got %v", tip, val.addr) - } - if val.app != app { - t.Fatalf("want %s, got %s", app, val.app) - } -} - func TestReserveIPs(t *testing.T) { c := newConn25(logger.Discard) c.client.magicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) @@ -439,64 +421,70 @@ func TestConfigReconfigUpdate(t *testing.T) { assertChanged() } -func TestMapDNSResponse(t *testing.T) { - makeDNSResponse := func(domain string, addrs []dnsmessage.AResource) []byte { - b := dnsmessage.NewBuilder(nil, - dnsmessage.Header{ - ID: 1, - Response: true, - Authoritative: true, - RCode: dnsmessage.RCodeSuccess, - }) - b.EnableCompression() +func makeDNSResponse(t *testing.T, domain string, addrs []dnsmessage.AResource) []byte { + b := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ + ID: 1, + Response: true, + Authoritative: true, + RCode: dnsmessage.RCodeSuccess, + }) + b.EnableCompression() - if err := b.StartQuestions(); err != nil { - t.Fatal(err) - } - - if err := b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }); err != nil { - t.Fatal(err) - } - - if err := b.StartAnswers(); err != nil { - t.Fatal(err) - } - - for _, addr := range addrs { - b.AResource( - dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }, - addr, - ) - } - - outbs, err := b.Finish() - if err != nil { - t.Fatal(err) - } - return outbs + if err := b.StartQuestions(); err != nil { + t.Fatal(err) } + if err := b.Question(dnsmessage.Question{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }); err != nil { + t.Fatal(err) + } + + if err := b.StartAnswers(); err != nil { + t.Fatal(err) + } + + for _, addr := range addrs { + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + addr, + ) + } + + outbs, err := b.Finish() + if err != nil { + t.Fatal(err) + } + return outbs +} + +func TestMapDNSResponse(t *testing.T) { + for _, tt := range []struct { - name string - domain string - addrs []dnsmessage.AResource - wantMagicIPs map[netip.Addr]appAddr + name string + domain string + addrs []dnsmessage.AResource + wantAddrTable map[netip.Addr]connection }{ { name: "one-ip-matches", domain: "example.com.", addrs: []dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, - // these are 'expected' because they are the beginning of the provided pools - wantMagicIPs: map[netip.Addr]appAddr{ - netip.MustParseAddr("100.64.0.0"): {app: "app1", addr: netip.MustParseAddr("100.64.0.40")}, + wantAddrTable: map[netip.Addr]connection{ + netip.MustParseAddr("100.64.0.0"): { + app: "app1", + transit: netip.MustParseAddr("100.64.0.40"), + magic: netip.MustParseAddr("100.64.0.0"), + dst: netip.MustParseAddr("1.0.0.0"), + domain: "example.com.", + }, }, }, { @@ -506,9 +494,21 @@ func TestMapDNSResponse(t *testing.T) { {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, - wantMagicIPs: map[netip.Addr]appAddr{ - netip.MustParseAddr("100.64.0.0"): {app: "app1", addr: netip.MustParseAddr("100.64.0.40")}, - netip.MustParseAddr("100.64.0.1"): {app: "app1", addr: netip.MustParseAddr("100.64.0.41")}, + wantAddrTable: map[netip.Addr]connection{ + netip.MustParseAddr("100.64.0.0"): { + app: "app1", + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + domain: "example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + }, + netip.MustParseAddr("100.64.0.1"): { + app: "app1", + magic: netip.MustParseAddr("100.64.0.1"), + transit: netip.MustParseAddr("100.64.0.41"), + domain: "example.com.", + dst: netip.MustParseAddr("2.0.0.0"), + }, }, }, { @@ -518,10 +518,11 @@ func TestMapDNSResponse(t *testing.T) { {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, + wantAddrTable: make(map[netip.Addr]connection), }, } { t.Run(tt.name, func(t *testing.T) { - dnsResp := makeDNSResponse(tt.domain, tt.addrs) + dnsResp := makeDNSResponse(t, tt.domain, tt.addrs) sn := makeSelfNode(t, appctype.Conn25Attr{ Name: "app1", Connectors: []string{"tag:woo"}, @@ -536,9 +537,48 @@ func TestMapDNSResponse(t *testing.T) { if !reflect.DeepEqual(dnsResp, bs) { t.Fatal("shouldn't be changing the bytes (yet)") } - if diff := cmp.Diff(tt.wantMagicIPs, c.client.magicIPs, cmpopts.EquateComparable(appAddr{}, netip.Addr{})); diff != "" { - t.Errorf("magicIPs diff (-want, +got):\n%s", diff) + var toCmp map[netip.Addr]connection + if c.client.addrs != nil { + toCmp = c.client.addrs.entries + } + if diff := cmp.Diff(tt.wantAddrTable, toCmp, cmp.AllowUnexported(connection{}), cmpopts.EquateComparable(netip.Addr{})); diff != "" { + t.Errorf("addrTable entries diff (-want, +got):\n%s", diff) } }) } } + +func TestMapDNSResponseReservesOnce(t *testing.T) { + dnsResp := makeDNSResponse(t, "example.com.", []dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}) + sn := makeSelfNode(t, appctype.Conn25Attr{ + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: []string{"example.com"}, + MagicIPPool: []netipx.IPRange{rangeFrom("0", "10"), rangeFrom("20", "30")}, + TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, + }, []string{}) + c := newConn25(logger.Discard) + c.reconfig(sn) + + wantAddrEntries := map[netip.Addr]connection{ + netip.MustParseAddr("100.64.0.0"): { + app: "app1", + transit: netip.MustParseAddr("100.64.0.40"), + magic: netip.MustParseAddr("100.64.0.0"), + dst: netip.MustParseAddr("1.0.0.0"), + domain: "example.com.", + }, + } + assertWanted := func(addrs *addrTable) { + if diff := cmp.Diff(wantAddrEntries, addrs.entries, cmp.AllowUnexported(connection{}), cmpopts.EquateComparable(netip.Addr{})); diff != "" { + t.Errorf("addrTable entries diff (-want, +got):\n%s", diff) + } + } + + c.mapDNSResponse(dnsResp) + assertWanted(c.client.addrs) + + // doing it again doesn't change anything + c.mapDNSResponse(dnsResp) + assertWanted(c.client.addrs) +}