Brad Fitzpatrick 00a08ea86d control/tsp: add lite map update support
Updates #12542
Updates tailscale/corp#40088

Change-Id: Idb4526f1bf1f3f424d6fb3d7e34ebe89a474b57b
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2026-04-17 04:19:50 -07:00

416 lines
13 KiB
Go

// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"bytes"
"cmp"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"github.com/klauspost/compress/zstd"
"tailscale.com/control/ts2021"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// errSessionClosed is returned by [MapSession.Next] and
// [MapSession.NextInto] when called after [MapSession.Close].
var errSessionClosed = errors.New("tsp: map session closed")
// DefaultMaxMessageSize is the default cap, in bytes, on the size of a
// single compressed map response frame. See [MapOpts.MaxMessageSize].
const DefaultMaxMessageSize = 4 << 20
// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to
// amortize the cost of setting up zstd state. Decoders are returned via
// [MapSession.Close]; entries are reclaimed by the runtime under memory
// pressure via sync.Pool semantics.
var zstdDecoderPool sync.Pool // of *zstd.Decoder
// MapOpts contains options for sending a map request.
type MapOpts struct {
// NodeKey is the node's private key. Required.
NodeKey key.NodePrivate
// Hostinfo is the host information to send. Optional;
// if nil, a minimal default is used.
Hostinfo *tailcfg.Hostinfo
// Stream is whether to receive multiple MapResponses over
// the same HTTP connection.
Stream bool
// OmitPeers is whether the client is okay with the Peers list
// being omitted in the response.
OmitPeers bool
// MaxMessageSize is the maximum size in bytes of any single
// compressed map response frame on the wire. If zero,
// [DefaultMaxMessageSize] is used.
MaxMessageSize int64
}
// framedReader is an io.Reader that consumes a stream of length-prefixed
// frames (each a little-endian uint32 length followed by that many bytes)
// from r and yields only the frame payloads back-to-back.
//
// This lets us feed the concatenated zstd frames from our wire protocol
// into a single streaming zstd decoder. Zstd's file format permits
// concatenation (RFC 8478 §2), and klauspost's decoder handles it
// transparently.
//
// If onNewFrame is non-nil, it is called after each new 4-byte length
// header is successfully read. Used to reset the per-message decoded-size
// budget downstream.
type framedReader struct {
r io.Reader
maxSize int64 // per-frame compressed-size cap
remain int // bytes remaining in the current frame
onNewFrame func()
}
func (f *framedReader) Read(p []byte) (int, error) {
if f.remain == 0 {
var hdr [4]byte
if _, err := io.ReadFull(f.r, hdr[:]); err != nil {
return 0, err
}
sz := int64(binary.LittleEndian.Uint32(hdr[:]))
if sz == 0 {
return 0, fmt.Errorf("map response: zero-length frame")
}
if sz > f.maxSize {
return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize)
}
f.remain = int(sz)
if f.onNewFrame != nil {
f.onNewFrame()
}
}
if len(p) > f.remain {
p = p[:f.remain]
}
n, err := f.r.Read(p)
f.remain -= n
return n, err
}
// boundedReader is an io.Reader that yields at most remain bytes from r
// before returning an error. Call reset to raise the budget back to max,
// typically at a new message boundary.
//
// Used to cap the decoded size of a single map response so a malicious
// server can't send a small zstd frame that explodes into gigabytes of
// junk for the json.Decoder to consume.
type boundedReader struct {
r io.Reader
max int64
remain int64
}
func (b *boundedReader) Read(p []byte) (int, error) {
if b.remain <= 0 {
return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max)
}
if int64(len(p)) > b.remain {
p = p[:b.remain]
}
n, err := b.r.Read(p)
b.remain -= int64(n)
return n, err
}
func (b *boundedReader) reset() { b.remain = b.max }
// MapSession wraps an in-progress map response stream. Call Next to read
// each MapResponse. Call Close when done.
type MapSession struct {
res *http.Response
stream bool
noiseDoer func(*http.Request) (*http.Response, error)
// inNext detects concurrent NextInto callers. It CAS-flips
// false→true on entry and back to false on exit; a failed CAS
// panics, akin to how the Go runtime detects concurrent map
// access. It does not serialize Close vs. NextInto; that's
// nextMu's job.
inNext atomic.Bool
// nextMu is held while [MapSession.NextInto] is running jdec.Decode,
// so that Close can wait for an in-flight Decode to unwind before it
// touches zdec (Reset, pool-Put) and avoids racing with the running
// Read chain that Decode drives.
nextMu sync.Mutex
read int // guarded by nextMu
closed bool // guarded by nextMu
zdec *zstd.Decoder // reads from a framedReader wrapping res.Body
jdec *json.Decoder // reads decompressed JSON from zdec
closeOnce sync.Once
closeErr error
}
// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session.
func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) {
return s.noiseDoer(req)
}
// Next reads and returns the next MapResponse from the stream.
// For non-streaming sessions, the first call returns the single response
// and subsequent calls return io.EOF.
// For streaming sessions, Next blocks until the next response arrives
// or the server closes the connection.
//
// Each call allocates a fresh MapResponse. Callers that want to amortize
// the allocation across calls can use [MapSession.NextInto].
//
// Next and NextInto are not safe to call concurrently from multiple
// goroutines on the same [MapSession]; a concurrent call panics, akin
// to the Go runtime's concurrent map access detection. [MapSession.Close]
// may be called concurrently to abort an in-flight Next.
func (s *MapSession) Next() (*tailcfg.MapResponse, error) {
var resp tailcfg.MapResponse
if err := s.NextInto(&resp); err != nil {
return nil, err
}
return &resp, nil
}
// NextInto is like [MapSession.Next] but decodes the next MapResponse into
// the caller-supplied *resp rather than allocating a new one. The pointer's
// pointee is zeroed before decoding so fields from a prior response do not
// persist.
//
// For non-streaming sessions, the first call decodes the single response
// and subsequent calls return io.EOF.
// For streaming sessions, NextInto blocks until the next response arrives
// or the server closes the connection.
//
// See [MapSession.Next] for concurrency rules; those apply to NextInto too.
func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error {
if !s.inNext.CompareAndSwap(false, true) {
panic("tsp: invalid concurrent call to MapSession.Next/NextInto")
}
defer s.inNext.Store(false)
s.nextMu.Lock()
defer s.nextMu.Unlock()
if s.closed {
return errSessionClosed
}
if !s.stream && s.read > 0 {
return io.EOF
}
*resp = tailcfg.MapResponse{}
if err := s.jdec.Decode(resp); err != nil {
return err
}
s.read++
return nil
}
// Close returns the session's zstd decoder to the pool and closes the
// underlying HTTP response body. It is safe to call Close multiple times
// and from multiple goroutines, including while a [MapSession.Next] or
// [MapSession.NextInto] call is in flight on another goroutine (which
// will return an error once the body close propagates).
func (s *MapSession) Close() error {
// Callers are likely to race a deferred Close with a time.AfterFunc
// timeout (or similar) Close that aborts a hung Next. Without the
// Once, both Closes would Put the same *zstd.Decoder into the pool,
// corrupting it, and the Reset/Put in one would race with the
// zdec.Read that the hung Next is driving.
//
// Ordering inside the Once: close the body first to unblock any
// in-flight NextInto (its Read chain ends at res.Body and will
// return an error once it's closed). That lets NextInto unwind and
// release nextMu. Only then do we take nextMu ourselves and touch
// zdec, which is safe because no goroutine is still reading from
// it. Acquiring nextMu before closing the body would deadlock
// against a hung NextInto.
s.closeOnce.Do(func() {
s.closeErr = s.res.Body.Close()
s.nextMu.Lock()
defer s.nextMu.Unlock()
s.closed = true
s.zdec.Reset(nil)
zstdDecoderPool.Put(s.zdec)
})
return s.closeErr
}
// SendMapUpdateOpts contains options for [Client.SendMapUpdate].
type SendMapUpdateOpts struct {
// NodeKey is the node's private key. Required.
NodeKey key.NodePrivate
// DiscoKey, if non-zero, is the node's disco public key.
// Peers use it to verify disco pings from this node, which is
// what enables direct (non-DERP) paths.
DiscoKey key.DiscoPublic
// Hostinfo is the host information to send. Optional;
// if nil, a minimal default is used.
Hostinfo *tailcfg.Hostinfo
}
// SendMapUpdate sends a one-shot, non-streaming MapRequest to push small
// updates (such as the node's endpoints, hostinfo, or disco public key) to the
// coordination server without starting or disturbing a streaming map session.
func (c *Client) SendMapUpdate(ctx context.Context, opts SendMapUpdateOpts) error {
if opts.NodeKey.IsZero() {
return fmt.Errorf("NodeKey is required")
}
hi := opts.Hostinfo
if hi == nil {
hi = defaultHostinfo()
}
mapReq := tailcfg.MapRequest{
Version: tailcfg.CurrentCapabilityVersion,
NodeKey: opts.NodeKey.Public(),
DiscoKey: opts.DiscoKey,
Hostinfo: hi,
Compress: "zstd",
// A lite update that lets the server persist our state without breaking
// any existing streaming map session. See the [tailcfg.MapResponse]
// OmitPeers docs.
OmitPeers: true,
Stream: false,
ReadOnly: false,
}
body, err := json.Marshal(mapReq)
if err != nil {
return fmt.Errorf("encoding map request: %w", err)
}
nc, err := c.noiseClient(ctx)
if err != nil {
return fmt.Errorf("establishing noise connection: %w", err)
}
url := c.serverURL + "/machine/map"
url = strings.Replace(url, "http:", "https:", 1)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("creating map request: %w", err)
}
ts2021.AddLBHeader(req, opts.NodeKey.Public())
res, err := nc.Do(req)
if err != nil {
return fmt.Errorf("map request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
msg, _ := io.ReadAll(res.Body)
return fmt.Errorf("map request: http %d: %.200s",
res.StatusCode, strings.TrimSpace(string(msg)))
}
io.Copy(io.Discard, res.Body)
return nil
}
// Map sends a map request to the coordination server and returns a MapSession
// for reading the framed, zstd-compressed response(s).
func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) {
if opts.NodeKey.IsZero() {
return nil, fmt.Errorf("NodeKey is required")
}
hi := opts.Hostinfo
if hi == nil {
hi = defaultHostinfo()
}
mapReq := tailcfg.MapRequest{
Version: tailcfg.CurrentCapabilityVersion,
NodeKey: opts.NodeKey.Public(),
Hostinfo: hi,
Stream: opts.Stream,
Compress: "zstd",
OmitPeers: opts.OmitPeers,
// Streaming requires the server to track us as "connected",
// which in turn requires ReadOnly=false. Non-streaming polls
// stay ReadOnly to minimize side effects.
ReadOnly: !opts.Stream,
}
body, err := json.Marshal(mapReq)
if err != nil {
return nil, fmt.Errorf("encoding map request: %w", err)
}
nc, err := c.noiseClient(ctx)
if err != nil {
return nil, fmt.Errorf("establishing noise connection: %w", err)
}
url := c.serverURL + "/machine/map"
url = strings.Replace(url, "http:", "https:", 1)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating map request: %w", err)
}
ts2021.AddLBHeader(req, opts.NodeKey.Public())
res, err := nc.Do(req)
if err != nil {
return nil, fmt.Errorf("map request: %w", err)
}
if res.StatusCode != 200 {
msg, _ := io.ReadAll(res.Body)
res.Body.Close()
return nil, fmt.Errorf("map request: http %d: %.200s",
res.StatusCode, strings.TrimSpace(string(msg)))
}
maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize)
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
fr := &framedReader{
r: res.Body,
maxSize: maxMessageSize,
onNewFrame: bounded.reset,
}
zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder)
if zdec != nil {
if err := zdec.Reset(fr); err != nil {
// Reset can fail if the previous stream is in a bad state; drop
// the decoder and create a fresh one.
zdec = nil
}
}
if zdec == nil {
zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
if err != nil {
res.Body.Close()
return nil, fmt.Errorf("creating zstd decoder: %w", err)
}
}
bounded.r = zdec
return &MapSession{
res: res,
stream: opts.Stream,
noiseDoer: nc.Do,
zdec: zdec,
jdec: json.NewDecoder(bounded),
}, nil
}