net/rioconn: implement a batching UDPConn

In this commit, we implement a batching UDPConn using the previously added
RIO plumbing, the base conn, and requestRing. The non-batching Read/Write
methods wrap the batching variants.

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:22:08 -06:00
parent 05b7b04527
commit e5cb1f48a6
No known key found for this signature in database
6 changed files with 1084 additions and 0 deletions

View File

@ -6,6 +6,7 @@
package rioconn
import (
"cmp"
"errors"
"fmt"
"iter"
@ -13,6 +14,7 @@ import (
"net/netip"
"sync"
"syscall"
"unsafe"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/sys/windows"
@ -363,3 +365,22 @@ func rioSocket(family, sotype, proto int32) (windows.Handle, error) {
windows.WSA_FLAG_OVERLAPPED
return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags)
}
// WSAIoctlIn issues an ioctl command with the provided code and input value
// on the connection's underlying socket. It is a type-safe shorthand for calling
// [syscall.RawConn.Control] with a function that invokes [windows.WSAIoctl]
// with the appropriate arguments, without any output buffer.
func WSAIoctlIn[Input any](conn syscall.Conn, code uint32, in Input) error {
rawConn, err := conn.SyscallConn()
if err != nil {
return err
}
controlErr := rawConn.Control(func(s uintptr) {
ret := uint32(0)
err = windows.WSAIoctl(windows.Handle(s), code,
(*byte)(unsafe.Pointer(&in)), uint32(unsafe.Sizeof(in)),
nil, 0, &ret, nil, 0,
)
})
return cmp.Or(controlErr, err)
}

View File

@ -3,6 +3,9 @@
//go:build windows
// Package rioconn provides [UDPConn], a UDP socket implementation
// that uses the Windows RIO API extensions and supports batched I/O
// for improved performance on high-throughput UDP workloads.
package rioconn
import (

235
net/rioconn/udp.go Normal file
View File

@ -0,0 +1,235 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/sys/windows"
)
const (
// MaxUDPPayloadIPv4 is the maximum UDP payload size over IPv4.
// IPv4 total length is 65535 bytes, including:
// - 20-byte IPv4 header (no options)
// - 8-byte UDP header
MaxUDPPayloadIPv4 = 1<<16 - 1 - 20 - 8
// MaxUDPPayloadIPv6 is the maximum UDP payload size over IPv6.
// The IPv6 payload length field excludes the 40-byte base header
// and includes the 8-byte UDP header.
MaxUDPPayloadIPv6 = 1<<16 - 1 - 8
// MaxUDPPayload is the maximum UDP payload size across IP versions.
MaxUDPPayload = max(MaxUDPPayloadIPv4, MaxUDPPayloadIPv6)
)
// UDPConn implements a UDP socket using the Windows RIO API extensions.
// It supports batched I/O, UDP RSC Offload (URO), and UDP Segmentation
// Offload (USO) to improve performance in high-throughput UDP workloads.
type UDPConn struct {
config UDPConfig
*conn // the underlying socket connection with RIO extensions
udpRx // receiving half-connection
udpTx // transmitting half-connection
}
// ListenUDP listens for incoming UDP packets on the local address using
// the Registered Input/Output (RIO) API and supports URO and USO when
// available. It returns an error if RIO is not available.
//
// The network must be a UDP network name.
//
// If the IP field of addr is nil or an unspecified IP address,
// ListenUDP listens on all available IP addresses of the local system
// except multicast IP addresses. If the network is "udp" and the local
// IP is unspecified, ListenUDP listens on both IPv4 and IPv6 addresses.
//
// If the Port field of addr is 0, a port number is automatically
// chosen.
//
// The provided options are to configure various aspects of the connection,
// such as RIO buffer sizes, URO and USO parameters and other socket options.
func ListenUDP(network string, addr *net.UDPAddr, options ...UDPOption) (_ *UDPConn, err error) {
defer func() {
if err != nil {
err = &net.OpError{Op: "listen", Net: network, Addr: addr, Err: err}
}
}()
if err := Initialize(); err != nil {
return nil, err
}
udp := &UDPConn{}
for _, o := range options {
if o != nil {
o.applyUDP(&udp.config)
}
}
laddr, dualStack, err := addrPortFromUDPAddr(network, addr)
if err != nil {
return nil, err
}
// Create the underlying socket with Registered I/O extensions
// and bind it to the local address.
udp.conn, err = newConn(windows.SOCK_DGRAM, windows.IPPROTO_UDP,
dualStack, laddr, &udp.config.Config)
if err != nil {
return nil, err
}
defer func() {
// If initialization fails, close the connection to
// release any allocated resources.
if err != nil {
udp.Close()
}
}()
// Initialize the Rx and Tx halves of the connection,
// which includes allocating memory for RIO buffers
// and creating RIO completion queues for each half.
if err := udp.udpRx.init(udp.conn, udp.config); err != nil {
return nil, fmt.Errorf("failed to initialize Rx: %w", err)
}
if err := udp.udpTx.init(udp.conn, udp.config); err != nil {
return nil, fmt.Errorf("failed to initialize Tx: %w", err)
}
// Create the RIO request queue for the connection and associate it
// with the Rx and Tx completion queues.
if err := udp.createRequestQueue(
udp.udpRx.completionQueue(), udp.udpRx.maxOutstandingRequests(),
udp.udpTx.completionQueue(), udp.udpTx.maxOutstandingRequests(),
); err != nil {
return nil, fmt.Errorf("failed to create RIO request queue: %w", err)
}
// Disable reporting of ICMP "Port Unreachable" errors as socket errors (golang/go#5834).
// https://web.archive.org/web/20260208062329/https://support.microsoft.com/en-US/help/263823
if err := WSAIoctlIn(udp, windows.SIO_UDP_CONNRESET, uint32(0)); err != nil {
return nil, fmt.Errorf("failed to disable SIO_UDP_CONNRESET: %w", err)
}
// Post initial receive requests.
if err := udp.udpRx.postReceiveRequests(); err != nil {
return nil, fmt.Errorf("failed to post initial receive requests: %w", err)
}
return udp, nil
}
// Config returns the effective configuration of the connection.
// The returned value is immutable for the lifetime of the connection.
func (c *UDPConn) Config() *UDPConfig {
return &c.config
}
// SetDeadline implements [net.Conn.SetDeadline].
func (c *UDPConn) SetDeadline(t time.Time) error {
// TODO(nickkhyl): move this and the other deadline methods to the underlying [conn]?
err1 := c.SetReadDeadline(t)
err2 := c.SetWriteDeadline(t)
return errors.Join(err1, err2)
}
func (c *UDPConn) SetReadDeadline(t time.Time) error {
// TODO(nickkhyl): implement read and write deadlines
return fmt.Errorf("%w: (%T).SetReadDeadline is not yet implemented", errors.ErrUnsupported, c)
}
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
// TODO(nickkhyl): implement read and write deadlines
return fmt.Errorf("%w: (%T).SetWriteDeadline is not yet implemented", errors.ErrUnsupported, c)
}
// Close closes the connection, canceling any pending operations,
// and freeing all associated resources.
func (c *UDPConn) Close() error {
if err := c.conn.Close(); err != nil {
return err
}
// Close the Rx and Tx halves only after closing the underlying connection.
// This ensures that all in-flight requests complete and that nothing uses
// the RIO buffers or completion queues after they are closed.
return errors.Join(c.udpRx.Close(), c.udpTx.Close())
}
// udpNx is a base struct for [udpRx] and [udpTx] half-connections
// that contains common state and logic.
type udpNx struct {
conn *conn
// mu protects the fields below and serializes access to the completion queue.
// Lock order: udpNx.mu > conn.mu.
mu sync.Mutex
requests *requestRing // ring of RIO request contexts for this half-connection
cq winrio.Cq // completion queue associated with this half-connection
hasCompletionsEvt windows.Handle // signaled by RIO when there are completions to dequeue.
results []winrio.Result // dequeued completion results
}
// init initializes the common state for [udpRx] or [udpTx].
// The conn parameter is the underlying connection associated with this half-connection.
// The dataSize parameter specifies the size of the data buffer for each request in the ring,
// and memoryLimit specifies the maximum total memory used by all requests.
func (nx *udpNx) init(conn *conn, dataSize uint16, memoryLimit uintptr) (err error) {
defer func() {
if err != nil {
nx.Close()
}
}()
if nx.requests, err = newRequestRing(dataSize, memoryLimit); err != nil {
return fmt.Errorf("failed to create request ring: %w", err)
}
if nx.hasCompletionsEvt, err = windows.CreateEvent(nil, 0, 0, nil); err != nil {
return fmt.Errorf("failed to create completion event: %w", err)
}
nx.results = make([]winrio.Result, 0, nx.requests.Cap())
if nx.cq, err = winrio.CreateEventCompletionQueue(nx.requests.Cap(), nx.hasCompletionsEvt, true); err != nil {
return fmt.Errorf("failed to create completion queue: %w", err)
}
nx.conn = conn
return nil
}
// completionQueue returns the RIO completion queue used by
// the half-connection for completion notifications.
func (nx *udpNx) completionQueue() winrio.Cq {
return nx.cq
}
// maxOutstandingRequests returns the maximum number of in-flight
// requests the half-connection can post to the RIO request queue
// without blocking.
func (nx *udpNx) maxOutstandingRequests() uint32 {
return nx.requests.Cap()
}
// Close releases all resources associated with the half-connection.
// It must not be called until the connection using this
// half-connection's buffers and completion queue is closed.
func (nx *udpNx) Close() error {
nx.mu.Lock()
defer nx.mu.Unlock()
if nx.cq != 0 {
winrio.CloseCompletionQueue(nx.cq)
nx.cq = 0
}
if nx.hasCompletionsEvt != 0 {
windows.CloseHandle(nx.hasCompletionsEvt)
nx.hasCompletionsEvt = 0
}
if nx.requests != nil {
if err := nx.requests.Close(); err != nil {
return err
}
nx.requests = nil
}
return nil
}

