net/rioconn: implement UDP control message support

Adding support for USO/URO requires reading and writing control messages.
With RIO, control messages, like other data passed to or from the API,
must reside in a RIO-registered buffer. RIO also requires them to be
wrapped in a RIO_CMSG_BUFFER, which consists of a length followed by
properly aligned control messages.

In this commit, implement control message handling and include
control messages when posting sends and receives to RIO.

But we keep them blank fo now.

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:24:32 -06:00
parent e22a9f5909
commit c4993354ed
No known key found for this signature in database
10 changed files with 1177 additions and 33 deletions

171
net/rioconn/cmsg.go Normal file
View File

@ -0,0 +1,171 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package rioconn
import (
"fmt"
"iter"
"unsafe"
)
// _RIO_CMSG_BUFFER is the header of the control messages buffer for RIO operations,
// as defined by the Windows API. It is followed by zero or more control messages.
type _RIO_CMSG_BUFFER struct {
totalLength uint32 // total length of the buffer, including this header
// followed by control messages aligned to _WSA_CMSGHDR_ALIGN
}
// _WSACMSGHDR is the header for a single control message, as defined by the Windows API.
type _WSACMSGHDR struct {
len uintptr
level int32
typ int32
// followed by data aligned to _WSA_CMSGDATA_ALIGN
}
const (
_MAX_NATURAL_ALIGN = unsafe.Alignof(uintptr(0))
_WSA_CMSGHDR_ALIGN = unsafe.Alignof(_WSACMSGHDR{})
_WSA_CMSGDATA_ALIGN = _MAX_NATURAL_ALIGN
_RIO_CMSG_BASE_SIZE = (unsafe.Sizeof(_RIO_CMSG_BUFFER{}) +
(_WSA_CMSGHDR_ALIGN - 1)) &^ (_WSA_CMSGHDR_ALIGN - 1)
)
const (
// controlMessagesSize is the target size of the [controlMessages] struct,
// which includes the header, padding for alignment, and space for control messages.
// It is somewhat arbitrary but is large enough to hold typical control messages.
controlMessagesSize = 64
// controlMessagesBufferSize is the size of the buffer for control messages,
// which is the total size minus the size of the header.
controlMessagesBufferSize = controlMessagesSize - unsafe.Sizeof(_RIO_CMSG_BUFFER{})
)
// controlMessages is a fixed-size control messages buffer for RIO operations.
// It is large enough to hold the header and typical control messages
// for either send or receive operations.
type controlMessages struct {
_RIO_CMSG_BUFFER
_ [controlMessagesBufferSize]byte // space for cmsgs
}
// controlMessage represents a single control message.
type controlMessage struct {
Level int32 // protocol that originated the control information
Type int32 // protocol-specific type of control information
Data []byte // type-specific control data (backed by [controlMessages.buffer])
}
// Empty reports whether the control messages buffer contains no control messages.
func (cmsgs *controlMessages) Empty() bool {
return uintptr(cmsgs.totalLength) <= _RIO_CMSG_BASE_SIZE
}
// All returns an iterator over all control messages in the buffer.
func (cmsgs *controlMessages) All() iter.Seq[controlMessage] {
return func(yield func(cmsg controlMessage) bool) {
offset := _RIO_CMSG_BASE_SIZE
totalLen := uintptr(cmsgs.totalLength)
if totalLen > unsafe.Sizeof(*cmsgs) {
panic("controlMessages buffer overflow")
}
for offset+unsafe.Sizeof(_WSACMSGHDR{}) <= totalLen {
hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset))
dataOffset := alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN)
if hdr.len < dataOffset || offset+hdr.len > totalLen {
panic("invalid control message header length")
}
dataLen := uintptr(hdr.len) - dataOffset
data := unsafe.Slice((*byte)(
unsafe.Add(unsafe.Pointer(hdr), dataOffset)),
dataLen,
)
if !yield(controlMessage{
Level: hdr.level,
Type: hdr.typ,
Data: data,
}) {
break
}
offset += alignUp(hdr.len, _WSA_CMSGHDR_ALIGN)
}
}
}
// GetOk retrieves the data for the first control message with the given
// level and type. It reports whether such a control message was found.
func (cmsgs *controlMessages) GetOk(level, ctype int32) (data []byte, ok bool) {
for cmsg := range cmsgs.All() {
if cmsg.Level == level && cmsg.Type == ctype {
return cmsg.Data, true
}
}
return nil, false
}
// GetUInt32Ok is like GetOk but interprets the data as a uint32.
func (cmsgs *controlMessages) GetUInt32Ok(level, ctype int32) (val uint32, ok bool) {
data, ok := cmsgs.GetOk(level, ctype)
if !ok || len(data) < 4 {
return 0, false
}
return *(*uint32)(unsafe.Pointer(unsafe.SliceData(data))), true
}
// GetUInt32 is like GetUInt32Ok, but returns zero if the specified control
// message is not found.
func (cmsgs *controlMessages) GetUInt32(level, ctype int32) uint32 {
val, _ := cmsgs.GetUInt32Ok(level, ctype)
return val
}
// Append adds a control message with the given level, type, and data to the
// buffer. It returns an error if there is not enough space.
func (cmsgs *controlMessages) Append(level, ctype int32, data []byte) error {
space := alignUp(
unsafe.Sizeof(_WSACMSGHDR{})+
alignUp(uintptr(len(data)), _WSA_CMSGDATA_ALIGN),
_WSA_CMSGHDR_ALIGN,
)
// Append the new control message at the end of the existing messages,
// or after the base header if there are no existing messages.
offset := max(uintptr(cmsgs.totalLength), _RIO_CMSG_BASE_SIZE)
if space > unsafe.Sizeof(*cmsgs)-offset {
return fmt.Errorf("not enough space to append cmsg (Level=%d Type=%d Len=%d)", level, ctype, len(data))
}
hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset))
hdr.level = level
hdr.typ = ctype
hdr.len = alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN) + uintptr(len(data))
dataPtr := unsafe.Add(
unsafe.Pointer(hdr),
alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN),
)
copy(unsafe.Slice((*byte)(dataPtr), len(data)), data)
cmsgs.totalLength = uint32(offset + space)
return nil
}
// AppendUInt32 is like Append but appends a uint32 value.
func (cmsgs *controlMessages) AppendUInt32(level, ctype int32, val uint32) error {
data := (*[4]byte)(unsafe.Pointer(&val))[:]
return cmsgs.Append(level, ctype, data)
}
// Clear removes all control messages from the buffer.
func (cmsgs *controlMessages) Clear() {
cmsgs.totalLength = uint32(_RIO_CMSG_BASE_SIZE)
}
// Bytes returns the raw bytes of the control messages buffer.
func (cmsgs *controlMessages) Bytes() []byte {
totalLen := min(uintptr(cmsgs.totalLength), unsafe.Sizeof(*cmsgs))
return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), totalLen)
}
// Buffer returns the entire control messages buffer, including unused space.
func (cmsgs *controlMessages) Buffer() []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), unsafe.Sizeof(*cmsgs))
}

