mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 22:31:03 +02:00 
			
		
		
		
	ServerEndpoint will be used within magicsock and potentially elsewhere, which should be possible without needing to import the server implementation itself. Updates tailscale/corp#27502 Signed-off-by: Jordan Whited <jordan@tailscale.com>
		
			
				
	
	
		
			479 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			479 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| // Package udprelay contains constructs for relaying Disco and WireGuard packets
 | |
| // between Tailscale clients over UDP. This package is currently considered
 | |
| // experimental.
 | |
| package udprelay
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/rand"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"net/netip"
 | |
| 	"slices"
 | |
| 	"strconv"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"go4.org/mem"
 | |
| 	"tailscale.com/disco"
 | |
| 	"tailscale.com/net/packet"
 | |
| 	"tailscale.com/net/udprelay/endpoint"
 | |
| 	"tailscale.com/tstime"
 | |
| 	"tailscale.com/types/key"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	// defaultBindLifetime is somewhat arbitrary. We attempt to account for
 | |
| 	// high latency between client and [Server], and high latency between
 | |
| 	// clients over side channels, e.g. DERP, used to exchange
 | |
| 	// [endpoint.ServerEndpoint] details. So, a total of 3 paths with
 | |
| 	// potentially high latency. Using a conservative 10s "high latency" bounds
 | |
| 	// for each path we end up at a 30s total. It is worse to set an aggressive
 | |
| 	// bind lifetime as this may lead to path discovery failure, vs dealing with
 | |
| 	// a slight increase of [Server] resource utilization (VNIs, RAM, etc) while
 | |
| 	// tracking endpoints that won't bind.
 | |
| 	defaultBindLifetime        = time.Second * 30
 | |
| 	defaultSteadyStateLifetime = time.Minute * 5
 | |
| )
 | |
| 
 | |
| // Server implements an experimental UDP relay server.
 | |
| type Server struct {
 | |
| 	// disco keypair used as part of 3-way bind handshake
 | |
| 	disco       key.DiscoPrivate
 | |
| 	discoPublic key.DiscoPublic
 | |
| 
 | |
| 	bindLifetime        time.Duration
 | |
| 	steadyStateLifetime time.Duration
 | |
| 
 | |
| 	// addrPorts contains the ip:port pairs returned as candidate server
 | |
| 	// endpoints in response to an allocation request.
 | |
| 	addrPorts []netip.AddrPort
 | |
| 
 | |
| 	uc *net.UDPConn
 | |
| 
 | |
| 	closeOnce sync.Once
 | |
| 	wg        sync.WaitGroup
 | |
| 	closeCh   chan struct{}
 | |
| 	closed    bool
 | |
| 
 | |
| 	mu        sync.Mutex // guards the following fields
 | |
| 	lamportID uint64
 | |
| 	vniPool   []uint32 // the pool of available VNIs
 | |
| 	byVNI     map[uint32]*serverEndpoint
 | |
| 	byDisco   map[pairOfDiscoPubKeys]*serverEndpoint
 | |
| }
 | |
| 
 | |
| // pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via
 | |
| // newPairOfDiscoPubKeys to ensure lexicographical ordering.
 | |
| type pairOfDiscoPubKeys [2]key.DiscoPublic
 | |
| 
 | |
| func (p pairOfDiscoPubKeys) String() string {
 | |
| 	return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString())
 | |
| }
 | |
| 
 | |
| func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys {
 | |
| 	pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB})
 | |
| 	slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int {
 | |
| 		return a.Compare(b)
 | |
| 	})
 | |
| 	return pair
 | |
| }
 | |
| 
 | |
| // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
 | |
| // serverEndpoint methods are not thread-safe.
 | |
| type serverEndpoint struct {
 | |
| 	// discoPubKeys contains the key.DiscoPublic of the served clients. The
 | |
| 	// indexing of this array aligns with the following fields, e.g.
 | |
| 	// discoSharedSecrets[0] is the shared secret to use when sealing
 | |
| 	// Disco protocol messages for transmission towards discoPubKeys[0].
 | |
| 	discoPubKeys       pairOfDiscoPubKeys
 | |
| 	discoSharedSecrets [2]key.DiscoShared
 | |
| 	handshakeState     [2]disco.BindUDPRelayHandshakeState
 | |
| 	addrPorts          [2]netip.AddrPort
 | |
| 	lastSeen           [2]time.Time // TODO(jwhited): consider using mono.Time
 | |
| 	challenge          [2][disco.BindUDPRelayEndpointChallengeLen]byte
 | |
| 
 | |
| 	lamportID   uint64
 | |
| 	vni         uint32
 | |
| 	allocatedAt time.Time
 | |
| }
 | |
