diff --git a/net/ipset/ipset.go b/net/ipset/ipset.go index 45dc6486f..58a120432 100644 --- a/net/ipset/ipset.go +++ b/net/ipset/ipset.go @@ -10,11 +10,47 @@ import ( "github.com/gaissmai/bart" "tailscale.com/types/views" + "tailscale.com/util/set" ) // FalseContainsIPFunc is shorthand for NewContainsIPFunc(views.Slice[netip.Prefix]{}). func FalseContainsIPFunc() func(ip netip.Addr) bool { - return func(ip netip.Addr) bool { return false } + return emptySet +} + +func emptySet(ip netip.Addr) bool { return false } + +func bartLookup(t *bart.Table[struct{}]) func(netip.Addr) bool { + return func(ip netip.Addr) bool { + _, ok := t.Get(ip) + return ok + } +} + +func prefixContainsLoop(addrs []netip.Prefix) func(netip.Addr) bool { + return func(ip netip.Addr) bool { + for _, p := range addrs { + if p.Contains(ip) { + return true + } + } + return false + } +} + +func oneIP(ip1 netip.Addr) func(netip.Addr) bool { + return func(ip netip.Addr) bool { return ip == ip1 } +} + +func twoIP(ip1, ip2 netip.Addr) func(netip.Addr) bool { + return func(ip netip.Addr) bool { return ip == ip1 || ip == ip2 } +} + +func ipInMap(m set.Set[netip.Addr]) func(netip.Addr) bool { + return func(ip netip.Addr) bool { + _, ok := m[ip] + return ok + } } // pathForTest is a test hook for NewContainsIPFunc, to test that it took the @@ -29,7 +65,7 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool // (or just IPv6), and both IPv4 and IPv6. if addrs.Len() == 0 { pathForTest("empty") - return func(netip.Addr) bool { return false } + return emptySet } // If any addr is a prefix with more than a single IP, then do either a // linear scan or a bart table, depending on the number of addrs. @@ -41,40 +77,27 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool for i := range addrs.Len() { t.Insert(addrs.At(i), struct{}{}) } - return func(ip netip.Addr) bool { - _, ok := t.Get(ip) - return ok - } + return bartLookup(t) } else { pathForTest("linear-contains") // Small enough to do a linear search. - acopy := addrs.AsSlice() - return func(ip netip.Addr) bool { - for _, a := range acopy { - if a.Contains(ip) { - return true - } - } - return false - } + return prefixContainsLoop(addrs.AsSlice()) } } // Fast paths for 1 and 2 IPs: if addrs.Len() == 1 { pathForTest("one-ip") - a := addrs.At(0) - return func(ip netip.Addr) bool { return ip == a.Addr() } + return oneIP(addrs.At(0).Addr()) } if addrs.Len() == 2 { pathForTest("two-ip") - a, b := addrs.At(0), addrs.At(1) - return func(ip netip.Addr) bool { return ip == a.Addr() || ip == b.Addr() } + return twoIP(addrs.At(0).Addr(), addrs.At(1).Addr()) } // General case: pathForTest("ip-map") - m := map[netip.Addr]bool{} + m := set.Set[netip.Addr]{} for i := range addrs.Len() { - m[addrs.At(i).Addr()] = true + m.Add(addrs.At(i).Addr()) } - return func(ip netip.Addr) bool { return m[ip] } + return ipInMap(m) }