428
net/rioconn/cmsg_test.go Normal file
View File

@ -0,0 +1,428 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package rioconn
import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"testing"
"unsafe"
)
func TestControlMessagesSize(t *testing.T) {
t.Parallel()
if gotSize := unsafe.Sizeof(controlMessages{}); gotSize != controlMessagesSize {
t.Errorf("got: %d; want %d", gotSize, controlMessagesSize)
}
}
func TestControlMessageAppend(t *testing.T) {
t.Parallel()
tests := []struct {
name string
clear bool // whether to call Clear() before appending
messages []controlMessage
wantBytes []byte
}{
{
name: "zero",
wantBytes: []byte{},
},
{
name: "clear",
clear: true,
wantBytes: chooseForArch(
[]byte{
0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE)
// no padding needed on 32-bit platforms
},
[]byte{
0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
),
},
{
name: "single",
messages: []controlMessage{
{Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}},
},
wantBytes: chooseForArch(
[]byte{
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding)
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
0xAA, 0xBB, 0xCC, // cmsg Data
0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
[]byte{
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding)
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
0xAA, 0xBB, 0xCC, // cmsg Data
0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
),
},
{
name: "single/after-clear",
clear: true,
messages: []controlMessage{
{Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}},
},
wantBytes: chooseForArch(
[]byte{ // same as "single" test case
0x14, 0x00, 0x00, 0x00,
0x0F, 0x00, 0x00, 0x00,
0x0F, 0x00, 0x00, 0x00,
0x2A, 0x00, 0x00, 0x00,
0xAA, 0xBB, 0xCC,
0x00,
},
[]byte{ // same as "single" test case
0x20, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x0F, 0x00, 0x00, 0x00,
0x2A, 0x00, 0x00, 0x00,
0xAA, 0xBB, 0xCC,
0x00, 0x00, 0x00, 0x00, 0x00,
},
),
},
{
name: "multiple",
messages: []controlMessage{
{Level: 1, Type: 2, Data: []byte{
0xAA, 0xBB, 0xCC, 0xDD,
}},
{Level: 3, Type: 4, Data: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
}},
},
wantBytes: chooseForArch(
[]byte{
0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
// cmsg 1
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
// cmsg 2
0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
},
[]byte{
0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
// cmsg 1
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
// cmsg 2
0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
},
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cmsgs controlMessages
if tt.clear {
cmsgs.Clear()
}
for _, cm := range tt.messages {
if err := cmsgs.Append(cm.Level, cm.Type, cm.Data); err != nil {
t.Fatalf("Append failed: %v", err)
}
}
if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, tt.wantBytes) {
t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v",
hex.Dump(gotBytes), hex.Dump(tt.wantBytes),
)
}
})
}
}
func TestControlMessageAppendNotEnoughSpace(t *testing.T) {
t.Parallel()
var cmsgs controlMessages
if err := cmsgs.Append(1, 2, bytes.Repeat([]byte{0xCC}, 1024)); err == nil {
t.Errorf("Append succeeded unexpectedly")
}
if cmsgs.totalLength > uint32(_RIO_CMSG_BASE_SIZE) {
t.Errorf("Unexpected TotalLength after failed Append: %d", cmsgs.totalLength)
}
}
func TestControlMessageIterator(t *testing.T) {
t.Parallel()
tests := []struct {
name string
bytes []byte
wantMessages []controlMessage
}{
{
name: "zero",
bytes: []byte{},
wantMessages: []controlMessage{},
},
{
name: "empty",
bytes: chooseForArch(
[]byte{
0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE)
},
[]byte{
0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
),
wantMessages: []controlMessage{},
},
{
name: "single",
bytes: chooseForArch(
[]byte{
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding)
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
0xAA, 0xBB, 0xCC, // cmsg Data
0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
[]byte{
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding)
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
0xAA, 0xBB, 0xCC, // cmsg Data
0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
),
wantMessages: []controlMessage{
{Level: 15, Type: 42, Data: []byte{
0xAA, 0xBB, 0xCC,
}},
},
},
{
name: "multiple",
bytes: chooseForArch(
[]byte{
0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
// cmsg 1
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
// cmsg 2
0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
},
[]byte{
0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
// cmsg 1
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
// cmsg 2
0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
},
),
wantMessages: []controlMessage{
{Level: 1, Type: 2, Data: []byte{
0xAA, 0xBB, 0xCC, 0xDD,
}},
{Level: 3, Type: 4, Data: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cmsgs controlMessages
copy(cmsgs.Buffer(), tt.bytes)
gotMessages := slices.Collect(cmsgs.All())
if len(gotMessages) != len(tt.wantMessages) {
t.Fatalf("number of messages: got %d; want %d", len(gotMessages), len(tt.wantMessages))
}
for i := range gotMessages {
if got, want := gotMessages[i], tt.wantMessages[i]; got.Level != want.Level || got.Type != want.Type || !bytes.Equal(got.Data, want.Data) {
t.Errorf("message %d:\ngot %v\nwant %v", i, got, want)
}
}
})
}
}
func TestControlMessageUInt32(t *testing.T) {
t.Parallel()
const (
msgLvl = 1
msgType = 2
msgVal = uint32(0xAABBCCDD)
)
var cmsgs controlMessages
if err := cmsgs.AppendUInt32(msgLvl, msgType, msgVal); err != nil {
t.Fatalf("AppendUInt32 failed: %v", err)
}
wantBytes := chooseForArch(
[]byte{
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data
},
[]byte{
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
},
)
if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, wantBytes) {
t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v",
hex.Dump(gotBytes), hex.Dump(wantBytes),
)
}
gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
if !ok {
t.Fatal("GetUInt32Ok failed to find value")
}
if gotVal != msgVal {
t.Fatalf("uint32: got 0x%X; want 0x%X", gotVal, msgVal)
}
}
func TestControlMessageGetUInt32(t *testing.T) {
t.Parallel()
var cmsgs controlMessages
if _, ok := cmsgs.GetUInt32Ok(1, 2); ok {
t.Fatal("GetUInt32Ok found unexpected value")
}
cmsgs.AppendUInt32(3, 4, 0xDEADBEEF)
if _, ok := cmsgs.GetUInt32Ok(1, 2); ok {
t.Fatal("GetUInt32Ok found unexpected value")
}
if gotVal, ok := cmsgs.GetUInt32Ok(3, 4); !ok {
t.Fatal("GetUInt32Ok found unexpected value")
} else if gotVal != 0xDEADBEEF {
t.Fatalf("GetUInt32Ok: got 0x%X; want 0xDEADBEEF", gotVal)
}
}
func TestControlMessageNoAlloc(t *testing.T) {
const (
msgLvl = 1
msgType = 2
msgVal = uint32(0xAABBCCDD)
)
var cmsgs controlMessages
allocs := testing.AllocsPerRun(1000, func() {
cmsgs.Clear()
cmsgs.AppendUInt32(msgLvl, msgType, msgVal)
})
if allocs != 0 {
t.Fatalf("AppendUInt32 allocated %f times; want 0", allocs)
}
allocs = testing.AllocsPerRun(1000, func() {
gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
if !ok {
t.Fatal("GetUInt32Ok failed to find value")
}
if gotVal != msgVal {
t.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotVal, msgVal)
}
})
if allocs != 0 {
t.Fatalf("GetUInt32Ok allocated %f times; want 0", allocs)
}
}
func BenchmarkControlMessagesAppendUInt32(b *testing.B) {
var cmsgs controlMessages
b.ReportAllocs()
for b.Loop() {
cmsgs.Clear()
cmsgs.AppendUInt32(1, 2, uint32(3))
}
}
func BenchmarkControlMessagesGetUInt32Ok(b *testing.B) {
const (
msgLvl = 1
msgType = 2
msgVal = uint32(0xAABBCCDD)
)
var cmsgs controlMessages
cmsgs.AppendUInt32(msgLvl, msgType, msgVal)
b.ReportAllocs()
for b.Loop() {
gotMsgVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
if !ok {
b.Fatal("GetUInt32Ok failed to find value")
}
if gotMsgVal != msgVal {
b.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotMsgVal, msgVal)
}
}
}
func chooseForArch[T any](val32, val64 T) T {
if unsafe.Sizeof(uintptr(0)) == 4 {
return val32
}
return val64
}
// String returns a string representation of the control message.
func (cmsg controlMessage) String() string {
return fmt.Sprintf("cmsg{Level=%d Type=%d}\n%s", cmsg.Level, cmsg.Type, hex.Dump(cmsg.Data))
}

