mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
net/netns: interface probe prototype
WIP Experiment with an netns alternative that doesn't rely on the system routing table, but rather probes routes to find a working interface.
This commit is contained in:
parent
9a6282b515
commit
b59d58bb89
@ -72,6 +72,8 @@ func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) {
|
||||
}
|
||||
}
|
||||
|
||||
var probeInterfaces atomic.Bool
|
||||
|
||||
// Listener returns a new net.Listener with its Control hook func
|
||||
// initialized as necessary to run in logical network namespace that
|
||||
// doesn't route back into Tailscale.
|
||||
|
||||
@ -8,7 +8,6 @@ package netns
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@ -19,7 +18,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
@ -37,23 +35,103 @@ var errInterfaceStateInvalid = errors.New("interface state invalid")
|
||||
// controlLogf binds c to a particular interface as necessary to dial the
|
||||
// provided (network, address).
|
||||
func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error {
|
||||
if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isLocalhost(address) {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, err := getInterfaceIndex(logf, netMon, address)
|
||||
if err != nil {
|
||||
// callee logged
|
||||
/// FIXME: (barnstar) Temporary probeInterfaces logic. Maybe set via a cap? By platform? So may caps.
|
||||
probeInterfaces.Store(true)
|
||||
if probeInterfaces.Load() {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("netns: control: SplitHostPort %q: %w", address, err)
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: logf,
|
||||
hpn: HostPortNetwork{Network: network, Host: host, Port: port},
|
||||
filterf: filterInvalidIntefaces,
|
||||
race: true,
|
||||
cache: globalRouteCache,
|
||||
}
|
||||
|
||||
// No netmon and no routing table.
|
||||
iface, err := findInterfaceThatCanReach(opts)
|
||||
|
||||
if err != nil || iface == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bindFn := getBindFn(network, address)
|
||||
logf("netns: post-probe binding to interface %q (index %d) for %s/%s", iface.Name, iface.Index, network, address)
|
||||
return bindFn(c, uint32(iface.Index))
|
||||
}
|
||||
|
||||
// Not probing? Then check if we should bind at all.
|
||||
if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return bindConnToInterface(c, network, address, idx, logf)
|
||||
// Bind using the legacy RIB / netmon method.
|
||||
idx, _ := getInterfaceIndex(logf, netMon, address)
|
||||
bindFn := getBindFn(network, address)
|
||||
return bindFn(c, uint32(idx))
|
||||
}
|
||||
|
||||
func filterInvalidIntefaces(iface net.Interface) bool {
|
||||
uninterestingPrefixes := []string{"awdl", "llw", "gif", "stf", "ipsec", "bond", "fwip", "utun"}
|
||||
|
||||
for _, prefix := range uninterestingPrefixes {
|
||||
if strings.HasPrefix(iface.Name, prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
|
||||
// to the provided interface index.
|
||||
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
|
||||
if lc == nil {
|
||||
return errors.New("nil ListenConfig")
|
||||
}
|
||||
if lc.Control != nil {
|
||||
return errors.New("ListenConfig.Control already set")
|
||||
}
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
bindFn := getBindFn(network, address)
|
||||
return bindFn(c, uint32(ifIndex))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindSocket6(c syscall.RawConn, idx uint32) error {
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(idx))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
|
||||
}
|
||||
return sockErr
|
||||
}
|
||||
|
||||
func bindSocket4(c syscall.RawConn, idx uint32) error {
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(idx))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
|
||||
}
|
||||
return sockErr
|
||||
}
|
||||
|
||||
// Legacy
|
||||
|
||||
// getInterfaceIndex returns the interface index that we should bind to
|
||||
// in order to send traffic to the provided address using netmon's view of
|
||||
// the DefaultRouteInterfaceIndex and/or a direct query to the routing table.
|
||||
func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) (int, error) {
|
||||
// Helper so we can log errors.
|
||||
defaultIdx := func() (int, error) {
|
||||
@ -115,14 +193,9 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
|
||||
}
|
||||
|
||||
// If the address doesn't parse, use the default index.
|
||||
addr, err := parseAddress(address)
|
||||
if err != nil {
|
||||
if err != errUnspecifiedHost {
|
||||
logf("[unexpected] netns: error parsing address %q: %v", address, err)
|
||||
}
|
||||
return defaultIdx()
|
||||
}
|
||||
|
||||
logf("netns: getting interface index for address %q", address)
|
||||
addr, err := parseAddress(address)
|
||||
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
|
||||
if err != nil {
|
||||
logf("netns: error getting interface index for %q: %v", address, err)
|
||||
@ -143,34 +216,6 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
|
||||
return idx, err
|
||||
}
|
||||
|
||||
// tailscaleInterface returns the current machine's Tailscale interface, if any.
|
||||
// If none is found, (nil, nil) is returned.
|
||||
// A non-nil error is only returned on a problem listing the system interfaces.
|
||||
func tailscaleInterface() (*net.Interface, error) {
|
||||
ifs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, iface := range ifs {
|
||||
if !strings.HasPrefix(iface.Name, "utun") {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if ipnet, ok := a.(*net.IPNet); ok {
|
||||
nip, ok := netip.AddrFromSlice(ipnet.IP)
|
||||
if ok && tsaddr.IsTailscaleIP(nip.Unmap()) {
|
||||
return &iface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// interfaceIndexFor returns the interface index that we should bind to in
|
||||
// order to send traffic to the provided address.
|
||||
func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
|
||||
@ -276,40 +321,3 @@ func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
|
||||
|
||||
return 0, fmt.Errorf("no valid address found")
|
||||
}
|
||||
|
||||
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
|
||||
// to the provided interface index.
|
||||
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
|
||||
if lc == nil {
|
||||
return errors.New("nil ListenConfig")
|
||||
}
|
||||
if lc.Control != nil {
|
||||
return errors.New("ListenConfig.Control already set")
|
||||
}
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return bindConnToInterface(c, network, address, ifIndex, log.Printf)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindConnToInterface(c syscall.RawConn, network, address string, ifIndex int, logf logger.Logf) error {
|
||||
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
|
||||
proto := unix.IPPROTO_IP
|
||||
opt := unix.IP_BOUND_IF
|
||||
if v6 {
|
||||
proto = unix.IPPROTO_IPV6
|
||||
opt = unix.IPV6_BOUND_IF
|
||||
}
|
||||
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
|
||||
})
|
||||
if sockErr != nil {
|
||||
logf("[unexpected] netns: bindConnToInterface(%q, %q), v6=%v, index=%v: %v", network, address, v6, ifIndex, sockErr)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
|
||||
}
|
||||
return sockErr
|
||||
}
|
||||
|
||||
@ -5,27 +5,6 @@
|
||||
|
||||
package netns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
var errUnspecifiedHost = errors.New("unspecified host")
|
||||
|
||||
func parseAddress(address string) (addr netip.Addr, err error) {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// error means the string didn't contain a port number, so use the string directly
|
||||
host = address
|
||||
}
|
||||
if host == "" {
|
||||
return addr, errUnspecifiedHost
|
||||
}
|
||||
|
||||
return netip.ParseAddr(host)
|
||||
}
|
||||
|
||||
func UseSocketMark() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
454
net/netns/netns_probe.go
Normal file
454
net/netns/netns_probe.go
Normal file
@ -0,0 +1,454 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package netns contains the common code for using the Go net package
|
||||
// in a logical "network namespace" to avoid routing loops where
|
||||
// Tailscale-created packets would otherwise loop back through
|
||||
// Tailscale routes.
|
||||
//
|
||||
// Despite the name netns, the exact mechanism used differs by
|
||||
// operating system, and perhaps even by version of the OS.
|
||||
//
|
||||
// The netns package also handles connecting via SOCKS proxies when
|
||||
// configured by the environment.
|
||||
|
||||
package netns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/eventbus"
|
||||
)
|
||||
|
||||
// tailscaleInterface returns the current machine's Tailscale interface, if any.
|
||||
// If none is found, (nil, nil) is returned.
|
||||
// A non-nil error is only returned on a problem listing the system interfaces.
|
||||
// TODO (barnstar): netmon *usually* knows this (at least for darwing), but
|
||||
// this is more portable. It's still wildly different than the Windows method which
|
||||
// checks the description strings.
|
||||
func tailscaleInterface() (*net.Interface, error) {
|
||||
ifs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, iface := range ifs {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if ipnet, ok := a.(*net.IPNet); ok {
|
||||
nip, ok := netip.AddrFromSlice(ipnet.IP)
|
||||
if ok && tsaddr.IsTailscaleIP(nip.Unmap()) {
|
||||
return &iface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// inetReachability describes an interface and whether it was able to reach
|
||||
// the provided address.
|
||||
type inetReachability struct {
|
||||
iface net.Interface
|
||||
// TODO (barnstar): These are invariant. reachable should be true if err==nil.
|
||||
reachable bool
|
||||
err error
|
||||
}
|
||||
|
||||
// Tuple of the destination host, port, and network.
|
||||
// ie: "tcp4", "example.com", "80"
|
||||
type HostPortNetwork struct {
|
||||
Host string
|
||||
Port string
|
||||
Network string
|
||||
}
|
||||
|
||||
func (hpn HostPortNetwork) String() string {
|
||||
return fmt.Sprintf("%s/%s:%s", hpn.Network, hpn.Host, hpn.Port)
|
||||
}
|
||||
|
||||
type probeOpts struct {
|
||||
logf logger.Logf
|
||||
hpn HostPortNetwork
|
||||
race bool // if true, we'll pick the first interface that responds. sortf is ignored.
|
||||
filterf interfaceFilter // optional pre-filter for interfaces
|
||||
cache *routeCache // must be non-nil
|
||||
}
|
||||
|
||||
type DefaultIfaceHintFn func() int
|
||||
|
||||
var defaultIfaceHintFn DefaultIfaceHintFn
|
||||
|
||||
// Platforms may set defaultIFQueryFn to a function that returns the platforms's high
|
||||
// level view of the default interface index.
|
||||
func SetDefaultIFQueryFn(fn DefaultIfaceHintFn) {
|
||||
defaultIfaceHintFn = fn
|
||||
}
|
||||
|
||||
// uint
|
||||
type bindFn func(c syscall.RawConn, ifidx uint32) error
|
||||
|
||||
// Returns the proper bind function for the given network and address.
|
||||
// Currently only differentiates between IPv4 and IPv6 - and poorly.
|
||||
func bindFnByAddrType(network, address string) bindFn {
|
||||
// Very naive check for IPv6.
|
||||
if strings.Contains(address, "]:") || strings.HasSuffix(network, "6") {
|
||||
return bindSocket6
|
||||
}
|
||||
return bindSocket4
|
||||
}
|
||||
|
||||
type bindFunctionHook func(network, address string) bindFn
|
||||
|
||||
var getBindFn bindFunctionHook = bindFnByAddrType
|
||||
|
||||
var interfacesHookFn func() ([]net.Interface, error)
|
||||
|
||||
var interfacesHook = net.Interfaces
|
||||
|
||||
// ProbeInterfacesReachability probes all non-loopback, up interfaces
|
||||
// concurrently to determine which can reach the given address. It returns
|
||||
// a slice with one entry per probed interface in the same order as
|
||||
// net.Interfaces() filtered by the probe criteria.
|
||||
func probeInterfacesReachability(opts probeOpts) ([]inetReachability, error) {
|
||||
ifaces, err := interfacesHook()
|
||||
if err != nil {
|
||||
opts.logf("netns: ProbeInterfacesReachability: net.Interfaces: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make(chan inetReachability, len(ifaces))
|
||||
|
||||
tsiface, _ := tailscaleInterface()
|
||||
|
||||
var candidates []net.Interface
|
||||
for _, iface := range ifaces {
|
||||
// Individual platforms can exclude potential intefaces based on platorm-specific logic.
|
||||
// For example, on Darwin, we skip "utun" interfaces.
|
||||
if opts.filterf != nil && !opts.filterf(iface) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only consider up, non-loopback interfaces.
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagRunning == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the Tailscale interface.
|
||||
if tsiface != nil && iface.Index == tsiface.Index {
|
||||
continue
|
||||
}
|
||||
|
||||
// require an IPv4 or IPv6 global unicast address
|
||||
if !ifaceHasV4OrGlobalV6(&iface) {
|
||||
continue
|
||||
}
|
||||
|
||||
candidates = append(candidates, iface)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
opts.logf("netns: ProbeInterfacesReachability: no candidate interfaces found")
|
||||
return nil, errors.New("no candidate interfaces")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for _, iface := range candidates {
|
||||
go func() {
|
||||
// Per-probe timeout.
|
||||
|
||||
err := reachabilityHook(&iface, opts.hpn)
|
||||
|
||||
select {
|
||||
case results <- inetReachability{iface: iface, reachable: err == nil, err: err}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
out := make([]inetReachability, 0, len(candidates))
|
||||
timeout := time.After(600 * time.Millisecond)
|
||||
received := 0
|
||||
|
||||
for received < len(candidates) {
|
||||
select {
|
||||
case r := <-results:
|
||||
// If we're racing, return the first reachable interface immediately.
|
||||
// TODO (barnstar): We should cache all reachable results so we can try alteratives if we
|
||||
// can't get the conn up and running later but signal early if we're racing.
|
||||
if opts.race && r.reachable {
|
||||
return []inetReachability{r}, nil
|
||||
}
|
||||
// .. otherwise, collect all results including the unreachable ones.
|
||||
out = append(out, r)
|
||||
received++
|
||||
case <-timeout:
|
||||
return out, fmt.Errorf("netns: probe timed out after %v; received %d/%d results", timeout, received, len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// For testing
|
||||
type reachabilityHookFn func(iface *net.Interface, hpn HostPortNetwork) error
|
||||
|
||||
var reachabilityHook reachabilityHookFn = reachabilityCheck
|
||||
|
||||
func reachabilityCheck(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
// Per-probe timeout.
|
||||
dialCtx, dialCancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
|
||||
defer dialCancel()
|
||||
|
||||
d := net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
// (barnstar) TODO: The bind step here is still platform specific
|
||||
bindFn := getBindFn(network, address)
|
||||
return bindFn(c, uint32(iface.Index))
|
||||
},
|
||||
}
|
||||
|
||||
dst := net.JoinHostPort(hpn.Host, hpn.Port)
|
||||
conn, err := d.DialContext(dialCtx, hpn.Network, dst)
|
||||
if err == nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Pre-filter for interfaces. Platform-specific code can provide a filter
|
||||
// to exclude certain interfaces from consideration. For example, on Darwin,
|
||||
// we exclude "utun" interfaces and various other types which will never provie
|
||||
// have general internet connectivity.
|
||||
type interfaceFilter func(net.Interface) bool
|
||||
|
||||
func filterInPlace[T any](s []T, keep func(T) bool) []T {
|
||||
i := 0
|
||||
for _, v := range s {
|
||||
if keep(v) {
|
||||
s[i] = v
|
||||
i++
|
||||
}
|
||||
}
|
||||
return s[:i]
|
||||
}
|
||||
|
||||
var errUnspecifiedHost = errors.New("unspecified host")
|
||||
|
||||
func parseAddress(address string) (addr netip.Addr, err error) {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// error means the string didn't contain a port number, so use the string directly
|
||||
host = address
|
||||
}
|
||||
if host == "" {
|
||||
return addr, errUnspecifiedHost
|
||||
}
|
||||
|
||||
return netip.ParseAddr(host)
|
||||
}
|
||||
|
||||
// findInterfaceThatCanReach finds an interface that can reach the given host:port.
|
||||
// It uses the provided filterf to exclude certain interfaces, and the
|
||||
// sortf to prioritize certain interfaces. It returns the first interface that can reach
|
||||
// the destination.
|
||||
//
|
||||
// TODO (barnstar): What this does NOT do is provide a way to flag an interface as "bad" if
|
||||
// we can't get a connection up and running. Ideally we race for the first candidate, try
|
||||
// it for a partciular route, and if it fails, remove it from the route cache try a "different"
|
||||
// candidate. This requires the Dialer to be aware of this logic, and to be able to signal
|
||||
// back to the route cache that a given interface is "bad" for a given destination. We also
|
||||
// need to cache all of the candidates found during probing so we can try them again later some
|
||||
// related state.
|
||||
//
|
||||
// nil is returned if no interface can reach the destination.
|
||||
func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error) {
|
||||
// Try to parse the host as an IP address for cache lookup
|
||||
addr, err := parseAddress(opts.hpn.Host)
|
||||
if err == nil && addr.IsValid() {
|
||||
// Check cache first
|
||||
if cached := opts.cache.lookupCachedRoute(addr); cached != nil {
|
||||
opts.logf("netns: using cached interface %v for %v", cached.Name, opts.hpn)
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
res, err := probeInterfacesReachability(opts)
|
||||
if err != nil {
|
||||
opts.logf("netns: ProbeInterfacesReachability error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res = filterInPlace(res, func(r inetReachability) bool { return r.reachable })
|
||||
if len(res) == 0 {
|
||||
opts.logf("netns: could not find interface on network %v to reach %q:%q on %q: %v", opts.hpn.Network, opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
candidatesNames := make([]string, 0, len(res))
|
||||
for _, r := range res {
|
||||
candidatesNames = append(candidatesNames, r.iface.Name)
|
||||
}
|
||||
opts.logf("netns: found candidate interfaces that can reach %v:%v on %v: %v", opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, candidatesNames)
|
||||
iface = &res[0].iface
|
||||
|
||||
if defaultIfaceHintFn != nil {
|
||||
defIdx := defaultIfaceHintFn()
|
||||
for _, r := range res {
|
||||
if r.iface.Index == defIdx {
|
||||
opts.logf("netns: using default iface hint")
|
||||
iface = &r.iface
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
opts.logf("netns: returning interface %v at %v for %v:%v", iface.Name, iface.Index, opts.hpn.Host, opts.hpn.Port)
|
||||
|
||||
// Cache the result if we have a valid IP address
|
||||
if addr.IsValid() {
|
||||
opts.cache.setCachedRoute(addr, iface)
|
||||
}
|
||||
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
var ifaceHasV4AndGlobalV6Hook func(iface *net.Interface) bool
|
||||
|
||||
// ifaceHasV4AndGlobalV6 reports whether iface has at least one IPv4 address
|
||||
// and at least one IPv6 address that is not link-local.
|
||||
func ifaceHasV4OrGlobalV6(iface *net.Interface) bool {
|
||||
if ifaceHasV4AndGlobalV6Hook != nil {
|
||||
return ifaceHasV4AndGlobalV6Hook(iface)
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, a := range addrs {
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
if v.IP.IsGlobalUnicast() {
|
||||
return true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var globalRouteCache *routeCache
|
||||
|
||||
// SetGlobalRouteCache sets the global route cache used by netns.
|
||||
// It also subscribes the route cache to network change events from
|
||||
// the provided event bus.
|
||||
func SetGlobalRouteCache(rc *routeCache, e *eventbus.Bus, logf logger.Logf) {
|
||||
globalRouteCache = rc
|
||||
globalRouteCache.subscribeToNetworkChanges(e, logf)
|
||||
}
|
||||
|
||||
func NewRouteCache() *routeCache {
|
||||
return &routeCache{
|
||||
v4: new(bart.Table[*net.Interface]),
|
||||
v6: new(bart.Table[*net.Interface]),
|
||||
}
|
||||
}
|
||||
|
||||
type routeCache struct {
|
||||
mu syncs.Mutex
|
||||
v4 *bart.Table[*net.Interface] // IPv4 routing table
|
||||
v6 *bart.Table[*net.Interface] // IPv6 routing table
|
||||
ec *eventbus.Client
|
||||
}
|
||||
|
||||
func (rc *routeCache) subscribeToNetworkChanges(eventBus *eventbus.Bus, logf logger.Logf) {
|
||||
rc.mu.Lock()
|
||||
defer rc.mu.Unlock()
|
||||
|
||||
if rc.ec != nil {
|
||||
rc.ec.Close()
|
||||
}
|
||||
|
||||
rc.ec = eventBus.Client("routeCache")
|
||||
eventbus.SubscribeFunc(rc.ec, func(cd netmon.ChangeDelta) {
|
||||
if cd.RebindLikelyRequired {
|
||||
logf("netns: routeCache: major clearing all cached routes due to network change: %v", cd)
|
||||
rc.ClearAllCachedRoutes()
|
||||
}
|
||||
})
|
||||
logf("netns: routeCache: subscribed to network change events")
|
||||
}
|
||||
|
||||
func (rc *routeCache) lookupCachedRoute(addr netip.Addr) *net.Interface {
|
||||
rc.mu.Lock()
|
||||
defer rc.mu.Unlock()
|
||||
|
||||
iface, ok := rc.tableForAddr(addr).Lookup(addr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return iface
|
||||
}
|
||||
|
||||
func (rc *routeCache) setCachedRoute(addr netip.Addr, iface *net.Interface) {
|
||||
prefix := netip.PrefixFrom(addr, addrBits(addr))
|
||||
rc.setCachedRoutePrefix(prefix, iface)
|
||||
}
|
||||
|
||||
func (rc *routeCache) setCachedRoutePrefix(prefix netip.Prefix, iface *net.Interface) {
|
||||
rc.mu.Lock()
|
||||
defer rc.mu.Unlock()
|
||||
addr := prefix.Addr()
|
||||
rc.tableForAddr(addr).Insert(prefix, iface)
|
||||
}
|
||||
|
||||
func (rc *routeCache) clearCachedRoutePrefix(prefix netip.Prefix) {
|
||||
rc.mu.Lock()
|
||||
defer rc.mu.Unlock()
|
||||
addr := prefix.Addr()
|
||||
rc.tableForAddr(addr).Delete(prefix)
|
||||
}
|
||||
|
||||
func (rc *routeCache) ClearCachedRoute(addr netip.Addr) {
|
||||
prefix := netip.PrefixFrom(addr, addrBits(addr))
|
||||
rc.clearCachedRoutePrefix(prefix)
|
||||
}
|
||||
|
||||
func (rc *routeCache) ClearAllCachedRoutes() {
|
||||
rc.mu.Lock()
|
||||
defer rc.mu.Unlock()
|
||||
|
||||
rc.v4 = new(bart.Table[*net.Interface])
|
||||
rc.v6 = new(bart.Table[*net.Interface])
|
||||
}
|
||||
|
||||
func addrBits(addr netip.Addr) int {
|
||||
if addr.Is6() {
|
||||
return 128
|
||||
}
|
||||
return 32
|
||||
}
|
||||
|
||||
func (rc *routeCache) tableForAddr(addr netip.Addr) *bart.Table[*net.Interface] {
|
||||
if addr.Is6() {
|
||||
return rc.v6
|
||||
}
|
||||
return rc.v4
|
||||
}
|
||||
@ -14,7 +14,11 @@
|
||||
package netns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -76,3 +80,738 @@ func TestIsLocalhost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalRouteCache(t *testing.T) {
|
||||
iface1 := &net.Interface{Index: 1, Name: "eth0"}
|
||||
iface2 := &net.Interface{Index: 2, Name: "eth1"}
|
||||
iface3 := &net.Interface{Index: 3, Name: "wlan0"}
|
||||
|
||||
t.Run("insert and lookup IPv4", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.1.5")
|
||||
routeCache.setCachedRoute(addr, iface1)
|
||||
|
||||
got := routeCache.lookupCachedRoute(addr)
|
||||
if got != iface1 {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("insert and lookup IPv6", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
addr := netip.MustParseAddr("2001:db8::1")
|
||||
routeCache.setCachedRoute(addr, iface2)
|
||||
|
||||
got := routeCache.lookupCachedRoute(addr)
|
||||
if got != iface2 {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("lookup non-existent", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
got := routeCache.lookupCachedRoute(addr)
|
||||
if got != nil {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("longest prefix match IPv4", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
// Insert broader prefix
|
||||
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
|
||||
routeCache.setCachedRoutePrefix(prefix1, iface1)
|
||||
|
||||
// Insert more specific prefix
|
||||
prefix2 := netip.MustParsePrefix("10.0.1.0/24")
|
||||
routeCache.setCachedRoutePrefix(prefix2, iface2)
|
||||
|
||||
// Insert even more specific prefix
|
||||
prefix3 := netip.MustParsePrefix("10.0.1.128/25")
|
||||
routeCache.setCachedRoutePrefix(prefix3, iface3)
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
want *net.Interface
|
||||
}{
|
||||
{"10.0.0.1", iface1}, // matches 10.0.0.0/8
|
||||
{"10.0.1.1", iface2}, // matches 10.0.1.0/24
|
||||
{"10.0.1.129", iface3}, // matches 10.0.1.128/25
|
||||
{"10.0.1.127", iface2}, // matches 10.0.1.0/24 (not /25)
|
||||
{"10.0.2.1", iface1}, // matches 10.0.0.0/8
|
||||
{"192.168.1.1", nil}, // no match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
got := routeCache.lookupCachedRoute(addr)
|
||||
if got != tt.want {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("longest prefix match IPv6", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
// Insert broader prefix
|
||||
prefix1 := netip.MustParsePrefix("2001:db8::/32")
|
||||
routeCache.setCachedRoutePrefix(prefix1, iface1)
|
||||
|
||||
// Insert more specific prefix
|
||||
prefix2 := netip.MustParsePrefix("2001:db8:1::/48")
|
||||
routeCache.setCachedRoutePrefix(prefix2, iface2)
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
want *net.Interface
|
||||
}{
|
||||
{"2001:db8::1", iface1}, // matches 2001:db8::/32
|
||||
{"2001:db8:1::1", iface2}, // matches 2001:db8:1::/48
|
||||
{"2001:db8:2::1", iface1}, // matches 2001:db8::/32
|
||||
{"2001:db9::1", nil}, // no match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
got := routeCache.lookupCachedRoute(addr)
|
||||
if got != tt.want {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clear cached route by address", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.1.5")
|
||||
routeCache.setCachedRoute(addr, iface1)
|
||||
|
||||
// Verify it's there
|
||||
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
|
||||
t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
|
||||
}
|
||||
|
||||
// Clear it
|
||||
routeCache.ClearCachedRoute(addr)
|
||||
|
||||
// Verify it's gone
|
||||
if got := routeCache.lookupCachedRoute(addr); got != nil {
|
||||
t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clear cached route by prefix", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
prefix := netip.MustParsePrefix("10.0.1.0/24")
|
||||
routeCache.setCachedRoutePrefix(prefix, iface1)
|
||||
|
||||
// Verify it's there
|
||||
addr := netip.MustParseAddr("10.0.1.5")
|
||||
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
|
||||
t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
|
||||
}
|
||||
|
||||
// Clear it
|
||||
routeCache.clearCachedRoutePrefix(prefix)
|
||||
|
||||
// Verify it's gone
|
||||
if got := routeCache.lookupCachedRoute(addr); got != nil {
|
||||
t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clear specific prefix preserves other prefixes", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
|
||||
prefix2 := netip.MustParsePrefix("192.168.0.0/16")
|
||||
routeCache.setCachedRoutePrefix(prefix1, iface1)
|
||||
routeCache.setCachedRoutePrefix(prefix2, iface2)
|
||||
|
||||
// Clear only prefix1
|
||||
routeCache.clearCachedRoutePrefix(prefix1)
|
||||
|
||||
// Verify prefix1 is gone
|
||||
addr1 := netip.MustParseAddr("10.0.1.5")
|
||||
if got := routeCache.lookupCachedRoute(addr1); got != nil {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr1, got)
|
||||
}
|
||||
|
||||
// Verify prefix2 is still there
|
||||
addr2 := netip.MustParseAddr("192.168.1.1")
|
||||
if got := routeCache.lookupCachedRoute(addr2); got != iface2 {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr2, got, iface2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clear all cached routes", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
// Insert multiple routes
|
||||
addr1 := netip.MustParseAddr("10.0.1.5")
|
||||
addr2 := netip.MustParseAddr("192.168.1.1")
|
||||
addr3 := netip.MustParseAddr("2001:db8::1")
|
||||
routeCache.setCachedRoute(addr1, iface1)
|
||||
routeCache.setCachedRoute(addr2, iface2)
|
||||
routeCache.setCachedRoute(addr3, iface3)
|
||||
|
||||
// Clear all
|
||||
routeCache.ClearAllCachedRoutes()
|
||||
|
||||
// Verify all are gone
|
||||
if got := routeCache.lookupCachedRoute(addr1); got != nil {
|
||||
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr1, got)
|
||||
}
|
||||
if got := routeCache.lookupCachedRoute(addr2); got != nil {
|
||||
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr2, got)
|
||||
}
|
||||
if got := routeCache.lookupCachedRoute(addr3); got != nil {
|
||||
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr3, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("overwrite existing route", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.1.5")
|
||||
routeCache.setCachedRoute(addr, iface1)
|
||||
|
||||
// Verify initial value
|
||||
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
|
||||
t.Errorf("initial: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
|
||||
}
|
||||
|
||||
// Overwrite with different interface
|
||||
routeCache.setCachedRoute(addr, iface2)
|
||||
|
||||
// Verify new value
|
||||
if got := routeCache.lookupCachedRoute(addr); got != iface2 {
|
||||
t.Errorf("after overwrite: lookupCachedRoute(%v) = %v, want %v", addr, got, iface2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 are separate", func(t *testing.T) {
|
||||
routeCache := NewRouteCache()
|
||||
|
||||
addr4 := netip.MustParseAddr("10.0.1.5")
|
||||
addr6 := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
routeCache.setCachedRoute(addr4, iface1)
|
||||
routeCache.setCachedRoute(addr6, iface2)
|
||||
|
||||
// Verify both are stored independently
|
||||
if got := routeCache.lookupCachedRoute(addr4); got != iface1 {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr4, got, iface1)
|
||||
}
|
||||
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
|
||||
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
|
||||
}
|
||||
|
||||
// Clear IPv4, verify IPv6 remains
|
||||
routeCache.ClearCachedRoute(addr4)
|
||||
if got := routeCache.lookupCachedRoute(addr4); got != nil {
|
||||
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want nil", addr4, got)
|
||||
}
|
||||
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
|
||||
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func hookInterfaces(t *testing.T, ifaces []net.Interface) {
|
||||
interfacesHook = func() ([]net.Interface, error) {
|
||||
return ifaces, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
interfacesHook = net.Interfaces
|
||||
})
|
||||
}
|
||||
|
||||
func hookDefaultInterfaces(t *testing.T) {
|
||||
hookInterfaces(t, allTestIfs)
|
||||
}
|
||||
|
||||
var (
|
||||
iface1 net.Interface = net.Interface{
|
||||
Index: 1,
|
||||
MTU: 1500,
|
||||
Name: "eth0",
|
||||
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
|
||||
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
|
||||
}
|
||||
iface2 net.Interface = net.Interface{
|
||||
Index: 2,
|
||||
MTU: 1500,
|
||||
Name: "wlan0",
|
||||
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x66},
|
||||
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
|
||||
}
|
||||
iface3 net.Interface = net.Interface{
|
||||
Index: 3,
|
||||
MTU: 1500,
|
||||
Name: "eth1",
|
||||
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x77},
|
||||
Flags: net.FlagBroadcast | net.FlagMulticast,
|
||||
}
|
||||
allTestIfs = []net.Interface{iface1, iface2, iface3}
|
||||
)
|
||||
|
||||
func TestFindInterfaceThatCanReach(t *testing.T) {
|
||||
origReachabilityHook := reachabilityHook
|
||||
t.Cleanup(func() {
|
||||
ifaceHasV4AndGlobalV6Hook = nil
|
||||
reachabilityHook = origReachabilityHook
|
||||
})
|
||||
|
||||
ifaceHasV4AndGlobalV6Hook = func(iface *net.Interface) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
t.Run("uses route cache on hit", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// Pre-populate cache
|
||||
addr := netip.MustParseAddr("8.8.8.8")
|
||||
cache.setCachedRoute(addr, &iface2)
|
||||
|
||||
// Hook should never be called when cache hits
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
t.Error("reachabilityHookFn should not be called when cache hits")
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Name != "wlan0" {
|
||||
t.Errorf("expected wlan0 from cache, got %s", result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("populates cache on miss", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// All interfaces succeed
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Check cache was populated
|
||||
addr := netip.MustParseAddr("1.1.1.1")
|
||||
cached := cache.lookupCachedRoute(addr)
|
||||
if cached == nil {
|
||||
t.Error("expected cache to be populated")
|
||||
} else if cached.Name != result.Name {
|
||||
t.Errorf("cached interface %s != result interface %s", cached.Name, result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when no interface reachable", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// All interfaces fail
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
return errors.New("unreachable")
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "192.0.2.1", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Logf("expected error: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result when unreachable, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cache respects longest prefix match", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// Cache 10.0.0.0/8 -> eth0
|
||||
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
|
||||
cache.setCachedRoutePrefix(prefix1, &iface1)
|
||||
|
||||
// Cache 10.0.1.0/24 -> wlan0
|
||||
prefix2 := netip.MustParsePrefix("10.0.1.0/24")
|
||||
cache.setCachedRoutePrefix(prefix2, &iface2)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
t.Error("should use cache, not probe")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test 10.0.1.5 -> should match more specific /24
|
||||
opts1 := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "10.0.1.5", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result1, _ := findInterfaceThatCanReach(opts1)
|
||||
if result1 == nil || result1.Name != "wlan0" {
|
||||
t.Errorf("expected wlan0 for 10.0.1.5, got %v", result1)
|
||||
}
|
||||
|
||||
// Test 10.0.2.5 -> should match broader /8
|
||||
opts2 := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "10.0.2.5", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result2, _ := findInterfaceThatCanReach(opts2)
|
||||
if result2 == nil || result2.Name != "eth0" {
|
||||
t.Errorf("expected eth0 for 10.0.2.5, got %v", result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("race mode returns first reachable", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// eth0 (iface1) responds quickly
|
||||
// wlan0 (iface2) responds slowly
|
||||
// eth1 (iface3) responds slowly
|
||||
// Channels to control when each probe completes
|
||||
wlan0Done := make(chan struct{})
|
||||
eth1Done := make(chan struct{})
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
switch iface.Index {
|
||||
case iface1.Index: // eth0 - returns immediately
|
||||
return nil
|
||||
case iface2.Index: // wlan0 - waits for signal
|
||||
<-wlan0Done
|
||||
return nil
|
||||
case iface3.Index: // eth1 - waits for signal
|
||||
<-eth1Done
|
||||
return nil
|
||||
}
|
||||
return errors.New("unknown interface")
|
||||
}
|
||||
defer func() {
|
||||
// Now signal the slower interfaces to complete
|
||||
close(wlan0Done)
|
||||
close(eth1Done)
|
||||
}()
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
|
||||
race: true,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result in race mode")
|
||||
}
|
||||
|
||||
// Should return quickly without waiting for all probes
|
||||
t.Logf("race mode returned interface: %s", result.Name)
|
||||
})
|
||||
|
||||
t.Run("filterf excludes interfaces", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
probeCount := atomic.Int32{}
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
probeCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
filterf: func(iface net.Interface) bool {
|
||||
// Exclude wlan0 and eth1
|
||||
return iface.Name != "wlan0" && iface.Name != "eth1"
|
||||
},
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
// Should only probe filtered interfaces
|
||||
if probeCount.Load() > 1 {
|
||||
t.Logf("probed %d interfaces after filtering", probeCount.Load())
|
||||
}
|
||||
|
||||
if result != nil && (result.Name == "wlan0" || result.Name == "eth1") {
|
||||
t.Errorf("filterf should have excluded %s", result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles hostname instead of IP", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use a hostname that can't be parsed as an IP
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "example.com", Port: "443", Network: "tcp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Cache should not be used for hostnames
|
||||
addr, parseErr := netip.ParseAddr("example.com")
|
||||
if parseErr == nil && addr.IsValid() {
|
||||
t.Error("example.com should not parse as valid IP")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default interface hint is respected", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// All interfaces are reachable
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set hint to prefer iface2 (index 2)
|
||||
origHintFn := defaultIfaceHintFn
|
||||
defer func() { defaultIfaceHintFn = origHintFn }()
|
||||
|
||||
defaultIfaceHintFn = func() int {
|
||||
return 2 // iface2 / wlan0
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Index != 2 {
|
||||
t.Errorf("expected default hint interface (index 2), got index %d (%s)", result.Index, result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv6 address uses IPv6 cache table", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// Pre-populate IPv6 cache
|
||||
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
|
||||
cache.setCachedRoute(addr6, &iface3)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
t.Error("should use cache for IPv6")
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil || result.Name != "eth1" {
|
||||
t.Errorf("expected eth1 from IPv6 cache, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 caches are independent", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
addr4 := netip.MustParseAddr("8.8.8.8")
|
||||
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
|
||||
|
||||
cache.setCachedRoute(addr4, &iface1)
|
||||
cache.setCachedRoute(addr6, &iface2)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
t.Error("should use cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test IPv4
|
||||
opts4 := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
result4, _ := findInterfaceThatCanReach(opts4)
|
||||
if result4 == nil || result4.Name != "eth0" {
|
||||
t.Errorf("IPv4: expected eth0, got %v", result4)
|
||||
}
|
||||
|
||||
// Test IPv6
|
||||
opts6 := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"},
|
||||
cache: cache,
|
||||
}
|
||||
result6, _ := findInterfaceThatCanReach(opts6)
|
||||
if result6 == nil || result6.Name != "wlan0" {
|
||||
t.Errorf("IPv6: expected wlan0, got %v", result6)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty host returns error", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: "", Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, err := findInterfaceThatCanReach(opts)
|
||||
|
||||
// Should handle empty host gracefully
|
||||
if err == nil && result != nil {
|
||||
t.Logf("handled empty host, returned %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("caches subnet prefix correctly", func(t *testing.T) {
|
||||
cache := NewRouteCache()
|
||||
hookDefaultInterfaces(t)
|
||||
|
||||
// Manually cache a /16 subnet
|
||||
prefix := netip.MustParsePrefix("192.168.0.0/16")
|
||||
cache.setCachedRoutePrefix(prefix, &iface1)
|
||||
|
||||
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
|
||||
t.Error("should use cached subnet")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test various IPs in the subnet
|
||||
testIPs := []string{
|
||||
"192.168.0.1",
|
||||
"192.168.1.1",
|
||||
"192.168.255.254",
|
||||
}
|
||||
|
||||
for _, ip := range testIPs {
|
||||
opts := probeOpts{
|
||||
logf: t.Logf,
|
||||
hpn: HostPortNetwork{Host: ip, Port: "53", Network: "udp"},
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
result, _ := findInterfaceThatCanReach(opts)
|
||||
if result == nil || result.Name != "eth0" {
|
||||
t.Errorf("IP %s: expected eth0 from cached subnet, got %v", ip, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TODO (barnstar): Working, but the sleep is egregious. How to test async eventbus properly?
|
||||
// func TestRouteCacheEventBus(t *testing.T) {
|
||||
// t.Run("insert and lookup IPv4", func(t *testing.T) {
|
||||
// rc := NewRouteCache()
|
||||
// bus := eventbus.New()
|
||||
// b := bus.Client("netns_test")
|
||||
// t.Cleanup(func() {
|
||||
// b.Close()
|
||||
// })
|
||||
|
||||
// route := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// // Example of publishing a route cache clear event
|
||||
// publisher := eventbus.Publish[netmon.ChangeDelta](b)
|
||||
// SetGlobalRouteCache(rc, bus, t.Logf)
|
||||
// rc.setCachedRoute(route, &net.Interface{Index: 1, Name: "eth0"})
|
||||
// ifBeforeEvent := rc.lookupCachedRoute(route)
|
||||
// if ifBeforeEvent == nil || ifBeforeEvent.Name != "eth0" {
|
||||
// t.Fatalf("expected cached route before event, got %v", ifBeforeEvent)
|
||||
// }
|
||||
|
||||
// publisher.Publish(netmon.ChangeDelta{RebindLikelyRequired: true})
|
||||
// time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// ifAfterEvent := rc.lookupCachedRoute(route)
|
||||
// if ifAfterEvent != nil {
|
||||
// t.Fatalf("expected cached route to be cleared after event, got %v", ifAfterEvent)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
@ -33,6 +33,7 @@ import (
|
||||
"tailscale.com/net/dns/resolver"
|
||||
"tailscale.com/net/ipset"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/sockstats"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@ -391,6 +392,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
|
||||
// TODO: there's probably a better place for this
|
||||
sockstats.SetNetMon(e.netMon)
|
||||
netns.SetGlobalRouteCache(netns.NewRouteCache(), e.eventBus, logf)
|
||||
|
||||
logf("link state: %+v", e.netMon.InterfaceState())
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user