diff --git a/control/tsp/map_test.go b/control/tsp/map_test.go index 14b64f39a..ddfde3971 100644 --- a/control/tsp/map_test.go +++ b/control/tsp/map_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/klauspost/compress/zstd" + "tailscale.com/health" "tailscale.com/tailcfg" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" @@ -31,6 +32,8 @@ func TestMapAgainstTestControl(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + ht := new(health.Tracker) + serverKey, err := DiscoverServerKey(ctx, baseURL) if err != nil { t.Fatalf("DiscoverServerKey: %v", err) @@ -41,8 +44,9 @@ func TestMapAgainstTestControl(t *testing.T) { nodeKey = key.NewNode() machineKey = key.NewMachine() c, err := NewClient(ClientOpts{ - ServerURL: baseURL, - MachineKey: machineKey, + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, }) if err != nil { t.Fatalf("NewClient %s: %v", hostname, err) @@ -62,8 +66,9 @@ func TestMapAgainstTestControl(t *testing.T) { nodeKeyB, _ := register("b") clientA, err := NewClient(ClientOpts{ - ServerURL: baseURL, - MachineKey: machineKeyA, + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, }) if err != nil { t.Fatalf("NewClient A: %v", err) @@ -144,6 +149,8 @@ func TestSendMapUpdateAgainstTestControl(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + ht := new(health.Tracker) + serverKey, err := DiscoverServerKey(ctx, baseURL) if err != nil { t.Fatalf("DiscoverServerKey: %v", err) @@ -154,8 +161,9 @@ func TestSendMapUpdateAgainstTestControl(t *testing.T) { nodeKey = key.NewNode() machineKey = key.NewMachine() c, err := NewClient(ClientOpts{ - ServerURL: baseURL, - MachineKey: machineKey, + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, }) if err != nil { t.Fatalf("NewClient %s: %v", hostname, err) @@ -176,8 +184,9 @@ func TestSendMapUpdateAgainstTestControl(t *testing.T) { // B starts a streaming map poll so we can observe updates about peer A. clientB, err := NewClient(ClientOpts{ - ServerURL: baseURL, - MachineKey: machineKeyB, + ServerURL: baseURL, + MachineKey: machineKeyB, + HealthTracker: ht, }) if err != nil { t.Fatalf("NewClient B: %v", err) @@ -228,8 +237,9 @@ func TestSendMapUpdateAgainstTestControl(t *testing.T) { // A pushes its disco key via SendMapUpdate. clientA, err := NewClient(ClientOpts{ - ServerURL: baseURL, - MachineKey: machineKeyA, + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, }) if err != nil { t.Fatalf("NewClient A: %v", err) diff --git a/control/tsp/tsp.go b/control/tsp/tsp.go index a75cc7d0e..23f2fc261 100644 --- a/control/tsp/tsp.go +++ b/control/tsp/tsp.go @@ -19,6 +19,7 @@ import ( "sync" "tailscale.com/control/ts2021" + "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -43,6 +44,10 @@ type ClientOpts struct { // Logf is the log function. If nil, logger.Discard is used. Logf logger.Logf + + // HealthTracker, if non-nil, is the health tracker passed through + // to the underlying noise client. May be nil. + HealthTracker *health.Tracker } // Client is a Tailscale protocol client that speaks to a coordination @@ -155,11 +160,12 @@ func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) { } nc, err := ts2021.NewClient(ts2021.ClientOpts{ - ServerURL: c.serverURL, - PrivKey: c.opts.MachineKey, - ServerPubKey: c.serverPub, - Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext), - Logf: c.logf, + ServerURL: c.serverURL, + PrivKey: c.opts.MachineKey, + ServerPubKey: c.serverPub, + Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext), + Logf: c.logf, + HealthTracker: c.opts.HealthTracker, }) if err != nil { return nil, fmt.Errorf("creating noise client: %w", err) diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 95df1f5a6..0fdc885a0 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -156,8 +156,8 @@ type Server struct { updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath nodeKeyAuthed set.Set[key.NodePublic] - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + msgToSend map[key.NodePublic][]any // FIFO queue per node; values are *tailcfg.PingRequest or *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. // tkaStorage records the Tailnet Lock state, if any. // If nil, Tailnet Lock is not enabled in the Tailnet. @@ -300,14 +300,16 @@ func (s *Server) AddRawMapResponse(nodeKeyDst key.NodePublic, mr *tailcfg.MapRes func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.mu.Lock() defer s.mu.Unlock() - if s.msgToSend == nil { - s.msgToSend = map[key.NodePublic]any{} - } - // Now send the update to the channel node := s.nodeLocked(nodeKeyDst) if node == nil { return false } + updatesCh := s.updates[node.ID] + if updatesCh == nil { + // No streaming poll is registered, so there's nobody to deliver + // the message to. + return false + } if _, ok := msg.(*tailcfg.MapResponse); ok { if s.suppressAutoMapResponses == nil { @@ -316,10 +318,14 @@ func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.suppressAutoMapResponses.Add(nodeKeyDst) } - s.msgToSend[nodeKeyDst] = msg - nodeID := node.ID - oldUpdatesCh := s.updates[nodeID] - return sendUpdate(oldUpdatesCh, updateDebugInjection) + mak.Set(&s.msgToSend, nodeKeyDst, append(s.msgToSend[nodeKeyDst], msg)) + // sendUpdate returning false here is fine: the channel is a lossy + // wake-up signal whose buffer is single-slot. A full buffer means a + // prior wake-up is still pending, and the streaming poll will check + // msgToSend when it processes that wake-up. The queue in msgToSend + // is the source of truth. + sendUpdate(updatesCh, updateDebugInjection) + return true } // Mark the Node key of every node as expired @@ -1472,15 +1478,29 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, res.Node.PrimaryRoutes = s.nodeSubnetRoutes[nk] res.Node.AllowedIPs = append(res.Node.Addresses, s.nodeSubnetRoutes[nk]...) - // Consume a PingRequest while protected by mutex if it exists - switch m := s.msgToSend[nk].(type) { - case *tailcfg.PingRequest: - res.PingRequest = m - delete(s.msgToSend, nk) + // Consume a PingRequest at the head of the queue, if any. + if q := s.msgToSend[nk]; len(q) > 0 { + if pr, ok := q[0].(*tailcfg.PingRequest); ok { + res.PingRequest = pr + s.popMsgToSendLocked(nk) + } } return res, nil } +// popMsgToSendLocked pops the head of the per-node message queue. +// s.mu must be held. +func (s *Server) popMsgToSendLocked(nk key.NodePublic) { + q := s.msgToSend[nk] + if len(q) <= 1 { + delete(s.msgToSend, nk) + return + } + // Zero the head to allow GC of any large referenced response. + q[0] = nil + s.msgToSend[nk] = q[1:] +} + func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() @@ -1490,22 +1510,21 @@ func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { func (s *Server) hasPendingRawMapMessage(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - _, ok := s.msgToSend[nk] - return ok + return len(s.msgToSend[nk]) > 0 } func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok bool) { s.mu.Lock() defer s.mu.Unlock() - mr, ok := s.msgToSend[nk] - if !ok { + q := s.msgToSend[nk] + if len(q) == 0 { return nil, false } - delete(s.msgToSend, nk) + mr := q[0] + s.popMsgToSendLocked(nk) // If it's a bare PingRequest, wrap it in a MapResponse. - switch pr := mr.(type) { - case *tailcfg.PingRequest: + if pr, ok := mr.(*tailcfg.PingRequest); ok { mr = &tailcfg.MapResponse{PingRequest: pr} }