net/rioconn: add base, protocol-agnostic RIO-enabled connection type

We'd like to implement this as its own distinct type early on. That decouples protocol-specific
code from the UDPConn implementation in case we want to add a TCPConn later.

More importantly, it lets UDPConn's Rx/Tx half-conns reference it while keeping the locking
hierarchy sane, since RIO's request queue isn't safe for concurrent use.

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:33:11 -06:00
parent d57b58193a
commit 52ab2b1894
No known key found for this signature in database
5 changed files with 833 additions and 0 deletions

28
net/rioconn/config.go Normal file
View File

@ -0,0 +1,28 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package rioconn
import (
"errors"
"syscall"
)
// Config holds configuration for a RIO connection, independent of the transport protocol.
type Config struct {
control []func(network, address string, c syscall.RawConn) error
}
// Control invokes all control functions in the Config with the given
// network, address, and connection. A failure of one control function
// does not prevent the others from running. It returns an error if any
// control function fails.
func (c Config) Control(network string, address string, conn syscall.RawConn) error {
var err []error
for _, control := range c.control {
if e := control(network, address, conn); e != nil {
err = append(err, e)
}
}
return errors.Join(err...)
}

365
net/rioconn/conn.go Normal file
View File

@ -0,0 +1,365 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"errors"
"fmt"
"iter"
"net"
"net/netip"
"sync"
"syscall"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/sys/windows"
)
// conn is a protocol-agnostic base connection with RIO support.
//
// Its exported methods are safe for concurrent use, including
// concurrent calls with each other, with Close, and after Close.
//
// However, the caller must call [conn.acquire] before invoking
// any unexported methods and [conn.release] when done to prevent
// the connection from closing while the operation is in flight.
type conn struct {
// immutable once [newConn] returns:
family int32
localAddr net.Addr
localAddrPort netip.AddrPort
dualStack bool
sotype int32
proto int32
net string
config *Config
// guard prevents the connection from closing and its resources from
// being freed while operations are in flight. All fields below are
// protected by guard and must not be accessed after the connection
// is closed and [guard.Acquire] returns false, except by [conn.Close].
guard *guard
closedEvt windows.Handle
socket windows.Handle
// closeMu serializes calls to [conn.Close].
// Lock order: closeMu > mu.
closeMu sync.Mutex
// mu serializes access to the RIO request queue.
// Lock order: closeMu > mu.
mu sync.Mutex
rq winrio.Rq
}
// rawConn implements [syscall.RawConn] for [conn].
type rawConn conn
var (
_ syscall.Conn = (*conn)(nil)
_ syscall.RawConn = (*rawConn)(nil)
)
func newConn(sotype int32, proto int32, dualStack bool, laddr netip.AddrPort, config *Config) (_ *conn, err error) {
if config == nil {
config = &Config{}
}
sa, family, err := sockaddrFromAddrPort(laddr)
if err != nil {
return nil, err
}
net, err := networkName(sotype, proto, family, dualStack)
if err != nil {
return nil, err
}
conn := &conn{
family: family,
dualStack: dualStack,
sotype: sotype,
proto: proto,
net: net,
config: config,
guard: newGuard(),
}
defer func() {
// If initialization fails, close the connection to release
// any resources allocated before the error.
if err != nil {
conn.Close()
}
}()
// Create a manual-reset event to wake up pending operations on Close.
if conn.closedEvt, err = windows.CreateEvent(nil, 1, 0, nil); err != nil {
return nil, fmt.Errorf("failed to create close notification event: %w", err)
}
// Create a socket with the WSA_FLAG_REGISTERED_IO flag set.
if conn.socket, err = rioSocket(family, sotype, proto); err != nil {
return nil, fmt.Errorf("failed to create socket(%d, %d, %d): %w", family, sotype, proto, err)
}
// Enable dual-stack mode by clearing the IPV6_V6ONLY option, if necessary.
// https://web.archive.org/web/20260208062136/https://learn.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets
if dualStack {
if err := windows.SetsockoptInt(conn.socket, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, 0); err != nil {
return nil, fmt.Errorf("failed to enable dual-stack mode: %w", err)
}
}
// Invoke caller-provided control functions to set socket options before binding.
if err := conn.config.Control(net, laddr.String(), (*rawConn)(conn)); err != nil {
return nil, fmt.Errorf("control failed: %w", err)
}
if err := windows.Bind(conn.socket, sa); err != nil {
return nil, fmt.Errorf("failed to bind socket: %w", err)
}
// Record the local address from the actual socket, since the caller
// may have specified port 0 for automatic assignment.
if conn.localAddrPort, err = addrPortFromSocket(conn.socket); err != nil {
return nil, fmt.Errorf("failed to get local address and port: %w", err)
}
if conn.localAddr, err = netAddrFromAddrPort(conn.localAddrPort, sotype); err != nil {
return nil, fmt.Errorf("failed to convert local address and port to net.Addr: %w", err)
}
return conn, nil
}
// IsClosed reports whether Close has been called.
func (c *conn) IsClosed() bool {
return c.guard.IsClosed()
}
// acquire increments the connection's reference count, preventing
// it from closing. If it returns no error, the caller may use the
// connection and must call [conn.release] when done. Otherwise,
// the connection must not be used.
func (c *conn) acquire() error {
if !c.guard.Acquire() {
return net.ErrClosed
}
return nil
}
// release decrements the connection's reference count.
// Calling release without a matching acquire is a run-time error.
func (c *conn) release() {
c.guard.Release()
}
// Family returns the socket address family of the connection.
func (c *conn) Family() int32 {
return c.family
}
// Network returns the network name of the connection.
func (c *conn) Network() string {
return c.net
}
// IsDualStack reports whether the connection is dual-stack and can send
// and receive packets to and from IPv6 or IPv4-mapped IPv6 addresses.
func (c *conn) IsDualStack() bool {
return c.dualStack
}
// LocalAddr returns the local network address.
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
// LocalAddrPort returns the local network address and port.
func (c *conn) LocalAddrPort() netip.AddrPort {
return c.localAddrPort
}
// SyscallConn returns a raw network connection, or an error if the connection is closed.
func (c *conn) SyscallConn() (syscall.RawConn, error) {
if c.IsClosed() {
// Return the error immediately if the connection is already closed.
// The [conn] implementation handles the case where the connection
// closes after this call returns, so this is only an optimization.
return nil, net.ErrClosed
}
return c.syscallConn(), nil
}
func (c *conn) syscallConn() syscall.RawConn {
return (*rawConn)(c)
}
// Control implements [syscall.RawConn.Control].
func (c *rawConn) Control(f func(uintptr)) error {
return (*conn)(c).rawControl(f)
}
// rawControl implements [rawConn.Control].
func (c *conn) rawControl(f func(uintptr)) error {
if !c.guard.Acquire() {
return &net.OpError{Op: "raw-control", Net: c.net, Addr: c.localAddr, Err: net.ErrClosed}
}
defer c.guard.Release()
f(uintptr(c.socket))
return nil
}
// Read implements [syscall.RawConn.Read].
func (c *rawConn) Read(f func(uintptr) bool) error {
return &net.OpError{Op: "raw-read", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported}
}
// Write implements [syscall.RawConn.Write].
func (c *rawConn) Write(f func(uintptr) bool) error {
return &net.OpError{Op: "raw-write", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported}
}
// createRequestQueue creates a RIO request queue for the connection.
// It must be called before sending or receiving data. The call fails
// if a request queue already exists, if any parameter is invalid,
// or if RIO request queue creation fails.
//
// The caller must ensure that the connection is not closed while
// this call is in progress and that the provided completion queues
// have sufficient capacity for the specified number of outstanding
// requests, plus any requests posted by other connections sharing
// the same completion queues. As of 2026-02-17, completion queues
// are currently not shared between connections.
func (c *conn) createRequestQueue(
receiveCq winrio.Cq, maxOutstandingReceives uint32,
sendCq winrio.Cq, maxOutstandingSends uint32,
) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.rq != 0 {
return errors.New("already created")
}
if receiveCq == 0 {
return errors.New("invalid Rx completion queue")
}
if sendCq == 0 {
return errors.New("invalid Tx completion queue")
}
if maxOutstandingReceives == 0 {
return errors.New("invalid max outstanding receives")
}
if maxOutstandingSends == 0 {
return errors.New("invalid max outstanding sends")
}
var err error
c.rq, err = winrio.CreateRequestQueue(c.socket,
maxOutstandingReceives, 1,
maxOutstandingSends, 1,
receiveCq, sendCq,
0,
)
return err
}
// postReceiveRequests posts multiple receive requests to the RIO request queue.
// It returns an error if posting any request fails.
//
// The caller must ensure that the connection is not closed until this call returns.
func (c *conn) postReceiveRequests(reqs iter.Seq[*request]) (err error) {
var deferred int
defer func() {
// Always commit any deferred receive requests, even if an error occurred.
if deferred != 0 {
if commitErr := c.commitReceiveRequests(); commitErr != nil {
err = errors.Join(err, commitErr)
}
}
}()
c.mu.Lock()
defer c.mu.Unlock()
for req := range reqs {
if err := c.postReceiveRequestLocked(req, winrio.MsgDefer); err != nil {
return fmt.Errorf("failed to post receive request #%d: %w", deferred, err)
}
deferred++
}
return nil
}
// postReceiveRequestLocked posts a single receive request to the
// RIO request queue.
//
// c.mu must be held, and the caller must ensure
// that the connection is not closed until this call returns.
func (c *conn) postReceiveRequestLocked(req *request, flags uint32) error {
return req.PostReceive(c.rq, flags)
}
// commitReceiveRequests commits previously deferred receive requests.
//
// The caller must ensure that the connection is not closed until
// this call returns. It may be called with or without c.mu held.
func (c *conn) commitReceiveRequests() error {
// Unlike other ReceiveEx calls, commits do not need to be serialized:
// https://web.archive.org/web/20260216052922/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_rioreceiveex
if err := winrio.ReceiveEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil {
return fmt.Errorf("failed to commit deferred receive requests: %w", err)
}
return nil
}
// postSendRequest posts a single send request to the RIO request queue.
// The caller must ensure that the connection is not closed until this call returns.
func (c *conn) postSendRequest(req *request, flags uint32) error {
// Submit the send request. As the underlying RIO request queue
// is not thread-safe, we need to serialize access to it.
c.mu.Lock()
err := req.PostSend(c.rq, flags)
c.mu.Unlock()
return err
}
// commitSendRequests commits previously deferred send requests.
//
// The caller must ensure that the connection is not closed until
// this call returns. It may be called with or without c.mu held.
func (c *conn) commitSendRequests() error {
// Unlike other SendEx calls, commits do not need to be serialized:
// https://web.archive.org/web/20260216053051/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_riosendex
if err := winrio.SendEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil {
return fmt.Errorf("failed to commit deferred send requests: %w", err)
}
return nil
}
// Close closes the connection, cancels any pending operations,
// and releases all associated resources.
// Close is safe for concurrent use.
func (c *conn) Close() error {
if c == nil {
return nil
}
c.closeMu.Lock()
defer c.closeMu.Unlock()
c.guard.Close() // prevent new operations
if c.closedEvt != 0 { // wake up blocked operations
if err := windows.SetEvent(c.closedEvt); err != nil {
return fmt.Errorf("failed to set close notification event: %w", err)
}
}
c.guard.Wait()
// At this point, no operations are in flight and no new ones can start,
// so it is safe to release resources.
if c.socket != 0 {
windows.Closesocket(c.socket)
c.socket = 0
}
if c.closedEvt != 0 {
windows.CloseHandle(c.closedEvt)
c.closedEvt = 0
}
return nil
}
func rioSocket(family, sotype, proto int32) (windows.Handle, error) {
const rioWSAFlags = windows.WSA_FLAG_REGISTERED_IO |
windows.WSA_FLAG_NO_HANDLE_INHERIT |
windows.WSA_FLAG_OVERLAPPED
return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags)
}

