From 2c2b2f8cf9c2f199a42a67bf404350aedd990828 Mon Sep 17 00:00:00 2001 From: Harry Harpham Date: Tue, 23 Dec 2025 18:57:04 -0700 Subject: [PATCH] tsnet: add support to ListenService for identity and app capability headers Signed-off-by: Harry Harpham --- tsnet/tsnet.go | 73 ++++++- tsnet/tsnet_test.go | 198 ++++++++++++++---- tstest/integration/testcontrol/testcontrol.go | 38 ++++ 3 files changed, 261 insertions(+), 48 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 24553c220..b1057478a 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1260,6 +1260,32 @@ func ServiceOptionPROXYProtocol(version int) ServiceOption { return serviceOptionPROXYProtocol{version} } +type serviceOptionAppCapabilities struct { + path string + caps []string +} + +func (serviceOptionAppCapabilities) serviceOption() {} + +// TODO: doc +func ServiceOptionAppCapabilities(capabilities ...string) ServiceOption { + return ServiceOptionAppCapabilitiesForPath("/", capabilities...) +} + +// TODO: doc +func ServiceOptionAppCapabilitiesForPath(path string, capabilities ...string) ServiceOption { + return serviceOptionAppCapabilities{path, capabilities} +} + +type serviceOptionWithHeaders struct{} + +func (serviceOptionWithHeaders) serviceOption() {} + +// TODO: doc +func ServiceOptionWithHeaders() ServiceOption { + return serviceOptionWithHeaders{} +} + // ErrUntaggedServiceHost is returned by ListenService when run on a node // without any ACL tags. A node must use a tag-based identity to act as a // Service host. For more information, see: @@ -1272,6 +1298,7 @@ func (s *Server) ListenService(name string, port uint16, opts ...ServiceOption) if err := tailcfg.ServiceName(name).Validate(); err != nil { return nil, err } + svcName := name // TODO: // - create example for a Service with multiple ports @@ -1284,12 +1311,23 @@ func (s *Server) ListenService(name string, port uint16, opts ...ServiceOption) // Process options. terminateTLS := false proxyProtocol := 0 + capsMap := map[string][]tailcfg.PeerCapability{} // mount point => caps + isHTTP := false for _, o := range opts { switch opt := o.(type) { case serviceOptionTerminateTLS: terminateTLS = true case serviceOptionPROXYProtocol: proxyProtocol = opt.version + case serviceOptionWithHeaders: + isHTTP = true + case serviceOptionAppCapabilities: + isHTTP = true + caps := make([]tailcfg.PeerCapability, 0, len(opt.caps)) + for _, c := range opt.caps { + caps = append(caps, tailcfg.PeerCapability(c)) + } + capsMap[opt.path] = append(capsMap[opt.path], caps...) default: return nil, fmt.Errorf("unknown opts FunnelOption type %T", o) } @@ -1315,12 +1353,12 @@ func (s *Server) ListenService(name string, port uint16, opts ...ServiceOption) if err != nil { return nil, fmt.Errorf("fetching node preferences: %w", err) } - if !slices.Contains(prefs.AdvertiseServices, name) { + if !slices.Contains(prefs.AdvertiseServices, svcName) { // TODO: do we need to undo this edit on error? _, err = lc.EditPrefs(ctx, &ipn.MaskedPrefs{ AdvertiseServicesSet: true, Prefs: ipn.Prefs{ - AdvertiseServices: append(prefs.AdvertiseServices, name), + AdvertiseServices: append(prefs.AdvertiseServices, svcName), }, }) if err != nil { @@ -1341,10 +1379,33 @@ func (s *Server) ListenService(name string, port uint16, opts ...ServiceOption) if err != nil { return nil, fmt.Errorf("starting local listener: %w", err) } - // Forward all connections from service-hostname:port to our socket. - srvConfig.SetTCPForwardingForService( // TODO: tangent, but can we reduce the number of args here? - port, ln.Addr().String(), tailcfg.ServiceName(name), - terminateTLS, proxyProtocol, st.CurrentTailnet.MagicDNSSuffix) + + if isHTTP { + useTLS := false // TODO: set correctly + mds := st.CurrentTailnet.MagicDNSSuffix + setHandler := func(h ipn.HTTPHandler, path string) { + // TODO: do we need to add the path to the end of the proxy value? + h.Proxy = ln.Addr().String() + srvConfig.SetWebHandler(&h, svcName, port, path, useTLS, mds) + } + // Set a web handler for every mount point in the caps map. If we don't + // end up with a root handler after that, we need to set one. + haveRootHandler := false + for path, caps := range capsMap { + if path == "/" { + haveRootHandler = true + } + setHandler(ipn.HTTPHandler{AcceptAppCaps: caps}, path) + } + if !haveRootHandler { + setHandler(ipn.HTTPHandler{}, "/") + } + } else { + // Forward all connections from service-hostname:port to our socket. + srvConfig.SetTCPForwardingForService( + port, ln.Addr().String(), tailcfg.ServiceName(svcName), + terminateTLS, proxyProtocol, st.CurrentTailnet.MagicDNSSuffix) + } if err := lc.SetServeConfig(ctx, srvConfig); err != nil { ln.Close() diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 15b05612c..aef0cc737 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -758,16 +758,159 @@ func TestFunnel(t *testing.T) { } func TestListenService(t *testing.T) { + type dialFn func(context.Context, string, string) (net.Conn, error) + + // TCP helpers + acceptAndEcho := func(t *testing.T, ln net.Listener) { + t.Helper() + conn, err := ln.Accept() + if err != nil { + t.Error("accept error:", err) + return + } + defer conn.Close() + if _, err := io.Copy(conn, conn); err != nil { + t.Error("copy error:", err) + } + } + assertEcho := func(t *testing.T, conn net.Conn) { + t.Helper() + msg := "echo" + buf := make([]byte, 1024) + if _, err := conn.Write([]byte(msg)); err != nil { + t.Fatal("write failed:", err) + } + n, err := conn.Read(buf) + if err != nil { + t.Fatal("read failed:", err) + } + got := string(buf[:n]) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + + // HTTP helpers + checkAndEcho := func(t *testing.T, ln net.Listener, check func(r *http.Request)) { + t.Helper() + http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + check(r) + if _, err := io.Copy(w, r.Body); err != nil { + t.Error("copy error:", err) + w.WriteHeader(http.StatusInternalServerError) + } + })) + } + assertEchoHTTP := func(t *testing.T, hostname string, dial dialFn) { + t.Helper() + c := http.Client{ + Transport: &http.Transport{ + DialContext: dial, + }, + } + msg := "echo" + resp, err := c.Post("http://"+hostname, "text/plain", strings.NewReader(msg)) + if err != nil { + t.Fatal("posting request:", err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("reading body:", err) + } + got := string(b) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + tests := []struct { name string + port uint16 opts []ServiceOption + + extraSetup func(t *testing.T, serviceHost, peer *Server, control *testcontrol.Server) + + // run the test. This function does not need to close any of the input + // resources, but it should close any new resources it opens. + run func(t *testing.T, serviceListener net.Listener, peer *Server, serviceFQDN string) }{ { - name: "basic_TCP_service", + name: "basic_TCP", + port: 99, + run: func(t *testing.T, serviceListener net.Listener, peer *Server, serviceFQDN string) { + go acceptAndEcho(t, serviceListener) + + target := fmt.Sprintf("%s:%d", serviceFQDN, 99) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, conn) + }, }, { name: "TLS_terminated_TCP", opts: []ServiceOption{ServiceOptionTerminateTLS()}, + port: 443, + run: func(t *testing.T, serviceListener net.Listener, peer *Server, serviceFQDN string) { + go acceptAndEcho(t, serviceListener) + + target := fmt.Sprintf("%s:%d", serviceFQDN, 443) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, tls.Client(conn, &tls.Config{ + ServerName: serviceFQDN, + RootCAs: testCertRoot.Pool(), + })) + }, + }, + { + name: "identity_headers", + opts: []ServiceOption{ServiceOptionWithHeaders()}, + port: 80, + run: func(t *testing.T, serviceListener net.Listener, peer *Server, serviceFQDN string) { + expectHeader := "Tailscale-User-Name" + go checkAndEcho(t, serviceListener, func(r *http.Request) { + if _, ok := r.Header[expectHeader]; !ok { + t.Error("did not see expected header:", expectHeader) + } + }) + assertEchoHTTP(t, serviceFQDN, peer.Dial) + }, + }, + { + name: "app_capabilities", + opts: []ServiceOption{ServiceOptionAppCapabilities("example.com/cap/want")}, + port: 80, + extraSetup: func(t *testing.T, serviceHost, peer *Server, control *testcontrol.Server) { + control.SetGlobalAppCaps(tailcfg.PeerCapMap{ + "example.com/cap/want": []tailcfg.RawMessage{`true`}, + }) + }, + run: func(t *testing.T, serviceListener net.Listener, peer *Server, serviceFQDN string) { + go checkAndEcho(t, serviceListener, func(r *http.Request) { + rawCaps, ok := r.Header["Tailscale-App-Capabilities"] + if !ok { + t.Error("no app capabilities header") + return + } + if len(rawCaps) != 1 { + t.Error("expected one app capabilities header value, got", len(rawCaps)) + return + } + var caps map[string][]any + if err := json.Unmarshal([]byte(rawCaps[0]), &caps); err != nil { + t.Error("error unmarshaling app caps:", err) + return + } + if _, ok := caps["example.com/cap/want"]; !ok { + t.Errorf("got app caps, but expected cap is not present; saw:\n%v", caps) + } + }) + assertEchoHTTP(t, serviceFQDN, peer.Dial) + }, }, // TODO: // Success cases: @@ -799,7 +942,6 @@ func TestListenService(t *testing.T) { serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") const serviceName = tailcfg.ServiceName("svc:foo") - const servicePort uint16 = 99 const serviceVIP = "100.11.22.33" serviceFQDN := serviceName.WithoutPrefix() + "." + control.MagicDNSDomain @@ -834,6 +976,16 @@ func TestListenService(t *testing.T) { }, })) + // Set up DNS for our Service. + control.DNSConfig.ExtraRecords = append(control.DNSConfig.ExtraRecords, tailcfg.DNSRecord{ + Name: serviceFQDN, + Value: serviceVIP, + }) + + if tt.extraSetup != nil { + tt.extraSetup(t, serviceHost, serviceClient, control) + } + // Force netmap updates to avoid race conditions. The nodes need to // see our control updates before we can start the test. serviceClient.lb.DebugForceNetmapUpdate() @@ -842,48 +994,10 @@ func TestListenService(t *testing.T) { // == Done setting up mock state == // Start a Service listener. - ln := must.Get(serviceHost.ListenService(serviceName.String(), servicePort, tt.opts...)) + ln := must.Get(serviceHost.ListenService(serviceName.String(), tt.port, tt.opts...)) defer ln.Close() - // Accept the first connection on ln and echo back what we receive. - go func() { - conn, err := ln.Accept() - if err != nil { - t.Error("accept error:", err) - return - } - defer conn.Close() - if _, err := io.Copy(conn, conn); err != nil { - t.Error("copy error:", err) - } - }() - - target := fmt.Sprintf("%s:%d", serviceVIP, servicePort) - conn := must.Get(serviceClient.Dial(ctx, "tcp", target)) - defer conn.Close() - - for _, opt := range tt.opts { - if _, ok := opt.(serviceOptionTerminateTLS); ok { - conn = tls.Client(conn, &tls.Config{ - ServerName: serviceFQDN, - RootCAs: testCertRoot.Pool(), - }) - } - } - - msg := "hello, Service" - buf := make([]byte, 1024) - if _, err := conn.Write([]byte(msg)); err != nil { - t.Fatal("write failed:", err) - } - n, err := conn.Read(buf) - if err != nil { - t.Fatal("read failed:", err) - } - got := string(buf[:n]) - if got != msg { - t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) - } + tt.run(t, ln, serviceClient, serviceFQDN) }) } } diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 2f23384bd..cbb9f3663 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -110,6 +110,16 @@ type Server struct { // nodeCapMaps overrides the capability map sent down to a client. nodeCapMaps map[key.NodePublic]tailcfg.NodeCapMap + // globalAppCaps configures global app capabilities, equivalent to: + // "grants": [ + // { + // "src": ["*"], + // "dst": ["*"], + // "app": + // } + // ] + globalAppCaps tailcfg.PeerCapMap + // suppressAutoMapResponses is the set of nodes that should not be sent // automatic map responses from serveMap. (They should only get manually sent ones) suppressAutoMapResponses set.Set[key.NodePublic] @@ -531,6 +541,21 @@ func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap s.updateLocked("SetNodeCapMap", s.nodeIDsLocked(0)) } +// SetGlobalAppCaps configures global app capabilities. This is equivalent to +// +// "grants": [ +// { +// "src": ["*"], +// "dst": ["*"], +// "app": +// } +// ] +func (s *Server) SetGlobalAppCaps(appCaps tailcfg.PeerCapMap) { + s.mu.Lock() + s.globalAppCaps = appCaps + s.mu.Unlock() +} + // nodeIDsLocked returns the node IDs of all nodes in the server, except // for the node with the given ID. func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID { @@ -1280,6 +1305,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, s.mu.Lock() nodeMasqs := s.masquerades[node.Key] jailed := maps.Clone(s.peerIsJailed[node.Key]) + globalAppCaps := s.globalAppCaps s.mu.Unlock() for _, p := range s.AllNodes() { if p.StableID == node.StableID { @@ -1331,6 +1357,18 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, v6Prefix, } + if globalAppCaps != nil { + res.PacketFilter = append(res.PacketFilter, tailcfg.FilterRule{ + SrcIPs: []string{"*"}, + CapGrant: []tailcfg.CapGrant{ + { + Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + CapMap: globalAppCaps, + }, + }, + }) + } + // If the server is tracking TKA state, and there's a single TKA head, // add it to the MapResponse. if s.tkaStorage != nil {