diff --git a/net/rioconn/config.go b/net/rioconn/config.go index 43d520d54..8581b4272 100644 --- a/net/rioconn/config.go +++ b/net/rioconn/config.go @@ -20,6 +20,11 @@ const ( // It should be smaller than or equal to the typical hardware offload size // to avoid software fallback. Mellanox NICs typically support up to 64000. defaultMaxUSOOffloadSize = 64000 + + // defaultMaxUROCoalesceSize is the default maximum coalesce size for URO. + // It should be greater than or equal to the typical hardware offload size + // to avoid software fallback. Mellanox NICs typically support up to 64000. + defaultMaxUROCoalesceSize = math.MaxUint16 ) // Config holds configuration for a RIO connection, independent of the transport protocol. @@ -95,6 +100,7 @@ func (o TxConfig) MaxPayloadLen() uint16 { type UDPConfig struct { Config uso USOConfig + uro UROConfig } // USO returns the UDP segmentation offload (USO) configuration. @@ -102,6 +108,11 @@ func (o UDPConfig) USO() *USOConfig { return &o.uso } +// URO returns the UDP receive segment coalescing offload (URO) configuration. +func (o UDPConfig) URO() *UROConfig { + return &o.uro +} + // USOConfig holds the UDP segmentation offload (USO) configuration. type USOConfig struct { enabled bool @@ -121,3 +132,23 @@ func (o USOConfig) MaxOffloadSize() uint16 { } return cmp.Or(o.maxOffloadSize, defaultMaxUSOOffloadSize) } + +// UROConfig holds the UDP receive segment coalescing offload (URO) configuration. +type UROConfig struct { + enabled bool + maxCoalesceSize uint16 // 0 means default (i.e., [defaultMaxUROCoalesceSize]) +} + +// Enabled reports whether URO is enabled. +func (o UROConfig) Enabled() bool { + return o.enabled +} + +// MaxCoalesceSize returns the maximum number of bytes from multiple packets +// that can be coalesced into a single receive buffer. +func (o UROConfig) MaxCoalesceSize() uint16 { + if !o.Enabled() { + return 0 + } + return cmp.Or(o.maxCoalesceSize, defaultMaxUROCoalesceSize) +} diff --git a/net/rioconn/conn.go b/net/rioconn/conn.go index f71071401..e05022e22 100644 --- a/net/rioconn/conn.go +++ b/net/rioconn/conn.go @@ -384,3 +384,19 @@ func WSAIoctlIn[Input any](conn syscall.Conn, code uint32, in Input) error { }) return cmp.Or(controlErr, err) } + +// SetSockOption sets a socket option on the connection's underlying socket +// using the provided value. It is a type-safe shorthand for calling +// [syscall.RawConn.Control] with a function that invokes +// [windows.Setsockopt] with the appropriate arguments. +func SetSockOption[T any](conn syscall.Conn, level int32, optname int32, value T) error { + rawConn, err := conn.SyscallConn() + if err != nil { + return err + } + controlErr := rawConn.Control(func(s uintptr) { + err = windows.Setsockopt(windows.Handle(s), level, optname, + (*byte)(unsafe.Pointer(&value)), int32(unsafe.Sizeof(value))) + }) + return cmp.Or(controlErr, err) +} diff --git a/net/rioconn/doc.go b/net/rioconn/doc.go index 2406f68fb..40c12d455 100644 --- a/net/rioconn/doc.go +++ b/net/rioconn/doc.go @@ -4,8 +4,8 @@ //go:build windows // Package rioconn provides [UDPConn], a UDP socket implementation -// that uses the Windows RIO API extensions and supports batched I/O -// and USO for improved performance on high-throughput UDP workloads. +// that uses the Windows RIO API extensions and supports batched I/O, +// USO and URO for improved performance on high-throughput UDP workloads. package rioconn import ( diff --git a/net/rioconn/offloads.go b/net/rioconn/offloads.go index d02b98f0d..14d63c889 100644 --- a/net/rioconn/offloads.go +++ b/net/rioconn/offloads.go @@ -6,7 +6,9 @@ package rioconn import ( "fmt" "io" + "net" + "golang.org/x/net/ipv6" "tailscale.com/net/packet" ) @@ -82,3 +84,34 @@ func coalescePackets( } return packets, bytes, packetSize, nil } + +// splitCoalescedPackets splits src into msgs, treating it as coalesced packets +// of packetSize. A packet is ignored if it does not fit in the destination buffer +// of the corresponding msg, in which case its bytes are not copied into msgs, +// but it still counts towards the packet count and bytes read from src. +// The final packet in src may be smaller than packetSize. +// +// If packetSize <= 0, it treats src as a single packet. +// A zero-length src is treated as a single zero-length packet. +// +// It returns the number of messages the caller should evaluate for nonzero len +// and the number of bytes read from src for those messages. +func splitCoalescedPackets(addr *net.UDPAddr, src []byte, packetSize int, msgs []ipv6.Message) (packets, bytes int) { + srcLen := len(src) + if packetSize <= 0 { + packetSize = srcLen + } + for ; packets < len(msgs) && (bytes < srcLen || packets == 0); packets++ { + packetLen := min(packetSize, srcLen-bytes) // last packet may be smaller + if packetLen <= len(msgs[packets].Buffers[0]) { + // TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying + // buffer to the reader until the next read or an explicit release. + msgs[packets].N = copy(msgs[packets].Buffers[0], src[bytes:bytes+packetLen]) + } else { + msgs[packets].N = 0 // packet is too large; ignore it + } + msgs[packets].Addr = addr + bytes += packetLen + } + return packets, bytes +} diff --git a/net/rioconn/offloads_test.go b/net/rioconn/offloads_test.go index 417d202e6..107e11a2a 100644 --- a/net/rioconn/offloads_test.go +++ b/net/rioconn/offloads_test.go @@ -5,8 +5,11 @@ package rioconn import ( "bytes" + "net" + "net/netip" "testing" + "golang.org/x/net/ipv6" "tailscale.com/net/packet" ) @@ -332,3 +335,244 @@ func TestCoalescePackets(t *testing.T) { }) } } + +func TestSplitCoalescedPackets(t *testing.T) { + addr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("192.0.2.0:50000")) + + tests := []struct { + name string + addr *net.UDPAddr + src []byte + packetSize int + msgs []ipv6.Message + wantMsgs []ipv6.Message + wantPackets int + wantBytes int + }{ + { + name: "single-packet/zero-length", + addr: addr, + src: []byte{}, // zero-length src is treated as a single zero-length packet + msgs: makeMessages(2, 10), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{ + {}, + }, + N: 0, + }, + }, + wantPackets: 1, + wantBytes: 0, + }, + { + name: "single-packet/no-packet-size", + addr: addr, + src: []byte{0x01, 0x02, 0x03}, + msgs: makeMessages(2, 10), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{{0x01, 0x02, 0x03}}, + N: 3, + }, + }, + wantPackets: 1, + wantBytes: 3, + }, + { + name: "single-packet/with-packet-size", + addr: addr, + src: []byte{0x01, 0x02, 0x03}, + packetSize: 3, + msgs: makeMessages(2, 10), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{{0x01, 0x02, 0x03}}, + N: 3, + }, + }, + wantPackets: 1, + wantBytes: 3, + }, + { + name: "single-packet/too-large-for-msg", + addr: addr, + src: []byte{0x01, 0x02, 0x03}, + msgs: makeMessages(2, 2), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{ + {}, + }, + N: 0, // no bytes copied + }, + }, + wantPackets: 1, // but the packet is still counted + wantBytes: 3, // and all bytes are still counted as read from src + }, + { + name: "single-packet/no-msgs", + addr: addr, + src: []byte{0x01, 0x02, 0x03}, + msgs: nil, // no msgs to copy into + wantMsgs: nil, + wantPackets: 0, + wantBytes: 0, + }, + { + name: "multiple-packets/equal-packet-size", + addr: addr, + src: []byte{ + 0x01, 0x02, 0x03, // first packet + 0x04, 0x05, 0x06, // second packet + }, + packetSize: 3, + msgs: makeMessages(3, 10), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{{0x01, 0x02, 0x03}}, + N: 3, + }, + { + Addr: addr, + Buffers: [][]byte{{0x04, 0x05, 0x06}}, + N: 3, + }, + }, + wantPackets: 2, + wantBytes: 6, + }, + { + name: "multiple-packets/last-packet-smaller", + addr: addr, + src: []byte{ + 0x01, 0x02, 0x03, // first packet + 0x04, 0x05, 0x06, // second packet + 0x07, 0x08, // third packet, smaller than packetSize, ends the batch + }, + packetSize: 3, + msgs: makeMessages(4, 10), + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{{0x01, 0x02, 0x03}}, + N: 3, + }, + { + Addr: addr, + Buffers: [][]byte{{0x04, 0x05, 0x06}}, + N: 3, + }, + { + Addr: addr, + Buffers: [][]byte{{0x07, 0x08}}, + N: 2, + }, + }, + wantPackets: 3, + wantBytes: 8, + }, + { + name: "multiple-packets/partial-fit", + addr: addr, + src: []byte{ + 0x01, 0x02, 0x03, // first packet + 0x04, 0x05, 0x06, // second packet + 0x07, 0x08, // third packet, smaller than packetSize, ends the batch + }, + packetSize: 3, + msgs: makeMessages(2, 10), // can only fit the first two packets + wantMsgs: []ipv6.Message{ + { + Addr: addr, + Buffers: [][]byte{{0x01, 0x02, 0x03}}, + N: 3, + }, + { + Addr: addr, + Buffers: [][]byte{{0x04, 0x05, 0x06}}, + N: 3, + }, + }, + wantPackets: 2, // the third packet is not included in the msgs + wantBytes: 6, // and only the first two packets' bytes are counted as read from src + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allocs := testing.AllocsPerRun(1000, func() { + packets, bytes := splitCoalescedPackets(tt.addr, tt.src, tt.packetSize, tt.msgs) + if packets != tt.wantPackets { + t.Errorf("packets: got %d; want %d", packets, tt.wantPackets) + } + if bytes != tt.wantBytes { + t.Errorf("bytes: got %d; want %d", bytes, tt.wantBytes) + } + checkMessagesEqual(t, tt.msgs[:packets], tt.wantMsgs) + }) + // Splitting packets should not cause any allocations. + if allocs != 0 { + t.Errorf("unexpected allocations: got %f; want 0", allocs) + } + }) + } +} + +func makeMessages(num, size int) []ipv6.Message { + msgs := make([]ipv6.Message, num) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Buffers[0] = make([]byte, size) + } + return msgs +} + +func checkMessagesEqual(t *testing.T, got, want []ipv6.Message) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("number of messages: got %d; want %d", len(got), len(want)) + } + for i := range got { + checkNetAddrEqual(t, got[i].Addr, want[i].Addr) + if got[i].N != want[i].N { + t.Fatalf("message %d, N: got %d; want %d", i, got[i].N, want[i].N) + } + if got[i].N > len(got[i].Buffers[0]) { + t.Fatalf("message %d, N: got %d exceeds buffer size %d", i, got[i].N, len(got[i].Buffers[0])) + } + gotBuff := got[i].Buffers[0][:got[i].N] + wantBuff := want[i].Buffers[0][:want[i].N] + if !bytes.Equal(gotBuff, wantBuff) { + t.Errorf("message %d, buffer: got %v; want %v", i, gotBuff, wantBuff) + } + } +} + +func checkNetAddrEqual(t *testing.T, got, want net.Addr) { + t.Helper() + if got == nil && want == nil { + return + } + if got == nil || want == nil { + t.Errorf("address: got %v; want %v", got, want) + return + } + switch got := got.(type) { + case *net.UDPAddr: + want, ok := want.(*net.UDPAddr) + if !ok { + t.Errorf("address type: got %T; want %T", got, want) + return + } + if got.AddrPort() != want.AddrPort() { + t.Errorf("address: got %v; want %v", got, want) + } + default: + t.Errorf("address type: got %T; want %T", got, want) + } +} diff --git a/net/rioconn/options.go b/net/rioconn/options.go index fb1705c7f..87ab476b9 100644 --- a/net/rioconn/options.go +++ b/net/rioconn/options.go @@ -83,3 +83,10 @@ func USO(enabled bool) UDPOption { opts.uso.enabled = enabled }) } + +// URO specifies whether UDP receive segment coalescing offload (URO) should be enabled. +func URO(enabled bool) UDPOption { + return udpOption(func(opts *UDPConfig) { + opts.uro.enabled = enabled + }) +} diff --git a/net/rioconn/udp_test.go b/net/rioconn/udp_test.go index 589265072..384d895e4 100644 --- a/net/rioconn/udp_test.go +++ b/net/rioconn/udp_test.go @@ -206,18 +206,21 @@ func FuzzUDPSendReceiveBatch(f *testing.F) { packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4} numIterations := []uint16{1024} uso := []bool{false, true} + uro := []bool{false, true} for _, packetLen := range packetSizes { for _, numIter := range numIterations { for _, batchSize := range batchSizes { for _, usoEnabled := range uso { - f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled) + for _, uroEnabled := range uro { + f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled, uroEnabled) + } } } } } - f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled bool) { + f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled, uroEnabled bool) { network := "udp4" maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4) @@ -245,6 +248,7 @@ func FuzzUDPSendReceiveBatch(f *testing.F) { []rioconn.UDPOption{ rioconn.RxMemoryLimit(512 << 10), rioconn.TxMemoryLimit(128 << 10), + rioconn.URO(uroEnabled), }, ) }) diff --git a/net/rioconn/udprx.go b/net/rioconn/udprx.go index e07b3047d..5fa6a00eb 100644 --- a/net/rioconn/udprx.go +++ b/net/rioconn/udprx.go @@ -25,17 +25,40 @@ import ( // otherwise specified by the method. type udpRx struct { udpNx + useURO bool // whether URO is enabled // pendingResultIdx is the index in [udpNx.results] // of the next pending result to process. pendingResultIdx int + // pendingResultOffset is the offset into the pending result's data. + // It is used when the result contains coalesced packets and only + // part of the data has been processed and returned to the caller. + pendingResultOffset int } // init initializes the receive half of a [UDPConn] with the // specified underlying connection and options. func (rx *udpRx) init(conn *conn, options UDPConfig) error { - // Without URO, the data buffer for each receive request only needs - // to hold a single packet's payload. - dataSize := min(options.Rx().MaxPayloadLen(), MaxUDPPayload) + var dataSize uint16 + if uro := options.URO(); uro.Enabled() { + // When URO is enabled, the data buffer for each receive request + // must be large enough to hold multiple coalesced packets up + // to the maximum coalescing size, or a single packet up to + // the maximum payload size, whichever is larger. + dataSize = max(uro.MaxCoalesceSize(), options.Rx().MaxPayloadLen()) + maxCoalesceSize := uint32(uro.MaxCoalesceSize()) + err := SetSockOption(conn, windows.IPPROTO_UDP, + windows.UDP_RECV_MAX_COALESCED_SIZE, + maxCoalesceSize, + ) + if err != nil { + return fmt.Errorf("failed to enable URO: %w", err) + } + rx.useURO = true + } else { + // Otherwise, the data buffer for each receive request only needs + // to hold a single packet's payload. + dataSize = min(options.Rx().MaxPayloadLen(), MaxUDPPayload) + } if err := rx.udpNx.init(conn, dataSize, options.Rx().MemoryLimit()); err != nil { return fmt.Errorf("failed to initialize udpRx: %w", err) } @@ -124,6 +147,7 @@ func (rx *udpRx) awaitCompletionsLocked() error { rx.results = rx.results[:cap(rx.results)] rx.pendingResultIdx = 0 + rx.pendingResultOffset = 0 var count uint32 for { @@ -172,6 +196,7 @@ func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error r, err := req.CompleteReceive(res.Status, res.BytesTransferred) if err != nil { rx.pendingResultIdx++ + rx.pendingResultOffset = 0 if err == windows.WSAEMSGSIZE { // The packet is larger than [RxConfig.MaxPayloadLen]. // Skip it and try to process the next one, if any. @@ -189,16 +214,22 @@ func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error return n, fmt.Errorf("invalid remote address: %w", err) } - if r.Len() <= len(msgs[n].Buffers[0]) { - // TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying - // buffer to the reader until the next read or an explicit release. - msgs[n].N = copy(msgs[n].Buffers[0], r.Bytes()) - } else { - msgs[n].N = 0 // packet is too large; ignore it + var packetSize uint32 + if rx.useURO { + // When URO is enabled, the result may contain multiple coalesced packets, + // so we need to get the size of the first packet to know how to split them. + packetSize = r.ControlMessages().GetUInt32(windows.IPPROTO_UDP, windows.UDP_COALESCED_INFO) } - msgs[n].Addr = udpAddr - rx.pendingResultIdx++ - n++ + packetsProcessed, bytesProcessed := splitCoalescedPackets( + udpAddr, r.Bytes()[rx.pendingResultOffset:], + int(packetSize), msgs[n:], + ) + rx.pendingResultOffset += bytesProcessed + if rx.pendingResultOffset >= r.Len() { + rx.pendingResultIdx++ + rx.pendingResultOffset = 0 + } + n += packetsProcessed } return n, nil }