View File

@ -15,6 +15,11 @@ const (
defaultRXMemoryLimit = 2 << 20 // 2 MiB
defaultTXMemoryLimit = 2 << 20 // 2 MiB
// defaultMaxUSOOffloadSize is the default maximum offload size for USO.
// It should be smaller than or equal to the typical hardware offload size
// to avoid software fallback. Mellanox NICs typically support up to 64000.
defaultMaxUSOOffloadSize = 64000
)
// Config holds configuration for a RIO connection, independent of the transport protocol.
@ -89,4 +94,30 @@ func (o TxConfig) MaxPayloadLen() uint16 {
// UDPConfig holds configuration for a [UDPConn].
type UDPConfig struct {
Config
uso USOConfig
}
// USO returns the UDP segmentation offload (USO) configuration.
func (o UDPConfig) USO() *USOConfig {
return &o.uso
}
// USOConfig holds the UDP segmentation offload (USO) configuration.
type USOConfig struct {
enabled bool
maxOffloadSize uint16 // 0 means default (i.e., [defaultMaxUSOOffloadSize])
}
// Enabled reports whether USO is enabled.
func (o USOConfig) Enabled() bool {
return o.enabled
}
// MaxOffloadSize returns the maximum number of bytes that can be batched into a
// single send for UDP segmentation offload.
func (o USOConfig) MaxOffloadSize() uint16 {
if !o.Enabled() {
return 0
}
return cmp.Or(o.maxOffloadSize, defaultMaxUSOOffloadSize)
}

View File

@ -5,7 +5,7 @@
// Package rioconn provides [UDPConn], a UDP socket implementation
// that uses the Windows RIO API extensions and supports batched I/O
// for improved performance on high-throughput UDP workloads.
// and USO for improved performance on high-throughput UDP workloads.
package rioconn
import (

84
net/rioconn/offloads.go Normal file
View File

@ -0,0 +1,84 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rioconn
import (
"fmt"
"io"
"tailscale.com/net/packet"
)
// coalescePackets copies packets from buffs into dst until dst is full, it
// copies a packet shorter than the first, the maximum coalesced length
// or the maximum number of packets is reached, or there are no more packets.
// A zero maxCoalescedLen or maxCoalescedPackets means no limit.
//
// Each packet is copied starting at offset.
// Each copied packet is preceded by a Geneve header if geneve.VNI.IsSet().
//
// It returns the number of packets and bytes copied into dst, the packet
// size for the batch, or an error if Geneve header encoding fails.
func coalescePackets(
dst []byte, geneve packet.GeneveHeader, buffs [][]byte,
offset, maxCoalescedPackets, maxCoalescedBytes int,
) (packets, bytes, packetSize int, err error) {
var header []byte
if geneve.VNI.IsSet() {
var geneveHeader [packet.GeneveFixedHeaderLength]byte
if err := geneve.Encode(geneveHeader[:]); err != nil {
return 0, 0, 0, err
}
header = geneveHeader[:]
}
if len(buffs) != 0 {
// The first packet determines the packet size for the batch,
// which is the size of each packet in the coalesced batch
// except possibly the last one. If the first packet cannot fit
// in dst, we cannot coalesce any packets.
packetSize = len(header) + len(buffs[0]) - offset
if packetSize > len(dst) {
return 0, 0, 0, fmt.Errorf("%w: packet size %d exceeds dst size %d",
io.ErrShortBuffer, packetSize, len(dst),
)
}
}
for _, buff := range buffs {
buff = buff[offset:]
packetLen := len(header) + len(buff)
newBytes := bytes + packetLen
if newBytes > len(dst) {
break // no more space
}
if packetLen > packetSize {
break // packet is too large for this batch
}
if bytes != 0 && maxCoalescedBytes != 0 && newBytes > maxCoalescedBytes {
break // would exceed the maximum coalesced length
}
if maxCoalescedPackets != 0 && packets >= maxCoalescedPackets {
break // would exceed the maximum number of coalesced packets
}
if packetLen == 0 {
// Consume the zero-length packet if it's the first packet,
// but never coalesce them.
if packets == 0 {
packets = 1
}
break
}
copy(dst[bytes:], header)
copy(dst[bytes+len(header):], buff)
packets++
bytes = newBytes
if packetLen < packetSize {
// A smaller than packetSize packet on the tail is legal,
// but it must end the batch.
break
}
}
return packets, bytes, packetSize, nil
}

View File

@ -0,0 +1,334 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rioconn
import (
"bytes"
"testing"
"tailscale.com/net/packet"
)
func TestCoalescePackets(t *testing.T) {
geneve := packet.GeneveHeader{
Protocol: packet.GeneveProtocolWireGuard,
}
geneve.VNI.Set(7)
tests := []struct {
name string
buffs [][]byte
offset int
geneve packet.GeneveHeader
dst []byte
maxCoalescedPackets int
maxCoalescedBytes int
wantBytes []byte
wantPackets int
wantPacketSize int
wantErr bool
}{
{
name: "no-packets",
buffs: nil,
geneve: packet.GeneveHeader{},
dst: make([]byte, 100),
wantBytes: nil,
wantPackets: 0,
wantPacketSize: 0,
},
{
name: "single-packet",
buffs: [][]byte{
{0x01, 0x02, 0x03},
},
geneve: packet.GeneveHeader{},
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
},
wantPackets: 1,
wantPacketSize: 3,
},
{
name: "single-packet/zero-length",
buffs: [][]byte{
{},
},
geneve: packet.GeneveHeader{},
dst: make([]byte, 100),
wantBytes: []byte{},
wantPackets: 1,
wantPacketSize: 0,
},
{
name: "single-packet/with-offset",
buffs: [][]byte{
{
0x00, 0x00,
0x01, 0x02, 0x03,
},
},
offset: 2,
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
},
wantPackets: 1,
wantPacketSize: 3,
},
{
name: "single-packet/with-geneve",
buffs: [][]byte{
{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Geneve header space
0x01, 0x02, 0x03,
},
},
offset: 8,
geneve: geneve,
dst: make([]byte, 100),
wantBytes: []byte{
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header
0x01, 0x02, 0x03,
},
wantPackets: 1,
wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet
},
{
name: "single-packet/exact-fit",
buffs: [][]byte{
{0x01, 0x02, 0x03},
},
dst: make([]byte, 3),
wantBytes: []byte{0x01, 0x02, 0x03},
wantPackets: 1,
wantPacketSize: 3,
},
{
name: "single-packet/too-large-for-dst",
buffs: [][]byte{
{0x01, 0x02, 0x03},
},
dst: make([]byte, 2),
wantErr: true,
},
{
name: "single-packet/with-geneve/too-large-for-dst",
buffs: [][]byte{
{0x01, 0x02, 0x03},
},
geneve: geneve,
dst: make([]byte, 10), // smaller than 8 bytes Geneve header + 3 bytes packet
wantErr: true,
},
{
name: "multiple-packets/coalesce-all",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08, 0x09},
},
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
0x07, 0x08, 0x09,
},
wantPackets: 3,
wantPacketSize: 3,
},
{
name: "multiple-packets/coalesce-all/with-offset",
buffs: [][]byte{
{0x00, 0x00, 0x01, 0x02, 0x03},
{0x00, 0x00, 0x04, 0x05, 0x06},
{0x00, 0x00, 0x07, 0x08, 0x09},
},
offset: 2,
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
0x07, 0x08, 0x09,
},
wantPackets: 3,
wantPacketSize: 3,
},
{
name: "multiple-packets/smaller-packet-ends-batch",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08}, // will be coalesced, but ends the batch
{0x09, 0x0a, 0x0b}, // will not be coalesced in this batch
},
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
0x07, 0x08,
},
wantPackets: 3,
wantPacketSize: 3,
},
{
name: "multiple-packets/larger-packet-ends-batch",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08, 0x09, 0x0a}, // will not be coalesced in this batch
},
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
},
wantPackets: 2,
wantPacketSize: 3,
},
{
name: "multiple-packets/partial-fit",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08, 0x09}, // could be coalesced, but won't fit
},
dst: make([]byte, 7), // can only fit the first two packets
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
},
wantPackets: 2,
wantPacketSize: 3,
wantErr: false, // partial fit is not an error
},
{
name: "multiple-packets/exact-fit",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08, 0x09},
},
dst: make([]byte, 9), // can exactly fit all three packets
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
0x07, 0x08, 0x09,
},
wantPackets: 3,
wantPacketSize: 3,
},
{
name: "multiple-packets/with-geneve",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{0x07, 0x08, 0x09},
{0x0a, 0x0b}, // ends the batch
{0x0c, 0x0d, 0x0e},
},
geneve: geneve,
dst: make([]byte, 100),
wantBytes: []byte{
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet
0x01, 0x02, 0x03,
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet
0x04, 0x05, 0x06,
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet
0x07, 0x08, 0x09,
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for fourth packet
0x0a, 0x0b,
},
wantPackets: 4,
wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet
wantErr: false,
},
{
name: "multiple-packets/all-zero-length",
buffs: [][]byte{
{},
{},
{},
},
dst: make([]byte, 100),
wantBytes: []byte{},
wantPackets: 1, // zero-length packets cannot be coalesced
wantPacketSize: 0,
},
{
name: "multiple-packets/all-zero-length/with-geneve",
buffs: [][]byte{
{},
{},
{},
},
geneve: geneve,
dst: make([]byte, 100),
wantBytes: []byte{
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet
},
wantPackets: 3,
wantPacketSize: 8, // Geneve header size, since the packets are zero-length
},
{
name: "multiple-packets/zero-length-packet-ends-batch",
buffs: [][]byte{
{0x01, 0x02, 0x03},
{0x04, 0x05, 0x06},
{}, // cannot be coalesced, ends the batch
},
dst: make([]byte, 100),
wantBytes: []byte{
0x01, 0x02, 0x03,
0x04, 0x05, 0x06,
},
wantPackets: 2,
wantPacketSize: 3,
},
{
name: "invalid-geneve-header",
buffs: [][]byte{
{0x01, 0x02, 0x03},
},
geneve: packet.GeneveHeader{
Version: 0xFF,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
gotPackets, gotBytes, gotPacketSize, err := coalescePackets(
tt.dst, tt.geneve, tt.buffs, tt.offset,
tt.maxCoalescedPackets,
tt.maxCoalescedBytes,
)
if (err != nil) != tt.wantErr {
t.Fatalf("error: got %v; want error: %v", err, tt.wantErr)
}
if gotPackets != tt.wantPackets {
t.Errorf("packets: got %d; want %d", gotPackets, tt.wantPackets)
}
if gotBytes := tt.dst[:gotBytes]; !bytes.Equal(gotBytes, tt.wantBytes) {
t.Errorf("bytes: got %v; want %v", gotBytes, tt.wantBytes)
}
if gotPacketSize != tt.wantPacketSize {
t.Errorf("packet size: got %d; want %d", gotPacketSize, tt.wantPacketSize)
}
})
// Coalescing packets should not cause any allocations,
// except for when it returns an error.
if !tt.wantErr && allocs != 0 {
t.Errorf("unexpected allocations: got %f; want 0", allocs)
}
})
}
}

