diff --git a/client/local/local.go b/client/local/local.go index 75fdbe5a5..6508af29e 100644 --- a/client/local/local.go +++ b/client/local/local.go @@ -972,6 +972,19 @@ func (lc *Client) UserDial(ctx context.Context, network, host string, port uint1 if res.StatusCode != http.StatusSwitchingProtocols { body, _ := io.ReadAll(res.Body) res.Body.Close() + if res.StatusCode == http.StatusOK && res.Header.Get("Dial-Self") == "true" { + // Server told us to dial the address ourselves rather than + // proxying through the daemon. This happens for non-Tailscale + // addresses where the daemon shouldn't dial as root on the + // client's behalf. The server provides the resolved address + // to avoid a TOCTOU race with DNS re-resolution. + addr := res.Header.Get("Dial-Addr") + if addr == "" { + return nil, errors.New("server returned Dial-Self without Dial-Addr") + } + var d net.Dialer + return d.DialContext(ctx, network, addr) + } return nil, fmt.Errorf("unexpected HTTP response: %s, %s", res.Status, body) } // From here on, the underlying net.Conn is ours to use, but there diff --git a/client/local/local_test.go b/client/local/local_test.go index a5377fbd6..58a87b224 100644 --- a/client/local/local_test.go +++ b/client/local/local_test.go @@ -61,6 +61,57 @@ func TestWhoIsPeerNotFound(t *testing.T) { } } +func TestUserDialSelf(t *testing.T) { + // Start a real TCP listener that the client should dial directly + // when the server tells it to dial-self. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Write([]byte("hello")) + c.Close() + } + }() + targetAddr := ln.Addr().(*net.TCPAddr) + + // Mock LocalAPI server that returns Dial-Self response. + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", targetAddr.String()) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + lc := &Client{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, ts.Listener.Addr().String()) + }, + } + + conn, err := lc.UserDial(context.Background(), "tcp", targetAddr.IP.String(), uint16(targetAddr.Port)) + if err != nil { + t.Fatalf("UserDial: %v", err) + } + defer conn.Close() + + buf := make([]byte, 5) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got := string(buf[:n]); got != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + func TestDeps(t *testing.T) { deptest.DepChecker{ BadDeps: map[string]string{ diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index b06b69b04..f16fab2aa 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -1169,16 +1169,34 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing Dial-Host or Dial-Port header", http.StatusBadRequest) return } + network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") + + addr := net.JoinHostPort(hostStr, portStr) + + // Check whether the resolved address is a Tailscale route. + // If not, tell the client to dial it directly so the connection + // comes from the calling user's UID rather than our root-owned daemon. + ipp, viaTailscale, err := h.b.Dialer().UserDialPlan(r.Context(), network, addr) + if err != nil { + http.Error(w, "resolve failure: "+err.Error(), http.StatusBadGateway) + return + } + if !viaTailscale { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", ipp.String()) + w.WriteHeader(http.StatusOK) + return + } + hijacker, ok := w.(http.Hijacker) if !ok { http.Error(w, "make request over HTTP/1", http.StatusBadRequest) return } - network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") - - addr := net.JoinHostPort(hostStr, portStr) - outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr) + // Dial via Tailscale using the resolved IP:port to avoid a TOCTOU + // race with DNS re-resolution. + outConn, err := h.b.Dialer().UserDial(r.Context(), network, ipp.String()) if err != nil { http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway) return diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 47e334571..a755221bf 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -500,3 +500,69 @@ func TestServeWithUnhealthyState(t *testing.T) { }) } } + +func TestServeDialSelf(t *testing.T) { + h := handlerForTest(t, &Handler{ + PermitRead: true, + PermitWrite: true, + b: newTestLocalBackend(t), + }) + + tests := []struct { + name string + host string + port string + wantSelf bool + wantAddr string + wantStatus int + }{ + { + name: "loopback_v4", + host: "127.0.0.1", + port: "8080", + wantSelf: true, + wantAddr: "127.0.0.1:8080", + wantStatus: http.StatusOK, + }, + { + name: "loopback_v6", + host: "::1", + port: "8080", + wantSelf: true, + wantAddr: "[::1]:8080", + wantStatus: http.StatusOK, + }, + { + name: "localhost", + host: "localhost", + port: "3000", + wantSelf: true, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "http://local-tailscaled.sock/localapi/v0/dial", nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "ts-dial") + req.Header.Set("Dial-Host", tt.host) + req.Header.Set("Dial-Port", tt.port) + resp := httptest.NewRecorder() + h.serveDial(resp, req) + + if resp.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d; body: %s", resp.Code, tt.wantStatus, resp.Body.String()) + } + gotSelf := resp.Header().Get("Dial-Self") == "true" + if gotSelf != tt.wantSelf { + t.Errorf("Dial-Self = %v, want %v", gotSelf, tt.wantSelf) + } + if tt.wantAddr != "" { + if got := resp.Header().Get("Dial-Addr"); got != tt.wantAddr { + t.Errorf("Dial-Addr = %q, want %q", got, tt.wantAddr) + } + } + }) + } +} diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index ebbafa52b..ca08810a3 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -515,6 +515,33 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, return stdDialer.DialContext(ctx, network, ipp.String()) } +// UserDialPlan resolves addr and reports whether the dialer would +// handle it via Tailscale. If viaTailscale is false, the resolved +// address is not a Tailscale route and the caller may dial it directly. +// +// Warning: there is a TOCTOU race if addr contains a DNS name and the +// caller subsequently passes the same DNS name to [Dialer.UserDial], as DNS +// may resolve differently the second time. Callers who want to only +// dial over Tailscale should call [Dialer.UserDial] with the returned +// ipp.String() (an IP:port) rather than the original DNS name. +func (d *Dialer) UserDialPlan(ctx context.Context, network, addr string) (ipp netip.AddrPort, viaTailscale bool, err error) { + ipp, err = d.userDialResolve(ctx, network, addr) + if err != nil { + return netip.AddrPort{}, false, err + } + if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { + return ipp, true, nil + } + if routes := d.routes.Load(); routes != nil { + isTailscaleRoute, _ := routes.Lookup(ipp.Addr()) + return ipp, isTailscaleRoute, nil + } + if version.IsMacGUIVariant() && tsaddr.IsTailscaleIP(ipp.Addr()) { + return ipp, true, nil + } + return ipp, false, nil +} + // dialPeerAPI connects to a Tailscale peer's peerapi over TCP. // // network must a "tcp" type, and addr must be an ip:port. Name resolution diff --git a/net/tsdial/tsdial_test.go b/net/tsdial/tsdial_test.go new file mode 100644 index 000000000..92960acbe --- /dev/null +++ b/net/tsdial/tsdial_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "net/netip" + "testing" + + "github.com/gaissmai/bart" +) + +func TestUserDialPlan(t *testing.T) { + tests := []struct { + name string + addr string + routes map[netip.Prefix]bool // nil means no routes configured + useNetstackFor func(netip.Addr) bool // nil means not set + wantVia bool + wantAddr netip.AddrPort + }{ + { + name: "loopback_no_routes", + addr: "127.0.0.1:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:8080"), + }, + { + name: "loopback_v6_no_routes", + addr: "[::1]:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("[::1]:8080"), + }, + { + name: "tailscale_ip_in_routes", + addr: "100.64.1.1:22", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.64.1.1:22"), + }, + { + name: "non_tailscale_ip_in_local_routes", + addr: "10.0.0.5:80", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + netip.MustParsePrefix("10.0.0.0/8"): false, // local route + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("10.0.0.5:80"), + }, + { + name: "loopback_with_routes_configured", + addr: "127.0.0.1:3000", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:3000"), + }, + { + name: "netstack_for_ip", + addr: "100.100.100.100:53", + useNetstackFor: func(ip netip.Addr) bool { + return ip == netip.MustParseAddr("100.100.100.100") + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.100.100.100:53"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &Dialer{} + if tt.routes != nil { + rt := &bart.Table[bool]{} + for pfx, v := range tt.routes { + rt.Insert(pfx, v) + } + d.routes.Store(rt) + } + d.UseNetstackForIP = tt.useNetstackFor + + ipp, viaTailscale, err := d.UserDialPlan(context.Background(), "tcp", tt.addr) + if err != nil { + t.Fatalf("UserDialPlan: %v", err) + } + if viaTailscale != tt.wantVia { + t.Errorf("viaTailscale = %v, want %v", viaTailscale, tt.wantVia) + } + if ipp != tt.wantAddr { + t.Errorf("addr = %v, want %v", ipp, tt.wantAddr) + } + }) + } +}