net/porttrack: add net.Listen wrapper to help tests allocate ports race-free

Updates tailscale/corp#27805
Updates tailscale/corp#27806
Updates tailscale/corp#37964

Change-Id: I7bb5ed7f258e840a8208e5d725c7b2f126d7ef96
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2026-03-04 03:31:13 +00:00 committed by Brad Fitzpatrick
parent 120f27f383
commit d42b3743b7
2 changed files with 271 additions and 0 deletions

176
net/porttrack/porttrack.go Normal file
View File

@ -0,0 +1,176 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
// Package porttrack provides race-free ephemeral port assignment for
// subprocess tests. The parent test process creates a [Collector] that
// listens on a TCP port; the child process uses [Listen] which, when
// given a magic address, binds to localhost:0 and reports the actual
// port back to the collector.
//
// The magic address format is:
//
// testport-report:HOST:PORT/LABEL
//
// where HOST:PORT is the collector's TCP address and LABEL identifies
// which listener this is (e.g. "main", "plaintext").
//
// When [Listen] is called with a non-magic address, it falls through to
// [net.Listen] with zero overhead beyond a single [strings.HasPrefix]
// check.
package porttrack
import (
"bufio"
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"tailscale.com/util/testenv"
)
const magicPrefix = "testport-report:"
// Collector is the parent/test side of the porttrack protocol. It
// listens for port reports from child processes that used [Listen]
// with a magic address obtained from [Collector.Addr].
type Collector struct {
ln net.Listener
mu sync.Mutex
cond *sync.Cond
ports map[string]int
err error // non-nil if a context passed to Port was cancelled
}
// NewCollector creates a new Collector. The collector's TCP listener is
// closed when t finishes.
func NewCollector(t testenv.TB) *Collector {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("porttrack.NewCollector: %v", err)
}
c := &Collector{
ln: ln,
ports: make(map[string]int),
}
c.cond = sync.NewCond(&c.mu)
go c.accept(t)
t.Cleanup(func() { ln.Close() })
return c
}
// accept runs in a goroutine, accepting connections and parsing port
// reports until the listener is closed.
func (c *Collector) accept(t testenv.TB) {
for {
conn, err := c.ln.Accept()
if err != nil {
return // listener closed
}
go c.handleConn(t, conn)
}
}
func (c *Collector) handleConn(t testenv.TB, conn net.Conn) {
defer conn.Close()
scanner := bufio.NewScanner(conn)
for scanner.Scan() {
line := scanner.Text()
label, portStr, ok := strings.Cut(line, "\t")
if !ok {
t.Errorf("porttrack: malformed report line: %q", line)
return
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Errorf("porttrack: bad port in report %q: %v", line, err)
return
}
c.mu.Lock()
c.ports[label] = port
c.cond.Broadcast()
c.mu.Unlock()
}
}
// Addr returns a magic address string that, when passed to [Listen],
// causes the child to bind to localhost:0 and report its actual port
// back to this collector under the given label.
func (c *Collector) Addr(label string) string {
return magicPrefix + c.ln.Addr().String() + "/" + label
}
// Port blocks until the child process has reported the port for the
// given label, then returns it. If ctx is cancelled before a port is
// reported, Port returns the context's cause as an error.
func (c *Collector) Port(ctx context.Context, label string) (int, error) {
stop := context.AfterFunc(ctx, func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.err = context.Cause(ctx)
}
c.cond.Broadcast()
})
defer stop()
c.mu.Lock()
defer c.mu.Unlock()
for {
if p, ok := c.ports[label]; ok {
return p, nil
}
if c.err != nil {
return 0, c.err
}
c.cond.Wait()
}
}
// Listen is the child/production side of the porttrack protocol.
//
// If address has the magic prefix (as returned by [Collector.Addr]),
// Listen binds to localhost:0 on the given network, then TCP-connects
// to the collector and writes "LABEL\tPORT\n" to report the actual
// port. The collector connection is closed before returning.
//
// If address does not have the magic prefix, Listen is simply
// [net.Listen](network, address).
func Listen(network, address string) (net.Listener, error) {
rest, ok := strings.CutPrefix(address, magicPrefix)
if !ok {
return net.Listen(network, address)
}
// rest is "HOST:PORT/LABEL"
slashIdx := strings.LastIndex(rest, "/")
if slashIdx < 0 {
return nil, fmt.Errorf("porttrack: malformed magic address %q: missing /LABEL", address)
}
collectorAddr := rest[:slashIdx]
label := rest[slashIdx+1:]
ln, err := net.Listen(network, "localhost:0")
if err != nil {
return nil, err
}
port := ln.Addr().(*net.TCPAddr).Port
conn, err := net.Dial("tcp", collectorAddr)
if err != nil {
ln.Close()
return nil, fmt.Errorf("porttrack: failed to connect to collector at %s: %v", collectorAddr, err)
}
_, err = fmt.Fprintf(conn, "%s\t%d\n", label, port)
conn.Close()
if err != nil {
ln.Close()
return nil, fmt.Errorf("porttrack: failed to report port to collector: %v", err)
}
return ln, nil
}

View File

@ -0,0 +1,95 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package porttrack
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"testing"
)
func TestCollectorAndListen(t *testing.T) {
c := NewCollector(t)
labels := []string{"main", "plaintext", "debug"}
ports := make([]int, len(labels))
for i, label := range labels {
ln, err := Listen("tcp", c.Addr(label))
if err != nil {
t.Fatalf("Listen(%q): %v", label, err)
}
defer ln.Close()
p, err := c.Port(t.Context(), label)
if err != nil {
t.Fatalf("Port(%q): %v", label, err)
}
ports[i] = p
}
// All ports should be distinct non-zero values.
seen := map[int]string{}
for i, label := range labels {
if ports[i] == 0 {
t.Errorf("Port(%q) = 0", label)
}
if prev, ok := seen[ports[i]]; ok {
t.Errorf("Port(%q) = Port(%q) = %d", label, prev, ports[i])
}
seen[ports[i]] = label
}
}
func TestListenPassthrough(t *testing.T) {
ln, err := Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Listen passthrough: %v", err)
}
defer ln.Close()
if ln.Addr().(*net.TCPAddr).Port == 0 {
t.Fatal("expected non-zero port")
}
}
func TestRoundTrip(t *testing.T) {
c := NewCollector(t)
ln, err := Listen("tcp", c.Addr("http"))
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer ln.Close()
// Start a server on the listener.
go http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
port, err := c.Port(t.Context(), "http")
if err != nil {
t.Fatalf("Port: %v", err)
}
resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port))
if err != nil {
t.Fatalf("http.Get: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
}
}
func TestPortContextCancelled(t *testing.T) {
c := NewCollector(t)
// Nobody will ever report "never", so Port should block until ctx is done.
ctx, cancel := context.WithCancel(t.Context())
cancel()
_, err := c.Port(ctx, "never")
if !errors.Is(err, context.Canceled) {
t.Fatalf("Port with cancelled context: got %v, want %v", err, context.Canceled)
}
}