239
net/rioconn/conn_test.go Normal file
View File

@ -0,0 +1,239 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"net/netip"
"syscall"
"testing"
)
func TestNewConn(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sotype int32
proto int32
dualStack bool
laddr netip.AddrPort
wantErr bool
wantNetwork string
wantFamily int32
wantPort uint16 // 0 means any port
wantAddr netip.Addr
}{
{
name: "IPv4/UDP/AnyAddr/EphemeralPort",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
wantAddr: netip.MustParseAddr("0.0.0.0"),
wantFamily: syscall.AF_INET,
wantNetwork: "udp4",
},
{
name: "IPv6/UDP/AnyAddr/EphemeralPort",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
laddr: netip.MustParseAddrPort("[::]:0"),
wantAddr: netip.MustParseAddr("::"),
wantFamily: syscall.AF_INET6,
wantNetwork: "udp6",
},
{
name: "IPv4/UDP/LoopbackAddr/EphemeralPort",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
laddr: netip.MustParseAddrPort("127.0.0.1:0"),
wantAddr: netip.MustParseAddr("127.0.0.1"),
wantFamily: syscall.AF_INET,
wantNetwork: "udp4",
},
{
name: "IPv6/UDP/LoopbackAddr/EphemeralPort",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
laddr: netip.MustParseAddrPort("[::1]:0"),
wantAddr: netip.MustParseAddr("::1"),
wantFamily: syscall.AF_INET6,
wantNetwork: "udp6",
},
{
name: "IPv6/UDP/AnyAddr/EphemeralPort/DualStack",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
dualStack: true,
laddr: netip.MustParseAddrPort("[::]:0"),
wantAddr: netip.MustParseAddr("::"),
wantFamily: syscall.AF_INET6,
wantNetwork: "udp",
},
{
name: "IPv4/TCP/AnyAddr/EphemeralPort",
sotype: syscall.SOCK_STREAM,
proto: syscall.IPPROTO_TCP,
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
wantAddr: netip.MustParseAddr("0.0.0.0"),
wantFamily: syscall.AF_INET,
wantNetwork: "tcp4",
},
{
name: "IPv6/TCP/AnyAddr/EphemeralPort",
sotype: syscall.SOCK_STREAM,
proto: syscall.IPPROTO_TCP,
laddr: netip.MustParseAddrPort("[::]:0"),
wantAddr: netip.MustParseAddr("::"),
wantFamily: syscall.AF_INET6,
wantNetwork: "tcp6",
},
{
name: "IPv6/TCP/AnyAddr/EphemeralPort/DualStack",
sotype: syscall.SOCK_STREAM,
proto: syscall.IPPROTO_TCP,
dualStack: true,
laddr: netip.MustParseAddrPort("[::]:0"),
wantAddr: netip.MustParseAddr("::"),
wantFamily: syscall.AF_INET6,
wantNetwork: "tcp",
},
{
name: "InvalidSOType",
sotype: 12345,
proto: syscall.IPPROTO_UDP,
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
wantErr: true,
},
{
name: "InvalidProtocol",
sotype: syscall.SOCK_DGRAM,
proto: 12345,
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
wantErr: true,
},
{
name: "InvalidAddress",
sotype: syscall.SOCK_DGRAM,
proto: syscall.IPPROTO_UDP,
laddr: netip.AddrPort{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c, err := newConn(tt.sotype, tt.proto, tt.dualStack, tt.laddr, nil)
if (err != nil) != tt.wantErr {
t.Fatalf("newConn: got error %v, want: %v", err, tt.wantErr)
}
if err != nil {
return
}
t.Cleanup(func() {
if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
})
if got := c.Network(); got != tt.wantNetwork {
t.Errorf("newConn: got network %q, want %q", got, tt.wantNetwork)
}
if got := c.IsDualStack(); got != tt.dualStack {
t.Errorf("newConn: got dualStack %v, want %v", got, tt.dualStack)
}
if got := c.Family(); got != tt.wantFamily {
t.Errorf("newConn: got family %d, want %d", got, tt.wantFamily)
}
gotAddrPort := c.LocalAddrPort()
if !gotAddrPort.IsValid() {
t.Fatalf("newConn: got invalid local address %v", gotAddrPort)
}
if gotAddrPort.Addr() != tt.wantAddr {
t.Errorf("newConn: got address %v, want %v", gotAddrPort.Addr(), tt.wantAddr)
}
if tt.wantPort != 0 && gotAddrPort.Port() != tt.wantPort {
t.Errorf("newConn: got port %d, want %d", gotAddrPort.Port(), tt.wantPort)
} else if gotAddrPort.Port() == 0 {
t.Errorf("newConn: got port 0, want non-zero")
}
if c.LocalAddr().String() != gotAddrPort.String() {
t.Errorf("newConn: LocalAddr %q, LocalAddrPort %q", c.LocalAddr(), gotAddrPort)
}
})
}
}
func TestConnAcquireReleaseClose(t *testing.T) {
t.Parallel()
c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false, netip.MustParseAddrPort("0.0.0.0:0"), nil)
if err != nil {
t.Fatalf("newConn failed: %v", err)
}
if c.IsClosed() {
t.Fatal("newConn: got closed connection")
}
if err := c.acquire(); err != nil {
t.Errorf("acquire failed: %v", err)
}
go c.release() // race with close
if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
if !c.IsClosed() {
t.Fatal("Close did not mark connection as closed")
}
if err := c.acquire(); err == nil {
t.Fatal("acquire succeeded on closed connection")
}
if err := c.Close(); err != nil {
t.Fatalf("Close failed on already closed connection: %v", err)
}
}
func TestSyscallConn(t *testing.T) {
t.Parallel()
c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false,
netip.MustParseAddrPort("0.0.0.0:0"), nil)
if err != nil {
t.Fatalf("newConn failed: %v", err)
}
defer c.Close()
syscallConn, err := c.SyscallConn()
if err != nil {
t.Fatalf("SyscallConn failed: %v", err)
}
err = syscallConn.Control(func(fd uintptr) {
if fd != uintptr(c.socket) {
t.Fatalf("Control: got fd %v, want %v", fd, c.socket)
}
})
if err != nil {
t.Fatalf("Control failed: %v", err)
}
if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
err = syscallConn.Control(func(fd uintptr) {
t.Fatalf("Control succeeded on closed connection with fd %v", fd)
})
if err == nil {
t.Fatal("Control succeeded on closed connection")
}
}

