mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-12 15:56:14 +02:00
net/rioconn: add base, protocol-agnostic RIO-enabled connection type
We'd like to implement this as its own distinct type early on. That decouples protocol-specific code from the UDPConn implementation in case we want to add a TCPConn later. More importantly, it lets UDPConn's Rx/Tx half-conns reference it while keeping the locking hierarchy sane, since RIO's request queue isn't safe for concurrent use. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
d57b58193a
commit
52ab2b1894
28
net/rioconn/config.go
Normal file
28
net/rioconn/config.go
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Config holds configuration for a RIO connection, independent of the transport protocol.
|
||||
type Config struct {
|
||||
control []func(network, address string, c syscall.RawConn) error
|
||||
}
|
||||
|
||||
// Control invokes all control functions in the Config with the given
|
||||
// network, address, and connection. A failure of one control function
|
||||
// does not prevent the others from running. It returns an error if any
|
||||
// control function fails.
|
||||
func (c Config) Control(network string, address string, conn syscall.RawConn) error {
|
||||
var err []error
|
||||
for _, control := range c.control {
|
||||
if e := control(network, address, conn); e != nil {
|
||||
err = append(err, e)
|
||||
}
|
||||
}
|
||||
return errors.Join(err...)
|
||||
}
|
||||
365
net/rioconn/conn.go
Normal file
365
net/rioconn/conn.go
Normal file
@ -0,0 +1,365 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// conn is a protocol-agnostic base connection with RIO support.
|
||||
//
|
||||
// Its exported methods are safe for concurrent use, including
|
||||
// concurrent calls with each other, with Close, and after Close.
|
||||
//
|
||||
// However, the caller must call [conn.acquire] before invoking
|
||||
// any unexported methods and [conn.release] when done to prevent
|
||||
// the connection from closing while the operation is in flight.
|
||||
type conn struct {
|
||||
// immutable once [newConn] returns:
|
||||
family int32
|
||||
localAddr net.Addr
|
||||
localAddrPort netip.AddrPort
|
||||
dualStack bool
|
||||
sotype int32
|
||||
proto int32
|
||||
net string
|
||||
config *Config
|
||||
|
||||
// guard prevents the connection from closing and its resources from
|
||||
// being freed while operations are in flight. All fields below are
|
||||
// protected by guard and must not be accessed after the connection
|
||||
// is closed and [guard.Acquire] returns false, except by [conn.Close].
|
||||
guard *guard
|
||||
closedEvt windows.Handle
|
||||
socket windows.Handle
|
||||
|
||||
// closeMu serializes calls to [conn.Close].
|
||||
// Lock order: closeMu > mu.
|
||||
closeMu sync.Mutex
|
||||
|
||||
// mu serializes access to the RIO request queue.
|
||||
// Lock order: closeMu > mu.
|
||||
mu sync.Mutex
|
||||
rq winrio.Rq
|
||||
}
|
||||
|
||||
// rawConn implements [syscall.RawConn] for [conn].
|
||||
type rawConn conn
|
||||
|
||||
var (
|
||||
_ syscall.Conn = (*conn)(nil)
|
||||
_ syscall.RawConn = (*rawConn)(nil)
|
||||
)
|
||||
|
||||
func newConn(sotype int32, proto int32, dualStack bool, laddr netip.AddrPort, config *Config) (_ *conn, err error) {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
sa, family, err := sockaddrFromAddrPort(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
net, err := networkName(sotype, proto, family, dualStack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &conn{
|
||||
family: family,
|
||||
dualStack: dualStack,
|
||||
sotype: sotype,
|
||||
proto: proto,
|
||||
net: net,
|
||||
config: config,
|
||||
guard: newGuard(),
|
||||
}
|
||||
defer func() {
|
||||
// If initialization fails, close the connection to release
|
||||
// any resources allocated before the error.
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
// Create a manual-reset event to wake up pending operations on Close.
|
||||
if conn.closedEvt, err = windows.CreateEvent(nil, 1, 0, nil); err != nil {
|
||||
return nil, fmt.Errorf("failed to create close notification event: %w", err)
|
||||
}
|
||||
// Create a socket with the WSA_FLAG_REGISTERED_IO flag set.
|
||||
if conn.socket, err = rioSocket(family, sotype, proto); err != nil {
|
||||
return nil, fmt.Errorf("failed to create socket(%d, %d, %d): %w", family, sotype, proto, err)
|
||||
}
|
||||
// Enable dual-stack mode by clearing the IPV6_V6ONLY option, if necessary.
|
||||
// https://web.archive.org/web/20260208062136/https://learn.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets
|
||||
if dualStack {
|
||||
if err := windows.SetsockoptInt(conn.socket, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable dual-stack mode: %w", err)
|
||||
}
|
||||
}
|
||||
// Invoke caller-provided control functions to set socket options before binding.
|
||||
if err := conn.config.Control(net, laddr.String(), (*rawConn)(conn)); err != nil {
|
||||
return nil, fmt.Errorf("control failed: %w", err)
|
||||
}
|
||||
if err := windows.Bind(conn.socket, sa); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind socket: %w", err)
|
||||
}
|
||||
// Record the local address from the actual socket, since the caller
|
||||
// may have specified port 0 for automatic assignment.
|
||||
if conn.localAddrPort, err = addrPortFromSocket(conn.socket); err != nil {
|
||||
return nil, fmt.Errorf("failed to get local address and port: %w", err)
|
||||
}
|
||||
if conn.localAddr, err = netAddrFromAddrPort(conn.localAddrPort, sotype); err != nil {
|
||||
return nil, fmt.Errorf("failed to convert local address and port to net.Addr: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// IsClosed reports whether Close has been called.
|
||||
func (c *conn) IsClosed() bool {
|
||||
return c.guard.IsClosed()
|
||||
}
|
||||
|
||||
// acquire increments the connection's reference count, preventing
|
||||
// it from closing. If it returns no error, the caller may use the
|
||||
// connection and must call [conn.release] when done. Otherwise,
|
||||
// the connection must not be used.
|
||||
func (c *conn) acquire() error {
|
||||
if !c.guard.Acquire() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// release decrements the connection's reference count.
|
||||
// Calling release without a matching acquire is a run-time error.
|
||||
func (c *conn) release() {
|
||||
c.guard.Release()
|
||||
}
|
||||
|
||||
// Family returns the socket address family of the connection.
|
||||
func (c *conn) Family() int32 {
|
||||
return c.family
|
||||
}
|
||||
|
||||
// Network returns the network name of the connection.
|
||||
func (c *conn) Network() string {
|
||||
return c.net
|
||||
}
|
||||
|
||||
// IsDualStack reports whether the connection is dual-stack and can send
|
||||
// and receive packets to and from IPv6 or IPv4-mapped IPv6 addresses.
|
||||
func (c *conn) IsDualStack() bool {
|
||||
return c.dualStack
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
// LocalAddrPort returns the local network address and port.
|
||||
func (c *conn) LocalAddrPort() netip.AddrPort {
|
||||
return c.localAddrPort
|
||||
}
|
||||
|
||||
// SyscallConn returns a raw network connection, or an error if the connection is closed.
|
||||
func (c *conn) SyscallConn() (syscall.RawConn, error) {
|
||||
if c.IsClosed() {
|
||||
// Return the error immediately if the connection is already closed.
|
||||
// The [conn] implementation handles the case where the connection
|
||||
// closes after this call returns, so this is only an optimization.
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c.syscallConn(), nil
|
||||
}
|
||||
|
||||
func (c *conn) syscallConn() syscall.RawConn {
|
||||
return (*rawConn)(c)
|
||||
}
|
||||
|
||||
// Control implements [syscall.RawConn.Control].
|
||||
func (c *rawConn) Control(f func(uintptr)) error {
|
||||
return (*conn)(c).rawControl(f)
|
||||
}
|
||||
|
||||
// rawControl implements [rawConn.Control].
|
||||
func (c *conn) rawControl(f func(uintptr)) error {
|
||||
if !c.guard.Acquire() {
|
||||
return &net.OpError{Op: "raw-control", Net: c.net, Addr: c.localAddr, Err: net.ErrClosed}
|
||||
}
|
||||
defer c.guard.Release()
|
||||
f(uintptr(c.socket))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read implements [syscall.RawConn.Read].
|
||||
func (c *rawConn) Read(f func(uintptr) bool) error {
|
||||
return &net.OpError{Op: "raw-read", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported}
|
||||
}
|
||||
|
||||
// Write implements [syscall.RawConn.Write].
|
||||
func (c *rawConn) Write(f func(uintptr) bool) error {
|
||||
return &net.OpError{Op: "raw-write", Net: c.net, Source: c.localAddr, Err: errors.ErrUnsupported}
|
||||
}
|
||||
|
||||
// createRequestQueue creates a RIO request queue for the connection.
|
||||
// It must be called before sending or receiving data. The call fails
|
||||
// if a request queue already exists, if any parameter is invalid,
|
||||
// or if RIO request queue creation fails.
|
||||
//
|
||||
// The caller must ensure that the connection is not closed while
|
||||
// this call is in progress and that the provided completion queues
|
||||
// have sufficient capacity for the specified number of outstanding
|
||||
// requests, plus any requests posted by other connections sharing
|
||||
// the same completion queues. As of 2026-02-17, completion queues
|
||||
// are currently not shared between connections.
|
||||
func (c *conn) createRequestQueue(
|
||||
receiveCq winrio.Cq, maxOutstandingReceives uint32,
|
||||
sendCq winrio.Cq, maxOutstandingSends uint32,
|
||||
) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.rq != 0 {
|
||||
return errors.New("already created")
|
||||
}
|
||||
if receiveCq == 0 {
|
||||
return errors.New("invalid Rx completion queue")
|
||||
}
|
||||
if sendCq == 0 {
|
||||
return errors.New("invalid Tx completion queue")
|
||||
}
|
||||
if maxOutstandingReceives == 0 {
|
||||
return errors.New("invalid max outstanding receives")
|
||||
}
|
||||
if maxOutstandingSends == 0 {
|
||||
return errors.New("invalid max outstanding sends")
|
||||
}
|
||||
var err error
|
||||
c.rq, err = winrio.CreateRequestQueue(c.socket,
|
||||
maxOutstandingReceives, 1,
|
||||
maxOutstandingSends, 1,
|
||||
receiveCq, sendCq,
|
||||
0,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// postReceiveRequests posts multiple receive requests to the RIO request queue.
|
||||
// It returns an error if posting any request fails.
|
||||
//
|
||||
// The caller must ensure that the connection is not closed until this call returns.
|
||||
func (c *conn) postReceiveRequests(reqs iter.Seq[*request]) (err error) {
|
||||
var deferred int
|
||||
defer func() {
|
||||
// Always commit any deferred receive requests, even if an error occurred.
|
||||
if deferred != 0 {
|
||||
if commitErr := c.commitReceiveRequests(); commitErr != nil {
|
||||
err = errors.Join(err, commitErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for req := range reqs {
|
||||
if err := c.postReceiveRequestLocked(req, winrio.MsgDefer); err != nil {
|
||||
return fmt.Errorf("failed to post receive request #%d: %w", deferred, err)
|
||||
}
|
||||
deferred++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// postReceiveRequestLocked posts a single receive request to the
|
||||
// RIO request queue.
|
||||
//
|
||||
// c.mu must be held, and the caller must ensure
|
||||
// that the connection is not closed until this call returns.
|
||||
func (c *conn) postReceiveRequestLocked(req *request, flags uint32) error {
|
||||
return req.PostReceive(c.rq, flags)
|
||||
}
|
||||
|
||||
// commitReceiveRequests commits previously deferred receive requests.
|
||||
//
|
||||
// The caller must ensure that the connection is not closed until
|
||||
// this call returns. It may be called with or without c.mu held.
|
||||
func (c *conn) commitReceiveRequests() error {
|
||||
// Unlike other ReceiveEx calls, commits do not need to be serialized:
|
||||
// https://web.archive.org/web/20260216052922/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_rioreceiveex
|
||||
if err := winrio.ReceiveEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil {
|
||||
return fmt.Errorf("failed to commit deferred receive requests: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// postSendRequest posts a single send request to the RIO request queue.
|
||||
// The caller must ensure that the connection is not closed until this call returns.
|
||||
func (c *conn) postSendRequest(req *request, flags uint32) error {
|
||||
// Submit the send request. As the underlying RIO request queue
|
||||
// is not thread-safe, we need to serialize access to it.
|
||||
c.mu.Lock()
|
||||
err := req.PostSend(c.rq, flags)
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// commitSendRequests commits previously deferred send requests.
|
||||
//
|
||||
// The caller must ensure that the connection is not closed until
|
||||
// this call returns. It may be called with or without c.mu held.
|
||||
func (c *conn) commitSendRequests() error {
|
||||
// Unlike other SendEx calls, commits do not need to be serialized:
|
||||
// https://web.archive.org/web/20260216053051/https://learn.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_riosendex
|
||||
if err := winrio.SendEx(c.rq, nil, 0, nil, nil, nil, nil, winrio.MsgCommitOnly, 0); err != nil {
|
||||
return fmt.Errorf("failed to commit deferred send requests: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection, cancels any pending operations,
|
||||
// and releases all associated resources.
|
||||
// Close is safe for concurrent use.
|
||||
func (c *conn) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.closeMu.Lock()
|
||||
defer c.closeMu.Unlock()
|
||||
|
||||
c.guard.Close() // prevent new operations
|
||||
if c.closedEvt != 0 { // wake up blocked operations
|
||||
if err := windows.SetEvent(c.closedEvt); err != nil {
|
||||
return fmt.Errorf("failed to set close notification event: %w", err)
|
||||
}
|
||||
}
|
||||
c.guard.Wait()
|
||||
// At this point, no operations are in flight and no new ones can start,
|
||||
// so it is safe to release resources.
|
||||
if c.socket != 0 {
|
||||
windows.Closesocket(c.socket)
|
||||
c.socket = 0
|
||||
}
|
||||
if c.closedEvt != 0 {
|
||||
windows.CloseHandle(c.closedEvt)
|
||||
c.closedEvt = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rioSocket(family, sotype, proto int32) (windows.Handle, error) {
|
||||
const rioWSAFlags = windows.WSA_FLAG_REGISTERED_IO |
|
||||
windows.WSA_FLAG_NO_HANDLE_INHERIT |
|
||||
windows.WSA_FLAG_OVERLAPPED
|
||||
return windows.WSASocket(family, sotype, proto, nil, 0, rioWSAFlags)
|
||||
}
|
||||
239
net/rioconn/conn_test.go
Normal file
239
net/rioconn/conn_test.go
Normal file
@ -0,0 +1,239 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sotype int32
|
||||
proto int32
|
||||
dualStack bool
|
||||
laddr netip.AddrPort
|
||||
wantErr bool
|
||||
wantNetwork string
|
||||
wantFamily int32
|
||||
wantPort uint16 // 0 means any port
|
||||
wantAddr netip.Addr
|
||||
}{
|
||||
{
|
||||
name: "IPv4/UDP/AnyAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
wantAddr: netip.MustParseAddr("0.0.0.0"),
|
||||
wantFamily: syscall.AF_INET,
|
||||
wantNetwork: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6/UDP/AnyAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.MustParseAddrPort("[::]:0"),
|
||||
wantAddr: netip.MustParseAddr("::"),
|
||||
wantFamily: syscall.AF_INET6,
|
||||
wantNetwork: "udp6",
|
||||
},
|
||||
{
|
||||
name: "IPv4/UDP/LoopbackAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.MustParseAddrPort("127.0.0.1:0"),
|
||||
wantAddr: netip.MustParseAddr("127.0.0.1"),
|
||||
wantFamily: syscall.AF_INET,
|
||||
wantNetwork: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6/UDP/LoopbackAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.MustParseAddrPort("[::1]:0"),
|
||||
wantAddr: netip.MustParseAddr("::1"),
|
||||
wantFamily: syscall.AF_INET6,
|
||||
wantNetwork: "udp6",
|
||||
},
|
||||
{
|
||||
name: "IPv6/UDP/AnyAddr/EphemeralPort/DualStack",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
dualStack: true,
|
||||
laddr: netip.MustParseAddrPort("[::]:0"),
|
||||
wantAddr: netip.MustParseAddr("::"),
|
||||
wantFamily: syscall.AF_INET6,
|
||||
wantNetwork: "udp",
|
||||
},
|
||||
{
|
||||
name: "IPv4/TCP/AnyAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_STREAM,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
wantAddr: netip.MustParseAddr("0.0.0.0"),
|
||||
wantFamily: syscall.AF_INET,
|
||||
wantNetwork: "tcp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6/TCP/AnyAddr/EphemeralPort",
|
||||
sotype: syscall.SOCK_STREAM,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
laddr: netip.MustParseAddrPort("[::]:0"),
|
||||
wantAddr: netip.MustParseAddr("::"),
|
||||
wantFamily: syscall.AF_INET6,
|
||||
wantNetwork: "tcp6",
|
||||
},
|
||||
{
|
||||
name: "IPv6/TCP/AnyAddr/EphemeralPort/DualStack",
|
||||
sotype: syscall.SOCK_STREAM,
|
||||
proto: syscall.IPPROTO_TCP,
|
||||
dualStack: true,
|
||||
laddr: netip.MustParseAddrPort("[::]:0"),
|
||||
wantAddr: netip.MustParseAddr("::"),
|
||||
wantFamily: syscall.AF_INET6,
|
||||
wantNetwork: "tcp",
|
||||
},
|
||||
{
|
||||
name: "InvalidSOType",
|
||||
sotype: 12345,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidProtocol",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: 12345,
|
||||
laddr: netip.MustParseAddrPort("0.0.0.0:0"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidAddress",
|
||||
sotype: syscall.SOCK_DGRAM,
|
||||
proto: syscall.IPPROTO_UDP,
|
||||
laddr: netip.AddrPort{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, err := newConn(tt.sotype, tt.proto, tt.dualStack, tt.laddr, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("newConn: got error %v, want: %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if got := c.Network(); got != tt.wantNetwork {
|
||||
t.Errorf("newConn: got network %q, want %q", got, tt.wantNetwork)
|
||||
}
|
||||
|
||||
if got := c.IsDualStack(); got != tt.dualStack {
|
||||
t.Errorf("newConn: got dualStack %v, want %v", got, tt.dualStack)
|
||||
}
|
||||
|
||||
if got := c.Family(); got != tt.wantFamily {
|
||||
t.Errorf("newConn: got family %d, want %d", got, tt.wantFamily)
|
||||
}
|
||||
|
||||
gotAddrPort := c.LocalAddrPort()
|
||||
if !gotAddrPort.IsValid() {
|
||||
t.Fatalf("newConn: got invalid local address %v", gotAddrPort)
|
||||
}
|
||||
|
||||
if gotAddrPort.Addr() != tt.wantAddr {
|
||||
t.Errorf("newConn: got address %v, want %v", gotAddrPort.Addr(), tt.wantAddr)
|
||||
}
|
||||
|
||||
if tt.wantPort != 0 && gotAddrPort.Port() != tt.wantPort {
|
||||
t.Errorf("newConn: got port %d, want %d", gotAddrPort.Port(), tt.wantPort)
|
||||
} else if gotAddrPort.Port() == 0 {
|
||||
t.Errorf("newConn: got port 0, want non-zero")
|
||||
}
|
||||
|
||||
if c.LocalAddr().String() != gotAddrPort.String() {
|
||||
t.Errorf("newConn: LocalAddr %q, LocalAddrPort %q", c.LocalAddr(), gotAddrPort)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnAcquireReleaseClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false, netip.MustParseAddrPort("0.0.0.0:0"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newConn failed: %v", err)
|
||||
}
|
||||
if c.IsClosed() {
|
||||
t.Fatal("newConn: got closed connection")
|
||||
}
|
||||
if err := c.acquire(); err != nil {
|
||||
t.Errorf("acquire failed: %v", err)
|
||||
}
|
||||
go c.release() // race with close
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
if !c.IsClosed() {
|
||||
t.Fatal("Close did not mark connection as closed")
|
||||
}
|
||||
if err := c.acquire(); err == nil {
|
||||
t.Fatal("acquire succeeded on closed connection")
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Close failed on already closed connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyscallConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, err := newConn(syscall.SOCK_DGRAM, syscall.IPPROTO_UDP, false,
|
||||
netip.MustParseAddrPort("0.0.0.0:0"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newConn failed: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
syscallConn, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatalf("SyscallConn failed: %v", err)
|
||||
}
|
||||
|
||||
err = syscallConn.Control(func(fd uintptr) {
|
||||
if fd != uintptr(c.socket) {
|
||||
t.Fatalf("Control: got fd %v, want %v", fd, c.socket)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Control failed: %v", err)
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
err = syscallConn.Control(func(fd uintptr) {
|
||||
t.Fatalf("Control succeeded on closed connection with fd %v", fd)
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Control succeeded on closed connection")
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// request represents a portion of a RIO-registered memory buffer used for
|
||||
@ -85,6 +86,82 @@ func (r *request) Reader() *requestReader {
|
||||
return (*requestReader)(r)
|
||||
}
|
||||
|
||||
// PostSend posts the request as a send operation to the given RIO request queue
|
||||
// with the specified flags.
|
||||
func (r *request) PostSend(rq winrio.Rq, flags uint32) error {
|
||||
data := winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(len(r.data)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
|
||||
}
|
||||
remoteAddr := r.remoteAddrDesc()
|
||||
return winrio.SendEx(rq, &data, 1, nil, &remoteAddr, nil, nil, flags, uintptr(unsafe.Pointer(r)))
|
||||
}
|
||||
|
||||
// PostReceive posts the request as a receive operation to the given RIO request queue
|
||||
// with the specified flags.
|
||||
func (r *request) PostReceive(rq winrio.Rq, flags uint32) error {
|
||||
r.data = r.data[:0]
|
||||
data := winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(cap(r.data)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(unsafe.SliceData(r.data))) - r.buffBase),
|
||||
}
|
||||
remoteAddress := r.remoteAddrDesc()
|
||||
return winrio.ReceiveEx(rq, &data, 1, nil, &remoteAddress, nil,
|
||||
nil, flags, uintptr(unsafe.Pointer(r)))
|
||||
}
|
||||
|
||||
// CompleteSend finalizes a send request.
|
||||
//
|
||||
// It validates the completion status and the number of bytes written,
|
||||
// returning an error if the status indicates a failure, or if the number
|
||||
// of bytes written does not match the length of the request's data buffer.
|
||||
func (r *request) CompleteSend(status int32, bytesWritten uint32) error {
|
||||
expected := len(r.data)
|
||||
if status != 0 {
|
||||
return windows.Errno(status)
|
||||
}
|
||||
if uint64(bytesWritten) != uint64(expected) {
|
||||
return fmt.Errorf(
|
||||
"bytes written (%d) does not match data buffer length (%d)",
|
||||
bytesWritten,
|
||||
expected,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteReceive finalizes a receive request.
|
||||
//
|
||||
// It validates the completion status and the number of bytes read
|
||||
// returning an error if the status indicates a failure, or if the number
|
||||
// of bytes read exceeds the capacity of the request's data buffer.
|
||||
//
|
||||
// On success, it returns a reader view of the request.
|
||||
func (r *request) CompleteReceive(status int32, bytesRead uint32) (*requestReader, error) {
|
||||
if status != 0 {
|
||||
return nil, windows.Errno(status)
|
||||
}
|
||||
if uint64(bytesRead) > uint64(cap(r.data)) {
|
||||
return nil, fmt.Errorf(
|
||||
"bytes read (%d) exceeds data buffer capacity (%d)",
|
||||
bytesRead,
|
||||
cap(r.data),
|
||||
)
|
||||
}
|
||||
r.data = r.data[:bytesRead]
|
||||
return r.Reader(), nil
|
||||
}
|
||||
|
||||
func (r *request) remoteAddrDesc() winrio.Buffer {
|
||||
return winrio.Buffer{
|
||||
Id: r.buffID,
|
||||
Length: uint32(unsafe.Sizeof(r.raddr)),
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&r.raddr)) - r.buffBase),
|
||||
}
|
||||
}
|
||||
|
||||
// Reset prepares the request for reuse by resetting its state.
|
||||
func (r *request) Reset() {
|
||||
r.raddr = rawSockaddr{}
|
||||
|
||||
@ -10,6 +10,8 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestUnsafeMapRequest(t *testing.T) {
|
||||
@ -279,6 +281,128 @@ func TestRequestReader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCompleteSend(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payloadLength int
|
||||
bytesWritten uint32
|
||||
status int32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty/success",
|
||||
bytesWritten: 0,
|
||||
status: 0,
|
||||
},
|
||||
{
|
||||
name: "non-empty/success",
|
||||
payloadLength: 100,
|
||||
bytesWritten: 100,
|
||||
status: 0,
|
||||
},
|
||||
{
|
||||
name: "non-empty/partial-write",
|
||||
payloadLength: 100,
|
||||
bytesWritten: 50,
|
||||
status: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-empty/overflow-write",
|
||||
payloadLength: 100,
|
||||
bytesWritten: 150,
|
||||
status: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-empty/error",
|
||||
payloadLength: 100,
|
||||
bytesWritten: 0,
|
||||
status: int32(windows.ERROR_GEN_FAILURE),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := makeRequest(t, tt.payloadLength)
|
||||
req.Writer().Write(bytes.Repeat([]byte{'x'}, tt.payloadLength))
|
||||
|
||||
err := req.CompleteSend(tt.status, tt.bytesWritten)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("request.CompleteSend: expected error; got nil")
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("request.CompleteSend: unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCompleteReceive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
bytesRead uint32
|
||||
status int32
|
||||
wantErr bool
|
||||
wantBytes []byte
|
||||
}{
|
||||
{
|
||||
name: "empty/success",
|
||||
data: nil,
|
||||
bytesRead: 0,
|
||||
wantBytes: nil,
|
||||
},
|
||||
{
|
||||
name: "non-empty/success/full",
|
||||
data: []byte("Hello World"),
|
||||
bytesRead: 11,
|
||||
wantBytes: []byte("Hello World"),
|
||||
},
|
||||
{
|
||||
name: "non-empty/success/partial",
|
||||
data: []byte("Hello World"),
|
||||
// Simulate a partial read by reporting fewer bytes
|
||||
// that the request's data buffer length.
|
||||
bytesRead: 5,
|
||||
wantErr: false,
|
||||
wantBytes: []byte("Hello"),
|
||||
},
|
||||
{
|
||||
name: "non-empty/overflow-read",
|
||||
data: []byte("Hello World"),
|
||||
bytesRead: 20, // more than the data buffer can hold
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-empty/error",
|
||||
data: []byte("Hello World"),
|
||||
bytesRead: 0,
|
||||
status: int32(windows.ERROR_GEN_FAILURE),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := makeRequest(t, len(tt.data))
|
||||
copy(req.data[:cap(req.data)], tt.data)
|
||||
|
||||
r, err := req.CompleteReceive(tt.status, tt.bytesRead)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("request.CompleteReceive: expected error; got nil")
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("request.CompleteReceive: unexpected error: %v", err)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(r.Bytes(), tt.wantBytes) {
|
||||
t.Errorf("requestReader.Bytes: got %q; want %q", r.Bytes(), tt.wantBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeRequest(tb testing.TB, dataSize int) *request {
|
||||
tb.Helper()
|
||||
buff := makeAligned(totalRequestSize(uintptr(dataSize)), unsafe.Alignof(request{}))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user