mirror of
https://github.com/tailscale/tailscale.git
synced 2025-09-21 05:31:36 +02:00
control/controlclient: introduce eventbus messages instead of callbacks (#16956)
This is a small introduction of the eventbus into controlclient that communicates with mainly ipnlocal. While ipnlocal is a complicated part of the codebase, the subscribers here are from the perspective of ipnlocal already called async. Updates #15160 Signed-off-by: Claus Lensbøl <claus@tailscale.com>
This commit is contained in:
parent
782c16c513
commit
b816fd7117
@ -205,7 +205,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) {
|
||||
}
|
||||
})
|
||||
return c, nil
|
||||
|
||||
}
|
||||
|
||||
// SetPaused controls whether HTTP activity should be paused.
|
||||
@ -424,6 +423,11 @@ func (c *Auto) unpausedChanLocked() <-chan bool {
|
||||
return unpaused
|
||||
}
|
||||
|
||||
// ClientID returns the ClientID of the direct controlClient
|
||||
func (c *Auto) ClientID() int64 {
|
||||
return c.direct.ClientID()
|
||||
}
|
||||
|
||||
// mapRoutineState is the state of Auto.mapRoutine while it's running.
|
||||
type mapRoutineState struct {
|
||||
c *Auto
|
||||
|
@ -81,6 +81,9 @@ type Client interface {
|
||||
// in a separate http request. It has nothing to do with the rest of
|
||||
// the state machine.
|
||||
UpdateEndpoints(endpoints []tailcfg.Endpoint)
|
||||
// ClientID returns the ClientID of a client. This ID is meant to
|
||||
// distinguish one client from another.
|
||||
ClientID() int64
|
||||
}
|
||||
|
||||
// UserVisibleError is an error that should be shown to users.
|
||||
|
@ -35,6 +35,7 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/persist"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
)
|
||||
|
||||
func fieldsOf(t reflect.Type) (fields []string) {
|
||||
@ -218,6 +219,8 @@ func TestDirectProxyManual(t *testing.T) {
|
||||
t.Skip("skipping without --live-network-test")
|
||||
}
|
||||
|
||||
bus := eventbustest.NewBus(t)
|
||||
|
||||
dialer := &tsdial.Dialer{}
|
||||
dialer.SetNetMon(netmon.NewStatic())
|
||||
|
||||
@ -239,6 +242,7 @@ func TestDirectProxyManual(t *testing.T) {
|
||||
},
|
||||
Dialer: dialer,
|
||||
ControlKnobs: &controlknobs.Knobs{},
|
||||
Bus: bus,
|
||||
}
|
||||
d, err := NewDirect(opts)
|
||||
if err != nil {
|
||||
@ -263,6 +267,8 @@ func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
|
||||
func testHTTPS(t *testing.T, withProxy bool) {
|
||||
bakedroots.ResetForTest(t, tlstest.TestRootCA())
|
||||
|
||||
bus := eventbustest.NewBus(t)
|
||||
|
||||
controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -327,6 +333,7 @@ func testHTTPS(t *testing.T, withProxy bool) {
|
||||
t.Logf("PopBrowserURL: %q", url)
|
||||
},
|
||||
Dialer: dialer,
|
||||
Bus: bus,
|
||||
}
|
||||
d, err := NewDirect(opts)
|
||||
if err != nil {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@ -52,6 +53,7 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/tkatype"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/multierr"
|
||||
"tailscale.com/util/singleflight"
|
||||
"tailscale.com/util/syspolicy/pkey"
|
||||
@ -63,30 +65,31 @@ import (
|
||||
|
||||
// Direct is the client that connects to a tailcontrol server for a node.
|
||||
type Direct struct {
|
||||
httpc *http.Client // HTTP client used to talk to tailcontrol
|
||||
interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial
|
||||
dialer *tsdial.Dialer
|
||||
dnsCache *dnscache.Resolver
|
||||
controlKnobs *controlknobs.Knobs // always non-nil
|
||||
serverURL string // URL of the tailcontrol server
|
||||
clock tstime.Clock
|
||||
logf logger.Logf
|
||||
netMon *netmon.Monitor // non-nil
|
||||
health *health.Tracker
|
||||
discoPubKey key.DiscoPublic
|
||||
getMachinePrivKey func() (key.MachinePrivate, error)
|
||||
debugFlags []string
|
||||
skipIPForwardingCheck bool
|
||||
pinger Pinger
|
||||
polc policyclient.Client // always non-nil
|
||||
popBrowser func(url string) // or nil
|
||||
c2nHandler http.Handler // or nil
|
||||
onClientVersion func(*tailcfg.ClientVersion) // or nil
|
||||
onControlTime func(time.Time) // or nil
|
||||
onTailnetDefaultAutoUpdate func(bool) // or nil
|
||||
panicOnUse bool // if true, panic if client is used (for testing)
|
||||
closedCtx context.Context // alive until Direct.Close is called
|
||||
closeCtx context.CancelFunc // cancels closedCtx
|
||||
httpc *http.Client // HTTP client used to talk to tailcontrol
|
||||
interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial
|
||||
dialer *tsdial.Dialer
|
||||
dnsCache *dnscache.Resolver
|
||||
controlKnobs *controlknobs.Knobs // always non-nil
|
||||
serverURL string // URL of the tailcontrol server
|
||||
clock tstime.Clock
|
||||
logf logger.Logf
|
||||
netMon *netmon.Monitor // non-nil
|
||||
health *health.Tracker
|
||||
discoPubKey key.DiscoPublic
|
||||
busClient *eventbus.Client
|
||||
clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion]
|
||||
autoUpdatePub *eventbus.Publisher[AutoUpdate]
|
||||
controlTimePub *eventbus.Publisher[ControlTime]
|
||||
getMachinePrivKey func() (key.MachinePrivate, error)
|
||||
debugFlags []string
|
||||
skipIPForwardingCheck bool
|
||||
pinger Pinger
|
||||
popBrowser func(url string) // or nil
|
||||
polc policyclient.Client // always non-nil
|
||||
c2nHandler http.Handler // or nil
|
||||
panicOnUse bool // if true, panic if client is used (for testing)
|
||||
closedCtx context.Context // alive until Direct.Close is called
|
||||
closeCtx context.CancelFunc // cancels closedCtx
|
||||
|
||||
dialPlan ControlDialPlanner // can be nil
|
||||
|
||||
@ -107,6 +110,8 @@ type Direct struct {
|
||||
tkaHead string
|
||||
lastPingURL string // last PingRequest.URL received, for dup suppression
|
||||
connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest
|
||||
|
||||
controlClientID int64 // Random ID used to differentiate clients for consumers of messages.
|
||||
}
|
||||
|
||||
// Observer is implemented by users of the control client (such as LocalBackend)
|
||||
@ -120,26 +125,24 @@ type Observer interface {
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Persist persist.Persist // initial persistent data
|
||||
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
|
||||
ServerURL string // URL of the tailcontrol server
|
||||
AuthKey string // optional node auth key for auto registration
|
||||
Clock tstime.Clock
|
||||
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
|
||||
DiscoPublicKey key.DiscoPublic
|
||||
PolicyClient policyclient.Client // or nil for none
|
||||
Logf logger.Logf
|
||||
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
|
||||
NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only)
|
||||
DebugFlags []string // debug settings to send to control
|
||||
HealthTracker *health.Tracker
|
||||
PopBrowserURL func(url string) // optional func to open browser
|
||||
OnClientVersion func(*tailcfg.ClientVersion) // optional func to inform GUI of client version status
|
||||
OnControlTime func(time.Time) // optional func to notify callers of new time from control
|
||||
OnTailnetDefaultAutoUpdate func(bool) // optional func to inform GUI of default auto-update setting for the tailnet
|
||||
Dialer *tsdial.Dialer // non-nil
|
||||
C2NHandler http.Handler // or nil
|
||||
ControlKnobs *controlknobs.Knobs // or nil to ignore
|
||||
Persist persist.Persist // initial persistent data
|
||||
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
|
||||
ServerURL string // URL of the tailcontrol server
|
||||
AuthKey string // optional node auth key for auto registration
|
||||
Clock tstime.Clock
|
||||
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
|
||||
DiscoPublicKey key.DiscoPublic
|
||||
PolicyClient policyclient.Client // or nil for none
|
||||
Logf logger.Logf
|
||||
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
|
||||
NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only)
|
||||
DebugFlags []string // debug settings to send to control
|
||||
HealthTracker *health.Tracker
|
||||
PopBrowserURL func(url string) // optional func to open browser
|
||||
Dialer *tsdial.Dialer // non-nil
|
||||
C2NHandler http.Handler // or nil
|
||||
ControlKnobs *controlknobs.Knobs // or nil to ignore
|
||||
Bus *eventbus.Bus
|
||||
|
||||
// Observer is called when there's a change in status to report
|
||||
// from the control client.
|
||||
@ -287,33 +290,32 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
}
|
||||
|
||||
c := &Direct{
|
||||
httpc: httpc,
|
||||
interceptedDial: interceptedDial,
|
||||
controlKnobs: opts.ControlKnobs,
|
||||
getMachinePrivKey: opts.GetMachinePrivateKey,
|
||||
serverURL: opts.ServerURL,
|
||||
clock: opts.Clock,
|
||||
logf: opts.Logf,
|
||||
persist: opts.Persist.View(),
|
||||
authKey: opts.AuthKey,
|
||||
discoPubKey: opts.DiscoPublicKey,
|
||||
debugFlags: opts.DebugFlags,
|
||||
netMon: netMon,
|
||||
health: opts.HealthTracker,
|
||||
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
|
||||
pinger: opts.Pinger,
|
||||
polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})),
|
||||
popBrowser: opts.PopBrowserURL,
|
||||
onClientVersion: opts.OnClientVersion,
|
||||
onTailnetDefaultAutoUpdate: opts.OnTailnetDefaultAutoUpdate,
|
||||
onControlTime: opts.OnControlTime,
|
||||
c2nHandler: opts.C2NHandler,
|
||||
dialer: opts.Dialer,
|
||||
dnsCache: dnsCache,
|
||||
dialPlan: opts.DialPlan,
|
||||
httpc: httpc,
|
||||
interceptedDial: interceptedDial,
|
||||
controlKnobs: opts.ControlKnobs,
|
||||
getMachinePrivKey: opts.GetMachinePrivateKey,
|
||||
serverURL: opts.ServerURL,
|
||||
clock: opts.Clock,
|
||||
logf: opts.Logf,
|
||||
persist: opts.Persist.View(),
|
||||
authKey: opts.AuthKey,
|
||||
discoPubKey: opts.DiscoPublicKey,
|
||||
debugFlags: opts.DebugFlags,
|
||||
netMon: netMon,
|
||||
health: opts.HealthTracker,
|
||||
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
|
||||
pinger: opts.Pinger,
|
||||
polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})),
|
||||
popBrowser: opts.PopBrowserURL,
|
||||
c2nHandler: opts.C2NHandler,
|
||||
dialer: opts.Dialer,
|
||||
dnsCache: dnsCache,
|
||||
dialPlan: opts.DialPlan,
|
||||
}
|
||||
c.closedCtx, c.closeCtx = context.WithCancel(context.Background())
|
||||
|
||||
c.controlClientID = rand.Int64()
|
||||
|
||||
if opts.Hostinfo == nil {
|
||||
c.SetHostinfo(hostinfo.New())
|
||||
} else {
|
||||
@ -331,6 +333,12 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
if strings.Contains(opts.ServerURL, "controlplane.tailscale.com") && envknob.Bool("TS_PANIC_IF_HIT_MAIN_CONTROL") {
|
||||
c.panicOnUse = true
|
||||
}
|
||||
|
||||
c.busClient = opts.Bus.Client("controlClient.direct")
|
||||
c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient)
|
||||
c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient)
|
||||
c.controlTimePub = eventbus.Publish[ControlTime](c.busClient)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@ -340,6 +348,7 @@ func (c *Direct) Close() error {
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.busClient.Close()
|
||||
if c.noiseClient != nil {
|
||||
if err := c.noiseClient.Close(); err != nil {
|
||||
return err
|
||||
@ -826,6 +835,23 @@ func (c *Direct) SendUpdate(ctx context.Context) error {
|
||||
return c.sendMapRequest(ctx, false, nil)
|
||||
}
|
||||
|
||||
// ClientID returns the ControlClientID of the controlClient
|
||||
func (c *Direct) ClientID() int64 {
|
||||
return c.controlClientID
|
||||
}
|
||||
|
||||
// AutoUpdate wraps a bool for naming on the eventbus
|
||||
type AutoUpdate struct {
|
||||
ClientID int64 // The ID field is used for consumers to differentiate instances of Direct
|
||||
Value bool
|
||||
}
|
||||
|
||||
// ControlTime wraps a [time.Time] for naming on the eventbus
|
||||
type ControlTime struct {
|
||||
ClientID int64 // The ID field is used for consumers to differentiate instances of Direct
|
||||
Value time.Time
|
||||
}
|
||||
|
||||
// If we go more than watchdogTimeout without hearing from the server,
|
||||
// end the long poll. We should be receiving a keep alive ping
|
||||
// every minute.
|
||||
@ -1085,14 +1111,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
|
||||
c.logf("netmap: control says to open URL %v; no popBrowser func", u)
|
||||
}
|
||||
}
|
||||
if resp.ClientVersion != nil && c.onClientVersion != nil {
|
||||
c.onClientVersion(resp.ClientVersion)
|
||||
if resp.ClientVersion != nil {
|
||||
c.clientVersionPub.Publish(*resp.ClientVersion)
|
||||
}
|
||||
if resp.ControlTime != nil && !resp.ControlTime.IsZero() {
|
||||
c.logf.JSON(1, "controltime", resp.ControlTime.UTC())
|
||||
if c.onControlTime != nil {
|
||||
c.onControlTime(*resp.ControlTime)
|
||||
}
|
||||
c.controlTimePub.Publish(ControlTime{c.controlClientID, *resp.ControlTime})
|
||||
}
|
||||
if resp.KeepAlive {
|
||||
vlogf("netmap: got keep-alive")
|
||||
@ -1112,9 +1136,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
|
||||
continue
|
||||
}
|
||||
if au, ok := resp.DefaultAutoUpdate.Get(); ok {
|
||||
if c.onTailnetDefaultAutoUpdate != nil {
|
||||
c.onTailnetDefaultAutoUpdate(au)
|
||||
}
|
||||
c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, au})
|
||||
}
|
||||
|
||||
metricMapResponseMap.Add(1)
|
||||
|
@ -17,12 +17,14 @@ import (
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
)
|
||||
|
||||
func TestNewDirect(t *testing.T) {
|
||||
hi := hostinfo.New()
|
||||
ni := tailcfg.NetInfo{LinkType: "wired"}
|
||||
hi.NetInfo = &ni
|
||||
bus := eventbustest.NewBus(t)
|
||||
|
||||
k := key.NewMachine()
|
||||
opts := Options{
|
||||
@ -32,6 +34,7 @@ func TestNewDirect(t *testing.T) {
|
||||
return k, nil
|
||||
},
|
||||
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
||||
Bus: bus,
|
||||
}
|
||||
c, err := NewDirect(opts)
|
||||
if err != nil {
|
||||
@ -99,6 +102,7 @@ func TestTsmpPing(t *testing.T) {
|
||||
hi := hostinfo.New()
|
||||
ni := tailcfg.NetInfo{LinkType: "wired"}
|
||||
hi.NetInfo = &ni
|
||||
bus := eventbustest.NewBus(t)
|
||||
|
||||
k := key.NewMachine()
|
||||
opts := Options{
|
||||
@ -108,6 +112,7 @@ func TestTsmpPing(t *testing.T) {
|
||||
return k, nil
|
||||
},
|
||||
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
||||
Bus: bus,
|
||||
}
|
||||
|
||||
c, err := NewDirect(opts)
|
||||
|
@ -6,12 +6,14 @@ package ipnlocal
|
||||
import (
|
||||
"time"
|
||||
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstime"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/eventbus"
|
||||
)
|
||||
|
||||
// For extra defense-in-depth, when we're testing expired nodes we check
|
||||
@ -40,14 +42,46 @@ type expiryManager struct {
|
||||
|
||||
logf logger.Logf
|
||||
clock tstime.Clock
|
||||
|
||||
eventClient *eventbus.Client
|
||||
controlTimeSub *eventbus.Subscriber[controlclient.ControlTime]
|
||||
subsDoneCh chan struct{} // closed when consumeEventbusTopics returns
|
||||
}
|
||||
|
||||
func newExpiryManager(logf logger.Logf) *expiryManager {
|
||||
return &expiryManager{
|
||||
func newExpiryManager(logf logger.Logf, bus *eventbus.Bus) *expiryManager {
|
||||
em := &expiryManager{
|
||||
previouslyExpired: map[tailcfg.StableNodeID]bool{},
|
||||
logf: logf,
|
||||
clock: tstime.StdClock{},
|
||||
}
|
||||
|
||||
em.eventClient = bus.Client("ipnlocal.expiryManager")
|
||||
em.controlTimeSub = eventbus.Subscribe[controlclient.ControlTime](em.eventClient)
|
||||
|
||||
em.subsDoneCh = make(chan struct{})
|
||||
go em.consumeEventbusTopics()
|
||||
|
||||
return em
|
||||
}
|
||||
|
||||
// consumeEventbusTopics consumes events from all relevant
|
||||
// [eventbus.Subscriber]'s and passes them to their related handler. Events are
|
||||
// always handled in the order they are received, i.e. the next event is not
|
||||
// read until the previous event's handler has returned. It returns when the
|
||||
// [controlclient.ControlTime] subscriber is closed, which is interpreted to be the
|
||||
// same as the [eventbus.Client] closing ([eventbus.Subscribers] are either
|
||||
// all open or all closed).
|
||||
func (em *expiryManager) consumeEventbusTopics() {
|
||||
defer close(em.subsDoneCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-em.controlTimeSub.Done():
|
||||
return
|
||||
case time := <-em.controlTimeSub.Events():
|
||||
em.onControlTime(time.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onControlTime is called whenever we receive a new timestamp from the control
|
||||
@ -218,6 +252,11 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim
|
||||
return nextExpiry
|
||||
}
|
||||
|
||||
func (em *expiryManager) close() {
|
||||
em.eventClient.Close()
|
||||
<-em.subsDoneCh
|
||||
}
|
||||
|
||||
// ControlNow estimates the current time on the control server, calculated as
|
||||
// localNow + the delta between local and control server clocks as recorded
|
||||
// when the LocalBackend last received a time message from the control server.
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
)
|
||||
|
||||
func TestFlagExpiredPeers(t *testing.T) {
|
||||
@ -110,7 +111,8 @@ func TestFlagExpiredPeers(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
em := newExpiryManager(t.Logf)
|
||||
bus := eventbustest.NewBus(t)
|
||||
em := newExpiryManager(t.Logf, bus)
|
||||
em.clock = tstest.NewClock(tstest.ClockOpts{Start: now})
|
||||
if tt.controlTime != nil {
|
||||
em.onControlTime(*tt.controlTime)
|
||||
@ -240,7 +242,8 @@ func TestNextPeerExpiry(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
em := newExpiryManager(t.Logf)
|
||||
bus := eventbustest.NewBus(t)
|
||||
em := newExpiryManager(t.Logf, bus)
|
||||
em.clock = tstest.NewClock(tstest.ClockOpts{Start: now})
|
||||
got := em.nextPeerExpiry(tt.netmap, now)
|
||||
if !got.Equal(tt.want) {
|
||||
@ -253,7 +256,8 @@ func TestNextPeerExpiry(t *testing.T) {
|
||||
|
||||
t.Run("ClockSkew", func(t *testing.T) {
|
||||
t.Logf("local time: %q", now.Format(time.RFC3339))
|
||||
em := newExpiryManager(t.Logf)
|
||||
bus := eventbustest.NewBus(t)
|
||||
em := newExpiryManager(t.Logf, bus)
|
||||
em.clock = tstest.NewClock(tstest.ClockOpts{Start: now})
|
||||
|
||||
// The local clock is "running fast"; our clock skew is -2h
|
||||
|
@ -99,6 +99,7 @@ import (
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/goroutines"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/multierr"
|
||||
@ -202,6 +203,10 @@ type LocalBackend struct {
|
||||
keyLogf logger.Logf // for printing list of peers on change
|
||||
statsLogf logger.Logf // for printing peers stats on change
|
||||
sys *tsd.System
|
||||
eventClient *eventbus.Client
|
||||
clientVersionSub *eventbus.Subscriber[tailcfg.ClientVersion]
|
||||
autoUpdateSub *eventbus.Subscriber[controlclient.AutoUpdate]
|
||||
subsDoneCh chan struct{} // closed when consumeEventbusTopics returns
|
||||
health *health.Tracker // always non-nil
|
||||
polc policyclient.Client // always non-nil
|
||||
metrics metrics
|
||||
@ -525,7 +530,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
backendLogID: logID,
|
||||
state: ipn.NoState,
|
||||
portpoll: new(portlist.Poller),
|
||||
em: newExpiryManager(logf),
|
||||
em: newExpiryManager(logf, sys.Bus.Get()),
|
||||
loginFlags: loginFlags,
|
||||
clock: clock,
|
||||
selfUpdateProgress: make([]ipnstate.UpdateProgress, 0),
|
||||
@ -533,7 +538,11 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
captiveCtx: captiveCtx,
|
||||
captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running
|
||||
needsCaptiveDetection: make(chan bool),
|
||||
subsDoneCh: make(chan struct{}),
|
||||
}
|
||||
b.eventClient = b.Sys().Bus.Get().Client("ipnlocal.LocalBackend")
|
||||
b.clientVersionSub = eventbus.Subscribe[tailcfg.ClientVersion](b.eventClient)
|
||||
b.autoUpdateSub = eventbus.Subscribe[controlclient.AutoUpdate](b.eventClient)
|
||||
nb := newNodeBackend(ctx, b.sys.Bus.Get())
|
||||
b.currentNodeAtomic.Store(nb)
|
||||
nb.ready()
|
||||
@ -604,9 +613,32 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
}
|
||||
}
|
||||
}
|
||||
go b.consumeEventbusTopics()
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// consumeEventbusTopics consumes events from all relevant
|
||||
// [eventbus.Subscriber]'s and passes them to their related handler. Events are
|
||||
// always handled in the order they are received, i.e. the next event is not
|
||||
// read until the previous event's handler has returned. It returns when the
|
||||
// [tailcfg.ClientVersion] subscriber is closed, which is interpreted to be the
|
||||
// same as the [eventbus.Client] closing ([eventbus.Subscribers] are either
|
||||
// all open or all closed).
|
||||
func (b *LocalBackend) consumeEventbusTopics() {
|
||||
defer close(b.subsDoneCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.clientVersionSub.Done():
|
||||
return
|
||||
case clientVersion := <-b.clientVersionSub.Events():
|
||||
b.onClientVersion(&clientVersion)
|
||||
case au := <-b.autoUpdateSub.Events():
|
||||
b.onTailnetDefaultAutoUpdate(au.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LocalBackend) Clock() tstime.Clock { return b.clock }
|
||||
func (b *LocalBackend) Sys() *tsd.System { return b.sys }
|
||||
|
||||
@ -1065,6 +1097,17 @@ func (b *LocalBackend) ClearCaptureSink() {
|
||||
// Shutdown halts the backend and all its sub-components. The backend
|
||||
// can no longer be used after Shutdown returns.
|
||||
func (b *LocalBackend) Shutdown() {
|
||||
// Close the [eventbus.Client] and wait for LocalBackend.consumeEventbusTopics
|
||||
// to return. Do this before acquiring b.mu:
|
||||
// 1. LocalBackend.consumeEventbusTopics event handlers also acquire b.mu,
|
||||
// they can deadlock with c.Shutdown().
|
||||
// 2. LocalBackend.consumeEventbusTopics event handlers may not guard against
|
||||
// undesirable post/in-progress LocalBackend.Shutdown() behaviors.
|
||||
b.eventClient.Close()
|
||||
<-b.subsDoneCh
|
||||
|
||||
b.em.close()
|
||||
|
||||
b.mu.Lock()
|
||||
if b.shutdownCalled {
|
||||
b.mu.Unlock()
|
||||
@ -2465,33 +2508,32 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(apenwarr): The only way to change the ServerURL is to
|
||||
// re-run b.Start, because this is the only place we create a
|
||||
// new controlclient. EditPrefs allows you to overwrite ServerURL,
|
||||
// but it won't take effect until the next Start.
|
||||
cc, err := b.getNewControlClientFuncLocked()(controlclient.Options{
|
||||
GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(),
|
||||
Logf: logger.WithPrefix(b.logf, "control: "),
|
||||
Persist: *persistv,
|
||||
ServerURL: serverURL,
|
||||
AuthKey: opts.AuthKey,
|
||||
Hostinfo: hostinfo,
|
||||
HTTPTestClient: httpTestClient,
|
||||
DiscoPublicKey: discoPublic,
|
||||
DebugFlags: debugFlags,
|
||||
HealthTracker: b.health,
|
||||
PolicyClient: b.sys.PolicyClientOrDefault(),
|
||||
Pinger: b,
|
||||
PopBrowserURL: b.tellClientToBrowseToURL,
|
||||
OnClientVersion: b.onClientVersion,
|
||||
OnTailnetDefaultAutoUpdate: b.onTailnetDefaultAutoUpdate,
|
||||
OnControlTime: b.em.onControlTime,
|
||||
Dialer: b.Dialer(),
|
||||
Observer: b,
|
||||
C2NHandler: http.HandlerFunc(b.handleC2N),
|
||||
DialPlan: &b.dialPlan, // pointer because it can't be copied
|
||||
ControlKnobs: b.sys.ControlKnobs(),
|
||||
Shutdown: ccShutdown,
|
||||
GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(),
|
||||
Logf: logger.WithPrefix(b.logf, "control: "),
|
||||
Persist: *persistv,
|
||||
ServerURL: serverURL,
|
||||
AuthKey: opts.AuthKey,
|
||||
Hostinfo: hostinfo,
|
||||
HTTPTestClient: httpTestClient,
|
||||
DiscoPublicKey: discoPublic,
|
||||
DebugFlags: debugFlags,
|
||||
HealthTracker: b.health,
|
||||
PolicyClient: b.sys.PolicyClientOrDefault(),
|
||||
Pinger: b,
|
||||
PopBrowserURL: b.tellClientToBrowseToURL,
|
||||
Dialer: b.Dialer(),
|
||||
Observer: b,
|
||||
C2NHandler: http.HandlerFunc(b.handleC2N),
|
||||
DialPlan: &b.dialPlan, // pointer because it can't be copied
|
||||
ControlKnobs: b.sys.ControlKnobs(),
|
||||
Shutdown: ccShutdown,
|
||||
Bus: b.sys.Bus.Get(),
|
||||
|
||||
// Don't warn about broken Linux IP forwarding when
|
||||
// netstack is being used.
|
||||
@ -4482,7 +4524,6 @@ func (b *LocalBackend) changeDisablesExitNodeLocked(prefs ipn.PrefsView, change
|
||||
// but wasn't empty before, then the change disables
|
||||
// exit node usage.
|
||||
return tmpPrefs.ExitNodeID == ""
|
||||
|
||||
}
|
||||
|
||||
// adjustEditPrefsLocked applies additional changes to mp if necessary,
|
||||
@ -8001,7 +8042,6 @@ func isAllowedAutoExitNodeID(polc policyclient.Client, exitNodeID tailcfg.Stable
|
||||
}
|
||||
if nodes, _ := polc.GetStringArray(pkey.AllowedSuggestedExitNodes, nil); nodes != nil {
|
||||
return slices.Contains(nodes, string(exitNodeID))
|
||||
|
||||
}
|
||||
return true // no policy configured; allow all exit nodes
|
||||
}
|
||||
@ -8145,9 +8185,7 @@ func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcf
|
||||
return servicesList
|
||||
}
|
||||
|
||||
var (
|
||||
metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus")
|
||||
)
|
||||
var metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus")
|
||||
|
||||
func (b *LocalBackend) stateEncrypted() opt.Bool {
|
||||
switch runtime.GOOS {
|
||||
|
@ -59,6 +59,7 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
@ -455,7 +456,8 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func newTestLocalBackend(t testing.TB) *LocalBackend {
|
||||
return newTestLocalBackendWithSys(t, tsd.NewSystem())
|
||||
bus := eventbustest.NewBus(t)
|
||||
return newTestLocalBackendWithSys(t, tsd.NewSystemWithBus(bus))
|
||||
}
|
||||
|
||||
// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System.
|
||||
@ -533,7 +535,6 @@ func TestZeroExitNodeViaLocalAPI(t *testing.T) {
|
||||
ExitNodeID: "",
|
||||
},
|
||||
}, user)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("enabling first exit node: %v", err)
|
||||
}
|
||||
@ -543,7 +544,6 @@ func TestZeroExitNodeViaLocalAPI(t *testing.T) {
|
||||
if got, want := pv.InternalExitNodePrior(), tailcfg.StableNodeID(""); got != want {
|
||||
t.Fatalf("unexpected InternalExitNodePrior %q, want: %q", got, want)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSetUseExitNodeEnabled(t *testing.T) {
|
||||
@ -3619,7 +3619,8 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
prefs := defaultPrefs.AsStruct()
|
||||
pp.set(prefs, tt.initialValue)
|
||||
|
||||
sys := tsd.NewSystem()
|
||||
bus := eventbustest.NewBus(t)
|
||||
sys := tsd.NewSystemWithBus(bus)
|
||||
sys.PolicyClient.Set(polc)
|
||||
|
||||
lb := newTestLocalBackendWithSys(t, sys)
|
||||
@ -5786,7 +5787,8 @@ func TestNotificationTargetMatch(t *testing.T) {
|
||||
type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client
|
||||
|
||||
func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend {
|
||||
return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystem(), newControl)
|
||||
bus := eventbustest.NewBus(t)
|
||||
return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystemWithBus(bus), newControl)
|
||||
}
|
||||
|
||||
func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend {
|
||||
@ -5945,7 +5947,6 @@ func (w *notificationWatcher) watch(mask ipn.NotifyWatchOpt, wanted []wantedNoti
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
}()
|
||||
<-watchAddedCh
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ import (
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/persist"
|
||||
"tailscale.com/types/tkatype"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
@ -49,6 +50,7 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto {
|
||||
hi := hostinfo.New()
|
||||
ni := tailcfg.NetInfo{LinkType: "wired"}
|
||||
hi.NetInfo = &ni
|
||||
bus := eventbustest.NewBus(t)
|
||||
|
||||
k := key.NewMachine()
|
||||
opts := controlclient.Options{
|
||||
@ -61,6 +63,7 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto {
|
||||
NoiseTestClient: c,
|
||||
Observer: observerFunc(func(controlclient.Status) {}),
|
||||
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
||||
Bus: bus,
|
||||
}
|
||||
|
||||
cc, err := controlclient.NewNoStart(opts)
|
||||
|
@ -33,6 +33,7 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/logid"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/syspolicy/policyclient"
|
||||
@ -240,11 +241,15 @@ func TestServeConfigForeground(t *testing.T) {
|
||||
|
||||
err := b.SetServeConfig(&ipn.ServeConfig{
|
||||
Foreground: map[string]*ipn.ServeConfig{
|
||||
session1: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
443: {TCPForward: "http://localhost:3000"}},
|
||||
session1: {
|
||||
TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
443: {TCPForward: "http://localhost:3000"},
|
||||
},
|
||||
},
|
||||
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
999: {TCPForward: "http://localhost:4000"}},
|
||||
session2: {
|
||||
TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
999: {TCPForward: "http://localhost:4000"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, "")
|
||||
@ -267,8 +272,10 @@ func TestServeConfigForeground(t *testing.T) {
|
||||
5000: {TCPForward: "http://localhost:5000"},
|
||||
},
|
||||
Foreground: map[string]*ipn.ServeConfig{
|
||||
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
999: {TCPForward: "http://localhost:4000"}},
|
||||
session2: {
|
||||
TCP: map[uint16]*ipn.TCPPortHandler{
|
||||
999: {TCPForward: "http://localhost:4000"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, "")
|
||||
@ -491,7 +498,6 @@ func TestServeConfigServices(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestServeConfigETag(t *testing.T) {
|
||||
@ -659,6 +665,7 @@ func TestServeHTTPProxyPath(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTPProxyHeaders(t *testing.T) {
|
||||
b := newTestBackend(t)
|
||||
|
||||
@ -859,7 +866,6 @@ func Test_reverseProxyConfiguration(t *testing.T) {
|
||||
wantsURL: mustCreateURL(t, "https://example3.com"),
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func mustCreateURL(t *testing.T, u string) url.URL {
|
||||
@ -878,7 +884,8 @@ func newTestBackend(t *testing.T, opts ...any) *LocalBackend {
|
||||
logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ")
|
||||
}
|
||||
|
||||
sys := tsd.NewSystem()
|
||||
bus := eventbustest.NewBus(t)
|
||||
sys := tsd.NewSystemWithBus(bus)
|
||||
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
@ -952,13 +959,13 @@ func newTestBackend(t *testing.T, opts ...any) *LocalBackend {
|
||||
func TestServeFileOrDirectory(t *testing.T) {
|
||||
td := t.TempDir()
|
||||
writeFile := func(suffix, contents string) {
|
||||
if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil {
|
||||
if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
writeFile("foo", "this is foo")
|
||||
writeFile("bar", "this is bar")
|
||||
os.MkdirAll(filepath.Join(td, "subdir"), 0700)
|
||||
os.MkdirAll(filepath.Join(td, "subdir"), 0o700)
|
||||
writeFile("subdir/file-a", "this is A")
|
||||
writeFile("subdir/file-b", "this is B")
|
||||
writeFile("subdir/file-c", "this is C")
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -39,6 +40,7 @@ import (
|
||||
"tailscale.com/types/persist"
|
||||
"tailscale.com/types/preftype"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/wgengine"
|
||||
@ -113,10 +115,11 @@ func (nt *notifyThrottler) drain(count int) []ipn.Notify {
|
||||
// in the controlclient.Client, so by controlling it, we can check that
|
||||
// the state machine works as expected.
|
||||
type mockControl struct {
|
||||
tb testing.TB
|
||||
logf logger.Logf
|
||||
opts controlclient.Options
|
||||
paused atomic.Bool
|
||||
tb testing.TB
|
||||
logf logger.Logf
|
||||
opts controlclient.Options
|
||||
paused atomic.Bool
|
||||
controlClientID int64
|
||||
|
||||
mu sync.Mutex
|
||||
persist *persist.Persist
|
||||
@ -127,12 +130,13 @@ type mockControl struct {
|
||||
|
||||
func newClient(tb testing.TB, opts controlclient.Options) *mockControl {
|
||||
return &mockControl{
|
||||
tb: tb,
|
||||
authBlocked: true,
|
||||
logf: opts.Logf,
|
||||
opts: opts,
|
||||
shutdown: make(chan struct{}),
|
||||
persist: opts.Persist.Clone(),
|
||||
tb: tb,
|
||||
authBlocked: true,
|
||||
logf: opts.Logf,
|
||||
opts: opts,
|
||||
shutdown: make(chan struct{}),
|
||||
persist: opts.Persist.Clone(),
|
||||
controlClientID: rand.Int64(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,6 +291,10 @@ func (cc *mockControl) UpdateEndpoints(endpoints []tailcfg.Endpoint) {
|
||||
cc.called("UpdateEndpoints")
|
||||
}
|
||||
|
||||
func (cc *mockControl) ClientID() int64 {
|
||||
return cc.controlClientID
|
||||
}
|
||||
|
||||
func (b *LocalBackend) nonInteractiveLoginForStateTest() {
|
||||
b.mu.Lock()
|
||||
if b.cc == nil {
|
||||
@ -1507,7 +1515,8 @@ func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) (
|
||||
dialer := &tsdial.Dialer{Logf: logf}
|
||||
dialer.SetNetMon(netmon.NewStatic())
|
||||
|
||||
sys := tsd.NewSystem()
|
||||
bus := eventbustest.NewBus(t)
|
||||
sys := tsd.NewSystemWithBus(bus)
|
||||
sys.Set(dialer)
|
||||
sys.Set(dialer.NetMon())
|
||||
|
||||
|
@ -35,6 +35,7 @@ import (
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/logid"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/slicesx"
|
||||
"tailscale.com/wgengine"
|
||||
)
|
||||
@ -158,7 +159,6 @@ func TestWhoIsArgTypes(t *testing.T) {
|
||||
t.Fatalf("backend called with %v; want %v", k, keyStr)
|
||||
}
|
||||
return match()
|
||||
|
||||
},
|
||||
peerCaps: map[netip.Addr]tailcfg.PeerCapMap{
|
||||
netip.MustParseAddr("100.101.102.103"): map[tailcfg.PeerCapability][]tailcfg.RawMessage{
|
||||
@ -336,7 +336,7 @@ func TestServeWatchIPNBus(t *testing.T) {
|
||||
|
||||
func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend {
|
||||
var logf logger.Logf = logger.Discard
|
||||
sys := tsd.NewSystem()
|
||||
sys := tsd.NewSystemWithBus(eventbustest.NewBus(t))
|
||||
store := new(mem.Store)
|
||||
sys.Set(store)
|
||||
eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get())
|
||||
|
12
tsd/tsd.go
12
tsd/tsd.go
@ -80,9 +80,17 @@ type System struct {
|
||||
|
||||
// NewSystem constructs a new otherwise-empty [System] with a
|
||||
// freshly-constructed event bus populated.
|
||||
func NewSystem() *System {
|
||||
func NewSystem() *System { return NewSystemWithBus(eventbus.New()) }
|
||||
|
||||
// NewSystemWithBus constructs a new otherwise-empty [System] with an
|
||||
// eventbus provided by the caller. The provided bus must not be nil.
|
||||
// This is mainly intended for testing; for production use call [NewBus].
|
||||
func NewSystemWithBus(bus *eventbus.Bus) *System {
|
||||
if bus == nil {
|
||||
panic("nil eventbus")
|
||||
}
|
||||
sys := new(System)
|
||||
sys.Set(eventbus.New())
|
||||
sys.Set(bus)
|
||||
return sys
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
|
||||
// NewBus constructs an [eventbus.Bus] that will be shut automatically when
|
||||
// its controlling test ends.
|
||||
func NewBus(t *testing.T) *eventbus.Bus {
|
||||
func NewBus(t testing.TB) *eventbus.Bus {
|
||||
bus := eventbus.New()
|
||||
t.Cleanup(bus.Close)
|
||||
return bus
|
||||
|
Loading…
x
Reference in New Issue
Block a user