diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 8a2c2b9ef..311d6a4ba 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -196,11 +196,17 @@ func (v *onceFlagValue) IsBoolFlag() bool { // noDupFlagify modifies c recursively to make all the // flag values be wrappers that permit setting the value -// at most once. +// at most once. If a flag is already wrapped, it resets +// the wrapper's state instead of double-wrapping. func noDupFlagify(c *ffcli.Command) { if c.FlagSet != nil { c.FlagSet.VisitAll(func(f *flag.Flag) { - f.Value = &onceFlagValue{Value: f.Value} + if ofv, ok := f.Value.(*onceFlagValue); ok { + // Already wrapped; reset the flag state + ofv.set = false + } else { + f.Value = &onceFlagValue{Value: f.Value} + } }) } for _, sub := range c.Subcommands { diff --git a/kube/services/services.go b/kube/services/services.go index 0c27f888f..a31f4d197 100644 --- a/kube/services/services.go +++ b/kube/services/services.go @@ -16,6 +16,19 @@ import ( "tailscale.com/types/logger" ) +// serviceWaitDuration is the time to wait for services to propagate after +// advertising/unadvertising. This can be overridden in tests via +// SetWaitDurationForTest. +var serviceWaitDuration = 20 * time.Second + +// SetWaitDurationForTest sets the service wait duration and returns a function +// to restore the original value. This should only be used in tests. +func SetWaitDurationForTest(d time.Duration) func() { + old := serviceWaitDuration + serviceWaitDuration = d + return func() { serviceWaitDuration = old } +} + // EnsureServicesAdvertised is a function that gets called on containerboot // startup and ensures that Services get advertised if they exist. func EnsureServicesAdvertised(ctx context.Context, services []string, lc localclient.LocalClient, logf logger.Logf) error { @@ -47,7 +60,7 @@ func EnsureServicesAdvertised(ctx context.Context, services []string, lc localcl select { case <-ctx.Done(): return nil - case <-time.After(20 * time.Second): + case <-time.After(serviceWaitDuration): return nil } } @@ -94,7 +107,7 @@ func EnsureServicesNotAdvertised(ctx context.Context, lc *local.Client, logf log select { case <-ctx.Done(): return nil - case <-time.After(20 * time.Second): + case <-time.After(serviceWaitDuration): return nil } } diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index dfd4bbd87..a95dd1282 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -93,6 +93,9 @@ func (d *Detector) detectCaptivePortalWithGOOS(ctx context.Context, netMon *netm // the captive portal alert thrown by the system. If no default route interface is known, // we need to try with anything that might remotely resemble a Wi-Fi interface. for ifName, i := range ifState.Interface { + if ctx.Err() != nil { + return false + } if !i.IsUp() || i.IsLoopback() || interfaceNameDoesNotNeedCaptiveDetection(ifName, goos) { continue } diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index a64c358c5..37dedf96a 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -1015,6 +1015,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe } // Wait for captive portal check before finishing the report. + // Try to stop the captive portal check timer in case it hasn't fired yet. + // This is safe to call multiple times - if the timer already fired, the + // goroutine will close the channel when it completes. + captivePortalStop() <-captivePortalDone return c.finishAndStoreReport(rs, dm), nil diff --git a/util/eventbus/eventbustest/eventbustest.go b/util/eventbus/eventbustest/eventbustest.go index b3ef6c884..c0dff4ab7 100644 --- a/util/eventbus/eventbustest/eventbustest.go +++ b/util/eventbus/eventbustest/eventbustest.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "reflect" + "sync" "testing" "time" @@ -43,9 +44,10 @@ func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher { // // For usage examples, see the documentation in the top of the package. type Watcher struct { - mon *eventbus.Subscriber[eventbus.RoutedEvent] - events chan any - chDone chan bool + mon *eventbus.Subscriber[eventbus.RoutedEvent] + events chan any + chDone chan bool + doneOnce sync.Once } // Type is a helper representing the expectation to see an event of type T, without @@ -174,7 +176,9 @@ func (tw *Watcher) watch() { // done tells the watcher to stop monitoring for new events. func (tw *Watcher) done() { - close(tw.chDone) + tw.doneOnce.Do(func() { + close(tw.chDone) + }) } type filter = func(any) (bool, error)