mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
net/rioconn: implement a batching UDPConn
In this commit, we implement a batching UDPConn using the previously added RIO plumbing, the base conn, and requestRing. The non-batching Read/Write methods wrap the batching variants. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
05b7b04527
commit
e5cb1f48a6
@ -6,6 +6,7 @@
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
@ -13,6 +14,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
@ -363,3 +365,22 @@ func rioSocket(family, sotype, proto int32) (windows.Handle, error) {
|
||||
windows.WSA_FLAG_OVERLAPPED
|
||||
return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags)
|
||||
}
|
||||
|
||||
// WSAIoctlIn issues an ioctl command with the provided code and input value
|
||||
// on the connection's underlying socket. It is a type-safe shorthand for calling
|
||||
// [syscall.RawConn.Control] with a function that invokes [windows.WSAIoctl]
|
||||
// with the appropriate arguments, without any output buffer.
|
||||
func WSAIoctlIn[Input any](conn syscall.Conn, code uint32, in Input) error {
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
controlErr := rawConn.Control(func(s uintptr) {
|
||||
ret := uint32(0)
|
||||
err = windows.WSAIoctl(windows.Handle(s), code,
|
||||
(*byte)(unsafe.Pointer(&in)), uint32(unsafe.Sizeof(in)),
|
||||
nil, 0, &ret, nil, 0,
|
||||
)
|
||||
})
|
||||
return cmp.Or(controlErr, err)
|
||||
}
|
||||
|
||||
@ -3,6 +3,9 @@
|
||||
|
||||
//go:build windows
|
||||
|
||||
// Package rioconn provides [UDPConn], a UDP socket implementation
|
||||
// that uses the Windows RIO API extensions and supports batched I/O
|
||||
// for improved performance on high-throughput UDP workloads.
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
|
||||
235
net/rioconn/udp.go
Normal file
235
net/rioconn/udp.go
Normal file
@ -0,0 +1,235 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxUDPPayloadIPv4 is the maximum UDP payload size over IPv4.
|
||||
// IPv4 total length is 65535 bytes, including:
|
||||
// - 20-byte IPv4 header (no options)
|
||||
// - 8-byte UDP header
|
||||
MaxUDPPayloadIPv4 = 1<<16 - 1 - 20 - 8
|
||||
// MaxUDPPayloadIPv6 is the maximum UDP payload size over IPv6.
|
||||
// The IPv6 payload length field excludes the 40-byte base header
|
||||
// and includes the 8-byte UDP header.
|
||||
MaxUDPPayloadIPv6 = 1<<16 - 1 - 8
|
||||
// MaxUDPPayload is the maximum UDP payload size across IP versions.
|
||||
MaxUDPPayload = max(MaxUDPPayloadIPv4, MaxUDPPayloadIPv6)
|
||||
)
|
||||
|
||||
// UDPConn implements a UDP socket using the Windows RIO API extensions.
|
||||
// It supports batched I/O, UDP RSC Offload (URO), and UDP Segmentation
|
||||
// Offload (USO) to improve performance in high-throughput UDP workloads.
|
||||
type UDPConn struct {
|
||||
config UDPConfig
|
||||
|
||||
*conn // the underlying socket connection with RIO extensions
|
||||
udpRx // receiving half-connection
|
||||
udpTx // transmitting half-connection
|
||||
}
|
||||
|
||||
// ListenUDP listens for incoming UDP packets on the local address using
|
||||
// the Registered Input/Output (RIO) API and supports URO and USO when
|
||||
// available. It returns an error if RIO is not available.
|
||||
//
|
||||
// The network must be a UDP network name.
|
||||
//
|
||||
// If the IP field of addr is nil or an unspecified IP address,
|
||||
// ListenUDP listens on all available IP addresses of the local system
|
||||
// except multicast IP addresses. If the network is "udp" and the local
|
||||
// IP is unspecified, ListenUDP listens on both IPv4 and IPv6 addresses.
|
||||
//
|
||||
// If the Port field of addr is 0, a port number is automatically
|
||||
// chosen.
|
||||
//
|
||||
// The provided options are to configure various aspects of the connection,
|
||||
// such as RIO buffer sizes, URO and USO parameters and other socket options.
|
||||
func ListenUDP(network string, addr *net.UDPAddr, options ...UDPOption) (_ *UDPConn, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = &net.OpError{Op: "listen", Net: network, Addr: addr, Err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udp := &UDPConn{}
|
||||
for _, o := range options {
|
||||
if o != nil {
|
||||
o.applyUDP(&udp.config)
|
||||
}
|
||||
}
|
||||
|
||||
laddr, dualStack, err := addrPortFromUDPAddr(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the underlying socket with Registered I/O extensions
|
||||
// and bind it to the local address.
|
||||
udp.conn, err = newConn(windows.SOCK_DGRAM, windows.IPPROTO_UDP,
|
||||
dualStack, laddr, &udp.config.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// If initialization fails, close the connection to
|
||||
// release any allocated resources.
|
||||
if err != nil {
|
||||
udp.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize the Rx and Tx halves of the connection,
|
||||
// which includes allocating memory for RIO buffers
|
||||
// and creating RIO completion queues for each half.
|
||||
if err := udp.udpRx.init(udp.conn, udp.config); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Rx: %w", err)
|
||||
}
|
||||
if err := udp.udpTx.init(udp.conn, udp.config); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Tx: %w", err)
|
||||
}
|
||||
// Create the RIO request queue for the connection and associate it
|
||||
// with the Rx and Tx completion queues.
|
||||
if err := udp.createRequestQueue(
|
||||
udp.udpRx.completionQueue(), udp.udpRx.maxOutstandingRequests(),
|
||||
udp.udpTx.completionQueue(), udp.udpTx.maxOutstandingRequests(),
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to create RIO request queue: %w", err)
|
||||
}
|
||||
// Disable reporting of ICMP "Port Unreachable" errors as socket errors (golang/go#5834).
|
||||
// https://web.archive.org/web/20260208062329/https://support.microsoft.com/en-US/help/263823
|
||||
if err := WSAIoctlIn(udp, windows.SIO_UDP_CONNRESET, uint32(0)); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable SIO_UDP_CONNRESET: %w", err)
|
||||
}
|
||||
// Post initial receive requests.
|
||||
if err := udp.udpRx.postReceiveRequests(); err != nil {
|
||||
return nil, fmt.Errorf("failed to post initial receive requests: %w", err)
|
||||
}
|
||||
return udp, nil
|
||||
}
|
||||
|
||||
// Config returns the effective configuration of the connection.
|
||||
// The returned value is immutable for the lifetime of the connection.
|
||||
func (c *UDPConn) Config() *UDPConfig {
|
||||
return &c.config
|
||||
}
|
||||
|
||||
// SetDeadline implements [net.Conn.SetDeadline].
|
||||
func (c *UDPConn) SetDeadline(t time.Time) error {
|
||||
// TODO(nickkhyl): move this and the other deadline methods to the underlying [conn]?
|
||||
err1 := c.SetReadDeadline(t)
|
||||
err2 := c.SetWriteDeadline(t)
|
||||
return errors.Join(err1, err2)
|
||||
}
|
||||
|
||||
func (c *UDPConn) SetReadDeadline(t time.Time) error {
|
||||
// TODO(nickkhyl): implement read and write deadlines
|
||||
return fmt.Errorf("%w: (%T).SetReadDeadline is not yet implemented", errors.ErrUnsupported, c)
|
||||
}
|
||||
|
||||
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
|
||||
// TODO(nickkhyl): implement read and write deadlines
|
||||
return fmt.Errorf("%w: (%T).SetWriteDeadline is not yet implemented", errors.ErrUnsupported, c)
|
||||
}
|
||||
|
||||
// Close closes the connection, canceling any pending operations,
|
||||
// and freeing all associated resources.
|
||||
func (c *UDPConn) Close() error {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Close the Rx and Tx halves only after closing the underlying connection.
|
||||
// This ensures that all in-flight requests complete and that nothing uses
|
||||
// the RIO buffers or completion queues after they are closed.
|
||||
return errors.Join(c.udpRx.Close(), c.udpTx.Close())
|
||||
}
|
||||
|
||||
// udpNx is a base struct for [udpRx] and [udpTx] half-connections
|
||||
// that contains common state and logic.
|
||||
type udpNx struct {
|
||||
conn *conn
|
||||
|
||||
// mu protects the fields below and serializes access to the completion queue.
|
||||
// Lock order: udpNx.mu > conn.mu.
|
||||
mu sync.Mutex
|
||||
requests *requestRing // ring of RIO request contexts for this half-connection
|
||||
cq winrio.Cq // completion queue associated with this half-connection
|
||||
hasCompletionsEvt windows.Handle // signaled by RIO when there are completions to dequeue.
|
||||
results []winrio.Result // dequeued completion results
|
||||
}
|
||||
|
||||
// init initializes the common state for [udpRx] or [udpTx].
|
||||
// The conn parameter is the underlying connection associated with this half-connection.
|
||||
// The dataSize parameter specifies the size of the data buffer for each request in the ring,
|
||||
// and memoryLimit specifies the maximum total memory used by all requests.
|
||||
func (nx *udpNx) init(conn *conn, dataSize uint16, memoryLimit uintptr) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
nx.Close()
|
||||
}
|
||||
}()
|
||||
if nx.requests, err = newRequestRing(dataSize, memoryLimit); err != nil {
|
||||
return fmt.Errorf("failed to create request ring: %w", err)
|
||||
}
|
||||
if nx.hasCompletionsEvt, err = windows.CreateEvent(nil, 0, 0, nil); err != nil {
|
||||
return fmt.Errorf("failed to create completion event: %w", err)
|
||||
}
|
||||
nx.results = make([]winrio.Result, 0, nx.requests.Cap())
|
||||
if nx.cq, err = winrio.CreateEventCompletionQueue(nx.requests.Cap(), nx.hasCompletionsEvt, true); err != nil {
|
||||
return fmt.Errorf("failed to create completion queue: %w", err)
|
||||
}
|
||||
nx.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// completionQueue returns the RIO completion queue used by
|
||||
// the half-connection for completion notifications.
|
||||
func (nx *udpNx) completionQueue() winrio.Cq {
|
||||
return nx.cq
|
||||
}
|
||||
|
||||
// maxOutstandingRequests returns the maximum number of in-flight
|
||||
// requests the half-connection can post to the RIO request queue
|
||||
// without blocking.
|
||||
func (nx *udpNx) maxOutstandingRequests() uint32 {
|
||||
return nx.requests.Cap()
|
||||
}
|
||||
|
||||
// Close releases all resources associated with the half-connection.
|
||||
// It must not be called until the connection using this
|
||||
// half-connection's buffers and completion queue is closed.
|
||||
func (nx *udpNx) Close() error {
|
||||
nx.mu.Lock()
|
||||
defer nx.mu.Unlock()
|
||||
if nx.cq != 0 {
|
||||
winrio.CloseCompletionQueue(nx.cq)
|
||||
nx.cq = 0
|
||||
}
|
||||
if nx.hasCompletionsEvt != 0 {
|
||||
windows.CloseHandle(nx.hasCompletionsEvt)
|
||||
nx.hasCompletionsEvt = 0
|
||||
}
|
||||
if nx.requests != nil {
|
||||
if err := nx.requests.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
nx.requests = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
430
net/rioconn/udp_test.go
Normal file
430
net/rioconn/udp_test.go
Normal file
@ -0,0 +1,430 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"tailscale.com/net/batching"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/rioconn"
|
||||
"tailscale.com/types/nettype"
|
||||
)
|
||||
|
||||
// [UDPConn] implements the following interfaces.
|
||||
var (
|
||||
_ batching.Conn = (*rioconn.UDPConn)(nil)
|
||||
_ net.PacketConn = (*rioconn.UDPConn)(nil)
|
||||
_ nettype.PacketConn = (*rioconn.UDPConn)(nil)
|
||||
_ syscall.Conn = (*rioconn.UDPConn)(nil)
|
||||
)
|
||||
|
||||
func TestListenUDP(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
address string
|
||||
wantLocalAddrPort netip.AddrPort
|
||||
wantDualStack bool
|
||||
}{
|
||||
{
|
||||
network: "udp", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"),
|
||||
},
|
||||
{
|
||||
network: "udp4", address: "127.0.0.1:0", wantLocalAddrPort: netip.MustParseAddrPort("127.0.0.1:0"),
|
||||
},
|
||||
{
|
||||
network: "udp", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"),
|
||||
},
|
||||
{
|
||||
network: "udp6", address: "[::1]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::1]:0"),
|
||||
},
|
||||
{
|
||||
network: "udp", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
},
|
||||
{
|
||||
network: "udp4", address: "0.0.0.0:0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
},
|
||||
{
|
||||
network: "udp", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
|
||||
},
|
||||
{
|
||||
network: "udp6", address: "[::]:0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
|
||||
},
|
||||
{
|
||||
network: "udp", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"), wantDualStack: true,
|
||||
},
|
||||
{
|
||||
network: "udp4", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
},
|
||||
{
|
||||
network: "udp6", address: ":0", wantLocalAddrPort: netip.MustParseAddrPort("[::]:0"),
|
||||
},
|
||||
{
|
||||
network: "udp", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: true,
|
||||
},
|
||||
{
|
||||
network: "udp4", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("0.0.0.0:41613"), wantDualStack: false,
|
||||
},
|
||||
{
|
||||
network: "udp6", address: ":41613", wantLocalAddrPort: netip.MustParseAddrPort("[::]:41613"), wantDualStack: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.network+"/"+tt.address, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr(tt.network, tt.address)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveUDPAddr(%q, %q) error: %v", tt.network, tt.address, err)
|
||||
}
|
||||
|
||||
conn, err := rioconn.ListenUDP(tt.network, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP(%q, %q) error: %v", tt.network, tt.address, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
gotAddressPort := conn.LocalAddrPort()
|
||||
if wantAddress := tt.wantLocalAddrPort.Addr(); gotAddressPort.Addr().Compare(wantAddress) != 0 {
|
||||
t.Errorf("LocalAddrPort() Addr = %v; want %v", gotAddressPort.Addr(), tt.wantLocalAddrPort.Addr())
|
||||
}
|
||||
if wantPort := tt.wantLocalAddrPort.Port(); wantPort != 0 && gotAddressPort.Port() != wantPort {
|
||||
t.Errorf("LocalAddrPort() Port = %v; want %v", gotAddressPort.Port(), wantPort)
|
||||
}
|
||||
if gotDualStack := conn.IsDualStack(); gotDualStack != tt.wantDualStack {
|
||||
t.Errorf("IsDualStack() = %v; want %v", gotDualStack, tt.wantDualStack)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPSendReceiveBatch(t *testing.T) {
|
||||
const defaultBatchSize = 64
|
||||
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
pattern []int
|
||||
iterations int
|
||||
sendBatchSize int
|
||||
receiveBatchSize int
|
||||
}{
|
||||
{
|
||||
name: "udp4/single",
|
||||
network: "udp4",
|
||||
pattern: []int{1312},
|
||||
},
|
||||
{
|
||||
name: "udp4/single/max",
|
||||
network: "udp4",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv4},
|
||||
},
|
||||
{
|
||||
name: "udp4/batch/max",
|
||||
network: "udp4",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv4},
|
||||
},
|
||||
{
|
||||
name: "udp6/single",
|
||||
network: "udp6",
|
||||
pattern: []int{1312},
|
||||
},
|
||||
{
|
||||
name: "udp6/single/max",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
},
|
||||
{
|
||||
name: "udp6/batch/max",
|
||||
network: "udp6",
|
||||
pattern: []int{rioconn.MaxUDPPayloadIPv6},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
udpSendReceiveBatch(t,
|
||||
tt.pattern, max(1, tt.iterations),
|
||||
cmp.Or(tt.sendBatchSize, defaultBatchSize),
|
||||
cmp.Or(tt.receiveBatchSize, defaultBatchSize),
|
||||
tt.network, tt.network,
|
||||
nil, nil,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzUDPSendReceiveBatch(f *testing.F) {
|
||||
batchSizes := []uint16{1, 64}
|
||||
packetSizes := []uint16{0, 1, 64, 1312, 9000, rioconn.MaxUDPPayloadIPv4}
|
||||
numIterations := []uint16{1024}
|
||||
|
||||
for _, packetLen := range packetSizes {
|
||||
for _, numIter := range numIterations {
|
||||
for _, batchSize := range batchSizes {
|
||||
f.Add(packetLen, numIter, batchSize, batchSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, packetLen, numIterations, sendBatchSize, receiveBatchSize uint16) {
|
||||
network := "udp4"
|
||||
maxPacketLen := uint16(rioconn.MaxUDPPayloadIPv4)
|
||||
|
||||
if packetLen > maxPacketLen {
|
||||
t.Skipf("packetLen is too large: %d", packetLen)
|
||||
}
|
||||
if numIterations > 10_000 {
|
||||
t.Skipf("numIterations is too large: %d", numIterations)
|
||||
}
|
||||
if sendBatchSize == 0 || sendBatchSize > 1024 {
|
||||
t.Skipf("sendBatchSize is out of range: %d", sendBatchSize)
|
||||
}
|
||||
if receiveBatchSize == 0 || receiveBatchSize > 1024 {
|
||||
t.Skipf("receiveBatchSize is out of range: %d", receiveBatchSize)
|
||||
}
|
||||
|
||||
packetLengthPattern := []int{int(packetLen)}
|
||||
udpSendReceiveBatch(t, packetLengthPattern, int(numIterations),
|
||||
int(sendBatchSize), int(receiveBatchSize), network, network,
|
||||
[]rioconn.UDPOption{
|
||||
rioconn.RxMemoryLimit(128 << 10),
|
||||
rioconn.TxMemoryLimit(512 << 10),
|
||||
},
|
||||
[]rioconn.UDPOption{
|
||||
rioconn.RxMemoryLimit(512 << 10),
|
||||
rioconn.TxMemoryLimit(128 << 10),
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// udpSendReceive sends and receives batches of UDP packets between two
|
||||
// [rioconn.UDPConn] instances over the loopback interface.
|
||||
//
|
||||
// It uses the provided packet length pattern, iteration count,
|
||||
// batch sizes, networks, and connection options.
|
||||
func udpSendReceiveBatch(
|
||||
tb testing.TB,
|
||||
packetLengthPattern []int,
|
||||
numIterations int,
|
||||
sendBatchSize, receiveBatchSize int,
|
||||
senderNetwork, receiverNetwork string,
|
||||
senderOpts, receiverOpts []rioconn.UDPOption,
|
||||
) {
|
||||
stopMsg := []byte("STOP")
|
||||
|
||||
sender, err := rioconn.ListenUDP(senderNetwork, loopbackUDPAddr(senderNetwork, 0), senderOpts...)
|
||||
if err != nil {
|
||||
tb.Fatalf("ListenUDP(%s, nil) error: %v", senderNetwork, err)
|
||||
}
|
||||
defer sender.Close()
|
||||
|
||||
receiver, err := rioconn.ListenUDP(receiverNetwork, loopbackUDPAddr(receiverNetwork, 0), receiverOpts...)
|
||||
if err != nil {
|
||||
tb.Fatalf("ListenUDP(%s, nil) error: %v", receiverNetwork, err)
|
||||
}
|
||||
defer receiver.Close()
|
||||
|
||||
// Do not allocate buffers larger than needed for the test.
|
||||
maxPacketLen := max(len(stopMsg), slices.Max(packetLengthPattern))
|
||||
|
||||
outBuffs := make([][]byte, sendBatchSize)
|
||||
for i := range outBuffs {
|
||||
outBuffs[i] = make([]byte, maxPacketLen)
|
||||
}
|
||||
|
||||
inMsgs := make([]ipv6.Message, receiveBatchSize)
|
||||
for i := range inMsgs {
|
||||
inMsgs[i].Buffers = make([][]byte, 1)
|
||||
inMsgs[i].Buffers[0] = make([]byte, maxPacketLen)
|
||||
}
|
||||
|
||||
readerResult := make(chan error, 1)
|
||||
writerResult := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(writerResult)
|
||||
|
||||
dstAddr := receiver.LocalAddrPort()
|
||||
|
||||
bytes := 0
|
||||
packets := 0
|
||||
iteration := 0
|
||||
for iteration < numIterations {
|
||||
outBuffs := outBuffs[:cap(outBuffs)]
|
||||
for k := range outBuffs {
|
||||
packetLen := packetLengthPattern[packets%len(packetLengthPattern)]
|
||||
out := outBuffs[k][:packetLen]
|
||||
outBuffs[k] = out
|
||||
for j := 0; j < packetLen; j++ {
|
||||
out[j] = byte('A' + bytes%26)
|
||||
bytes++
|
||||
}
|
||||
packets++
|
||||
if packets%len(packetLengthPattern) == 0 {
|
||||
iteration++
|
||||
}
|
||||
if iteration >= numIterations {
|
||||
outBuffs = outBuffs[:k+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := sender.WriteBatchTo(outBuffs, dstAddr, packet.GeneveHeader{}, 0); err != nil {
|
||||
writerResult <- fmt.Errorf("failed to send batch #%d: %w", iteration, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tb.Logf("Writer done sending %d packets and %d bytes in %d iterations", packets, bytes, iteration)
|
||||
tb.Logf("Sending STOP messages to signal the reader to stop")
|
||||
for {
|
||||
select {
|
||||
case <-readerResult:
|
||||
tb.Logf("Reader has stopped, no need to send more STOP messages")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if _, err := sender.WriteTo(stopMsg, net.UDPAddrFromAddrPort(dstAddr)); err != nil {
|
||||
writerResult <- fmt.Errorf("failed to send a STOP message: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer close(readerResult)
|
||||
|
||||
bytesReceived := 0
|
||||
for {
|
||||
n, err := receiver.ReadBatch(inMsgs, 0)
|
||||
if err != nil {
|
||||
readerResult <- fmt.Errorf("ReadBatch() error: %w", err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
msg := inMsgs[i]
|
||||
if bytes.Equal(msg.Buffers[0][:msg.N], stopMsg) {
|
||||
tb.Logf("Received a STOP message, reader is stopping")
|
||||
return
|
||||
}
|
||||
for j := 0; j < msg.N; j++ {
|
||||
expectedByte := byte('A' + bytesReceived%26)
|
||||
if msg.Buffers[0][j] != expectedByte {
|
||||
readerResult <- fmt.Errorf("unexpected byte at position %d: got %v, want %v",
|
||||
bytesReceived, msg.Buffers[0][j], expectedByte)
|
||||
return
|
||||
}
|
||||
bytesReceived++
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := <-writerResult; err != nil {
|
||||
tb.Fatalf("writer error: %v", err)
|
||||
}
|
||||
if err := <-readerResult; err != nil {
|
||||
tb.Fatalf("reader error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPReadWrite(t *testing.T) {
|
||||
sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP: %v", err)
|
||||
}
|
||||
defer sender.Close()
|
||||
|
||||
receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP: %v", err)
|
||||
}
|
||||
defer receiver.Close()
|
||||
|
||||
message := []byte("Hello, world!")
|
||||
|
||||
n, err := sender.WriteTo(message, net.UDPAddrFromAddrPort(receiver.LocalAddrPort()))
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTo: %v", err)
|
||||
}
|
||||
if n != len(message) {
|
||||
t.Fatalf("WriteTo: wrote %d bytes, want %d", n, len(message))
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, addr, err := receiver.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom: %v", err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], message) {
|
||||
t.Fatalf("ReadFrom: got %q, want %q", buf[:n], message)
|
||||
}
|
||||
if addr.String() != net.UDPAddrFromAddrPort(sender.LocalAddrPort()).String() {
|
||||
t.Fatalf("ReadFrom: got addr %v, want %v", addr, sender.LocalAddrPort())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPReadFromUDPAddrPort(t *testing.T) {
|
||||
sender, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP: %v", err)
|
||||
}
|
||||
defer sender.Close()
|
||||
|
||||
receiver, err := rioconn.ListenUDP("udp4", loopbackUDPAddr("udp4", 0))
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP: %v", err)
|
||||
}
|
||||
defer receiver.Close()
|
||||
|
||||
message := []byte("Hello, world!")
|
||||
|
||||
n, err := sender.WriteToUDPAddrPort(message, receiver.LocalAddrPort())
|
||||
if err != nil {
|
||||
t.Fatalf("WriteToUDPAddrPort: %v", err)
|
||||
}
|
||||
if n != len(message) {
|
||||
t.Fatalf("WriteToUDPAddrPort: wrote %d bytes, want %d", n, len(message))
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, addr, err := receiver.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFromUDPAddrPort: %v", err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], message) {
|
||||
t.Fatalf("ReadFromUDPAddrPort: got %q, want %q", buf[:n], message)
|
||||
}
|
||||
if addr != sender.LocalAddrPort() {
|
||||
t.Fatalf("ReadFromUDPAddrPort: got addr %v, want %v", addr, sender.LocalAddrPort())
|
||||
}
|
||||
}
|
||||
|
||||
func loopbackUDPAddr(network string, port int) *net.UDPAddr {
|
||||
switch network {
|
||||
case "udp4":
|
||||
return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}
|
||||
case "udp6":
|
||||
return &net.UDPAddr{IP: net.IPv6loopback, Port: port}
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported network: %s", network))
|
||||
}
|
||||
}
|
||||
204
net/rioconn/udprx.go
Normal file
204
net/rioconn/udprx.go
Normal file
@ -0,0 +1,204 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// udpRx is the receive half of [UDPConn].
|
||||
//
|
||||
// Its exported methods are safe for concurrent use.
|
||||
// The caller must ensure that the connection is not closed
|
||||
// while any unexported methods are in flight, unless
|
||||
// otherwise specified by the method.
|
||||
type udpRx struct {
|
||||
udpNx
|
||||
// pendingResultIdx is the index in [udpNx.results]
|
||||
// of the next pending result to process.
|
||||
pendingResultIdx int
|
||||
}
|
||||
|
||||
// init initializes the receive half of a [UDPConn] with the
|
||||
// specified underlying connection and options.
|
||||
func (rx *udpRx) init(conn *conn, options UDPConfig) error {
|
||||
// Without URO, the data buffer for each receive request only needs
|
||||
// to hold a single packet's payload.
|
||||
dataSize := min(options.Rx().MaxPayloadLen(), MaxUDPPayload)
|
||||
if err := rx.udpNx.init(conn, dataSize, options.Rx().MemoryLimit()); err != nil {
|
||||
return fmt.Errorf("failed to initialize udpRx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadBatch implements [batching.Conn] by reading messages into msgs.
|
||||
// It returns the number of messages the caller should evaluate for nonzero len,
|
||||
// as a zero len message may fall on either side of a nonzero.
|
||||
// The flags parameter is reserved for future use and must be zero.
|
||||
func (rx *udpRx) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
|
||||
// Prevent the connection from closing while in use.
|
||||
if err := rx.conn.acquire(); err != nil {
|
||||
return 0, &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err}
|
||||
}
|
||||
defer rx.conn.release()
|
||||
|
||||
rx.mu.Lock()
|
||||
defer rx.mu.Unlock()
|
||||
// Keep trying to read until we get at least one message or an error.
|
||||
for n == 0 && err == nil {
|
||||
if err := rx.awaitCompletionsLocked(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = rx.processCompletionsLocked(msgs)
|
||||
}
|
||||
// Always try to post more receive requests, even if an error
|
||||
// occurred while processing completed ones.
|
||||
if postErr := rx.postReceiveRequestsLocked(); postErr != nil {
|
||||
err = errors.Join(err, postErr)
|
||||
}
|
||||
if err != nil {
|
||||
err = &net.OpError{Op: "read", Net: rx.conn.Network(), Source: rx.conn.LocalAddr(), Err: err}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ReadFromUDPAddrPort implements [nettype.PacketConn.ReadFromUDPAddrPort].
|
||||
func (rx *udpRx) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
|
||||
n, netAddr, err := rx.ReadFrom(p)
|
||||
if netAddr != nil {
|
||||
addr = netAddr.(*net.UDPAddr).AddrPort()
|
||||
}
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
// ReadFrom implements [net.PacketConn.ReadFrom].
|
||||
func (rx *udpRx) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
msgs := []ipv6.Message{{
|
||||
Buffers: [][]byte{p},
|
||||
}}
|
||||
numMsgs, err := rx.ReadBatch(msgs, 0)
|
||||
if numMsgs != 0 {
|
||||
n = msgs[0].N
|
||||
addr = msgs[0].Addr
|
||||
}
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
// postReceiveRequests posts available receive requests to the
|
||||
// RIO request queue. The caller must ensure that the connection
|
||||
// is not closed until this call returns.
|
||||
func (rx *udpRx) postReceiveRequests() error {
|
||||
rx.mu.Lock()
|
||||
defer rx.mu.Unlock()
|
||||
return rx.postReceiveRequestsLocked()
|
||||
}
|
||||
|
||||
// postReceiveRequestsLocked posts all available receive requests
|
||||
// to the RIO request queue.
|
||||
// rx.mu must be held.
|
||||
func (rx *udpRx) postReceiveRequestsLocked() (err error) {
|
||||
return rx.conn.postReceiveRequests(rx.requests.AcquireSeq())
|
||||
}
|
||||
|
||||
// awaitCompletionsLocked dequeues completed receive requests, returning when
|
||||
// there's at least one completion to process, the connection is closed,
|
||||
// or an error occurs.
|
||||
// rx.mu must be held.
|
||||
func (rx *udpRx) awaitCompletionsLocked() error {
|
||||
if rx.pendingResultIdx < len(rx.results) {
|
||||
// We have already dequeued some completions that haven't been
|
||||
// fully processed yet. Return immediately.
|
||||
return nil
|
||||
}
|
||||
|
||||
rx.results = rx.results[:cap(rx.results)]
|
||||
rx.pendingResultIdx = 0
|
||||
|
||||
var count uint32
|
||||
for {
|
||||
if count = winrio.DequeueCompletion(rx.cq, rx.results[:]); count != 0 {
|
||||
// Got new completions to process, no need to wait.
|
||||
break
|
||||
}
|
||||
// Otherwise, arm the notification...
|
||||
if err := winrio.Notify(rx.cq); err != nil {
|
||||
return err
|
||||
}
|
||||
// ...and wait until RIO signals that more completions are available
|
||||
// or the connection is closed.
|
||||
handles := []windows.Handle{rx.conn.closedEvt, rx.hasCompletionsEvt}
|
||||
switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); {
|
||||
case err != nil:
|
||||
return fmt.Errorf("waiting for completed receives failed: %w", err)
|
||||
case evtIdx == 0:
|
||||
return net.ErrClosed
|
||||
case evtIdx == 1:
|
||||
continue // try dequeueing completions again
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
rx.results = rx.results[:count]
|
||||
return nil
|
||||
}
|
||||
|
||||
// processCompletionsLocked processes completed receive requests and fills msgs
|
||||
// with the received packets. It returns the number of messages the caller
|
||||
// should evaluate for nonzero len, as a zero len message may fall on either
|
||||
// side of a nonzero.
|
||||
// rx.mu must be held.
|
||||
func (rx *udpRx) processCompletionsLocked(msgs []ipv6.Message) (n int, err error) {
|
||||
firstResultIdx := rx.pendingResultIdx
|
||||
|
||||
defer func() {
|
||||
// Always release processed results, even if an error occurred.
|
||||
rx.requests.ReleaseN(rx.pendingResultIdx - firstResultIdx)
|
||||
}()
|
||||
|
||||
for rx.pendingResultIdx < len(rx.results) && n < len(msgs) {
|
||||
res := &rx.results[rx.pendingResultIdx]
|
||||
req := (*request)(unsafe.Pointer(uintptr(res.RequestContext)))
|
||||
r, err := req.CompleteReceive(res.Status, res.BytesTransferred)
|
||||
if err != nil {
|
||||
rx.pendingResultIdx++
|
||||
if err == windows.WSAEMSGSIZE {
|
||||
// The packet is larger than [RxConfig.MaxPayloadLen].
|
||||
// Skip it and try to process the next one, if any.
|
||||
continue
|
||||
}
|
||||
// In case of other errors, skip the packet and return
|
||||
// the error to the caller.
|
||||
return n, err
|
||||
}
|
||||
// TODO(nickkhyl): Maintain an LRU cache of remote addresses to
|
||||
// avoid allocating a new [netip.AddrPort] / [net.UDPAddr] for each packet.
|
||||
// Profiling suggests this accounts for ~5% of total processing time.
|
||||
udpAddr, err := r.RemoteAddr().ToUDPAddr()
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("invalid remote address: %w", err)
|
||||
}
|
||||
|
||||
if r.Len() <= len(msgs[n].Buffers[0]) {
|
||||
// TODO(nickkhyl): avoid the copy? We could transfer ownership of the underlying
|
||||
// buffer to the reader until the next read or an explicit release.
|
||||
msgs[n].N = copy(msgs[n].Buffers[0], r.Bytes())
|
||||
} else {
|
||||
msgs[n].N = 0 // packet is too large; ignore it
|
||||
}
|
||||
msgs[n].Addr = udpAddr
|
||||
rx.pendingResultIdx++
|
||||
n++
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
191
net/rioconn/udptx.go
Normal file
191
net/rioconn/udptx.go
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
// udpTx is the transmit half of [UDPConn].
|
||||
//
|
||||
// Its exported methods are safe for concurrent use.
|
||||
// The caller must ensure that the connection is not closed
|
||||
// while any unexported methods are in flight, unless
|
||||
// otherwise specified by the method.
|
||||
type udpTx struct {
|
||||
udpNx
|
||||
}
|
||||
|
||||
// init initializes the transmit half of a [UDPConn] with the
|
||||
// specified underlying connection and options.
|
||||
func (tx *udpTx) init(conn *conn, options UDPConfig) error {
|
||||
// Without USO, the data buffer for each send request only needs to hold
|
||||
// a single packet's payload.
|
||||
dataSize := min(options.Tx().MaxPayloadLen(), MaxUDPPayload)
|
||||
if err := tx.udpNx.init(conn, dataSize, options.Tx().MemoryLimit()); err != nil {
|
||||
return fmt.Errorf("failed to initialize udpTx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteBatchTo implements [batching.Conn.WriteBatchTo] by writing
|
||||
// buffs to the specified remote address.
|
||||
//
|
||||
// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding
|
||||
// offset, and offset must equal [packet.GeneveFixedHeaderLength].
|
||||
// Otherwise, the space preceding offset is ignored.
|
||||
func (tx *udpTx) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
|
||||
if tx.conn.IsDualStack() && addr.Addr().Is4() {
|
||||
// Convert to an IPv4-mapped IPv6 address
|
||||
addr = netip.AddrPortFrom(netip.AddrFrom16(addr.Addr().As16()), addr.Port())
|
||||
}
|
||||
if err := tx.writeBatchTo(buffs, addr, geneve, offset); err != nil {
|
||||
return &net.OpError{Op: "write", Net: tx.conn.Network(), Source: tx.conn.LocalAddr(), Addr: net.UDPAddrFromAddrPort(addr), Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeBatchTo implements [udpTx.WriteBatchTo]. It returns an
|
||||
// error if the connection is already closed and prevents the
|
||||
// connection from closing until it returns.
|
||||
func (tx *udpTx) writeBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) (err error) {
|
||||
if len(buffs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
raddr, err := rawSockaddrFromAddrPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert address: %w", err)
|
||||
}
|
||||
|
||||
// Prevent the connection from closing while in use.
|
||||
if err := tx.conn.acquire(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.conn.release()
|
||||
|
||||
tx.mu.Lock()
|
||||
defer tx.mu.Unlock()
|
||||
|
||||
n := 0
|
||||
defer func() {
|
||||
if n != 0 {
|
||||
if commitErr := tx.conn.commitSendRequests(); commitErr != nil {
|
||||
err = errors.Join(err, commitErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for n < len(buffs) {
|
||||
if tx.conn.IsClosed() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if err := tx.drainCompletionsLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := tx.requests.Peek()
|
||||
w := req.Writer()
|
||||
w.SetRemoteAddr(raddr)
|
||||
|
||||
if geneve.VNI.IsSet() {
|
||||
geneveHeader := w.Reserve(packet.GeneveFixedHeaderLength)
|
||||
geneve.Encode(geneveHeader[:])
|
||||
}
|
||||
if _, err := w.Write(buffs[n][offset:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.conn.postSendRequest(req, winrio.MsgDefer); err != nil {
|
||||
return fmt.Errorf("failed to post send request: %w", err)
|
||||
}
|
||||
|
||||
tx.requests.Advance() // advance after posting the request
|
||||
n++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteToUDPAddrPort implements [nettype.PacketConn.WriteToUDPAddrPort].
|
||||
func (tx *udpTx) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (n int, err error) {
|
||||
if err := tx.WriteBatchTo([][]byte{p}, addr, packet.GeneveHeader{}, 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// WriteTo implements [net.PacketConn.WriteTo].
|
||||
func (tx *udpTx) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: tx.conn.Network(),
|
||||
Source: tx.conn.LocalAddr(),
|
||||
Addr: addr,
|
||||
Err: net.InvalidAddrError("address is not a *net.UDPAddr"),
|
||||
}
|
||||
}
|
||||
return tx.WriteToUDPAddrPort(p, udpAddr.AddrPort())
|
||||
}
|
||||
|
||||
// drainCompletionsLocked dequeues and processes completed send requests
|
||||
// until the request ring is not full (i.e., more requests can be posted)
|
||||
// or the closedEvt is signaled.
|
||||
//
|
||||
// tx.mu must be held, and the caller must ensure that the connection
|
||||
// is not closed until this call returns.
|
||||
func (tx *udpTx) drainCompletionsLocked() error {
|
||||
var count uint32
|
||||
for {
|
||||
if count = winrio.DequeueCompletion(tx.cq, tx.results[:cap(tx.results)]); count != 0 {
|
||||
// Got new completions to process, no need to wait.
|
||||
break
|
||||
}
|
||||
if !tx.requests.IsFull() {
|
||||
// No completions to process, but also not all requests are in-flight,
|
||||
// so no need to wait.
|
||||
break
|
||||
}
|
||||
// Otherwise, if all requests are in flight, commit any deferred sends.
|
||||
tx.conn.commitSendRequests()
|
||||
// Then arm the notification...
|
||||
if err := winrio.Notify(tx.cq); err != nil {
|
||||
return err
|
||||
}
|
||||
// ...and wait for either RIO to signal that more completions are available,
|
||||
// or the connection to be closed.
|
||||
handles := []windows.Handle{tx.conn.closedEvt, tx.hasCompletionsEvt}
|
||||
switch evtIdx, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE); {
|
||||
case err != nil:
|
||||
return fmt.Errorf("waiting for completed sends failed: %w", err)
|
||||
case evtIdx == 0:
|
||||
return net.ErrClosed
|
||||
case evtIdx == 1:
|
||||
continue // try dequeueing completions again
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
for _, res := range tx.results[:count] {
|
||||
req := (*request)(unsafe.Pointer(uintptr(res.RequestContext)))
|
||||
if err := req.CompleteSend(res.Status, res.BytesTransferred); err != nil {
|
||||
// TODO(nickkhyl): Returning an error here does not make much sense.
|
||||
// Increment a send error metric or log the error instead?
|
||||
}
|
||||
}
|
||||
tx.results = tx.results[:0]
|
||||
tx.requests.ReleaseN(int(count))
|
||||
return nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user