feature/relayserver,net/udprelay: add IPv6 support (#16442)

Updates tailscale/corp#27502
Updates tailscale/corp#30043

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2025-07-02 20:38:39 -07:00 committed by GitHub
parent 77d19604f4
commit 3a4b439c62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 179 additions and 112 deletions

View File

@ -137,7 +137,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) {
return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set") return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set")
} }
var err error var err error
e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil) e.server, err = udprelay.NewServer(e.logf, *e.port, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -57,7 +57,10 @@ type Server struct {
bindLifetime time.Duration bindLifetime time.Duration
steadyStateLifetime time.Duration steadyStateLifetime time.Duration
bus *eventbus.Bus bus *eventbus.Bus
uc *net.UDPConn uc4 *net.UDPConn // always non-nil
uc4Port uint16 // always nonzero
uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization
uc6Port uint16 // may be zero if IPv6 bind fails during initialization
closeOnce sync.Once closeOnce sync.Once
wg sync.WaitGroup wg sync.WaitGroup
closeCh chan struct{} closeCh chan struct{}
@ -278,13 +281,11 @@ func (e *serverEndpoint) isBound() bool {
e.boundAddrPorts[1].IsValid() e.boundAddrPorts[1].IsValid()
} }
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet // NewServer constructs a [Server] listening on port. If port is zero, then
// supported. Port may be 0, and what ultimately gets bound is returned as // port selection is left up to the host networking stack. If
// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic // len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
// discovery, which is useful to override in tests. // which is useful to override in tests.
// func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
// TODO: IPv6 support
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) {
s = &Server{ s = &Server{
logf: logger.WithPrefix(logf, "relayserver"), logf: logger.WithPrefix(logf, "relayserver"),
disco: key.NewDisco(), disco: key.NewDisco(),
@ -306,30 +307,36 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
s.bus = bus s.bus = bus
netMon, err := netmon.New(s.bus, logf) netMon, err := netmon.New(s.bus, logf)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
s.netChecker = &netcheck.Client{ s.netChecker = &netcheck.Client{
NetMon: netMon, NetMon: netMon,
Logf: logger.WithPrefix(logf, "relayserver: netcheck:"), Logf: logger.WithPrefix(logf, "relayserver: netcheck:"),
SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) {
return s.uc.WriteToUDPAddrPort(b, addrPort) if addrPort.Addr().Is4() {
return s.uc4.WriteToUDPAddrPort(b, addrPort)
} else if s.uc6 != nil {
return s.uc6.WriteToUDPAddrPort(b, addrPort)
} else {
return 0, errors.New("IPv6 socket is not bound")
}
}, },
} }
boundPort, err = s.listenOn(port) err = s.listenOn(port)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
s.wg.Add(1)
go s.packetReadLoop()
s.wg.Add(1)
go s.endpointGCLoop()
if len(overrideAddrs) > 0 { if len(overrideAddrs) > 0 {
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs)) addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
for _, addr := range overrideAddrs { for _, addr := range overrideAddrs {
if addr.IsValid() { if addr.IsValid() {
addrPorts.Add(netip.AddrPortFrom(addr, boundPort)) if addr.Is4() {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
} else if s.uc6 != nil {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
}
} }
} }
s.addrPorts = addrPorts.Slice() s.addrPorts = addrPorts.Slice()
@ -337,7 +344,17 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
s.wg.Add(1) s.wg.Add(1)
go s.addrDiscoveryLoop() go s.addrDiscoveryLoop()
} }
return s, boundPort, nil
s.wg.Add(1)
go s.packetReadLoop(s.uc4)
if s.uc6 != nil {
s.wg.Add(1)
go s.packetReadLoop(s.uc6)
}
s.wg.Add(1)
go s.endpointGCLoop()
return s, nil
} }
func (s *Server) addrDiscoveryLoop() { func (s *Server) addrDiscoveryLoop() {
@ -351,14 +368,17 @@ func (s *Server) addrDiscoveryLoop() {
addrPorts.Make() addrPorts.Make()
// get local addresses // get local addresses
localPort := s.uc.LocalAddr().(*net.UDPAddr).Port
ips, _, err := netmon.LocalAddresses() ips, _, err := netmon.LocalAddresses()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, ip := range ips { for _, ip := range ips {
if ip.IsValid() { if ip.IsValid() {
addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort))) if ip.Is4() {
addrPorts.Add(netip.AddrPortFrom(ip, s.uc4Port))
} else {
addrPorts.Add(netip.AddrPortFrom(ip, s.uc6Port))
}
} }
} }
@ -413,24 +433,52 @@ func (s *Server) addrDiscoveryLoop() {
} }
} }
func (s *Server) listenOn(port int) (uint16, error) { // listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) // we manage to bind the IPv4 socket.
//
// The requested port may be zero, in which case port selection is left up to
// the host networking stack. We make no attempt to bind a consistent port
// across IPv4 and IPv6 if the requested port is zero.
//
// TODO: make these "re-bindable" in similar fashion to magicsock as a means to
// deal with EDR software closing them. http://go/corp/30118
func (s *Server) listenOn(port int) error {
for _, network := range []string{"udp4", "udp6"} {
uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil { if err != nil {
return 0, err if network == "udp4" {
return err
} else {
s.logf("ignoring IPv6 bind failure: %v", err)
break
}
} }
// TODO: set IP_PKTINFO sockopt // TODO: set IP_PKTINFO sockopt
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
if err != nil { if err != nil {
s.uc.Close() uc.Close()
return 0, err if s.uc4 != nil {
s.uc4.Close()
} }
boundPort, err := strconv.ParseUint(boundPortStr, 10, 16) return err
}
portUint, err := strconv.ParseUint(boundPortStr, 10, 16)
if err != nil { if err != nil {
s.uc.Close() uc.Close()
return 0, err if s.uc4 != nil {
s.uc4.Close()
} }
s.uc = uc return err
return uint16(boundPort), nil }
if network == "udp4" {
s.uc4 = uc
s.uc4Port = uint16(portUint)
} else {
s.uc6 = uc
s.uc6Port = uint16(portUint)
}
}
return nil
} }
// Close closes the server. // Close closes the server.
@ -438,7 +486,10 @@ func (s *Server) Close() error {
s.closeOnce.Do(func() { s.closeOnce.Do(func() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.uc.Close() s.uc4.Close()
if s.uc6 != nil {
s.uc6.Close()
}
close(s.closeCh) close(s.closeCh)
s.wg.Wait() s.wg.Wait()
clear(s.byVNI) clear(s.byVNI)
@ -507,7 +558,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
e.handlePacket(from, gh, b, uw, s.discoPublic) e.handlePacket(from, gh, b, uw, s.discoPublic)
} }
func (s *Server) packetReadLoop() { func (s *Server) packetReadLoop(uc *net.UDPConn) {
defer func() { defer func() {
s.wg.Done() s.wg.Done()
s.Close() s.Close()
@ -515,11 +566,11 @@ func (s *Server) packetReadLoop() {
b := make([]byte, 1<<16-1) b := make([]byte, 1<<16-1)
for { for {
// TODO: extract laddr from IP_PKTINFO for use in reply // TODO: extract laddr from IP_PKTINFO for use in reply
n, from, err := s.uc.ReadFromUDPAddrPort(b) n, from, err := uc.ReadFromUDPAddrPort(b)
if err != nil { if err != nil {
return return
} }
s.handlePacket(from, b[:n], s.uc) s.handlePacket(from, b[:n], uc)
} }
} }

View File

@ -29,7 +29,7 @@ type testClient struct {
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient { func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient {
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
uc, err := net.DialUDP("udp4", nil, rAddr) uc, err := net.DialUDP("udp", nil, rAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -180,9 +180,23 @@ func TestServer(t *testing.T) {
discoA := key.NewDisco() discoA := key.NewDisco()
discoB := key.NewDisco() discoB := key.NewDisco()
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") cases := []struct {
name string
overrideAddrs []netip.Addr
}{
{
name: "over ipv4",
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
{
name: "over ipv6",
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
},
}
server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr}) for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -261,4 +275,6 @@ func TestServer(t *testing.T) {
if !bytes.Equal(txToBOnNewPort, rxFromA) { if !bytes.Equal(txToBOnNewPort, rxFromA) {
t.Fatal("unexpected msg A->B") t.Fatal("unexpected msg A->B")
} }
})
}
} }