| 
 | |
| func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) {
 | |
| 	if senderIndex != 0 && senderIndex != 1 {
 | |
| 		return
 | |
| 	}
 | |
| 	handshakeState := e.handshakeState[senderIndex]
 | |
| 	if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
 | |
| 		// this sender is already bound
 | |
| 		return
 | |
| 	}
 | |
| 	switch discoMsg := discoMsg.(type) {
 | |
| 	case *disco.BindUDPRelayEndpoint:
 | |
| 		switch handshakeState {
 | |
| 		case disco.BindUDPRelayHandshakeStateInit:
 | |
| 			// set sender addr
 | |
| 			e.addrPorts[senderIndex] = from
 | |
| 			fallthrough
 | |
| 		case disco.BindUDPRelayHandshakeStateChallengeSent:
 | |
| 			if from != e.addrPorts[senderIndex] {
 | |
| 				// this is a later arriving bind from a different source, or
 | |
| 				// a retransmit and the sender's source has changed, discard
 | |
| 				return
 | |
| 			}
 | |
| 			m := new(disco.BindUDPRelayEndpointChallenge)
 | |
| 			copy(m.Challenge[:], e.challenge[senderIndex][:])
 | |
| 			reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
 | |
| 			gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
 | |
| 			err := gh.Encode(reply)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 			reply = append(reply, disco.Magic...)
 | |
| 			reply = serverDisco.AppendTo(reply)
 | |
| 			box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
 | |
| 			reply = append(reply, box...)
 | |
| 			uw.WriteMsgUDPAddrPort(reply, nil, from)
 | |
| 			// set new state
 | |
| 			e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
 | |
| 			return
 | |
| 		default:
 | |
| 			// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
 | |
| 			return
 | |
| 		}
 | |
| 	case *disco.BindUDPRelayEndpointAnswer:
 | |
| 		switch handshakeState {
 | |
| 		case disco.BindUDPRelayHandshakeStateChallengeSent:
 | |
| 			if from != e.addrPorts[senderIndex] {
 | |
| 				// sender source has changed
 | |
| 				return
 | |
| 			}
 | |
| 			if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
 | |
| 				// bad answer
 | |
| 				return
 | |
| 			}
 | |
| 			// sender is now bound
 | |
| 			// TODO: Consider installing a fast path via netfilter or similar to
 | |
| 			// relay (NAT) data packets for this serverEndpoint.
 | |
| 			e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
 | |
| 			// record last seen as bound time
 | |
| 			e.lastSeen[senderIndex] = time.Now()
 | |
| 			return
 | |
| 		default:
 | |
| 			// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
 | |
| 			// states, or we've already handled it
 | |
| 			return
 | |
| 		}
 | |
| 	default:
 | |
| 		// unexpected Disco message type
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
 | |
| 	senderRaw, isDiscoMsg := disco.Source(b)
 | |
| 	if !isDiscoMsg {
 | |
| 		// Not a Disco message
 | |
| 		return
 | |
| 	}
 | |
| 	sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
 | |
| 	senderIndex := -1
 | |
| 	switch {
 | |
| 	case sender.Compare(e.discoPubKeys[0]) == 0:
 | |
| 		senderIndex = 0
 | |
| 	case sender.Compare(e.discoPubKeys[1]) == 0:
 | |
| 		senderIndex = 1
 | |
| 	default:
 | |
| 		// unknown Disco public key
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
 | |
| 	discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
 | |
| 	if !ok {
 | |
| 		// unable to decrypt the Disco payload
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	discoMsg, err := disco.Parse(discoPayload)
 | |
| 	if err != nil {
 | |
| 		// unable to parse the Disco payload
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco)
 | |
| }
 | |
| 
 | |
| type udpWriter interface {
 | |
| 	WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error)
 | |
| }
 | |
| 
 | |
| func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
 | |
