diff --git a/net/rioconn/conn.go b/net/rioconn/conn.go index f71719232..f71071401 100644 --- a/net/rioconn/conn.go +++ b/net/rioconn/conn.go @@ -6,6 +6,7 @@ package rioconn import ( + "cmp" "errors" "fmt" "iter" @@ -13,6 +14,7 @@ import ( "net/netip" "sync" "syscall" + "unsafe" "github.com/tailscale/wireguard-go/conn/winrio" "golang.org/x/sys/windows" @@ -363,3 +365,22 @@ func rioSocket(family, sotype, proto int32) (windows.Handle, error) { windows.WSA_FLAG_OVERLAPPED return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags) } + +// WSAIoctlIn issues an ioctl command with the provided code and input value +// on the connection's underlying socket. It is a type-safe shorthand for calling +// [syscall.RawConn.Control] with a function that invokes [windows.WSAIoctl] +// with the appropriate arguments, without any output buffer. +func WSAIoctlIn[Input any](conn syscall.Conn, code uint32, in Input) error { + rawConn, err := conn.SyscallConn() + if err != nil { + return err + } + controlErr := rawConn.Control(func(s uintptr) { + ret := uint32(0) + err = windows.WSAIoctl(windows.Handle(s), code, + (*byte)(unsafe.Pointer(&in)), uint32(unsafe.Sizeof(in)), + nil, 0, &ret, nil, 0, + ) + }) + return cmp.Or(controlErr, err) +} diff --git a/net/rioconn/doc.go b/net/rioconn/doc.go index 406cdef73..3ea1d49a8 100644 --- a/net/rioconn/doc.go +++ b/net/rioconn/doc.go @@ -3,6 +3,9 @@ //go:build windows +// Package rioconn provides [UDPConn], a UDP socket implementation +// that uses the Windows RIO API extensions and supports batched I/O +// for improved performance on high-throughput UDP workloads. package rioconn import ( diff --git a/net/rioconn/udp.go b/net/rioconn/udp.go new file mode 100644 index 000000000..4fb1a54df --- /dev/null +++ b/net/rioconn/udp.go @@ -0,0 +1,235 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/sys/windows" +) + +const ( + // MaxUDPPayloadIPv4 is the maximum UDP payload size over IPv4. + // IPv4 total length is 65535 bytes, including: + // - 20-byte IPv4 header (no options) + // - 8-byte UDP header + MaxUDPPayloadIPv4 = 1<<16 - 1 - 20 - 8 + // MaxUDPPayloadIPv6 is the maximum UDP payload size over IPv6. + // The IPv6 payload length field excludes the 40-byte base header + // and includes the 8-byte UDP header. + MaxUDPPayloadIPv6 = 1<<16 - 1 - 8 + // MaxUDPPayload is the maximum UDP payload size across IP versions. + MaxUDPPayload = max(MaxUDPPayloadIPv4, MaxUDPPayloadIPv6) +) + +// UDPConn implements a UDP socket using the Windows RIO API extensions. +// It supports batched I/O, UDP RSC Offload (URO), and UDP Segmentation +// Offload (USO) to improve performance in high-throughput UDP workloads. +type UDPConn struct { + config UDPConfig + + *conn // the underlying socket connection with RIO extensions + udpRx // receiving half-connection + udpTx // transmitting half-connection +} + +// ListenUDP listens for incoming UDP packets on the local address using +// the Registered Input/Output (RIO) API and supports URO and USO when +// available. It returns an error if RIO is not available. +// +// The network must be a UDP network name. +// +// If the IP field of addr is nil or an unspecified IP address, +// ListenUDP listens on all available IP addresses of the local system +// except multicast IP addresses. If the network is "udp" and the local +// IP is unspecified, ListenUDP listens on both IPv4 and IPv6 addresses. +// +// If the Port field of addr is 0, a port number is automatically +// chosen. +// +// The provided options are to configure various aspects of the connection, +// such as RIO buffer sizes, URO and USO parameters and other socket options. +func ListenUDP(network string, addr *net.UDPAddr, options ...UDPOption) (_ *UDPConn, err error) { + defer func() { + if err != nil { + err = &net.OpError{Op: "listen", Net: network, Addr: addr, Err: err} + } + }() + + if err := Initialize(); err != nil { + return nil, err + } + + udp := &UDPConn{} + for _, o := range options { + if o != nil { + o.applyUDP(&udp.config) + } + } + + laddr, dualStack, err := addrPortFromUDPAddr(network, addr) + if err != nil { + return nil, err + } + + // Create the underlying socket with Registered I/O extensions + // and bind it to the local address. + udp.conn, err = newConn(windows.SOCK_DGRAM, windows.IPPROTO_UDP, + dualStack, laddr, &udp.config.Config) + if err != nil { + return nil, err + } + defer func() { + // If initialization fails, close the connection to + // release any allocated resources. + if err != nil { + udp.Close() + } + }() + + // Initialize the Rx and Tx halves of the connection, + // which includes allocating memory for RIO buffers + // and creating RIO completion queues for each half. + if err := udp.udpRx.init(udp.conn, udp.config); err != nil { + return nil, fmt.Errorf("failed to initialize Rx: %w", err) + } + if err := udp.udpTx.init(udp.conn, udp.config); err != nil { + return nil, fmt.Errorf("failed to initialize Tx: %w", err) + } + // Create the RIO request queue for the connection and associate it + // with the Rx and Tx completion queues. + if err := udp.createRequestQueue( + udp.udpRx.completionQueue(), udp.udpRx.maxOutstandingRequests(), + udp.udpTx.completionQueue(), udp.udpTx.maxOutstandingRequests(), + ); err != nil { + return nil, fmt.Errorf("failed to create RIO request queue: %w", err) + } + // Disable reporting of ICMP "Port Unreachable" errors as socket errors (golang/go#5834). + // https://web.archive.org/web/20260208062329/https://support.microsoft.com/en-US/help/263823 + if err := WSAIoctlIn(udp, windows.SIO_UDP_CONNRESET, uint32(0)); err != nil { + return nil, fmt.Errorf("failed to disable SIO_UDP_CONNRESET: %w", err) + } + // Post initial receive requests. + if err := udp.udpRx.postReceiveRequests(); err != nil { + return nil, fmt.Errorf("failed to post initial receive requests: %w", err) + } + return udp, nil +} + +// Config returns the effective configuration of the connection. +// The returned value is immutable for the lifetime of the connection. +func (c *UDPConn) Config() *UDPConfig { + return &c.config +} + +// SetDeadline implements [net.Conn.SetDeadline]. +func (c *UDPConn) SetDeadline(t time.Time) error { + // TODO(nickkhyl): move this and the other deadline methods to the underlying [conn]? + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + return errors.Join(err1, err2) +} + +func (c *UDPConn) SetReadDeadline(t time.Time) error { + // TODO(nickkhyl): implement read and write deadlines + return fmt.Errorf("%w: (%T).SetReadDeadline is not yet implemented", errors.ErrUnsupported, c) +} + +func (c *UDPConn) SetWriteDeadline(t time.Time) error { + // TODO(nickkhyl): implement read and write deadlines + return fmt.Errorf("%w: (%T).SetWriteDeadline is not yet implemented", errors.ErrUnsupported, c) +} + +// Close closes the connection, canceling any pending operations, +// and freeing all associated resources. +func (c *UDPConn) Close() error { + if err := c.conn.Close(); err != nil { + return err + } + // Close the Rx and Tx halves only after closing the underlying connection. + // This ensures that all in-flight requests complete and that nothing uses + // the RIO buffers or completion queues after they are closed. + return errors.Join(c.udpRx.Close(), c.udpTx.Close()) +} + +// udpNx is a base struct for [udpRx] and [udpTx] half-connections +// that contains common state and logic. +type udpNx struct { + conn *conn + + // mu protects the fields below and serializes access to the completion queue. + // Lock order: udpNx.mu > conn.mu. + mu sync.Mutex + requests *requestRing // ring of RIO request contexts for this half-connection + cq winrio.Cq // completion queue associated with this half-connection + hasCompletionsEvt windows.Handle // signaled by RIO when there are completions to dequeue. + results []winrio.Result // dequeued completion results +} + +// init initializes the common state for [udpRx] or [udpTx]. +// The conn parameter is the underlying connection associated with this half-connection. +// The dataSize parameter specifies the size of the data buffer for each request in the ring, +// and memoryLimit specifies the maximum total memory used by all requests. +func (nx *udpNx) init(conn *conn, dataSize uint16, memoryLimit uintptr) (err error) { + defer func() { + if err != nil { + nx.Close() + } + }() + if nx.requests, err = newRequestRing(dataSize, memoryLimit); err != nil { + return fmt.Errorf("failed to create request ring: %w", err) + } + if nx.hasCompletionsEvt, err = windows.CreateEvent(nil, 0, 0, nil); err != nil { + return fmt.Errorf("failed to create completion event: %w", err) + } + nx.results = make([]winrio.Result, 0, nx.requests.Cap()) + if nx.cq, err = winrio.CreateEventCompletionQueue(nx.requests.Cap(), nx.hasCompletionsEvt, true); err != nil { + return fmt.Errorf("failed to create completion queue: %w", err) + } + nx.conn = conn + return nil +} + +// completionQueue returns the RIO completion queue used by +// the half-connection for completion notifications. +func (nx *udpNx) completionQueue() winrio.Cq { + return nx.cq +} + +// maxOutstandingRequests returns the maximum number of in-flight +// requests the half-connection can post to the RIO request queue +// without blocking. +func (nx *udpNx) maxOutstandingRequests() uint32 { + return nx.requests.Cap() +} + +// Close releases all resources associated with the half-connection. +// It must not be called until the connection using this +// half-connection's buffers and completion queue is closed. +func (nx *udpNx) Close() error { + nx.mu.Lock() + defer nx.mu.Unlock() + if nx.cq != 0 { + winrio.CloseCompletionQueue(nx.cq) + nx.cq = 0 + } + if nx.hasCompletionsEvt != 0 { + windows.CloseHandle(nx.hasCompletionsEvt) + nx.hasCompletionsEvt = 0 + } + if nx.requests != nil { + if err := nx.requests.Close(); err != nil { + return err + } + nx.requests = nil + } + return nil +} diff --git a/net/rioconn/udp_test.go b/net/rioconn/udp_test.go new file mode 100644 index 000000000..befe613f2 --- /dev/null +++ b/net/rioconn/udp_test.go @@ -0,0 +1,430 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn_test + +import ( + "bytes" + "cmp" + "fmt" + "net" + "net/netip" + "slices" + "syscall" + "testing" + + "golang.org/x/net/ipv6" + "tailscale.com/net/batching" + "tailscale.com/net/packet" + "tailscale.com/net/rioconn" + "tailscale.com/types/nettype" +) + +// [UDPConn] implements the following interfaces. +var ( + _ batching.Conn = (*rioconn.UDPConn)(nil) + _ net.PacketConn = (*rioconn.UDPConn)(nil) + _ nettype.PacketConn = (*rioconn.UDPConn)(nil) + _ syscall.Conn = (*rioconn.UDPConn)(nil) +) + +func TestListenUDP(t *testing.T) { + tests := []struct { + network string + address string + wantLocalAddrPort netip.AddrPort + wantDualStack bool + }{ + { + network: "udp", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"), + }, + { + network: "udp4", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"), + }, + { + network: "udp", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"), + }, + { + network: "udp6", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"), + }, + { + network: "udp", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"), + }, + { + network: "udp4", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"), + }, + { + network: "udp", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), + }, + { + network: "udp6", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), + }, + { + network: "udp", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), wantDualStack: true, + }, + { + network: "udp4", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"), + }, + { + network: "udp6", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), + }, + { + network: "udp", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: true, + }, + { + network: "udp4", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:41613"), wantDualStack: false, + }, + { + network: "udp6", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: false, + }, + } + for _, tt := range tests { + t.Run(tt.network+"/"+tt.address, func(t *testing.T) { + addr, err := net.ResolveUDPAddr(tt.network, tt.address) + if err != nil { + t.Fatalf("ResolveUDPAddr(%q, %q) error: %v", tt.network, tt.address, err) + } + + conn, err := rioconn.ListenUDP(tt.network, addr) + if err != nil { + t.Fatalf("ListenUDP(%q, %q) error: %v", tt.network, tt.address, err) + } + t.Cleanup(func() { + err := conn.Close() + if err != nil { + t.Errorf("Close() error: %v", err) + } + }) + + gotAddressPort := conn.LocalAddrPort() + if wantAddress := tt.wantLocalAddrPort.Addr(); gotAddressPort.Addr().Compare(wantAddress) != 0 { + t.Errorf("LocalAddrPort() Addr = %v; want %v", gotAddressPort.Addr(), tt.wantLocalAddrPort.Addr()) + } + if wantPort := tt.wantLocalAddrPort.Port(); wantPort != 0 && gotAddressPort.Port() != wantPort { + t.Errorf("LocalAddrPort() Port = %v; want %v", gotAddressPort.Port(), wantPort) + } + if gotDualStack := conn.IsDualStack(); gotDualStack != tt.wantDualStack { + t.Errorf("IsDualStack() = %v; want %v", gotDualStack, tt.wantDualStack) + } + }) + } +} + +func TestUDPSendReceiveBatch(t *testing.T) { + const defaultBatchSize = 64 + + t.Parallel() + + tests := []struct { + name string + network string + pattern []int + iterations int + sendBatchSize int + receiveBatchSize int + }{ + { + name: "udp4/single", + network: "udp4", + pattern: []int{1312}, + }, + { + name: "udp4/single/max", + network: "udp4", + pattern: []int{rioconn.MaxUDPPayloadIPv4}, + }, + { + name: "udp4/batch/max", + network: "udp4", + pattern: []int{rioconn.MaxUDPPayloadIPv4}, + }, + { + name: "udp6/single", + network: "udp6", + pattern: []int{1312}, + }, + { + name: "udp6/single/max", + network: "udp6", + pattern: []int{rioconn.MaxUDPPayloadIPv6}, + }, + { + name: "udp6/batch/max", + network: "udp6", + pattern: []int{rioconn.MaxUDPPayloadIPv6}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + udpSendReceiveBatch(t, + tt.pattern, max(1, tt.iterations), + cmp.Or(tt.sendBatchSize, defaultBatchSize), + cmp.Or(tt.receiveBatchSize, defaultBatchSize), + tt.network, tt.network, + nil, nil, + ) + }) + } +} + +func FuzzUDPSendReceiveBatch(f *testing.F) { + batchSizes := []uint16{1, 64} + packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4} + numIterations := []uint16{1024} + + for _, packetLen := range packetSizes { + for _, numIter := range numIterations { + for _, batchSize := range batchSizes { + f.Add(packetLen, numIter, batchSize, batchSize) + } + } + } + + f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) { + network := "udp4" + maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4) + + if packetLen > maxPacketLen { + t.Skipf("packetLen is too large: %d", packetLen) + } + if numIterations > 10_000 { + t.Skipf("numIterations is too large: %d", numIterations) + } + if sendBatchSize == 0 || sendBatchSize > 1024 { + t.Skipf("sendBatchSize is out of range: %d", sendBatchSize) + } + if receiveBatchSize == 0 || receiveBatchSize > 1024 { + t.Skipf("receiveBatchSize is out of range: %d", receiveBatchSize) + } + + packetLengthPattern := []int{int(packetLen)} + udpSendReceiveBatch(t, packetLengthPattern, int(numIterations), + int(sendBatchSize), int(receiveBatchSize), network, network, + []rioconn.UDPOption{ + rioconn.RxMemoryLimit(128 << 10), + rioconn.TxMemoryLimit(512 << 10), + }, + []rioconn.UDPOption{ + rioconn.RxMemoryLimit(512 << 10), + rioconn.TxMemoryLimit(128 << 10), + }, + ) + }) +} + +// udpSendReceive sends and receives batches of UDP packets between two +// [rioconn.UDPConn] instances over the loopback interface. +// +// It uses the provided packet length pattern, iteration count, +// batch sizes, networks, and connection options. +func udpSendReceiveBatch( + tb testing.TB, + packetLengthPattern []int, + numIterations int, + sendBatchSize, receiveBatchSize int, + senderNetwork, receiverNetwork string, + senderOpts, receiverOpts []rioconn.UDPOption, +) { + stopMsg := []byte("STOP") + + sender, err := rioconn.ListenUDP(senderNetwork, loopbackUDPAddr(senderNetwork, 0), senderOpts...) + if err != nil { + tb.Fatalf("ListenUDP(%s, nil) error: %v", senderNetwork, err) + } + defer sender.Close() + + receiver, err := rioconn.ListenUDP(receiverNetwork, loopbackUDPAddr(receiverNetwork, 0), receiverOpts...) + if err != nil { + tb.Fatalf("ListenUDP(%s, nil) error: %v", receiverNetwork, err) + } + defer receiver.Close() + + // Do not allocate buffers larger than needed for the test. + maxPacketLen := max(len(stopMsg), slices.Max(packetLengthPattern)) + + outBuffs := make([][]byte, sendBatchSize) + for i := range outBuffs { + outBuffs[i] = make([]byte, maxPacketLen) + } + + inMsgs := make([]ipv6.Message, receiveBatchSize) + for i := range inMsgs { + inMsgs[i].Buffers = make([][]byte, 1) + inMsgs[i].Buffers[0] = make([]byte, maxPacketLen) + } + + readerResult := make(chan error, 1) + writerResult := make(chan error, 1) + + go func() { + defer close(writerResult) + + dstAddr := receiver.LocalAddrPort() + + bytes := 0 + packets := 0 + iteration := 0 + for iteration < numIterations { + outBuffs := outBuffs[:cap(outBuffs)] + for k := range outBuffs { + packetLen := packetLengthPattern[packets%len(packetLengthPattern)] + out := outBuffs[k][:packetLen] + outBuffs[k] = out + for j := 0; j < packetLen; j++ { + out[j] = byte('A' + bytes%26) + bytes++ + } + packets++ + if packets%len(packetLengthPattern) == 0 { + iteration++ + } + if iteration >= numIterations { + outBuffs = outBuffs[:k+1] + break + } + } + if err := sender.WriteBatchTo(outBuffs, dstAddr, packet.GeneveHeader{}, 0); err != nil { + writerResult <- fmt.Errorf("failed to send batch #%d: %w", iteration, err) + return + } + } + + tb.Logf("Writer done sending %d packets and %d bytes in %d iterations", packets, bytes, iteration) + tb.Logf("Sending STOP messages to signal the reader to stop") + for { + select { + case <-readerResult: + tb.Logf("Reader has stopped, no need to send more STOP messages") + return + default: + } + + if _, err := sender.WriteTo(stopMsg, net.UDPAddrFromAddrPort(dstAddr)); err != nil { + writerResult <- fmt.Errorf("failed to send a STOP message: %w", err) + return + } + } + }() + + go func() { + defer close(readerResult) + + bytesReceived := 0 + for { + n, err := receiver.ReadBatch(inMsgs, 0) + if err != nil { + readerResult <- fmt.Errorf("ReadBatch() error: %w", err) + return + } + for i := range n { + msg := inMsgs[i] + if bytes.Equal(msg.Buffers[0][:msg.N], stopMsg) { + tb.Logf("Received a STOP message, reader is stopping") + return + } + for j := 0; j < msg.N; j++ { + expectedByte := byte('A' + bytesReceived%26) + if msg.Buffers[0][j] != expectedByte { + readerResult <- fmt.Errorf("unexpected byte at position %d: got %v, want %v", + bytesReceived, msg.Buffers[0][j], expectedByte) + return + } + bytesReceived++ + } + } + } + }() + + if err := <-writerResult; err != nil { + tb.Fatalf("writer error: %v", err) + } + if err := <-readerResult; err != nil { + tb.Fatalf("reader error: %v", err) + } +} + +func TestUDPReadWrite(t *testing.T) { + sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0)) + if err != nil { + t.Fatalf("ListenUDP: %v", err) + } + defer sender.Close() + + receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0)) + if err != nil { + t.Fatalf("ListenUDP: %v", err) + } + defer receiver.Close() + + message := []byte("Hello, world!") + + n, err := sender.WriteTo(message, net.UDPAddrFromAddrPort(receiver.LocalAddrPort())) + if err != nil { + t.Fatalf("WriteTo: %v", err) + } + if n != len(message) { + t.Fatalf("WriteTo: wrote %d bytes, want %d", n, len(message)) + } + + buf := make([]byte, 1024) + n, addr, err := receiver.ReadFrom(buf) + if err != nil { + t.Fatalf("ReadFrom: %v", err) + } + if !bytes.Equal(buf[:n], message) { + t.Fatalf("ReadFrom: got %q, want %q", buf[:n], message) + } + if addr.String() != net.UDPAddrFromAddrPort(sender.LocalAddrPort()).String() { + t.Fatalf("ReadFrom: got addr %v, want %v", addr, sender.LocalAddrPort()) + } +} + +func TestUDPReadFromUDPAddrPort(t *testing.T) { + sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0)) + if err != nil { + t.Fatalf("ListenUDP: %v", err) + } + defer sender.Close() + + receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0)) + if err != nil { + t.Fatalf("ListenUDP: %v", err) + } + defer receiver.Close() + + message := []byte("Hello, world!") + + n, err := sender.WriteToUDPAddrPort(message, receiver.LocalAddrPort()) + if err != nil { + t.Fatalf("WriteToUDPAddrPort: %v", err) + } + if n != len(message) { + t.Fatalf("WriteToUDPAddrPort: wrote %d bytes, want %d", n, len(message)) + } + + buf := make([]byte, 1024) + n, addr, err := receiver.ReadFromUDPAddrPort(buf) + if err != nil { + t.Fatalf("ReadFromUDPAddrPort: %v", err) + } + if !bytes.Equal(buf[:n], message) { + t.Fatalf("ReadFromUDPAddrPort: got %q, want %q", buf[:n], message) + } + if addr != sender.LocalAddrPort() { + t.Fatalf("ReadFromUDPAddrPort: got addr %v, want %v", addr, sender.LocalAddrPort()) + } +} + +func loopbackUDPAddr(network string, port int) *net.UDPAddr { + switch network { + case "udp4": + return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port} + case "udp6": + return &net.UDPAddr{IP: net.IPv6loopback, Port: port} + default: + panic(fmt.Sprintf("unsupported network: %s", network)) + } +} diff --git a/net/rioconn/udprx.go b/net/rioconn/udprx.go new file mode 100644 index 000000000..e07b3047d --- /dev/null +++ b/net/rioconn/udprx.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "errors" + "fmt" + "net" + "net/netip" + "unsafe" + + "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/net/ipv6" + "golang.org/x/sys/windows" +) + +// udpRx is the receive half of [UDPConn]. +// +// Its exported methods are safe for concurrent use. +// The caller must ensure that the connection is not closed +// while any unexported methods are in flight, unless +// otherwise specified by the method. +type udpRx struct { + udpNx + // pendingResultIdx is the index in [udpNx.results] + // of the next pending result to process. + pendingResultIdx int +} + +// init initializes the receive half of a [UDPConn] with the +// specified underlying connection and options. +func (rx *udpRx) init(conn *conn, options UDPConfig) error { + // Without URO, the data buffer for each receive request only needs + // to hold a single packet's payload. + dataSize := min(options.Rx().MaxPayloadLen(), MaxUDPPayload) + if err := rx.udpNx.init(conn, dataSize, options.Rx().MemoryLimit()); err != nil { + return fmt.Errorf("failed to initialize udpRx: %w", err) + } + return nil +} + +// ReadBatch implements [batching.Conn] by reading messages into msgs. +// It returns the number of messages the caller should evaluate for nonzero len, +// as a zero len message may fall on either side of a nonzero. +// The flags parameter is reserved for future use and must be zero. +func (rx *udpRx) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { + // Prevent the connection from closing while in use. + if err := rx.conn.acquire(); err != nil { + return 0, &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err} + } + defer rx.conn.release() + + rx.mu.Lock() + defer rx.mu.Unlock() + // Keep trying to read until we get at least one message or an error. + for n == 0 && err == nil { + if err := rx.awaitCompletionsLocked(); err != nil { + return 0, err + } + n, err = rx.processCompletionsLocked(msgs) + } + // Always try to post more receive requests, even if an error + // occurred while processing completed ones. + if postErr := rx.postReceiveRequestsLocked(); postErr != nil { + err = errors.Join(err, postErr) + } + if err != nil { + err = &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err} + } + return n, err +} + +// ReadFromUDPAddrPort implements [nettype.PacketConn.ReadFromUDPAddrPort]. +func (rx *udpRx) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + n, netAddr, err := rx.ReadFrom(p) + if netAddr != nil { + addr = netAddr.(*net.UDPAddr).AddrPort() + } + return n, addr, err +} + +// ReadFrom implements [net.PacketConn.ReadFrom]. +func (rx *udpRx) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + msgs := []ipv6.Message{{ + Buffers: [][]byte{p}, + }} + numMsgs, err := rx.ReadBatch(msgs, 0) + if numMsgs != 0 { + n = msgs[0].N + addr = msgs[0].Addr + } + return n, addr, err +} + +// postReceiveRequests posts available receive requests to the +// RIO request queue. The caller must ensure that the connection +// is not closed until this call returns. +func (rx *udpRx) postReceiveRequests() error { + rx.mu.Lock() + defer rx.mu.Unlock() + return rx.postReceiveRequestsLocked() +} + +// postReceiveRequestsLocked posts all available receive requests +// to the RIO request queue. +// rx.mu must be held. +func (rx *udpRx) postReceiveRequestsLocked() (err error) { + return rx.conn.postReceiveRequests(rx.requests.AcquireSeq()) +} + +// awaitCompletionsLocked dequeues completed receive requests, returning when +// there's at least one completion to process, the connection is closed, +// or an error occurs. +// rx.mu must be held. +func (rx *udpRx) awaitCompletionsLocked() error { + if rx.pendingResultIdx < len(rx.results) { + // We have already dequeued some completions that haven't been + // fully processed yet. Return immediately. + return nil + } + + rx.results = rx.results[:cap(rx.results)] + rx.pendingResultIdx = 0 + + var count uint32 + for { + if count = winrio.DequeueCompletion(rx.cq, rx.results[:]); count != 0 { + // Got new completions to process, no need to wait. + break + } + // Otherwise, arm the notification... + if err := winrio.Notify(rx.cq); err != nil { + return err + } + // ...and wait until RIO signals that more completions are available + // or the connection is closed. + handles := []windows.Handle{rx.conn.closedEvt, rx.hasCompletionsEvt} + switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); { + case err != nil: + return fmt.Errorf("waiting for completed receives failed: %w", err) + case evtIdx == 0: + return net.ErrClosed + case evtIdx == 1: + continue // try dequeueing completions again + default: + panic("unreachable") + } + } + rx.results = rx.results[:count] + return nil +} + +// processCompletionsLocked processes completed receive requests and fills msgs +// with the received packets. It returns the number of messages the caller +// should evaluate for nonzero len, as a zero len message may fall on either +// side of a nonzero. +// rx.mu must be held. +func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error) { + firstResultIdx := rx.pendingResultIdx + + defer func() { + // Always release processed results, even if an error occurred. + rx.requests.ReleaseN(rx.pendingResultIdx - firstResultIdx) + }() + + for rx.pendingResultIdx < len(rx.results) && n < len(msgs) { + res := &rx.results[rx.pendingResultIdx] + req := (*request)(unsafe.Pointer(uintptr(res.RequestContext))) + r, err := req.CompleteReceive(res.Status, res.BytesTransferred) + if err != nil { + rx.pendingResultIdx++ + if err == windows.WSAEMSGSIZE { + // The packet is larger than [RxConfig.MaxPayloadLen]. + // Skip it and try to process the next one, if any. + continue + } + // In case of other errors, skip the packet and return + // the error to the caller. + return n, err + } + // TODO(nickkhyl): Maintain an LRU cache of remote addresses to + // avoid allocating a new [netip.AddrPort] / [net.UDPAddr] for each packet. + // Profiling suggests this accounts for ~5% of total processing time. + udpAddr, err := r.RemoteAddr().ToUDPAddr() + if err != nil { + return n, fmt.Errorf("invalid remote address: %w", err) + } + + if r.Len() <= len(msgs[n].Buffers[0]) { + // TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying + // buffer to the reader until the next read or an explicit release. + msgs[n].N = copy(msgs[n].Buffers[0], r.Bytes()) + } else { + msgs[n].N = 0 // packet is too large; ignore it + } + msgs[n].Addr = udpAddr + rx.pendingResultIdx++ + n++ + } + return n, nil +} diff --git a/net/rioconn/udptx.go b/net/rioconn/udptx.go new file mode 100644 index 000000000..ee79a9d54 --- /dev/null +++ b/net/rioconn/udptx.go @@ -0,0 +1,191 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "errors" + "fmt" + "net" + "net/netip" + "unsafe" + + "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/sys/windows" + "tailscale.com/net/packet" +) + +// udpTx is the transmit half of [UDPConn]. +// +// Its exported methods are safe for concurrent use. +// The caller must ensure that the connection is not closed +// while any unexported methods are in flight, unless +// otherwise specified by the method. +type udpTx struct { + udpNx +} + +// init initializes the transmit half of a [UDPConn] with the +// specified underlying connection and options. +func (tx *udpTx) init(conn *conn, options UDPConfig) error { + // Without USO, the data buffer for each send request only needs to hold + // a single packet's payload. + dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload) + if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil { + return fmt.Errorf("failed to initialize udpTx: %w", err) + } + return nil +} + +// WriteBatchTo implements [batching.Conn.WriteBatchTo] by writing +// buffs to the specified remote address. +// +// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding +// offset, and offset must equal [packet.GeneveFixedHeaderLength]. +// Otherwise, the space preceding offset is ignored. +func (tx *udpTx) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error { + if tx.conn.IsDualStack() && addr.Addr().Is4() { + // Convert to an IPv4-mapped IPv6 address + addr = netip.AddrPortFrom(netip.AddrFrom16(addr.Addr().As16()), addr.Port()) + } + if err := tx.writeBatchTo(buffs, addr, geneve, offset); err != nil { + return &net.OpError{Op: "write", Net: tx.conn.Network(), Source: tx.conn.LocalAddr(), Addr: net.UDPAddrFromAddrPort(addr), Err: err} + } + return nil +} + +// writeBatchTo implements [udpTx.WriteBatchTo]. It returns an +// error if the connection is already closed and prevents the +// connection from closing until it returns. +func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) (err error) { + if len(buffs) == 0 { + return nil + } + + raddr, err := rawSockaddrFromAddrPort(addr) + if err != nil { + return fmt.Errorf("failed to convert address: %w", err) + } + + // Prevent the connection from closing while in use. + if err := tx.conn.acquire(); err != nil { + return err + } + defer tx.conn.release() + + tx.mu.Lock() + defer tx.mu.Unlock() + + n := 0 + defer func() { + if n != 0 { + if commitErr := tx.conn.commitSendRequests(); commitErr != nil { + err = errors.Join(err, commitErr) + } + } + }() + + for n < len(buffs) { + if tx.conn.IsClosed() { + return net.ErrClosed + } + if err := tx.drainCompletionsLocked(); err != nil { + return err + } + + req := tx.requests.Peek() + w := req.Writer() + w.SetRemoteAddr(raddr) + + if geneve.VNI.IsSet() { + geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength) + geneve.Encode(geneveHeader[:]) + } + if _, err := w.Write(buffs[n][offset:]); err != nil { + return err + } + + if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil { + return fmt.Errorf("failed to post send request: %w", err) + } + + tx.requests.Advance() // advance after posting the request + n++ + } + return nil +} + +// WriteToUDPAddrPort implements [nettype.PacketConn.WriteToUDPAddrPort]. +func (tx *udpTx) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (n int, err error) { + if err := tx.WriteBatchTo([][]byte{p}, addr, packet.GeneveHeader{}, 0); err != nil { + return 0, err + } + return len(p), nil +} + +// WriteTo implements [net.PacketConn.WriteTo]. +func (tx *udpTx) WriteTo(p []byte, addr net.Addr) (n int, err error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, &net.OpError{ + Op: "write", + Net: tx.conn.Network(), + Source: tx.conn.LocalAddr(), + Addr: addr, + Err: net.InvalidAddrError("address is not a *net.UDPAddr"), + } + } + return tx.WriteToUDPAddrPort(p, udpAddr.AddrPort()) +} + +// drainCompletionsLocked dequeues and processes completed send requests +// until the request ring is not full (i.e., more requests can be posted) +// or the closedEvt is signaled. +// +// tx.mu must be held, and the caller must ensure that the connection +// is not closed until this call returns. +func (tx *udpTx) drainCompletionsLocked() error { + var count uint32 + for { + if count = winrio.DequeueCompletion(tx.cq, tx.results[:cap(tx.results)]); count != 0 { + // Got new completions to process, no need to wait. + break + } + if !tx.requests.IsFull() { + // No completions to process, but also not all requests are in-flight, + // so no need to wait. + break + } + // Otherwise, if all requests are in flight, commit any deferred sends. + tx.conn.commitSendRequests() + // Then arm the notification... + if err := winrio.Notify(tx.cq); err != nil { + return err + } + // ...and wait for either RIO to signal that more completions are available, + // or the connection to be closed. + handles := []windows.Handle{tx.conn.closedEvt, tx.hasCompletionsEvt} + switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); { + case err != nil: + return fmt.Errorf("waiting for completed sends failed: %w", err) + case evtIdx == 0: + return net.ErrClosed + case evtIdx == 1: + continue // try dequeueing completions again + default: + panic("unreachable") + } + } + for _, res := range tx.results[:count] { + req := (*request)(unsafe.Pointer(uintptr(res.RequestContext))) + if err := req.CompleteSend(res.Status, res.BytesTransferred); err != nil { + // TODO(nickkhyl): Returning an error here does not make much sense. + // Increment a send error metric or log the error instead? + } + } + tx.results = tx.results[:0] + tx.requests.ReleaseN(int(count)) + return nil +}