diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index aeee43646..671120f99 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -861,6 +861,370 @@ func TestFunnelClose(t *testing.T) { }) } +func TestListenServiceAlongsideTUN(t *testing.T) { + // First test an error case which doesn't require all of the fancy setup. + t.Run("untagged_node_error", func(t *testing.T) { + lt := setupListenTest(t, true) + serviceHost := lt.s2 + + ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080}) + if ln != nil { + ln.Close() + } + if !errors.Is(err, ErrUntaggedServiceHost) { + t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err) + } + }) + + // Now on to the fancier tests. + + type dialFn func(context.Context, string, string) (net.Conn, error) + + // TCP helpers + acceptAndEcho := func(t *testing.T, ln net.Listener) { + t.Helper() + conn, err := ln.Accept() + if err != nil { + t.Error("accept error:", err) + return + } + defer conn.Close() + if _, err := io.Copy(conn, conn); err != nil { + t.Error("copy error:", err) + } + } + assertEcho := func(t *testing.T, conn net.Conn) { + t.Helper() + msg := "echo" + buf := make([]byte, 1024) + if _, err := conn.Write([]byte(msg)); err != nil { + t.Fatal("write failed:", err) + } + n, err := conn.Read(buf) + if err != nil { + t.Fatal("read failed:", err) + } + got := string(buf[:n]) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + + // HTTP helpers + checkAndEcho := func(t *testing.T, ln net.Listener, check func(r *http.Request)) { + t.Helper() + if check == nil { + check = func(*http.Request) {} + } + http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + check(r) + if _, err := io.Copy(w, r.Body); err != nil { + t.Error("copy error:", err) + w.WriteHeader(http.StatusInternalServerError) + } + })) + } + assertEchoHTTP := func(t *testing.T, hostname, path string, dial dialFn) { + t.Helper() + c := http.Client{ + Transport: &http.Transport{ + DialContext: dial, + }, + } + msg := "echo" + resp, err := c.Post("http://"+hostname+path, "text/plain", strings.NewReader(msg)) + if err != nil { + t.Fatal("posting request:", err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("reading body:", err) + } + got := string(b) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + + tests := []struct { + name string + + // modes is used as input to [Server.ListenService]. + // + // If this slice has multiple modes, then ListenService will be invoked + // multiple times. The number of listeners provided to the run function + // (below) will always match the number of elements in this slice. + modes []ServiceMode + + extraSetup func(t *testing.T, control *testcontrol.Server) + + // run executes the test. This function does not need to close any of + // the input resources, but it should close any new resources it opens. + // listeners[i] corresponds to inputs[i]. + run func(t *testing.T, listeners []*ServiceListener, peer *Server) + }{ + { + name: "basic_TCP", + modes: []ServiceMode{ + ServiceModeTCP{Port: 99}, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 99) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, conn) + }, + }, + { + name: "TLS_terminated_TCP", + modes: []ServiceMode{ + ServiceModeTCP{ + TerminateTLS: true, + Port: 443, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 443) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, tls.Client(conn, &tls.Config{ + ServerName: listeners[0].FQDN, + RootCAs: testCertRoot.Pool(), + })) + }, + }, + { + name: "identity_headers", + modes: []ServiceMode{ + ServiceModeHTTP{ + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + expectHeader := "Tailscale-User-Name" + go checkAndEcho(t, listeners[0], func(r *http.Request) { + if _, ok := r.Header[expectHeader]; !ok { + t.Error("did not see expected header:", expectHeader) + } + }) + assertEchoHTTP(t, listeners[0].FQDN, "", peer.Dial) + }, + }, + { + name: "identity_headers_TLS", + modes: []ServiceMode{ + ServiceModeHTTP{ + HTTPS: true, + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + expectHeader := "Tailscale-User-Name" + go checkAndEcho(t, listeners[0], func(r *http.Request) { + if _, ok := r.Header[expectHeader]; !ok { + t.Error("did not see expected header:", expectHeader) + } + }) + + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { + tcpConn, err := peer.Dial(ctx, network, addr) + if err != nil { + return nil, err + } + return tls.Client(tcpConn, &tls.Config{ + ServerName: listeners[0].FQDN, + RootCAs: testCertRoot.Pool(), + }), nil + } + + assertEchoHTTP(t, listeners[0].FQDN, "", dial) + }, + }, + { + name: "app_capabilities", + modes: []ServiceMode{ + ServiceModeHTTP{ + Port: 80, + AcceptAppCaps: map[string][]string{ + "/": {"example.com/cap/all-paths"}, + "/foo": {"example.com/cap/all-paths", "example.com/cap/foo"}, + }, + }, + }, + extraSetup: func(t *testing.T, control *testcontrol.Server) { + control.SetGlobalAppCaps(tailcfg.PeerCapMap{ + "example.com/cap/all-paths": []tailcfg.RawMessage{`true`}, + "example.com/cap/foo": []tailcfg.RawMessage{`true`}, + }) + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + allPathsCap := "example.com/cap/all-paths" + fooCap := "example.com/cap/foo" + checkCaps := func(r *http.Request) { + rawCaps, ok := r.Header["Tailscale-App-Capabilities"] + if !ok { + t.Error("no app capabilities header") + return + } + if len(rawCaps) != 1 { + t.Error("expected one app capabilities header value, got", len(rawCaps)) + return + } + var caps map[string][]any + if err := json.Unmarshal([]byte(rawCaps[0]), &caps); err != nil { + t.Error("error unmarshaling app caps:", err) + return + } + if _, ok := caps[allPathsCap]; !ok { + t.Errorf("got app caps, but %v is not present; saw:\n%v", allPathsCap, caps) + } + if strings.HasPrefix(r.URL.Path, "/foo") { + if _, ok := caps[fooCap]; !ok { + t.Errorf("%v should be present for /foo request; saw:\n%v", fooCap, caps) + } + } else { + if _, ok := caps[fooCap]; ok { + t.Errorf("%v should not be present for non-/foo request; saw:\n%v", fooCap, caps) + } + } + } + + go checkAndEcho(t, listeners[0], checkCaps) + assertEchoHTTP(t, listeners[0].FQDN, "", peer.Dial) + assertEchoHTTP(t, listeners[0].FQDN, "/foo", peer.Dial) + assertEchoHTTP(t, listeners[0].FQDN, "/foo/bar", peer.Dial) + }, + }, + { + name: "multiple_ports", + modes: []ServiceMode{ + ServiceModeTCP{ + Port: 99, + }, + ServiceModeHTTP{ + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 99) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + assertEcho(t, conn) + + go checkAndEcho(t, listeners[1], nil) + assertEchoHTTP(t, listeners[1].FQDN, "", peer.Dial) + }, + }, + } + + for _, tt := range tests { + // Overview: + // - start test control + // - start 2 tsnet nodes: + // one to act as Service host and a second to act as a peer client + // - configure necessary state on control mock + // - start a Service listener from the host + // - call tt.run with our test bed + // + // This ends up also testing the Service forwarding logic in + // LocalBackend, but that's useful too. + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + + // controlURL, control := startControl(t) + // serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + // serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") + + lt := setupListenTest(t, true) + serviceHost := lt.s2 + serviceClient := lt.s1 + control := lt.control + + const serviceName = tailcfg.ServiceName("svc:foo") + const serviceVIP = "100.11.22.33" + + // == Set up necessary state in our mock == + + // The Service host must have the 'service-host' capability, which + // is a mapping from the Service name to the Service VIP. + var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] + mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) + j := must.Get(json.Marshal(serviceHostCaps)) + cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() + mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) + control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) + + // The Service host must be allowed to advertise the Service VIP. + control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ + netip.MustParsePrefix(serviceVIP + `/32`), + }) + + // The Service host must be a tagged node (any tag will do). + serviceHostNode := control.Node(serviceHost.lb.NodeKey()) + serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") + control.UpdateNode(serviceHostNode) + + // The service client must accept routes advertised by other nodes + // (RouteAll is equivalent to --accept-routes). + must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{ + RouteAll: true, + }, + })) + + // Set up DNS for our Service. + control.AddDNSRecords(tailcfg.DNSRecord{ + Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, + Value: serviceVIP, + }) + + if tt.extraSetup != nil { + tt.extraSetup(t, control) + } + + // Force netmap updates to avoid race conditions. The nodes need to + // see our control updates before we can start the test. + must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey())) + must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey())) + netmapUpToDate := func(s *Server) bool { + nm := s.lb.NetMap() + return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { + return r.Value == serviceVIP + }) + } + for !netmapUpToDate(serviceClient) { + time.Sleep(10 * time.Millisecond) + } + for !netmapUpToDate(serviceHost) { + time.Sleep(10 * time.Millisecond) + } + + // == Done setting up mock state == + + // Start the Service listeners. + listeners := make([]*ServiceListener, 0, len(tt.modes)) + for _, input := range tt.modes { + ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) + defer ln.Close() + listeners = append(listeners, ln) + } + + tt.run(t, listeners, serviceClient) + }) + } +} + func TestListenService(t *testing.T) { // First test an error case which doesn't require all of the fancy setup. t.Run("untagged_node_error", func(t *testing.T) { @@ -1928,6 +2292,7 @@ func (t *chanTUN) BatchSize() int { return 1 } // listenTest provides common setup for listener and TUN tests. type listenTest struct { + control *testcontrol.Server s1, s2 *Server s1ip4, s1ip6 netip.Addr s2ip4, s2ip6 netip.Addr @@ -1941,7 +2306,7 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest { tstest.Shard(t) tstest.ResourceCheck(t) ctx := t.Context() - controlURL, _ := startControl(t) + controlURL, control := startControl(t) s1, _, _ := startServer(t, ctx, controlURL, "s1") tmp := filepath.Join(t.TempDir(), "s2") @@ -1981,13 +2346,14 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest { must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP)) return &listenTest{ - s1: s1, - s2: s2, - s1ip4: s1ip4, - s1ip6: s1ip6, - s2ip4: s2ip4, - s2ip6: s2ip6, - tun: tun, + control: control, + s1: s1, + s2: s2, + s1ip4: s1ip4, + s1ip6: s1ip6, + s2ip4: s2ip4, + s2ip6: s2ip6, + tun: tun, } }