| 	if !gh.Control {
 | |
| 		if !e.isBound() {
 | |
| 			// not a control packet, but serverEndpoint isn't bound
 | |
| 			return
 | |
| 		}
 | |
| 		var to netip.AddrPort
 | |
| 		switch {
 | |
| 		case from == e.addrPorts[0]:
 | |
| 			e.lastSeen[0] = time.Now()
 | |
| 			to = e.addrPorts[1]
 | |
| 		case from == e.addrPorts[1]:
 | |
| 			e.lastSeen[1] = time.Now()
 | |
| 			to = e.addrPorts[0]
 | |
| 		default:
 | |
| 			// unrecognized source
 | |
| 			return
 | |
| 		}
 | |
| 		// relay packet
 | |
| 		uw.WriteMsgUDPAddrPort(b, nil, to)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if e.isBound() {
 | |
| 		// control packet, but serverEndpoint is already bound
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if gh.Protocol != packet.GeneveProtocolDisco {
 | |
| 		// control packet, but not Disco
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	msg := b[packet.GeneveFixedHeaderLength:]
 | |
| 	e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco)
 | |
| }
 | |
| 
 | |
| func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
 | |
| 	if !e.isBound() {
 | |
| 		if now.Sub(e.allocatedAt) > bindLifetime {
 | |
| 			return true
 | |
| 		}
 | |
| 		return false
 | |
| 	}
 | |
| 	if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime {
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // isBound returns true if both clients have completed their 3-way handshake,
 | |
| // otherwise false.
 | |
| func (e *serverEndpoint) isBound() bool {
 | |
| 	return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
 | |
| 		e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
 | |
| }
 | |
| 
 | |
| // NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
 | |
| // supported. Port may be 0, and what ultimately gets bound is returned as
 | |
| // 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
 | |
| // [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint()
 | |
| // requests.
 | |
| //
 | |
| // TODO: IPv6 support
 | |
| // TODO: dynamic addrs:port discovery
 | |
| func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
 | |
| 	s = &Server{
 | |
| 		disco:               key.NewDisco(),
 | |
| 		bindLifetime:        defaultBindLifetime,
 | |
| 		steadyStateLifetime: defaultSteadyStateLifetime,
 | |
| 		closeCh:             make(chan struct{}),
 | |
| 		byDisco:             make(map[pairOfDiscoPubKeys]*serverEndpoint),
 | |
| 		byVNI:               make(map[uint32]*serverEndpoint),
 | |
| 	}
 | |
| 	s.discoPublic = s.disco.Public()
 | |
| 	// TODO: instead of allocating 10s of MBs for the full pool, allocate
 | |
| 	// smaller chunks and increase as needed
 | |
| 	s.vniPool = make([]uint32, 0, 1<<24-1)
 | |
| 	for i := 1; i < 1<<24; i++ {
 | |
| 		s.vniPool = append(s.vniPool, uint32(i))
 | |
| 	}
 | |
| 	boundPort, err = s.listenOn(port)
 | |
| 	if err != nil {
 | |
| 		return nil, 0, err
 | |
| 	}
 | |
| 	addrPorts := make([]netip.AddrPort, 0, len(addrs))
 | |
| 	for _, addr := range addrs {
 | |
| 		addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
 | |
| 		if err != nil {
 | |
| 			return nil, 0, err
 | |
| 		}
 | |
| 		addrPorts = append(addrPorts, addrPort)
 | |
| 	}
 | |
| 	s.addrPorts = addrPorts
 | |
| 	s.wg.Add(2)
 | |
| 	go s.packetReadLoop()
 | |
| 	go s.endpointGCLoop()
 | |
| 	return s, boundPort, nil
 | |
| }
 | |
| 
 | |
| func (s *Server) listenOn(port int) (int, error) {
 | |
| 	uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	// TODO: set IP_PKTINFO sockopt
 | |
| 	_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
 | |
| 	if err != nil {
 | |
| 		s.uc.Close()
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	boundPort, err := strconv.Atoi(boundPortStr)
 | |
| 	if err != nil {
 | |
| 		s.uc.Close()
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	s.uc = uc
 | |
| 	return boundPort, nil
 | |
| }
 | |
| 
 | |
| // Close closes the server.
 | |
| func (s *Server) Close() error {
 | |
| 	s.closeOnce.Do(func() {
 | |
| 		s.mu.Lock()
 | |
| 		defer s.mu.Unlock()
 | |
| 		s.uc.Close()
 | |
| 		close(s.closeCh)
 | |
| 		s.wg.Wait()
 | |
| 		clear(s.byVNI)
 | |
| 		clear(s.byDisco)
 | |
| 		s.vniPool = nil
 | |
| 		s.closed = true
 | |
| 	})
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (s *Server) endpointGCLoop() {
 | |
| 	defer s.wg.Done()
 | |
| 	ticker := time.NewTicker(s.bindLifetime)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	gc := func() {
 | |
| 		now := time.Now()
 | |
| 		// TODO: consider performance implications of scanning all endpoints and
 | |
| 		// holding s.mu for the duration. Keep it simple (and slow) for now.
 | |
| 		s.mu.Lock()
 | |
| 		defer s.mu.Unlock()
 | |
| 		for k, v := range s.byDisco {
 | |
| 			if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
 | |
| 				delete(s.byDisco, k)
 | |
| 				delete(s.byVNI, v.vni)
 | |
| 				s.vniPool = append(s.vniPool, v.vni)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ticker.C:
 | |
| 			gc()
 | |
| 		case <-s.closeCh:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
 | |
| 	gh := packet.GeneveHeader{}
 | |
| 	err := gh.Decode(b)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	// TODO: consider performance implications of holding s.mu for the remainder
 | |
| 	// of this method, which does a bunch of disco/crypto work depending. Keep
 | |
| 	// it simple (and slow) for now.
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 	e, ok := s.byVNI[gh.VNI]
 | |
| 	if !ok {
 | |
| 		// unknown VNI
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	e.handlePacket(from, gh, b, uw, s.discoPublic)
 | |
| }
 | |
| 
 | |
| func (s *Server) packetReadLoop() {
 | |
| 	defer func() {
 | |
| 		s.wg.Done()
 | |
| 		s.Close()
 | |
| 	}()
 | |
| 	b := make([]byte, 1<<16-1)
 | |
| 	for {
 | |
| 		// TODO: extract laddr from IP_PKTINFO for use in reply
 | |
| 		n, from, err := s.uc.ReadFromUDPAddrPort(b)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		s.handlePacket(from, b[:n], s.uc)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var ErrServerClosed = errors.New("server closed")
 | |
| 
 | |
| // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
 | |
| // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
 | |
| // it is returned without modification/reallocation. AllocateEndpoint returns
 | |
| // [ErrServerClosed] if the server has been closed.
 | |
| func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) {
 | |
| 	s.mu.Lock()
 | |
| 	defer s.mu.Unlock()
 | |
| 	if s.closed {
 | |
| 		return endpoint.ServerEndpoint{}, ErrServerClosed
 | |
| 	}
 | |
| 
 | |
| 	if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
 | |
| 		return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString())
 | |
| 	}
 | |
| 
 | |
| 	pair := newPairOfDiscoPubKeys(discoA, discoB)
 | |
| 	e, ok := s.byDisco[pair]
 | |
| 	if ok {
 | |
| 		// Return the existing allocation. Clients can resolve duplicate
 | |
| 		// [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID].
 | |
| 		//
 | |
| 		// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
 | |
| 		// to give the client a more accurate picture of the bind window.
 | |
| 		return endpoint.ServerEndpoint{
 | |
| 			ServerDisco:         s.discoPublic,
 | |
| 			AddrPorts:           s.addrPorts,
 | |
| 			VNI:                 e.vni,
 | |
| 			LamportID:           e.lamportID,
 | |
| 			BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},
 | |
| 			SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime},
 | |
| 		}, nil
 | |
| 	}
 | |
| 
 | |
| 	if len(s.vniPool) == 0 {
 | |
| 		return endpoint.ServerEndpoint{}, errors.New("VNI pool exhausted")
 | |
| 	}
 | |
| 
 | |
| 	s.lamportID++
 | |
| 	e = &serverEndpoint{
 | |
| 		discoPubKeys: pair,
 | |
| 		lamportID:    s.lamportID,
 | |
| 		allocatedAt:  time.Now(),
 | |
| 	}
 | |
| 	e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
 | |
| 	e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
 | |
| 	e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
 | |
| 	rand.Read(e.challenge[0][:])
 | |
| 	rand.Read(e.challenge[1][:])
 | |
| 
 | |
| 	s.byDisco[pair] = e
 | |
| 	s.byVNI[e.vni] = e
 | |
| 
 | |
| 	return endpoint.ServerEndpoint{
 | |
| 		ServerDisco:         s.discoPublic,
 | |
| 		AddrPorts:           s.addrPorts,
 | |
| 		VNI:                 e.vni,
 | |
| 		LamportID:           e.lamportID,
 | |
| 		BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},
 | |
| 		SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime},
 | |
| 	}, nil
 | |
| }
 |