diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index b90a62345..24304e8ec 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -6,9 +6,13 @@ package relayserver import ( + "log" + "net/netip" + "strings" "sync" "tailscale.com/disco" + "tailscale.com/envknob" "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnext" @@ -115,6 +119,26 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV e.handleBusLifetimeLocked() } +// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It +// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is +// not a stable interface, and is subject to change. +var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) { + all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS") + const max = 3 + remain := all + for remain != "" && len(ret) < max { + var s string + s, remain, _ = strings.Cut(remain, ",") + addr, err := netip.ParseAddr(s) + if err != nil { + log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err) + continue + } + ret = append(ret, addr) + } + return +}) + func (e *extension) consumeEventbusTopics(port int) { defer close(e.busDoneCh) @@ -140,7 +164,7 @@ func (e *extension) consumeEventbusTopics(port int) { case req := <-reqSub.Events(): if rs == nil { var err error - rs, err = udprelay.NewServer(e.logf, port, nil) + rs, err = udprelay.NewServer(e.logf, port, overrideAddrs()) if err != nil { e.logf("error initializing server: %v", err) continue diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index 987bb569a..b28ebaba1 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -480,11 +480,13 @@ func (lc *LogCatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) { // TestEnv contains the test environment (set of servers) used by one // or more nodes. type TestEnv struct { - t testing.TB - tunMode bool - cli string - daemon string - loopbackPort *int + t testing.TB + tunMode bool + cli string + daemon string + loopbackPort *int + neverDirectUDP bool + relayServerUseLoopback bool LogCatcher *LogCatcher LogCatcherServer *httptest.Server @@ -842,6 +844,12 @@ func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { if n.env.loopbackPort != nil { cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) } + if n.env.neverDirectUDP { + cmd.Env = append(cmd.Env, "TS_DEBUG_NEVER_DIRECT_UDP=1") + } + if n.env.relayServerUseLoopback { + cmd.Env = append(cmd.Env, "TS_DEBUG_RELAY_SERVER_ADDRS=::1,127.0.0.1") + } if version.IsRace() { cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") } diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index de464108c..b282adcf8 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -44,6 +44,7 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/ptr" "tailscale.com/util/must" + "tailscale.com/util/set" ) func TestMain(m *testing.M) { @@ -1530,3 +1531,105 @@ func TestEncryptStateMigration(t *testing.T) { runNode(t, wantPlaintextStateKeys) }) } + +// TestPeerRelayPing creates three nodes with one acting as a peer relay. +// The test succeeds when "tailscale ping" flows through the peer +// relay between all 3 nodes. +func TestPeerRelayPing(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + + env := NewTestEnv(t, ConfigureControl(func(server *testcontrol.Server) { + server.PeerRelayGrants = true + })) + env.neverDirectUDP = true + env.relayServerUseLoopback = true + + n1 := NewTestNode(t, env) + n2 := NewTestNode(t, env) + peerRelay := NewTestNode(t, env) + + allNodes := []*TestNode{n1, n2, peerRelay} + wantPeerRelayServers := make(set.Set[string]) + for _, n := range allNodes { + n.StartDaemon() + n.AwaitResponding() + n.MustUp() + wantPeerRelayServers.Add(n.AwaitIP4().String()) + n.AwaitRunning() + } + + if err := peerRelay.Tailscale("set", "--relay-server-port=0").Run(); err != nil { + t.Fatal(err) + } + + errCh := make(chan error) + for _, a := range allNodes { + go func() { + err := tstest.WaitFor(time.Second*5, func() error { + out, err := a.Tailscale("debug", "peer-relay-servers").CombinedOutput() + if err != nil { + return fmt.Errorf("debug peer-relay-servers failed: %v", err) + } + servers := make([]string, 0) + err = json.Unmarshal(out, &servers) + if err != nil { + return fmt.Errorf("failed to unmarshal debug peer-relay-servers: %v", err) + } + gotPeerRelayServers := make(set.Set[string]) + for _, server := range servers { + gotPeerRelayServers.Add(server) + } + if !gotPeerRelayServers.Equal(wantPeerRelayServers) { + return fmt.Errorf("got peer relay servers: %v want: %v", gotPeerRelayServers, wantPeerRelayServers) + } + return nil + }) + errCh <- err + }() + } + for range allNodes { + err := <-errCh + if err != nil { + t.Fatal(err) + } + } + + pingPairs := make([][2]*TestNode, 0) + for _, a := range allNodes { + for _, z := range allNodes { + if a == z { + continue + } + pingPairs = append(pingPairs, [2]*TestNode{a, z}) + } + } + for _, pair := range pingPairs { + go func() { + a := pair[0] + z := pair[1] + err := tstest.WaitFor(time.Second*10, func() error { + remoteKey := z.MustStatus().Self.PublicKey + if err := a.Tailscale("ping", "--until-direct=false", "--c=1", "--timeout=1s", z.AwaitIP4().String()).Run(); err != nil { + return err + } + remotePeer, ok := a.MustStatus().Peer[remoteKey] + if !ok { + return fmt.Errorf("%v->%v remote peer not found", a.MustStatus().Self.ID, z.MustStatus().Self.ID) + } + if len(remotePeer.PeerRelay) == 0 { + return fmt.Errorf("%v->%v not using peer relay, curAddr=%v relay=%v", a.MustStatus().Self.ID, z.MustStatus().Self.ID, remotePeer.CurAddr, remotePeer.Relay) + } + t.Logf("%v->%v using peer relay addr: %v", a.MustStatus().Self.ID, z.MustStatus().Self.ID, remotePeer.PeerRelay) + return nil + }) + errCh <- err + }() + } + for range pingPairs { + err := <-errCh + if err != nil { + t.Fatal(err) + } + } +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 2fbf37de9..66d868aca 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -55,6 +55,10 @@ type Server struct { MagicDNSDomain string HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + // PeerRelayGrants, if true, inserts relay capabilities into the wildcard + // grants rules. + PeerRelayGrants bool + // AllNodesSameUser, if true, makes all created nodes // belong to the same user. AllNodesSameUser bool @@ -931,14 +935,21 @@ var keepAliveMsg = &struct { KeepAlive: true, } -func packetFilterWithIngressCaps() []tailcfg.FilterRule { +func packetFilterWithIngress(addRelayCaps bool) []tailcfg.FilterRule { out := slices.Clone(tailcfg.FilterAllowAll) + caps := []tailcfg.PeerCapability{ + tailcfg.PeerCapabilityIngress, + } + if addRelayCaps { + caps = append(caps, tailcfg.PeerCapabilityRelay) + caps = append(caps, tailcfg.PeerCapabilityRelayTarget) + } out = append(out, tailcfg.FilterRule{ SrcIPs: []string{"*"}, CapGrant: []tailcfg.CapGrant{ { Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - Caps: []tailcfg.PeerCapability{tailcfg.PeerCapabilityIngress}, + Caps: caps, }, }, }) @@ -977,7 +988,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, DERPMap: s.DERPMap, Domain: domain, CollectServices: "true", - PacketFilter: packetFilterWithIngressCaps(), + PacketFilter: packetFilterWithIngress(s.PeerRelayGrants), DNSConfig: dns, ControlTime: &t, }