diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 9e6f24419..4e2215aec 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -872,6 +872,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/tsd+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/local+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt index 4c9c6831e..f5cd1232d 100644 --- a/tsnet/depaware.txt +++ b/tsnet/depaware.txt @@ -303,6 +303,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/tsd+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/local+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 5f4f0bd8c..f88dab29d 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -95,6 +95,7 @@ type endpoint struct { expired bool // whether the node has expired isWireguardOnly bool // whether the endpoint is WireGuard only + relayCapable bool // whether the node is capable of speaking via a [tailscale.com/net/udprelay.Server] } func (de *endpoint) setBestAddrLocked(v addrQuality) { @@ -1249,6 +1250,13 @@ func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { // sent so our firewall ports are probably open and now // would be a good time for them to connect. go de.c.enqueueCallMeMaybe(derpAddr, de) + + // Schedule allocation of relay endpoints. We make no considerations for + // current relay endpoints or best UDP path state for now, keep it + // simple. + if de.relayCapable { + go de.c.relayManager.allocateAndHandshakeAllServers(de) + } } } @@ -1863,6 +1871,7 @@ func (de *endpoint) resetLocked() { } } de.probeUDPLifetime.resetCycleEndpointLocked() + de.c.relayManager.cancelOutstandingWork(de) } func (de *endpoint) numStopAndReset() int64 { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 7df46f76c..cf3ef2352 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1939,6 +1939,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) return } + ep.mu.Lock() + relayCapable := ep.relayCapable + ep.mu.Unlock() + if isVia && !relayCapable { + c.logf("magicsock: disco: ignoring %s from %v; %v is not known to be relay capable", msgType, sender.ShortString(), sender.ShortString()) + return + } epDisco := ep.disco.Load() if epDisco == nil { return diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index bf737b078..b1732ff41 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -4,48 +4,283 @@ package magicsock import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" "net/netip" "sync" + "time" "tailscale.com/disco" + udprelay "tailscale.com/net/udprelay/endpoint" "tailscale.com/types/key" + "tailscale.com/util/httpm" + "tailscale.com/util/set" ) // relayManager manages allocation and handshaking of // [tailscale.com/net/udprelay.Server] endpoints. The zero value is ready for // use. type relayManager struct { - mu sync.Mutex // guards the following fields + initOnce sync.Once + + // =================================================================== + // The following fields are owned by a single goroutine, runLoop(). + serversByAddrPort set.Set[netip.AddrPort] + allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork + + // =================================================================== + // The following chan fields serve event inputs to a single goroutine, + // runLoop(). + allocateHandshakeCh chan *endpoint + allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent + cancelWorkCh chan *endpoint + newServerEndpointCh chan newRelayServerEndpointEvent + rxChallengeCh chan relayHandshakeChallengeEvent + rxCallMeMaybeViaCh chan *disco.CallMeMaybeVia + + discoInfoMu sync.Mutex // guards the following field discoInfoByServerDisco map[key.DiscoPublic]*discoInfo + + // runLoopStoppedCh is written to by runLoop() upon return, enabling event + // writers to restart it when they are blocked (see + // relayManagerInputEvent()). + runLoopStoppedCh chan struct{} } -func (h *relayManager) initLocked() { - if h.discoInfoByServerDisco != nil { - return +type newRelayServerEndpointEvent struct { + ep *endpoint + se udprelay.ServerEndpoint +} + +type relayEndpointAllocWorkDoneEvent struct { + ep *endpoint + work *relayEndpointAllocWork +} + +// activeWork returns true if there is outstanding allocation or handshaking +// work, otherwise it returns false. +func (r *relayManager) activeWork() bool { + return len(r.allocWorkByEndpoint) > 0 + // TODO(jwhited): consider handshaking work +} + +// runLoop is a form of event loop. It ensures exclusive access to most of +// [relayManager] state. +func (r *relayManager) runLoop() { + defer func() { + r.runLoopStoppedCh <- struct{}{} + }() + + for { + select { + case ep := <-r.allocateHandshakeCh: + r.cancelAndClearWork(ep) + r.allocateAllServersForEndpoint(ep) + if !r.activeWork() { + return + } + case msg := <-r.allocateWorkDoneCh: + work, ok := r.allocWorkByEndpoint[msg.ep] + if ok && work == msg.work { + // Verify the work in the map is the same as the one that we're + // cleaning up. New events on r.allocateHandshakeCh can + // overwrite pre-existing keys. + delete(r.allocWorkByEndpoint, msg.ep) + } + if !r.activeWork() { + return + } + case ep := <-r.cancelWorkCh: + r.cancelAndClearWork(ep) + if !r.activeWork() { + return + } + case newEndpoint := <-r.newServerEndpointCh: + _ = newEndpoint + // TODO(jwhited): implement + if !r.activeWork() { + return + } + case challenge := <-r.rxChallengeCh: + _ = challenge + // TODO(jwhited): implement + if !r.activeWork() { + return + } + case via := <-r.rxCallMeMaybeViaCh: + _ = via + // TODO(jwhited): implement + if !r.activeWork() { + return + } + } } - h.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo) +} + +type relayHandshakeChallengeEvent struct { + challenge [32]byte + disco key.DiscoPublic + from netip.AddrPort + vni uint32 + at time.Time +} + +// relayEndpointAllocWork serves to track in-progress relay endpoint allocation +// for an [*endpoint]. This structure is immutable once initialized. +type relayEndpointAllocWork struct { + // ep is the [*endpoint] associated with the work + ep *endpoint + // cancel() will signal all associated goroutines to return + cancel context.CancelFunc + // wg.Wait() will return once all associated goroutines have returned + wg *sync.WaitGroup +} + +// init initializes [relayManager] if it is not already initialized. +func (r *relayManager) init() { + r.initOnce.Do(func() { + r.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo) + r.allocWorkByEndpoint = make(map[*endpoint]*relayEndpointAllocWork) + r.allocateHandshakeCh = make(chan *endpoint) + r.allocateWorkDoneCh = make(chan relayEndpointAllocWorkDoneEvent) + r.cancelWorkCh = make(chan *endpoint) + r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) + r.rxChallengeCh = make(chan relayHandshakeChallengeEvent) + r.rxCallMeMaybeViaCh = make(chan *disco.CallMeMaybeVia) + r.runLoopStoppedCh = make(chan struct{}, 1) + go r.runLoop() + }) } // discoInfo returns a [*discoInfo] for 'serverDisco' if there is an // active/ongoing handshake with it, otherwise it returns nil, false. -func (h *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { - h.mu.Lock() - defer h.mu.Unlock() - h.initLocked() - di, ok := h.discoInfoByServerDisco[serverDisco] +func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[serverDisco] return di, ok } -func (h *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) { - h.mu.Lock() - defer h.mu.Unlock() - h.initLocked() - // TODO(jwhited): implement +func (r *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) { + relayManagerInputEvent(r, nil, &r.rxCallMeMaybeViaCh, dm) } -func (h *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) { - h.mu.Lock() - defer h.mu.Unlock() - h.initLocked() - // TODO(jwhited): implement +func (r *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) { + relayManagerInputEvent(r, nil, &r.rxChallengeCh, relayHandshakeChallengeEvent{challenge: dm.Challenge, disco: di.discoKey, from: src, vni: vni, at: time.Now()}) +} + +// relayManagerInputEvent initializes [relayManager] if necessary, starts +// relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'. +// +// [relayManager] initialization will make `*eventCh`, so it must be passed as +// a pointer to a channel. +// +// 'ctx' can be used for returning when runLoop is waiting for the caller to +// return, i.e. the calling goroutine was birthed by runLoop and is cancelable +// via 'ctx'. 'ctx' may be nil. +func relayManagerInputEvent[T any](r *relayManager, ctx context.Context, eventCh *chan T, event T) { + r.init() + var ctxDoneCh <-chan struct{} + if ctx != nil { + ctxDoneCh = ctx.Done() + } + for { + select { + case <-ctxDoneCh: + return + case *eventCh <- event: + return + case <-r.runLoopStoppedCh: + go r.runLoop() + } + } +} + +// allocateAndHandshakeAllServers kicks off allocation and handshaking of relay +// endpoints for 'ep' on all known relay servers, canceling any existing +// in-progress work. +func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) { + relayManagerInputEvent(r, nil, &r.allocateHandshakeCh, ep) +} + +// cancelOutstandingWork cancels all outstanding allocation & handshaking work +// for 'ep'. +func (r *relayManager) cancelOutstandingWork(ep *endpoint) { + relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep) +} + +// cancelAndClearWork cancels & clears any outstanding work for 'ep'. +func (r *relayManager) cancelAndClearWork(ep *endpoint) { + allocWork, ok := r.allocWorkByEndpoint[ep] + if ok { + allocWork.cancel() + allocWork.wg.Wait() + delete(r.allocWorkByEndpoint, ep) + } + // TODO(jwhited): cancel & clear handshake work +} + +func (r *relayManager) allocateAllServersForEndpoint(ep *endpoint) { + if len(r.serversByAddrPort) == 0 { + return + } + ctx, cancel := context.WithCancel(context.Background()) + started := &relayEndpointAllocWork{ep: ep, cancel: cancel, wg: &sync.WaitGroup{}} + for k := range r.serversByAddrPort { + started.wg.Add(1) + go r.allocateEndpoint(ctx, started.wg, k, ep) + } + r.allocWorkByEndpoint[ep] = started + go func() { + started.wg.Wait() + started.cancel() + relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{ep: ep, work: started}) + }() +} + +func (r *relayManager) allocateEndpoint(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) { + // TODO(jwhited): introduce client metrics counters for notable failures + defer wg.Done() + var b bytes.Buffer + remoteDisco := ep.disco.Load() + if remoteDisco == nil { + return + } + type allocateRelayEndpointReq struct { + DiscoKeys []key.DiscoPublic + } + a := &allocateRelayEndpointReq{ + DiscoKeys: []key.DiscoPublic{ep.c.discoPublic, remoteDisco.key}, + } + err := json.NewEncoder(&b).Encode(a) + if err != nil { + return + } + const reqTimeout = time.Second * 10 + reqCtx, cancel := context.WithTimeout(ctx, reqTimeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, httpm.POST, "http://"+server.String()+"/relay/endpoint", &b) + if err != nil { + return + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return + } + var se udprelay.ServerEndpoint + err = json.NewDecoder(io.LimitReader(resp.Body, 4096)).Decode(&se) + if err != nil { + return + } + relayManagerInputEvent(r, ctx, &r.newServerEndpointCh, newRelayServerEndpointEvent{ + ep: ep, + se: se, + }) } diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go new file mode 100644 index 000000000..579dceb53 --- /dev/null +++ b/wgengine/magicsock/relaymanager_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "testing" + + "tailscale.com/disco" +) + +func TestRelayManagerInitAndIdle(t *testing.T) { + rm := relayManager{} + rm.allocateAndHandshakeAllServers(&endpoint{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.cancelOutstandingWork(&endpoint{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleCallMeMaybeVia(&disco.CallMeMaybeVia{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleBindUDPRelayEndpointChallenge(&disco.BindUDPRelayEndpointChallenge{}, &discoInfo{}, netip.AddrPort{}, 0) + <-rm.runLoopStoppedCh +}