net/rioconn: add URO support for UDPConn

In this commit, we implement UDP receive segment coalescing offload (URO) support.

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:25:06 -06:00
parent c4993354ed
commit 4ec71f6480
No known key found for this signature in database
8 changed files with 382 additions and 16 deletions

View File

@ -20,6 +20,11 @@ const (
// It should be smaller than or equal to the typical hardware offload size
// to avoid software fallback. Mellanox NICs typically support up to 64000.
defaultMaxUSOOffloadSize = 64000
// defaultMaxUROCoalesceSize is the default maximum coalesce size for URO.
// It should be greater than or equal to the typical hardware offload size
// to avoid software fallback. Mellanox NICs typically support up to 64000.
defaultMaxUROCoalesceSize = math.MaxUint16
)
// Config holds configuration for a RIO connection, independent of the transport protocol.
@ -95,6 +100,7 @@ func (o TxConfig) MaxPayloadLen() uint16 {
type UDPConfig struct {
Config
uso USOConfig
uro UROConfig
}
// USO returns the UDP segmentation offload (USO) configuration.
@ -102,6 +108,11 @@ func (o UDPConfig) USO() *USOConfig {
return &o.uso
}
// URO returns the UDP receive segment coalescing offload (URO) configuration.
func (o UDPConfig) URO() *UROConfig {
return &o.uro
}
// USOConfig holds the UDP segmentation offload (USO) configuration.
type USOConfig struct {
enabled bool
@ -121,3 +132,23 @@ func (o USOConfig) MaxOffloadSize() uint16 {
}
return cmp.Or(o.maxOffloadSize, defaultMaxUSOOffloadSize)
}
// UROConfig holds the UDP receive segment coalescing offload (URO) configuration.
type UROConfig struct {
enabled bool
maxCoalesceSize uint16 // 0 means default (i.e., [defaultMaxUROCoalesceSize])
}
// Enabled reports whether URO is enabled.
func (o UROConfig) Enabled() bool {
return o.enabled
}
// MaxCoalesceSize returns the maximum number of bytes from multiple packets
// that can be coalesced into a single receive buffer.
func (o UROConfig) MaxCoalesceSize() uint16 {
if !o.Enabled() {
return 0
}
return cmp.Or(o.maxCoalesceSize, defaultMaxUROCoalesceSize)
}

View File

@ -384,3 +384,19 @@ func WSAIoctlIn[Input any](conn syscall.Conn, code uint32, in Input) error {
})
return cmp.Or(controlErr, err)
}
// SetSockOption sets a socket option on the connection's underlying socket
// using the provided value. It is a type-safe shorthand for calling
// [syscall.RawConn.Control] with a function that invokes
// [windows.Setsockopt] with the appropriate arguments.
func SetSockOption[T any](conn syscall.Conn, level int32, optname int32, value T) error {
rawConn, err := conn.SyscallConn()
if err != nil {
return err
}
controlErr := rawConn.Control(func(s uintptr) {
err = windows.Setsockopt(windows.Handle(s), level, optname,
(*byte)(unsafe.Pointer(&value)), int32(unsafe.Sizeof(value)))
})
return cmp.Or(controlErr, err)
}

View File

