diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index e98a9e078..680fe2d02 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -40,6 +40,7 @@ import ( "tailscale.com/net/tshttpproxy" "tailscale.com/paths" "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -844,6 +845,24 @@ func runTS2021(ctx context.Context, args []string) error { if ts2021Args.verbose { logf = log.Printf } + + h2Transport, err := http2.ConfigureTransports(&http.Transport{ + IdleConnTimeout: time.Second, + }) + if err != nil { + return fmt.Errorf("http2.ConfigureTransports: %w", err) + } + + var noiseConns syncs.Map[*controlhttp.ClientConn, *noiseconn.Conn] + + // Close all noise conns when we're done. + defer func() { + noiseConns.Range(func(_ *controlhttp.ClientConn, ncc *noiseconn.Conn) bool { + ncc.Close() + return true + }) + }() + conn, err := (&controlhttp.Dialer{ Hostname: ts2021Args.host, HTTPPort: "80", @@ -853,6 +872,24 @@ func runTS2021(ctx context.Context, args []string) error { ProtocolVersion: uint16(ts2021Args.version), Dialer: dialFunc, Logf: logf, + TestConn: func(cc *controlhttp.ClientConn) (retErr error) { + log.Printf("testing ClientConn %p ...", cc) + + nc, err := noiseconn.New(cc.Conn, h2Transport, 0, nil) + if err != nil { + return fmt.Errorf("noiseconn.New: %w", err) + } + + // Store this conn for later use. + noiseConns.Store(cc, nc) + defer func() { + if retErr != nil { + noiseConns.Delete(cc) + nc.Close() + } + }() + return noiseconn.TestConn(ctx, log.Printf, nc, ts2021Args.host) + }, }).Dial(ctx) log.Printf("controlhttp.Dial = %p, %v", conn, err) if err != nil { @@ -867,52 +904,6 @@ func runTS2021(ctx context.Context, args []string) error { } log.Printf("final underlying conn: %v / %v", conn.LocalAddr(), conn.RemoteAddr()) - - h2Transport, err := http2.ConfigureTransports(&http.Transport{ - IdleConnTimeout: time.Second, - }) - if err != nil { - return fmt.Errorf("http2.ConfigureTransports: %w", err) - } - - // Now, create a Noise conn over the existing conn. - nc, err := noiseconn.New(conn.Conn, h2Transport, 0, nil) - if err != nil { - return fmt.Errorf("noiseconn.New: %w", err) - } - defer nc.Close() - - // Reserve a RoundTrip for the whoami request. - ok, _, err := nc.ReserveNewRequest(ctx) - if err != nil { - return fmt.Errorf("ReserveNewRequest: %w", err) - } - if !ok { - return errors.New("ReserveNewRequest failed") - } - - // Make a /whoami request to the server to verify that we can actually - // communicate over the newly-established connection. - whoamiURL := "http://" + ts2021Args.host + "/machine/whoami" - req, err = http.NewRequestWithContext(ctx, "GET", whoamiURL, nil) - if err != nil { - return err - } - resp, err := nc.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip whoami request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - log.Printf("whoami request returned status %v", resp.Status) - } else { - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("reading whoami response: %w", err) - } - log.Printf("whoami response: %q", body) - } return nil } diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 44437e2f3..39508dd16 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -21,6 +21,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -337,6 +338,10 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() + // We create a new noiseconn in the TestConn function, below, so to + // avoid re-creating it multiple times (after the Dial), we store a + // reference to created ones here. + var noiseConns syncs.Map[*controlhttp.ClientConn, *noiseconn.Conn] clientConn, err := (&controlhttp.Dialer{ Hostname: nc.host, HTTPPort: nc.httpPort, @@ -351,16 +356,49 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) { NetMon: nc.netMon, HealthTracker: nc.health, Clock: tstime.StdClock{}, + TestConn: func(cc *controlhttp.ClientConn) error { + ncc, err := noiseconn.New(cc.Conn, nc.h2t, connID, nc.connClosed) + if err != nil { + return err + } + + // Store this conn for later extraction. + noiseConns.Store(cc, ncc) + + if err := noiseconn.TestConn(ctx, nc.logf, ncc, nc.host); err != nil { + noiseConns.Delete(cc) + ncc.Close() + return err + } + + if nc.logf != nil { + nc.logf("tested noise connection successfully") + } + return nil + }, }).Dial(ctx) if err != nil { + // Ensure that we close any noise connections that we created. + noiseConns.Range(func(_ *controlhttp.ClientConn, ncc *noiseconn.Conn) bool { + ncc.Close() + return true + }) return nil, err } - ncc, err := noiseconn.New(clientConn.Conn, nc.h2t, connID, nc.connClosed) - if err != nil { - return nil, err + // If we get here, we know that we successfully created a noiseConn, + // above, so we extract and use it. + ncc, found := noiseConns.LoadAndDelete(clientConn) + if !found { + return nil, errors.New("[unexpected] no noiseConn found") } + // Close all other noiseConns that we created but didn't use. + noiseConns.Range(func(_ *controlhttp.ClientConn, ncc *noiseconn.Conn) bool { + ncc.Close() + return true + }) + nc.mu.Lock() if nc.closed { nc.mu.Unlock() diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index e01cb1f9a..522318c0e 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -310,6 +310,21 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, er if debugNoiseDial() { a.logf("noise dial (%v, %v) = (%v, %v)", u, addr, cbConn, err) } + + // We've seen some networks where the connection upgrades + // successfully, but then fails when we make a request after + // the upgrade. Work around this by making a request over the + // now-upgraded connection before we tell the outer function + // that we've got a connection. + if err == nil && a.TestConn != nil { + err = a.TestConn(cbConn) + if err != nil { + // Close and don't leak the connection. + cbConn.Close() + cbConn = nil + } + } + select { case ch <- tryURLRes{u, cbConn, err}: case <-ctx.Done(): diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 6b5116262..ffa61edec 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -88,6 +88,13 @@ type Dialer struct { // plan before falling back to DNS. DialPlan *tailcfg.ControlDialPlan + // TestConn, if non-nil, is called with a dialed connection to verify + // that it's ready to serve real requests. If this function returns an + // error, the connection is closed and not used. If this function + // returns an error for all dialed connections, an error is returned + // from Dial. + TestConn func(*ClientConn) error + proxyFunc func(*http.Request) (*url.URL, error) // or nil // For tests only diff --git a/internal/noiseconn/conn.go b/internal/noiseconn/conn.go index 7476b7ecc..abc805b74 100644 --- a/internal/noiseconn/conn.go +++ b/internal/noiseconn/conn.go @@ -14,13 +14,19 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" "io" "net/http" + "net/url" "sync" "golang.org/x/net/http2" "tailscale.com/control/controlbase" + "tailscale.com/envknob" + "tailscale.com/ipn" "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) // Conn is a wrapper around controlbase.Conn. @@ -185,3 +191,55 @@ func (c *Conn) Close() error { } return nil } + +var defaultControlHost string = (func() string { + uu := must.Get(url.Parse(ipn.DefaultControlURL)) + return uu.Hostname() +})() + +// TestConn is a shared implementation to allow testing that a Conn is +// operating successfully, by making a request to the server and verifying that +// the response is what we expect. +// +// Since this +func TestConn(ctx context.Context, logf logger.Logf, conn *Conn, host string) error { + // TODO(andrew-d): double-check that reserving a request here doesn't + // mess with our early payload. + ok, _, err := conn.ReserveNewRequest(ctx) + if err != nil { + return fmt.Errorf("ReserveNewRequest: %w", err) + } + if !ok { + return errors.New("ReserveNewRequest failed") + } + + whoamiURL := "http://" + host + "/machine/whoami" + req, err := http.NewRequestWithContext(ctx, "GET", whoamiURL, nil) + if err != nil { + return err + } + resp, err := conn.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip whoami request: %w", err) + } + defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) + + // If we're talking to the default control plane, we know that this + // endpoint exists and should return 200. Thus, we can error if this + // request doesn't work. However, for custom control planes this + // endpoint may not be present, so we treat any valid HTTP response as + // a success. + if host == defaultControlHost && resp.StatusCode != 200 { + return fmt.Errorf("whoami request failed; got status %d", resp.StatusCode) + } + + if envknob.Bool("TS_DEBUG_NOISE_DIAL") && logf != nil { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading whoami response: %w", err) + } + logf("[v1] TestConn: whoami response: %q", body) + } + return nil +}