From 52ab2b18947f0adf2cca82c101b7537d5c2fb6c5 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 19 Feb 2026 08:33:11 -0600 Subject: [PATCH] net/rioconn: add base, protocol-agnostic RIO-enabled connection type We'd like to implement this as its own distinct type early on. That decouples protocol-specific code from the UDPConn implementation in case we want to add a TCPConn later. More importantly, it lets UDPConn's Rx/Tx half-conns reference it while keeping the locking hierarchy sane, since RIO's request queue isn't safe for concurrent use. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl --- net/rioconn/config.go | 28 +++ net/rioconn/conn.go | 365 ++++++++++++++++++++++++++++++++++++ net/rioconn/conn_test.go | 239 +++++++++++++++++++++++ net/rioconn/request.go | 77 ++++++++ net/rioconn/request_test.go | 124 ++++++++++++ 5 files changed, 833 insertions(+) create mode 100644 net/rioconn/config.go create mode 100644 net/rioconn/conn.go create mode 100644 net/rioconn/conn_test.go diff --git a/net/rioconn/config.go b/net/rioconn/config.go new file mode 100644 index 000000000..8ccc88306 --- /dev/null +++ b/net/rioconn/config.go @@ -0,0 +1,28 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "errors" + "syscall" +) + +// Config holds configuration for a RIO connection, independent of the transport protocol. +type Config struct { + control []func(network, address string, c syscall.RawConn) error +} + +// Control invokes all control functions in the Config with the given +// network, address, and connection. A failure of one control function +// does not prevent the others from running. It returns an error if any +// control function fails. +func (c Config) Control(network string, address string, conn syscall.RawConn) error { + var err []error + for _, control := range c.control { + if e := control(network, address, conn); e != nil { + err = append(err, e) + } + } + return errors.Join(err...) +} diff --git a/net/rioconn/conn.go b/net/rioconn/conn.go new file mode 100644 index 000000000..f71719232 --- /dev/null +++ b/net/rioconn/conn.go @@ -0,0 +1,365 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "errors" + "fmt" + "iter" + "net" + "net/netip" + "sync" + "syscall" + + "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/sys/windows" +) + +// conn is a protocol-agnostic base connection with RIO support. +// +// Its exported methods are safe for concurrent use, including +// concurrent calls with each other, with Close, and after Close. +// +// However, the caller must call [conn.acquire] before invoking +// any unexported methods and [conn.release] when done to prevent +// the connection from closing while the operation is in flight. +type conn struct { + // immutable once [newConn] returns: + family int32 + localAddr net.Addr + localAddrPort netip.AddrPort + dualStack bool + sotype int32 + proto int32 + net string + config *Config + + // guard prevents the connection from closing and its resources from + // being freed while operations are in flight. All fields below are + // protected by guard and must not be accessed after the connection + // is closed and [guard.Acquire] returns false, except by [conn.Close]. + guard *guard + closedEvt windows.Handle + socket windows.Handle + + // closeMu serializes calls to [conn.Close]. + // Lock order: closeMu > mu. + closeMu sync.Mutex + + // mu serializes access to the RIO request queue. + // Lock order: closeMu > mu. + mu sync.Mutex + rq winrio.Rq +} + +// rawConn implements [syscall.RawConn] for [conn]. +type rawConn conn + +var ( + _ syscall.Conn = (*conn)(nil) + _ syscall.RawConn = (*rawConn)(nil) +) + +func newConn(sotype int32, proto int32, dualStack bool, laddr netip.AddrPort, config *Config) (_ *conn, err error) { + if config == nil { + config = &Config{} + } + sa, family, err := sockaddrFromAddrPort(laddr) + if err != nil { + return nil, err + } + net, err := networkName(sotype, proto, family, dualStack) + if err != nil { + return nil, err + } + conn := &conn{ + family: family, + dualStack: dualStack, + sotype: sotype, + proto: proto, + net: net, + config: config, + guard: newGuard(), + } + defer func() { + // If initialization fails, close the connection to release + // any resources allocated before the error. + if err != nil { + conn.Close() + } + }() + // Create a manual-reset event to wake up pending operations on Close. + if conn.closedEvt, err = windows.CreateEvent(nil, 1, 0, nil); err != nil { + return nil, fmt.Errorf("failed to create close notification event: %w", err) + } + // Create a socket with the WSA_FLAG_REGISTERED_IO flag set. + if conn.socket, err = rioSocket(family, sotype, proto); err != nil { + return nil, fmt.Errorf("failed to create socket(%d, %d, %d): %w", family, sotype, proto, err) + } + // Enable dual-stack mode by clearing the IPV6_V6ONLY option, if necessary. + // https://web.archive.org/web/20260208062136/https://learn.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets + if dualStack { + if err := windows.SetsockoptInt(conn.socket, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, 0); err != nil { + return nil, fmt.Errorf("failed to enable dual-stack mode: %w", err) + } + } + // Invoke caller-provided control functions to set socket options before binding. + if err := conn.config.Control(net, laddr.String(), (*rawConn)(conn)); err != nil { + return nil, fmt.Errorf("control failed: %w", err) + } + if err := windows.Bind(conn.socket, sa); err != nil { + return nil, fmt.Errorf("failed to bind socket: %w", err) + } + // Record the local address from the actual socket, since the caller + // may have specified port 0 for automatic assignment. + if conn.localAddrPort, err = addrPortFromSocket(conn.socket); err != nil { + return nil, fmt.Errorf("failed to get local address and port: %w", err) + } + if conn.localAddr, err = netAddrFromAddrPort(conn.localAddrPort, sotype); err != nil { + return nil, fmt.Errorf("failed to convert local address and port to net.Addr: %w", err) + } + return conn, nil +} + +// IsClosed reports whether Close has been called. +func (c *conn) IsClosed() bool { + return c.guard.IsClosed() +} + +// acquire increments the connection's reference count, preventing +// it from closing. If it returns no error, the caller may use the +// connection and must call [conn.release] when done. Otherwise, +// the connection must not be used. +func (c *conn) acquire() error { + if !c.guard.Acquire() { + return net.ErrClosed + } + return nil +} + +// release decrements the connection's reference count. +// Calling release without a matching acquire is a run-time error. +func (c *conn) release() { + c.guard.Release() +} + +// Family returns the socket address family of the connection. +func (c *conn) Family() int32 { + return c.family +} + +// Network returns the network name of the connection. +func (c *conn) Network() string { + return c.net +} + +// IsDualStack reports whether the connection is dual-stack and can send +// and receive packets to and from IPv6 or IPv4-mapped IPv6 addresses. +func (c *conn) IsDualStack() bool { + return c.dualStack +} + +// LocalAddr returns the local network address. +func (c *conn) LocalAddr() net.Addr { + return c.localAddr +} + +// LocalAddrPort returns the local network address and port. +func (c *conn) LocalAddrPort() netip.AddrPort { + return c.localAddrPort +} + +// SyscallConn returns a raw network connection, or an error if the connection is closed. +func (c *conn) SyscallConn() (syscall.RawConn, error) { + if c.IsClosed() { + // Return the error immediately if the connection is already closed. + // The [conn] implementation handles the case where the connection + // closes after this call returns, so this is only an optimization. + return nil, net.ErrClosed + } + return c.syscallConn(), nil +} + +func (c *conn) syscallConn() syscall.RawConn { + return (*rawConn)(c) +} + +// Control implements [syscall.RawConn.Control]. +func (c *rawConn) Control(f func(uintptr)) error { + return (*conn)(c).rawControl(f) +} + +// rawControl implements [rawConn.Control]. +func (c *conn) rawControl(f func(uintptr)) error { + if !c.guard.Acquire() { + return &net.OpError{Op: "raw-control", Net: c.net, Addr: c.localAddr, Err: net.ErrClosed} + } + defer c.guard.Release() + f(uintptr(c.socket)) + return nil +} + +// Read implements [syscall.RawConn.Read]. +func (c *rawConn) Read(f func(uintptr) bool) error { + return &net.OpError{Op: "raw-read", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported} +} + +// Write implements [syscall.RawConn.Write]. +func (c *rawConn) Write(f func(uintptr) bool) error { + return &net.OpError{Op: "raw-write", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported} +} + +// createRequestQueue creates a RIO request queue for the connection. +// It must be called before sending or receiving data. The call fails +// if a request queue already exists, if any parameter is invalid, +// or if RIO request queue creation fails. +// +// The caller must ensure that the connection is not closed while +// this call is in progress and that the provided completion queues +// have sufficient capacity for the specified number of outstanding +// requests, plus any requests posted by other connections sharing +// the same completion queues. As of 2026-02-17, completion queues +// are currently not shared between connections. +func (c *conn) createRequestQueue( + receiveCq winrio.Cq, maxOutstandingReceives uint32, + sendCq winrio.Cq, maxOutstandingSends uint32, +) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.rq != 0 { + return errors.New("already created") + } + if receiveCq == 0 { + return errors.New("invalid Rx completion queue") + } + if sendCq == 0 { + return errors.New("invalid Tx completion queue") + } + if maxOutstandingReceives == 0 { + return errors.New("invalid max outstanding receives") + } + if maxOutstandingSends == 0 { + return errors.New("invalid max outstanding sends") + } + var err error + c.rq, err = winrio.CreateRequestQueue(c.socket, + maxOutstandingReceives, 1, + maxOutstandingSends, 1, + receiveCq, sendCq, + 0, + ) + return err +} + +// postReceiveRequests posts multiple receive requests to the RIO request queue. +// It returns an error if posting any request fails. +// +// The caller must ensure that the connection is not closed until this call returns. +func (c *conn) postReceiveRequests(reqs iter.Seq[*request]) (err error) { + var deferred int + defer func() { + // Always commit any deferred receive requests, even if an error occurred. + if deferred != 0 { + if commitErr := c.commitReceiveRequests(); commitErr != nil { + err = errors.Join(err, commitErr) + } + } + }() + + c.mu.Lock() + defer c.mu.Unlock() + for req := range reqs { + if err := c.postReceiveRequestLocked(req, winrio.MsgDefer); err != nil { + return fmt.Errorf("failed to post receive request #%d: %w", deferred, err) + } + deferred++ + } + return nil +} + +// postReceiveRequestLocked posts a single receive request to the +// RIO request queue. +// +// c.mu must be held, and the caller must ensure +// that the connection is not closed until this call returns. +func (c *conn) postReceiveRequestLocked(req *request, flags uint32) error { + return req.PostReceive(c.rq, flags) +} + +// commitReceiveRequests commits previously deferred receive requests. +// +// The caller must ensure that the connection is not closed until +// this call returns. It may be called with or without c.mu held. +func (c *conn) commitReceiveRequests() error { + // Unlike other ReceiveEx calls, commits do not need to be serialized: + // https://web.archive.org/web/20260216052922/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_rioreceiveex + if err := winrio.ReceiveEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil { + return fmt.Errorf("failed to commit deferred receive requests: %w", err) + } + return nil +} + +// postSendRequest posts a single send request to the RIO request queue. +// The caller must ensure that the connection is not closed until this call returns. +func (c *conn) postSendRequest(req *request, flags uint32) error { + // Submit the send request. As the underlying RIO request queue + // is not thread-safe, we need to serialize access to it. + c.mu.Lock() + err := req.PostSend(c.rq, flags) + c.mu.Unlock() + return err +} + +// commitSendRequests commits previously deferred send requests. +// +// The caller must ensure that the connection is not closed until +// this call returns. It may be called with or without c.mu held. +func (c *conn) commitSendRequests() error { + // Unlike other SendEx calls, commits do not need to be serialized: + // https://web.archive.org/web/20260216053051/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_riosendex + if err := winrio.SendEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil { + return fmt.Errorf("failed to commit deferred send requests: %w", err) + } + return nil +} + +// Close closes the connection, cancels any pending operations, +// and releases all associated resources. +// Close is safe for concurrent use. +func (c *conn) Close() error { + if c == nil { + return nil + } + + c.closeMu.Lock() + defer c.closeMu.Unlock() + + c.guard.Close() // prevent new operations + if c.closedEvt != 0 { // wake up blocked operations + if err := windows.SetEvent(c.closedEvt); err != nil { + return fmt.Errorf("failed to set close notification event: %w", err) + } + } + c.guard.Wait() + // At this point, no operations are in flight and no new ones can start, + // so it is safe to release resources. + if c.socket != 0 { + windows.Closesocket(c.socket) + c.socket = 0 + } + if c.closedEvt != 0 { + windows.CloseHandle(c.closedEvt) + c.closedEvt = 0 + } + return nil +} + +func rioSocket(family, sotype, proto int32) (windows.Handle, error) { + const rioWSAFlags = windows.WSA_FLAG_REGISTERED_IO | + windows.WSA_FLAG_NO_HANDLE_INHERIT | + windows.WSA_FLAG_OVERLAPPED + return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags) +} diff --git a/net/rioconn/conn_test.go b/net/rioconn/conn_test.go new file mode 100644 index 000000000..50b37c717 --- /dev/null +++ b/net/rioconn/conn_test.go @@ -0,0 +1,239 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "net/netip" + "syscall" + "testing" +) + +func TestNewConn(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sotype int32 + proto int32 + dualStack bool + laddr netip.AddrPort + wantErr bool + wantNetwork string + wantFamily int32 + wantPort uint16 // 0 means any port + wantAddr netip.Addr + }{ + { + name: "IPv4/UDP/AnyAddr/EphemeralPort", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + laddr: netip.MustParseAddrPort("0.0.0.0:0"), + wantAddr: netip.MustParseAddr("0.0.0.0"), + wantFamily: syscall.AF_INET, + wantNetwork: "udp4", + }, + { + name: "IPv6/UDP/AnyAddr/EphemeralPort", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + laddr: netip.MustParseAddrPort("[::]:0"), + wantAddr: netip.MustParseAddr("::"), + wantFamily: syscall.AF_INET6, + wantNetwork: "udp6", + }, + { + name: "IPv4/UDP/LoopbackAddr/EphemeralPort", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + laddr: netip.MustParseAddrPort("127.0.0.1:0"), + wantAddr: netip.MustParseAddr("127.0.0.1"), + wantFamily: syscall.AF_INET, + wantNetwork: "udp4", + }, + { + name: "IPv6/UDP/LoopbackAddr/EphemeralPort", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + laddr: netip.MustParseAddrPort("[::1]:0"), + wantAddr: netip.MustParseAddr("::1"), + wantFamily: syscall.AF_INET6, + wantNetwork: "udp6", + }, + { + name: "IPv6/UDP/AnyAddr/EphemeralPort/DualStack", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + dualStack: true, + laddr: netip.MustParseAddrPort("[::]:0"), + wantAddr: netip.MustParseAddr("::"), + wantFamily: syscall.AF_INET6, + wantNetwork: "udp", + }, + { + name: "IPv4/TCP/AnyAddr/EphemeralPort", + sotype: syscall.SOCK_STREAM, + proto: syscall.IPPROTO_TCP, + laddr: netip.MustParseAddrPort("0.0.0.0:0"), + wantAddr: netip.MustParseAddr("0.0.0.0"), + wantFamily: syscall.AF_INET, + wantNetwork: "tcp4", + }, + { + name: "IPv6/TCP/AnyAddr/EphemeralPort", + sotype: syscall.SOCK_STREAM, + proto: syscall.IPPROTO_TCP, + laddr: netip.MustParseAddrPort("[::]:0"), + wantAddr: netip.MustParseAddr("::"), + wantFamily: syscall.AF_INET6, + wantNetwork: "tcp6", + }, + { + name: "IPv6/TCP/AnyAddr/EphemeralPort/DualStack", + sotype: syscall.SOCK_STREAM, + proto: syscall.IPPROTO_TCP, + dualStack: true, + laddr: netip.MustParseAddrPort("[::]:0"), + wantAddr: netip.MustParseAddr("::"), + wantFamily: syscall.AF_INET6, + wantNetwork: "tcp", + }, + { + name: "InvalidSOType", + sotype: 12345, + proto: syscall.IPPROTO_UDP, + laddr: netip.MustParseAddrPort("0.0.0.0:0"), + wantErr: true, + }, + { + name: "InvalidProtocol", + sotype: syscall.SOCK_DGRAM, + proto: 12345, + laddr: netip.MustParseAddrPort("0.0.0.0:0"), + wantErr: true, + }, + { + name: "InvalidAddress", + sotype: syscall.SOCK_DGRAM, + proto: syscall.IPPROTO_UDP, + laddr: netip.AddrPort{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c, err := newConn(tt.sotype, tt.proto, tt.dualStack, tt.laddr, nil) + if (err != nil) != tt.wantErr { + t.Fatalf("newConn: got error %v, want: %v", err, tt.wantErr) + } + if err != nil { + return + } + t.Cleanup(func() { + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + }) + + if got := c.Network(); got != tt.wantNetwork { + t.Errorf("newConn: got network %q, want %q", got, tt.wantNetwork) + } + + if got := c.IsDualStack(); got != tt.dualStack { + t.Errorf("newConn: got dualStack %v, want %v", got, tt.dualStack) + } + + if got := c.Family(); got != tt.wantFamily { + t.Errorf("newConn: got family %d, want %d", got, tt.wantFamily) + } + + gotAddrPort := c.LocalAddrPort() + if !gotAddrPort.IsValid() { + t.Fatalf("newConn: got invalid local address %v", gotAddrPort) + } + + if gotAddrPort.Addr() != tt.wantAddr { + t.Errorf("newConn: got address %v, want %v", gotAddrPort.Addr(), tt.wantAddr) + } + + if tt.wantPort != 0 && gotAddrPort.Port() != tt.wantPort { + t.Errorf("newConn: got port %d, want %d", gotAddrPort.Port(), tt.wantPort) + } else if gotAddrPort.Port() == 0 { + t.Errorf("newConn: got port 0, want non-zero") + } + + if c.LocalAddr().String() != gotAddrPort.String() { + t.Errorf("newConn: LocalAddr %q, LocalAddrPort %q", c.LocalAddr(), gotAddrPort) + } + + }) + } +} + +func TestConnAcquireReleaseClose(t *testing.T) { + t.Parallel() + + c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false, netip.MustParseAddrPort("0.0.0.0:0"), nil) + if err != nil { + t.Fatalf("newConn failed: %v", err) + } + if c.IsClosed() { + t.Fatal("newConn: got closed connection") + } + if err := c.acquire(); err != nil { + t.Errorf("acquire failed: %v", err) + } + go c.release() // race with close + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if !c.IsClosed() { + t.Fatal("Close did not mark connection as closed") + } + if err := c.acquire(); err == nil { + t.Fatal("acquire succeeded on closed connection") + } + if err := c.Close(); err != nil { + t.Fatalf("Close failed on already closed connection: %v", err) + } +} + +func TestSyscallConn(t *testing.T) { + t.Parallel() + + c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false, + netip.MustParseAddrPort("0.0.0.0:0"), nil) + if err != nil { + t.Fatalf("newConn failed: %v", err) + } + defer c.Close() + + syscallConn, err := c.SyscallConn() + if err != nil { + t.Fatalf("SyscallConn failed: %v", err) + } + + err = syscallConn.Control(func(fd uintptr) { + if fd != uintptr(c.socket) { + t.Fatalf("Control: got fd %v, want %v", fd, c.socket) + } + }) + if err != nil { + t.Fatalf("Control failed: %v", err) + } + + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + err = syscallConn.Control(func(fd uintptr) { + t.Fatalf("Control succeeded on closed connection with fd %v", fd) + }) + if err == nil { + t.Fatal("Control succeeded on closed connection") + } +} diff --git a/net/rioconn/request.go b/net/rioconn/request.go index e194f7e22..3e2b04ebb 100644 --- a/net/rioconn/request.go +++ b/net/rioconn/request.go @@ -12,6 +12,7 @@ import ( "unsafe" "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/sys/windows" ) // request represents a portion of a RIO-registered memory buffer used for @@ -85,6 +86,82 @@ func (r *request) Reader() *requestReader { return (*requestReader)(r) } +// PostSend posts the request as a send operation to the given RIO request queue +// with the specified flags. +func (r *request) PostSend(rq winrio.Rq, flags uint32) error { + data := winrio.Buffer{ + Id: r.buffID, + Length: uint32(len(r.data)), + Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase), + } + remoteAddr := r.remoteAddrDesc() + return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r))) +} + +// PostReceive posts the request as a receive operation to the given RIO request queue +// with the specified flags. +func (r *request) PostReceive(rq winrio.Rq, flags uint32) error { + r.data = r.data[:0] + data := winrio.Buffer{ + Id: r.buffID, + Length: uint32(cap(r.data)), + Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase), + } + remoteAddress := r.remoteAddrDesc() + return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil, + nil, flags, uintptr(unsafe.Pointer(r))) +} + +// CompleteSend finalizes a send request. +// +// It validates the completion status and the number of bytes written, +// returning an error if the status indicates a failure, or if the number +// of bytes written does not match the length of the request's data buffer. +func (r *request) CompleteSend(status int32, bytesWritten uint32) error { + expected := len(r.data) + if status != 0 { + return windows.Errno(status) + } + if uint64(bytesWritten) != uint64(expected) { + return fmt.Errorf( + "bytes written (%d) does not match data buffer length (%d)", + bytesWritten, + expected, + ) + } + return nil +} + +// CompleteReceive finalizes a receive request. +// +// It validates the completion status and the number of bytes read +// returning an error if the status indicates a failure, or if the number +// of bytes read exceeds the capacity of the request's data buffer. +// +// On success, it returns a reader view of the request. +func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReader, error) { + if status != 0 { + return nil, windows.Errno(status) + } + if uint64(bytesRead) > uint64(cap(r.data)) { + return nil, fmt.Errorf( + "bytes read (%d) exceeds data buffer capacity (%d)", + bytesRead, + cap(r.data), + ) + } + r.data = r.data[:bytesRead] + return r.Reader(), nil +} + +func (r *request) remoteAddrDesc() winrio.Buffer { + return winrio.Buffer{ + Id: r.buffID, + Length: uint32(unsafe.Sizeof(r.raddr)), + Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase), + } +} + // Reset prepares the request for reuse by resetting its state. func (r *request) Reset() { r.raddr = rawSockaddr{} diff --git a/net/rioconn/request_test.go b/net/rioconn/request_test.go index b1f4c498b..9885eaff6 100644 --- a/net/rioconn/request_test.go +++ b/net/rioconn/request_test.go @@ -10,6 +10,8 @@ import ( "net/netip" "testing" "unsafe" + + "golang.org/x/sys/windows" ) func TestUnsafeMapRequest(t *testing.T) { @@ -279,6 +281,128 @@ func TestRequestReader(t *testing.T) { } } +func TestRequestCompleteSend(t *testing.T) { + tests := []struct { + name string + payloadLength int + bytesWritten uint32 + status int32 + wantErr bool + }{ + { + name: "empty/success", + bytesWritten: 0, + status: 0, + }, + { + name: "non-empty/success", + payloadLength: 100, + bytesWritten: 100, + status: 0, + }, + { + name: "non-empty/partial-write", + payloadLength: 100, + bytesWritten: 50, + status: 0, + wantErr: true, + }, + { + name: "non-empty/overflow-write", + payloadLength: 100, + bytesWritten: 150, + status: 0, + wantErr: true, + }, + { + name: "non-empty/error", + payloadLength: 100, + bytesWritten: 0, + status: int32(windows.ERROR_GEN_FAILURE), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := makeRequest(t, tt.payloadLength) + req.Writer().Write(bytes.Repeat([]byte{'x'}, tt.payloadLength)) + + err := req.CompleteSend(tt.status, tt.bytesWritten) + if tt.wantErr && err == nil { + t.Error("request.CompleteSend: expected error; got nil") + } else if !tt.wantErr && err != nil { + t.Errorf("request.CompleteSend: unexpected error: %v", err) + } + }) + } +} + +func TestRequestCompleteReceive(t *testing.T) { + tests := []struct { + name string + data []byte + bytesRead uint32 + status int32 + wantErr bool + wantBytes []byte + }{ + { + name: "empty/success", + data: nil, + bytesRead: 0, + wantBytes: nil, + }, + { + name: "non-empty/success/full", + data: []byte("Hello World"), + bytesRead: 11, + wantBytes: []byte("Hello World"), + }, + { + name: "non-empty/success/partial", + data: []byte("Hello World"), + // Simulate a partial read by reporting fewer bytes + // that the request's data buffer length. + bytesRead: 5, + wantErr: false, + wantBytes: []byte("Hello"), + }, + { + name: "non-empty/overflow-read", + data: []byte("Hello World"), + bytesRead: 20, // more than the data buffer can hold + wantErr: true, + }, + { + name: "non-empty/error", + data: []byte("Hello World"), + bytesRead: 0, + status: int32(windows.ERROR_GEN_FAILURE), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := makeRequest(t, len(tt.data)) + copy(req.data[:cap(req.data)], tt.data) + + r, err := req.CompleteReceive(tt.status, tt.bytesRead) + if tt.wantErr && err == nil { + t.Error("request.CompleteReceive: expected error; got nil") + } else if !tt.wantErr && err != nil { + t.Errorf("request.CompleteReceive: unexpected error: %v", err) + } + if err != nil { + return + } + + if !bytes.Equal(r.Bytes(), tt.wantBytes) { + t.Errorf("requestReader.Bytes: got %q; want %q", r.Bytes(), tt.wantBytes) + } + }) + } +} + func makeRequest(tb testing.TB, dataSize int) *request { tb.Helper() buff := makeAligned(totalRequestSize(uintptr(dataSize)), unsafe.Alignof(request{}))