diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 27d7864ae..b9b7db525 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -186,7 +186,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+ - tailscale.com/util/multierr from tailscale.com/control/controlhttp+ + tailscale.com/util/multierr from tailscale.com/health+ tailscale.com/util/must from tailscale.com/clientupdate/distsign+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/prompt from tailscale.com/cmd/tailscale/cli diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 87061c310..da9590c48 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -27,14 +27,12 @@ import ( "errors" "fmt" "io" - "math" "net" "net/http" "net/http/httptrace" "net/netip" "net/url" "runtime" - "sort" "sync/atomic" "time" @@ -53,7 +51,6 @@ import ( "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" - "tailscale.com/util/multierr" ) var stdDialer net.Dialer @@ -110,18 +107,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { } candidates := a.DialPlan.Candidates - // Otherwise, we try dialing per the plan. Store the highest priority - // in the list, so that if we get a connection to one of those - // candidates we can return quickly. - var highestPriority int = math.MinInt - for _, c := range candidates { - if c.Priority > highestPriority { - highestPriority = c.Priority - } - } - - // This context allows us to cancel in-flight connections if we get a - // highest-priority connection before we're all done. + // Create a context to be canceled as we return, so once we get a good connection, + // we can drop all the other ones. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -129,142 +116,58 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { type dialResult struct { conn *ClientConn err error - cand tailcfg.ControlIPCandidate } - resultsCh := make(chan dialResult, len(candidates)) + resultsCh := make(chan dialResult) // unbuffered, never closed - var pending atomic.Int32 - pending.Store(int32(len(candidates))) - for _, c := range candidates { - go func(ctx context.Context, c tailcfg.ControlIPCandidate) { - var ( - conn *ClientConn - err error - ) + dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) { + if cand.ACEHost != "" { + a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q via ACE %s (%s)", cand.DialStartDelaySec, a.Hostname, cand.ACEHost, cmp.Or(cand.IP.String(), "dns")) + } else { + a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %s", cand.DialStartDelaySec, a.Hostname, cand.IP.String()) + } - // Always send results back to our channel. - defer func() { - resultsCh <- dialResult{conn, err, c} - if pending.Add(-1) == 0 { - close(resultsCh) - } - }() - - // If non-zero, wait the configured start timeout - // before we do anything. - if c.DialStartDelaySec > 0 { - a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) - tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) - defer tmr.Stop() - select { - case <-ctx.Done(): - err = ctx.Err() - return - case <-tmrChannel: - } - } - - // Now, create a sub-context with the given timeout and - // try dialing the provided host. - ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second))) - defer cancel() - - if c.IP.IsValid() { - a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP) - } else if c.ACEHost != "" { - a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost) - } - // This will dial, and the defer above sends it back to our parent. - conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost) - }(ctx, c) + ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second))) + defer cancel() + return a.dialHostOpt(ctx, cand.IP, cand.ACEHost) } - var results []dialResult - for res := range resultsCh { - // If we get a response that has the highest priority, we don't - // need to wait for any of the other connections to finish; we - // can just return this connection. - // - // TODO(andrew): we could make this better by keeping track of - // the highest remaining priority dynamically, instead of just - // checking for the highest total - if res.cand.Priority == highestPriority && res.conn != nil { - a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String())) - - // Drain the channel and any existing connections in - // the background. + for _, cand := range candidates { + timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() { go func() { - for _, res := range results { - if res.conn != nil { - res.conn.Close() + conn, err := dialCand(cand) + select { + case resultsCh <- dialResult{conn, err}: + if err == nil { + a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String())) } - } - for res := range resultsCh { - if res.conn != nil { - res.conn.Close() + case <-ctx.Done(): + if conn != nil { + conn.Close() } } - if a.drainFinished != nil { - close(a.drainFinished) - } }() - return res.conn, nil - } - - // This isn't a highest-priority result, so just store it until - // we're done. - results = append(results, res) + }) + defer timer.Stop() } - // After we finish this function, close any remaining open connections. - defer func() { - for _, result := range results { - // Note: below, we nil out the returned connection (if - // any) in the slice so we don't close it. - if result.conn != nil { - result.conn.Close() + var errs []error + for { + select { + case res := <-resultsCh: + if res.err == nil { + return res.conn, nil } + errs = append(errs, res.err) + if len(errs) == len(candidates) { + // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. + a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...)) + return a.dialHost(ctx) + } + case <-ctx.Done(): + a.logf("controlhttp: context aborted dialing") + return nil, ctx.Err() } - - // We don't drain asynchronously after this point, so notify our - // channel when we return. - if a.drainFinished != nil { - close(a.drainFinished) - } - }() - - // Sort by priority, then take the first non-error response. - sort.Slice(results, func(i, j int) bool { - // NOTE: intentionally inverted so that the highest priority - // item comes first - return results[i].cand.Priority > results[j].cand.Priority - }) - - var ( - conn *ClientConn - errs []error - ) - for i, result := range results { - if result.err != nil { - errs = append(errs, result.err) - continue - } - - a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String())) - conn = result.conn - results[i].conn = nil // so we don't close it in the defer - return conn, nil } - if ctx.Err() != nil { - a.logf("controlhttp: context aborted dialing") - return nil, ctx.Err() - } - - merr := multierr.New(errs...) - - // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. - a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error()) - return a.dialHost(ctx) } // The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to @@ -402,6 +305,9 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost } var err80, err443 error + if forceTLS { + err80 = errors.New("TLS forced: no port 80 dialed") + } for { select { case <-ctx.Done(): diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 12038fae4..58fed1b76 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -98,7 +98,6 @@ type Dialer struct { logPort80Failure atomic.Bool // For tests only - drainFinished chan struct{} omitCertErrorLogging bool testFallbackDelay time.Duration diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 0b4e117f9..6485761ac 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -15,19 +15,20 @@ import ( "net/http/httputil" "net/netip" "net/url" - "runtime" "slices" "strconv" + "strings" "sync" "testing" + "testing/synctest" "time" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/health" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" - "tailscale.com/net/netx" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -36,6 +37,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/must" ) type httpTestParam struct { @@ -532,6 +534,28 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== } } +// slowListener wraps a memnet listener to delay accept operations +type slowListener struct { + net.Listener + delay time.Duration +} + +func (sl *slowListener) Accept() (net.Conn, error) { + // Add delay before accepting connections + timer := time.NewTimer(sl.delay) + defer timer.Stop() + <-timer.C + + return sl.Listener.Accept() +} + +func newSlowListener(inner net.Listener, delay time.Duration) net.Listener { + return &slowListener{ + Listener: inner, + delay: delay, + } +} + func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) @@ -545,33 +569,102 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { } func TestDialPlan(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("only works on Linux due to multiple localhost addresses") + testCases := []struct { + name string + plan *tailcfg.ControlDialPlan + want []netip.Addr + allowFallback bool + maxDuration time.Duration + }{ + { + name: "single", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "broken-then-good", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10, DialStartDelaySec: 1}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "multiple-candidates-with-broken", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + // Multiple good IPs plus a broken one + // Should succeed with any of the good ones + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.4"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.4"), netip.MustParseAddr("10.0.0.3")}, + }, + { + name: "multiple-candidates-race", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.3"), netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "fallback", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 1}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + allowFallback: true, + }, + { + // In tailscale/corp#32534 we discovered that a prior implementation + // of the dial race was waiting for all dials to complete when the + // top priority dial was failing. This delay was long enough that in + // real scenarios the server will close the connection due to + // inactivity, because the client does not send the first inside of + // noise request soon enough. This test is a regression guard + // against that behavior - proving that the dial returns promptly + // even if there is some cause of a slow race. + name: "slow-endpoint-doesnt-block", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.12"), Priority: 5, DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), Priority: 1, DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + maxDuration: 2 * time.Second, // Must complete quickly, not wait for slow endpoint + }, } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + runDialPlanTest(t, tt.plan, tt.want, tt.allowFallback, tt.maxDuration) + }) + }) + } +} + +func runDialPlanTest(t *testing.T, plan *tailcfg.ControlDialPlan, want []netip.Addr, allowFallback bool, maxDuration time.Duration) { client, server := key.NewMachine(), key.NewMachine() const ( testProtocolVersion = 1 + httpPort = "80" + httpsPort = "443" ) - getRandomPort := func() string { - ln, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("net.Listen: %v", err) - } - defer ln.Close() - _, port, err := net.SplitHostPort(ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - return port - } + memNetwork := &memnet.Network{} - // We need consistent ports for each address; these are chosen - // randomly and we hope that they won't conflict during this test. - httpPort := getRandomPort() - httpsPort := getRandomPort() + fallbackAddr := netip.MustParseAddr("10.0.0.1") + goodAddr := netip.MustParseAddr("10.0.0.2") + otherAddr := netip.MustParseAddr("10.0.0.3") + other2Addr := netip.MustParseAddr("10.0.0.4") + brokenAddr := netip.MustParseAddr("10.0.0.10") + slowAddr := netip.MustParseAddr("10.0.0.12") makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) { done := make(chan struct{}) @@ -592,14 +685,8 @@ func TestDialPlan(t *testing.T) { handler = wrap(handler) } - httpLn, err := net.Listen("tcp", host.String()+":"+httpPort) - if err != nil { - t.Fatalf("HTTP listen: %v", err) - } - httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort) - if err != nil { - t.Fatalf("HTTPS listen: %v", err) - } + httpLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpPort)) + httpsLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpsPort)) httpServer := &http.Server{Handler: handler} go httpServer.Serve(httpLn) @@ -616,209 +703,199 @@ func TestDialPlan(t *testing.T) { t.Cleanup(func() { httpsServer.Close() }) - return } - fallbackAddr := netip.MustParseAddr("127.0.0.1") - goodAddr := netip.MustParseAddr("127.0.0.2") - otherAddr := netip.MustParseAddr("127.0.0.3") - other2Addr := netip.MustParseAddr("127.0.0.4") - brokenAddr := netip.MustParseAddr("127.0.0.10") - - testCases := []struct { - name string - plan *tailcfg.ControlDialPlan - wrap func(http.Handler) http.Handler - want netip.Addr - - allowFallback bool - }{ - { - name: "single", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - }}, - want: goodAddr, - }, - { - name: "broken-then-good", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // Dials the broken one, which fails, and then - // eventually dials the good one and succeeds - {IP: brokenAddr, Priority: 2, DialTimeoutSec: 10}, - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1}, - }}, - want: goodAddr, - }, - // TODO(#8442): fix this test - // { - // name: "multiple-priority-fast-path", - // plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // // Dials some good IPs and our bad one (which - // // hangs forever), which then hits the fast - // // path where we bail without waiting. - // {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10}, - // {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - // {IP: other2Addr, Priority: 1, DialTimeoutSec: 10}, - // {IP: otherAddr, Priority: 2, DialTimeoutSec: 10}, - // }}, - // want: otherAddr, - // }, - { - name: "multiple-priority-slow-path", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // Our broken address is the highest priority, - // so we don't hit our fast path. - {IP: brokenAddr, Priority: 10, DialTimeoutSec: 10}, - {IP: otherAddr, Priority: 2, DialTimeoutSec: 10}, - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - }}, - want: otherAddr, - }, - { - name: "fallback", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - {IP: brokenAddr, Priority: 1, DialTimeoutSec: 1}, - }}, - want: fallbackAddr, - allowFallback: true, - }, - } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - // TODO(awly): replace this with tstest.NewClock and update the - // test to advance the clock correctly. - clock := tstime.StdClock{} - makeHandler(t, "fallback", fallbackAddr, nil) - makeHandler(t, "good", goodAddr, nil) - makeHandler(t, "other", otherAddr, nil) - makeHandler(t, "other2", other2Addr, nil) - makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { - return brokenMITMHandler(clock) - }) - - dialer := closeTrackDialer{ - t: t, - inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial, - conns: make(map[*closeTrackConn]bool), - } - defer dialer.Done() - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // By default, we intentionally point to something that - // we know won't connect, since we want a fallback to - // DNS to be an error. - host := "example.com" - if tt.allowFallback { - host = "localhost" - } - - drained := make(chan struct{}) - a := &Dialer{ - Hostname: host, - HTTPPort: httpPort, - HTTPSPort: httpsPort, - MachineKey: client, - ControlKey: server.Public(), - ProtocolVersion: testProtocolVersion, - Dialer: dialer.Dial, - Logf: t.Logf, - DialPlan: tt.plan, - proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, - drainFinished: drained, - omitCertErrorLogging: true, - testFallbackDelay: 50 * time.Millisecond, - Clock: clock, - HealthTracker: health.NewTracker(eventbustest.NewBus(t)), - } - - conn, err := a.dial(ctx) + // Use synctest's controlled time + clock := tstime.StdClock{} + makeHandler(t, "fallback", fallbackAddr, nil) + makeHandler(t, "good", goodAddr, nil) + makeHandler(t, "other", otherAddr, nil) + makeHandler(t, "other2", other2Addr, nil) + makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { + return brokenMITMHandler(clock) + }) + // Create slow listener that delays accept by 5 seconds + makeSlowHandler := func(t *testing.T, name string, host netip.Addr, delay time.Duration) { + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { - t.Fatalf("dialing controlhttp: %v", err) - } - defer conn.Close() - - raddr := conn.RemoteAddr().(*net.TCPAddr) - - got, ok := netip.AddrFromSlice(raddr.IP) - if !ok { - t.Errorf("invalid remote IP: %v", raddr.IP) - } else if got != tt.want { - t.Errorf("got connection from %q; want %q", got, tt.want) + log.Print(err) } else { - t.Logf("successfully connected to %q", raddr.String()) + defer conn.Close() } + w.Header().Set("X-Handler-Name", name) + <-done + }) - // Wait until our dialer drains so we can verify that - // all connections are closed. - <-drained + httpLn, err := memNetwork.Listen("tcp", host.String()+":"+httpPort) + if err != nil { + t.Fatalf("HTTP listen: %v", err) + } + httpsLn, err := memNetwork.Listen("tcp", host.String()+":"+httpsPort) + if err != nil { + t.Fatalf("HTTPS listen: %v", err) + } + + slowHttpLn := newSlowListener(httpLn, delay) + slowHttpsLn := newSlowListener(httpsLn, delay) + + httpServer := &http.Server{Handler: handler} + go httpServer.Serve(slowHttpLn) + t.Cleanup(func() { + httpServer.Close() + }) + + httpsServer := &http.Server{ + Handler: handler, + TLSConfig: tlsConfig(t), + ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")), + } + go httpsServer.ServeTLS(slowHttpsLn, "", "") + t.Cleanup(func() { + httpsServer.Close() }) } + makeSlowHandler(t, "slow", slowAddr, 5*time.Second) + + // memnetDialer with connection tracking, so we can catch connection leaks. + dialer := &memnetDialer{ + inner: memNetwork.Dial, + t: t, + } + defer dialer.waitForAllClosedSynctest() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + host := "example.com" + if allowFallback { + host = fallbackAddr.String() + } + + a := &Dialer{ + Hostname: host, + HTTPPort: httpPort, + HTTPSPort: httpsPort, + MachineKey: client, + ControlKey: server.Public(), + ProtocolVersion: testProtocolVersion, + Dialer: dialer.Dial, + Logf: t.Logf, + DialPlan: plan, + proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, + omitCertErrorLogging: true, + testFallbackDelay: 50 * time.Millisecond, + Clock: clock, + HealthTracker: health.NewTracker(eventbustest.NewBus(t)), + } + + start := time.Now() + conn, err := a.dial(ctx) + duration := time.Since(start) + + if err != nil { + t.Fatalf("dialing controlhttp: %v", err) + } + defer conn.Close() + + if maxDuration > 0 && duration > maxDuration { + t.Errorf("dial took %v, expected < %v (should not wait for slow endpoints)", duration, maxDuration) + } + + raddr := conn.RemoteAddr() + raddrStr := raddr.String() + + // split on "|" first to remove memnet pipe suffix + addrPart := raddrStr + if idx := strings.Index(raddrStr, "|"); idx >= 0 { + addrPart = raddrStr[:idx] + } + + host, _, err2 := net.SplitHostPort(addrPart) + if err2 != nil { + t.Fatalf("failed to parse remote address %q: %v", addrPart, err2) + } + + got, err3 := netip.ParseAddr(host) + if err3 != nil { + t.Errorf("invalid remote IP: %v", host) + } else { + found := slices.Contains(want, got) + if !found { + t.Errorf("got connection from %q; want one of %v", got, want) + } else { + t.Logf("successfully connected to %q", raddr.String()) + } + } } -type closeTrackDialer struct { - t testing.TB - inner netx.DialFunc +// memnetDialer wraps memnet.Network.Dial to track connections for testing +type memnetDialer struct { + inner func(ctx context.Context, network, addr string) (net.Conn, error) + t *testing.T mu sync.Mutex - conns map[*closeTrackConn]bool + conns map[net.Conn]string // conn -> remote address for debugging } -func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := d.inner(ctx, network, addr) +func (d *memnetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := d.inner(ctx, network, addr) if err != nil { return nil, err } - ct := &closeTrackConn{Conn: c, d: d} d.mu.Lock() - d.conns[ct] = true + if d.conns == nil { + d.conns = make(map[net.Conn]string) + } + d.conns[conn] = conn.RemoteAddr().String() + d.t.Logf("tracked connection opened to %s", conn.RemoteAddr()) d.mu.Unlock() - return ct, nil + + return &memnetTrackedConn{Conn: conn, dialer: d}, nil } -func (d *closeTrackDialer) Done() { - // Unfortunately, tsdial.Dialer.SystemDial closes connections - // asynchronously in a goroutine, so we can't assume that everything is - // closed by the time we get here. - // - // Sleep/wait a few times on the assumption that things will close - // "eventually". - const iters = 100 - for i := range iters { +func (d *memnetDialer) waitForAllClosedSynctest() { + const maxWait = 15 * time.Second + const checkInterval = 100 * time.Millisecond + + for range int(maxWait / checkInterval) { d.mu.Lock() - if len(d.conns) == 0 { + remaining := len(d.conns) + if remaining == 0 { d.mu.Unlock() return } - - // Only error on last iteration - if i != iters-1 { - d.mu.Unlock() - time.Sleep(100 * time.Millisecond) - continue - } - - for conn := range d.conns { - d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String()) - } d.mu.Unlock() + + time.Sleep(checkInterval) + } + + d.mu.Lock() + defer d.mu.Unlock() + for _, addr := range d.conns { + d.t.Errorf("connection to %s was not closed after %v", addr, maxWait) } } -func (d *closeTrackDialer) noteClose(c *closeTrackConn) { +func (d *memnetDialer) noteClose(conn net.Conn) { d.mu.Lock() - delete(d.conns, c) // safe if already deleted + if addr, exists := d.conns[conn]; exists { + d.t.Logf("tracked connection closed to %s", addr) + delete(d.conns, conn) + } d.mu.Unlock() } -type closeTrackConn struct { +type memnetTrackedConn struct { net.Conn - d *closeTrackDialer + dialer *memnetDialer } -func (c *closeTrackConn) Close() error { - c.d.noteClose(c) +func (c *memnetTrackedConn) Close() error { + c.dialer.noteClose(c.Conn) return c.Conn.Close() }