mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
net/rioconn: implement UDP control message support
Adding support for USO/URO requires reading and writing control messages. With RIO, control messages, like other data passed to or from the API, must reside in a RIO-registered buffer. RIO also requires them to be wrapped in a RIO_CMSG_BUFFER, which consists of a length followed by properly aligned control messages. In this commit, implement control message handling and include control messages when posting sends and receives to RIO. But we keep them blank fo now. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
e22a9f5909
commit
c4993354ed
171
net/rioconn/cmsg.go
Normal file
171
net/rioconn/cmsg.go
Normal file
@ -0,0 +1,171 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// _RIO_CMSG_BUFFER is the header of the control messages buffer for RIO operations,
|
||||
// as defined by the Windows API. It is followed by zero or more control messages.
|
||||
type _RIO_CMSG_BUFFER struct {
|
||||
totalLength uint32 // total length of the buffer, including this header
|
||||
// followed by control messages aligned to _WSA_CMSGHDR_ALIGN
|
||||
}
|
||||
|
||||
// _WSACMSGHDR is the header for a single control message, as defined by the Windows API.
|
||||
type _WSACMSGHDR struct {
|
||||
len uintptr
|
||||
level int32
|
||||
typ int32
|
||||
// followed by data aligned to _WSA_CMSGDATA_ALIGN
|
||||
}
|
||||
|
||||
const (
|
||||
_MAX_NATURAL_ALIGN = unsafe.Alignof(uintptr(0))
|
||||
_WSA_CMSGHDR_ALIGN = unsafe.Alignof(_WSACMSGHDR{})
|
||||
_WSA_CMSGDATA_ALIGN = _MAX_NATURAL_ALIGN
|
||||
_RIO_CMSG_BASE_SIZE = (unsafe.Sizeof(_RIO_CMSG_BUFFER{}) +
|
||||
(_WSA_CMSGHDR_ALIGN - 1)) &^ (_WSA_CMSGHDR_ALIGN - 1)
|
||||
)
|
||||
|
||||
const (
|
||||
// controlMessagesSize is the target size of the [controlMessages] struct,
|
||||
// which includes the header, padding for alignment, and space for control messages.
|
||||
// It is somewhat arbitrary but is large enough to hold typical control messages.
|
||||
controlMessagesSize = 64
|
||||
|
||||
// controlMessagesBufferSize is the size of the buffer for control messages,
|
||||
// which is the total size minus the size of the header.
|
||||
controlMessagesBufferSize = controlMessagesSize - unsafe.Sizeof(_RIO_CMSG_BUFFER{})
|
||||
)
|
||||
|
||||
// controlMessages is a fixed-size control messages buffer for RIO operations.
|
||||
// It is large enough to hold the header and typical control messages
|
||||
// for either send or receive operations.
|
||||
type controlMessages struct {
|
||||
_RIO_CMSG_BUFFER
|
||||
_ [controlMessagesBufferSize]byte // space for cmsgs
|
||||
}
|
||||
|
||||
// controlMessage represents a single control message.
|
||||
type controlMessage struct {
|
||||
Level int32 // protocol that originated the control information
|
||||
Type int32 // protocol-specific type of control information
|
||||
Data []byte // type-specific control data (backed by [controlMessages.buffer])
|
||||
}
|
||||
|
||||
// Empty reports whether the control messages buffer contains no control messages.
|
||||
func (cmsgs *controlMessages) Empty() bool {
|
||||
return uintptr(cmsgs.totalLength) <= _RIO_CMSG_BASE_SIZE
|
||||
}
|
||||
|
||||
// All returns an iterator over all control messages in the buffer.
|
||||
func (cmsgs *controlMessages) All() iter.Seq[controlMessage] {
|
||||
return func(yield func(cmsg controlMessage) bool) {
|
||||
offset := _RIO_CMSG_BASE_SIZE
|
||||
totalLen := uintptr(cmsgs.totalLength)
|
||||
if totalLen > unsafe.Sizeof(*cmsgs) {
|
||||
panic("controlMessages buffer overflow")
|
||||
}
|
||||
for offset+unsafe.Sizeof(_WSACMSGHDR{}) <= totalLen {
|
||||
hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset))
|
||||
dataOffset := alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN)
|
||||
if hdr.len < dataOffset || offset+hdr.len > totalLen {
|
||||
panic("invalid control message header length")
|
||||
}
|
||||
dataLen := uintptr(hdr.len) - dataOffset
|
||||
data := unsafe.Slice((*byte)(
|
||||
unsafe.Add(unsafe.Pointer(hdr), dataOffset)),
|
||||
dataLen,
|
||||
)
|
||||
if !yield(controlMessage{
|
||||
Level: hdr.level,
|
||||
Type: hdr.typ,
|
||||
Data: data,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
offset += alignUp(hdr.len, _WSA_CMSGHDR_ALIGN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetOk retrieves the data for the first control message with the given
|
||||
// level and type. It reports whether such a control message was found.
|
||||
func (cmsgs *controlMessages) GetOk(level, ctype int32) (data []byte, ok bool) {
|
||||
for cmsg := range cmsgs.All() {
|
||||
if cmsg.Level == level && cmsg.Type == ctype {
|
||||
return cmsg.Data, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetUInt32Ok is like GetOk but interprets the data as a uint32.
|
||||
func (cmsgs *controlMessages) GetUInt32Ok(level, ctype int32) (val uint32, ok bool) {
|
||||
data, ok := cmsgs.GetOk(level, ctype)
|
||||
if !ok || len(data) < 4 {
|
||||
return 0, false
|
||||
}
|
||||
return *(*uint32)(unsafe.Pointer(unsafe.SliceData(data))), true
|
||||
}
|
||||
|
||||
// GetUInt32 is like GetUInt32Ok, but returns zero if the specified control
|
||||
// message is not found.
|
||||
func (cmsgs *controlMessages) GetUInt32(level, ctype int32) uint32 {
|
||||
val, _ := cmsgs.GetUInt32Ok(level, ctype)
|
||||
return val
|
||||
}
|
||||
|
||||
// Append adds a control message with the given level, type, and data to the
|
||||
// buffer. It returns an error if there is not enough space.
|
||||
func (cmsgs *controlMessages) Append(level, ctype int32, data []byte) error {
|
||||
space := alignUp(
|
||||
unsafe.Sizeof(_WSACMSGHDR{})+
|
||||
alignUp(uintptr(len(data)), _WSA_CMSGDATA_ALIGN),
|
||||
_WSA_CMSGHDR_ALIGN,
|
||||
)
|
||||
// Append the new control message at the end of the existing messages,
|
||||
// or after the base header if there are no existing messages.
|
||||
offset := max(uintptr(cmsgs.totalLength), _RIO_CMSG_BASE_SIZE)
|
||||
if space > unsafe.Sizeof(*cmsgs)-offset {
|
||||
return fmt.Errorf("not enough space to append cmsg (Level=%d Type=%d Len=%d)", level, ctype, len(data))
|
||||
}
|
||||
hdr := (*_WSACMSGHDR)(unsafe.Add(unsafe.Pointer(cmsgs), offset))
|
||||
hdr.level = level
|
||||
hdr.typ = ctype
|
||||
hdr.len = alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN) + uintptr(len(data))
|
||||
dataPtr := unsafe.Add(
|
||||
unsafe.Pointer(hdr),
|
||||
alignUp(unsafe.Sizeof(_WSACMSGHDR{}), _WSA_CMSGDATA_ALIGN),
|
||||
)
|
||||
copy(unsafe.Slice((*byte)(dataPtr), len(data)), data)
|
||||
cmsgs.totalLength = uint32(offset + space)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppendUInt32 is like Append but appends a uint32 value.
|
||||
func (cmsgs *controlMessages) AppendUInt32(level, ctype int32, val uint32) error {
|
||||
data := (*[4]byte)(unsafe.Pointer(&val))[:]
|
||||
return cmsgs.Append(level, ctype, data)
|
||||
}
|
||||
|
||||
// Clear removes all control messages from the buffer.
|
||||
func (cmsgs *controlMessages) Clear() {
|
||||
cmsgs.totalLength = uint32(_RIO_CMSG_BASE_SIZE)
|
||||
}
|
||||
|
||||
// Bytes returns the raw bytes of the control messages buffer.
|
||||
func (cmsgs *controlMessages) Bytes() []byte {
|
||||
totalLen := min(uintptr(cmsgs.totalLength), unsafe.Sizeof(*cmsgs))
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), totalLen)
|
||||
}
|
||||
|
||||
// Buffer returns the entire control messages buffer, including unused space.
|
||||
func (cmsgs *controlMessages) Buffer() []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(cmsgs)), unsafe.Sizeof(*cmsgs))
|
||||
}
|
||||
428
net/rioconn/cmsg_test.go
Normal file
428
net/rioconn/cmsg_test.go
Normal file
@ -0,0 +1,428 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestControlMessagesSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
if gotSize := unsafe.Sizeof(controlMessages{}); gotSize != controlMessagesSize {
|
||||
t.Errorf("got: %d; want %d", gotSize, controlMessagesSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
clear bool // whether to call Clear() before appending
|
||||
messages []controlMessage
|
||||
wantBytes []byte
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
wantBytes: []byte{},
|
||||
},
|
||||
{
|
||||
name: "clear",
|
||||
clear: true,
|
||||
wantBytes: chooseForArch(
|
||||
[]byte{
|
||||
0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE)
|
||||
// no padding needed on 32-bit platforms
|
||||
},
|
||||
[]byte{
|
||||
0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
messages: []controlMessage{
|
||||
{Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}},
|
||||
},
|
||||
wantBytes: chooseForArch(
|
||||
[]byte{
|
||||
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
|
||||
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
|
||||
0xAA, 0xBB, 0xCC, // cmsg Data
|
||||
0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
[]byte{
|
||||
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
|
||||
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
|
||||
0xAA, 0xBB, 0xCC, // cmsg Data
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "single/after-clear",
|
||||
clear: true,
|
||||
messages: []controlMessage{
|
||||
{Level: 15, Type: 42, Data: []byte{0xAA, 0xBB, 0xCC}},
|
||||
},
|
||||
wantBytes: chooseForArch(
|
||||
[]byte{ // same as "single" test case
|
||||
0x14, 0x00, 0x00, 0x00,
|
||||
0x0F, 0x00, 0x00, 0x00,
|
||||
0x0F, 0x00, 0x00, 0x00,
|
||||
0x2A, 0x00, 0x00, 0x00,
|
||||
0xAA, 0xBB, 0xCC,
|
||||
0x00,
|
||||
},
|
||||
[]byte{ // same as "single" test case
|
||||
0x20, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x0F, 0x00, 0x00, 0x00,
|
||||
0x2A, 0x00, 0x00, 0x00,
|
||||
0xAA, 0xBB, 0xCC,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
messages: []controlMessage{
|
||||
{Level: 1, Type: 2, Data: []byte{
|
||||
0xAA, 0xBB, 0xCC, 0xDD,
|
||||
}},
|
||||
{Level: 3, Type: 4, Data: []byte{
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
}},
|
||||
},
|
||||
wantBytes: chooseForArch(
|
||||
[]byte{
|
||||
0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
|
||||
// cmsg 1
|
||||
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
|
||||
// cmsg 2
|
||||
0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28
|
||||
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
|
||||
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
},
|
||||
[]byte{
|
||||
0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
// cmsg 1
|
||||
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
// cmsg 2
|
||||
0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32
|
||||
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
|
||||
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cmsgs controlMessages
|
||||
if tt.clear {
|
||||
cmsgs.Clear()
|
||||
}
|
||||
for _, cm := range tt.messages {
|
||||
if err := cmsgs.Append(cm.Level, cm.Type, cm.Data); err != nil {
|
||||
t.Fatalf("Append failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, tt.wantBytes) {
|
||||
t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v",
|
||||
hex.Dump(gotBytes), hex.Dump(tt.wantBytes),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageAppendNotEnoughSpace(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cmsgs controlMessages
|
||||
if err := cmsgs.Append(1, 2, bytes.Repeat([]byte{0xCC}, 1024)); err == nil {
|
||||
t.Errorf("Append succeeded unexpectedly")
|
||||
}
|
||||
if cmsgs.totalLength > uint32(_RIO_CMSG_BASE_SIZE) {
|
||||
t.Errorf("Unexpected TotalLength after failed Append: %d", cmsgs.totalLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageIterator(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
bytes []byte
|
||||
wantMessages []controlMessage
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
bytes: []byte{},
|
||||
wantMessages: []controlMessage{},
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
bytes: chooseForArch(
|
||||
[]byte{
|
||||
0x04, 0x00, 0x00, 0x00, // TotalLength = 4 (_RIO_CMSG_BASE_SIZE)
|
||||
},
|
||||
[]byte{
|
||||
0x08, 0x00, 0x00, 0x00, // TotalLength = 8 (_RIO_CMSG_BASE_SIZE)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
),
|
||||
wantMessages: []controlMessage{},
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
bytes: chooseForArch(
|
||||
[]byte{
|
||||
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Len = 15 (excluding padding)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
|
||||
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
|
||||
0xAA, 0xBB, 0xCC, // cmsg Data
|
||||
0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
[]byte{
|
||||
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 19 (excluding padding)
|
||||
0x0F, 0x00, 0x00, 0x00, // cmsg Level = 15
|
||||
0x2A, 0x00, 0x00, 0x00, // cmsg Type = 42
|
||||
0xAA, 0xBB, 0xCC, // cmsg Data
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
),
|
||||
wantMessages: []controlMessage{
|
||||
{Level: 15, Type: 42, Data: []byte{
|
||||
0xAA, 0xBB, 0xCC,
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
bytes: chooseForArch(
|
||||
[]byte{
|
||||
0x30, 0x00, 0x00, 0x00, // TotalLength = 48 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
|
||||
// cmsg 1
|
||||
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
|
||||
// cmsg 2
|
||||
0x1C, 0x00, 0x00, 0x00, // cmsg Len = 28
|
||||
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
|
||||
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
},
|
||||
[]byte{
|
||||
0x40, 0x00, 0x00, 0x00, // TotalLength = 64 (_RIO_CMSG_BASE_SIZE + aligned cmsg1 size + aligned cmsg2 size)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
// cmsg 1
|
||||
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xAA, 0xBB, 0xCC, 0xDD, // cmsg Data
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
// cmsg 2
|
||||
0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 32
|
||||
0x03, 0x00, 0x00, 0x00, // cmsg Level = 3
|
||||
0x04, 0x00, 0x00, 0x00, // cmsg Type = 4
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // cmsg Data
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
},
|
||||
),
|
||||
wantMessages: []controlMessage{
|
||||
{Level: 1, Type: 2, Data: []byte{
|
||||
0xAA, 0xBB, 0xCC, 0xDD,
|
||||
}},
|
||||
{Level: 3, Type: 4, Data: []byte{
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var cmsgs controlMessages
|
||||
copy(cmsgs.Buffer(), tt.bytes)
|
||||
|
||||
gotMessages := slices.Collect(cmsgs.All())
|
||||
if len(gotMessages) != len(tt.wantMessages) {
|
||||
t.Fatalf("number of messages: got %d; want %d", len(gotMessages), len(tt.wantMessages))
|
||||
}
|
||||
|
||||
for i := range gotMessages {
|
||||
if got, want := gotMessages[i], tt.wantMessages[i]; got.Level != want.Level || got.Type != want.Type || !bytes.Equal(got.Data, want.Data) {
|
||||
t.Errorf("message %d:\ngot %v\nwant %v", i, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageUInt32(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
msgLvl = 1
|
||||
msgType = 2
|
||||
msgVal = uint32(0xAABBCCDD)
|
||||
)
|
||||
var cmsgs controlMessages
|
||||
if err := cmsgs.AppendUInt32(msgLvl, msgType, msgVal); err != nil {
|
||||
t.Fatalf("AppendUInt32 failed: %v", err)
|
||||
}
|
||||
wantBytes := chooseForArch(
|
||||
[]byte{
|
||||
0x14, 0x00, 0x00, 0x00, // TotalLength = 20 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x10, 0x00, 0x00, 0x00, // cmsg Len = 16
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data
|
||||
},
|
||||
[]byte{
|
||||
0x20, 0x00, 0x00, 0x00, // TotalLength = 32 (_RIO_CMSG_BASE_SIZE + aligned cmsg size)
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // cmsg Len = 20 (excluding padding)
|
||||
0x01, 0x00, 0x00, 0x00, // cmsg Level = 1
|
||||
0x02, 0x00, 0x00, 0x00, // cmsg Type = 2
|
||||
0xDD, 0xCC, 0xBB, 0xAA, // cmsg Data
|
||||
0x00, 0x00, 0x00, 0x00, // padding to align to _WSA_CMSGHDR_ALIGNMENT
|
||||
},
|
||||
)
|
||||
if gotBytes := cmsgs.Bytes(); !bytes.Equal(gotBytes, wantBytes) {
|
||||
t.Fatalf("buffer bytes:\ngot\n%v\nwant\n%v",
|
||||
hex.Dump(gotBytes), hex.Dump(wantBytes),
|
||||
)
|
||||
}
|
||||
gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
|
||||
if !ok {
|
||||
t.Fatal("GetUInt32Ok failed to find value")
|
||||
}
|
||||
if gotVal != msgVal {
|
||||
t.Fatalf("uint32: got 0x%X; want 0x%X", gotVal, msgVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageGetUInt32(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var cmsgs controlMessages
|
||||
if _, ok := cmsgs.GetUInt32Ok(1, 2); ok {
|
||||
t.Fatal("GetUInt32Ok found unexpected value")
|
||||
}
|
||||
cmsgs.AppendUInt32(3, 4, 0xDEADBEEF)
|
||||
if _, ok := cmsgs.GetUInt32Ok(1, 2); ok {
|
||||
t.Fatal("GetUInt32Ok found unexpected value")
|
||||
}
|
||||
if gotVal, ok := cmsgs.GetUInt32Ok(3, 4); !ok {
|
||||
t.Fatal("GetUInt32Ok found unexpected value")
|
||||
} else if gotVal != 0xDEADBEEF {
|
||||
t.Fatalf("GetUInt32Ok: got 0x%X; want 0xDEADBEEF", gotVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlMessageNoAlloc(t *testing.T) {
|
||||
const (
|
||||
msgLvl = 1
|
||||
msgType = 2
|
||||
msgVal = uint32(0xAABBCCDD)
|
||||
)
|
||||
|
||||
var cmsgs controlMessages
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
cmsgs.Clear()
|
||||
cmsgs.AppendUInt32(msgLvl, msgType, msgVal)
|
||||
})
|
||||
if allocs != 0 {
|
||||
t.Fatalf("AppendUInt32 allocated %f times; want 0", allocs)
|
||||
}
|
||||
allocs = testing.AllocsPerRun(1000, func() {
|
||||
gotVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
|
||||
if !ok {
|
||||
t.Fatal("GetUInt32Ok failed to find value")
|
||||
}
|
||||
if gotVal != msgVal {
|
||||
t.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotVal, msgVal)
|
||||
}
|
||||
})
|
||||
if allocs != 0 {
|
||||
t.Fatalf("GetUInt32Ok allocated %f times; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkControlMessagesAppendUInt32(b *testing.B) {
|
||||
var cmsgs controlMessages
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
cmsgs.Clear()
|
||||
cmsgs.AppendUInt32(1, 2, uint32(3))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkControlMessagesGetUInt32Ok(b *testing.B) {
|
||||
const (
|
||||
msgLvl = 1
|
||||
msgType = 2
|
||||
msgVal = uint32(0xAABBCCDD)
|
||||
)
|
||||
|
||||
var cmsgs controlMessages
|
||||
cmsgs.AppendUInt32(msgLvl, msgType, msgVal)
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
gotMsgVal, ok := cmsgs.GetUInt32Ok(msgLvl, msgType)
|
||||
if !ok {
|
||||
b.Fatal("GetUInt32Ok failed to find value")
|
||||
}
|
||||
if gotMsgVal != msgVal {
|
||||
b.Fatalf("GetUInt32Ok: got 0x%X; want 0x%X", gotMsgVal, msgVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chooseForArch[T any](val32, val64 T) T {
|
||||
if unsafe.Sizeof(uintptr(0)) == 4 {
|
||||
return val32
|
||||
}
|
||||
return val64
|
||||
}
|
||||
|
||||
// String returns a string representation of the control message.
|
||||
func (cmsg controlMessage) String() string {
|
||||
return fmt.Sprintf("cmsg{Level=%d Type=%d}\n%s", cmsg.Level, cmsg.Type, hex.Dump(cmsg.Data))
|
||||
}
|
||||
@ -15,6 +15,11 @@ const (
|
||||
|
||||
defaultRXMemoryLimit = 2 << 20 // 2 MiB
|
||||
defaultTXMemoryLimit = 2 << 20 // 2 MiB
|
||||
|
||||
// defaultMaxUSOOffloadSize is the default maximum offload size for USO.
|
||||
// It should be smaller than or equal to the typical hardware offload size
|
||||
// to avoid software fallback. Mellanox NICs typically support up to 64000.
|
||||
defaultMaxUSOOffloadSize = 64000
|
||||
)
|
||||
|
||||
// Config holds configuration for a RIO connection, independent of the transport protocol.
|
||||
@ -89,4 +94,30 @@ func (o TxConfig) MaxPayloadLen() uint16 {
|
||||
// UDPConfig holds configuration for a [UDPConn].
|
||||
type UDPConfig struct {
|
||||
Config
|
||||
uso USOConfig
|
||||
}
|
||||
|
||||
// USO returns the UDP segmentation offload (USO) configuration.
|
||||
func (o UDPConfig) USO() *USOConfig {
|
||||
return &o.uso
|
||||
}
|
||||
|
||||
// USOConfig holds the UDP segmentation offload (USO) configuration.
|
||||
type USOConfig struct {
|
||||
enabled bool
|
||||
maxOffloadSize uint16 // 0 means default (i.e., [defaultMaxUSOOffloadSize])
|
||||
}
|
||||
|
||||
// Enabled reports whether USO is enabled.
|
||||
func (o USOConfig) Enabled() bool {
|
||||
return o.enabled
|
||||
}
|
||||
|
||||
// MaxOffloadSize returns the maximum number of bytes that can be batched into a
|
||||
// single send for UDP segmentation offload.
|
||||
func (o USOConfig) MaxOffloadSize() uint16 {
|
||||
if !o.Enabled() {
|
||||
return 0
|
||||
}
|
||||
return cmp.Or(o.maxOffloadSize, defaultMaxUSOOffloadSize)
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
// Package rioconn provides [UDPConn], a UDP socket implementation
|
||||
// that uses the Windows RIO API extensions and supports batched I/O
|
||||
// for improved performance on high-throughput UDP workloads.
|
||||
// and USO for improved performance on high-throughput UDP workloads.
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
|
||||
84
net/rioconn/offloads.go
Normal file
84
net/rioconn/offloads.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
// coalescePackets copies packets from buffs into dst until dst is full, it
|
||||
// copies a packet shorter than the first, the maximum coalesced length
|
||||
// or the maximum number of packets is reached, or there are no more packets.
|
||||
// A zero maxCoalescedLen or maxCoalescedPackets means no limit.
|
||||
//
|
||||
// Each packet is copied starting at offset.
|
||||
// Each copied packet is preceded by a Geneve header if geneve.VNI.IsSet().
|
||||
//
|
||||
// It returns the number of packets and bytes copied into dst, the packet
|
||||
// size for the batch, or an error if Geneve header encoding fails.
|
||||
func coalescePackets(
|
||||
dst []byte, geneve packet.GeneveHeader, buffs [][]byte,
|
||||
offset, maxCoalescedPackets, maxCoalescedBytes int,
|
||||
) (packets, bytes, packetSize int, err error) {
|
||||
var header []byte
|
||||
if geneve.VNI.IsSet() {
|
||||
var geneveHeader [packet.GeneveFixedHeaderLength]byte
|
||||
if err := geneve.Encode(geneveHeader[:]); err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
header = geneveHeader[:]
|
||||
}
|
||||
if len(buffs) != 0 {
|
||||
// The first packet determines the packet size for the batch,
|
||||
// which is the size of each packet in the coalesced batch
|
||||
// except possibly the last one. If the first packet cannot fit
|
||||
// in dst, we cannot coalesce any packets.
|
||||
packetSize = len(header) + len(buffs[0]) - offset
|
||||
if packetSize > len(dst) {
|
||||
return 0, 0, 0, fmt.Errorf("%w: packet size %d exceeds dst size %d",
|
||||
io.ErrShortBuffer, packetSize, len(dst),
|
||||
)
|
||||
}
|
||||
}
|
||||
for _, buff := range buffs {
|
||||
buff = buff[offset:]
|
||||
packetLen := len(header) + len(buff)
|
||||
newBytes := bytes + packetLen
|
||||
if newBytes > len(dst) {
|
||||
break // no more space
|
||||
}
|
||||
if packetLen > packetSize {
|
||||
break // packet is too large for this batch
|
||||
}
|
||||
if bytes != 0 && maxCoalescedBytes != 0 && newBytes > maxCoalescedBytes {
|
||||
break // would exceed the maximum coalesced length
|
||||
}
|
||||
if maxCoalescedPackets != 0 && packets >= maxCoalescedPackets {
|
||||
break // would exceed the maximum number of coalesced packets
|
||||
}
|
||||
if packetLen == 0 {
|
||||
// Consume the zero-length packet if it's the first packet,
|
||||
// but never coalesce them.
|
||||
if packets == 0 {
|
||||
packets = 1
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
copy(dst[bytes:], header)
|
||||
copy(dst[bytes+len(header):], buff)
|
||||
|
||||
packets++
|
||||
bytes = newBytes
|
||||
if packetLen < packetSize {
|
||||
// A smaller than packetSize packet on the tail is legal,
|
||||
// but it must end the batch.
|
||||
break
|
||||
}
|
||||
}
|
||||
return packets, bytes, packetSize, nil
|
||||
}
|
||||
334
net/rioconn/offloads_test.go
Normal file
334
net/rioconn/offloads_test.go
Normal file
@ -0,0 +1,334 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
func TestCoalescePackets(t *testing.T) {
|
||||
geneve := packet.GeneveHeader{
|
||||
Protocol: packet.GeneveProtocolWireGuard,
|
||||
}
|
||||
geneve.VNI.Set(7)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
buffs [][]byte
|
||||
offset int
|
||||
geneve packet.GeneveHeader
|
||||
dst []byte
|
||||
maxCoalescedPackets int
|
||||
maxCoalescedBytes int
|
||||
wantBytes []byte
|
||||
wantPackets int
|
||||
wantPacketSize int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no-packets",
|
||||
buffs: nil,
|
||||
geneve: packet.GeneveHeader{},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: nil,
|
||||
wantPackets: 0,
|
||||
wantPacketSize: 0,
|
||||
},
|
||||
{
|
||||
name: "single-packet",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
},
|
||||
geneve: packet.GeneveHeader{},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
wantPackets: 1,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "single-packet/zero-length",
|
||||
buffs: [][]byte{
|
||||
{},
|
||||
},
|
||||
geneve: packet.GeneveHeader{},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{},
|
||||
wantPackets: 1,
|
||||
wantPacketSize: 0,
|
||||
},
|
||||
{
|
||||
name: "single-packet/with-offset",
|
||||
buffs: [][]byte{
|
||||
{
|
||||
0x00, 0x00,
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
},
|
||||
offset: 2,
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
wantPackets: 1,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "single-packet/with-geneve",
|
||||
buffs: [][]byte{
|
||||
{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Geneve header space
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
},
|
||||
offset: 8,
|
||||
geneve: geneve,
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
wantPackets: 1,
|
||||
wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet
|
||||
},
|
||||
{
|
||||
name: "single-packet/exact-fit",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
},
|
||||
dst: make([]byte, 3),
|
||||
wantBytes: []byte{0x01, 0x02, 0x03},
|
||||
wantPackets: 1,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "single-packet/too-large-for-dst",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
},
|
||||
dst: make([]byte, 2),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "single-packet/with-geneve/too-large-for-dst",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
},
|
||||
geneve: geneve,
|
||||
dst: make([]byte, 10), // smaller than 8 bytes Geneve header + 3 bytes packet
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/coalesce-all",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08, 0x09},
|
||||
},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09,
|
||||
},
|
||||
wantPackets: 3,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/coalesce-all/with-offset",
|
||||
buffs: [][]byte{
|
||||
{0x00, 0x00, 0x01, 0x02, 0x03},
|
||||
{0x00, 0x00, 0x04, 0x05, 0x06},
|
||||
{0x00, 0x00, 0x07, 0x08, 0x09},
|
||||
},
|
||||
offset: 2,
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09,
|
||||
},
|
||||
wantPackets: 3,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/smaller-packet-ends-batch",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08}, // will be coalesced, but ends the batch
|
||||
{0x09, 0x0a, 0x0b}, // will not be coalesced in this batch
|
||||
},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
0x07, 0x08,
|
||||
},
|
||||
wantPackets: 3,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/larger-packet-ends-batch",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08, 0x09, 0x0a}, // will not be coalesced in this batch
|
||||
|
||||
},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
},
|
||||
wantPackets: 2,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/partial-fit",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08, 0x09}, // could be coalesced, but won't fit
|
||||
|
||||
},
|
||||
dst: make([]byte, 7), // can only fit the first two packets
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
},
|
||||
wantPackets: 2,
|
||||
wantPacketSize: 3,
|
||||
wantErr: false, // partial fit is not an error
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/exact-fit",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08, 0x09},
|
||||
},
|
||||
dst: make([]byte, 9), // can exactly fit all three packets
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09,
|
||||
},
|
||||
wantPackets: 3,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/with-geneve",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{0x07, 0x08, 0x09},
|
||||
{0x0a, 0x0b}, // ends the batch
|
||||
{0x0c, 0x0d, 0x0e},
|
||||
},
|
||||
geneve: geneve,
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet
|
||||
0x01, 0x02, 0x03,
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet
|
||||
0x04, 0x05, 0x06,
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet
|
||||
0x07, 0x08, 0x09,
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for fourth packet
|
||||
0x0a, 0x0b,
|
||||
},
|
||||
wantPackets: 4,
|
||||
wantPacketSize: 11, // 8 bytes Geneve header + 3 bytes packet
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/all-zero-length",
|
||||
buffs: [][]byte{
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{},
|
||||
wantPackets: 1, // zero-length packets cannot be coalesced
|
||||
wantPacketSize: 0,
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/all-zero-length/with-geneve",
|
||||
buffs: [][]byte{
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
},
|
||||
geneve: geneve,
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for first packet
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for second packet
|
||||
0x00, 0x00, 0x7a, 0x12, 0x00, 0x00, 0x07, 0x00, // Geneve header for third packet
|
||||
},
|
||||
wantPackets: 3,
|
||||
wantPacketSize: 8, // Geneve header size, since the packets are zero-length
|
||||
},
|
||||
{
|
||||
name: "multiple-packets/zero-length-packet-ends-batch",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
{0x04, 0x05, 0x06},
|
||||
{}, // cannot be coalesced, ends the batch
|
||||
|
||||
},
|
||||
dst: make([]byte, 100),
|
||||
wantBytes: []byte{
|
||||
0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06,
|
||||
},
|
||||
wantPackets: 2,
|
||||
wantPacketSize: 3,
|
||||
},
|
||||
{
|
||||
name: "invalid-geneve-header",
|
||||
buffs: [][]byte{
|
||||
{0x01, 0x02, 0x03},
|
||||
},
|
||||
geneve: packet.GeneveHeader{
|
||||
Version: 0xFF,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
gotPackets, gotBytes, gotPacketSize, err := coalescePackets(
|
||||
tt.dst, tt.geneve, tt.buffs, tt.offset,
|
||||
tt.maxCoalescedPackets,
|
||||
tt.maxCoalescedBytes,
|
||||
)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("error: got %v; want error: %v", err, tt.wantErr)
|
||||
}
|
||||
if gotPackets != tt.wantPackets {
|
||||
t.Errorf("packets: got %d; want %d", gotPackets, tt.wantPackets)
|
||||
}
|
||||
if gotBytes := tt.dst[:gotBytes]; !bytes.Equal(gotBytes, tt.wantBytes) {
|
||||
t.Errorf("bytes: got %v; want %v", gotBytes, tt.wantBytes)
|
||||
}
|
||||
if gotPacketSize != tt.wantPacketSize {
|
||||
t.Errorf("packet size: got %d; want %d", gotPacketSize, tt.wantPacketSize)
|
||||
}
|
||||
})
|
||||
// Coalescing packets should not cause any allocations,
|
||||
// except for when it returns an error.
|
||||
if !tt.wantErr && allocs != 0 {
|
||||
t.Errorf("unexpected allocations: got %f; want 0", allocs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -76,3 +76,10 @@ type udpOption func(*UDPConfig)
|
||||
func (o udpOption) applyUDP(opts *UDPConfig) {
|
||||
o(opts)
|
||||
}
|
||||
|
||||
// USO specifies whether UDP segmentation offload (USO) should be enabled.
|
||||
func USO(enabled bool) UDPOption {
|
||||
return udpOption(func(opts *UDPConfig) {
|
||||
opts.uso.enabled = enabled
|
||||
})
|
||||
}
|
||||
|
||||
@ -27,8 +27,9 @@ type request struct {
|
||||
buffID winrio.BufferId // ID of the registered buffer containing this request
|
||||
buffBase uintptr // base address of the registered buffer
|
||||
|
||||
raddr rawSockaddr // remote address for RIO send/receive operations
|
||||
data []byte // a slice pointing into the data buffer area after the struct
|
||||
raddr rawSockaddr // remote address for RIO send/receive operations
|
||||
data []byte // a slice pointing into the data buffer area after the struct
|
||||
control controlMessages // control messages buffer
|
||||
// followed by the actual data at [requestDataOffset]
|
||||
// from the start of the struct.
|
||||
}
|
||||
@ -94,8 +95,22 @@ func (r *request) PostSend(rq winrio.Rq, flags uint32) error {
|
||||
Length: uint32(len(r.data)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
|
||||
}
|
||||
remoteAddr := r.remoteAddrDesc()
|
||||
return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r)))
|
||||
remoteAddress := winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(unsafe.Sizeof(r.raddr)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
|
||||
}
|
||||
var control *winrio.Buffer
|
||||
if !r.control.Empty() {
|
||||
cmsgs := r.control.Buffer()
|
||||
control = &winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(len(cmsgs)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(cmsgs))) - r.buffBase),
|
||||
}
|
||||
}
|
||||
return winrio.SendEx(rq, &data, 1, nil, &remoteAddress, control,
|
||||
nil, flags, uintptr(unsafe.Pointer(r)))
|
||||
}
|
||||
|
||||
// PostReceive posts the request as a receive operation to the given RIO request queue
|
||||
@ -107,8 +122,18 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
|
||||
Length: uint32(cap(r.data)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
|
||||
}
|
||||
remoteAddress := r.remoteAddrDesc()
|
||||
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil,
|
||||
remoteAddress := winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(unsafe.Sizeof(r.raddr)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
|
||||
}
|
||||
controlBuffer := r.control.Buffer()
|
||||
control := &winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(len(controlBuffer)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(controlBuffer))) - r.buffBase),
|
||||
}
|
||||
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, control,
|
||||
nil, flags, uintptr(unsafe.Pointer(r)))
|
||||
}
|
||||
|
||||
@ -119,6 +144,8 @@ func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
|
||||
// of bytes written does not match the length of the request's data buffer.
|
||||
func (r *request) CompleteSend(status int32, bytesWritten uint32) error {
|
||||
expected := len(r.data)
|
||||
r.data = r.data[:0]
|
||||
r.control.Clear()
|
||||
if status != 0 {
|
||||
return windows.Errno(status)
|
||||
}
|
||||
@ -154,18 +181,11 @@ func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReade
|
||||
return r.Reader(), nil
|
||||
}
|
||||
|
||||
func (r *request) remoteAddrDesc() winrio.Buffer {
|
||||
return winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(unsafe.Sizeof(r.raddr)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
|
||||
}
|
||||
}
|
||||
|
||||
// Reset prepares the request for reuse by resetting its state.
|
||||
func (r *request) Reset() {
|
||||
r.raddr = rawSockaddr{}
|
||||
r.data = r.data[:0]
|
||||
r.control.Clear()
|
||||
}
|
||||
|
||||
type (
|
||||
@ -201,6 +221,11 @@ func (w *requestWriter) SetRemoteAddr(raddr rawSockaddr) {
|
||||
w.raddr = raddr
|
||||
}
|
||||
|
||||
// ControlMessages returns a pointer to the request's control messages buffer.
|
||||
func (w *requestWriter) ControlMessages() *controlMessages {
|
||||
return &w.control
|
||||
}
|
||||
|
||||
// Reserve reserves n bytes in the request's data buffer,
|
||||
// and returns a slice pointing to the reserved space.
|
||||
// It panics if n is negative or exceeds the available capacity.
|
||||
@ -262,3 +287,8 @@ func (r *requestReader) RemoteAddrPort() (netip.AddrPort, error) {
|
||||
func (r *requestReader) RemoteAddr() rawSockaddr {
|
||||
return r.raddr
|
||||
}
|
||||
|
||||
// ControlMessages returns a pointer to the request's control messages buffer.
|
||||
func (r *requestReader) ControlMessages() *controlMessages {
|
||||
return &r.control
|
||||
}
|
||||
|
||||
@ -124,36 +124,65 @@ func TestUDPSendReceiveBatch(t *testing.T) {
|
||||
iterations int
|
||||
sendBatchSize int
|
||||
receiveBatchSize int
|
||||
uso bool
|
||||
}{
|
||||
{
|
||||
name: "udp4/single",
|
||||
network: "udp4",
|
||||
pattern: []int{1312},
|
||||
},
|
||||
{
|
||||
name: "udp4/batch",
|
||||
network: "udp4",
|
||||
pattern: []int{1312},
|
||||
iterations: 1024,
|
||||
},
|
||||
{
|
||||
name: "udp4/single/max",
|
||||
network: "udp4",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv4},
|
||||
},
|
||||
{
|
||||
name: "udp4/batch/max",
|
||||
network: "udp4",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv4},
|
||||
name: "udp4/batch/max",
|
||||
network: "udp4",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv4},
|
||||
iterations: 1024,
|
||||
},
|
||||
{
|
||||
name: "udp6/single",
|
||||
network: "udp6",
|
||||
pattern: []int{1312},
|
||||
},
|
||||
{
|
||||
name: "udp6/batch",
|
||||
network: "udp6",
|
||||
pattern: []int{1312},
|
||||
iterations: 1024,
|
||||
},
|
||||
{
|
||||
name: "udp6/single/max",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
},
|
||||
{
|
||||
name: "udp6/batch/max",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
name: "udp6/batch/max",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
iterations: 10,
|
||||
},
|
||||
{
|
||||
name: "udp6/batch/uso",
|
||||
network: "udp6",
|
||||
pattern: []int{1312},
|
||||
iterations: 10,
|
||||
uso: true,
|
||||
},
|
||||
{
|
||||
name: "udp6/batch/max/uso",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
iterations: 10,
|
||||
uso: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@ -164,7 +193,9 @@ func TestUDPSendReceiveBatch(t *testing.T) {
|
||||
cmp.Or(tt.sendBatchSize, defaultBatchSize),
|
||||
cmp.Or(tt.receiveBatchSize, defaultBatchSize),
|
||||
tt.network, tt.network,
|
||||
nil, nil,
|
||||
[]rioconn.UDPOption{
|
||||
rioconn.USO(tt.uso),
|
||||
}, nil,
|
||||
)
|
||||
})
|
||||
}
|
||||
@ -174,16 +205,19 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
|
||||
batchSizes := []uint16{1, 64}
|
||||
packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4}
|
||||
numIterations := []uint16{1024}
|
||||
uso := []bool{false, true}
|
||||
|
||||
for _, packetLen := range packetSizes {
|
||||
for _, numIter := range numIterations {
|
||||
for _, batchSize := range batchSizes {
|
||||
f.Add(packetLen, numIter, batchSize, batchSize)
|
||||
for _, usoEnabled := range uso {
|
||||
f.Add(packetLen, numIter, batchSize, batchSize, usoEnabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) {
|
||||
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16, usoEnabled bool) {
|
||||
network := "udp4"
|
||||
maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4)
|
||||
|
||||
@ -206,6 +240,7 @@ func FuzzUDPSendReceiveBatch(f *testing.F) {
|
||||
[]rioconn.UDPOption{
|
||||
rioconn.RxMemoryLimit(128 << 10),
|
||||
rioconn.TxMemoryLimit(512 << 10),
|
||||
rioconn.USO(usoEnabled),
|
||||
},
|
||||
[]rioconn.UDPOption{
|
||||
rioconn.RxMemoryLimit(512 << 10),
|
||||
|
||||
@ -8,6 +8,7 @@ package rioconn
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
@ -25,14 +26,26 @@ import (
|
||||
// otherwise specified by the method.
|
||||
type udpTx struct {
|
||||
udpNx
|
||||
maxCoalescedPackets int // the maximum number of coalesced packets, or 0 if no limit
|
||||
maxCoalescedBytes int // the maximum total length of coalesced packets, or 0 if no limit
|
||||
}
|
||||
|
||||
// init initializes the transmit half of a [UDPConn] with the
|
||||
// specified underlying connection and options.
|
||||
func (tx *udpTx) init(conn *conn, options UDPConfig) error {
|
||||
// Without USO, the data buffer for each send request only needs to hold
|
||||
// a single packet's payload.
|
||||
dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
|
||||
var dataSize uint16
|
||||
if uso := options.USO(); uso.Enabled() {
|
||||
// When USO is enabled, the request's data buffer may need to
|
||||
// hold multiple coalesced packets up to the maximum offload size,
|
||||
// or a single packet up to the maximum payload size, whichever is larger.
|
||||
dataSize = max(options.Tx().MaxPayloadLen(), uso.MaxOffloadSize())
|
||||
tx.maxCoalescedBytes = int(uso.MaxOffloadSize())
|
||||
} else {
|
||||
// Otherwise, the data buffer for each send request only needs to hold
|
||||
// a single packet's payload.
|
||||
dataSize = min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
|
||||
tx.maxCoalescedPackets = 1
|
||||
}
|
||||
if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil {
|
||||
return fmt.Errorf("failed to initialize udpTx: %w", err)
|
||||
}
|
||||
@ -99,20 +112,31 @@ func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet
|
||||
w := req.Writer()
|
||||
w.SetRemoteAddr(raddr)
|
||||
|
||||
if geneve.VNI.IsSet() {
|
||||
geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength)
|
||||
geneve.Encode(geneveHeader[:])
|
||||
}
|
||||
if _, err := w.Write(buffs[n][offset:]); err != nil {
|
||||
buf := w.Reserve(w.Cap()) // reserve the entire buffer
|
||||
packets, bytes, packetSize, err := coalescePackets(
|
||||
buf, geneve, buffs[n:], offset,
|
||||
tx.maxCoalescedPackets,
|
||||
tx.maxCoalescedBytes,
|
||||
)
|
||||
w.SetLen(bytes) // shrink the buffer to the actual number of bytes coalesced
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if packets == 0 {
|
||||
return io.ErrNoProgress
|
||||
}
|
||||
|
||||
if packets > 1 {
|
||||
if err := w.ControlMessages().AppendUInt32(windows.IPPROTO_UDP, windows.UDP_SEND_MSG_SIZE, uint32(packetSize)); err != nil {
|
||||
return fmt.Errorf("failed to append UDP_SEND_MSG_SIZE cmsg: %w", err)
|
||||
}
|
||||
}
|
||||
if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil {
|
||||
return fmt.Errorf("failed to post send request: %w", err)
|
||||
}
|
||||
|
||||
tx.requests.Advance() // advance after posting the request
|
||||
n++
|
||||
n += packets
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user