View File

@ -12,6 +12,7 @@ import (
"unsafe"
"github.com/tailscale/wireguard-go/conn/winrio"
"golang.org/x/sys/windows"
)
// request represents a portion of a RIO-registered memory buffer used for
@ -85,6 +86,82 @@ func (r *request) Reader() *requestReader {
return (*requestReader)(r)
}
// PostSend posts the request as a send operation to the given RIO request queue
// with the specified flags.
func (r *request) PostSend(rq winrio.Rq, flags uint32) error {
data := winrio.Buffer{
Id: r.buffID,
Length: uint32(len(r.data)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
}
remoteAddr := r.remoteAddrDesc()
return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r)))
}
// PostReceive posts the request as a receive operation to the given RIO request queue
// with the specified flags.
func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
r.data = r.data[:0]
data := winrio.Buffer{
Id: r.buffID,
Length: uint32(cap(r.data)),
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
}
remoteAddress := r.remoteAddrDesc()
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil,
nil, flags, uintptr(unsafe.Pointer(r)))
}
// CompleteSend finalizes a send request.
//
// It validates the completion status and the number of bytes written,
// returning an error if the status indicates a failure, or if the number
// of bytes written does not match the length of the request's data buffer.
func (r *request) CompleteSend(status int32, bytesWritten uint32) error {
expected := len(r.data)
if status != 0 {
return windows.Errno(status)
}
if uint64(bytesWritten) != uint64(expected) {
return fmt.Errorf(
"bytes written (%d) does not match data buffer length (%d)",
bytesWritten,
expected,
)
}
return nil
}
// CompleteReceive finalizes a receive request.
//
// It validates the completion status and the number of bytes read
// returning an error if the status indicates a failure, or if the number
// of bytes read exceeds the capacity of the request's data buffer.
//
// On success, it returns a reader view of the request.
func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReader, error) {
if status != 0 {
return nil, windows.Errno(status)
}
if uint64(bytesRead) > uint64(cap(r.data)) {
return nil, fmt.Errorf(
"bytes read (%d) exceeds data buffer capacity (%d)",
bytesRead,
cap(r.data),
)
}
r.data = r.data[:bytesRead]
return r.Reader(), nil
}
func (r *request) remoteAddrDesc() winrio.Buffer {
return winrio.Buffer{
Id: r.buffID,
Length: uint32(unsafe.Sizeof(r.raddr)),
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
}
}
// Reset prepares the request for reuse by resetting its state.
func (r *request) Reset() {
r.raddr = rawSockaddr{}

View File

@ -10,6 +10,8 @@ import (
"net/netip"
"testing"
"unsafe"
"golang.org/x/sys/windows"
)
func TestUnsafeMapRequest(t *testing.T) {
@ -279,6 +281,128 @@ func TestRequestReader(t *testing.T) {
}
}
func TestRequestCompleteSend(t *testing.T) {
tests := []struct {
name string
payloadLength int
bytesWritten uint32
status int32
wantErr bool
}{
{
name: "empty/success",
bytesWritten: 0,
status: 0,
},
{
name: "non-empty/success",
payloadLength: 100,
bytesWritten: 100,
status: 0,
},
{
name: "non-empty/partial-write",
payloadLength: 100,
bytesWritten: 50,
status: 0,
wantErr: true,
},
{
name: "non-empty/overflow-write",
payloadLength: 100,
bytesWritten: 150,
status: 0,
wantErr: true,
},
{
name: "non-empty/error",
payloadLength: 100,
bytesWritten: 0,
status: int32(windows.ERROR_GEN_FAILURE),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := makeRequest(t, tt.payloadLength)
req.Writer().Write(bytes.Repeat([]byte{'x'}, tt.payloadLength))
err := req.CompleteSend(tt.status, tt.bytesWritten)
if tt.wantErr && err == nil {
t.Error("request.CompleteSend: expected error; got nil")
} else if !tt.wantErr && err != nil {
t.Errorf("request.CompleteSend: unexpected error: %v", err)
}
})
}
}
func TestRequestCompleteReceive(t *testing.T) {
tests := []struct {
name string
data []byte
bytesRead uint32
status int32
wantErr bool
wantBytes []byte
}{
{
name: "empty/success",
data: nil,
bytesRead: 0,
wantBytes: nil,
},
{
name: "non-empty/success/full",
data: []byte("Hello World"),
bytesRead: 11,
wantBytes: []byte("Hello World"),
},
{
name: "non-empty/success/partial",
data: []byte("Hello World"),
// Simulate a partial read by reporting fewer bytes
// that the request's data buffer length.
bytesRead: 5,
wantErr: false,
wantBytes: []byte("Hello"),
},
{
name: "non-empty/overflow-read",
data: []byte("Hello World"),
bytesRead: 20, // more than the data buffer can hold
wantErr: true,
},
{
name: "non-empty/error",
data: []byte("Hello World"),
bytesRead: 0,
status: int32(windows.ERROR_GEN_FAILURE),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := makeRequest(t, len(tt.data))
copy(req.data[:cap(req.data)], tt.data)
r, err := req.CompleteReceive(tt.status, tt.bytesRead)
if tt.wantErr && err == nil {
t.Error("request.CompleteReceive: expected error; got nil")
} else if !tt.wantErr && err != nil {
t.Errorf("request.CompleteReceive: unexpected error: %v", err)
}
if err != nil {
return
}
if !bytes.Equal(r.Bytes(), tt.wantBytes) {
t.Errorf("requestReader.Bytes: got %q; want %q", r.Bytes(), tt.wantBytes)
}
})
}
}
func makeRequest(tb testing.TB, dataSize int) *request {
tb.Helper()
buff := makeAligned(totalRequestSize(uintptr(dataSize)), unsafe.Alignof(request{}))