From 4941cd7c73a69a867521c42b58b095a883ebbc2e Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 10 Apr 2025 20:24:58 -0500 Subject: [PATCH] cmd/tailscaled,ipn/{auditlog,desktop,ipnext,ipnlocal},tsd: extract LocalBackend extension interfaces and implementation In this PR, we refactor the LocalBackend extension system, moving from direct callbacks to a more organized extension host model. Specifically, we: - Extract interface and callback types used by packages extending LocalBackend functionality into a new ipn/ipnext package. - Define ipnext.Host as a new interface that bridges extensions with LocalBackend. It enables extensions to register callbacks and interact with LocalBackend in a concurrency-safe, well-defined, and controlled way. - Move existing callback registration and invocation code from ipnlocal.LocalBackend into a new type called ipnlocal.ExtensionHost, implementing ipnext.Host. - Improve docs for existing types and methods while adding docs for the new interfaces. - Add test coverage for both the extracted and the new code. - Remove ipn/desktop.SessionManager from tsd.System since ipn/desktop is now self-contained. - Update existing extensions (e.g., ipn/auditlog and ipn/desktop) to use the new interfaces where appropriate. We're not introducing new callback and hook types (e.g., for ipn.Prefs changes) just yet, nor are we enhancing current callbacks, such as by improving conflict resolution when more than one extension tries to influence profile selection via a background profile resolver. These further improvements will be submitted separately. Updates #12614 Updates tailscale/corp#27645 Updates tailscale/corp#26435 Updates tailscale/corp#18342 Signed-off-by: Nick Khyl --- cmd/k8s-operator/depaware.txt | 2 +- cmd/tailscaled/depaware.txt | 3 +- cmd/tailscaled/tailscaled_windows.go | 9 +- ipn/auditlog/extension.go | 39 +- .../extension.go} | 106 +- ipn/ipnext/ipnext.go | 284 ++++ ipn/ipnlocal/extension_host.go | 537 ++++++++ ipn/ipnlocal/extension_host_test.go | 1139 +++++++++++++++++ ipn/ipnlocal/local.go | 283 +--- ipn/ipnlocal/profiles.go | 4 + tsd/tsd.go | 4 - 11 files changed, 2079 insertions(+), 331 deletions(-) rename ipn/{ipnlocal/desktop_sessions.go => desktop/extension.go} (62%) create mode 100644 ipn/ipnext/ipnext.go create mode 100644 ipn/ipnlocal/extension_host.go create mode 100644 ipn/ipnlocal/extension_host_test.go diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 7fd4c4b21..416265188 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -815,8 +815,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/internal/noiseconn from tailscale.com/control/controlclient tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ - πŸ’£ tailscale.com/ipn/desktop from tailscale.com/ipn/ipnlocal+ πŸ’£ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/ipn/localapi from tailscale.com/tsnet+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 394056295..9cdebbae1 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -273,8 +273,9 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn from tailscale.com/client/local+ W tailscale.com/ipn/auditlog from tailscale.com/cmd/tailscaled tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ - πŸ’£ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled+ + W πŸ’£ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled πŸ’£ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/auditlog+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled tailscale.com/ipn/ipnstate from tailscale.com/client/local+ diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index dfe53ef61..54ff2af14 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -45,7 +45,7 @@ import ( "tailscale.com/drive/driveimpl" "tailscale.com/envknob" _ "tailscale.com/ipn/auditlog" - "tailscale.com/ipn/desktop" + _ "tailscale.com/ipn/desktop" "tailscale.com/logpolicy" "tailscale.com/logtail/backoff" "tailscale.com/net/dns" @@ -337,13 +337,6 @@ func beWindowsSubprocess() bool { sys.Set(driveimpl.NewFileSystemForRemote(log.Printf)) - if sessionManager, err := desktop.NewSessionManager(log.Printf); err == nil { - sys.Set(sessionManager) - } else { - // Errors creating the session manager are unexpected, but not fatal. - log.Printf("[unexpected]: error creating a desktop session manager: %v", err) - } - publicLogID, _ := logid.ParsePublicID(logID) err = startIPNServer(ctx, log.Printf, publicLogID, sys) if err != nil { diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go index 8be7dfb66..6bbe37398 100644 --- a/ipn/auditlog/extension.go +++ b/ipn/auditlog/extension.go @@ -14,19 +14,23 @@ import ( "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" - "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/lazy" "tailscale.com/types/logger" ) +// featureName is the name of the feature implemented by this package. +// It is also the the [extension] name and the log prefix. +const featureName = "auditlog" + func init() { - feature.Register("auditlog") - ipnlocal.RegisterExtension("auditlog", newExtension) + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) } -// extension is an [ipnlocal.Extension] managing audit logging +// extension is an [ipnext.Extension] managing audit logging // on platforms that import this package. // As of 2025-03-27, that's only Windows and macOS. type extension struct { @@ -48,19 +52,24 @@ type extension struct { logger *Logger } -// newExtension is an [ipnlocal.NewExtensionFn] that creates a new audit log extension. -// It is registered with [ipnlocal.RegisterExtension] if the package is imported. -func newExtension(logf logger.Logf, _ *tsd.System) (ipnlocal.Extension, error) { - return &extension{logf: logger.WithPrefix(logf, "auditlog: ")}, nil +// newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { + return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil } -// Init implements [ipnlocal.Extension] by registering callbacks and providers +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers // for the duration of the extension's lifetime. -func (e *extension) Init(lb *ipnlocal.LocalBackend) error { +func (e *extension) Init(h ipnext.Host) error { e.cleanup = []func(){ - lb.RegisterControlClientCallback(e.controlClientChanged), - lb.RegisterProfileChangeCallback(e.profileChanged, false), - lb.RegisterAuditLogProvider(e.getCurrentLogger), + h.RegisterControlClientCallback(e.controlClientChanged), + h.Profiles().RegisterProfileChangeCallback(e.profileChanged), + h.RegisterAuditLogProvider(e.getCurrentLogger), } return nil } @@ -165,8 +174,8 @@ func noCurrentLogger(_ tailcfg.ClientAuditAction, _ string) error { return errNoLogger } -// getCurrentLogger is an [ipnlocal.AuditLogProvider] registered with [ipnlocal.LocalBackend]. -// It is called when [ipnlocal.LocalBackend] needs to audit an action. +// getCurrentLogger is an [ipnext.AuditLogProvider] registered with [ipnext.Host]. +// It is called when [ipnlocal.LocalBackend] or an extension needs to audit an action. // // It returns a function that enqueues the audit log for the current profile, // or [noCurrentLogger] if the logger is unavailable. diff --git a/ipn/ipnlocal/desktop_sessions.go b/ipn/desktop/extension.go similarity index 62% rename from ipn/ipnlocal/desktop_sessions.go rename to ipn/desktop/extension.go index 29cb196c7..86ae96f5b 100644 --- a/ipn/ipnlocal/desktop_sessions.go +++ b/ipn/desktop/extension.go @@ -7,29 +7,32 @@ //go:build windows && !ts_omit_desktop_sessions -package ipnlocal +package desktop import ( "cmp" - "errors" "fmt" "sync" "tailscale.com/feature" "tailscale.com/ipn" - "tailscale.com/ipn/desktop" + "tailscale.com/ipn/ipnext" "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/util/syspolicy" ) +// featureName is the name of the feature implemented by this package. +// It is also the the [desktopSessionsExt] name and the log prefix. +const featureName = "desktop-sessions" + func init() { - feature.Register("desktop-sessions") - RegisterExtension("desktop-sessions", newDesktopSessionsExt) + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newDesktopSessionsExt) } -// desktopSessionsExt implements [Extension]. -var _ Extension = (*desktopSessionsExt)(nil) +// [desktopSessionsExt] implements [ipnext.Extension]. +var _ ipnext.Extension = (*desktopSessionsExt)(nil) // desktopSessionsExt extends [LocalBackend] with desktop session management. // It keeps Tailscale running in the background if Always-On mode is enabled, @@ -37,32 +40,41 @@ var _ Extension = (*desktopSessionsExt)(nil) // locks their screen, or disconnects a remote session. type desktopSessionsExt struct { logf logger.Logf - sm desktop.SessionManager + sm SessionManager - *LocalBackend // or nil, until Init is called - cleanup []func() // cleanup functions to call on shutdown + host ipnext.Host // or nil, until Init is called + cleanup []func() // cleanup functions to call on shutdown // mu protects all following fields. - // When both mu and [LocalBackend.mu] need to be taken, - // [LocalBackend.mu] must be taken before mu. - mu sync.Mutex - id2sess map[desktop.SessionID]*desktop.Session + mu sync.Mutex + sessByID map[SessionID]*Session } // newDesktopSessionsExt returns a new [desktopSessionsExt], -// or an error if [desktop.SessionManager] is not available. -func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (Extension, error) { - sm, ok := sys.SessionManager.GetOK() - if !ok { - return nil, errors.New("session manager is not available") +// or an error if a [SessionManager] cannot be created. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (ipnext.Extension, error) { + logf = logger.WithPrefix(logf, featureName+": ") + sm, err := NewSessionManager(logf) + if err != nil { + return nil, fmt.Errorf("%w: session manager is not available: %w", ipnext.SkipExtension, err) } - return &desktopSessionsExt{logf: logf, sm: sm, id2sess: make(map[desktop.SessionID]*desktop.Session)}, nil + return &desktopSessionsExt{ + logf: logf, + sm: sm, + sessByID: make(map[SessionID]*Session), + }, nil } -// Init implements [localBackendExtension]. -func (e *desktopSessionsExt) Init(lb *LocalBackend) (err error) { - e.LocalBackend = lb - unregisterResolver := lb.RegisterBackgroundProfileResolver(e.getBackgroundProfile) +// Name implements [ipnext.Extension]. +func (e *desktopSessionsExt) Name() string { + return featureName +} + +// Init implements [ipnext.Extension]. +func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) { + e.host = host + unregisterResolver := host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile) unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState) if err != nil { unregisterResolver() @@ -72,30 +84,30 @@ func (e *desktopSessionsExt) Init(lb *LocalBackend) (err error) { return nil } -// updateDesktopSessionState is a [desktop.SessionStateCallback] -// invoked by [desktop.SessionManager] once for each existing session +// updateDesktopSessionState is a [SessionStateCallback] +// invoked by [SessionManager] once for each existing session // and whenever the session state changes. It updates the session map // and switches to the best profile if necessary. -func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) { +func (e *desktopSessionsExt) updateDesktopSessionState(session *Session) { e.mu.Lock() - if session.Status != desktop.ClosedSession { - e.id2sess[session.ID] = session + if session.Status != ClosedSession { + e.sessByID[session.ID] = session } else { - delete(e.id2sess, session.ID) + delete(e.sessByID, session.ID) } e.mu.Unlock() var action string switch session.Status { - case desktop.ForegroundSession: + case ForegroundSession: // The user has either signed in or unlocked their session. // For remote sessions, this may also mean the user has connected. // The distinction isn't important for our purposes, // so let's always say "signed in". action = "signed in to" - case desktop.BackgroundSession: + case BackgroundSession: action = "locked" - case desktop.ClosedSession: + case ClosedSession: action = "signed out from" default: panic("unreachable") @@ -104,10 +116,10 @@ func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) userIdentifier := cmp.Or(maybeUsername, string(session.User.UserID()), "user") reason := fmt.Sprintf("%s %s session %v", userIdentifier, action, session.ID) - e.SwitchToBestProfile(reason) + e.host.Profiles().SwitchToBestProfileAsync(reason) } -// getBackgroundProfile is a [profileResolver] that works as follows: +// getBackgroundProfile is a [ipnext.ProfileResolver] that works as follows: // // If Always-On mode is disabled, it returns no profile. // @@ -121,9 +133,7 @@ func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) // disconnects without signing out. // // In all other cases, it returns no profile. -// -// It is called with [LocalBackend.mu] locked. -func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { +func (e *desktopSessionsExt) getBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { e.mu.Lock() defer e.mu.Unlock() @@ -135,16 +145,16 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { isCurrentProfileOwnerSignedIn := false var foregroundUIDs []ipn.WindowsUserID - for _, s := range e.id2sess { + for _, s := range e.sessByID { switch uid := s.User.UserID(); uid { - case e.pm.CurrentProfile().LocalUserID(): + case profiles.CurrentProfile().LocalUserID(): isCurrentProfileOwnerSignedIn = true - if s.Status == desktop.ForegroundSession { + if s.Status == ForegroundSession { // Keep the current profile if the user has a foreground session. - return e.pm.CurrentProfile() + return profiles.CurrentProfile() } default: - if s.Status == desktop.ForegroundSession { + if s.Status == ForegroundSession { foregroundUIDs = append(foregroundUIDs, uid) } } @@ -154,7 +164,7 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { // or if the current profile's owner has no foreground session, switch to the default profile // of the first user with a foreground session, if any. for _, uid := range foregroundUIDs { - if profile := e.pm.DefaultUserProfile(uid); profile.ID() != "" { + if profile := profiles.DefaultUserProfile(uid); profile.ID() != "" { return profile } } @@ -163,19 +173,19 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { // keep the current profile even if the session is not in the foreground, // such as when the screen is locked or a remote session is disconnected. if len(foregroundUIDs) == 0 && isCurrentProfileOwnerSignedIn { - return e.pm.CurrentProfile() + return profiles.CurrentProfile() } // Otherwise, there's no background profile. return ipn.LoginProfileView{} } -// Shutdown implements [localBackendExtension]. +// Shutdown implements [ipnext.Extension]. func (e *desktopSessionsExt) Shutdown() error { for _, f := range e.cleanup { f() } e.cleanup = nil - e.LocalBackend = nil - return nil + e.host = nil + return e.sm.Close() } diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go new file mode 100644 index 000000000..af870b53a --- /dev/null +++ b/ipn/ipnext/ipnext.go @@ -0,0 +1,284 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipnext defines types and interfaces used for extending the core LocalBackend +// functionality with additional features and services. +package ipnext + +import ( + "errors" + "fmt" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/mak" +) + +// Extension augments LocalBackend with additional functionality. +// +// An extension uses the provided [Host] to register callbacks +// and interact with the backend in a controlled, well-defined +// and thread-safe manner. +// +// Extensions are registered using [RegisterExtension]. +// +// They must be safe for concurrent use. +type Extension interface { + // Name is a unique name of the extension. + // It must be the same as the name used to register the extension. + Name() string + + // Init is called to initialize the extension when LocalBackend is initialized. + // If the extension cannot be initialized, it must return an error, + // and its Shutdown method will not be called on the host's shutdown. + // Returned errors are not fatal; they are used for logging. + // A [SkipExtension] error indicates an intentional decision rather than a failure. + Init(Host) error + + // Shutdown is called when LocalBackend is shutting down, + // provided the extension was initialized. For multiple extensions, + // Shutdown is called in the reverse order of Init. + // Returned errors are not fatal; they are used for logging. + Shutdown() error +} + +// NewExtensionFn is a function that instantiates an [Extension]. +// If a registered extension cannot be instantiated, the function must return an error. +// If the extension should be skipped at runtime, it must return either [SkipExtension] +// or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent +// the LocalBackend from starting. +type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error) + +// SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension +// should be skipped rather than prevent the LocalBackend from starting. +// +// Skipping an extension should be reserved for cases where the extension is not supported +// on the current platform or configuration, or depends on a feature that is not available, +// or otherwise should be disabled permanently rather than temporarily. +// +// Specifically, it must not be returned if the extension is not required right now +// based on user preferences, policy settings, the current tailnet, or other factors +// that may change throughout the LocalBackend's lifetime. +var SkipExtension = errors.New("skipping extension") + +// Definition describes a registered [Extension]. +type Definition struct { + name string // name under which the extension is registered + newFn NewExtensionFn // function that creates a new instance of the extension +} + +// Name returns the name of the extension. +func (d *Definition) Name() string { + return d.name +} + +// MakeExtension instantiates the extension. +func (d *Definition) MakeExtension(logf logger.Logf, sys *tsd.System) (Extension, error) { + ext, err := d.newFn(logf, sys) + if err != nil { + return nil, err + } + if ext.Name() != d.name { + return nil, fmt.Errorf("extension name mismatch: registered %q; actual %q", d.name, ext.Name()) + } + return ext, nil +} + +// extensionsByName is a map of registered extensions, +// where the key is the name of the extension. +var extensionsByName map[string]*Definition + +// extensionsByOrder is a slice of registered extensions, +// in the order they were registered. +var extensionsByOrder []*Definition + +// RegisterExtension registers a function that instantiates an [Extension]. +// The name must be the same as returned by the extension's [Extension.Name]. +// +// It must be called on the main goroutine before LocalBackend is created, +// such as from an init function of the package implementing the extension. +// +// It panics if newExt is nil or if an extension with the same name +// has already been registered. +func RegisterExtension(name string, newExt NewExtensionFn) { + if newExt == nil { + panic(fmt.Sprintf("ipnext: newExt is nil: %q", name)) + } + if _, ok := extensionsByName[name]; ok { + panic(fmt.Sprintf("ipnext: duplicate extensions: %q", name)) + } + ext := &Definition{name, newExt} + mak.Set(&extensionsByName, name, ext) + extensionsByOrder = append(extensionsByOrder, ext) +} + +// Extensions returns a read-only view of the extensions +// registered via [RegisterExtension]. It preserves the order +// in which the extensions were registered. +func Extensions() views.Slice[*Definition] { + return views.SliceOf(extensionsByOrder) +} + +// DefinitionForTest returns a [Definition] for the specified [Extension]. +// It is primarily used for testing where the test code needs to instantiate +// and use an extension without registering it. +func DefinitionForTest(ext Extension) *Definition { + return &Definition{ + name: ext.Name(), + newFn: func(logger.Logf, *tsd.System) (Extension, error) { return ext, nil }, + } +} + +// DefinitionWithErrForTest returns a [Definition] with the specified extension name +// whose [Definition.MakeExtension] method returns the specified error. +// It is used for testing. +func DefinitionWithErrForTest(name string, err error) *Definition { + return &Definition{ + name: name, + newFn: func(logger.Logf, *tsd.System) (Extension, error) { return nil, err }, + } +} + +// Host is the API surface used by [Extension]s to interact with LocalBackend +// in a controlled manner. +// +// Extensions can register callbacks, request information, or perform actions +// via the [Host] interface. +// +// Typically, the host invokes registered callbacks when one of the following occurs: +// - LocalBackend notifies it of an event or state change that may be +// of interest to extensions, such as when switching [ipn.LoginProfile]. +// - LocalBackend needs to consult extensions for information, for example, +// determining the most appropriate profile for the current state of the system. +// - LocalBackend performs an extensible action, such as logging an auditable event, +// and delegates its execution to the extension. +// +// The callbacks are invoked synchronously, and the LocalBackend's state +// remains unchanged while callbacks execute. +// +// In contrast, actions initiated by extensions are generally asynchronous, +// as indicated by the "Async" suffix in their names. +// Performing actions may result in callbacks being invoked as described above. +// +// To prevent conflicts between extensions competing for shared state, +// such as the current profile or prefs, the host must not expose methods +// that directly modify that state. For example, instead of allowing extensions +// to switch profiles at-will, the host's [ProfileServices] provides a method +// to switch to the "best" profile. The host can then consult extensions +// to determine the appropriate profile to use and resolve any conflicts +// in a controlled manner. +// +// A host must be safe for concurrent use. +type Host interface { + // Profiles returns the host's [ProfileServices]. + Profiles() ProfileServices + + // RegisterAuditLogProvider registers an audit log provider, + // which returns a function to be called when an auditable action + // is about to be performed. The returned function unregisters the provider. + // It is a runtime error to register a nil provider. + RegisterAuditLogProvider(AuditLogProvider) (unregister func()) + + // AuditLogger returns a function that calls all currently registered audit loggers. + // The function fails if any logger returns an error, indicating that the action + // cannot be logged and must not be performed. + // + // The returned function captures the current state (e.g., the current profile) at + // the time of the call and must not be persisted. + AuditLogger() ipnauth.AuditLogFunc + + // RegisterControlClientCallback registers a function to be called every time a new + // control client is created. The returned function unregisters the callback. + // It is a runtime error to register a nil callback. + RegisterControlClientCallback(NewControlClientCallback) (unregister func()) +} + +// ProfileServices provides access to the [Host]'s profile management services, +// such as switching profiles and registering profile change callbacks. +type ProfileServices interface { + // SwitchToBestProfileAsync asynchronously selects the best profile to use + // and switches to it, unless it is already the current profile. + // + // If an extension needs to know when a profile switch occurs, + // it must use [ProfileServices.RegisterProfileChangeCallback] + // to register a [ProfileChangeCallback]. + // + // The reason indicates why the profile is being switched, such as due + // to a client connecting or disconnecting or a change in the desktop + // session state. It is used for logging. + SwitchToBestProfileAsync(reason string) + + // RegisterBackgroundProfileResolver registers a function to be used when + // resolving the background profile. The returned function unregisters the resolver. + // It is a runtime error to register a nil resolver. + // + // TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver. + // TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver". + // The concepts of the "current user", "foreground profile" and "background profile" + // only exist on Windows, and we're moving away from them anyway. + RegisterBackgroundProfileResolver(ProfileResolver) (unregister func()) + + // RegisterProfileChangeCallback registers a function to be called when the current + // [ipn.LoginProfile] changes. The returned function unregisters the callback. + // It is a runtime error to register a nil callback. + RegisterProfileChangeCallback(ProfileChangeCallback) (unregister func()) +} + +// ProfileStore provides read-only access to available login profiles and their preferences. +// It is not safe for concurrent use and can only be used from the callback it is passed to. +type ProfileStore interface { + // CurrentUserID returns the current user ID. It is only non-empty on + // Windows where we have a multi-user system. + // + // Deprecated: this method exists for compatibility with the current (as of 2024-08-27) + // permission model and will be removed as we progress on tailscale/corp#18342. + CurrentUserID() ipn.WindowsUserID + + // CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. + // The returned view is always valid, but the profile's [ipn.LoginProfileView.ID] + // returns "" if the profile is new and has not been persisted yet. + CurrentProfile() ipn.LoginProfileView + + // CurrentPrefs returns a read-only view of the current prefs. + // The returned view is always valid. + CurrentPrefs() ipn.PrefsView + + // DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. + // It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. + DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView +} + +// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for +// logging auditable actions. +type AuditLogProvider func() ipnauth.AuditLogFunc + +// ProfileResolver is a function that returns a read-only view of a login profile. +// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] +// indicates that the profile is new and has not been persisted yet. +// The provided [ProfileStore] can only be used for the duration of the callback. +type ProfileResolver func(ProfileStore) ipn.LoginProfileView + +// ProfileChangeCallback is a function to be called when the current login profile changes. +// The sameNode parameter indicates whether the profile represents the same node as before, +// such as when only the profile metadata is updated but the node ID remains the same, +// or when a new profile is persisted and assigned an [ipn.ProfileID] for the first time. +// The subscribers can use this information to decide whether to reset their state. +// +// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] +// returns "" if the profile is new and has not been persisted yet. +type ProfileChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) + +// NewControlClientCallback is a function to be called when a new [controlclient.Client] +// is created and before it is first used. The login profile and prefs represent +// the profile for which the cc is created and are always valid; however, the +// profile's [ipn.LoginProfileView.ID] returns "" if the profile is new +// and has not been persisted yet. If the [controlclient.Client] is created +// due to a profile switch, any registered [ProfileChangeCallback]s are called first. +// +// It returns a function to be called when the cc is being shut down, +// or nil if no cleanup is needed. +type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView, ipn.PrefsView) (cleanup func()) diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go new file mode 100644 index 000000000..4a617ed72 --- /dev/null +++ b/ipn/ipnlocal/extension_host.go @@ -0,0 +1,537 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "fmt" + "iter" + "maps" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/util/execqueue" + "tailscale.com/util/set" + "tailscale.com/util/testenv" +) + +// ExtensionHost is a bridge between the [LocalBackend] and the registered [ipnext.Extension]s. +// It implements [ipnext.Host] and is safe for concurrent use. +// +// A nil pointer to [ExtensionHost] is a valid, no-op extension host which is primarily used in tests +// that instantiate [LocalBackend] directly without using [NewExtensionHost]. +// +// The [LocalBackend] is not required to hold its mutex when calling the host's methods, +// but it typically does so either to prevent changes to its state (for example, the current profile) +// while callbacks are executing, or because it calls the host's methods as part of a larger operation +// that requires the mutex to be held. +// +// Extensions might invoke the host's methods either from callbacks triggered by the [LocalBackend], +// or in a response to external events. Some methods can be called by both the extensions and the backend. +// +// As a general rule, the host cannot assume anything about the current state of the [LocalBackend]'s +// internal mutex on entry to its methods, and therefore cannot safely call [LocalBackend] methods directly. +// +// The following are typical and supported patterns: +// - LocalBackend notifies the host about an event, such as a change in the current profile. +// The host invokes callbacks registered by Extensions, forwarding the event arguments to them. +// If necessary, the host can also update its own state for future use. +// - LocalBackend requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the [ipn.LoginProfile] to use when no GUI/CLI client is connected. Typically, [LocalBackend] +// provides the required context to the host, and the host returns the result to [LocalBackend] +// after forwarding the request to the extensions. +// - Extension invokes the host's method to perform an action, such as switching to the "best" profile +// in response to a change in the device's state. Since the host does not know whether the [LocalBackend]'s +// internal mutex is held, it cannot invoke any methods on the [LocalBackend] directly and must instead +// do so asynchronously, such as by using [ExtensionHost.enqueueBackendOperation]. +// - Extension requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the current [ipn.LoginProfile]. Since the host cannot invoke any methods on the [LocalBackend] directly, +// it should maintain its own view of the current state, updating it when the [LocalBackend] notifies it +// about a change or event. +// +// To safeguard against adopting incorrect or risky patterns, the host does not store [LocalBackend] in its fields +// and instead provides [ExtensionHost.enqueueBackendOperation]. Additionally, to make it easier to test extensions +// and to further reduce the risk of accessing unexported methods or fields of [LocalBackend], the host interacts +// with it via the [Backend] interface. +type ExtensionHost struct { + logf logger.Logf // prefixed with "ipnext:" + + // allExtensions holds the extensions in the order they were registered, + // including those that have not yet attempted initialization or have failed to initialize. + allExtensions []ipnext.Extension + + // initOnce is used to ensure that the extensions are initialized only once, + // even if [extensionHost.Init] is called multiple times. + initOnce sync.Once + // shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown]. + shutdownOnce sync.Once + + // workQueue maintains execution order for asynchronous operations requested by extensions. + // It is always an [execqueue.ExecQueue] except in some tests. + workQueue execQueue + // doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue. + doEnqueueBackendOperation func(func(Backend)) + + // mu protects the following fields. + // It must not be held when calling [LocalBackend] methods + // or when invoking callbacks registered by extensions. + mu sync.Mutex + // initialized is whether the host and extensions have been fully initialized. + initialized atomic.Bool + // activeExtensions is a subset of allExtensions that have been initialized and are ready to use. + activeExtensions []ipnext.Extension + // extensionsByName are the activeExtensions indexed by their names. + extensionsByName map[string]ipnext.Extension + // postInitWorkQueue is a queue of functions to be executed + // by the workQueue after all extensions have been initialized. + postInitWorkQueue []func(Backend) + + // auditLoggers are registered [AuditLogProvider]s. + // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action + // is about to be performed. If an audit logger returns an error, the action is denied. + auditLoggers set.HandleSet[ipnext.AuditLogProvider] + // backgroundProfileResolvers are registered background profile resolvers. + // They're used to determine the profile to use when no GUI/CLI client is connected. + backgroundProfileResolvers set.HandleSet[ipnext.ProfileResolver] + // newControlClientCbs are the functions to be called when a new control client is created. + newControlClientCbs set.HandleSet[ipnext.NewControlClientCallback] + // profileChangeCbs are the callbacks to be invoked when the current login profile changes, + // either because of a profile switch, or because the profile information was updated + // by [LocalBackend.SetControlClientStatus], including when the profile is first populated + // and persisted. + profileChangeCbs set.HandleSet[ipnext.ProfileChangeCallback] +} + +// Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost]. +// It is primarily used for testing. +type Backend interface { + // SwitchToBestProfile switches to the best profile for the current state of the system. + // The reason indicates why the profile is being switched. + SwitchToBestProfile(reason string) +} + +// NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend. +// The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called. +// It returns an error if instantiating any extension fails. +// +// If overrideExts is non-nil, the registered extensions are ignored and the provided extensions are used instead. +// Overriding extensions is primarily used for testing. +func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { + host := &ExtensionHost{ + logf: logger.WithPrefix(logf, "ipnext: "), + workQueue: &execqueue.ExecQueue{}, + } + + // All operations on the backend must be executed asynchronously by the work queue. + // DO NOT retain a direct reference to the backend in the host. + // See the docstring for [ExtensionHost] for more details. + host.doEnqueueBackendOperation = func(f func(Backend)) { + if f == nil { + panic("nil backend operation") + } + host.workQueue.Add(func() { f(b) }) + } + + var numExts int + var exts iter.Seq2[int, *ipnext.Definition] + if overrideExts == nil { + // Use registered extensions. + exts = ipnext.Extensions().All() + numExts = ipnext.Extensions().Len() + } else { + // Use the provided, potentially empty, overrideExts + // instead of the registered ones. + exts = slices.All(overrideExts) + numExts = len(overrideExts) + } + + host.allExtensions = make([]ipnext.Extension, 0, numExts) + for _, d := range exts { + ext, err := d.MakeExtension(logf, sys) + if errors.Is(err, ipnext.SkipExtension) { + // The extension wants to be skipped. + host.logf("%q: %v", d.Name(), err) + continue + } else if err != nil { + return nil, fmt.Errorf("failed to create %q extension: %v", d.Name(), err) + } + host.allExtensions = append(host.allExtensions, ext) + } + return host, nil +} + +// Init initializes the host and the extensions it manages. +func (h *ExtensionHost) Init() { + if h != nil { + h.initOnce.Do(h.init) + } +} + +func (h *ExtensionHost) init() { + // Initialize the extensions in the order they were registered. + h.mu.Lock() + h.activeExtensions = make([]ipnext.Extension, 0, len(h.allExtensions)) + h.extensionsByName = make(map[string]ipnext.Extension, len(h.allExtensions)) + h.mu.Unlock() + for _, ext := range h.allExtensions { + // Do not hold the lock while calling [ipnext.Extension.Init]. + // Extensions call back into the host to register their callbacks, + // and that would cause a deadlock if the h.mu is already held. + if err := ext.Init(h); err != nil { + // As per the [ipnext.Extension] interface, failures to initialize + // an extension are never fatal. The extension is simply skipped. + // + // But we handle [ipnext.SkipExtension] differently for nicer logging + // if the extension wants to be skipped and not actually failing. + if errors.Is(err, ipnext.SkipExtension) { + h.logf("%q: %v", ext.Name(), err) + } else { + h.logf("%q init failed: %v", ext.Name(), err) + } + continue + } + // Update the initialized extensions lists as soon as the extension is initialized. + // We'd like to make them visible to other extensions that are initialized later. + h.mu.Lock() + h.activeExtensions = append(h.activeExtensions, ext) + h.extensionsByName[ext.Name()] = ext + h.mu.Unlock() + } + + // Report active extensions to the log. + // TODO(nickkhyl): update client metrics to include the active/failed/skipped extensions. + h.mu.Lock() + extensionNames := slices.Collect(maps.Keys(h.extensionsByName)) + h.mu.Unlock() + h.logf("active extensions: %v", strings.Join(extensionNames, ", ")) + + // Additional init steps that need to be performed after all extensions have been initialized. + h.mu.Lock() + wq := h.postInitWorkQueue + h.postInitWorkQueue = nil + h.initialized.Store(true) + h.mu.Unlock() + + // Enqueue work that was requested and deferred during initialization. + h.doEnqueueBackendOperation(func(b Backend) { + for _, f := range wq { + f(b) + } + }) + +} + +// Profiles implements [ipnext.Host]. +func (h *ExtensionHost) Profiles() ipnext.ProfileServices { + // Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// SwitchToBestProfileAsync implements [ipnext.ProfileServices]. +func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SwitchToBestProfile(reason) + }) +} + +// RegisterProfileChangeCallback implements [ipnext.ProfileServices]. +func (h *ExtensionHost) RegisterProfileChangeCallback(cb ipnext.ProfileChangeCallback) (unregister func()) { + if h == nil { + return func() {} + } + if cb == nil { + panic("nil profile change callback") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.profileChangeCbs.Add(cb) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.profileChangeCbs, handle) + } +} + +// NotifyProfileChange invokes registered profile change callbacks. +// It strips private keys from the [ipn.Prefs] before passing it to the callbacks. +func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + if h == nil { + return + } + h.mu.Lock() + cbs := collectValues(h.profileChangeCbs) + h.mu.Unlock() + if cbs != nil { + // Strip private keys from the prefs before passing it to the callbacks. + // Extensions should not need it (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + prefs = stripKeysFromPrefs(prefs) + for _, cb := range cbs { + cb(profile, prefs, sameNode) + } + } +} + +// RegisterBackgroundProfileResolver implements [ipnext.ProfileServices]. +func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) (unregister func()) { + if h == nil { + return func() {} + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.backgroundProfileResolvers.Add(resolver) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.backgroundProfileResolvers, handle) + } +} + +// DetermineBackgroundProfile returns a read-only view of the profile +// used when no GUI/CLI client is connected, using background profile +// resolvers registered by extensions. +// +// It returns an invalid view if Tailscale should not run in the background +// and instead disconnect until a GUI/CLI client connects. +// +// As of 2025-02-07, this is only used on Windows. +func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + if h == nil { + return ipn.LoginProfileView{} + } + // TODO(nickkhyl): check if the returned profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + + // Attempt to resolve the background profile using the registered + // background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows). + h.mu.Lock() + resolvers := collectValues(h.backgroundProfileResolvers) + h.mu.Unlock() + for _, resolver := range resolvers { + if profile := resolver(profiles); profile.Valid() { + return profile + } + } + + // Otherwise, switch to an empty profile and disconnect Tailscale + // until a GUI or CLI client connects. + return ipn.LoginProfileView{} +} + +// RegisterControlClientCallback implements [ipnext.Host]. +func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) (unregister func()) { + if h == nil { + return func() {} + } + if cb == nil { + panic("nil control client callback") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.newControlClientCbs.Add(cb) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.newControlClientCbs, handle) + } +} + +// NotifyNewControlClient invokes all registered control client callbacks. +// It returns callbacks to be executed when the control client shuts down. +func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView, prefs ipn.PrefsView) (ccShutdownCbs []func()) { + if h == nil { + return nil + } + h.mu.Lock() + cbs := collectValues(h.newControlClientCbs) + h.mu.Unlock() + if len(cbs) > 0 { + ccShutdownCbs = make([]func(), 0, len(cbs)) + for _, cb := range cbs { + if shutdown := cb(cc, profile, prefs); shutdown != nil { + ccShutdownCbs = append(ccShutdownCbs, shutdown) + } + } + } + return ccShutdownCbs +} + +// RegisterAuditLogProvider implements [ipnext.Host]. +func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) (unregister func()) { + if h == nil { + return func() {} + } + if provider == nil { + panic("nil audit log provider") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.auditLoggers.Add(provider) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.auditLoggers, handle) + } +} + +// AuditLogger returns a function that reports an auditable action +// to all registered audit loggers. It fails if any of them returns an error, +// indicating that the action cannot be logged and must not be performed. +// +// It implements [ipnext.Host], but is also used by the [LocalBackend]. +// +// The returned function closes over the current state of the host and extensions, +// which typically includes the current profile and the audit loggers registered by extensions. +// It must not be persisted outside of the auditable action context. +func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc { + if h == nil { + return func(tailcfg.ClientAuditAction, string) error { return nil } + } + + h.mu.Lock() + providers := collectValues(h.auditLoggers) + h.mu.Unlock() + + var loggers []ipnauth.AuditLogFunc + if len(providers) > 0 { + loggers = make([]ipnauth.AuditLogFunc, len(providers)) + for i, provider := range providers { + loggers[i] = provider() + } + } + return func(action tailcfg.ClientAuditAction, details string) error { + // Log auditable actions to the host's log regardless of whether + // the audit loggers are available or not. + h.logf("auditlog: %v: %v", action, details) + + // Invoke all registered audit loggers and collect errors. + // If any of them returns an error, the action is denied. + var errs []error + for _, logger := range loggers { + if err := logger(action, details); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + } +} + +// Shutdown shuts down the extension host and all initialized extensions. +func (h *ExtensionHost) Shutdown() { + if h == nil { + return + } + // Ensure that the init function has completed before shutting down, + // or prevent any further init calls from happening. + h.initOnce.Do(func() {}) + h.shutdownOnce.Do(h.shutdown) +} + +func (h *ExtensionHost) shutdown() { + // Prevent any queued but not yet started operations from running, + // block new operations from being enqueued, and wait for the + // currently executing operation (if any) to finish. + h.shutdownWorkQueue() + // Invoke shutdown callbacks registered by extensions. + h.shutdownExtensions() +} + +func (h *ExtensionHost) shutdownWorkQueue() { + h.workQueue.Shutdown() + var ctx context.Context + if testenv.InTest() { + // In tests, we'd like to wait indefinitely for the current operation to finish, + // mostly to help avoid flaky tests. Test runners can be pretty slow. + ctx = context.Background() + } else { + // In prod, however, we want to avoid blocking indefinitely. + // The 5s timeout is somewhat arbitrary; LocalBackend operations + // should not take that long. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + // Since callbacks are invoked synchronously, this will also wait + // for in-flight callbacks associated with those operations to finish. + if err := h.workQueue.Wait(ctx); err != nil { + h.logf("work queue shutdown failed: %v", err) + } +} + +func (h *ExtensionHost) shutdownExtensions() { + h.mu.Lock() + extensions := h.activeExtensions + h.mu.Unlock() + + // h.mu must not be held while shutting down extensions. + // Extensions might call back into the host and that would cause + // a deadlock if the h.mu is already held. + // + // Shutdown is called in the reverse order of Init. + for _, ext := range slices.Backward(extensions) { + if err := ext.Shutdown(); err != nil { + // Extension shutdown errors are never fatal, but we log them for debugging purposes. + h.logf("%q: shutdown callback failed: %v", ext.Name(), err) + } + } +} + +// enqueueBackendOperation enqueues a function to perform an operation on the [Backend]. +// If the host has not yet been initialized (e.g., when called from an extension's Init method), +// the operation is deferred until after the host and all extensions have completed initialization. +// It panics if the f is nil. +func (h *ExtensionHost) enqueueBackendOperation(f func(Backend)) { + if h == nil { + return + } + if f == nil { + panic("nil backend operation") + } + h.mu.Lock() // protects h.initialized and h.postInitWorkQueue + defer h.mu.Unlock() + if h.initialized.Load() { + h.doEnqueueBackendOperation(f) + } else { + h.postInitWorkQueue = append(h.postInitWorkQueue, f) + } +} + +// execQueue is an ordered asynchronous queue for executing functions. +// It is implemented by [execqueue.ExecQueue]. The interface is used +// to allow testing with a mock implementation. +type execQueue interface { + Add(func()) + Shutdown() + Wait(context.Context) error +} + +// collectValues is like [slices.Collect] of [maps.Values], +// but pre-allocates the slice to avoid reallocations. +// It returns nil if the map is empty. +func collectValues[K comparable, V any](m map[K]V) []V { + if len(m) == 0 { + return nil + } + s := make([]V, 0, len(m)) + for _, v := range m { + s = append(s, v) + } + return s +} diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go new file mode 100644 index 000000000..1e03abaa1 --- /dev/null +++ b/ipn/ipnlocal/extension_host_test.go @@ -0,0 +1,1139 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "context" + "errors" + "net/netip" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + deepcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/persist" + "tailscale.com/util/must" +) + +// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes +// and shuts down extensions. +func TestExtensionInitShutdown(t *testing.T) { + t.Parallel() + + // As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors + // as extension initialization and shutdown errors are not fatal. + // If these methods are updated to return errors, this test should also be updated. + // The conversions below will fail to compile if their signatures change, reminding us to update the test. + _ = (func(*ExtensionHost))((*ExtensionHost).Init) + _ = (func(*ExtensionHost))((*ExtensionHost).Shutdown) + + tests := []struct { + name string + nilHost bool + exts []*testExtension + wantInit []string + wantShutdown []string + skipInit bool + }{ + { + name: "nil-host", + nilHost: true, + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "empty-extensions", + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "single-extension", + exts: []*testExtension{{name: "A"}}, + wantInit: []string{"A"}, + wantShutdown: []string{"A"}, + }, + { + name: "multiple-extensions/all-ok", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B", "A"}, + }, + { + name: "multiple-extensions/no-init-no-shutdown", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{}, + wantShutdown: []string{}, + skipInit: true, + }, + { + name: "multiple-extensions/init-failed/first", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B"}, + }, + { + name: "multiple-extensions/init-failed/second", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + { + name: "multiple-extensions/init-failed/third", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"B", "A"}, + }, + { + name: "multiple-extensions/init-failed/all", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{}, + }, + { + name: "multiple-extensions/init-skipped", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return ipnext.SkipExtension }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Configure all extensions to append their names + // to the gotInit and gotShutdown slices + // during initialization and shutdown, + // so we can check that they are called in the right order + // and that shutdown is not unless init succeeded. + var gotInit, gotShutdown []string + for _, ext := range tt.exts { + oldInitHook := ext.InitHook + ext.InitHook = func(e *testExtension) error { + gotInit = append(gotInit, e.name) + if oldInitHook == nil { + return nil + } + return oldInitHook(e) + } + ext.ShutdownHook = func(e *testExtension) error { + gotShutdown = append(gotShutdown, e.name) + return nil + } + } + + var h *ExtensionHost + if !tt.nilHost { + h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...) + } + + if !tt.skipInit { + h.Init() + } + + // Check that the extensions were initialized in the right order. + if !slices.Equal(gotInit, tt.wantInit) { + t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit) + } + + // Calling Init again on the host should be a no-op. + // The [testExtension.Init] method fails the test if called more than once, + // regardless of which test is running, so we don't need to check it here. + // Similarly, calling Shutdown again on the host should be a no-op as well. + // It is verified by the [testExtension.Shutdown] method itself. + if !tt.skipInit { + h.Init() + } + + // Extensions should not be shut down before the host is shut down, + // even if they are not initialized successfully. + for _, ext := range tt.exts { + if gotShutdown := ext.ShutdownCalled(); gotShutdown { + t.Errorf("%q: Extension shutdown called before host shutdown", ext.name) + } + } + + h.Shutdown() + // Check that the extensions were shut down in the right order, + // and that they were not shut down if they were not initialized successfully. + if !slices.Equal(gotShutdown, tt.wantShutdown) { + t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown) + } + + }) + } +} + +// TestNewExtensionHost tests that [NewExtensionHost] correctly creates +// an [ExtensionHost], instantiates the extensions and handles errors +// if an extension cannot be created. +func TestNewExtensionHost(t *testing.T) { + t.Parallel() + tests := []struct { + name string + defs []*ipnext.Definition + wantErr bool + wantExts []string + }{ + { + name: "no-exts", + defs: []*ipnext.Definition{}, + wantErr: false, + wantExts: []string{}, + }, + { + name: "exts-ok", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionForTest(&testExtension{name: "B"}), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, + wantExts: []string{"A", "B", "C"}, + }, + { + name: "exts-skipped", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, // extension B is skipped, that's ok + wantExts: []string{"A", "C"}, + }, + { + name: "exts-fail", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: true, // extension B failed to create, that's not ok + wantExts: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + logf := tstest.WhileTestRunningLogger(t) + h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr) + } + if err != nil { + return + } + + var gotExts []string + for _, ext := range h.allExtensions { + gotExts = append(gotExts, ext.Name()) + } + + if !slices.Equal(gotExts, tt.wantExts) { + t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts) + } + }) + } +} + +// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues +// backend operations and executes them asynchronously in the order they were received. +// It also checks that operations requested before the host and all extensions are initialized +// are not executed immediately but rather after the host and extensions are initialized. +func TestExtensionHostEnqueueBackendOperation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + preInitCalls []string // before host init + extInitCalls []string // from [Extension.Init]; "" means no call + wantInitCalls []string // what we expect to be called after host init + postInitCalls []string // after host init + }{ + { + name: "no-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{}, + }, + { + name: "pre-init-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{}, + wantInitCalls: []string{"pre-init-1", "pre-init-2"}, + postInitCalls: []string{}, + }, + { + name: "init-calls", + preInitCalls: []string{}, + extInitCalls: []string{"init-1", "init-2"}, + wantInitCalls: []string{"init-1", "init-2"}, + postInitCalls: []string{}, + }, + { + name: "post-init-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + { + name: "mixed-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{"init-1", "", "init-2"}, + wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var gotCalls []string + var h *ExtensionHost + b := &testBackend{ + switchToBestProfileHook: func(reason string) { + gotCalls = append(gotCalls, reason) + }, + } + + exts := make([]*testExtension, len(tt.extInitCalls)) + for i, reason := range tt.extInitCalls { + exts[i] = &testExtension{} + if reason != "" { + exts[i].InitHook = func(e *testExtension) error { + e.host.Profiles().SwitchToBestProfileAsync(reason) + return nil + } + } + } + + h = newExtensionHostForTest(t, b, false, exts...) + wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue]. + + // Issue some pre-init calls. They should be deferred and not + // added to the queue until the host is initialized. + for _, call := range tt.preInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // The queue should be empty before the host is initialized. + wq.Drain() + if len(gotCalls) != 0 { + t.Errorf("Pre-init calls: got %v; want (none)", gotCalls) + } + gotCalls = nil + + // Initialize the host and all extensions. + // The extensions will make their calls during initialization. + h.Init() + + // Calls made before or during initialization should now be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + + // Let's make some more calls, as if extensions were making them in a response + // to external events. + for _, call := range tt.postInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // Any calls made after initialization should be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + }) + } +} + +// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration, +// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called +// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback. +func TestExtensionHostProfileChangeCallback(t *testing.T) { + t.Parallel() + + type profileChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + // newProfileChange creates a new profile change with deep copies of the profile and prefs. + newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange { + return profileChange{ + Profile: profile.AsStruct(), + Prefs: prefs.AsStruct(), + SameNode: sameNode, + } + } + // makeProfileChangeAppender returns a callback that appends profile changes to the extension's state. + makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback { + return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + UpdateExtState(e, "changes", func(changes []profileChange) []profileChange { + return append(changes, newProfileChange(profile, prefs, sameNode)) + }) + } + } + // getProfileChanges returns the profile changes stored in the extension's state. + getProfileChanges := func(e *testExtension) []profileChange { + changes, _ := GetExtStateOk[[]profileChange](e, "changes") + return changes + } + + tests := []struct { + name string + ext *testExtension + calls []profileChange + wantCalls []profileChange + }{ + { + // Register the callback for the lifetime of the extension. + name: "Register/Lifetime", + ext: &testExtension{}, + calls: []profileChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + wantCalls: []profileChange{ // all calls are received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + }, + { + // Override the default InitHook used in the test to unregister the callback + // after the first call. + name: "Register/Once", + ext: &testExtension{ + InitHook: func(e *testExtension) error { + var unregister func() + handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + makeProfileChangeAppender(e)(profile, prefs, sameNode) + unregister() + } + unregister = e.host.Profiles().RegisterProfileChangeCallback(handler) + return nil + }, + }, + calls: []profileChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + }, + wantCalls: []profileChange{ // only the first call is received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + }, + }, + { + // Ensure that ipn.Prefs are passed to the callback. + name: "CheckPrefs", + ext: &testExtension{}, + calls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + wantCalls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + }, + { + // Ensure that private keys are stripped from persist.Persist shared with extensions. + name: "StripPrivateKeys", + ext: &testExtension{}, + calls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + wantCalls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use the default InitHook if not provided by the test. + if tt.ext.InitHook == nil { + tt.ext.InitHook = func(e *testExtension) error { + // Create and register the callback on init. + handler := makeProfileChangeAppender(e) + e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler)) + return nil + } + } + + h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext) + for _, call := range tt.calls { + h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode) + } + opts := []deepcmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" { + t.Errorf("ProfileChange callbacks: (-want +got): %v", diff) + } + }) + } +} + +// TestBackgroundProfileResolver tests that the background profile resolvers +// are correctly registered, unregistered and invoked by the [ExtensionHost]. +func TestBackgroundProfileResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + profiles []ipn.LoginProfile // the first one is the current profile + resolvers []ipnext.ProfileResolver + wantProfile *ipn.LoginProfile + }{ + { + name: "No-Profiles/No-Resolvers", + profiles: nil, + resolvers: nil, + wantProfile: nil, + }, + { + // TODO(nickkhyl): update this test as we change "background profile resolvers" + // to just "profile resolvers". The wantProfile should be the current profile by default. + name: "Has-Profiles/No-Resolvers", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: nil, + wantProfile: nil, + }, + { + name: "Has-Profiles/Single-Resolver", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: []ipnext.ProfileResolver{ + func(ps ipnext.ProfileStore) ipn.LoginProfileView { + return ps.CurrentProfile() + }, + }, + wantProfile: &ipn.LoginProfile{ID: "profile-1"}, + }, + // TODO(nickkhyl): add more tests for multiple resolvers and different profiles + // once we change "background profile resolvers" to just "profile resolvers" + // and add proper conflict resolution logic. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a new profile manager and add the profiles to it. + // We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface. + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) + for i, p := range tt.profiles { + // Generate a unique ID and key for each profile, + // unless the profile already has them set + // or is an empty, unnamed profile. + if p.Name != "" { + if p.ID == "" { + p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i)) + } + if p.Key == "" { + p.Key = "key-" + ipn.StateKey(p.ID) + } + } + pv := p.View() + pm.knownProfiles[p.ID] = pv + if i == 0 { + // Set the first profile as the current one. + // A profileManager starts with an empty profile, + // so it's okay if the list of profiles is empty. + pm.SwitchToProfile(pv) + } + } + + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true) + + // Register the resolvers with the host. + // This is typically done by the extensions themselves, + // but we do it here for testing purposes. + for _, r := range tt.resolvers { + t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r)) + } + + // Call the resolver to get the profile. + gotProfile := h.DetermineBackgroundProfile(pm) + if !gotProfile.Equals(tt.wantProfile.View()) { + t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile) + } + }) + } +} + +// TestAuditLogProviders tests that the [ExtensionHost] correctly handles +// the registration and invocation of audit log providers. It verifies that +// the audit loggers are called with the correct actions and details, +// and that any errors returned by the providers are properly propagated. +func TestAuditLogProviders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + auditLoggers []ipnauth.AuditLogFunc // each represents an extension + actions []tailcfg.ClientAuditAction + wantErr bool + }{ + { + name: "No-Providers", + auditLoggers: nil, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, + }, + { + name: "Many-Providers/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Many-Providers/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + func(tailcfg.ClientAuditAction, string) error { + return nil // all good + }, + func(tailcfg.ClientAuditAction, string) error { + return errors.New("also failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, // some providers failed to log, so that's an error + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create extensions that register the audit log providers. + // Each extension/provider will append auditable actions to its state, + // then call the test's auditLogger function. + var exts []*testExtension + for _, auditLogger := range tt.auditLoggers { + ext := &testExtension{} + provider := func() ipnauth.AuditLogFunc { + return func(action tailcfg.ClientAuditAction, details string) error { + UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction { + return append(actions, action) + }) + return auditLogger(action, details) + } + } + ext.InitHook = func(e *testExtension) error { + e.Cleanup(e.host.RegisterAuditLogProvider(provider)) + return nil + } + exts = append(exts, ext) + } + + // Initialize the host and the extensions. + h := newExtensionHostForTest(t, &testBackend{}, true, exts...) + + // Use [ExtensionHost.AuditLogger] to log actions. + for _, action := range tt.actions { + err := h.AuditLogger()(action, "Test details") + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr) + } + } + + // Check that the actions were logged correctly by each provider. + for _, ext := range exts { + gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions") + if !slices.Equal(gotActions, tt.actions) { + t.Errorf("Actions: got %v; want %v", gotActions, tt.actions) + } + } + }) + } +} + +// TestNilExtensionHostMethodCall tests that calling exported methods +// on a nil [ExtensionHost] does not panic. We should treat it as a valid +// value since it's used in various tests that instantiate [LocalBackend] +// manually without calling [NewLocalBackend]. It also verifies that if +// a method returns a single func value (e.g., a cleanup function), +// it should not be nil. This is a basic sanity check to ensure that +// typical method calls on a nil receiver work as expected. +// It does not replace the need for more thorough testing of specific methods. +func TestNilExtensionHostMethodCall(t *testing.T) { + t.Parallel() + + var h *ExtensionHost + typ := reflect.TypeOf(h) + for i := range typ.NumMethod() { + m := typ.Method(i) + if strings.HasSuffix(m.Name, "ForTest") { + // Skip methods that are only for testing. + continue + } + + t.Run(m.Name, func(t *testing.T) { + t.Parallel() + // Calling the method on the nil receiver should not panic. + ret := checkMethodCallWithZeroArgs(t, m, h) + if len(ret) == 1 && ret[0].Kind() == reflect.Func { + // If the method returns a single func, such as a cleanup function, + // it should not be nil. + fn := ret[0] + if fn.IsNil() { + t.Fatalf("(%T).%s returned a nil func", h, m.Name) + } + // We expect it to be a no-op and calling it should not panic. + args := makeZeroArgsFor(fn) + func() { + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e) + } + }() + fn.Call(args) + }() + } + }) + } +} + +// checkMethodCallWithZeroArgs calls the method m on the receiver r +// with zero values for all its arguments, except the receiver itself. +// It returns the result of the method call, or fails the test if the call panics. +func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value { + t.Helper() + args := makeZeroArgsFor(m.Func) + // The first arg is the receiver. + args[0] = reflect.ValueOf(r) + // Calling the method should not panic. + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e) + } + }() + return m.Func.Call(args) +} + +func makeZeroArgsFor(fn reflect.Value) []reflect.Value { + args := make([]reflect.Value, fn.Type().NumIn()) + for i := range args { + args[i] = reflect.Zero(fn.Type().In(i)) + } + return args +} + +// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions. +// It associates each extension that either is or embeds a [testExtension] with the test +// and assigns a name if one isn’t already set. +// +// If the host cannot be created, it fails the test. +// +// The host is initialized if the initialize parameter is true. +// It is shut down automatically when the test ends. +func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost { + t.Helper() + + // testExtensionIface is a subset of the methods implemented by [testExtension] that are used here. + // We use testExtensionIface in type assertions instead of using the [testExtension] type directly, + // which supports scenarios where an extension type embeds a [testExtension]. + type testExtensionIface interface { + Name() string + setName(string) + setT(*testing.T) + checkShutdown() + } + + logf := tstest.WhileTestRunningLogger(t) + defs := make([]*ipnext.Definition, len(exts)) + for i, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i))) + ext.setT(t) + } + defs[i] = ipnext.DefinitionForTest(ext) + } + h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...) + if err != nil { + t.Fatalf("NewExtensionHost: %v", err) + } + // Replace doEnqueueBackendOperation with the one that's marked as a helper, + // so that we'll have better output if [testExecQueue.Add] fails a test. + h.doEnqueueBackendOperation = func(f func(Backend)) { + t.Helper() + h.workQueue.Add(func() { f(b) }) + } + for _, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + t.Cleanup(ext.checkShutdown) + } + } + t.Cleanup(h.Shutdown) + if initialize { + h.Init() + } + return h +} + +// testExtension is an [ipnext.Extension] that: +// - Calls the provided init and shutdown callbacks +// when [Init] and [Shutdown] are called. +// - Ensures that [Init] and [Shutdown] are called at most once, +// that [Shutdown] is called after [Init], but is not called if [Init] fails +// and is called before the test ends if [Init] succeeds. +// +// Typically, [testExtension]s are created and passed to [newExtensionHostForTest] +// when creating an [ExtensionHost] for testing. +type testExtension struct { + t *testing.T // test that created the extension + name string // name of the extension, used for logging + + host ipnext.Host // or nil if not initialized + + // InitHook and ShutdownHook are optional hooks that can be set by tests. + InitHook, ShutdownHook func(*testExtension) error + + // initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown + // are called at most once and in the correct order. + initCnt, initOkCnt, shutdownCnt atomic.Int32 + + // mu protects the following fields. + mu sync.Mutex + // state is the optional state used by tests. + // It can be accessed by tests using [setTestExtensionState], + // [getTestExtensionStateOk] and [getTestExtensionState]. + state map[string]any + // cleanup are functions to be called on shutdown. + cleanup []func() +} + +var _ ipnext.Extension = (*testExtension)(nil) + +func (e *testExtension) setT(t *testing.T) { + e.t = t +} + +func (e *testExtension) setName(name string) { + e.name = name +} + +// Name implements [ipnext.Extension]. +func (e *testExtension) Name() string { + return e.name +} + +// Init implements [ipnext.Extension]. +func (e *testExtension) Init(host ipnext.Host) (err error) { + e.t.Helper() + e.host = host + if e.initCnt.Add(1) == 1 { + e.mu.Lock() + e.state = make(map[string]any) + e.mu.Unlock() + } else { + e.t.Errorf("%q: Init called more than once", e.name) + } + if e.InitHook != nil { + err = e.InitHook(e) + } + if err == nil { + e.initOkCnt.Add(1) + } + return err // may be nil or non-nil +} + +// InitCalled reports whether the Init method was called on the receiver. +func (e *testExtension) InitCalled() bool { + return e.initCnt.Load() != 0 +} + +func (e *testExtension) Cleanup(f func()) { + e.mu.Lock() + e.cleanup = append(e.cleanup, f) + e.mu.Unlock() +} + +// Shutdown implements [ipnext.Extension]. +func (e *testExtension) Shutdown() (err error) { + e.t.Helper() + e.mu.Lock() + cleanup := e.cleanup + e.cleanup = nil + e.mu.Unlock() + for _, f := range cleanup { + f() + } + if e.ShutdownHook != nil { + err = e.ShutdownHook(e) + } + if e.shutdownCnt.Add(1) != 1 { + e.t.Errorf("%q: Shutdown called more than once", e.name) + } + if e.initCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called without Init", e.name) + } else if e.initOkCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called despite failed Init", e.name) + } + e.host = nil + return err // may be nil or non-nil +} + +func (e *testExtension) checkShutdown() { + e.t.Helper() + if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown has not been called before test end", e.name) + } +} + +// ShutdownCalled reports whether the Shutdown method was called on the receiver. +func (e *testExtension) ShutdownCalled() bool { + return e.shutdownCnt.Load() != 0 +} + +// SetExtState sets a keyed state on [testExtension] to the given value. +// Tests use it to propagate test-specific state throughout the extension lifecycle +// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks) +func SetExtState[T any](e *testExtension, key string, value T) { + e.mu.Lock() + defer e.mu.Unlock() + e.state[key] = value +} + +// UpdateExtState updates a keyed state of the extension using the provided update function. +func UpdateExtState[T any](e *testExtension, key string, update func(T) T) { + e.mu.Lock() + defer e.mu.Unlock() + old, _ := e.state[key].(T) + new := update(old) + e.state[key] = new +} + +// GetExtState returns the value of the keyed state of the extension. +// It returns a zero value of T if the state is not set or is of a different type. +func GetExtState[T any](e *testExtension, key string) T { + v, _ := GetExtStateOk[T](e, key) + return v +} + +// GetExtStateOk is like [getExtState], but also reports whether the state +// with the given key exists and is of the expected type. +func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) { + e.mu.Lock() + defer e.mu.Unlock() + v, ok := e.state[key].(T) + return v, ok +} + +// testExecQueue is a test implementation of [execQueue] +// that defers execution of the enqueued funcs until +// [testExecQueue.Drain] is called, and fails the test if +// if [execQueue.Add] is called before the host is initialized. +// +// It is typically used by calling [ExtensionHost.SetWorkQueueForTest]. +type testExecQueue struct { + t *testing.T // test that created the queue + h *ExtensionHost // host to own the queue + + mu sync.Mutex + queue []func() +} + +var _ execQueue = (*testExecQueue)(nil) + +// SetWorkQueueForTest is a helper function that creates a new [testExecQueue] +// and sets it as the work queue for the specified [ExtensionHost], +// returning the new queue. +// +// It fails the test if the host is already initialized. +func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue { + t.Helper() + if h.initialized.Load() { + t.Fatalf("UseTestWorkQueue: host is already initialized") + return nil + } + q := &testExecQueue{t: t, h: h} + h.workQueue = q + return q +} + +// Add implements [execQueue]. +func (q *testExecQueue) Add(f func()) { + q.t.Helper() + + if !q.h.initialized.Load() { + q.t.Fatal("ExecQueue.Add must not be called until the host is initialized") + return + } + + q.mu.Lock() + q.queue = append(q.queue, f) + q.mu.Unlock() +} + +// Drain executes all queued functions in the order they were added. +func (q *testExecQueue) Drain() { + q.mu.Lock() + queue := q.queue + q.queue = nil + q.mu.Unlock() + + for _, f := range queue { + f() + } +} + +// Shutdown implements [execQueue]. +func (q *testExecQueue) Shutdown() {} + +// Wait implements [execQueue]. +func (q *testExecQueue) Wait(context.Context) error { return nil } + +// testBackend implements [ipnext.Backend] for testing purposes +// by calling the provided hooks when its methods are called. +type testBackend struct { + switchToBestProfileHook func(reason string) + + // mu protects the backend state. + // It is acquired on entry to the exported methods of the backend + // and released on exit, mimicking the behavior of the [LocalBackend]. + mu sync.Mutex +} + +func (b *testBackend) SwitchToBestProfile(reason string) { + b.mu.Lock() + defer b.mu.Unlock() + if b.switchToBestProfileHook != nil { + b.switchToBestProfileHook(reason) + } +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index a99d67cda..0f3ea1fbb 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -169,78 +169,6 @@ type watchSession struct { cancel context.CancelFunc // to shut down the session } -// Extension extends [LocalBackend] with additional functionality. -type Extension interface { - // Init is called to initialize the extension when the [LocalBackend] is created - // and before it starts running. If the extension cannot be initialized, - // it must return an error, and the Shutdown method will not be called. - // Any returned errors are not fatal; they are used for logging. - // TODO(nickkhyl): should we allow returning a fatal error? - Init(*LocalBackend) error - - // Shutdown is called when the [LocalBackend] is shutting down, - // if the extension was initialized. Any returned errors are not fatal; - // they are used for logging. - Shutdown() error -} - -// NewExtensionFn is a function that instantiates an [Extension]. -type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error) - -// registeredExtensions is a map of registered local backend extensions, -// where the key is the name of the extension and the value is the function -// that instantiates the extension. -var registeredExtensions map[string]NewExtensionFn - -// RegisterExtension registers a function that creates a [localBackendExtension]. -// It panics if newExt is nil or if an extension with the same name has already been registered. -func RegisterExtension(name string, newExt NewExtensionFn) { - if newExt == nil { - panic(fmt.Sprintf("lb: newExt is nil: %q", name)) - } - if _, ok := registeredExtensions[name]; ok { - panic(fmt.Sprintf("lb: duplicate extensions: %q", name)) - } - mak.Set(®isteredExtensions, name, newExt) -} - -// profileResolver is any function that returns a read-only view of a login profile. -// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] -// indicates that the profile is new and has not been persisted yet. -// -// It is called with [LocalBackend.mu] held. -type profileResolver func() ipn.LoginProfileView - -// NewControlClientCallback is a function to be called when a new [controlclient.Client] -// is created and before it is first used. The login profile and prefs represent -// the profile for which the cc is created and are always valid; however, the -// profile's [ipn.LoginProfileView.ID] returns a zero [ipn.ProfileID] if the profile -// is new and has not been persisted yet. -// -// The callback is called with [LocalBackend.mu] held and must not call -// any [LocalBackend] methods. -// -// It returns a function to be called when the cc is being shut down, -// or nil if no cleanup is needed. -type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView, ipn.PrefsView) (cleanup func()) - -// ProfileChangeCallback is a function to be called when the current login profile changes. -// The sameNode parameter indicates whether the profile represents the same node as before, -// such as when only the profile metadata is updated but the node ID remains the same, -// or when a new profile is persisted and assigned an [ipn.ProfileID] for the first time. -// The subscribers can use this information to decide whether to reset their state. -// -// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] -// returns a zero [ipn.ProfileID] if the profile is new and has not been persisted yet. -// -// The callback is called with [LocalBackend.mu] held and must not call -// any [LocalBackend] methods. -type ProfileChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) - -// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for -// logging auditable actions. -type AuditLogProvider func() ipnauth.AuditLogFunc - // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -311,6 +239,13 @@ type LocalBackend struct { // for testing and graceful shutdown purposes. goTracker goroutines.Tracker + // extHost is the bridge between [LocalBackend] and the registered [ipnext.Extension]s. + // It may be nil in tests that use direct composite literal initialization of [LocalBackend] + // instead of calling [NewLocalBackend]. A nil pointer is a valid, no-op host. + // It can be used with or without b.mu held, but is typically used with it held + // to prevent state changes while invoking callbacks. + extHost *ExtensionHost + // The mutex protects the following elements. mu sync.Mutex conf *conffile.Config // latest parsed config, or nil if not in declarative mode @@ -378,9 +313,6 @@ type LocalBackend struct { c2nUpdateStatus updateStatus currentUser ipnauth.Actor - // backgroundProfileResolvers are optional background profile resolvers. - backgroundProfileResolvers set.HandleSet[profileResolver] - selfUpdateProgress []ipnstate.UpdateProgress lastSelfUpdateState ipnstate.SelfUpdateStatus // capForcedNetfilter is the netfilter that control instructs Linux clients @@ -481,25 +413,6 @@ type LocalBackend struct { // reconnectTimer is used to schedule a reconnect by setting [ipn.Prefs.WantRunning] // to true after a delay, or nil if no reconnect is scheduled. reconnectTimer tstime.TimerController - - // shutdownCbs are the callbacks to be called when the backend is shutting down. - // Each callback is called exactly once in unspecified order and without b.mu held. - // Returned errors are logged but otherwise ignored and do not affect the shutdown process. - shutdownCbs set.HandleSet[func() error] - - // newControlClientCbs are the functions to be called when a new control client is created. - newControlClientCbs set.HandleSet[NewControlClientCallback] - - // profileChangeCbs are the callbacks to be called when the current login profile changes, - // either because of a profile switch, or because the profile information was updated - // by [LocalBackend.SetControlClientStatus], including when the profile is first populated - // and persisted. - profileChangeCbs set.HandleSet[ProfileChangeCallback] - - // auditLoggers is a collection of registered audit log providers. - // Each [AuditLogProvider] is called to get an [ipnauth.AuditLogFunc] when an auditable action - // is about to be performed. If an audit logger returns an error, the action is denied. - auditLoggers set.HandleSet[AuditLogProvider] } // HealthTracker returns the health tracker for the backend. @@ -614,6 +527,10 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } + if b.extHost, err = NewExtensionHost(logf, sys, b); err != nil { + return nil, fmt.Errorf("failed to create extension host: %w", err) + } + if b.unregisterSysPolicyWatch, err = b.registerSysPolicyWatch(); err != nil { return nil, err } @@ -668,19 +585,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } - for name, newFn := range registeredExtensions { - ext, err := newFn(logf, sys) - if err != nil { - b.logf("lb: failed to create %q extension: %v", name, err) - continue - } - if err := ext.Init(b); err != nil { - b.logf("lb: failed to initialize %q extension: %v", name, err) - continue - } - b.shutdownCbs.Add(ext.Shutdown) - } - + b.extHost.Init() return b, nil } @@ -1143,17 +1048,11 @@ func (b *LocalBackend) Shutdown() { if b.notifyCancel != nil { b.notifyCancel() } - shutdownCbs := slices.Collect(maps.Values(b.shutdownCbs)) - b.shutdownCbs = nil + extHost := b.extHost + b.extHost = nil b.mu.Unlock() b.webClientShutdown() - for _, cb := range shutdownCbs { - if err := cb(); err != nil { - b.logf("shutdown callback failed: %v", err) - } - } - if b.sockstatLogger != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -1170,6 +1069,7 @@ func (b *LocalBackend) Shutdown() { if cc != nil { cc.Shutdown() } + extHost.Shutdown() b.ctxCancel() b.e.Close() <-b.e.Done() @@ -1743,7 +1643,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // If the profile ID was empty before SetPrefs, it's a new profile // and the user has just completed a login for the first time. sameNode := profile.ID() == "" || profile.ID() == cp.ID() - b.notifyProfileChangeLocked(profile, prefs.View(), sameNode) + b.extHost.NotifyProfileChange(profile, prefs.View(), sameNode) } } @@ -2492,11 +2392,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if err != nil { return err } - for _, cb := range b.newControlClientCbs { - if cleanup := cb(cc, b.pm.CurrentProfile(), prefs); cleanup != nil { - ccShutdownCbs = append(ccShutdownCbs, cleanup) - } - } + ccShutdownCbs = b.extHost.NotifyNewControlClient(cc, b.pm.CurrentProfile(), prefs) b.setControlClientLocked(cc) endpoints := b.endpoints @@ -4060,6 +3956,10 @@ func (b *LocalBackend) switchToBestProfileLockedOnEntry(reason string, unlock un // // b.mu must be held. func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBackground bool) { + // TODO(nickkhyl): delegate all of this to the extensions and remove the distinction + // between "foreground" and "background" profiles as we migrate away from the concept + // of a single "current user" on Windows. See tailscale/corp#18342. + // // If a GUI/CLI client is connected, use the connected user's profile, which means // either the current profile if owned by the user, or their default profile. if b.currentUser != nil { @@ -4079,7 +3979,12 @@ func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBac // If the returned background profileID is "", Tailscale will disconnect // and remain idle until a GUI or CLI client connects. if goos := envknob.GOOS(); goos == "windows" { - profile := b.getBackgroundProfileLocked() + // If Unattended Mode is enabled for the current profile, keep using it. + if b.pm.CurrentPrefs().ForceDaemon() { + return b.pm.CurrentProfile(), true + } + // Otherwise, use the profile returned by the extension. + profile := b.extHost.DetermineBackgroundProfile(b.pm) return profile, true } @@ -4092,47 +3997,6 @@ func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBac return b.pm.CurrentProfile(), false } -// RegisterBackgroundProfileResolver registers a function to be used when -// resolving the background profile, until the returned unregister function is called. -func (b *LocalBackend) RegisterBackgroundProfileResolver(resolver profileResolver) (unregister func()) { - // TODO(nickkhyl): should we allow specifying some kind of priority/altitude for the resolver? - b.mu.Lock() - defer b.mu.Unlock() - handle := b.backgroundProfileResolvers.Add(resolver) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.backgroundProfileResolvers, handle) - } -} - -// getBackgroundProfileLocked returns a read-only view of the profile to use -// when no GUI/CLI client is connected. If Tailscale should not run in the background -// and should disconnect until a GUI/CLI client connects, the returned view is not valid. -// As of 2025-02-07, it is only used on Windows. -func (b *LocalBackend) getBackgroundProfileLocked() ipn.LoginProfileView { - // TODO(nickkhyl): check if the returned profile is allowed on the device, - // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. - // See tailscale/corp#26249. - - // If Unattended Mode is enabled for the current profile, keep using it. - if b.pm.CurrentPrefs().ForceDaemon() { - return b.pm.CurrentProfile() - } - - // Otherwise, attempt to resolve the background profile using the background - // profile resolvers available on the current platform. - for _, resolver := range b.backgroundProfileResolvers { - if profile := resolver(); profile.Valid() { - return profile - } - } - - // Otherwise, switch to an empty profile and disconnect Tailscale - // until a GUI or CLI client connects. - return ipn.LoginProfileView{} -} - // CurrentUserForTest returns the current user and the associated WindowsUserID. // It is used for testing only, and will be removed along with the rest of the // "current user" functionality as we progress on the multi-user improvements (tailscale/corp#18342). @@ -4351,47 +4215,6 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { return err } -// RegisterAuditLogProvider registers an audit log provider, which returns a function -// to be called when an auditable action is about to be performed. -// The returned function unregisters the provider. -// It panics if the provider is nil. -func (b *LocalBackend) RegisterAuditLogProvider(provider AuditLogProvider) (unregister func()) { - if provider == nil { - panic("nil audit log provider") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.auditLoggers.Add(provider) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.auditLoggers, handle) - } -} - -// getAuditLoggerLocked returns a function that calls all currently registered -// audit loggers, failing as soon as any of them returns an error. -// -// b.mu must be held. -func (b *LocalBackend) getAuditLoggerLocked() ipnauth.AuditLogFunc { - var loggers []ipnauth.AuditLogFunc - if len(b.auditLoggers) != 0 { - loggers = make([]ipnauth.AuditLogFunc, 0, len(b.auditLoggers)) - for _, getLogger := range b.auditLoggers { - loggers = append(loggers, getLogger()) - } - } - return func(action tailcfg.ClientAuditAction, details string) error { - b.logf("auditlog: %v: %v", action, details) - for _, logger := range loggers { - if err := logger(action, details); err != nil { - return err - } - } - return nil - } -} - // EditPrefs applies the changes in mp to the current prefs, // acting as the tailscaled itself rather than a specific user. func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { @@ -4417,7 +4240,7 @@ func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ip unlock := b.lockAndGetUnlock() defer unlock() if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() { - if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.getAuditLoggerLocked()); err != nil { + if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.extHost.AuditLogger()); err != nil { b.logf("check profile access failed: %v", err) return ipn.PrefsView{}, err } @@ -6031,23 +5854,6 @@ func (b *LocalBackend) requestEngineStatusAndWait() { b.logf("requestEngineStatusAndWait: got status update.") } -// RegisterControlClientCallback registers a function to be called every time a new -// control client is created, until the returned unregister function is called. -// It panics if the cb is nil. -func (b *LocalBackend) RegisterControlClientCallback(cb NewControlClientCallback) (unregister func()) { - if cb == nil { - panic("nil control client callback") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.newControlClientCbs.Add(cb) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.newControlClientCbs, handle) - } -} - // setControlClientLocked sets the control client to cc, // which may be nil. // @@ -7633,37 +7439,6 @@ func (b *LocalBackend) resetDialPlan() { } } -// RegisterProfileChangeCallback registers a function to be called when the current [ipn.LoginProfile] changes. -// If includeCurrent is true, the callback is called immediately with the current profile. -// The returned function unregisters the callback. -// It panics if the cb is nil. -func (b *LocalBackend) RegisterProfileChangeCallback(cb ProfileChangeCallback, includeCurrent bool) (unregister func()) { - if cb == nil { - panic("nil profile change callback") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.profileChangeCbs.Add(cb) - if includeCurrent { - cb(b.pm.CurrentProfile(), stripKeysFromPrefs(b.pm.CurrentPrefs()), false) - } - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.profileChangeCbs, handle) - } -} - -// notifyProfileChangeLocked invokes all registered profile change callbacks. -// -// b.mu must be held. -func (b *LocalBackend) notifyProfileChangeLocked(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { - prefs = stripKeysFromPrefs(prefs) - for _, cb := range b.profileChangeCbs { - cb(profile, prefs, sameNode) - } -} - // getHardwareAddrs returns the hardware addresses for the machine. If the list // of hardware addresses is empty, it will return the previously known hardware // addresses. Both the current, and previously known hardware addresses might be @@ -7711,7 +7486,7 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err b.lastSuggestedExitNode = "" b.keyExpired = false b.resetAlwaysOnOverrideLocked() - b.notifyProfileChangeLocked(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) + b.extHost.NotifyProfileChange(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) b.enterStateLockedOnEntry(ipn.NoState, unlock) // Reset state; releases b.mu b.health.SetLocalLogConfigHealth(nil) diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 901a4a899..057fe2aae 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -17,6 +17,7 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" @@ -24,6 +25,9 @@ import ( var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") +// [profileManager] implements [ipnext.ProfileStore]. +var _ ipnext.ProfileStore = (*profileManager)(nil) + // profileManager is a wrapper around an [ipn.StateStore] that manages // multiple profiles and the current profile. // diff --git a/tsd/tsd.go b/tsd/tsd.go index 1d1f35017..acd09560c 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -26,7 +26,6 @@ import ( "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/conffile" - "tailscale.com/ipn/desktop" "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -53,7 +52,6 @@ type System struct { Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl DriveForLocal SubSystem[drive.FileSystemForLocal] DriveForRemote SubSystem[drive.FileSystemForRemote] - SessionManager SubSystem[desktop.SessionManager] // InitialConfig is initial server config, if any. // It is nil if the node is not in declarative mode. @@ -112,8 +110,6 @@ func (s *System) Set(v any) { s.DriveForLocal.Set(v) case drive.FileSystemForRemote: s.DriveForRemote.Set(v) - case desktop.SessionManager: - s.SessionManager.Set(v) default: panic(fmt.Sprintf("unknown type %T", v)) }