diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 4634f3ac2..5a82a9d11 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -137,7 +137,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) { return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set") } var err error - e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil) + e.server, err = udprelay.NewServer(e.logf, *e.port, nil) if err != nil { return nil, err } diff --git a/net/udprelay/server.go b/net/udprelay/server.go index e32f8917c..d2661e59f 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -57,7 +57,10 @@ type Server struct { bindLifetime time.Duration steadyStateLifetime time.Duration bus *eventbus.Bus - uc *net.UDPConn + uc4 *net.UDPConn // always non-nil + uc4Port uint16 // always nonzero + uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization + uc6Port uint16 // may be zero if IPv6 bind fails during initialization closeOnce sync.Once wg sync.WaitGroup closeCh chan struct{} @@ -278,13 +281,11 @@ func (e *serverEndpoint) isBound() bool { e.boundAddrPorts[1].IsValid() } -// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet -// supported. Port may be 0, and what ultimately gets bound is returned as -// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic -// discovery, which is useful to override in tests. -// -// TODO: IPv6 support -func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) { +// NewServer constructs a [Server] listening on port. If port is zero, then +// port selection is left up to the host networking stack. If +// len(overrideAddrs) > 0 these will be used in place of dynamic discovery, +// which is useful to override in tests. +func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) { s = &Server{ logf: logger.WithPrefix(logf, "relayserver"), disco: key.NewDisco(), @@ -306,30 +307,36 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve s.bus = bus netMon, err := netmon.New(s.bus, logf) if err != nil { - return nil, 0, err + return nil, err } s.netChecker = &netcheck.Client{ NetMon: netMon, Logf: logger.WithPrefix(logf, "relayserver: netcheck:"), SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { - return s.uc.WriteToUDPAddrPort(b, addrPort) + if addrPort.Addr().Is4() { + return s.uc4.WriteToUDPAddrPort(b, addrPort) + } else if s.uc6 != nil { + return s.uc6.WriteToUDPAddrPort(b, addrPort) + } else { + return 0, errors.New("IPv6 socket is not bound") + } }, } - boundPort, err = s.listenOn(port) + err = s.listenOn(port) if err != nil { - return nil, 0, err + return nil, err } - s.wg.Add(1) - go s.packetReadLoop() - s.wg.Add(1) - go s.endpointGCLoop() if len(overrideAddrs) > 0 { addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs)) for _, addr := range overrideAddrs { if addr.IsValid() { - addrPorts.Add(netip.AddrPortFrom(addr, boundPort)) + if addr.Is4() { + addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port)) + } else if s.uc6 != nil { + addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port)) + } } } s.addrPorts = addrPorts.Slice() @@ -337,7 +344,17 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve s.wg.Add(1) go s.addrDiscoveryLoop() } - return s, boundPort, nil + + s.wg.Add(1) + go s.packetReadLoop(s.uc4) + if s.uc6 != nil { + s.wg.Add(1) + go s.packetReadLoop(s.uc6) + } + s.wg.Add(1) + go s.endpointGCLoop() + + return s, nil } func (s *Server) addrDiscoveryLoop() { @@ -351,14 +368,17 @@ func (s *Server) addrDiscoveryLoop() { addrPorts.Make() // get local addresses - localPort := s.uc.LocalAddr().(*net.UDPAddr).Port ips, _, err := netmon.LocalAddresses() if err != nil { return nil, err } for _, ip := range ips { if ip.IsValid() { - addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort))) + if ip.Is4() { + addrPorts.Add(netip.AddrPortFrom(ip, s.uc4Port)) + } else { + addrPorts.Add(netip.AddrPortFrom(ip, s.uc6Port)) + } } } @@ -413,24 +433,52 @@ func (s *Server) addrDiscoveryLoop() { } } -func (s *Server) listenOn(port int) (uint16, error) { - uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) - if err != nil { - return 0, err +// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if +// we manage to bind the IPv4 socket. +// +// The requested port may be zero, in which case port selection is left up to +// the host networking stack. We make no attempt to bind a consistent port +// across IPv4 and IPv6 if the requested port is zero. +// +// TODO: make these "re-bindable" in similar fashion to magicsock as a means to +// deal with EDR software closing them. http://go/corp/30118 +func (s *Server) listenOn(port int) error { + for _, network := range []string{"udp4", "udp6"} { + uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + if network == "udp4" { + return err + } else { + s.logf("ignoring IPv6 bind failure: %v", err) + break + } + } + // TODO: set IP_PKTINFO sockopt + _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) + if err != nil { + uc.Close() + if s.uc4 != nil { + s.uc4.Close() + } + return err + } + portUint, err := strconv.ParseUint(boundPortStr, 10, 16) + if err != nil { + uc.Close() + if s.uc4 != nil { + s.uc4.Close() + } + return err + } + if network == "udp4" { + s.uc4 = uc + s.uc4Port = uint16(portUint) + } else { + s.uc6 = uc + s.uc6Port = uint16(portUint) + } } - // TODO: set IP_PKTINFO sockopt - _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) - if err != nil { - s.uc.Close() - return 0, err - } - boundPort, err := strconv.ParseUint(boundPortStr, 10, 16) - if err != nil { - s.uc.Close() - return 0, err - } - s.uc = uc - return uint16(boundPort), nil + return nil } // Close closes the server. @@ -438,7 +486,10 @@ func (s *Server) Close() error { s.closeOnce.Do(func() { s.mu.Lock() defer s.mu.Unlock() - s.uc.Close() + s.uc4.Close() + if s.uc6 != nil { + s.uc6.Close() + } close(s.closeCh) s.wg.Wait() clear(s.byVNI) @@ -507,7 +558,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { e.handlePacket(from, gh, b, uw, s.discoPublic) } -func (s *Server) packetReadLoop() { +func (s *Server) packetReadLoop(uc *net.UDPConn) { defer func() { s.wg.Done() s.Close() @@ -515,11 +566,11 @@ func (s *Server) packetReadLoop() { b := make([]byte, 1<<16-1) for { // TODO: extract laddr from IP_PKTINFO for use in reply - n, from, err := s.uc.ReadFromUDPAddrPort(b) + n, from, err := uc.ReadFromUDPAddrPort(b) if err != nil { return } - s.handlePacket(from, b[:n], s.uc) + s.handlePacket(from, b[:n], uc) } } diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 3fcb9b8b1..8c0c5aff6 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -29,7 +29,7 @@ type testClient struct { func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient { rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} - uc, err := net.DialUDP("udp4", nil, rAddr) + uc, err := net.DialUDP("udp", nil, rAddr) if err != nil { t.Fatal(err) } @@ -180,85 +180,101 @@ func TestServer(t *testing.T) { discoA := key.NewDisco() discoB := key.NewDisco() - ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") - - server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr}) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) - if err != nil { - t.Fatal(err) - } - dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) - if err != nil { - t.Fatal(err) + cases := []struct { + name string + overrideAddrs []netip.Addr + }{ + { + name: "over ipv4", + overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + { + name: "over ipv6", + overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")}, + }, } - // We expect the same endpoint details pre-handshake. - if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { - t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) - } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(t.Logf, 0, tt.overrideAddrs) + if err != nil { + t.Fatal(err) + } + defer server.Close() - if len(endpoint.AddrPorts) != 1 { - t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) - } - tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) - defer tcA.close() - tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) - defer tcB.close() + endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } - tcA.handshake(t) - tcB.handshake(t) + // We expect the same endpoint details pre-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } - dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) - if err != nil { - t.Fatal(err) - } - // We expect the same endpoint details post-handshake. - if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { - t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) - } + if len(endpoint.AddrPorts) != 1 { + t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) + } + tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) + defer tcA.close() + tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) + defer tcB.close() - txToB := []byte{1, 2, 3} - tcA.writeDataPkt(t, txToB) - rxFromA := tcB.readDataPkt(t) - if !bytes.Equal(txToB, rxFromA) { - t.Fatal("unexpected msg A->B") - } + tcA.handshake(t) + tcB.handshake(t) - txToA := []byte{4, 5, 6} - tcB.writeDataPkt(t, txToA) - rxFromB := tcA.readDataPkt(t) - if !bytes.Equal(txToA, rxFromB) { - t.Fatal("unexpected msg B->A") - } + dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + // We expect the same endpoint details post-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } - tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) - tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1 - defer tcAOnNewPort.close() + txToB := []byte{1, 2, 3} + tcA.writeDataPkt(t, txToB) + rxFromA := tcB.readDataPkt(t) + if !bytes.Equal(txToB, rxFromA) { + t.Fatal("unexpected msg A->B") + } - // Handshake client A on a new source IP:port, verify we receive packets on the new binding - tcAOnNewPort.handshake(t) - txToAOnNewPort := []byte{7, 8, 9} - tcB.writeDataPkt(t, txToAOnNewPort) - rxFromB = tcAOnNewPort.readDataPkt(t) - if !bytes.Equal(txToAOnNewPort, rxFromB) { - t.Fatal("unexpected msg B->A") - } + txToA := []byte{4, 5, 6} + tcB.writeDataPkt(t, txToA) + rxFromB := tcA.readDataPkt(t) + if !bytes.Equal(txToA, rxFromB) { + t.Fatal("unexpected msg B->A") + } - tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) - tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1 - defer tcBOnNewPort.close() + tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) + tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1 + defer tcAOnNewPort.close() - // Handshake client B on a new source IP:port, verify we receive packets on the new binding - tcBOnNewPort.handshake(t) - txToBOnNewPort := []byte{7, 8, 9} - tcAOnNewPort.writeDataPkt(t, txToBOnNewPort) - rxFromA = tcBOnNewPort.readDataPkt(t) - if !bytes.Equal(txToBOnNewPort, rxFromA) { - t.Fatal("unexpected msg A->B") + // Handshake client A on a new source IP:port, verify we receive packets on the new binding + tcAOnNewPort.handshake(t) + txToAOnNewPort := []byte{7, 8, 9} + tcB.writeDataPkt(t, txToAOnNewPort) + rxFromB = tcAOnNewPort.readDataPkt(t) + if !bytes.Equal(txToAOnNewPort, rxFromB) { + t.Fatal("unexpected msg B->A") + } + + tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) + tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1 + defer tcBOnNewPort.close() + + // Handshake client B on a new source IP:port, verify we receive packets on the new binding + tcBOnNewPort.handshake(t) + txToBOnNewPort := []byte{7, 8, 9} + tcAOnNewPort.writeDataPkt(t, txToBOnNewPort) + rxFromA = tcBOnNewPort.readDataPkt(t) + if !bytes.Equal(txToBOnNewPort, rxFromA) { + t.Fatal("unexpected msg A->B") + } + }) } }