@ -4,8 +4,8 @@
//go:build windows
// Package rioconn provides [UDPConn], a UDP socket implementation
// that uses the Windows RIO API extensions and supports batched I/O
// and USO for improved performance on high-throughput UDP workloads.
// that uses the Windows RIO API extensions and supports batched I/O,
// USO and URO for improved performance on high-throughput UDP workloads.
package rioconn
import (

View File

@ -6,7 +6,9 @@ package rioconn
import (
"fmt"
"io"
"net"
"golang.org/x/net/ipv6"
"tailscale.com/net/packet"
)
@ -82,3 +84,34 @@ func coalescePackets(
}
return packets, bytes, packetSize, nil
}
// splitCoalescedPackets splits src into msgs, treating it as coalesced packets
// of packetSize. A packet is ignored if it does not fit in the destination buffer
// of the corresponding msg, in which case its bytes are not copied into msgs,
// but it still counts towards the packet count and bytes read from src.
// The final packet in src may be smaller than packetSize.
//
// If packetSize <= 0, it treats src as a single packet.
// A zero-length src is treated as a single zero-length packet.
//
// It returns the number of messages the caller should evaluate for nonzero len
// and the number of bytes read from src for those messages.
func splitCoalescedPackets(addr *net.UDPAddr, src []byte, packetSize int, msgs []ipv6.Message) (packets, bytes int) {
srcLen := len(src)
if packetSize <= 0 {
packetSize = srcLen
}
for ; packets < len(msgs) && (bytes < srcLen || packets == 0); packets++ {
packetLen := min(packetSize, srcLen-bytes) // last packet may be smaller
if packetLen <= len(msgs[packets].Buffers[0]) {
// TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying
// buffer to the reader until the next read or an explicit release.
msgs[packets].N = copy(msgs[packets].Buffers[0], src[bytes:bytes+packetLen])
} else {
msgs[packets].N = 0 // packet is too large; ignore it
}
msgs[packets].Addr = addr
bytes += packetLen
}
return packets, bytes
}

View File

