diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 91d07484c..95bf29a11 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -82,11 +82,11 @@ type extension struct { logf logger.Logf bus *eventbus.Bus - mu sync.Mutex // guards the following fields - shutdown bool + mu sync.Mutex // guards the following fields + shutdown bool + port *int // ipn.Prefs.RelayServerPort, nil if disabled - disconnectFromBusCh chan struct{} // non-nil if consumeEventbusTopics is running, closed to signal it to return - busDoneCh chan struct{} // non-nil if consumeEventbusTopics is running, closed when it returns + eventSubs *eventbus.Monitor // nil if not connected to eventbus debugSessionsCh chan chan []status.ServerSession // non-nil if consumeEventbusTopics is running hasNodeAttrDisableRelayServer bool // tailcfg.NodeAttrDisableRelayServer } @@ -119,15 +119,13 @@ func (e *extension) handleBusLifetimeLocked() { if !busShouldBeRunning { e.disconnectFromBusLocked() return - } - if e.busDoneCh != nil { + } else if e.eventSubs != nil { return // already running } - port := *e.port - e.disconnectFromBusCh = make(chan struct{}) - e.busDoneCh = make(chan struct{}) + + ec := e.bus.Client("relayserver.extension") e.debugSessionsCh = make(chan chan []status.ServerSession) - go e.consumeEventbusTopics(port) + e.eventSubs = ptr.To(ec.Monitor(e.consumeEventbusTopics(ec, *e.port))) } func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { @@ -175,77 +173,72 @@ var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) { // consumeEventbusTopics serves endpoint allocation requests over the eventbus. // It also serves [relayServer] debug information on a channel. -// consumeEventbusTopics must never acquire [extension.mu], which can be held by -// other goroutines while waiting to receive on [extension.busDoneCh] or the +// consumeEventbusTopics must never acquire [extension.mu], which can be held +// by other goroutines while waiting to receive on [extension.eventSubs] or the // inner [extension.debugSessionsCh] channel. -func (e *extension) consumeEventbusTopics(port int) { - defer close(e.busDoneCh) +func (e *extension) consumeEventbusTopics(ec *eventbus.Client, port int) func(*eventbus.Client) { + reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](ec) + respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](ec) + debugSessionsCh := e.debugSessionsCh - eventClient := e.bus.Client("relayserver.extension") - reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](eventClient) - respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](eventClient) - defer eventClient.Close() - - var rs relayServer // lazily initialized - defer func() { - if rs != nil { - rs.Close() - } - }() - for { - select { - case <-e.disconnectFromBusCh: - return - case <-eventClient.Done(): - return - case respCh := <-e.debugSessionsCh: - if rs == nil { - // Don't initialize the server simply for a debug request. - respCh <- nil - continue + return func(ec *eventbus.Client) { + var rs relayServer // lazily initialized + defer func() { + if rs != nil { + rs.Close() } - sessions := rs.GetSessions() - respCh <- sessions - case req := <-reqSub.Events(): - if rs == nil { - var err error - rs, err = udprelay.NewServer(e.logf, port, overrideAddrs()) - if err != nil { - e.logf("error initializing server: %v", err) + }() + for { + select { + case <-ec.Done(): + return + case respCh := <-debugSessionsCh: + if rs == nil { + // Don't initialize the server simply for a debug request. + respCh <- nil continue } - } - se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) - if err != nil { - e.logf("error allocating endpoint: %v", err) - continue - } - respPub.Publish(magicsock.UDPRelayAllocResp{ - ReqRxFromNodeKey: req.RxFromNodeKey, - ReqRxFromDiscoKey: req.RxFromDiscoKey, - Message: &disco.AllocateUDPRelayEndpointResponse{ - Generation: req.Message.Generation, - UDPRelayEndpoint: disco.UDPRelayEndpoint{ - ServerDisco: se.ServerDisco, - ClientDisco: se.ClientDisco, - LamportID: se.LamportID, - VNI: se.VNI, - BindLifetime: se.BindLifetime.Duration, - SteadyStateLifetime: se.SteadyStateLifetime.Duration, - AddrPorts: se.AddrPorts, + sessions := rs.GetSessions() + respCh <- sessions + case req := <-reqSub.Events(): + if rs == nil { + var err error + rs, err = udprelay.NewServer(e.logf, port, overrideAddrs()) + if err != nil { + e.logf("error initializing server: %v", err) + continue + } + } + se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) + if err != nil { + e.logf("error allocating endpoint: %v", err) + continue + } + respPub.Publish(magicsock.UDPRelayAllocResp{ + ReqRxFromNodeKey: req.RxFromNodeKey, + ReqRxFromDiscoKey: req.RxFromDiscoKey, + Message: &disco.AllocateUDPRelayEndpointResponse{ + Generation: req.Message.Generation, + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, }, - }, - }) + }) + } } } } func (e *extension) disconnectFromBusLocked() { - if e.busDoneCh != nil { - close(e.disconnectFromBusCh) - <-e.busDoneCh - e.busDoneCh = nil - e.disconnectFromBusCh = nil + if e.eventSubs != nil { + e.eventSubs.Close() + e.eventSubs = nil e.debugSessionsCh = nil } } @@ -270,7 +263,7 @@ func (e *extension) serverStatus() status.ServerStatus { UDPPort: nil, Sessions: nil, } - if e.port == nil || e.busDoneCh == nil { + if e.port == nil || e.eventSubs == nil { return st } st.UDPPort = ptr.To(*e.port) @@ -281,7 +274,7 @@ func (e *extension) serverStatus() status.ServerStatus { resp := <-ch st.Sessions = resp return st - case <-e.busDoneCh: + case <-e.eventSubs.Done(): return st } } diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index d3fc36a83..89c004dc7 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -101,8 +101,8 @@ func Test_extension_profileStateChanged(t *testing.T) { } defer e.disconnectFromBusLocked() e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) - if tt.wantBusRunning != (e.busDoneCh != nil) { - t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) + if tt.wantBusRunning != (e.eventSubs != nil) { + t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil) } if (tt.wantPort == nil) != (e.port == nil) { t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) @@ -118,7 +118,7 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) { name string shutdown bool port *int - busDoneCh chan struct{} + eventSubs *eventbus.Monitor hasNodeAttrDisableRelayServer bool wantBusRunning bool }{ @@ -157,13 +157,13 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) { bus: eventbus.New(), shutdown: tt.shutdown, port: tt.port, - busDoneCh: tt.busDoneCh, + eventSubs: tt.eventSubs, hasNodeAttrDisableRelayServer: tt.hasNodeAttrDisableRelayServer, } e.handleBusLifetimeLocked() defer e.disconnectFromBusLocked() - if tt.wantBusRunning != (e.busDoneCh != nil) { - t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) + if tt.wantBusRunning != (e.eventSubs != nil) { + t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil) } }) }