net/udprelay: use batching.Conn (#16866)

This significantly improves throughput of a peer relay server on Linux.

Server.packetReadLoop no longer passes sockets down the stack. Instead,
packet handling methods return a netip.AddrPort and []byte, which
packetReadLoop gathers together for eventual batched writes on the
appropriate socket(s).

Updates tailscale/corp#31164

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2025-08-19 14:44:39 -07:00 committed by GitHub
parent 5c560d7489
commit d4b7200129
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 153 additions and 63 deletions

View File

@ -311,7 +311,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+
tailscale.com/metrics from tailscale.com/derp+ tailscale.com/metrics from tailscale.com/derp+
tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+
💣 tailscale.com/net/batching from tailscale.com/wgengine/magicsock 💣 tailscale.com/net/batching from tailscale.com/wgengine/magicsock+
tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+
tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/connstats from tailscale.com/net/tstun+
tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ tailscale.com/net/dns from tailscale.com/cmd/tailscaled+

View File

@ -32,7 +32,6 @@ type Conn interface {
// message may fall on either side of a nonzero. // message may fall on either side of a nonzero.
// //
// Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize(). // Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize().
// len(msgs) must be at least MinReadBatchMsgsLen().
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
// WriteBatchTo writes buffs to addr. // WriteBatchTo writes buffs to addr.
// //

View File

@ -19,3 +19,5 @@ var controlMessageSize = 0
func MinControlMessageSize() int { func MinControlMessageSize() int {
return controlMessageSize return controlMessageSize
} }
const IdealBatchSize = 1

View File

@ -384,7 +384,7 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
} }
// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades // TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// pconn to a [Conn] if appropriate. A batch size of MinReadBatchMsgsLen() is // pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is
// suggested for the best performance. // suggested for the best performance.
func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
@ -457,6 +457,4 @@ func MinControlMessageSize() int {
return controlMessageSize return controlMessageSize
} }
func MinReadBatchMsgsLen() int { const IdealBatchSize = 128
return 128
}

View File

@ -310,7 +310,7 @@ func TestMinReadBatchMsgsLen(t *testing.T) {
// So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is // So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
// shaped for wireguard-go to control packet memory, these values should be // shaped for wireguard-go to control packet memory, these values should be
// aligned. // aligned.
if MinReadBatchMsgsLen() != conn.IdealBatchSize { if IdealBatchSize != conn.IdealBatchSize {
t.Fatalf("MinReadBatchMsgsLen():%d != conn.IdealBatchSize(): %d", MinReadBatchMsgsLen(), conn.IdealBatchSize) t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize)
} }
} }

View File