@ -5,8 +5,11 @@ package rioconn
import (
"bytes"
"net"
"net/netip"
"testing"
"golang.org/x/net/ipv6"
"tailscale.com/net/packet"
)
@ -332,3 +335,244 @@ func TestCoalescePackets(t *testing.T) {
})
}
}
func TestSplitCoalescedPackets(t *testing.T) {
addr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("192.0.2.0:50000"))
tests := []struct {
name string
addr *net.UDPAddr
src []byte
packetSize int
msgs []ipv6.Message
wantMsgs []ipv6.Message
wantPackets int
wantBytes int
}{
{
name: "single-packet/zero-length",
addr: addr,
src: []byte{}, // zero-length src is treated as a single zero-length packet
msgs: makeMessages(2, 10),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{
{},
},
N: 0,
},
},
wantPackets: 1,
wantBytes: 0,
},
{
name: "single-packet/no-packet-size",
addr: addr,
src: []byte{0x01, 0x02, 0x03},
msgs: makeMessages(2, 10),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{{0x01, 0x02, 0x03}},
N: 3,
},
},
wantPackets: 1,
wantBytes: 3,
},
{
name: "single-packet/with-packet-size",
addr: addr,
src: []byte{0x01, 0x02, 0x03},
packetSize: 3,
msgs: makeMessages(2, 10),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{{0x01, 0x02, 0x03}},
N: 3,
},
},
wantPackets: 1,
wantBytes: 3,
},
{
name: "single-packet/too-large-for-msg",
addr: addr,
src: []byte{0x01, 0x02, 0x03},
msgs: makeMessages(2, 2),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{
{},
},
N: 0, // no bytes copied
},
},
wantPackets: 1, // but the packet is still counted
wantBytes: 3, // and all bytes are still counted as read from src
},
{
name: "single-packet/no-msgs",
addr: addr,
src: []byte{0x01, 0x02, 0x03},
msgs: nil, // no msgs to copy into
wantMsgs: nil,
wantPackets: 0,
wantBytes: 0,
},
{
name: "multiple-packets/equal-packet-size",
addr: addr,
src: []byte{
0x01, 0x02, 0x03, // first packet
0x04, 0x05, 0x06, // second packet
},
packetSize: 3,
msgs: makeMessages(3, 10),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{{0x01, 0x02, 0x03}},
N: 3,
},
{
Addr: addr,
Buffers: [][]byte{{0x04, 0x05, 0x06}},
N: 3,
},
},
wantPackets: 2,
wantBytes: 6,
},
{
name: "multiple-packets/last-packet-smaller",
addr: addr,
src: []byte{
0x01, 0x02, 0x03, // first packet
0x04, 0x05, 0x06, // second packet
0x07, 0x08, // third packet, smaller than packetSize, ends the batch
},
packetSize: 3,
msgs: makeMessages(4, 10),
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{{0x01, 0x02, 0x03}},
N: 3,
},
{
Addr: addr,
Buffers: [][]byte{{0x04, 0x05, 0x06}},
N: 3,
},
{
Addr: addr,
Buffers: [][]byte{{0x07, 0x08}},
N: 2,
},
},
wantPackets: 3,
wantBytes: 8,
},
{
name: "multiple-packets/partial-fit",
addr: addr,
src: []byte{
0x01, 0x02, 0x03, // first packet
0x04, 0x05, 0x06, // second packet
0x07, 0x08, // third packet, smaller than packetSize, ends the batch
},
packetSize: 3,
msgs: makeMessages(2, 10), // can only fit the first two packets
wantMsgs: []ipv6.Message{
{
Addr: addr,
Buffers: [][]byte{{0x01, 0x02, 0x03}},
N: 3,
},
{
Addr: addr,
Buffers: [][]byte{{0x04, 0x05, 0x06}},
N: 3,
},
},
wantPackets: 2, // the third packet is not included in the msgs
wantBytes: 6, // and only the first two packets' bytes are counted as read from src
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
packets, bytes := splitCoalescedPackets(tt.addr, tt.src, tt.packetSize, tt.msgs)
if packets != tt.wantPackets {
t.Errorf("packets: got %d; want %d", packets, tt.wantPackets)
}
if bytes != tt.wantBytes {
t.Errorf("bytes: got %d; want %d", bytes, tt.wantBytes)
}
checkMessagesEqual(t, tt.msgs[:packets], tt.wantMsgs)
})
// Splitting packets should not cause any allocations.
if allocs != 0 {
t.Errorf("unexpected allocations: got %f; want 0", allocs)
}
})
}
}
func makeMessages(num, size int) []ipv6.Message {
msgs := make([]ipv6.Message, num)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Buffers[0] = make([]byte, size)
}
return msgs
}
func checkMessagesEqual(t *testing.T, got, want []ipv6.Message) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("number of messages: got %d; want %d", len(got), len(want))
}
for i := range got {
checkNetAddrEqual(t, got[i].Addr, want[i].Addr)
if got[i].N != want[i].N {
t.Fatalf("message %d, N: got %d; want %d", i, got[i].N, want[i].N)
}
if got[i].N > len(got[i].Buffers[0]) {
t.Fatalf("message %d, N: got %d exceeds buffer size %d", i, got[i].N, len(got[i].Buffers[0]))
}
gotBuff := got[i].Buffers[0][:got[i].N]
wantBuff := want[i].Buffers[0][:want[i].N]
if !bytes.Equal(gotBuff, wantBuff) {
t.Errorf("message %d, buffer: got %v; want %v", i, gotBuff, wantBuff)
}
}
}
func checkNetAddrEqual(t *testing.T, got, want net.Addr) {
t.Helper()
if got == nil && want == nil {
return
}
if got == nil || want == nil {
t.Errorf("address: got %v; want %v", got, want)
return
}
switch got := got.(type) {
case *net.UDPAddr:
want, ok := want.(*net.UDPAddr)
if !ok {
t.Errorf("address type: got %T; want %T", got, want)
return
}
if got.AddrPort() != want.AddrPort() {
t.Errorf("address: got %v; want %v", got, want)
}
default:
t.Errorf("address type: got %T; want %T", got, want)
}
}

View File

@ -83,3 +83,10 @@ func USO(enabled bool) UDPOption {
opts.uso.enabled = enabled
})
}
// URO specifies whether UDP receive segment coalescing offload (URO) should be enabled.
func URO(enabled bool) UDPOption {
return udpOption(func(opts *UDPConfig) {
opts.uro.enabled = enabled
})
}

View File