430
net/rioconn/udp_test.go Normal file
View File

@ -0,0 +1,430 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn_test
import (
"bytes"
"cmp"
"fmt"
"net"
"net/netip"
"slices"
"syscall"
"testing"
"golang.org/x/net/ipv6"
"tailscale.com/net/batching"
"tailscale.com/net/packet"
"tailscale.com/net/rioconn"
"tailscale.com/types/nettype"
)
// [UDPConn] implements the following interfaces.
var (
_ batching.Conn = (*rioconn.UDPConn)(nil)
_ net.PacketConn = (*rioconn.UDPConn)(nil)
_ nettype.PacketConn = (*rioconn.UDPConn)(nil)
_ syscall.Conn = (*rioconn.UDPConn)(nil)
)
func TestListenUDP(t *testing.T) {
tests := []struct {
network string
address string
wantLocalAddrPort netip.AddrPort
wantDualStack bool
}{
{
network: "udp", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"),
},
{
network: "udp4", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"),
},
{
network: "udp", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"),
},
{
network: "udp6", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"),
},
{
network: "udp", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
},
{
network: "udp4", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
},
{
network: "udp", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
},
{
network: "udp6", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
},
{
network: "udp", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), wantDualStack: true,
},
{
network: "udp4", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
},
{
network: "udp6", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
},
{
network: "udp", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: true,
},
{
network: "udp4", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:41613"), wantDualStack: false,
},
{
network: "udp6", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: false,
},
}
for _, tt := range tests {
t.Run(tt.network+"/"+tt.address, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.address)
if err != nil {
t.Fatalf("ResolveUDPAddr(%q, %q) error: %v", tt.network, tt.address, err)
}
conn, err := rioconn.ListenUDP(tt.network, addr)
if err != nil {
t.Fatalf("ListenUDP(%q, %q) error: %v", tt.network, tt.address, err)
}
t.Cleanup(func() {
err := conn.Close()
if err != nil {
t.Errorf("Close() error: %v", err)
}
})
gotAddressPort := conn.LocalAddrPort()
if wantAddress := tt.wantLocalAddrPort.Addr(); gotAddressPort.Addr().Compare(wantAddress) != 0 {
t.Errorf("LocalAddrPort() Addr = %v; want %v", gotAddressPort.Addr(), tt.wantLocalAddrPort.Addr())
}
if wantPort := tt.wantLocalAddrPort.Port(); wantPort != 0 && gotAddressPort.Port() != wantPort {
t.Errorf("LocalAddrPort() Port = %v; want %v", gotAddressPort.Port(), wantPort)
}
if gotDualStack := conn.IsDualStack(); gotDualStack != tt.wantDualStack {
t.Errorf("IsDualStack() = %v; want %v", gotDualStack, tt.wantDualStack)
}
})
}
}
func TestUDPSendReceiveBatch(t *testing.T) {
const defaultBatchSize = 64
t.Parallel()
tests := []struct {
name string
network string
pattern []int
iterations int
sendBatchSize int
receiveBatchSize int
}{
{
name: "udp4/single",
network: "udp4",
pattern: []int{1312},
},
{
name: "udp4/single/max",
network: "udp4",
pattern: []int{rioconn.MaxUDPPayloadIPv4},
},
{
name: "udp4/batch/max",
network: "udp4",
pattern: []int{rioconn.MaxUDPPayloadIPv4},
},
{
name: "udp6/single",
network: "udp6",
pattern: []int{1312},
},
{
name: "udp6/single/max",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
},
{
name: "udp6/batch/max",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
udpSendReceiveBatch(t,
tt.pattern, max(1, tt.iterations),
cmp.Or(tt.sendBatchSize, defaultBatchSize),
cmp.Or(tt.receiveBatchSize, defaultBatchSize),
tt.network, tt.network,
nil, nil,
)
})
}
}
func FuzzUDPSendReceiveBatch(f *testing.F) {
batchSizes := []uint16{1, 64}
packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4}
numIterations := []uint16{1024}
for _, packetLen := range packetSizes {
for _, numIter := range numIterations {
for _, batchSize := range batchSizes {
f.Add(packetLen, numIter, batchSize, batchSize)
}
}
}
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) {
network := "udp4"
maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4)
if packetLen > maxPacketLen {
t.Skipf("packetLen is too large: %d", packetLen)
}
if numIterations > 10_000 {
t.Skipf("numIterations is too large: %d", numIterations)
}
if sendBatchSize == 0 || sendBatchSize > 1024 {
t.Skipf("sendBatchSize is out of range: %d", sendBatchSize)
}
if receiveBatchSize == 0 || receiveBatchSize > 1024 {
t.Skipf("receiveBatchSize is out of range: %d", receiveBatchSize)
}
packetLengthPattern := []int{int(packetLen)}
udpSendReceiveBatch(t, packetLengthPattern, int(numIterations),
int(sendBatchSize), int(receiveBatchSize), network, network,
[]rioconn.UDPOption{
rioconn.RxMemoryLimit(128 << 10),
rioconn.TxMemoryLimit(512 << 10),
},
[]rioconn.UDPOption{
rioconn.RxMemoryLimit(512 << 10),
rioconn.TxMemoryLimit(128 << 10),
},
)
})
}
// udpSendReceive sends and receives batches of UDP packets between two
// [rioconn.UDPConn] instances over the loopback interface.
//
// It uses the provided packet length pattern, iteration count,
// batch sizes, networks, and connection options.
func udpSendReceiveBatch(
tb testing.TB,
packetLengthPattern []int,
numIterations int,
sendBatchSize, receiveBatchSize int,
senderNetwork, receiverNetwork string,
senderOpts, receiverOpts []rioconn.UDPOption,
) {
stopMsg := []byte("STOP")
sender, err := rioconn.ListenUDP(senderNetwork, loopbackUDPAddr(senderNetwork, 0), senderOpts...)
if err != nil {
tb.Fatalf("ListenUDP(%s, nil) error: %v", senderNetwork, err)
}
defer sender.Close()
receiver, err := rioconn.ListenUDP(receiverNetwork, loopbackUDPAddr(receiverNetwork, 0), receiverOpts...)
if err != nil {
tb.Fatalf("ListenUDP(%s, nil) error: %v", receiverNetwork, err)
}
defer receiver.Close()
// Do not allocate buffers larger than needed for the test.
maxPacketLen := max(len(stopMsg), slices.Max(packetLengthPattern))
outBuffs := make([][]byte, sendBatchSize)
for i := range outBuffs {
outBuffs[i] = make([]byte, maxPacketLen)
}
inMsgs := make([]ipv6.Message, receiveBatchSize)
for i := range inMsgs {
inMsgs[i].Buffers = make([][]byte, 1)
inMsgs[i].Buffers[0] = make([]byte, maxPacketLen)
}
readerResult := make(chan error, 1)
writerResult := make(chan error, 1)
go func() {
defer close(writerResult)
dstAddr := receiver.LocalAddrPort()
bytes := 0
packets := 0
iteration := 0
for iteration < numIterations {
outBuffs := outBuffs[:cap(outBuffs)]
for k := range outBuffs {
packetLen := packetLengthPattern[packets%len(packetLengthPattern)]
out := outBuffs[k][:packetLen]
outBuffs[k] = out
for j := 0; j < packetLen; j++ {
out[j] = byte('A' + bytes%26)
bytes++
}
packets++
if packets%len(packetLengthPattern) == 0 {
iteration++
}
if iteration >= numIterations {
outBuffs = outBuffs[:k+1]
break
}
}
if err := sender.WriteBatchTo(outBuffs, dstAddr, packet.GeneveHeader{}, 0); err != nil {
writerResult <- fmt.Errorf("failed to send batch #%d: %w", iteration, err)
return
}
}
tb.Logf("Writer done sending %d packets and %d bytes in %d iterations", packets, bytes, iteration)
tb.Logf("Sending STOP messages to signal the reader to stop")
for {
select {
case <-readerResult:
tb.Logf("Reader has stopped, no need to send more STOP messages")
return
default:
}
if _, err := sender.WriteTo(stopMsg, net.UDPAddrFromAddrPort(dstAddr)); err != nil {
writerResult <- fmt.Errorf("failed to send a STOP message: %w", err)
return
}
}
}()
go func() {
defer close(readerResult)
bytesReceived := 0
for {
n, err := receiver.ReadBatch(inMsgs, 0)
if err != nil {
readerResult <- fmt.Errorf("ReadBatch() error: %w", err)
return
}
for i := range n {
msg := inMsgs[i]
if bytes.Equal(msg.Buffers[0][:msg.N], stopMsg) {
tb.Logf("Received a STOP message, reader is stopping")
return
}
for j := 0; j < msg.N; j++ {
expectedByte := byte('A' + bytesReceived%26)
if msg.Buffers[0][j] != expectedByte {
readerResult <- fmt.Errorf("unexpected byte at position %d: got %v, want %v",
bytesReceived, msg.Buffers[0][j], expectedByte)
return
}
bytesReceived++
}
}
}
}()
if err := <-writerResult; err != nil {
tb.Fatalf("writer error: %v", err)
}
if err := <-readerResult; err != nil {
tb.Fatalf("reader error: %v", err)
}
}
func TestUDPReadWrite(t *testing.T) {
sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
if err != nil {
t.Fatalf("ListenUDP: %v", err)
}
defer sender.Close()
receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
if err != nil {
t.Fatalf("ListenUDP: %v", err)
}
defer receiver.Close()
message := []byte("Hello, world!")
n, err := sender.WriteTo(message, net.UDPAddrFromAddrPort(receiver.LocalAddrPort()))
if err != nil {
t.Fatalf("WriteTo: %v", err)
}
if n != len(message) {
t.Fatalf("WriteTo: wrote %d bytes, want %d", n, len(message))
}
buf := make([]byte, 1024)
n, addr, err := receiver.ReadFrom(buf)
if err != nil {
t.Fatalf("ReadFrom: %v", err)
}
if !bytes.Equal(buf[:n], message) {
t.Fatalf("ReadFrom: got %q, want %q", buf[:n], message)
}
if addr.String() != net.UDPAddrFromAddrPort(sender.LocalAddrPort()).String() {
t.Fatalf("ReadFrom: got addr %v, want %v", addr, sender.LocalAddrPort())
}
}
func TestUDPReadFromUDPAddrPort(t *testing.T) {
sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
if err != nil {
t.Fatalf("ListenUDP: %v", err)
}
defer sender.Close()
receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
if err != nil {
t.Fatalf("ListenUDP: %v", err)
}
defer receiver.Close()
message := []byte("Hello, world!")
n, err := sender.WriteToUDPAddrPort(message, receiver.LocalAddrPort())
if err != nil {
t.Fatalf("WriteToUDPAddrPort: %v", err)
}
if n != len(message) {
t.Fatalf("WriteToUDPAddrPort: wrote %d bytes, want %d", n, len(message))
}
buf := make([]byte, 1024)
n, addr, err := receiver.ReadFromUDPAddrPort(buf)
if err != nil {
t.Fatalf("ReadFromUDPAddrPort: %v", err)
}
if !bytes.Equal(buf[:n], message) {
t.Fatalf("ReadFromUDPAddrPort: got %q, want %q", buf[:n], message)
}
if addr != sender.LocalAddrPort() {
t.Fatalf("ReadFromUDPAddrPort: got addr %v, want %v", addr, sender.LocalAddrPort())
}
}
func loopbackUDPAddr(network string, port int) *net.UDPAddr {
switch network {
case "udp4":
return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}
case "udp6":
return &net.UDPAddr{IP: net.IPv6loopback, Port: port}
default:
panic(fmt.Sprintf("unsupported network: %s", network))
}
}

