diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 5db04ff40..3de2e6479 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -54,9 +54,11 @@ type Impl struct { mc *magicsock.Conn logf logger.Logf - mu sync.Mutex - dns DNSMap - connsToSubnetIP map[tcpip.Address]int + mu sync.Mutex + dns DNSMap + // connsOpenBySubnetIP keeps track of subnet IPs temporarily + // registered on netstack for active TCP connections. + connsOpenBySubnetIP map[netaddr.IP]int } const nicID = 1 @@ -100,13 +102,13 @@ func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsoc }, }) ns := &Impl{ - logf: logf, - ipstack: ipstack, - linkEP: linkEP, - tundev: tundev, - e: e, - mc: mc, - connsToSubnetIP: make(map[tcpip.Address]int), + logf: logf, + ipstack: ipstack, + linkEP: linkEP, + tundev: tundev, + e: e, + mc: mc, + connsOpenBySubnetIP: make(map[netaddr.IP]int), } return ns, nil } @@ -170,21 +172,31 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { ns.dns = DNSMapFromNetworkMap(nm) } -func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, addr tcpip.Address) { +func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, addr tcpip.Address) (ok bool) { + ip, ok := netaddr.FromStdIP(net.IP(addr)) + if !ok { + return false + } ns.mu.Lock() - ns.connsToSubnetIP[addr]++ + ns.connsOpenBySubnetIP[ip]++ ns.mu.Unlock() ns.ipstack.AddAddress(nicID, pn, addr) + return true } -func (ns *Impl) removeSubnetAddress(addr tcpip.Address) { +func (ns *Impl) removeSubnetAddress(addr tcpip.Address) (ok bool) { + ip, ok := netaddr.FromStdIP(net.IP(addr)) + if !ok { + return false + } ns.mu.Lock() defer ns.mu.Unlock() - ns.connsToSubnetIP[addr]-- - if ns.connsToSubnetIP[addr] == 0 { + ns.connsOpenBySubnetIP[ip]-- + if ns.connsOpenBySubnetIP[ip] == 0 { ns.ipstack.RemoveAddress(nicID, addr) - delete(ns.connsToSubnetIP, addr) + delete(ns.connsOpenBySubnetIP, ip) } + return true } func ipPrefixToAddressWithPrefix(ipp netaddr.IPPrefix) tcpip.AddressWithPrefix { @@ -219,8 +231,8 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { } } ns.mu.Lock() - for ip := range ns.connsToSubnetIP { - ipp := ip.WithPrefix() + for ip := range ns.connsOpenBySubnetIP { + ipp := tcpip.Address(ip.IPAddr().IP).WithPrefix() ipsToBeAdded[ipp] = true delete(ipsToBeRemoved, ipp) }