View File

@ -76,3 +76,10 @@ type udpOption func(*UDPConfig)
func (o udpOption) applyUDP(opts *UDPConfig) {
o(opts)
}
// USO specifies whether UDP segmentation offload (USO) should be enabled.
func USO(enabled bool) UDPOption {
return udpOption(func(opts *UDPConfig) {
opts.uso.enabled = enabled
})
}

View File

@ -27,8 +27,9 @@ type request struct {
buffID winrio.BufferId // ID of the registered buffer containing this request
buffBase uintptr // base address of the registered buffer
raddr rawSockaddr // remote address for RIO send/receive operations
data []byte // a slice pointing into the data buffer area after the struct
raddr rawSockaddr // remote address for RIO send/receive operations
data []byte // a slice pointing into the data buffer area after the struct
control controlMessages // control messages buffer
// followed by the actual data at [requestDataOffset]
// from the start of the struct.
}
@ -94,8 +95,22 @@ func (r *request) PostSend(rq winrio.Rq, flags uint32) error {
Length: uint32(len(r.data)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
}
remoteAddr := r.remoteAddrDesc()
return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r)))
remoteAddress := winrio.Buffer{
Id: r.buffID,
Length: uint32(unsafe.Sizeof(r.raddr)),
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
}
var control *winrio.Buffer
if !r.control.Empty() {
cmsgs := r.control.Buffer()
control = &winrio.Buffer{
Id: r.buffID,
Length: uint32(len(cmsgs)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(cmsgs))) - r.buffBase),
}
}
return winrio.SendEx(rq, &data, 1, nil, &remoteAddress, control,
nil, flags, uintptr(unsafe.Pointer(r)))
}
// PostReceive posts the request as a receive operation to the given RIO request queue
@ -107,8 +122,18 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
Length: uint32(cap(r.data)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
}
remoteAddress := r.remoteAddrDesc()
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil,
remoteAddress := winrio.Buffer{
Id: r.buffID,
Length: uint32(unsafe.Sizeof(r.raddr)),
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
}
controlBuffer := r.control.Buffer()
control := &winrio.Buffer{
Id: r.buffID,
Length: uint32(len(controlBuffer)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(controlBuffer))) - r.buffBase),
}
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, control,
nil, flags, uintptr(unsafe.Pointer(r)))
}
@ -119,6 +144,8 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
// of bytes written does not match the length of the request's data buffer.
func (r *request) CompleteSend(status int32, bytesWritten uint32) error {
expected := len(r.data)
r.data = r.data[:0]
r.control.Clear()
if status != 0 {
return windows.Errno(status)
}
@ -154,18 +181,11 @@ func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReade
return r.Reader(), nil
}
func (r *request) remoteAddrDesc() winrio.Buffer {
return winrio.Buffer{
Id: r.buffID,
Length: uint32(unsafe.Sizeof(r.raddr)),
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
}
}
// Reset prepares the request for reuse by resetting its state.
func (r *request) Reset() {
r.raddr = rawSockaddr{}
r.data = r.data[:0]
r.control.Clear()
}
type (
@ -201,6 +221,11 @@ func (w *requestWriter) SetRemoteAddr(raddr rawSockaddr) {
w.raddr = raddr
}
// ControlMessages returns a pointer to the request's control messages buffer.
func (w *requestWriter) ControlMessages() *controlMessages {
return &w.control
}
// Reserve reserves n bytes in the request's data buffer,
// and returns a slice pointing to the reserved space.
// It panics if n is negative or exceeds the available capacity.
@ -262,3 +287,8 @@ func (r *requestReader) RemoteAddrPort() (netip.AddrPort, error) {
func (r *requestReader) RemoteAddr() rawSockaddr {
return r.raddr
}
// ControlMessages returns a pointer to the request's control messages buffer.
func (r *requestReader) ControlMessages() *controlMessages {
return &r.control
}

View File

@ -124,36 +124,65 @@ func TestUDPSendReceiveBatch(t *testing.T) {
iterations int
sendBatchSize int
receiveBatchSize int
uso bool
}{
{
name: "udp4/single",
network: "udp4",
pattern: []int{1312},
},
{
name: "udp4/batch",
network: "udp4",
pattern: []int{1312},
iterations: 1024,
},
{
name: "udp4/single/max",
network: "udp4",
pattern: []int{rioconn.MaxUDPPayloadIPv4},
},
{
name: "udp4/batch/max",
network: "udp4",
pattern: []int{rioconn.MaxUDPPayloadIPv4},
name: "udp4/batch/max",
network: "udp4",
pattern: []int{rioconn.MaxUDPPayloadIPv4},
iterations: 1024,
},
{
name: "udp6/single",
network: "udp6",
pattern: []int{1312},
},
{
name: "udp6/batch",
network: "udp6",
pattern: []int{1312},
iterations: 1024,
},
{
name: "udp6/single/max",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
},
{
name: "udp6/batch/max",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
name: "udp6/batch/max",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
iterations: 10,
},
{
name: "udp6/batch/uso",
network: "udp6",
pattern: []int{1312},
iterations: 10,
uso: true,
},
{
name: "udp6/batch/max/uso",
network: "udp6",
pattern: []int{rioconn.MaxUDPPayloadIPv6},
iterations: 10,
uso: true,
},
}
for _, tt := range tests {
@ -164,7 +193,9 @@ func TestUDPSendReceiveBatch(t *testing.T) {
cmp.Or(tt.sendBatchSize, defaultBatchSize),
cmp.Or(tt.receiveBatchSize, defaultBatchSize),
tt.network, tt.network,
nil, nil,
[]rioconn.UDPOption{
rioconn.USO(tt.uso),
}, nil,
)
})
}
@ -174,16 +205,19 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
batchSizes := []uint16{1, 64}
packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4}
numIterations := []uint16{1024}
uso := []bool{false, true}
for _, packetLen := range packetSizes {
for _, numIter := range numIterations {
for _, batchSize := range batchSizes {
f.Add(packetLen, numIter, batchSize, batchSize)
for _, usoEnabled := range uso {
f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled)
}
}
}
}
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) {
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled bool) {
network := "udp4"
maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4)
@ -206,6 +240,7 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
[]rioconn.UDPOption{
rioconn.RxMemoryLimit(128 << 10),
rioconn.TxMemoryLimit(512 << 10),
rioconn.USO(usoEnabled),
},
[]rioconn.UDPOption{
rioconn.RxMemoryLimit(512 << 10),

View File

@ -8,6 +8,7 @@ package rioconn
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"unsafe"
@ -25,14 +26,26 @@ import (
// otherwise specified by the method.
type udpTx struct {
udpNx
maxCoalescedPackets int // the maximum number of coalesced packets, or 0 if no limit
maxCoalescedBytes int // the maximum total length of coalesced packets, or 0 if no limit
}
// init initializes the transmit half of a [UDPConn] with the
// specified underlying connection and options.
func (tx *udpTx) init(conn *conn, options UDPConfig) error {
// Without USO, the data buffer for each send request only needs to hold
// a single packet's payload.
dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
var dataSize uint16
if uso := options.USO(); uso.Enabled() {
// When USO is enabled, the request's data buffer may need to
// hold multiple coalesced packets up to the maximum offload size,
// or a single packet up to the maximum payload size, whichever is larger.
dataSize = max(options.Tx().MaxPayloadLen(), uso.MaxOffloadSize())
tx.maxCoalescedBytes = int(uso.MaxOffloadSize())
} else {
// Otherwise, the data buffer for each send request only needs to hold
// a single packet's payload.
dataSize = min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
tx.maxCoalescedPackets = 1
}
if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil {
return fmt.Errorf("failed to initialize udpTx: %w", err)
}
@ -99,20 +112,31 @@ func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet
w := req.Writer()
w.SetRemoteAddr(raddr)
if geneve.VNI.IsSet() {
geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength)
geneve.Encode(geneveHeader[:])
}
if _, err := w.Write(buffs[n][offset:]); err != nil {
buf := w.Reserve(w.Cap()) // reserve the entire buffer
packets, bytes, packetSize, err := coalescePackets(
buf, geneve, buffs[n:], offset,
tx.maxCoalescedPackets,
tx.maxCoalescedBytes,
)
w.SetLen(bytes) // shrink the buffer to the actual number of bytes coalesced
if err != nil {
return err
}
if packets == 0 {
return io.ErrNoProgress
}
if packets > 1 {
if err := w.ControlMessages().AppendUInt32(windows.IPPROTO_UDP, windows.UDP_SEND_MSG_SIZE, uint32(packetSize)); err != nil {
return fmt.Errorf("failed to append UDP_SEND_MSG_SIZE cmsg: %w", err)
}
}
if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil {
return fmt.Errorf("failed to post send request: %w", err)
}
tx.requests.Advance() // advance after posting the request
n++
n += packets
}
return nil
}