mirror of
https://github.com/tailscale/tailscale.git
synced 2026-04-21 21:42:19 +02:00
Updates #12542 Updates tailscale/corp#40088 Change-Id: Idb4526f1bf1f3f424d6fb3d7e34ebe89a474b57b Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
416 lines
13 KiB
Go
416 lines
13 KiB
Go
// Copyright (c) Tailscale Inc & contributors
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package tsp
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
"tailscale.com/control/ts2021"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/key"
|
|
)
|
|
|
|
// errSessionClosed is returned by [MapSession.Next] and
|
|
// [MapSession.NextInto] when called after [MapSession.Close].
|
|
var errSessionClosed = errors.New("tsp: map session closed")
|
|
|
|
// DefaultMaxMessageSize is the default cap, in bytes, on the size of a
|
|
// single compressed map response frame. See [MapOpts.MaxMessageSize].
|
|
const DefaultMaxMessageSize = 4 << 20
|
|
|
|
// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to
|
|
// amortize the cost of setting up zstd state. Decoders are returned via
|
|
// [MapSession.Close]; entries are reclaimed by the runtime under memory
|
|
// pressure via sync.Pool semantics.
|
|
var zstdDecoderPool sync.Pool // of *zstd.Decoder
|
|
|
|
// MapOpts contains options for sending a map request.
|
|
type MapOpts struct {
|
|
// NodeKey is the node's private key. Required.
|
|
NodeKey key.NodePrivate
|
|
|
|
// Hostinfo is the host information to send. Optional;
|
|
// if nil, a minimal default is used.
|
|
Hostinfo *tailcfg.Hostinfo
|
|
|
|
// Stream is whether to receive multiple MapResponses over
|
|
// the same HTTP connection.
|
|
Stream bool
|
|
|
|
// OmitPeers is whether the client is okay with the Peers list
|
|
// being omitted in the response.
|
|
OmitPeers bool
|
|
|
|
// MaxMessageSize is the maximum size in bytes of any single
|
|
// compressed map response frame on the wire. If zero,
|
|
// [DefaultMaxMessageSize] is used.
|
|
MaxMessageSize int64
|
|
}
|
|
|
|
// framedReader is an io.Reader that consumes a stream of length-prefixed
|
|
// frames (each a little-endian uint32 length followed by that many bytes)
|
|
// from r and yields only the frame payloads back-to-back.
|
|
//
|
|
// This lets us feed the concatenated zstd frames from our wire protocol
|
|
// into a single streaming zstd decoder. Zstd's file format permits
|
|
// concatenation (RFC 8478 §2), and klauspost's decoder handles it
|
|
// transparently.
|
|
//
|
|
// If onNewFrame is non-nil, it is called after each new 4-byte length
|
|
// header is successfully read. Used to reset the per-message decoded-size
|
|
// budget downstream.
|
|
type framedReader struct {
|
|
r io.Reader
|
|
maxSize int64 // per-frame compressed-size cap
|
|
remain int // bytes remaining in the current frame
|
|
onNewFrame func()
|
|
}
|
|
|
|
func (f *framedReader) Read(p []byte) (int, error) {
|
|
if f.remain == 0 {
|
|
var hdr [4]byte
|
|
if _, err := io.ReadFull(f.r, hdr[:]); err != nil {
|
|
return 0, err
|
|
}
|
|
sz := int64(binary.LittleEndian.Uint32(hdr[:]))
|
|
if sz == 0 {
|
|
return 0, fmt.Errorf("map response: zero-length frame")
|
|
}
|
|
if sz > f.maxSize {
|
|
return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize)
|
|
}
|
|
f.remain = int(sz)
|
|
if f.onNewFrame != nil {
|
|
f.onNewFrame()
|
|
}
|
|
}
|
|
if len(p) > f.remain {
|
|
p = p[:f.remain]
|
|
}
|
|
n, err := f.r.Read(p)
|
|
f.remain -= n
|
|
return n, err
|
|
}
|
|
|
|
// boundedReader is an io.Reader that yields at most remain bytes from r
|
|
// before returning an error. Call reset to raise the budget back to max,
|
|
// typically at a new message boundary.
|
|
//
|
|
// Used to cap the decoded size of a single map response so a malicious
|
|
// server can't send a small zstd frame that explodes into gigabytes of
|
|
// junk for the json.Decoder to consume.
|
|
type boundedReader struct {
|
|
r io.Reader
|
|
max int64
|
|
remain int64
|
|
}
|
|
|
|
func (b *boundedReader) Read(p []byte) (int, error) {
|
|
if b.remain <= 0 {
|
|
return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max)
|
|
}
|
|
if int64(len(p)) > b.remain {
|
|
p = p[:b.remain]
|
|
}
|
|
n, err := b.r.Read(p)
|
|
b.remain -= int64(n)
|
|
return n, err
|
|
}
|
|
|
|
func (b *boundedReader) reset() { b.remain = b.max }
|
|
|
|
// MapSession wraps an in-progress map response stream. Call Next to read
|
|
// each MapResponse. Call Close when done.
|
|
type MapSession struct {
|
|
res *http.Response
|
|
stream bool
|
|
noiseDoer func(*http.Request) (*http.Response, error)
|
|
|
|
// inNext detects concurrent NextInto callers. It CAS-flips
|
|
// false→true on entry and back to false on exit; a failed CAS
|
|
// panics, akin to how the Go runtime detects concurrent map
|
|
// access. It does not serialize Close vs. NextInto; that's
|
|
// nextMu's job.
|
|
inNext atomic.Bool
|
|
|
|
// nextMu is held while [MapSession.NextInto] is running jdec.Decode,
|
|
// so that Close can wait for an in-flight Decode to unwind before it
|
|
// touches zdec (Reset, pool-Put) and avoids racing with the running
|
|
// Read chain that Decode drives.
|
|
nextMu sync.Mutex
|
|
read int // guarded by nextMu
|
|
closed bool // guarded by nextMu
|
|
zdec *zstd.Decoder // reads from a framedReader wrapping res.Body
|
|
jdec *json.Decoder // reads decompressed JSON from zdec
|
|
|
|
closeOnce sync.Once
|
|
closeErr error
|
|
}
|
|
|
|
// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session.
|
|
func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) {
|
|
return s.noiseDoer(req)
|
|
}
|
|
|
|
// Next reads and returns the next MapResponse from the stream.
|
|
// For non-streaming sessions, the first call returns the single response
|
|
// and subsequent calls return io.EOF.
|
|
// For streaming sessions, Next blocks until the next response arrives
|
|
// or the server closes the connection.
|
|
//
|
|
// Each call allocates a fresh MapResponse. Callers that want to amortize
|
|
// the allocation across calls can use [MapSession.NextInto].
|
|
//
|
|
// Next and NextInto are not safe to call concurrently from multiple
|
|
// goroutines on the same [MapSession]; a concurrent call panics, akin
|
|
// to the Go runtime's concurrent map access detection. [MapSession.Close]
|
|
// may be called concurrently to abort an in-flight Next.
|
|
func (s *MapSession) Next() (*tailcfg.MapResponse, error) {
|
|
var resp tailcfg.MapResponse
|
|
if err := s.NextInto(&resp); err != nil {
|
|
return nil, err
|
|
}
|
|
return &resp, nil
|
|
}
|
|
|
|
// NextInto is like [MapSession.Next] but decodes the next MapResponse into
|
|
// the caller-supplied *resp rather than allocating a new one. The pointer's
|
|
// pointee is zeroed before decoding so fields from a prior response do not
|
|
// persist.
|
|
//
|
|
// For non-streaming sessions, the first call decodes the single response
|
|
// and subsequent calls return io.EOF.
|
|
// For streaming sessions, NextInto blocks until the next response arrives
|
|
// or the server closes the connection.
|
|
//
|
|
// See [MapSession.Next] for concurrency rules; those apply to NextInto too.
|
|
func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error {
|
|
if !s.inNext.CompareAndSwap(false, true) {
|
|
panic("tsp: invalid concurrent call to MapSession.Next/NextInto")
|
|
}
|
|
defer s.inNext.Store(false)
|
|
|
|
s.nextMu.Lock()
|
|
defer s.nextMu.Unlock()
|
|
if s.closed {
|
|
return errSessionClosed
|
|
}
|
|
if !s.stream && s.read > 0 {
|
|
return io.EOF
|
|
}
|
|
*resp = tailcfg.MapResponse{}
|
|
if err := s.jdec.Decode(resp); err != nil {
|
|
return err
|
|
}
|
|
s.read++
|
|
return nil
|
|
}
|
|
|
|
// Close returns the session's zstd decoder to the pool and closes the
|
|
// underlying HTTP response body. It is safe to call Close multiple times
|
|
// and from multiple goroutines, including while a [MapSession.Next] or
|
|
// [MapSession.NextInto] call is in flight on another goroutine (which
|
|
// will return an error once the body close propagates).
|
|
func (s *MapSession) Close() error {
|
|
// Callers are likely to race a deferred Close with a time.AfterFunc
|
|
// timeout (or similar) Close that aborts a hung Next. Without the
|
|
// Once, both Closes would Put the same *zstd.Decoder into the pool,
|
|
// corrupting it, and the Reset/Put in one would race with the
|
|
// zdec.Read that the hung Next is driving.
|
|
//
|
|
// Ordering inside the Once: close the body first to unblock any
|
|
// in-flight NextInto (its Read chain ends at res.Body and will
|
|
// return an error once it's closed). That lets NextInto unwind and
|
|
// release nextMu. Only then do we take nextMu ourselves and touch
|
|
// zdec, which is safe because no goroutine is still reading from
|
|
// it. Acquiring nextMu before closing the body would deadlock
|
|
// against a hung NextInto.
|
|
s.closeOnce.Do(func() {
|
|
s.closeErr = s.res.Body.Close()
|
|
s.nextMu.Lock()
|
|
defer s.nextMu.Unlock()
|
|
s.closed = true
|
|
s.zdec.Reset(nil)
|
|
zstdDecoderPool.Put(s.zdec)
|
|
})
|
|
return s.closeErr
|
|
}
|
|
|
|
// SendMapUpdateOpts contains options for [Client.SendMapUpdate].
|
|
type SendMapUpdateOpts struct {
|
|
// NodeKey is the node's private key. Required.
|
|
NodeKey key.NodePrivate
|
|
|
|
// DiscoKey, if non-zero, is the node's disco public key.
|
|
// Peers use it to verify disco pings from this node, which is
|
|
// what enables direct (non-DERP) paths.
|
|
DiscoKey key.DiscoPublic
|
|
|
|
// Hostinfo is the host information to send. Optional;
|
|
// if nil, a minimal default is used.
|
|
Hostinfo *tailcfg.Hostinfo
|
|
}
|
|
|
|
// SendMapUpdate sends a one-shot, non-streaming MapRequest to push small
|
|
// updates (such as the node's endpoints, hostinfo, or disco public key) to the
|
|
// coordination server without starting or disturbing a streaming map session.
|
|
func (c *Client) SendMapUpdate(ctx context.Context, opts SendMapUpdateOpts) error {
|
|
if opts.NodeKey.IsZero() {
|
|
return fmt.Errorf("NodeKey is required")
|
|
}
|
|
|
|
hi := opts.Hostinfo
|
|
if hi == nil {
|
|
hi = defaultHostinfo()
|
|
}
|
|
|
|
mapReq := tailcfg.MapRequest{
|
|
Version: tailcfg.CurrentCapabilityVersion,
|
|
NodeKey: opts.NodeKey.Public(),
|
|
DiscoKey: opts.DiscoKey,
|
|
Hostinfo: hi,
|
|
Compress: "zstd",
|
|
|
|
// A lite update that lets the server persist our state without breaking
|
|
// any existing streaming map session. See the [tailcfg.MapResponse]
|
|
// OmitPeers docs.
|
|
OmitPeers: true,
|
|
Stream: false,
|
|
ReadOnly: false,
|
|
}
|
|
|
|
body, err := json.Marshal(mapReq)
|
|
if err != nil {
|
|
return fmt.Errorf("encoding map request: %w", err)
|
|
}
|
|
|
|
nc, err := c.noiseClient(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("establishing noise connection: %w", err)
|
|
}
|
|
|
|
url := c.serverURL + "/machine/map"
|
|
url = strings.Replace(url, "http:", "https:", 1)
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return fmt.Errorf("creating map request: %w", err)
|
|
}
|
|
ts2021.AddLBHeader(req, opts.NodeKey.Public())
|
|
|
|
res, err := nc.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("map request: %w", err)
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode != 200 {
|
|
msg, _ := io.ReadAll(res.Body)
|
|
return fmt.Errorf("map request: http %d: %.200s",
|
|
res.StatusCode, strings.TrimSpace(string(msg)))
|
|
}
|
|
io.Copy(io.Discard, res.Body)
|
|
return nil
|
|
}
|
|
|
|
// Map sends a map request to the coordination server and returns a MapSession
|
|
// for reading the framed, zstd-compressed response(s).
|
|
func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) {
|
|
if opts.NodeKey.IsZero() {
|
|
return nil, fmt.Errorf("NodeKey is required")
|
|
}
|
|
|
|
hi := opts.Hostinfo
|
|
if hi == nil {
|
|
hi = defaultHostinfo()
|
|
}
|
|
|
|
mapReq := tailcfg.MapRequest{
|
|
Version: tailcfg.CurrentCapabilityVersion,
|
|
NodeKey: opts.NodeKey.Public(),
|
|
Hostinfo: hi,
|
|
Stream: opts.Stream,
|
|
Compress: "zstd",
|
|
OmitPeers: opts.OmitPeers,
|
|
// Streaming requires the server to track us as "connected",
|
|
// which in turn requires ReadOnly=false. Non-streaming polls
|
|
// stay ReadOnly to minimize side effects.
|
|
ReadOnly: !opts.Stream,
|
|
}
|
|
|
|
body, err := json.Marshal(mapReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("encoding map request: %w", err)
|
|
}
|
|
|
|
nc, err := c.noiseClient(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("establishing noise connection: %w", err)
|
|
}
|
|
|
|
url := c.serverURL + "/machine/map"
|
|
url = strings.Replace(url, "http:", "https:", 1)
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating map request: %w", err)
|
|
}
|
|
ts2021.AddLBHeader(req, opts.NodeKey.Public())
|
|
|
|
res, err := nc.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("map request: %w", err)
|
|
}
|
|
|
|
if res.StatusCode != 200 {
|
|
msg, _ := io.ReadAll(res.Body)
|
|
res.Body.Close()
|
|
return nil, fmt.Errorf("map request: http %d: %.200s",
|
|
res.StatusCode, strings.TrimSpace(string(msg)))
|
|
}
|
|
|
|
maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize)
|
|
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
|
|
fr := &framedReader{
|
|
r: res.Body,
|
|
maxSize: maxMessageSize,
|
|
onNewFrame: bounded.reset,
|
|
}
|
|
|
|
zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder)
|
|
if zdec != nil {
|
|
if err := zdec.Reset(fr); err != nil {
|
|
// Reset can fail if the previous stream is in a bad state; drop
|
|
// the decoder and create a fresh one.
|
|
zdec = nil
|
|
}
|
|
}
|
|
if zdec == nil {
|
|
zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
|
|
if err != nil {
|
|
res.Body.Close()
|
|
return nil, fmt.Errorf("creating zstd decoder: %w", err)
|
|
}
|
|
}
|
|
bounded.r = zdec
|
|
|
|
return &MapSession{
|
|
res: res,
|
|
stream: opts.Stream,
|
|
noiseDoer: nc.Do,
|
|
zdec: zdec,
|
|
jdec: json.NewDecoder(bounded),
|
|
}, nil
|
|
}
|