diff --git a/net/rioconn/cmsg.go b/net/rioconn/cmsg.go new file mode 100644 index 000000000..033e6f81e --- /dev/null +++ b/net/rioconn/cmsg.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "fmt" + "iter" + "unsafe" +) + +// _RIO_CMSG_BUFFER is the header of the control messages buffer for RIO operations, +// as defined by the Windows API. It is followed by zero or more control messages. +type _RIO_CMSG_BUFFER struct { + totalLength uint32 // total length of the buffer, including this header + // followed by control messages aligned to _WSA_CMSGHDR_ALIGN +} + +// _WSACMSGHDR is the header for a single control message, as defined by the Windows API. +type _WSACMSGHDR struct { + len uintptr + level int32 + typ int32 + // followed by data aligned to _WSA_CMSGDATA_ALIGN +} + +const ( + _MAX_NATURAL_ALIGN = unsafe.Alignof(uintptr(0)) + _WSA_CMSGHDR_ALIGN = unsafe.Alignof(_WSACMSGHDR{}) + _WSA_CMSGDATA_ALIGN = _MAX_NATURAL_ALIGN + _RIO_CMSG_BASE_SIZE = (unsafe.Sizeof(_RIO_CMSG_BUFFER{}) + + (_WSA_CMSGHDR_ALIGN - 1)) &^ (_WSA_CMSGHDR_ALIGN - 1) +) + +const ( + // controlMessagesSize is the target size of the [controlMessages] struct, + // which includes the header, padding for alignment, and space for control messages. + // It is somewhat arbitrary but is large enough to hold typical control messages. + controlMessagesSize = 64 + + // controlMessagesBufferSize is the size of the buffer for control messages, + // which is the total size minus the size of the header. + controlMessagesBufferSize = controlMessagesSize - unsafe.Sizeof(_RIO_CMSG_BUFFER{}) +) + +// controlMessages is a fixed-size control messages buffer for RIO operations. +// It is large enough to hold the header and typical control messages +// for either send or receive operations. +type controlMessages struct { + _RIO_CMSG_BUFFER + _ [controlMessagesBufferSize]byte // space for cmsgs +} + +// controlMessage represents a single control message. +type controlMessage struct { + Level int32 // protocol that originated the control information + Type int32 // protocol-specific type of control information + Data []byte // type-specific control data (backed by [controlMessages.buffer]) +} + +// Empty reports whether the control messages buffer contains no control messages. +func (cmsgs *controlMessages) Empty() bool { + return uintptr(cmsgs.totalLength) <= _RIO_CMSG_BASE_SIZE +} + +// All returns an iterator over all control messages in the buffer. +func (cmsgs *controlMessages) All() iter.Seq[controlMessage] { + return func(yield func(cmsg controlMessage) bool) { + offset := _RIO_CMSG_BASE_SIZE + totalLen := uintptr(cmsgs.totalLength) + if totalLen > unsafe.Sizeof(*cmsgs) { + panic("controlMessages buffer overflow") + } + for offset+unsafe.Sizeof(_WSACMSGHDR{}) <= totalLen { + hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset)) + dataOffset := alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN) + if hdr.len < dataOffset || offset+hdr.len > totalLen { + panic("invalid control message header length") + } + dataLen := uintptr(hdr.len) - dataOffset + data := unsafe.Slice((*byte)( + unsafe.Add(unsafe.Pointer(hdr), dataOffset)), + dataLen, + ) + if !yield(controlMessage{ + Level: hdr.level, + Type: hdr.typ, + Data: data, + }) { + break + } + offset += alignUp(hdr.len, _WSA_CMSGHDR_ALIGN) + } + } +} + +// GetOk retrieves the data for the first control message with the given +// level and type. It reports whether such a control message was found. +func (cmsgs *controlMessages) GetOk(level, ctype int32) (data []byte, ok bool) { + for cmsg := range cmsgs.All() { + if cmsg.Level == level && cmsg.Type == ctype { + return cmsg.Data, true + } + } + return nil, false +} + +// GetUInt32Ok is like GetOk but interprets the data as a uint32. +func (cmsgs *controlMessages) GetUInt32Ok(level, ctype int32) (val uint32, ok bool) { + data, ok := cmsgs.GetOk(level, ctype) + if !ok || len(data) < 4 { + return 0, false + } + return *(*uint32)(unsafe.Pointer(unsafe.SliceData(data))), true +} + +// GetUInt32 is like GetUInt32Ok, but returns zero if the specified control +// message is not found. +func (cmsgs *controlMessages) GetUInt32(level, ctype int32) uint32 { + val, _ := cmsgs.GetUInt32Ok(level, ctype) + return val +} + +// Append adds a control message with the given level, type, and data to the +// buffer. It returns an error if there is not enough space. +func (cmsgs *controlMessages) Append(level, ctype int32, data []byte) error { + space := alignUp( + unsafe.Sizeof(_WSACMSGHDR{})+ + alignUp(uintptr(len(data)), _WSA_CMSGDATA_ALIGN), + _WSA_CMSGHDR_ALIGN, + ) + // Append the new control message at the end of the existing messages, + // or after the base header if there are no existing messages. + offset := max(uintptr(cmsgs.totalLength), _RIO_CMSG_BASE_SIZE) + if space > unsafe.Sizeof(*cmsgs)-offset { + return fmt.Errorf("not enough space to append cmsg (Level=%d Type=%d Len=%d)", level, ctype, len(data)) + } + hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset)) + hdr.level = level + hdr.typ = ctype + hdr.len = alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN) + uintptr(len(data)) + dataPtr := unsafe.Add( + unsafe.Pointer(hdr), + alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN), + ) + copy(unsafe.Slice((*byte)(dataPtr), len(data)), data) + cmsgs.totalLength = uint32(offset + space) + return nil +} + +// AppendUInt32 is like Append but appends a uint32 value. +func (cmsgs *controlMessages) AppendUInt32(level, ctype int32, val uint32) error { + data := (*[4]byte)(unsafe.Pointer(&val))[:] + return cmsgs.Append(level, ctype, data) +} + +// Clear removes all control messages from the buffer. +func (cmsgs *controlMessages) Clear() { + cmsgs.totalLength = uint32(_RIO_CMSG_BASE_SIZE) +} + +// Bytes returns the raw bytes of the control messages buffer. +func (cmsgs *controlMessages) Bytes() []byte { + totalLen := min(uintptr(cmsgs.totalLength), unsafe.Sizeof(*cmsgs)) + return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), totalLen) +} + +// Buffer returns the entire control messages buffer, including unused space. +func (cmsgs *controlMessages) Buffer() []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), unsafe.Sizeof(*cmsgs)) +} diff --git a/net/rioconn/cmsg_test.go b/net/rioconn/cmsg_test.go new file mode 100644 index 000000000..9f55dbc8a --- /dev/null +++ b/net/rioconn/cmsg_test.go @@ -0,0 +1,428 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "bytes" + "encoding/hex" + "fmt" + "slices" + "testing" + "unsafe" +) + +func TestControlMessagesSize(t *testing.T) { + t.Parallel() + if gotSize := unsafe.Sizeof(controlMessages{}); gotSize != controlMessagesSize { + t.Errorf("got: %d; want %d", gotSize, controlMessagesSize) + } +} + +func TestControlMessageAppend(t *testing.T) { + t.Parallel() + tests := []struct { + name string + clear bool // whether to call Clear() before appending + messages []controlMessage + wantBytes []byte + }{ + { + name: "zero", + wantBytes: []byte{}, + }, + { + name: "clear", + clear: true, + wantBytes: chooseForArch( + []byte{ + 0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE) + // no padding needed on 32-bit platforms + }, + []byte{ + 0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + ), + }, + { + name: "single", + messages: []controlMessage{ + {Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}}, + }, + wantBytes: chooseForArch( + []byte{ + 0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding) + 0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15 + 0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42 + 0xAA, 0xBB, 0xCC, // cmsg Data + 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + []byte{ + 0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + 0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding) + 0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15 + 0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42 + 0xAA, 0xBB, 0xCC, // cmsg Data + 0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + ), + }, + { + name: "single/after-clear", + clear: true, + messages: []controlMessage{ + {Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}}, + }, + wantBytes: chooseForArch( + []byte{ // same as "single" test case + 0x14, 0x00, 0x00, 0x00, + 0x0F, 0x00, 0x00, 0x00, + 0x0F, 0x00, 0x00, 0x00, + 0x2A, 0x00, 0x00, 0x00, + 0xAA, 0xBB, 0xCC, + 0x00, + }, + []byte{ // same as "single" test case + 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0F, 0x00, 0x00, 0x00, + 0x2A, 0x00, 0x00, 0x00, + 0xAA, 0xBB, 0xCC, + 0x00, 0x00, 0x00, 0x00, 0x00, + }, + ), + }, + { + name: "multiple", + messages: []controlMessage{ + {Level: 1, Type: 2, Data: []byte{ + 0xAA, 0xBB, 0xCC, 0xDD, + }}, + {Level: 3, Type: 4, Data: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }}, + }, + wantBytes: chooseForArch( + []byte{ + 0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size) + // cmsg 1 + 0x10, 0x00, 0x00, 0x00, // cmsg Len = 16 + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data + // cmsg 2 + 0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28 + 0x03, 0x00, 0x00, 0x00, // cmsg Level = 3 + 0x04, 0x00, 0x00, 0x00, // cmsg Type = 4 + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }, + []byte{ + 0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + // cmsg 1 + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding) + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + // cmsg 2 + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32 + 0x03, 0x00, 0x00, 0x00, // cmsg Level = 3 + 0x04, 0x00, 0x00, 0x00, // cmsg Type = 4 + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var cmsgs controlMessages + if tt.clear { + cmsgs.Clear() + } + for _, cm := range tt.messages { + if err := cmsgs.Append(cm.Level, cm.Type, cm.Data); err != nil { + t.Fatalf("Append failed: %v", err) + } + } + + if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, tt.wantBytes) { + t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v", + hex.Dump(gotBytes), hex.Dump(tt.wantBytes), + ) + } + }) + } +} + +func TestControlMessageAppendNotEnoughSpace(t *testing.T) { + t.Parallel() + var cmsgs controlMessages + if err := cmsgs.Append(1, 2, bytes.Repeat([]byte{0xCC}, 1024)); err == nil { + t.Errorf("Append succeeded unexpectedly") + } + if cmsgs.totalLength > uint32(_RIO_CMSG_BASE_SIZE) { + t.Errorf("Unexpected TotalLength after failed Append: %d", cmsgs.totalLength) + } +} + +func TestControlMessageIterator(t *testing.T) { + t.Parallel() + tests := []struct { + name string + bytes []byte + wantMessages []controlMessage + }{ + { + name: "zero", + bytes: []byte{}, + wantMessages: []controlMessage{}, + }, + { + name: "empty", + bytes: chooseForArch( + []byte{ + 0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE) + }, + []byte{ + 0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + ), + wantMessages: []controlMessage{}, + }, + { + name: "single", + bytes: chooseForArch( + []byte{ + 0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding) + 0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15 + 0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42 + 0xAA, 0xBB, 0xCC, // cmsg Data + 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + []byte{ + 0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + 0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding) + 0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15 + 0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42 + 0xAA, 0xBB, 0xCC, // cmsg Data + 0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + ), + wantMessages: []controlMessage{ + {Level: 15, Type: 42, Data: []byte{ + 0xAA, 0xBB, 0xCC, + }}, + }, + }, + { + name: "multiple", + bytes: chooseForArch( + []byte{ + 0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size) + // cmsg 1 + 0x10, 0x00, 0x00, 0x00, // cmsg Len = 16 + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data + // cmsg 2 + 0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28 + 0x03, 0x00, 0x00, 0x00, // cmsg Level = 3 + 0x04, 0x00, 0x00, 0x00, // cmsg Type = 4 + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }, + []byte{ + 0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + // cmsg 1 + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding) + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + // cmsg 2 + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32 + 0x03, 0x00, 0x00, 0x00, // cmsg Level = 3 + 0x04, 0x00, 0x00, 0x00, // cmsg Type = 4 + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }, + ), + wantMessages: []controlMessage{ + {Level: 1, Type: 2, Data: []byte{ + 0xAA, 0xBB, 0xCC, 0xDD, + }}, + {Level: 3, Type: 4, Data: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cmsgs controlMessages + copy(cmsgs.Buffer(), tt.bytes) + + gotMessages := slices.Collect(cmsgs.All()) + if len(gotMessages) != len(tt.wantMessages) { + t.Fatalf("number of messages: got %d; want %d", len(gotMessages), len(tt.wantMessages)) + } + + for i := range gotMessages { + if got, want := gotMessages[i], tt.wantMessages[i]; got.Level != want.Level || got.Type != want.Type || !bytes.Equal(got.Data, want.Data) { + t.Errorf("message %d:\ngot %v\nwant %v", i, got, want) + } + } + }) + } +} + +func TestControlMessageUInt32(t *testing.T) { + t.Parallel() + + const ( + msgLvl = 1 + msgType = 2 + msgVal = uint32(0xAABBCCDD) + ) + var cmsgs controlMessages + if err := cmsgs.AppendUInt32(msgLvl, msgType, msgVal); err != nil { + t.Fatalf("AppendUInt32 failed: %v", err) + } + wantBytes := chooseForArch( + []byte{ + 0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x10, 0x00, 0x00, 0x00, // cmsg Len = 16 + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data + }, + []byte{ + 0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size) + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding) + 0x01, 0x00, 0x00, 0x00, // cmsg Level = 1 + 0x02, 0x00, 0x00, 0x00, // cmsg Type = 2 + 0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data + 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT + }, + ) + if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, wantBytes) { + t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v", + hex.Dump(gotBytes), hex.Dump(wantBytes), + ) + } + gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType) + if !ok { + t.Fatal("GetUInt32Ok failed to find value") + } + if gotVal != msgVal { + t.Fatalf("uint32: got 0x%X; want 0x%X", gotVal, msgVal) + } +} + +func TestControlMessageGetUInt32(t *testing.T) { + t.Parallel() + + var cmsgs controlMessages + if _, ok := cmsgs.GetUInt32Ok(1, 2); ok { + t.Fatal("GetUInt32Ok found unexpected value") + } + cmsgs.AppendUInt32(3, 4, 0xDEADBEEF) + if _, ok := cmsgs.GetUInt32Ok(1, 2); ok { + t.Fatal("GetUInt32Ok found unexpected value") + } + if gotVal, ok := cmsgs.GetUInt32Ok(3, 4); !ok { + t.Fatal("GetUInt32Ok found unexpected value") + } else if gotVal != 0xDEADBEEF { + t.Fatalf("GetUInt32Ok: got 0x%X; want 0xDEADBEEF", gotVal) + } +} + +func TestControlMessageNoAlloc(t *testing.T) { + const ( + msgLvl = 1 + msgType = 2 + msgVal = uint32(0xAABBCCDD) + ) + + var cmsgs controlMessages + allocs := testing.AllocsPerRun(1000, func() { + cmsgs.Clear() + cmsgs.AppendUInt32(msgLvl, msgType, msgVal) + }) + if allocs != 0 { + t.Fatalf("AppendUInt32 allocated %f times; want 0", allocs) + } + allocs = testing.AllocsPerRun(1000, func() { + gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType) + if !ok { + t.Fatal("GetUInt32Ok failed to find value") + } + if gotVal != msgVal { + t.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotVal, msgVal) + } + }) + if allocs != 0 { + t.Fatalf("GetUInt32Ok allocated %f times; want 0", allocs) + } +} + +func BenchmarkControlMessagesAppendUInt32(b *testing.B) { + var cmsgs controlMessages + + b.ReportAllocs() + for b.Loop() { + cmsgs.Clear() + cmsgs.AppendUInt32(1, 2, uint32(3)) + } +} + +func BenchmarkControlMessagesGetUInt32Ok(b *testing.B) { + const ( + msgLvl = 1 + msgType = 2 + msgVal = uint32(0xAABBCCDD) + ) + + var cmsgs controlMessages + cmsgs.AppendUInt32(msgLvl, msgType, msgVal) + + b.ReportAllocs() + for b.Loop() { + gotMsgVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType) + if !ok { + b.Fatal("GetUInt32Ok failed to find value") + } + if gotMsgVal != msgVal { + b.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotMsgVal, msgVal) + } + } +} + +func chooseForArch[T any](val32, val64 T) T { + if unsafe.Sizeof(uintptr(0)) == 4 { + return val32 + } + return val64 +} + +// String returns a string representation of the control message. +func (cmsg controlMessage) String() string { + return fmt.Sprintf("cmsg{Level=%d Type=%d}\n%s", cmsg.Level, cmsg.Type, hex.Dump(cmsg.Data)) +} diff --git a/net/rioconn/config.go b/net/rioconn/config.go index 359e33457..43d520d54 100644 --- a/net/rioconn/config.go +++ b/net/rioconn/config.go @@ -15,6 +15,11 @@ const ( defaultRXMemoryLimit = 2 << 20 // 2 MiB defaultTXMemoryLimit = 2 << 20 // 2 MiB + + // defaultMaxUSOOffloadSize is the default maximum offload size for USO. + // It should be smaller than or equal to the typical hardware offload size + // to avoid software fallback. Mellanox NICs typically support up to 64000. + defaultMaxUSOOffloadSize = 64000 ) // Config holds configuration for a RIO connection, independent of the transport protocol. @@ -89,4 +94,30 @@ func (o TxConfig) MaxPayloadLen() uint16 { // UDPConfig holds configuration for a [UDPConn]. type UDPConfig struct { Config + uso USOConfig +} + +// USO returns the UDP segmentation offload (USO) configuration. +func (o UDPConfig) USO() *USOConfig { + return &o.uso +} + +// USOConfig holds the UDP segmentation offload (USO) configuration. +type USOConfig struct { + enabled bool + maxOffloadSize uint16 // 0 means default (i.e., [defaultMaxUSOOffloadSize]) +} + +// Enabled reports whether USO is enabled. +func (o USOConfig) Enabled() bool { + return o.enabled +} + +// MaxOffloadSize returns the maximum number of bytes that can be batched into a +// single send for UDP segmentation offload. +func (o USOConfig) MaxOffloadSize() uint16 { + if !o.Enabled() { + return 0 + } + return cmp.Or(o.maxOffloadSize, defaultMaxUSOOffloadSize) } diff --git a/net/rioconn/doc.go b/net/rioconn/doc.go index 3ea1d49a8..2406f68fb 100644 --- a/net/rioconn/doc.go +++ b/net/rioconn/doc.go @@ -5,7 +5,7 @@ // Package rioconn provides [UDPConn], a UDP socket implementation // that uses the Windows RIO API extensions and supports batched I/O -// for improved performance on high-throughput UDP workloads. +// and USO for improved performance on high-throughput UDP workloads. package rioconn import ( diff --git a/net/rioconn/offloads.go b/net/rioconn/offloads.go new file mode 100644 index 000000000..d02b98f0d --- /dev/null +++ b/net/rioconn/offloads.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "fmt" + "io" + + "tailscale.com/net/packet" +) + +// coalescePackets copies packets from buffs into dst until dst is full, it +// copies a packet shorter than the first, the maximum coalesced length +// or the maximum number of packets is reached, or there are no more packets. +// A zero maxCoalescedLen or maxCoalescedPackets means no limit. +// +// Each packet is copied starting at offset. +// Each copied packet is preceded by a Geneve header if geneve.VNI.IsSet(). +// +// It returns the number of packets and bytes copied into dst, the packet +// size for the batch, or an error if Geneve header encoding fails. +func coalescePackets( + dst []byte, geneve packet.GeneveHeader, buffs [][]byte, + offset, maxCoalescedPackets, maxCoalescedBytes int, +) (packets, bytes, packetSize int, err error) { + var header []byte + if geneve.VNI.IsSet() { + var geneveHeader [packet.GeneveFixedHeaderLength]byte + if err := geneve.Encode(geneveHeader[:]); err != nil { + return 0, 0, 0, err + } + header = geneveHeader[:] + } + if len(buffs) != 0 { + // The first packet determines the packet size for the batch, + // which is the size of each packet in the coalesced batch + // except possibly the last one. If the first packet cannot fit + // in dst, we cannot coalesce any packets. + packetSize = len(header) + len(buffs[0]) - offset + if packetSize > len(dst) { + return 0, 0, 0, fmt.Errorf("%w: packet size %d exceeds dst size %d", + io.ErrShortBuffer, packetSize, len(dst), + ) + } + } + for _, buff := range buffs { + buff = buff[offset:] + packetLen := len(header) + len(buff) + newBytes := bytes + packetLen + if newBytes > len(dst) { + break // no more space + } + if packetLen > packetSize { + break // packet is too large for this batch + } + if bytes != 0 && maxCoalescedBytes != 0 && newBytes > maxCoalescedBytes { + break // would exceed the maximum coalesced length + } + if maxCoalescedPackets != 0 && packets >= maxCoalescedPackets { + break // would exceed the maximum number of coalesced packets + } + if packetLen == 0 { + // Consume the zero-length packet if it's the first packet, + // but never coalesce them. + if packets == 0 { + packets = 1 + } + break + } + + copy(dst[bytes:], header) + copy(dst[bytes+len(header):], buff) + + packets++ + bytes = newBytes + if packetLen < packetSize { + // A smaller than packetSize packet on the tail is legal, + // but it must end the batch. + break + } + } + return packets, bytes, packetSize, nil +} diff --git a/net/rioconn/offloads_test.go b/net/rioconn/offloads_test.go new file mode 100644 index 000000000..417d202e6 --- /dev/null +++ b/net/rioconn/offloads_test.go @@ -0,0 +1,334 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "bytes" + "testing" + + "tailscale.com/net/packet" +) + +func TestCoalescePackets(t *testing.T) { + geneve := packet.GeneveHeader{ + Protocol: packet.GeneveProtocolWireGuard, + } + geneve.VNI.Set(7) + + tests := []struct { + name string + buffs [][]byte + offset int + geneve packet.GeneveHeader + dst []byte + maxCoalescedPackets int + maxCoalescedBytes int + wantBytes []byte + wantPackets int + wantPacketSize int + wantErr bool + }{ + { + name: "no-packets", + buffs: nil, + geneve: packet.GeneveHeader{}, + dst: make([]byte, 100), + wantBytes: nil, + wantPackets: 0, + wantPacketSize: 0, + }, + { + name: "single-packet", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + }, + geneve: packet.GeneveHeader{}, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + }, + wantPackets: 1, + wantPacketSize: 3, + }, + { + name: "single-packet/zero-length", + buffs: [][]byte{ + {}, + }, + geneve: packet.GeneveHeader{}, + dst: make([]byte, 100), + wantBytes: []byte{}, + wantPackets: 1, + wantPacketSize: 0, + }, + { + name: "single-packet/with-offset", + buffs: [][]byte{ + { + 0x00, 0x00, + 0x01, 0x02, 0x03, + }, + }, + offset: 2, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + }, + wantPackets: 1, + wantPacketSize: 3, + }, + { + name: "single-packet/with-geneve", + buffs: [][]byte{ + { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Geneve header space + 0x01, 0x02, 0x03, + }, + }, + offset: 8, + geneve: geneve, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header + 0x01, 0x02, 0x03, + }, + wantPackets: 1, + wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet + }, + { + name: "single-packet/exact-fit", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + }, + dst: make([]byte, 3), + wantBytes: []byte{0x01, 0x02, 0x03}, + wantPackets: 1, + wantPacketSize: 3, + }, + { + name: "single-packet/too-large-for-dst", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + }, + dst: make([]byte, 2), + wantErr: true, + }, + { + name: "single-packet/with-geneve/too-large-for-dst", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + }, + geneve: geneve, + dst: make([]byte, 10), // smaller than 8 bytes Geneve header + 3 bytes packet + wantErr: true, + }, + { + name: "multiple-packets/coalesce-all", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, + }, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, + }, + wantPackets: 3, + wantPacketSize: 3, + }, + { + name: "multiple-packets/coalesce-all/with-offset", + buffs: [][]byte{ + {0x00, 0x00, 0x01, 0x02, 0x03}, + {0x00, 0x00, 0x04, 0x05, 0x06}, + {0x00, 0x00, 0x07, 0x08, 0x09}, + }, + offset: 2, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, + }, + wantPackets: 3, + wantPacketSize: 3, + }, + { + name: "multiple-packets/smaller-packet-ends-batch", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08}, // will be coalesced, but ends the batch + {0x09, 0x0a, 0x0b}, // will not be coalesced in this batch + }, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + 0x07, 0x08, + }, + wantPackets: 3, + wantPacketSize: 3, + }, + { + name: "multiple-packets/larger-packet-ends-batch", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09, 0x0a}, // will not be coalesced in this batch + + }, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + }, + wantPackets: 2, + wantPacketSize: 3, + }, + { + name: "multiple-packets/partial-fit", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, // could be coalesced, but won't fit + + }, + dst: make([]byte, 7), // can only fit the first two packets + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + }, + wantPackets: 2, + wantPacketSize: 3, + wantErr: false, // partial fit is not an error + }, + { + name: "multiple-packets/exact-fit", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, + }, + dst: make([]byte, 9), // can exactly fit all three packets + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, + }, + wantPackets: 3, + wantPacketSize: 3, + }, + { + name: "multiple-packets/with-geneve", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, + {0x0a, 0x0b}, // ends the batch + {0x0c, 0x0d, 0x0e}, + }, + geneve: geneve, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet + 0x01, 0x02, 0x03, + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet + 0x04, 0x05, 0x06, + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet + 0x07, 0x08, 0x09, + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for fourth packet + 0x0a, 0x0b, + }, + wantPackets: 4, + wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet + wantErr: false, + }, + { + name: "multiple-packets/all-zero-length", + buffs: [][]byte{ + {}, + {}, + {}, + }, + dst: make([]byte, 100), + wantBytes: []byte{}, + wantPackets: 1, // zero-length packets cannot be coalesced + wantPacketSize: 0, + }, + { + name: "multiple-packets/all-zero-length/with-geneve", + buffs: [][]byte{ + {}, + {}, + {}, + }, + geneve: geneve, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet + 0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet + }, + wantPackets: 3, + wantPacketSize: 8, // Geneve header size, since the packets are zero-length + }, + { + name: "multiple-packets/zero-length-packet-ends-batch", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {}, // cannot be coalesced, ends the batch + + }, + dst: make([]byte, 100), + wantBytes: []byte{ + 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, + }, + wantPackets: 2, + wantPacketSize: 3, + }, + { + name: "invalid-geneve-header", + buffs: [][]byte{ + {0x01, 0x02, 0x03}, + }, + geneve: packet.GeneveHeader{ + Version: 0xFF, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allocs := testing.AllocsPerRun(1000, func() { + gotPackets, gotBytes, gotPacketSize, err := coalescePackets( + tt.dst, tt.geneve, tt.buffs, tt.offset, + tt.maxCoalescedPackets, + tt.maxCoalescedBytes, + ) + if (err != nil) != tt.wantErr { + t.Fatalf("error: got %v; want error: %v", err, tt.wantErr) + } + if gotPackets != tt.wantPackets { + t.Errorf("packets: got %d; want %d", gotPackets, tt.wantPackets) + } + if gotBytes := tt.dst[:gotBytes]; !bytes.Equal(gotBytes, tt.wantBytes) { + t.Errorf("bytes: got %v; want %v", gotBytes, tt.wantBytes) + } + if gotPacketSize != tt.wantPacketSize { + t.Errorf("packet size: got %d; want %d", gotPacketSize, tt.wantPacketSize) + } + }) + // Coalescing packets should not cause any allocations, + // except for when it returns an error. + if !tt.wantErr && allocs != 0 { + t.Errorf("unexpected allocations: got %f; want 0", allocs) + } + }) + } +} diff --git a/net/rioconn/options.go b/net/rioconn/options.go index 5e915a89b..fb1705c7f 100644 --- a/net/rioconn/options.go +++ b/net/rioconn/options.go @@ -76,3 +76,10 @@ type udpOption func(*UDPConfig) func (o udpOption) applyUDP(opts *UDPConfig) { o(opts) } + +// USO specifies whether UDP segmentation offload (USO) should be enabled. +func USO(enabled bool) UDPOption { + return udpOption(func(opts *UDPConfig) { + opts.uso.enabled = enabled + }) +} diff --git a/net/rioconn/request.go b/net/rioconn/request.go index 3e2b04ebb..d4f6ae5bb 100644 --- a/net/rioconn/request.go +++ b/net/rioconn/request.go @@ -27,8 +27,9 @@ type request struct { buffID winrio.BufferId // ID of the registered buffer containing this request buffBase uintptr // base address of the registered buffer - raddr rawSockaddr // remote address for RIO send/receive operations - data []byte // a slice pointing into the data buffer area after the struct + raddr rawSockaddr // remote address for RIO send/receive operations + data []byte // a slice pointing into the data buffer area after the struct + control controlMessages // control messages buffer // followed by the actual data at [requestDataOffset] // from the start of the struct. } @@ -94,8 +95,22 @@ func (r *request) PostSend(rq winrio.Rq, flags uint32) error { Length: uint32(len(r.data)), Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase), } - remoteAddr := r.remoteAddrDesc() - return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r))) + remoteAddress := winrio.Buffer{ + Id: r.buffID, + Length: uint32(unsafe.Sizeof(r.raddr)), + Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase), + } + var control *winrio.Buffer + if !r.control.Empty() { + cmsgs := r.control.Buffer() + control = &winrio.Buffer{ + Id: r.buffID, + Length: uint32(len(cmsgs)), + Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(cmsgs))) - r.buffBase), + } + } + return winrio.SendEx(rq, &data, 1, nil, &remoteAddress, control, + nil, flags, uintptr(unsafe.Pointer(r))) } // PostReceive posts the request as a receive operation to the given RIO request queue @@ -107,8 +122,18 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error { Length: uint32(cap(r.data)), Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase), } - remoteAddress := r.remoteAddrDesc() - return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil, + remoteAddress := winrio.Buffer{ + Id: r.buffID, + Length: uint32(unsafe.Sizeof(r.raddr)), + Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase), + } + controlBuffer := r.control.Buffer() + control := &winrio.Buffer{ + Id: r.buffID, + Length: uint32(len(controlBuffer)), + Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(controlBuffer))) - r.buffBase), + } + return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, control, nil, flags, uintptr(unsafe.Pointer(r))) } @@ -119,6 +144,8 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error { // of bytes written does not match the length of the request's data buffer. func (r *request) CompleteSend(status int32, bytesWritten uint32) error { expected := len(r.data) + r.data = r.data[:0] + r.control.Clear() if status != 0 { return windows.Errno(status) } @@ -154,18 +181,11 @@ func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReade return r.Reader(), nil } -func (r *request) remoteAddrDesc() winrio.Buffer { - return winrio.Buffer{ - Id: r.buffID, - Length: uint32(unsafe.Sizeof(r.raddr)), - Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase), - } -} - // Reset prepares the request for reuse by resetting its state. func (r *request) Reset() { r.raddr = rawSockaddr{} r.data = r.data[:0] + r.control.Clear() } type ( @@ -201,6 +221,11 @@ func (w *requestWriter) SetRemoteAddr(raddr rawSockaddr) { w.raddr = raddr } +// ControlMessages returns a pointer to the request's control messages buffer. +func (w *requestWriter) ControlMessages() *controlMessages { + return &w.control +} + // Reserve reserves n bytes in the request's data buffer, // and returns a slice pointing to the reserved space. // It panics if n is negative or exceeds the available capacity. @@ -262,3 +287,8 @@ func (r *requestReader) RemoteAddrPort() (netip.AddrPort, error) { func (r *requestReader) RemoteAddr() rawSockaddr { return r.raddr } + +// ControlMessages returns a pointer to the request's control messages buffer. +func (r *requestReader) ControlMessages() *controlMessages { + return &r.control +} diff --git a/net/rioconn/udp_test.go b/net/rioconn/udp_test.go index befe613f2..589265072 100644 --- a/net/rioconn/udp_test.go +++ b/net/rioconn/udp_test.go @@ -124,36 +124,65 @@ func TestUDPSendReceiveBatch(t *testing.T) { iterations int sendBatchSize int receiveBatchSize int + uso bool }{ { name: "udp4/single", network: "udp4", pattern: []int{1312}, }, + { + name: "udp4/batch", + network: "udp4", + pattern: []int{1312}, + iterations: 1024, + }, { name: "udp4/single/max", network: "udp4", pattern: []int{rioconn.MaxUDPPayloadIPv4}, }, { - name: "udp4/batch/max", - network: "udp4", - pattern: []int{rioconn.MaxUDPPayloadIPv4}, + name: "udp4/batch/max", + network: "udp4", + pattern: []int{rioconn.MaxUDPPayloadIPv4}, + iterations: 1024, }, { name: "udp6/single", network: "udp6", pattern: []int{1312}, }, + { + name: "udp6/batch", + network: "udp6", + pattern: []int{1312}, + iterations: 1024, + }, { name: "udp6/single/max", network: "udp6", pattern: []int{rioconn.MaxUDPPayloadIPv6}, }, { - name: "udp6/batch/max", - network: "udp6", - pattern: []int{rioconn.MaxUDPPayloadIPv6}, + name: "udp6/batch/max", + network: "udp6", + pattern: []int{rioconn.MaxUDPPayloadIPv6}, + iterations: 10, + }, + { + name: "udp6/batch/uso", + network: "udp6", + pattern: []int{1312}, + iterations: 10, + uso: true, + }, + { + name: "udp6/batch/max/uso", + network: "udp6", + pattern: []int{rioconn.MaxUDPPayloadIPv6}, + iterations: 10, + uso: true, }, } for _, tt := range tests { @@ -164,7 +193,9 @@ func TestUDPSendReceiveBatch(t *testing.T) { cmp.Or(tt.sendBatchSize, defaultBatchSize), cmp.Or(tt.receiveBatchSize, defaultBatchSize), tt.network, tt.network, - nil, nil, + []rioconn.UDPOption{ + rioconn.USO(tt.uso), + }, nil, ) }) } @@ -174,16 +205,19 @@ func FuzzUDPSendReceiveBatch(f *testing.F) { batchSizes := []uint16{1, 64} packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4} numIterations := []uint16{1024} + uso := []bool{false, true} for _, packetLen := range packetSizes { for _, numIter := range numIterations { for _, batchSize := range batchSizes { - f.Add(packetLen, numIter, batchSize, batchSize) + for _, usoEnabled := range uso { + f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled) + } } } } - f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) { + f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled bool) { network := "udp4" maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4) @@ -206,6 +240,7 @@ func FuzzUDPSendReceiveBatch(f *testing.F) { []rioconn.UDPOption{ rioconn.RxMemoryLimit(128 << 10), rioconn.TxMemoryLimit(512 << 10), + rioconn.USO(usoEnabled), }, []rioconn.UDPOption{ rioconn.RxMemoryLimit(512 << 10), diff --git a/net/rioconn/udptx.go b/net/rioconn/udptx.go index ee79a9d54..029a370e3 100644 --- a/net/rioconn/udptx.go +++ b/net/rioconn/udptx.go @@ -8,6 +8,7 @@ package rioconn import ( "errors" "fmt" + "io" "net" "net/netip" "unsafe" @@ -25,14 +26,26 @@ import ( // otherwise specified by the method. type udpTx struct { udpNx + maxCoalescedPackets int // the maximum number of coalesced packets, or 0 if no limit + maxCoalescedBytes int // the maximum total length of coalesced packets, or 0 if no limit } // init initializes the transmit half of a [UDPConn] with the // specified underlying connection and options. func (tx *udpTx) init(conn *conn, options UDPConfig) error { - // Without USO, the data buffer for each send request only needs to hold - // a single packet's payload. - dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload) + var dataSize uint16 + if uso := options.USO(); uso.Enabled() { + // When USO is enabled, the request's data buffer may need to + // hold multiple coalesced packets up to the maximum offload size, + // or a single packet up to the maximum payload size, whichever is larger. + dataSize = max(options.Tx().MaxPayloadLen(), uso.MaxOffloadSize()) + tx.maxCoalescedBytes = int(uso.MaxOffloadSize()) + } else { + // Otherwise, the data buffer for each send request only needs to hold + // a single packet's payload. + dataSize = min(options.Tx().MaxPayloadLen(), MaxUDPPayload) + tx.maxCoalescedPackets = 1 + } if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil { return fmt.Errorf("failed to initialize udpTx: %w", err) } @@ -99,20 +112,31 @@ func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet w := req.Writer() w.SetRemoteAddr(raddr) - if geneve.VNI.IsSet() { - geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength) - geneve.Encode(geneveHeader[:]) - } - if _, err := w.Write(buffs[n][offset:]); err != nil { + buf := w.Reserve(w.Cap()) // reserve the entire buffer + packets, bytes, packetSize, err := coalescePackets( + buf, geneve, buffs[n:], offset, + tx.maxCoalescedPackets, + tx.maxCoalescedBytes, + ) + w.SetLen(bytes) // shrink the buffer to the actual number of bytes coalesced + if err != nil { return err } + if packets == 0 { + return io.ErrNoProgress + } + if packets > 1 { + if err := w.ControlMessages().AppendUInt32(windows.IPPROTO_UDP, windows.UDP_SEND_MSG_SIZE, uint32(packetSize)); err != nil { + return fmt.Errorf("failed to append UDP_SEND_MSG_SIZE cmsg: %w", err) + } + } if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil { return fmt.Errorf("failed to post send request: %w", err) } tx.requests.Advance() // advance after posting the request - n++ + n += packets } return nil }