@ -20,8 +20,11 @@ import (
"time" "time"
"go4.org/mem" "go4.org/mem"
"golang.org/x/net/ipv6"
"tailscale.com/client/local" "tailscale.com/client/local"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/batching"
"tailscale.com/net/netaddr"
"tailscale.com/net/netcheck" "tailscale.com/net/netcheck"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/packet" "tailscale.com/net/packet"
@ -57,9 +60,9 @@ type Server struct {
bindLifetime time.Duration bindLifetime time.Duration
steadyStateLifetime time.Duration steadyStateLifetime time.Duration
bus *eventbus.Bus bus *eventbus.Bus
uc4 *net.UDPConn // always non-nil uc4 batching.Conn // always non-nil
uc4Port uint16 // always nonzero uc4Port uint16 // always nonzero
uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization uc6 batching.Conn // may be nil if IPv6 bind fails during initialization
uc6Port uint16 // may be zero if IPv6 bind fails during initialization uc6Port uint16 // may be zero if IPv6 bind fails during initialization
closeOnce sync.Once closeOnce sync.Once
wg sync.WaitGroup wg sync.WaitGroup
@ -96,9 +99,9 @@ type serverEndpoint struct {
allocatedAt time.Time allocatedAt time.Time
} }
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, conn *net.UDPConn, serverDisco key.DiscoPublic) { func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
if senderIndex != 0 && senderIndex != 1 { if senderIndex != 0 && senderIndex != 1 {
return return nil, netip.AddrPort{}
} }
otherSender := 0 otherSender := 0
@ -121,15 +124,15 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil { if err != nil {
// silently drop // silently drop
return return nil, netip.AddrPort{}
} }
if discoMsg.Generation == 0 { if discoMsg.Generation == 0 {
// Generation must be nonzero, silently drop // Generation must be nonzero, silently drop
return return nil, netip.AddrPort{}
} }
if e.handshakeGeneration[senderIndex] == discoMsg.Generation { if e.handshakeGeneration[senderIndex] == discoMsg.Generation {
// we've seen this generation before, silently drop // we've seen this generation before, silently drop
return return nil, netip.AddrPort{}
} }
e.handshakeGeneration[senderIndex] = discoMsg.Generation e.handshakeGeneration[senderIndex] = discoMsg.Generation
e.handshakeAddrPorts[senderIndex] = from e.handshakeAddrPorts[senderIndex] = from
@ -144,19 +147,18 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
gh.VNI.Set(e.vni) gh.VNI.Set(e.vni)
err = gh.Encode(reply) err = gh.Encode(reply)
if err != nil { if err != nil {
return return nil, netip.AddrPort{}
} }
reply = append(reply, disco.Magic...) reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply) reply = serverDisco.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...) reply = append(reply, box...)
conn.WriteMsgUDPAddrPort(reply, nil, from) return reply, from
return
case *disco.BindUDPRelayEndpointAnswer: case *disco.BindUDPRelayEndpointAnswer:
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil { if err != nil {
// silently drop // silently drop
return return nil, netip.AddrPort{}
} }
generation := e.handshakeGeneration[senderIndex] generation := e.handshakeGeneration[senderIndex]
if generation == 0 || // we have no active handshake if generation == 0 || // we have no active handshake
@ -164,23 +166,23 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake
!bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake
// silently drop // silently drop
return return nil, netip.AddrPort{}
} }
// Handshake complete. Update the binding for this sender. // Handshake complete. Update the binding for this sender.
e.boundAddrPorts[senderIndex] = from e.boundAddrPorts[senderIndex] = from
e.lastSeen[senderIndex] = time.Now() // record last seen as bound time e.lastSeen[senderIndex] = time.Now() // record last seen as bound time
return return nil, netip.AddrPort{}
default: default:
// unexpected message types, silently drop // unexpected message types, silently drop
return return nil, netip.AddrPort{}
} }
} }
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, conn *net.UDPConn, serverDisco key.DiscoPublic) { func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
senderRaw, isDiscoMsg := disco.Source(b) senderRaw, isDiscoMsg := disco.Source(b)
if !isDiscoMsg { if !isDiscoMsg {
// Not a Disco message // Not a Disco message
return return nil, netip.AddrPort{}
} }
sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
senderIndex := -1 senderIndex := -1
@ -191,63 +193,51 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by
senderIndex = 1 senderIndex = 1
default: default:
// unknown Disco public key // unknown Disco public key
return return nil, netip.AddrPort{}
} }
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:]) discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
if !ok { if !ok {
// unable to decrypt the Disco payload // unable to decrypt the Disco payload
return return nil, netip.AddrPort{}
} }
discoMsg, err := disco.Parse(discoPayload) discoMsg, err := disco.Parse(discoPayload)
if err != nil { if err != nil {
// unable to parse the Disco payload // unable to parse the Disco payload
return return nil, netip.AddrPort{}
} }
e.handleDiscoControlMsg(from, senderIndex, discoMsg, conn, serverDisco) return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco)
} }
func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, rxSocket, otherAFSocket *net.UDPConn, serverDisco key.DiscoPublic) { func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
if !gh.Control { if !gh.Control {
if !e.isBound() { if !e.isBound() {
// not a control packet, but serverEndpoint isn't bound // not a control packet, but serverEndpoint isn't bound
return return nil, netip.AddrPort{}
} }
var to netip.AddrPort
switch { switch {
case from == e.boundAddrPorts[0]: case from == e.boundAddrPorts[0]:
e.lastSeen[0] = time.Now() e.lastSeen[0] = time.Now()
to = e.boundAddrPorts[1] return b, e.boundAddrPorts[1]
case from == e.boundAddrPorts[1]: case from == e.boundAddrPorts[1]:
e.lastSeen[1] = time.Now() e.lastSeen[1] = time.Now()
to = e.boundAddrPorts[0] return b, e.boundAddrPorts[0]
default: default:
// unrecognized source // unrecognized source
return return nil, netip.AddrPort{}
} }
// Relay the packet towards the other party via the socket associated
// with the destination's address family. If source and destination
// address families are matching we tx on the same socket the packet
// was received (rxSocket), otherwise we use the "other" socket
// (otherAFSocket). [Server] makes no use of dual-stack sockets.
if from.Addr().Is4() == to.Addr().Is4() {
rxSocket.WriteMsgUDPAddrPort(b, nil, to)
} else if otherAFSocket != nil {
otherAFSocket.WriteMsgUDPAddrPort(b, nil, to)
}
return
} }
if gh.Protocol != packet.GeneveProtocolDisco { if gh.Protocol != packet.GeneveProtocolDisco {
// control packet, but not Disco // control packet, but not Disco
return return nil, netip.AddrPort{}
} }
msg := b[packet.GeneveFixedHeaderLength:] msg := b[packet.GeneveFixedHeaderLength:]
e.handleSealedDiscoControlMsg(from, msg, rxSocket, serverDisco) return e.handleSealedDiscoControlMsg(from, msg, serverDisco)
} }
func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
@ -338,10 +328,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
} }
s.wg.Add(1) s.wg.Add(1)
go s.packetReadLoop(s.uc4, s.uc6) go s.packetReadLoop(s.uc4, s.uc6, true)
if s.uc6 != nil { if s.uc6 != nil {
s.wg.Add(1) s.wg.Add(1)
go s.packetReadLoop(s.uc6, s.uc4) go s.packetReadLoop(s.uc6, s.uc4, false)
} }
s.wg.Add(1) s.wg.Add(1)
go s.endpointGCLoop() go s.endpointGCLoop()
@ -425,6 +415,41 @@ func (s *Server) addrDiscoveryLoop() {
} }
} }
// This is a compile-time assertion that [singlePacketConn] implements the
// [batching.Conn] interface.
var _ batching.Conn = (*singlePacketConn)(nil)
// singlePacketConn implements [batching.Conn] with single packet syscall
// operations.
type singlePacketConn struct {
*net.UDPConn
}
func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) {
n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0])
if err != nil {
return 0, err
}
msgs[0].N = n
msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap))
return 1, nil
}
func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
for _, buff := range buffs {
if geneve.VNI.IsSet() {
geneve.Encode(buff)
} else {
buff = buff[offset:]
}
_, err := c.UDPConn.WriteToUDPAddrPort(buff, addr)
if err != nil {
return err
}
}
return nil
}
// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if // listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if
// we manage to bind the IPv4 socket. // we manage to bind the IPv4 socket.
// //
@ -433,7 +458,10 @@ func (s *Server) addrDiscoveryLoop() {
// across IPv4 and IPv6 if the requested port is zero. // across IPv4 and IPv6 if the requested port is zero.
// //
// TODO: make these "re-bindable" in similar fashion to magicsock as a means to // TODO: make these "re-bindable" in similar fashion to magicsock as a means to
// deal with EDR software closing them. http://go/corp/30118 // deal with EDR software closing them. http://go/corp/30118. We could re-use
// [magicsock.RebindingConn], which would also remove the need for
// [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to
// single packet syscall operations.
func (s *Server) listenOn(port int) error { func (s *Server) listenOn(port int) error {
for _, network := range []string{"udp4", "udp6"} { for _, network := range []string{"udp4", "udp6"} {
uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
@ -462,11 +490,16 @@ func (s *Server) listenOn(port int) error {
} }
return err return err
} }
pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize)
bc, ok := pc.(batching.Conn)
if !ok {
bc = &singlePacketConn{uc}
}
if network == "udp4" { if network == "udp4" {
s.uc4 = uc s.uc4 = bc
s.uc4Port = uint16(portUint) s.uc4Port = uint16(portUint)
} else { } else {
s.uc6 = uc s.uc6 = bc
s.uc6Port = uint16(portUint) s.uc6Port = uint16(portUint)
} }
} }
@ -526,18 +559,18 @@ func (s *Server) endpointGCLoop() {
} }
} }
func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSocket *net.UDPConn) { func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to netip.AddrPort) {
if stun.Is(b) && b[1] == 0x01 { if stun.Is(b) && b[1] == 0x01 {
// A b[1] value of 0x01 (STUN method binding) is sufficiently // A b[1] value of 0x01 (STUN method binding) is sufficiently
// non-overlapping with the Geneve header where the LSB is always 0 // non-overlapping with the Geneve header where the LSB is always 0
// (part of 6 "reserved" bits). // (part of 6 "reserved" bits).
s.netChecker.ReceiveSTUNPacket(b, from) s.netChecker.ReceiveSTUNPacket(b, from)
return return nil, netip.AddrPort{}
} }
gh := packet.GeneveHeader{} gh := packet.GeneveHeader{}
err := gh.Decode(b) err := gh.Decode(b)
if err != nil { if err != nil {
return return nil, netip.AddrPort{}
} }
// TODO: consider performance implications of holding s.mu for the remainder // TODO: consider performance implications of holding s.mu for the remainder
// of this method, which does a bunch of disco/crypto work depending. Keep // of this method, which does a bunch of disco/crypto work depending. Keep
@ -547,13 +580,13 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSo
e, ok := s.byVNI[gh.VNI.Get()] e, ok := s.byVNI[gh.VNI.Get()]
if !ok { if !ok {
// unknown VNI // unknown VNI
return return nil, netip.AddrPort{}
} }
e.handlePacket(from, gh, b, rxSocket, otherAFSocket, s.discoPublic) return e.handlePacket(from, gh, b, s.discoPublic)
} }
func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) { func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) {
defer func() { defer func() {
// We intentionally close the [Server] if we encounter a socket read // We intentionally close the [Server] if we encounter a socket read
// error below, at least until socket "re-binding" is implemented as // error below, at least until socket "re-binding" is implemented as
@ -564,15 +597,73 @@ func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) {
s.wg.Done() s.wg.Done()
s.Close() s.Close()
}() }()
b := make([]byte, 1<<16-1)
msgs := make([]ipv6.Message, batching.IdealBatchSize)
for i := range msgs {
msgs[i].OOB = make([]byte, batching.MinControlMessageSize())
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Buffers[0] = make([]byte, 1<<16-1)
}
writeBuffsByDest := make(map[netip.AddrPort][][]byte, batching.IdealBatchSize)
for { for {
for i := range msgs {
msgs[i] = ipv6.Message{Buffers: msgs[i].Buffers, OOB: msgs[i].OOB[:cap(msgs[i].OOB)]}
}
// TODO: extract laddr from IP_PKTINFO for use in reply // TODO: extract laddr from IP_PKTINFO for use in reply
n, from, err := readFromSocket.ReadFromUDPAddrPort(b) // ReadBatch will split coalesced datagrams before returning, which
// WriteBatchTo will re-coalesce further down. We _could_ be more
// efficient and not split datagrams that belong to the same VNI if they
// are non-control/handshake packets. We pay the memmove/memcopy
// performance penalty for now in the interest of simple single packet
// handlers.
n, err := readFromSocket.ReadBatch(msgs, 0)
if err != nil { if err != nil {
s.logf("error reading from socket(%v): %v", readFromSocket.LocalAddr(), err) s.logf("error reading from socket(%v): %v", readFromSocket.LocalAddr(), err)
return return
} }
s.handlePacket(from, b[:n], readFromSocket, otherSocket)
for _, msg := range msgs[:n] {
if msg.N == 0 {
continue
}
buf := msg.Buffers[0][:msg.N]
from := msg.Addr.(*net.UDPAddr).AddrPort()
write, to := s.handlePacket(from, buf)
if !to.IsValid() {
continue
}
if from.Addr().Is4() == to.Addr().Is4() || otherSocket != nil {
buffs, ok := writeBuffsByDest[to]
if !ok {
buffs = make([][]byte, 0, batching.IdealBatchSize)
}
buffs = append(buffs, write)
writeBuffsByDest[to] = buffs
} else {
// This is unexpected. We should never produce a packet to write
// to the "other" socket if the other socket is nil/unbound.
// [server.handlePacket] has to see a packet from a particular
// address family at least once in order for it to return a
// packet to write towards a dest for the same address family.
s.logf("[unexpected] packet from: %v produced packet to: %v while otherSocket is nil", from, to)
}
}
for dest, buffs := range writeBuffsByDest {
// Write the packet batches via the socket associated with the
// destination's address family. If source and destination address
// families are matching we tx on the same socket the packet was
// received, otherwise we use the "other" socket. [Server] makes no
// use of dual-stack sockets.
if dest.Addr().Is4() == readFromSocketIsIPv4 {
readFromSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0)
} else {
otherSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0)
}
delete(writeBuffsByDest, dest)
}
} }
} }