From 3c32f87624ca2cbe384dc4b7a2e3b1925c672e5d Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Thu, 2 Oct 2025 09:18:55 -0700 Subject: [PATCH] feature/relayserver: use eventbus.Monitor to simplify lifecycle management (#17234) Instead of using separate channels to manage the lifecycle of the eventbus client, use the recently-added eventbus.Monitor, which handles signaling the processing loop to stop and waiting for it to complete. This allows us to simplify some of the setup and cleanup code in the relay server. Updates #15160 Change-Id: Ia1a47ce2e5a31bc8f546dca4c56c3141a40d67af Signed-off-by: M. J. Fromberger --- feature/relayserver/relayserver.go | 137 +++++++++++------------- feature/relayserver/relayserver_test.go | 12 +-- 2 files changed, 71 insertions(+), 78 deletions(-) diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 91d07484c..95bf29a11 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -82,11 +82,11 @@ type extension struct { logf logger.Logf bus *eventbus.Bus - mu sync.Mutex // guards the following fields - shutdown bool + mu sync.Mutex // guards the following fields + shutdown bool + port *int // ipn.Prefs.RelayServerPort, nil if disabled - disconnectFromBusCh chan struct{} // non-nil if consumeEventbusTopics is running, closed to signal it to return - busDoneCh chan struct{} // non-nil if consumeEventbusTopics is running, closed when it returns + eventSubs *eventbus.Monitor // nil if not connected to eventbus debugSessionsCh chan chan []status.ServerSession // non-nil if consumeEventbusTopics is running hasNodeAttrDisableRelayServer bool // tailcfg.NodeAttrDisableRelayServer } @@ -119,15 +119,13 @@ func (e *extension) handleBusLifetimeLocked() { if !busShouldBeRunning { e.disconnectFromBusLocked() return - } - if e.busDoneCh != nil { + } else if e.eventSubs != nil { return // already running } - port := *e.port - e.disconnectFromBusCh = make(chan struct{}) - e.busDoneCh = make(chan struct{}) + + ec := e.bus.Client("relayserver.extension") e.debugSessionsCh = make(chan chan []status.ServerSession) - go e.consumeEventbusTopics(port) + e.eventSubs = ptr.To(ec.Monitor(e.consumeEventbusTopics(ec, *e.port))) } func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { @@ -175,77 +173,72 @@ var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) { // consumeEventbusTopics serves endpoint allocation requests over the eventbus. // It also serves [relayServer] debug information on a channel. -// consumeEventbusTopics must never acquire [extension.mu], which can be held by -// other goroutines while waiting to receive on [extension.busDoneCh] or the +// consumeEventbusTopics must never acquire [extension.mu], which can be held +// by other goroutines while waiting to receive on [extension.eventSubs] or the // inner [extension.debugSessionsCh] channel. -func (e *extension) consumeEventbusTopics(port int) { - defer close(e.busDoneCh) +func (e *extension) consumeEventbusTopics(ec *eventbus.Client, port int) func(*eventbus.Client) { + reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](ec) + respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](ec) + debugSessionsCh := e.debugSessionsCh - eventClient := e.bus.Client("relayserver.extension") - reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](eventClient) - respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](eventClient) - defer eventClient.Close() - - var rs relayServer // lazily initialized - defer func() { - if rs != nil { - rs.Close() - } - }() - for { - select { - case <-e.disconnectFromBusCh: - return - case <-eventClient.Done(): - return - case respCh := <-e.debugSessionsCh: - if rs == nil { - // Don't initialize the server simply for a debug request. - respCh <- nil - continue + return func(ec *eventbus.Client) { + var rs relayServer // lazily initialized + defer func() { + if rs != nil { + rs.Close() } - sessions := rs.GetSessions() - respCh <- sessions - case req := <-reqSub.Events(): - if rs == nil { - var err error - rs, err = udprelay.NewServer(e.logf, port, overrideAddrs()) - if err != nil { - e.logf("error initializing server: %v", err) + }() + for { + select { + case <-ec.Done(): + return + case respCh := <-debugSessionsCh: + if rs == nil { + // Don't initialize the server simply for a debug request. + respCh <- nil continue } - } - se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) - if err != nil { - e.logf("error allocating endpoint: %v", err) - continue - } - respPub.Publish(magicsock.UDPRelayAllocResp{ - ReqRxFromNodeKey: req.RxFromNodeKey, - ReqRxFromDiscoKey: req.RxFromDiscoKey, - Message: &disco.AllocateUDPRelayEndpointResponse{ - Generation: req.Message.Generation, - UDPRelayEndpoint: disco.UDPRelayEndpoint{ - ServerDisco: se.ServerDisco, - ClientDisco: se.ClientDisco, - LamportID: se.LamportID, - VNI: se.VNI, - BindLifetime: se.BindLifetime.Duration, - SteadyStateLifetime: se.SteadyStateLifetime.Duration, - AddrPorts: se.AddrPorts, + sessions := rs.GetSessions() + respCh <- sessions + case req := <-reqSub.Events(): + if rs == nil { + var err error + rs, err = udprelay.NewServer(e.logf, port, overrideAddrs()) + if err != nil { + e.logf("error initializing server: %v", err) + continue + } + } + se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) + if err != nil { + e.logf("error allocating endpoint: %v", err) + continue + } + respPub.Publish(magicsock.UDPRelayAllocResp{ + ReqRxFromNodeKey: req.RxFromNodeKey, + ReqRxFromDiscoKey: req.RxFromDiscoKey, + Message: &disco.AllocateUDPRelayEndpointResponse{ + Generation: req.Message.Generation, + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, }, - }, - }) + }) + } } } } func (e *extension) disconnectFromBusLocked() { - if e.busDoneCh != nil { - close(e.disconnectFromBusCh) - <-e.busDoneCh - e.busDoneCh = nil - e.disconnectFromBusCh = nil + if e.eventSubs != nil { + e.eventSubs.Close() + e.eventSubs = nil e.debugSessionsCh = nil } } @@ -270,7 +263,7 @@ func (e *extension) serverStatus() status.ServerStatus { UDPPort: nil, Sessions: nil, } - if e.port == nil || e.busDoneCh == nil { + if e.port == nil || e.eventSubs == nil { return st } st.UDPPort = ptr.To(*e.port) @@ -281,7 +274,7 @@ func (e *extension) serverStatus() status.ServerStatus { resp := <-ch st.Sessions = resp return st - case <-e.busDoneCh: + case <-e.eventSubs.Done(): return st } } diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index d3fc36a83..89c004dc7 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -101,8 +101,8 @@ func Test_extension_profileStateChanged(t *testing.T) { } defer e.disconnectFromBusLocked() e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) - if tt.wantBusRunning != (e.busDoneCh != nil) { - t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) + if tt.wantBusRunning != (e.eventSubs != nil) { + t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil) } if (tt.wantPort == nil) != (e.port == nil) { t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) @@ -118,7 +118,7 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) { name string shutdown bool port *int - busDoneCh chan struct{} + eventSubs *eventbus.Monitor hasNodeAttrDisableRelayServer bool wantBusRunning bool }{ @@ -157,13 +157,13 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) { bus: eventbus.New(), shutdown: tt.shutdown, port: tt.port, - busDoneCh: tt.busDoneCh, + eventSubs: tt.eventSubs, hasNodeAttrDisableRelayServer: tt.hasNodeAttrDisableRelayServer, } e.handleBusLifetimeLocked() defer e.disconnectFromBusLocked() - if tt.wantBusRunning != (e.busDoneCh != nil) { - t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) + if tt.wantBusRunning != (e.eventSubs != nil) { + t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil) } }) }