204
net/rioconn/udprx.go Normal file
View File

@ -0,0 +1,204 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"errors"
"fmt"
"net"
"net/netip"
"unsafe"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/net/ipv6"
"golang.org/x/sys/windows"
)
// udpRx is the receive half of [UDPConn].
//
// Its exported methods are safe for concurrent use.
// The caller must ensure that the connection is not closed
// while any unexported methods are in flight, unless
// otherwise specified by the method.
type udpRx struct {
udpNx
// pendingResultIdx is the index in [udpNx.results]
// of the next pending result to process.
pendingResultIdx int
}
// init initializes the receive half of a [UDPConn] with the
// specified underlying connection and options.
func (rx *udpRx) init(conn *conn, options UDPConfig) error {
// Without URO, the data buffer for each receive request only needs
// to hold a single packet's payload.
dataSize := min(options.Rx().MaxPayloadLen(), MaxUDPPayload)
if err := rx.udpNx.init(conn, dataSize, options.Rx().MemoryLimit()); err != nil {
return fmt.Errorf("failed to initialize udpRx: %w", err)
}
return nil
}
// ReadBatch implements [batching.Conn] by reading messages into msgs.
// It returns the number of messages the caller should evaluate for nonzero len,
// as a zero len message may fall on either side of a nonzero.
// The flags parameter is reserved for future use and must be zero.
func (rx *udpRx) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
// Prevent the connection from closing while in use.
if err := rx.conn.acquire(); err != nil {
return 0, &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err}
}
defer rx.conn.release()
rx.mu.Lock()
defer rx.mu.Unlock()
// Keep trying to read until we get at least one message or an error.
for n == 0 && err == nil {
if err := rx.awaitCompletionsLocked(); err != nil {
return 0, err
}
n, err = rx.processCompletionsLocked(msgs)
}
// Always try to post more receive requests, even if an error
// occurred while processing completed ones.
if postErr := rx.postReceiveRequestsLocked(); postErr != nil {
err = errors.Join(err, postErr)
}
if err != nil {
err = &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err}
}
return n, err
}
// ReadFromUDPAddrPort implements [nettype.PacketConn.ReadFromUDPAddrPort].
func (rx *udpRx) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
n, netAddr, err := rx.ReadFrom(p)
if netAddr != nil {
addr = netAddr.(*net.UDPAddr).AddrPort()
}
return n, addr, err
}
// ReadFrom implements [net.PacketConn.ReadFrom].
func (rx *udpRx) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
msgs := []ipv6.Message{{
Buffers: [][]byte{p},
}}
numMsgs, err := rx.ReadBatch(msgs, 0)
if numMsgs != 0 {
n = msgs[0].N
addr = msgs[0].Addr
}
return n, addr, err
}
// postReceiveRequests posts available receive requests to the
// RIO request queue. The caller must ensure that the connection
// is not closed until this call returns.
func (rx *udpRx) postReceiveRequests() error {
rx.mu.Lock()
defer rx.mu.Unlock()
return rx.postReceiveRequestsLocked()
}
// postReceiveRequestsLocked posts all available receive requests
// to the RIO request queue.
// rx.mu must be held.
func (rx *udpRx) postReceiveRequestsLocked() (err error) {
return rx.conn.postReceiveRequests(rx.requests.AcquireSeq())
}
// awaitCompletionsLocked dequeues completed receive requests, returning when
// there's at least one completion to process, the connection is closed,
// or an error occurs.
// rx.mu must be held.
func (rx *udpRx) awaitCompletionsLocked() error {
if rx.pendingResultIdx < len(rx.results) {
// We have already dequeued some completions that haven't been
// fully processed yet. Return immediately.
return nil
}
rx.results = rx.results[:cap(rx.results)]
rx.pendingResultIdx = 0
var count uint32
for {
if count = winrio.DequeueCompletion(rx.cq, rx.results[:]); count != 0 {
// Got new completions to process, no need to wait.
break
}
// Otherwise, arm the notification...
if err := winrio.Notify(rx.cq); err != nil {
return err
}
// ...and wait until RIO signals that more completions are available
// or the connection is closed.
handles := []windows.Handle{rx.conn.closedEvt, rx.hasCompletionsEvt}
switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); {
case err != nil:
return fmt.Errorf("waiting for completed receives failed: %w", err)
case evtIdx == 0:
return net.ErrClosed
case evtIdx == 1:
continue // try dequeueing completions again
default:
panic("unreachable")
}
}
rx.results = rx.results[:count]
return nil
}
// processCompletionsLocked processes completed receive requests and fills msgs
// with the received packets. It returns the number of messages the caller
// should evaluate for nonzero len, as a zero len message may fall on either
// side of a nonzero.
// rx.mu must be held.
func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error) {
firstResultIdx := rx.pendingResultIdx
defer func() {
// Always release processed results, even if an error occurred.
rx.requests.ReleaseN(rx.pendingResultIdx - firstResultIdx)
}()
for rx.pendingResultIdx < len(rx.results) && n < len(msgs) {
res := &rx.results[rx.pendingResultIdx]
req := (*request)(unsafe.Pointer(uintptr(res.RequestContext)))
r, err := req.CompleteReceive(res.Status, res.BytesTransferred)
if err != nil {
rx.pendingResultIdx++
if err == windows.WSAEMSGSIZE {
// The packet is larger than [RxConfig.MaxPayloadLen].
// Skip it and try to process the next one, if any.
continue
}
// In case of other errors, skip the packet and return
// the error to the caller.
return n, err
}
// TODO(nickkhyl): Maintain an LRU cache of remote addresses to
// avoid allocating a new [netip.AddrPort] / [net.UDPAddr] for each packet.
// Profiling suggests this accounts for ~5% of total processing time.
udpAddr, err := r.RemoteAddr().ToUDPAddr()
if err != nil {
return n, fmt.Errorf("invalid remote address: %w", err)
}
if r.Len() <= len(msgs[n].Buffers[0]) {
// TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying
// buffer to the reader until the next read or an explicit release.
msgs[n].N = copy(msgs[n].Buffers[0], r.Bytes())
} else {
msgs[n].N = 0 // packet is too large; ignore it
}
msgs[n].Addr = udpAddr
rx.pendingResultIdx++
n++
}
return n, nil
}

