diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index aeee43646..41d239e3b 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1141,83 +1141,91 @@ func TestListenService(t *testing.T) { // This ends up also testing the Service forwarding logic in // LocalBackend, but that's useful too. t.Run(tt.name, func(t *testing.T) { - ctx := t.Context() + // We run each test with and without a TUN device ([Server.Tun]). + // Note that this TUN device is distinct from TUN mode for Services. + doTest := func(t *testing.T, withTUNDevice bool) { + ctx := t.Context() - controlURL, control := startControl(t) - serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") - serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") + lt := setupTwoClientTest(t, withTUNDevice) + serviceHost := lt.s2 + serviceClient := lt.s1 + control := lt.control - const serviceName = tailcfg.ServiceName("svc:foo") - const serviceVIP = "100.11.22.33" + const serviceName = tailcfg.ServiceName("svc:foo") + const serviceVIP = "100.11.22.33" - // == Set up necessary state in our mock == + // == Set up necessary state in our mock == - // The Service host must have the 'service-host' capability, which - // is a mapping from the Service name to the Service VIP. - var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] - mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) - j := must.Get(json.Marshal(serviceHostCaps)) - cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() - mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) - control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) + // The Service host must have the 'service-host' capability, which + // is a mapping from the Service name to the Service VIP. + var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] + mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) + j := must.Get(json.Marshal(serviceHostCaps)) + cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() + mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) + control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) - // The Service host must be allowed to advertise the Service VIP. - control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ - netip.MustParsePrefix(serviceVIP + `/32`), - }) - - // The Service host must be a tagged node (any tag will do). - serviceHostNode := control.Node(serviceHost.lb.NodeKey()) - serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") - control.UpdateNode(serviceHostNode) - - // The service client must accept routes advertised by other nodes - // (RouteAll is equivalent to --accept-routes). - must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - RouteAllSet: true, - Prefs: ipn.Prefs{ - RouteAll: true, - }, - })) - - // Set up DNS for our Service. - control.AddDNSRecords(tailcfg.DNSRecord{ - Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, - Value: serviceVIP, - }) - - if tt.extraSetup != nil { - tt.extraSetup(t, control) - } - - // Force netmap updates to avoid race conditions. The nodes need to - // see our control updates before we can start the test. - must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey())) - must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey())) - netmapUpToDate := func(s *Server) bool { - nm := s.lb.NetMap() - return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { - return r.Value == serviceVIP + // The Service host must be allowed to advertise the Service VIP. + control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ + netip.MustParsePrefix(serviceVIP + `/32`), }) - } - for !netmapUpToDate(serviceClient) { - time.Sleep(10 * time.Millisecond) - } - for !netmapUpToDate(serviceHost) { - time.Sleep(10 * time.Millisecond) + + // The Service host must be a tagged node (any tag will do). + serviceHostNode := control.Node(serviceHost.lb.NodeKey()) + serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") + control.UpdateNode(serviceHostNode) + + // The service client must accept routes advertised by other nodes + // (RouteAll is equivalent to --accept-routes). + must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{ + RouteAll: true, + }, + })) + + // Set up DNS for our Service. + control.AddDNSRecords(tailcfg.DNSRecord{ + Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, + Value: serviceVIP, + }) + + if tt.extraSetup != nil { + tt.extraSetup(t, control) + } + + // Force netmap updates to avoid race conditions. The nodes need to + // see our control updates before we can start the test. + must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey())) + must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey())) + netmapUpToDate := func(s *Server) bool { + nm := s.lb.NetMap() + return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { + return r.Value == serviceVIP + }) + } + for !netmapUpToDate(serviceClient) { + time.Sleep(10 * time.Millisecond) + } + for !netmapUpToDate(serviceHost) { + time.Sleep(10 * time.Millisecond) + } + + // == Done setting up mock state == + + // Start the Service listeners. + listeners := make([]*ServiceListener, 0, len(tt.modes)) + for _, input := range tt.modes { + ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) + defer ln.Close() + listeners = append(listeners, ln) + } + + tt.run(t, listeners, serviceClient) } - // == Done setting up mock state == - - // Start the Service listeners. - listeners := make([]*ServiceListener, 0, len(tt.modes)) - for _, input := range tt.modes { - ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) - defer ln.Close() - listeners = append(listeners, ln) - } - - tt.run(t, listeners, serviceClient) + t.Run("TUN", func(t *testing.T) { doTest(t, true) }) + t.Run("netstack", func(t *testing.T) { doTest(t, false) }) }) } } @@ -1928,20 +1936,21 @@ func (t *chanTUN) BatchSize() int { return 1 } // listenTest provides common setup for listener and TUN tests. type listenTest struct { + control *testcontrol.Server s1, s2 *Server s1ip4, s1ip6 netip.Addr s2ip4, s2ip6 netip.Addr tun *chanTUN // nil for netstack mode } -// setupListenTest creates two tsnet servers for testing. +// setupTwoClientTest creates two tsnet servers for testing. // If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only. -func setupListenTest(t *testing.T, useTUN bool) *listenTest { +func setupTwoClientTest(t *testing.T, useTUN bool) *listenTest { t.Helper() tstest.Shard(t) tstest.ResourceCheck(t) ctx := t.Context() - controlURL, _ := startControl(t) + controlURL, control := startControl(t) s1, _, _ := startServer(t, ctx, controlURL, "s1") tmp := filepath.Join(t.TempDir(), "s2") @@ -1969,6 +1978,7 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest { if err != nil { t.Fatal(err) } + s2.lb.ConfigureCertsForTest(testCertRoot.getCert) s1ip4, s1ip6 := s1.TailscaleIPs() s2ip4 := s2status.TailscaleIPs[0] @@ -1981,13 +1991,14 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest { must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP)) return &listenTest{ - s1: s1, - s2: s2, - s1ip4: s1ip4, - s1ip6: s1ip6, - s2ip4: s2ip4, - s2ip6: s2ip6, - tun: tun, + control: control, + s1: s1, + s2: s2, + s1ip4: s1ip4, + s1ip6: s1ip6, + s2ip4: s2ip4, + s2ip6: s2ip6, + tun: tun, } } @@ -2016,7 +2027,7 @@ func echoUDP(pkt []byte) []byte { } func TestTUN(t *testing.T) { - tt := setupListenTest(t, true) + tt := setupTwoClientTest(t, true) go func() { for pkt := range tt.tun.Inbound { @@ -2059,7 +2070,7 @@ func TestTUN(t *testing.T) { // responses. This verifies that handleLocalPackets intercepts outbound traffic // to the service IP. func TestTUNDNS(t *testing.T) { - tt := setupListenTest(t, true) + tt := setupTwoClientTest(t, true) test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) { tt.tun.Outbound <- buildDNSQuery("s2", srcIP) @@ -2149,13 +2160,13 @@ func TestListenPacket(t *testing.T) { } t.Run("Netstack", func(t *testing.T) { - lt := setupListenTest(t, false) + lt := setupTwoClientTest(t, false) t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) }) t.Run("TUN", func(t *testing.T) { - lt := setupListenTest(t, true) + lt := setupTwoClientTest(t, true) t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) }) @@ -2221,13 +2232,13 @@ func TestListenTCP(t *testing.T) { } t.Run("Netstack", func(t *testing.T) { - lt := setupListenTest(t, false) + lt := setupTwoClientTest(t, false) t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) }) t.Run("TUN", func(t *testing.T) { - lt := setupListenTest(t, true) + lt := setupTwoClientTest(t, true) t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) }) @@ -2299,13 +2310,13 @@ func TestListenTCPDualStack(t *testing.T) { } t.Run("Netstack", func(t *testing.T) { - lt := setupListenTest(t, false) + lt := setupTwoClientTest(t, false) t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) }) t.Run("TUN", func(t *testing.T) { - lt := setupListenTest(t, true) + lt := setupTwoClientTest(t, true) t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) }) @@ -2372,13 +2383,13 @@ func TestDialTCP(t *testing.T) { } t.Run("Netstack", func(t *testing.T) { - lt := setupListenTest(t, false) + lt := setupTwoClientTest(t, false) t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) }) t.Run("TUN", func(t *testing.T) { - lt := setupListenTest(t, true) + lt := setupTwoClientTest(t, true) var escapedTCPPackets atomic.Int32 var wg sync.WaitGroup @@ -2460,13 +2471,13 @@ func TestDialUDP(t *testing.T) { } t.Run("Netstack", func(t *testing.T) { - lt := setupListenTest(t, false) + lt := setupTwoClientTest(t, false) t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) }) t.Run("TUN", func(t *testing.T) { - lt := setupListenTest(t, true) + lt := setupTwoClientTest(t, true) var escapedUDPPackets atomic.Int32 var wg sync.WaitGroup