diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index e6335e54d..7bca6c8d8 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -205,7 +205,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) { } }) return c, nil - } // SetPaused controls whether HTTP activity should be paused. @@ -424,6 +423,11 @@ func (c *Auto) unpausedChanLocked() <-chan bool { return unpaused } +// ClientID returns the ClientID of the direct controlClient +func (c *Auto) ClientID() int64 { + return c.direct.ClientID() +} + // mapRoutineState is the state of Auto.mapRoutine while it's running. type mapRoutineState struct { c *Auto diff --git a/control/controlclient/client.go b/control/controlclient/client.go index 8df64f9e8..d0aa129ae 100644 --- a/control/controlclient/client.go +++ b/control/controlclient/client.go @@ -81,6 +81,9 @@ type Client interface { // in a separate http request. It has nothing to do with the rest of // the state machine. UpdateEndpoints(endpoints []tailcfg.Endpoint) + // ClientID returns the ClientID of a client. This ID is meant to + // distinguish one client from another. + ClientID() int64 } // UserVisibleError is an error that should be shown to users. diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index 792c26955..2efc27b5e 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -35,6 +35,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" + "tailscale.com/util/eventbus/eventbustest" ) func fieldsOf(t reflect.Type) (fields []string) { @@ -218,6 +219,8 @@ func TestDirectProxyManual(t *testing.T) { t.Skip("skipping without --live-network-test") } + bus := eventbustest.NewBus(t) + dialer := &tsdial.Dialer{} dialer.SetNetMon(netmon.NewStatic()) @@ -239,6 +242,7 @@ func TestDirectProxyManual(t *testing.T) { }, Dialer: dialer, ControlKnobs: &controlknobs.Knobs{}, + Bus: bus, } d, err := NewDirect(opts) if err != nil { @@ -263,6 +267,8 @@ func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) } func testHTTPS(t *testing.T, withProxy bool) { bakedroots.ResetForTest(t, tlstest.TestRootCA()) + bus := eventbustest.NewBus(t) + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig()) if err != nil { t.Fatal(err) @@ -327,6 +333,7 @@ func testHTTPS(t *testing.T, withProxy bool) { t.Logf("PopBrowserURL: %q", url) }, Dialer: dialer, + Bus: bus, } d, err := NewDirect(opts) if err != nil { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 47283a673..b9e26cc98 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -52,6 +53,7 @@ import ( "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" "tailscale.com/util/syspolicy/pkey" @@ -63,30 +65,31 @@ import ( // Direct is the client that connects to a tailcontrol server for a node. type Direct struct { - httpc *http.Client // HTTP client used to talk to tailcontrol - interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial - dialer *tsdial.Dialer - dnsCache *dnscache.Resolver - controlKnobs *controlknobs.Knobs // always non-nil - serverURL string // URL of the tailcontrol server - clock tstime.Clock - logf logger.Logf - netMon *netmon.Monitor // non-nil - health *health.Tracker - discoPubKey key.DiscoPublic - getMachinePrivKey func() (key.MachinePrivate, error) - debugFlags []string - skipIPForwardingCheck bool - pinger Pinger - polc policyclient.Client // always non-nil - popBrowser func(url string) // or nil - c2nHandler http.Handler // or nil - onClientVersion func(*tailcfg.ClientVersion) // or nil - onControlTime func(time.Time) // or nil - onTailnetDefaultAutoUpdate func(bool) // or nil - panicOnUse bool // if true, panic if client is used (for testing) - closedCtx context.Context // alive until Direct.Close is called - closeCtx context.CancelFunc // cancels closedCtx + httpc *http.Client // HTTP client used to talk to tailcontrol + interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial + dialer *tsdial.Dialer + dnsCache *dnscache.Resolver + controlKnobs *controlknobs.Knobs // always non-nil + serverURL string // URL of the tailcontrol server + clock tstime.Clock + logf logger.Logf + netMon *netmon.Monitor // non-nil + health *health.Tracker + discoPubKey key.DiscoPublic + busClient *eventbus.Client + clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] + autoUpdatePub *eventbus.Publisher[AutoUpdate] + controlTimePub *eventbus.Publisher[ControlTime] + getMachinePrivKey func() (key.MachinePrivate, error) + debugFlags []string + skipIPForwardingCheck bool + pinger Pinger + popBrowser func(url string) // or nil + polc policyclient.Client // always non-nil + c2nHandler http.Handler // or nil + panicOnUse bool // if true, panic if client is used (for testing) + closedCtx context.Context // alive until Direct.Close is called + closeCtx context.CancelFunc // cancels closedCtx dialPlan ControlDialPlanner // can be nil @@ -107,6 +110,8 @@ type Direct struct { tkaHead string lastPingURL string // last PingRequest.URL received, for dup suppression connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest + + controlClientID int64 // Random ID used to differentiate clients for consumers of messages. } // Observer is implemented by users of the control client (such as LocalBackend) @@ -120,26 +125,24 @@ type Observer interface { } type Options struct { - Persist persist.Persist // initial persistent data - GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use - ServerURL string // URL of the tailcontrol server - AuthKey string // optional node auth key for auto registration - Clock tstime.Clock - Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc - DiscoPublicKey key.DiscoPublic - PolicyClient policyclient.Client // or nil for none - Logf logger.Logf - HTTPTestClient *http.Client // optional HTTP client to use (for tests only) - NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) - DebugFlags []string // debug settings to send to control - HealthTracker *health.Tracker - PopBrowserURL func(url string) // optional func to open browser - OnClientVersion func(*tailcfg.ClientVersion) // optional func to inform GUI of client version status - OnControlTime func(time.Time) // optional func to notify callers of new time from control - OnTailnetDefaultAutoUpdate func(bool) // optional func to inform GUI of default auto-update setting for the tailnet - Dialer *tsdial.Dialer // non-nil - C2NHandler http.Handler // or nil - ControlKnobs *controlknobs.Knobs // or nil to ignore + Persist persist.Persist // initial persistent data + GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use + ServerURL string // URL of the tailcontrol server + AuthKey string // optional node auth key for auto registration + Clock tstime.Clock + Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc + DiscoPublicKey key.DiscoPublic + PolicyClient policyclient.Client // or nil for none + Logf logger.Logf + HTTPTestClient *http.Client // optional HTTP client to use (for tests only) + NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) + DebugFlags []string // debug settings to send to control + HealthTracker *health.Tracker + PopBrowserURL func(url string) // optional func to open browser + Dialer *tsdial.Dialer // non-nil + C2NHandler http.Handler // or nil + ControlKnobs *controlknobs.Knobs // or nil to ignore + Bus *eventbus.Bus // Observer is called when there's a change in status to report // from the control client. @@ -287,33 +290,32 @@ func NewDirect(opts Options) (*Direct, error) { } c := &Direct{ - httpc: httpc, - interceptedDial: interceptedDial, - controlKnobs: opts.ControlKnobs, - getMachinePrivKey: opts.GetMachinePrivateKey, - serverURL: opts.ServerURL, - clock: opts.Clock, - logf: opts.Logf, - persist: opts.Persist.View(), - authKey: opts.AuthKey, - discoPubKey: opts.DiscoPublicKey, - debugFlags: opts.DebugFlags, - netMon: netMon, - health: opts.HealthTracker, - skipIPForwardingCheck: opts.SkipIPForwardingCheck, - pinger: opts.Pinger, - polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), - popBrowser: opts.PopBrowserURL, - onClientVersion: opts.OnClientVersion, - onTailnetDefaultAutoUpdate: opts.OnTailnetDefaultAutoUpdate, - onControlTime: opts.OnControlTime, - c2nHandler: opts.C2NHandler, - dialer: opts.Dialer, - dnsCache: dnsCache, - dialPlan: opts.DialPlan, + httpc: httpc, + interceptedDial: interceptedDial, + controlKnobs: opts.ControlKnobs, + getMachinePrivKey: opts.GetMachinePrivateKey, + serverURL: opts.ServerURL, + clock: opts.Clock, + logf: opts.Logf, + persist: opts.Persist.View(), + authKey: opts.AuthKey, + discoPubKey: opts.DiscoPublicKey, + debugFlags: opts.DebugFlags, + netMon: netMon, + health: opts.HealthTracker, + skipIPForwardingCheck: opts.SkipIPForwardingCheck, + pinger: opts.Pinger, + polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), + popBrowser: opts.PopBrowserURL, + c2nHandler: opts.C2NHandler, + dialer: opts.Dialer, + dnsCache: dnsCache, + dialPlan: opts.DialPlan, } c.closedCtx, c.closeCtx = context.WithCancel(context.Background()) + c.controlClientID = rand.Int64() + if opts.Hostinfo == nil { c.SetHostinfo(hostinfo.New()) } else { @@ -331,6 +333,12 @@ func NewDirect(opts Options) (*Direct, error) { if strings.Contains(opts.ServerURL, "controlplane.tailscale.com") && envknob.Bool("TS_PANIC_IF_HIT_MAIN_CONTROL") { c.panicOnUse = true } + + c.busClient = opts.Bus.Client("controlClient.direct") + c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient) + c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient) + c.controlTimePub = eventbus.Publish[ControlTime](c.busClient) + return c, nil } @@ -340,6 +348,7 @@ func (c *Direct) Close() error { c.mu.Lock() defer c.mu.Unlock() + c.busClient.Close() if c.noiseClient != nil { if err := c.noiseClient.Close(); err != nil { return err @@ -826,6 +835,23 @@ func (c *Direct) SendUpdate(ctx context.Context) error { return c.sendMapRequest(ctx, false, nil) } +// ClientID returns the ControlClientID of the controlClient +func (c *Direct) ClientID() int64 { + return c.controlClientID +} + +// AutoUpdate wraps a bool for naming on the eventbus +type AutoUpdate struct { + ClientID int64 // The ID field is used for consumers to differentiate instances of Direct + Value bool +} + +// ControlTime wraps a [time.Time] for naming on the eventbus +type ControlTime struct { + ClientID int64 // The ID field is used for consumers to differentiate instances of Direct + Value time.Time +} + // If we go more than watchdogTimeout without hearing from the server, // end the long poll. We should be receiving a keep alive ping // every minute. @@ -1085,14 +1111,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap c.logf("netmap: control says to open URL %v; no popBrowser func", u) } } - if resp.ClientVersion != nil && c.onClientVersion != nil { - c.onClientVersion(resp.ClientVersion) + if resp.ClientVersion != nil { + c.clientVersionPub.Publish(*resp.ClientVersion) } if resp.ControlTime != nil && !resp.ControlTime.IsZero() { c.logf.JSON(1, "controltime", resp.ControlTime.UTC()) - if c.onControlTime != nil { - c.onControlTime(*resp.ControlTime) - } + c.controlTimePub.Publish(ControlTime{c.controlClientID, *resp.ControlTime}) } if resp.KeepAlive { vlogf("netmap: got keep-alive") @@ -1112,9 +1136,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap continue } if au, ok := resp.DefaultAutoUpdate.Get(); ok { - if c.onTailnetDefaultAutoUpdate != nil { - c.onTailnetDefaultAutoUpdate(au) - } + c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, au}) } metricMapResponseMap.Add(1) diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index e2a6f9fa4..bba76d6f0 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -17,12 +17,14 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/eventbus/eventbustest" ) func TestNewDirect(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() opts := Options{ @@ -32,6 +34,7 @@ func TestNewDirect(t *testing.T) { return k, nil }, Dialer: tsdial.NewDialer(netmon.NewStatic()), + Bus: bus, } c, err := NewDirect(opts) if err != nil { @@ -99,6 +102,7 @@ func TestTsmpPing(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() opts := Options{ @@ -108,6 +112,7 @@ func TestTsmpPing(t *testing.T) { return k, nil }, Dialer: tsdial.NewDialer(netmon.NewStatic()), + Bus: bus, } c, err := NewDirect(opts) diff --git a/ipn/ipnlocal/expiry.go b/ipn/ipnlocal/expiry.go index d11199815..3d20d57b4 100644 --- a/ipn/ipnlocal/expiry.go +++ b/ipn/ipnlocal/expiry.go @@ -6,12 +6,14 @@ package ipnlocal import ( "time" + "tailscale.com/control/controlclient" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" ) // For extra defense-in-depth, when we're testing expired nodes we check @@ -40,14 +42,46 @@ type expiryManager struct { logf logger.Logf clock tstime.Clock + + eventClient *eventbus.Client + controlTimeSub *eventbus.Subscriber[controlclient.ControlTime] + subsDoneCh chan struct{} // closed when consumeEventbusTopics returns } -func newExpiryManager(logf logger.Logf) *expiryManager { - return &expiryManager{ +func newExpiryManager(logf logger.Logf, bus *eventbus.Bus) *expiryManager { + em := &expiryManager{ previouslyExpired: map[tailcfg.StableNodeID]bool{}, logf: logf, clock: tstime.StdClock{}, } + + em.eventClient = bus.Client("ipnlocal.expiryManager") + em.controlTimeSub = eventbus.Subscribe[controlclient.ControlTime](em.eventClient) + + em.subsDoneCh = make(chan struct{}) + go em.consumeEventbusTopics() + + return em +} + +// consumeEventbusTopics consumes events from all relevant +// [eventbus.Subscriber]'s and passes them to their related handler. Events are +// always handled in the order they are received, i.e. the next event is not +// read until the previous event's handler has returned. It returns when the +// [controlclient.ControlTime] subscriber is closed, which is interpreted to be the +// same as the [eventbus.Client] closing ([eventbus.Subscribers] are either +// all open or all closed). +func (em *expiryManager) consumeEventbusTopics() { + defer close(em.subsDoneCh) + + for { + select { + case <-em.controlTimeSub.Done(): + return + case time := <-em.controlTimeSub.Events(): + em.onControlTime(time.Value) + } + } } // onControlTime is called whenever we receive a new timestamp from the control @@ -218,6 +252,11 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim return nextExpiry } +func (em *expiryManager) close() { + em.eventClient.Close() + <-em.subsDoneCh +} + // ControlNow estimates the current time on the control server, calculated as // localNow + the delta between local and control server clocks as recorded // when the LocalBackend last received a time message from the control server. diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index a2b10fe32..2c646ca72 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -14,6 +14,7 @@ import ( "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus/eventbustest" ) func TestFlagExpiredPeers(t *testing.T) { @@ -110,7 +111,8 @@ func TestFlagExpiredPeers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) if tt.controlTime != nil { em.onControlTime(*tt.controlTime) @@ -240,7 +242,8 @@ func TestNextPeerExpiry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) got := em.nextPeerExpiry(tt.netmap, now) if !got.Equal(tt.want) { @@ -253,7 +256,8 @@ func TestNextPeerExpiry(t *testing.T) { t.Run("ClockSkew", func(t *testing.T) { t.Logf("local time: %q", now.Format(time.RFC3339)) - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) // The local clock is "running fast"; our clock skew is -2h diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 6108aa830..c98a0810d 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -99,6 +99,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/deephash" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" "tailscale.com/util/goroutines" "tailscale.com/util/mak" "tailscale.com/util/multierr" @@ -202,6 +203,10 @@ type LocalBackend struct { keyLogf logger.Logf // for printing list of peers on change statsLogf logger.Logf // for printing peers stats on change sys *tsd.System + eventClient *eventbus.Client + clientVersionSub *eventbus.Subscriber[tailcfg.ClientVersion] + autoUpdateSub *eventbus.Subscriber[controlclient.AutoUpdate] + subsDoneCh chan struct{} // closed when consumeEventbusTopics returns health *health.Tracker // always non-nil polc policyclient.Client // always non-nil metrics metrics @@ -525,7 +530,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo backendLogID: logID, state: ipn.NoState, portpoll: new(portlist.Poller), - em: newExpiryManager(logf), + em: newExpiryManager(logf, sys.Bus.Get()), loginFlags: loginFlags, clock: clock, selfUpdateProgress: make([]ipnstate.UpdateProgress, 0), @@ -533,7 +538,11 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo captiveCtx: captiveCtx, captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), + subsDoneCh: make(chan struct{}), } + b.eventClient = b.Sys().Bus.Get().Client("ipnlocal.LocalBackend") + b.clientVersionSub = eventbus.Subscribe[tailcfg.ClientVersion](b.eventClient) + b.autoUpdateSub = eventbus.Subscribe[controlclient.AutoUpdate](b.eventClient) nb := newNodeBackend(ctx, b.sys.Bus.Get()) b.currentNodeAtomic.Store(nb) nb.ready() @@ -604,9 +613,32 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } } + go b.consumeEventbusTopics() return b, nil } +// consumeEventbusTopics consumes events from all relevant +// [eventbus.Subscriber]'s and passes them to their related handler. Events are +// always handled in the order they are received, i.e. the next event is not +// read until the previous event's handler has returned. It returns when the +// [tailcfg.ClientVersion] subscriber is closed, which is interpreted to be the +// same as the [eventbus.Client] closing ([eventbus.Subscribers] are either +// all open or all closed). +func (b *LocalBackend) consumeEventbusTopics() { + defer close(b.subsDoneCh) + + for { + select { + case <-b.clientVersionSub.Done(): + return + case clientVersion := <-b.clientVersionSub.Events(): + b.onClientVersion(&clientVersion) + case au := <-b.autoUpdateSub.Events(): + b.onTailnetDefaultAutoUpdate(au.Value) + } + } +} + func (b *LocalBackend) Clock() tstime.Clock { return b.clock } func (b *LocalBackend) Sys() *tsd.System { return b.sys } @@ -1065,6 +1097,17 @@ func (b *LocalBackend) ClearCaptureSink() { // Shutdown halts the backend and all its sub-components. The backend // can no longer be used after Shutdown returns. func (b *LocalBackend) Shutdown() { + // Close the [eventbus.Client] and wait for LocalBackend.consumeEventbusTopics + // to return. Do this before acquiring b.mu: + // 1. LocalBackend.consumeEventbusTopics event handlers also acquire b.mu, + // they can deadlock with c.Shutdown(). + // 2. LocalBackend.consumeEventbusTopics event handlers may not guard against + // undesirable post/in-progress LocalBackend.Shutdown() behaviors. + b.eventClient.Close() + <-b.subsDoneCh + + b.em.close() + b.mu.Lock() if b.shutdownCalled { b.mu.Unlock() @@ -2465,33 +2508,32 @@ func (b *LocalBackend) Start(opts ipn.Options) error { cb() } } + // TODO(apenwarr): The only way to change the ServerURL is to // re-run b.Start, because this is the only place we create a // new controlclient. EditPrefs allows you to overwrite ServerURL, // but it won't take effect until the next Start. cc, err := b.getNewControlClientFuncLocked()(controlclient.Options{ - GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(), - Logf: logger.WithPrefix(b.logf, "control: "), - Persist: *persistv, - ServerURL: serverURL, - AuthKey: opts.AuthKey, - Hostinfo: hostinfo, - HTTPTestClient: httpTestClient, - DiscoPublicKey: discoPublic, - DebugFlags: debugFlags, - HealthTracker: b.health, - PolicyClient: b.sys.PolicyClientOrDefault(), - Pinger: b, - PopBrowserURL: b.tellClientToBrowseToURL, - OnClientVersion: b.onClientVersion, - OnTailnetDefaultAutoUpdate: b.onTailnetDefaultAutoUpdate, - OnControlTime: b.em.onControlTime, - Dialer: b.Dialer(), - Observer: b, - C2NHandler: http.HandlerFunc(b.handleC2N), - DialPlan: &b.dialPlan, // pointer because it can't be copied - ControlKnobs: b.sys.ControlKnobs(), - Shutdown: ccShutdown, + GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(), + Logf: logger.WithPrefix(b.logf, "control: "), + Persist: *persistv, + ServerURL: serverURL, + AuthKey: opts.AuthKey, + Hostinfo: hostinfo, + HTTPTestClient: httpTestClient, + DiscoPublicKey: discoPublic, + DebugFlags: debugFlags, + HealthTracker: b.health, + PolicyClient: b.sys.PolicyClientOrDefault(), + Pinger: b, + PopBrowserURL: b.tellClientToBrowseToURL, + Dialer: b.Dialer(), + Observer: b, + C2NHandler: http.HandlerFunc(b.handleC2N), + DialPlan: &b.dialPlan, // pointer because it can't be copied + ControlKnobs: b.sys.ControlKnobs(), + Shutdown: ccShutdown, + Bus: b.sys.Bus.Get(), // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -4482,7 +4524,6 @@ func (b *LocalBackend) changeDisablesExitNodeLocked(prefs ipn.PrefsView, change // but wasn't empty before, then the change disables // exit node usage. return tmpPrefs.ExitNodeID == "" - } // adjustEditPrefsLocked applies additional changes to mp if necessary, @@ -8001,7 +8042,6 @@ func isAllowedAutoExitNodeID(polc policyclient.Client, exitNodeID tailcfg.Stable } if nodes, _ := polc.GetStringArray(pkey.AllowedSuggestedExitNodes, nil); nodes != nil { return slices.Contains(nodes, string(exitNodeID)) - } return true // no policy configured; allow all exit nodes } @@ -8145,9 +8185,7 @@ func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcf return servicesList } -var ( - metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") -) +var metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") func (b *LocalBackend) stateEncrypted() opt.Bool { switch runtime.GOOS { diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 7d1c452f3..261d5c4c2 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -59,6 +59,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/set" @@ -455,7 +456,8 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) { } func newTestLocalBackend(t testing.TB) *LocalBackend { - return newTestLocalBackendWithSys(t, tsd.NewSystem()) + bus := eventbustest.NewBus(t) + return newTestLocalBackendWithSys(t, tsd.NewSystemWithBus(bus)) } // newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System. @@ -533,7 +535,6 @@ func TestZeroExitNodeViaLocalAPI(t *testing.T) { ExitNodeID: "", }, }, user) - if err != nil { t.Fatalf("enabling first exit node: %v", err) } @@ -543,7 +544,6 @@ func TestZeroExitNodeViaLocalAPI(t *testing.T) { if got, want := pv.InternalExitNodePrior(), tailcfg.StableNodeID(""); got != want { t.Fatalf("unexpected InternalExitNodePrior %q, want: %q", got, want) } - } func TestSetUseExitNodeEnabled(t *testing.T) { @@ -3619,7 +3619,8 @@ func TestPreferencePolicyInfo(t *testing.T) { prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) - sys := tsd.NewSystem() + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) sys.PolicyClient.Set(polc) lb := newTestLocalBackendWithSys(t, sys) @@ -5786,7 +5787,8 @@ func TestNotificationTargetMatch(t *testing.T) { type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { - return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystem(), newControl) + bus := eventbustest.NewBus(t) + return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystemWithBus(bus), newControl) } func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { @@ -5945,7 +5947,6 @@ func (w *notificationWatcher) watch(mask ipn.NotifyWatchOpt, wanted []wantedNoti return true }) - }() <-watchAddedCh } diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index 842b75c43..93ecd977f 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -35,6 +35,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/tkatype" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" "tailscale.com/util/set" ) @@ -49,6 +50,7 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() opts := controlclient.Options{ @@ -61,6 +63,7 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { NoiseTestClient: c, Observer: observerFunc(func(controlclient.Status) {}), Dialer: tsdial.NewDialer(netmon.NewStatic()), + Bus: bus, } cc, err := controlclient.NewNoStart(opts) diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index e2561cba9..86b56ab4b 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -33,6 +33,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/syspolicy/policyclient" @@ -240,11 +241,15 @@ func TestServeConfigForeground(t *testing.T) { err := b.SetServeConfig(&ipn.ServeConfig{ Foreground: map[string]*ipn.ServeConfig{ - session1: {TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {TCPForward: "http://localhost:3000"}}, + session1: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {TCPForward: "http://localhost:3000"}, + }, }, - session2: {TCP: map[uint16]*ipn.TCPPortHandler{ - 999: {TCPForward: "http://localhost:4000"}}, + session2: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}, + }, }, }, }, "") @@ -267,8 +272,10 @@ func TestServeConfigForeground(t *testing.T) { 5000: {TCPForward: "http://localhost:5000"}, }, Foreground: map[string]*ipn.ServeConfig{ - session2: {TCP: map[uint16]*ipn.TCPPortHandler{ - 999: {TCPForward: "http://localhost:4000"}}, + session2: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}, + }, }, }, }, "") @@ -491,7 +498,6 @@ func TestServeConfigServices(t *testing.T) { } }) } - } func TestServeConfigETag(t *testing.T) { @@ -659,6 +665,7 @@ func TestServeHTTPProxyPath(t *testing.T) { }) } } + func TestServeHTTPProxyHeaders(t *testing.T) { b := newTestBackend(t) @@ -859,7 +866,6 @@ func Test_reverseProxyConfiguration(t *testing.T) { wantsURL: mustCreateURL(t, "https://example3.com"), }, }) - } func mustCreateURL(t *testing.T, u string) url.URL { @@ -878,7 +884,8 @@ func newTestBackend(t *testing.T, opts ...any) *LocalBackend { logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ") } - sys := tsd.NewSystem() + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) for _, o := range opts { switch v := o.(type) { @@ -952,13 +959,13 @@ func newTestBackend(t *testing.T, opts ...any) *LocalBackend { func TestServeFileOrDirectory(t *testing.T) { td := t.TempDir() writeFile := func(suffix, contents string) { - if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil { + if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0o600); err != nil { t.Fatal(err) } } writeFile("foo", "this is foo") writeFile("bar", "this is bar") - os.MkdirAll(filepath.Join(td, "subdir"), 0700) + os.MkdirAll(filepath.Join(td, "subdir"), 0o700) writeFile("subdir/file-a", "this is A") writeFile("subdir/file-b", "this is B") writeFile("subdir/file-c", "this is C") diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 4097a3773..30538f2c8 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "math/rand/v2" "net/netip" "strings" "sync" @@ -39,6 +40,7 @@ import ( "tailscale.com/types/persist" "tailscale.com/types/preftype" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/wgengine" @@ -113,10 +115,11 @@ func (nt *notifyThrottler) drain(count int) []ipn.Notify { // in the controlclient.Client, so by controlling it, we can check that // the state machine works as expected. type mockControl struct { - tb testing.TB - logf logger.Logf - opts controlclient.Options - paused atomic.Bool + tb testing.TB + logf logger.Logf + opts controlclient.Options + paused atomic.Bool + controlClientID int64 mu sync.Mutex persist *persist.Persist @@ -127,12 +130,13 @@ type mockControl struct { func newClient(tb testing.TB, opts controlclient.Options) *mockControl { return &mockControl{ - tb: tb, - authBlocked: true, - logf: opts.Logf, - opts: opts, - shutdown: make(chan struct{}), - persist: opts.Persist.Clone(), + tb: tb, + authBlocked: true, + logf: opts.Logf, + opts: opts, + shutdown: make(chan struct{}), + persist: opts.Persist.Clone(), + controlClientID: rand.Int64(), } } @@ -287,6 +291,10 @@ func (cc *mockControl) UpdateEndpoints(endpoints []tailcfg.Endpoint) { cc.called("UpdateEndpoints") } +func (cc *mockControl) ClientID() int64 { + return cc.controlClientID +} + func (b *LocalBackend) nonInteractiveLoginForStateTest() { b.mu.Lock() if b.cc == nil { @@ -1507,7 +1515,8 @@ func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) ( dialer := &tsdial.Dialer{Logf: logf} dialer.SetNetMon(netmon.NewStatic()) - sys := tsd.NewSystem() + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) sys.Set(dialer) sys.Set(dialer.NetMon()) diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 970f798d0..046eb744d 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -35,6 +35,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/slicesx" "tailscale.com/wgengine" ) @@ -158,7 +159,6 @@ func TestWhoIsArgTypes(t *testing.T) { t.Fatalf("backend called with %v; want %v", k, keyStr) } return match() - }, peerCaps: map[netip.Addr]tailcfg.PeerCapMap{ netip.MustParseAddr("100.101.102.103"): map[tailcfg.PeerCapability][]tailcfg.RawMessage{ @@ -336,7 +336,7 @@ func TestServeWatchIPNBus(t *testing.T) { func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { var logf logger.Logf = logger.Discard - sys := tsd.NewSystem() + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) store := new(mem.Store) sys.Set(store) eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) diff --git a/tsd/tsd.go b/tsd/tsd.go index bd333bd31..e4a512e4b 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -80,9 +80,17 @@ type System struct { // NewSystem constructs a new otherwise-empty [System] with a // freshly-constructed event bus populated. -func NewSystem() *System { +func NewSystem() *System { return NewSystemWithBus(eventbus.New()) } + +// NewSystemWithBus constructs a new otherwise-empty [System] with an +// eventbus provided by the caller. The provided bus must not be nil. +// This is mainly intended for testing; for production use call [NewBus]. +func NewSystemWithBus(bus *eventbus.Bus) *System { + if bus == nil { + panic("nil eventbus") + } sys := new(System) - sys.Set(eventbus.New()) + sys.Set(bus) return sys } diff --git a/util/eventbus/eventbustest/eventbustest.go b/util/eventbus/eventbustest/eventbustest.go index 98536ae0a..b7375adc4 100644 --- a/util/eventbus/eventbustest/eventbustest.go +++ b/util/eventbus/eventbustest/eventbustest.go @@ -15,7 +15,7 @@ import ( // NewBus constructs an [eventbus.Bus] that will be shut automatically when // its controlling test ends. -func NewBus(t *testing.T) *eventbus.Bus { +func NewBus(t testing.TB) *eventbus.Bus { bus := eventbus.New() t.Cleanup(bus.Close) return bus