191
net/rioconn/udptx.go Normal file
View File

@ -0,0 +1,191 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"errors"
"fmt"
"net"
"net/netip"
"unsafe"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/sys/windows"
"tailscale.com/net/packet"
)
// udpTx is the transmit half of [UDPConn].
//
// Its exported methods are safe for concurrent use.
// The caller must ensure that the connection is not closed
// while any unexported methods are in flight, unless
// otherwise specified by the method.
type udpTx struct {
udpNx
}
// init initializes the transmit half of a [UDPConn] with the
// specified underlying connection and options.
func (tx *udpTx) init(conn *conn, options UDPConfig) error {
// Without USO, the data buffer for each send request only needs to hold
// a single packet's payload.
dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil {
return fmt.Errorf("failed to initialize udpTx: %w", err)
}
return nil
}
// WriteBatchTo implements [batching.Conn.WriteBatchTo] by writing
// buffs to the specified remote address.
//
// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding
// offset, and offset must equal [packet.GeneveFixedHeaderLength].
// Otherwise, the space preceding offset is ignored.
func (tx *udpTx) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
if tx.conn.IsDualStack() && addr.Addr().Is4() {
// Convert to an IPv4-mapped IPv6 address
addr = netip.AddrPortFrom(netip.AddrFrom16(addr.Addr().As16()), addr.Port())
}
if err := tx.writeBatchTo(buffs, addr, geneve, offset); err != nil {
return &net.OpError{Op: "write", Net: tx.conn.Network(), Source: tx.conn.LocalAddr(), Addr: net.UDPAddrFromAddrPort(addr), Err: err}
}
return nil
}
// writeBatchTo implements [udpTx.WriteBatchTo]. It returns an
// error if the connection is already closed and prevents the
// connection from closing until it returns.
func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) (err error) {
if len(buffs) == 0 {
return nil
}
raddr, err := rawSockaddrFromAddrPort(addr)
if err != nil {
return fmt.Errorf("failed to convert address: %w", err)
}
// Prevent the connection from closing while in use.
if err := tx.conn.acquire(); err != nil {
return err
}
defer tx.conn.release()
tx.mu.Lock()
defer tx.mu.Unlock()
n := 0
defer func() {
if n != 0 {
if commitErr := tx.conn.commitSendRequests(); commitErr != nil {
err = errors.Join(err, commitErr)
}
}
}()
for n < len(buffs) {
if tx.conn.IsClosed() {
return net.ErrClosed
}
if err := tx.drainCompletionsLocked(); err != nil {
return err
}
req := tx.requests.Peek()
w := req.Writer()
w.SetRemoteAddr(raddr)
if geneve.VNI.IsSet() {
geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength)
geneve.Encode(geneveHeader[:])
}
if _, err := w.Write(buffs[n][offset:]); err != nil {
return err
}
if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil {
return fmt.Errorf("failed to post send request: %w", err)
}
tx.requests.Advance() // advance after posting the request
n++
}
return nil
}
// WriteToUDPAddrPort implements [nettype.PacketConn.WriteToUDPAddrPort].
func (tx *udpTx) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (n int, err error) {
if err := tx.WriteBatchTo([][]byte{p}, addr, packet.GeneveHeader{}, 0); err != nil {
return 0, err
}
return len(p), nil
}
// WriteTo implements [net.PacketConn.WriteTo].
func (tx *udpTx) WriteTo(p []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: tx.conn.Network(),
Source: tx.conn.LocalAddr(),
Addr: addr,
Err: net.InvalidAddrError("address is not a *net.UDPAddr"),
}
}
return tx.WriteToUDPAddrPort(p, udpAddr.AddrPort())
}
// drainCompletionsLocked dequeues and processes completed send requests
// until the request ring is not full (i.e., more requests can be posted)
// or the closedEvt is signaled.
//
// tx.mu must be held, and the caller must ensure that the connection
// is not closed until this call returns.
func (tx *udpTx) drainCompletionsLocked() error {
var count uint32
for {
if count = winrio.DequeueCompletion(tx.cq, tx.results[:cap(tx.results)]); count != 0 {
// Got new completions to process, no need to wait.
break
}
if !tx.requests.IsFull() {
// No completions to process, but also not all requests are in-flight,
// so no need to wait.
break
}
// Otherwise, if all requests are in flight, commit any deferred sends.
tx.conn.commitSendRequests()
// Then arm the notification...
if err := winrio.Notify(tx.cq); err != nil {
return err
}
// ...and wait for either RIO to signal that more completions are available,
// or the connection to be closed.
handles := []windows.Handle{tx.conn.closedEvt, tx.hasCompletionsEvt}
switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); {
case err != nil:
return fmt.Errorf("waiting for completed sends failed: %w", err)
case evtIdx == 0:
return net.ErrClosed
case evtIdx == 1:
continue // try dequeueing completions again
default:
panic("unreachable")
}
}
for _, res := range tx.results[:count] {
req := (*request)(unsafe.Pointer(uintptr(res.RequestContext)))
if err := req.CompleteSend(res.Status, res.BytesTransferred); err != nil {
// TODO(nickkhyl): Returning an error here does not make much sense.
// Increment a send error metric or log the error instead?
}
}
tx.results = tx.results[:0]
tx.requests.ReleaseN(int(count))
return nil
}