@ -206,18 +206,21 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4}
numIterations := []uint16{1024}
uso := []bool{false, true}
uro := []bool{false, true}
for _, packetLen := range packetSizes {
for _, numIter := range numIterations {
for _, batchSize := range batchSizes {
for _, usoEnabled := range uso {
f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled)
for _, uroEnabled := range uro {
f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled, uroEnabled)
}
}
}
}
}
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled bool) {
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled, uroEnabled bool) {
network := "udp4"
maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4)
@ -245,6 +248,7 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
[]rioconn.UDPOption{
rioconn.RxMemoryLimit(512 << 10),
rioconn.TxMemoryLimit(128 << 10),
rioconn.URO(uroEnabled),
},
)
})

View File

@ -25,17 +25,40 @@ import (
// otherwise specified by the method.
type udpRx struct {
udpNx
useURO bool // whether URO is enabled
// pendingResultIdx is the index in [udpNx.results]
// of the next pending result to process.
pendingResultIdx int
// pendingResultOffset is the offset into the pending result's data.
// It is used when the result contains coalesced packets and only
// part of the data has been processed and returned to the caller.
pendingResultOffset int
}
// init initializes the receive half of a [UDPConn] with the
// specified underlying connection and options.
func (rx *udpRx) init(conn *conn, options UDPConfig) error {
// Without URO, the data buffer for each receive request only needs
// to hold a single packet's payload.
dataSize := min(options.Rx().MaxPayloadLen(), MaxUDPPayload)
var dataSize uint16
if uro := options.URO(); uro.Enabled() {
// When URO is enabled, the data buffer for each receive request
// must be large enough to hold multiple coalesced packets up
// to the maximum coalescing size, or a single packet up to
// the maximum payload size, whichever is larger.
dataSize = max(uro.MaxCoalesceSize(), options.Rx().MaxPayloadLen())
maxCoalesceSize := uint32(uro.MaxCoalesceSize())
err := SetSockOption(conn, windows.IPPROTO_UDP,
windows.UDP_RECV_MAX_COALESCED_SIZE,
maxCoalesceSize,
)
if err != nil {
return fmt.Errorf("failed to enable URO: %w", err)
}
rx.useURO = true
} else {
// Otherwise, the data buffer for each receive request only needs
// to hold a single packet's payload.
dataSize = min(options.Rx().MaxPayloadLen(), MaxUDPPayload)
}
if err := rx.udpNx.init(conn, dataSize, options.Rx().MemoryLimit()); err != nil {
return fmt.Errorf("failed to initialize udpRx: %w", err)
}
@ -124,6 +147,7 @@ func (rx *udpRx) awaitCompletionsLocked() error {
rx.results = rx.results[:cap(rx.results)]
rx.pendingResultIdx = 0
rx.pendingResultOffset = 0
var count uint32
for {
@ -172,6 +196,7 @@ func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error
r, err := req.CompleteReceive(res.Status, res.BytesTransferred)
if err != nil {
rx.pendingResultIdx++
rx.pendingResultOffset = 0
if err == windows.WSAEMSGSIZE {
// The packet is larger than [RxConfig.MaxPayloadLen].
// Skip it and try to process the next one, if any.
@ -189,16 +214,22 @@ func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error
return n, fmt.Errorf("invalid remote address: %w", err)
}
if r.Len() <= len(msgs[n].Buffers[0]) {
// TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying
// buffer to the reader until the next read or an explicit release.
msgs[n].N = copy(msgs[n].Buffers[0], r.Bytes())
} else {
msgs[n].N = 0 // packet is too large; ignore it
var packetSize uint32
if rx.useURO {
// When URO is enabled, the result may contain multiple coalesced packets,
// so we need to get the size of the first packet to know how to split them.
packetSize = r.ControlMessages().GetUInt32(windows.IPPROTO_UDP, windows.UDP_COALESCED_INFO)
}
msgs[n].Addr = udpAddr
rx.pendingResultIdx++
n++
packetsProcessed, bytesProcessed := splitCoalescedPackets(
udpAddr, r.Bytes()[rx.pendingResultOffset:],
int(packetSize), msgs[n:],
)
rx.pendingResultOffset += bytesProcessed
if rx.pendingResultOffset >= r.Len() {
rx.pendingResultIdx++
rx.pendingResultOffset = 0
}
n += packetsProcessed
}
return n, nil
}