diff --git a/net/rioconn/doc.go b/net/rioconn/doc.go new file mode 100644 index 000000000..406cdef73 --- /dev/null +++ b/net/rioconn/doc.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "fmt" + + "github.com/tailscale/wireguard-go/conn/winrio" +) + +// ErrRIOUnavailable is returned when Windows RIO is required but not available. +var ErrRIOUnavailable = fmt.Errorf("Registered I/O (RIO) is not available on this system") + +// Initialize initializes the Windows RIO API extensions. +// It returns [ErrRIOUnavailable] if RIO cannot be used. +func Initialize() error { + if !winrio.Initialize() { + return ErrRIOUnavailable + } + return nil +} diff --git a/net/rioconn/memory.go b/net/rioconn/memory.go new file mode 100644 index 000000000..38e9705bd --- /dev/null +++ b/net/rioconn/memory.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rioconn + +import ( + "math/bits" + + "golang.org/x/exp/constraints" +) + +// alignUp returns the smallest value >= v that is a multiple of alignment. +// The alignment must be a power of two. +func alignUp[V, A constraints.Integer](v V, alignment A) V { + return (v + V(alignment) - 1) &^ (V(alignment) - 1) +} + +// alignUpOffset rounds offset up so that base+offset is aligned to the +// specified boundary. Alignment must be a power of two. +func alignUpOffset(base, offset, alignment uintptr) uintptr { + return alignUp(base+offset, alignment) - base +} + +// isPowerOfTwo reports whether n is a power of two. +func isPowerOfTwo[T constraints.Integer](n T) bool { + return n > 0 && (n&(n-1)) == 0 +} + +// floorPowerOfTwo returns the largest power of two <= n. +func floorPowerOfTwo[T constraints.Unsigned](n T) T { + if n == 0 { + return 0 + } + return 1 << (bits.Len64(uint64(n)) - 1) +} diff --git a/net/rioconn/request.go b/net/rioconn/request.go new file mode 100644 index 000000000..e194f7e22 --- /dev/null +++ b/net/rioconn/request.go @@ -0,0 +1,187 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "errors" + "fmt" + "net/netip" + "unsafe" + + "github.com/tailscale/wireguard-go/conn/winrio" +) + +// request represents a portion of a RIO-registered memory buffer used for +// a single send or receive operation. It is always heap-allocated, and the +// fixed-size struct is followed in memory by a variable-size data buffer. +// +// A memory buffer implementation, such as a [requestRing], is responsible for +// allocating requests within a registered buffer via [unsafeMapRequest] and +// ensuring that the buffer remains registered and valid for the lifetime +// of the requests. +type request struct { + buffID winrio.BufferId // ID of the registered buffer containing this request + buffBase uintptr // base address of the registered buffer + + raddr rawSockaddr // remote address for RIO send/receive operations + data []byte // a slice pointing into the data buffer area after the struct + // followed by the actual data at [requestDataOffset] + // from the start of the struct. +} + +const ( + requestDataAlignment = 8 + requestDataOffset = (unsafe.Sizeof(request{}) + requestDataAlignment - 1) &^ (requestDataAlignment - 1) +) + +// totalRequestSize returns the total number of bytes required to hold +// a [request] struct followed by a data buffer of the given size. +func totalRequestSize(dataSize uintptr) uintptr { + return alignUp(requestDataOffset+dataSize, unsafe.Alignof(request{})) +} + +// unsafeMapRequest maps a [request] into the given RIO-registered buffer at the +// specified offset and returns a pointer to it, the number of bytes used, +// and whether the mapping succeeded. +// +// On success, the returned pointer is aligned to the [request]'s natural +// alignment and the request can hold up to dataSize bytes of data. +// +// It is the caller's responsibility to ensure that the buffer remains +// registered and valid for the lifetime of the returned request. +func unsafeMapRequest(buffID winrio.BufferId, buff []byte, offset, dataSize uintptr) (_ *request, n uintptr, ok bool) { + baseAddr := uintptr(unsafe.Pointer(unsafe.SliceData(buff))) + alignedOffset := alignUpOffset(baseAddr, offset, unsafe.Alignof(request{})) + bytesNeeded := totalRequestSize(dataSize) + uintptr(alignedOffset-offset) + if offset >= uintptr(len(buff)) { + return nil, 0, false + } + bytesAvailable := uintptr(len(buff)) - offset + if bytesAvailable < bytesNeeded { + return nil, 0, false + } + + requestBytes := unsafe.SliceData(buff[alignedOffset:]) + request := (*request)(unsafe.Pointer(requestBytes)) + request.buffID = buffID + request.buffBase = baseAddr + request.data = unsafe.Slice( + (*byte)(unsafe.Add(unsafe.Pointer(request), requestDataOffset)), + dataSize, + )[:0] // zero-length data slice with capacity of dataSize + return request, bytesNeeded, true +} + +// Writer returns a [requestWriter] for the request. +func (r *request) Writer() *requestWriter { + return (*requestWriter)(r) +} + +// Reader returns a [requestReader] for the request. +func (r *request) Reader() *requestReader { + return (*requestReader)(r) +} + +// Reset prepares the request for reuse by resetting its state. +func (r *request) Reset() { + r.raddr = rawSockaddr{} + r.data = r.data[:0] +} + +type ( + requestWriter request + requestReader request +) + +// Len returns the number of bytes written or reserved so far. +func (w *requestWriter) Len() int { + return len(w.data) +} + +// Cap returns the maximum number of bytes that can be written or reserved. +func (w *requestWriter) Cap() int { + return cap(w.data) +} + +// Available returns the number of bytes available for writing or reserving. +func (w *requestWriter) Available() int { + return cap(w.data) - len(w.data) +} + +// SetRemoteAddrPort sets the remote address for the request from a [netip.AddrPort]. +// It returns an error if the specified address cannot converted to a [rawSockaddr]. +func (w *requestWriter) SetRemoteAddrPort(raddr netip.AddrPort) error { + var err error + w.raddr, err = rawSockaddrFromAddrPort(raddr) + return err +} + +// SetRemoteAddr sets the remote address for the request from a [rawSockaddr]. +func (w *requestWriter) SetRemoteAddr(raddr rawSockaddr) { + w.raddr = raddr +} + +// Reserve reserves n bytes in the request's data buffer, +// and returns a slice pointing to the reserved space. +// It panics if n is negative or exceeds the available capacity. +func (w *requestWriter) Reserve(n int) []byte { + if n < 0 { + panic(fmt.Errorf("cannot reserve negative bytes: %d", n)) + } + if avail := w.Available(); n > avail { + panic(fmt.Errorf("cannot reserve %d bytes: only %d available", n, avail)) + } + oldLen := len(w.data) + newLen := oldLen + n + w.data = w.data[:newLen] + return w.data[oldLen:newLen:newLen] // prevent reslicing beyond newLen +} + +// Write implements [io.Writer]. +func (w *requestWriter) Write(p []byte) (n int, err error) { + if len(p) > w.Available() { + return 0, errors.New("not enough space to write data") + } + oldLen := len(w.data) + newLen := oldLen + len(p) + w.data = w.data[:newLen] + copy(w.data[oldLen:], p) + return len(p), nil +} + +// SetLen sets the length of the request's data buffer to n. +// It panics if n is negative or exceeds the number of bytes written. +func (w *requestWriter) SetLen(n int) { + if n < 0 { + panic(fmt.Errorf("cannot set negative data length: %d", n)) + } + if n > len(w.data) { + panic(fmt.Errorf("data length %d exceeds the number of bytes written (%d)", n, len(w.data))) + } + w.data = w.data[:n] +} + +// Len returns the number of bytes available to read. +func (r *requestReader) Len() int { + return len(r.data) +} + +// Bytes returns the request's payload data. +func (r *requestReader) Bytes() []byte { + return r.data[:len(r.data):len(r.data)] // prevent reslicing beyond len +} + +// RemoteAddrPort returns the request's remote address as a [netip.AddrPort]. +func (r *requestReader) RemoteAddrPort() (netip.AddrPort, error) { + return r.raddr.ToAddrPort() +} + +// RemoteAddr returns the request's remote address as an [rawSockaddr]. +// It is more efficient than [requestReader.RemoteAddrPort] if the caller +// only needs the raw socket address. +func (r *requestReader) RemoteAddr() rawSockaddr { + return r.raddr +} diff --git a/net/rioconn/request_ring.go b/net/rioconn/request_ring.go new file mode 100644 index 000000000..2183f897a --- /dev/null +++ b/net/rioconn/request_ring.go @@ -0,0 +1,310 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "fmt" + "iter" + "math" + "unsafe" + + "github.com/tailscale/wireguard-go/conn/winrio" + "golang.org/x/sys/windows" + "tailscale.com/util/winutil" +) + +// requestRing is a circular buffer of [request]s. +// It is not safe for concurrent use. +type requestRing struct { + capacity uint32 // number of requests in the ring; always a power of two + dataSize uint32 // per-request data buffer length in bytes + stride uintptr // byte offset from one request to the next + indexMask uint32 // masks an index to [0, capacity) + + ptr uintptr // base address of the ring buffer allocation + size uintptr // size of the allocation in bytes + buff []byte // a byte slice view of the allocated buffer + largePages bool // whether the allocation used large pages + + head, tail uint32 // monotonic counters; apply &indexMask for indexing + + id winrio.BufferId +} + +// When a memory buffer is registered with RIO, the virtual memory pages +// containing the buffer are locked into physical memory. +// Set hard limits to avoid excessive memory usage. +// TODO(nickkhyl): derive preferred values from system parameters? +const ( + maxRequestRingSize = 1 << 30 // 1 GiB; arbitrary limit + maxNumberOfRequests = 16384 // arbitrary power-of-two limit + maxRequestDataSize = math.MaxUint16 // IP packet size limit +) + +// requestStride returns the byte offset from the start of one [request] +// in a [requestRing] to the start of the next, given the per-request data buffer length. +func requestStride(dataSize uint16) uintptr { + return totalRequestSize(uintptr(dataSize)) +} + +// maxRequestRingCapacity returns the maximum ring buffer capacity that fits +// within maxBytes, given the per-request data size. The returned capacity +// does not exceed idealCapacity unless idealCapacity is zero, in which case the +// maximum possible value is returned. +// +// dataSize must be in (0, [maxRequestDataSize]], and idealCapacity, if non-zero, +// must be a power of two. The function returns an error if the buffer cannot +// +// It returns an error if the parameters are invalid, if the resulting +// capacity cannot hold at least one request, or if maxBytes exceeds +// [maxRequestRingSize]. +func maxRequestRingCapacity(idealCapacity uint32, dataSize uint16, maxBytes uintptr) (uint32, error) { + if idealCapacity != 0 && !isPowerOfTwo(idealCapacity) { + return 0, fmt.Errorf("the capacity must be a power of two, got %d", idealCapacity) + } + capacity := min(idealCapacity, maxNumberOfRequests) + + if dataSize == 0 || dataSize > maxRequestDataSize { + return 0, fmt.Errorf("the data size must be in (0, %d], got %d", maxRequestDataSize, dataSize) + } + + stride := requestStride(dataSize) + if maxBytes < stride { + return 0, fmt.Errorf("cannot fit any requests within maxBytes %d", maxBytes) + } + if maxBytes > maxRequestRingSize { + return 0, fmt.Errorf("maxBytes %d exceeds limit of %d", maxBytes, maxRequestRingSize) + } + + if capacity == 0 || uintptr(capacity)*stride > maxBytes { + capacity = uint32(floorPowerOfTwo(maxBytes / stride)) + } + return capacity, nil +} + +const ( + // seLockMemoryPrivilege is the name of the Windows privilege required to allocate large pages. + seLockMemoryPrivilege = "SeLockMemoryPrivilege" +) + +// newRequestRing creates a ring buffer of up to maxBytes bytes, +// with each element representing a RIO request backed by a data buffer +// of dataSize bytes. +// +// It determines the buffer capacity as the maximum power-of-two number +// of requests that fits within the allocation size limit, using large +// pages when possible. +// +// If the allocation fails due to insufficient memory, it retries +// with progressively smaller sizes until it succeeds or cannot fit +// at least one request. +// +// The returned buffer is registered with RIO and must be closed with +// [requestRing.Close] to unregister it and free the memory. +func newRequestRing(dataSize uint16, maxBytes uintptr) (_ *requestRing, err error) { + rb := &requestRing{ + dataSize: uint32(dataSize), + stride: requestStride(dataSize), + } + defer func() { + if err != nil { + rb.Close() + } + }() + + var largePageSize uintptr // 0 means "do not use large pages" + // The SeLockMemoryPrivilege privilege is required to allocate large pages. + // By default, this privilege can be requested only by processes + // running as Local System, such as the Tailscale service, + // and is not available to regular user processes (e.g., when running tests). + // If enabling the privilege fails, we fall back to normal pages. + // For testing, you can grant the privilege to your user account + // using the Local Security Policy management console (secpol.msc). + dropPrivs, err := winutil.EnableCurrentThreadPrivilege(seLockMemoryPrivilege) + if err == nil { + defer dropPrivs() + largePageSize = windows.GetLargePageMinimum() + } + +loop: + for { + capacity, err := maxRequestRingCapacity(0, dataSize, maxBytes) + if err != nil { + // The requested parameters are invalid. + return nil, err + } + + rb.capacity = uint32(capacity) + rb.indexMask = uint32(capacity - 1) + rb.size = rb.stride * uintptr(capacity) + + var largePageFlags uint32 + if largePageSize != 0 { + if alignedSize := alignUp(rb.size, largePageSize); alignedSize <= maxBytes { + largePageFlags = windows.MEM_LARGE_PAGES + rb.size = alignedSize + } + } + + rb.ptr, err = windows.VirtualAlloc( + 0, // no preferred address + rb.size, + windows.MEM_COMMIT|windows.MEM_RESERVE|largePageFlags, + windows.PAGE_READWRITE, + ) + switch err { + case nil: + // Allocation succeeded. + rb.buff = unsafe.Slice((*byte)(unsafe.Pointer(rb.ptr)), rb.size) + rb.largePages = largePageFlags != 0 + break loop + case windows.ERROR_NOT_ENOUGH_MEMORY: + // Try again with a smaller buffer. + maxBytes /= 2 + continue + case windows.ERROR_NO_SYSTEM_RESOURCES, windows.ERROR_PRIVILEGE_NOT_HELD: + // Cannot use large pages, try again without them. + largePageSize = 0 + continue + default: + return nil, fmt.Errorf("failed to allocate request ring buffer: %w", err) + } + } + + // The actual RIO initialization is guarded by [sync.Once], and we usually + // perform it much earlier. We check it here as well to ensure that calling + // [winrio.RegisterPointer] won't panic (e.g., in tests). + if err := Initialize(); err != nil { + return nil, err + } + + // Register the allocated buffer with RIO. + if rb.id, err = winrio.RegisterPointer(unsafe.Pointer(unsafe.SliceData(rb.buff)), uint32(rb.size)); err != nil { + return nil, fmt.Errorf("failed to register request ring buffer with RIO: %w", err) + } + + // Initialize each request in the ring. + for i := uintptr(0); i < uintptr(rb.capacity); i++ { + _, bytesUsed, ok := unsafeMapRequest(rb.id, rb.buff, i*rb.stride, uintptr(dataSize)) + if !ok || bytesUsed != rb.stride { + // This should never happen. + panic("failed to map request in newly created request ring") + } + } + return rb, nil +} + +// newRequestRingWithCapacity creates a ring buffer with the specified +// power-of-two capacity and per-request data length. +// +// If allocating that many requests exceeds the maximum allowed request ring +// size or fails due to insufficient memory, it retries with progressively +// smaller sizes until it succeeds or cannot fit at least one request. +// +// The caller is responsible for calling [requestRing.close] to free +// the allocated memory when done. +func newRequestRingWithCapacity(dataSize uint16, capacity uint32) (*requestRing, error) { + if !isPowerOfTwo(capacity) { + return nil, fmt.Errorf("capacity must be a power of two, got %d", capacity) + } + maxSizeInBytes := requestStride(dataSize) * uintptr(capacity) + return newRequestRing(dataSize, maxSizeInBytes) +} + +// Cap returns the total number of requests the ring can hold. +func (rb *requestRing) Cap() uint32 { + return rb.capacity +} + +// Len returns the number requests currently in use (acquired and not yet released). +func (rb *requestRing) Len() uint32 { + return rb.tail - rb.head +} + +// IsEmpty reports whether no requests are currently in use. +func (rb *requestRing) IsEmpty() bool { + return rb.head == rb.tail +} + +// IsFull reports whether all requests are currently in use. +func (rb *requestRing) IsFull() bool { + return rb.Len() == rb.Cap() +} + +// Peek returns the next request without advancing the tail. +// It panics if [requestRing.IsFull] reports true. +func (rb *requestRing) Peek() *request { + if rb.IsFull() { + panic("ring is full") + } + return rb.peek() +} + +// Advance marks the next request as in use by advancing the tail. +// It panics if [requestRing.IsFull] reports true. +func (rb *requestRing) Advance() { + if rb.IsFull() { + panic("ring is full") + } + rb.tail += 1 +} + +// Acquire returns the next [request] from the ring, advancing the tail. +// It panics if [requestRing.IsFull] reports true. +func (rb *requestRing) Acquire() *request { + req := rb.Peek() + rb.tail += 1 + return req +} + +// AcquireSeq yields available [request]s one by one until the ring +// runs out of unused requests or the caller stops the iteration. +func (rb *requestRing) AcquireSeq() iter.Seq[*request] { + return func(yield func(req *request) bool) { + end := rb.head + rb.capacity + for ; rb.tail != end; rb.tail++ { + if !yield(rb.peek()) { + return + } + } + } +} + +func (rb *requestRing) peek() *request { + offset := uintptr(rb.tail&rb.indexMask) * rb.stride + ptr := unsafe.SliceData(rb.buff[offset : offset+rb.stride]) + req := (*request)(unsafe.Pointer(ptr)) + req.Reset() + return req +} + +// ReleaseN marks n requests at the head of the ring as free. +// It is a run-time error to release more requests than have been acquired. +func (rb *requestRing) ReleaseN(n int) { + if n < 0 { + panic("ring: negative release count") + } + if uint64(n) > uint64(rb.Len()) { + panic("ring: releasing more requests than acquired") + } + rb.head += uint32(n) +} + +// Close frees the memory allocated for the request ring. +func (rb *requestRing) Close() error { + if rb.id != 0 { + winrio.DeregisterBuffer(rb.id) + rb.id = 0 + } + + if rb.ptr != 0 { + if err := windows.VirtualFree(rb.ptr, 0, windows.MEM_RELEASE); err != nil { + return fmt.Errorf("failed to free request ring buffer: %w", err) + } + rb.ptr = 0 + } + return nil +} diff --git a/net/rioconn/request_ring_test.go b/net/rioconn/request_ring_test.go new file mode 100644 index 000000000..8249162ea --- /dev/null +++ b/net/rioconn/request_ring_test.go @@ -0,0 +1,507 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "fmt" + "slices" + "testing" + "unsafe" + + "golang.org/x/sys/windows" +) + +func TestMaxNumberOfRequestsIsPow2(t *testing.T) { + if !isPowerOfTwo(maxNumberOfRequests) { + t.Fatalf("maxNumberOfRRequests %d is not a power of two", maxNumberOfRequests) + } +} + +func TestFloorPowerOfTwo(t *testing.T) { + tests := []struct { + n uint64 + want uint64 + }{ + {0, 0}, + {1, 1}, + {16, 16}, + {17, 16}, + {31, 16}, + {32, 32}, + {uint64(1 << 63), uint64(1 << 63)}, + {uint64(1<<64 - 1), uint64(1) << 63}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%x", tt.n), func(t *testing.T) { + if got := floorPowerOfTwo(tt.n); got != tt.want { + t.Fatalf("got %d; want %d", got, tt.want) + } + }) + } +} + +func TestMaxRingBufferCapacity(t *testing.T) { + tests := []struct { + name string + idealCapacity uint32 + dataSize uint16 + maxBytes uintptr + wantCapacity uint32 + wantErr bool + }{ + { + name: "invalid/not-pow2", + idealCapacity: 3, + dataSize: 512, + maxBytes: 65536, + wantErr: true, + }, + { + name: "invalid/data-length-zero", + idealCapacity: 16, + dataSize: 0, + maxBytes: 65536, + wantErr: true, + }, + { + name: "invalid/max-bytes-too-small", + idealCapacity: 16, + dataSize: 512, + maxBytes: requestStride(512) - 1, // less than one request + wantErr: true, + }, + { + name: "valid/no-clamp", + idealCapacity: 16, + dataSize: 512, + maxBytes: 65536, // can fit [0; 128) requests + wantCapacity: 16, // and we asked for 16 + }, + { + name: "valid/exact-fit", + idealCapacity: 16, + dataSize: 512, + maxBytes: requestStride(512) * 16, + wantCapacity: 16, + }, + { + name: "valid/clamp-down", + idealCapacity: 128, + dataSize: 512, + maxBytes: requestStride(512) * 64, // can fit only 64 requests + wantCapacity: 64, // clamps down to 64 + }, + { + name: "valid/max-requests", + idealCapacity: 0, // want as many as possible + dataSize: 512, + maxBytes: 65536, // can fit [0; 128) requests + wantCapacity: 64, // the max power of two that fits + }, + { + name: "valid/large-buffer/no-clamp", + idealCapacity: 8192, + dataSize: maxRequestDataSize, + maxBytes: requestStride(maxRequestDataSize) * 8192, + wantCapacity: 8192, + }, + { + name: "valid/large-buffer/clamp-down", + idealCapacity: 8192, + dataSize: maxRequestDataSize, + maxBytes: requestStride(maxRequestDataSize) * 8191, // can fit only 8191 + wantCapacity: 4096, // clamps to next lower power of two + }, + { + name: "invalid/too-many-requests", + idealCapacity: maxNumberOfRequests * 2, + dataSize: 512, + maxBytes: 1 << 30, + wantCapacity: maxNumberOfRequests, // cannot exceed [maxNumberOfRRequests] + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := maxRequestRingCapacity(tt.idealCapacity, tt.dataSize, uintptr(tt.maxBytes)) + if (err != nil) != tt.wantErr { + t.Fatalf("maxRingBufferCapacity error: got %v; want %v", err, tt.wantErr) + } + if got != tt.wantCapacity { + t.Fatalf("maxRingBufferCapacity: got %v; want %v", got, tt.wantCapacity) + } + }) + } +} + +func TestNewRingBuffer(t *testing.T) { + tests := []struct { + name string + dataSize uint16 + maxSizeInBytes uintptr + wantCapacity uint32 + wantErr bool + }{ + { + name: "small-buffer", + dataSize: 256, + maxSizeInBytes: requestStride(256) * 4, + wantCapacity: 4, + }, + { + name: "large-buffer/small-requests", + dataSize: 1280, + maxSizeInBytes: requestStride(1280) * 8192, + wantCapacity: 8192, + }, + { + name: "large-buffer/large-requests", + dataSize: 65535, + maxSizeInBytes: requestStride(65535) * 32, + wantCapacity: 32, + }, + { + name: "invalid/limit-too-small", + dataSize: 256, + maxSizeInBytes: 100, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rb, err := newRequestRing(tt.dataSize, tt.maxSizeInBytes) + if (err != nil) != tt.wantErr { + t.Fatalf("newRingBuffer error: got %v; wantErr %v", err, tt.wantErr) + } + if err != nil { + return + } + t.Cleanup(func() { + if err := rb.Close(); err != nil { + t.Fatalf("ringBuffer.close() failed: %v", err) + } + }) + + if gotCapacity := rb.Cap(); gotCapacity != tt.wantCapacity { + t.Errorf("ringBuffer.Cap() = %v; want %v", gotCapacity, tt.wantCapacity) + } + }) + } +} + +func TestRingBufferAcquireRelease(t *testing.T) { + tests := []struct { + name string + capacity uint32 + initialHead uint32 + initialTail uint32 + acquires int + wantAcquirePanic bool + releaseN int + wantReleasePanic bool + wantHead uint32 + wantTail uint32 + wantDepth int + }{ + { + name: "empty/acquire-one", + capacity: 16, + acquires: 1, + wantAcquirePanic: false, + wantTail: 1, + wantDepth: 1, + }, + { + name: "empty/acquire-few", + capacity: 16, + acquires: 4, + wantAcquirePanic: false, + wantTail: 4, + wantDepth: 4, + }, + { + name: "empty/acquire-all", + capacity: 16, + acquires: 16, + wantAcquirePanic: false, + wantTail: 16, + wantDepth: 16, + }, + { + name: "empty/acquire-too-many", + capacity: 16, + acquires: 17, // one more than the buffer can hold + wantAcquirePanic: true, + wantTail: 16, + wantDepth: 16, + }, + { + name: "empty/release-one", + capacity: 16, + releaseN: 1, + wantReleasePanic: true, + }, + { + name: "partially-full/acquire-few", + capacity: 16, + initialTail: 8, + acquires: 4, + wantAcquirePanic: false, + wantTail: 12, + wantDepth: 8, + }, + { + name: "partially-full/release-few", + capacity: 16, + initialTail: 8, + releaseN: 4, + wantReleasePanic: false, + wantHead: 4, + wantTail: 8, + wantDepth: 4, + }, + { + name: "partially-full/acquire-all/wrap-around", + capacity: 16, + initialHead: 4, + initialTail: 8, + acquires: 12, + wantAcquirePanic: false, + wantHead: 4, + wantTail: 20, + wantDepth: 16, + }, + { + name: "partially-full/acquire-too-many", + capacity: 16, + initialHead: 4, + initialTail: 8, + acquires: 13, // one more than can fit + wantAcquirePanic: true, + wantTail: 20, + wantHead: 4, + wantDepth: 16, + }, + { + name: "partially-full/release-too-many", + capacity: 16, + initialHead: 4, + initialTail: 8, + releaseN: 9, // one more than acquired + wantReleasePanic: true, + wantHead: 4, + wantTail: 8, + }, + { + name: "full/acquire-one", + capacity: 16, + initialTail: 16, + acquires: 1, + wantAcquirePanic: true, + wantTail: 16, + }, + { + name: "full/release-one", + capacity: 16, + initialTail: 16, + releaseN: 1, + wantReleasePanic: false, + wantHead: 1, + wantTail: 16, + }, + { + name: "full/release-all", + capacity: 16, + initialTail: 16, + releaseN: 16, + wantReleasePanic: false, + wantHead: 16, + wantTail: 16, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rb, err := newRequestRingWithCapacity(4, tt.capacity) + if err != nil { + t.Fatalf("newRingBuffer failed: %v", err) + } + t.Cleanup(func() { rb.Close() }) + + rb.head = uint32(tt.initialHead) + rb.tail = uint32(tt.initialTail) + + CheckPanic(t, tt.wantAcquirePanic, func() { + for range tt.acquires { + _ = rb.Acquire() + } + }) + + CheckPanic(t, tt.wantReleasePanic, func() { + rb.ReleaseN(tt.releaseN) + }) + + if rb.head != tt.wantHead { + t.Fatalf("rb.head = %d; want %d", rb.head, tt.wantHead) + } + if rb.tail != tt.wantTail { + t.Fatalf("rb.tail = %d; want %d", rb.tail, tt.wantTail) + } + }) + } +} + +func TestRingBufferAcquireSeq(t *testing.T) { + tests := []struct { + name string + capacity uint32 + initialHead uint32 + initialTail uint32 + wantTail uint32 + wantCount int + }{ + { + name: "empty", + capacity: 16, + initialHead: 0, + initialTail: 0, + wantTail: 16, + wantCount: 16, // 16 requests to acquire: [0; 16) + }, + { + name: "partially-full", + capacity: 16, + initialHead: 4, + initialTail: 10, + wantTail: 20, + wantCount: 10, // 10 requests to acquire: [10; 15) and [0; 4) + }, + { + name: "full", + capacity: 16, + initialHead: 0, + initialTail: 16, + wantTail: 16, + wantCount: 0, // nothing to acquire + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rb, err := newRequestRingWithCapacity(4, tt.capacity) + if err != nil { + t.Fatalf("newRingBuffer failed: %v", err) + } + t.Cleanup(func() { rb.Close() }) + + // Initialize head and tail. + rb.head = tt.initialHead + rb.tail = tt.initialTail + + // Acquire all available requests and count how many we got. + gotCount := len(slices.Collect(rb.AcquireSeq())) + + // Check that we got the expected count, the head didn't change, + // and the tail advanced as expected. + if gotCount != tt.wantCount { + t.Fatalf("gotCount = %d; want %d", gotCount, tt.wantCount) + } + if rb.head != tt.initialHead { + t.Fatalf("rb.head = %d; want %d", rb.head, tt.initialHead) + } + if rb.tail != tt.wantTail { + t.Fatalf("rb.tail = %d; want %d", rb.tail, tt.wantTail) + } + }) + } +} + +func TestRingBufferWrapAround(t *testing.T) { + const capacity = 16 + rb, err := newRequestRingWithCapacity(4, capacity) + if err != nil { + t.Fatalf("newRingBuffer failed: %v", err) + } + t.Cleanup(func() { rb.Close() }) + + // Acquire all requests and store pointers to them in a slice. + requests := slices.Collect(rb.AcquireSeq()) + if len(requests) != capacity { + t.Fatalf("acquired %d requests; want %d", len(requests), capacity) + } + + rb.ReleaseN(capacity) // release all + + // Acquire again and ensure we get the same requests in the same order. + for i, wantReq := range requests { + if gotReq := rb.Acquire(); gotReq != wantReq { + t.Fatalf("acquired request %d = %p; want %p", i, gotReq, wantReq) + } + } +} + +// CheckPanic checks whether the given function panics or not, +// and fails the test if the result does not match wantPanic. +func CheckPanic(t *testing.T, wantPanic bool, fn func()) { + t.Helper() + defer func() { + if r := recover(); r != nil { + if !wantPanic { + t.Fatalf("unexpected panic: %v", r) + } + } else if wantPanic { + t.Fatal("expected panic but none occurred") + } + }() + fn() +} + +// checkProcessPrivilege reports whether the given privilege is present +// and/or enabled in the current process token. +func checkProcessPrivilege(privName string) (present bool, enabled bool, err error) { + var tok windows.Token + if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &tok); err != nil { + return false, false, err + } + defer tok.Close() + return checkTokenPrivilege(tok, privName) +} + +// checkTokenPrivilege reports whether the given privilege is present +// and/or enabled in the specified token. +func checkTokenPrivilege(token windows.Token, privName string) (present bool, enabled bool, err error) { + var luid windows.LUID + namePtr, err := windows.UTF16PtrFromString(privName) + if err != nil { + return false, false, err + } + if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil { + return false, false, err + } + + var needed uint32 + _ = windows.GetTokenInformation(token, windows.TokenPrivileges, nil, 0, &needed) + if needed == 0 { + return false, false, windows.GetLastError() + } + + buf := make([]byte, needed) + if err := windows.GetTokenInformation(token, windows.TokenPrivileges, &buf[0], uint32(len(buf)), &needed); err != nil { + return false, false, err + } + + tp := (*windows.Tokenprivileges)(unsafe.Pointer(&buf[0])) + count := int(tp.PrivilegeCount) + + laa := unsafe.Slice((*windows.LUIDAndAttributes)(unsafe.Pointer(&tp.Privileges[0])), count) + for _, p := range laa { + if p.Luid == luid { + present = true + enabled = (p.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0 + return present, enabled, nil + } + } + + return false, false, nil +} diff --git a/net/rioconn/request_test.go b/net/rioconn/request_test.go new file mode 100644 index 000000000..b1f4c498b --- /dev/null +++ b/net/rioconn/request_test.go @@ -0,0 +1,319 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "bytes" + "net/netip" + "testing" + "unsafe" +) + +func TestUnsafeMapRequest(t *testing.T) { + tests := []struct { + name string + buff []byte + offset uintptr + dataSize uintptr + wantOk bool + wantOffset uintptr + wantBytesUsed uintptr + }{ + { + name: "enough-space", + buff: makeAligned(totalRequestSize(10), unsafe.Alignof(request{})), + dataSize: 10, + wantOk: true, + wantOffset: 0, + wantBytesUsed: totalRequestSize(10), + }, + { + name: "enough-space/unaligned-buffer", + buff: makeMisaligned(totalRequestSize(10)+2, unsafe.Alignof(request{}), 2), + offset: 0, + dataSize: 10, + wantOk: true, + wantOffset: 2, // the request starts at offset 2 for proper alignment + wantBytesUsed: totalRequestSize(10) + 2, // includes padding before the request + }, + { + name: "enough-space/unaligned-buffer/aligned-offset", + buff: makeMisaligned(totalRequestSize(10)+2, unsafe.Alignof(request{}), 2), + offset: 2, + dataSize: 10, + wantOk: true, + wantOffset: 2, // same as offset requested + wantBytesUsed: totalRequestSize(10), // no extra padding needed + }, + { + name: "enough-space/aligned-buffer/non-zero-offset", + buff: makeAligned(totalRequestSize(10)+16, unsafe.Alignof(request{})), + offset: 16, + dataSize: 10, + wantOk: true, + wantOffset: 16, + wantBytesUsed: totalRequestSize(10), + }, + { + name: "not-enough-space/nil-buffer", + buff: nil, + wantOk: false, + }, + { + name: "not-enough-space/small-buffer", + buff: makeAligned(totalRequestSize(10)-1, unsafe.Alignof(request{})), + dataSize: 10, + wantOk: false, + }, + { + name: "not-enough-space/due-to-offset", + buff: makeAligned(totalRequestSize(10), unsafe.Alignof(request{})), + offset: 8, + dataSize: 10, + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request, bytesUsed, ok := unsafeMapRequest(0, tt.buff, tt.offset, tt.dataSize) + if ok != tt.wantOk { + t.Errorf("mapRequest: ok: got %v; want %v", ok, tt.wantOk) + } + if !ok { + return + } + if bytesUsed != tt.wantBytesUsed { + t.Errorf("mapRequest: bytesUsed: got %d; want %d", bytesUsed, tt.wantBytesUsed) + } + gotOffset := uintptr(unsafe.Pointer(request)) - uintptr(unsafe.Pointer(&tt.buff[0])) + if gotOffset != tt.wantOffset { + t.Errorf("mapRequest: offset: got %d; want %d", gotOffset, tt.wantOffset) + } + }) + } +} + +func TestRequestWriterReserve(t *testing.T) { + const dataSize = 1312 + req := makeRequest(t, dataSize) + writer := req.Writer() + + // The capacity of the writer's data buffer must match the requested data size. + if gotCap := writer.Cap(); gotCap != dataSize { + t.Errorf("requestWriter.Cap: got %d; want %d", gotCap, dataSize) + } + + // Initially, the entire data buffer must be available for writing. + if gotAvail := writer.Available(); gotAvail != dataSize { + t.Errorf("requestWriter.Available: got %d; want %d", gotAvail, dataSize) + } + + // And the length of the data buffer must be zero + // since nothing has been written or reserved yet. + if gotLen := writer.Len(); gotLen != 0 { + t.Errorf("requestWriter.Len: got %d; want %d", gotLen, 0) + } + + // Reserving more than the available capacity must fail. + CheckPanic(t, true, func() { writer.Reserve(dataSize + 1) }) + if gotAvail := writer.Available(); gotAvail != dataSize { + t.Errorf("requestWriter.Available after failed Reserve: got %d; want %d", + gotAvail, dataSize) + } + + // Reserving zero bytes must return an empty buffer. + buf := writer.Reserve(0) + if len(buf) != 0 { + t.Errorf("Reserve: got buffer of length %d; want %d", len(buf), 0) + } + // The capacity of the returned buffer must also be zero. + if cap(buf) != 0 { + t.Errorf("Reserve: got buffer of capacity %d; want %d", cap(buf), 0) + } + + // Reserving part of the available capacity should succeed. + buf = writer.Reserve(64) + if len(buf) != 64 { + t.Errorf("Reserve: got buffer of length %d; want %d", len(buf), 64) + } + // The capacity of the returned buffer should be equal to its length + // to prevent reslicing beyond the reserved length. + if cap(buf) != len(buf) { + t.Errorf("Reserve: got buffer of capacity %d; want %d", cap(buf), len(buf)) + } + // The returned buffer should point to the start of the writer's data buffer. + if &buf[0] != &writer.data[0] { + t.Errorf("Reserve: got buffer starting at %p; want %p", &buf[0], &writer.data[0]) + } + // After reserving X bytes, the length of the data buffer should be X. + if writer.Len() != len(buf) { + t.Errorf("requestWriter.Len after Reserve: got %d; want %d", writer.Len(), len(buf)) + } + // And the available capacity should decrease by X. + if gotAvail := writer.Available(); gotAvail != dataSize-len(buf) { + t.Errorf("requestWriter.Available after Reserve: got %d; want %d", + gotAvail, dataSize-len(buf)) + } + // However, the total capacity of the writer's data buffer should remain unchanged. + if cap(writer.data) != dataSize { + t.Errorf("requestWriter.data capacity: got %d; want %d", cap(writer.data), dataSize) + } + + bytesWritten := copy(buf, []byte("Hello World")) + + // SetLen must panic if the desired length exceeds the number of bytes + // written or reserved so far. + CheckPanic(t, true, func() { writer.SetLen(len(buf) + 1) }) + // Or if the desired length is negative. + CheckPanic(t, true, func() { writer.SetLen(-1) }) + + // Otherwise, it must succeed and update the length of the data buffer accordingly. + writer.SetLen(len(buf)) + if writer.Len() != len(buf) { // no change expected + t.Errorf("requestWriter.Len after SetLen: got %d; want %d", writer.Len(), len(buf)) + } + + // The contents of the data buffer up to the set length should be what was written. + writer.SetLen(bytesWritten) + if !bytes.Equal(writer.data, []byte("Hello World")) { + t.Errorf("requestWriter.data: got %q; want %q", writer.data, "Hello World") + } + + // Reserving more bytes should succeed and return a buffer starting immediately + // after the previously reserved and written data. + newBuf := writer.Reserve(10) + if &newBuf[0] != &writer.data[bytesWritten] { + t.Errorf("Reserve: got buffer starting at %p; want %p", &newBuf[0], &writer.data[bytesWritten]) + } + if writer.Len() != bytesWritten+len(newBuf) { + t.Errorf("requestWriter.Len after second Reserve: got %d; want %d", + writer.Len(), bytesWritten+len(newBuf)) + } +} + +func TestRequestWriterWrite(t *testing.T) { + const dataSize = 128 + req := makeRequest(t, dataSize) + writer := req.Writer() + n, err := writer.Write([]byte("Hello")) + if err != nil { + t.Errorf("requestWriter.Write: unexpected error: %v", err) + } + if n != 5 { + t.Errorf("requestWriter.Write: got %d bytes written; want %d", n, 5) + } + writer.Write([]byte(" World")) + if !bytes.Equal(writer.data[:11], []byte("Hello World")) { + t.Errorf("requestWriter.data: got %q; want %q", writer.data[:11], "Hello World") + } + if writer.Len() != 11 { + t.Errorf("requestWriter.Len after Write: got %d; want %d", writer.Len(), 11) + } + // Writing more bytes than the available capacity should fail. + n, err = writer.Write(make([]byte, dataSize)) + if err == nil { + t.Error("requestWriter.Write: expected error when writing beyond capacity; got nil") + } + // We do not allow partial writes, so n should be zero... + if n != 0 { + t.Errorf("requestWriter.Write: got %d bytes written; want %d", n, 0) + } + // ... and the length should remain unchanged. + if writer.Len() != 11 { + t.Errorf("requestWriter.Len after failed Write: got %d; want %d", writer.Len(), 11) + } +} + +func TestRequestWriterSetAddrPort(t *testing.T) { + const dataSize = 128 + addrPort := netip.MustParseAddrPort("192.0.2.1:1234") + + req := makeRequest(t, dataSize) + writer := req.Writer() + err := writer.SetRemoteAddrPort(addrPort) + if err != nil { + t.Errorf("requestWriter.SetRemoteAddrPort: unexpected error: %v", err) + } + gotAddrPort, err := req.raddr.ToAddrPort() + if err != nil { + t.Errorf("request.raddr.ToAddrPort: unexpected error: %v", err) + } + if gotAddrPort != addrPort { + t.Errorf("request.raddr: got %v; want %v", gotAddrPort, addrPort) + } +} + +func TestRequestReader(t *testing.T) { + const dataSize = 128 + req := makeRequest(t, dataSize) + reader := req.Reader() + if reader.Len() != 0 { + t.Errorf("requestReader.Len: got %d; want %d", reader.Len(), 0) + } + if len(reader.Bytes()) != 0 { + t.Errorf("requestReader.Bytes: got buffer of length %d; want %d", len(reader.Bytes()), 0) + } + + // Simulate receiving data by directly modifying the request's data buffer and length. + req.data = req.data[:11] + copy(req.data, []byte("Hello World")) + + if !bytes.Equal(reader.Bytes(), []byte("Hello World")) { + t.Errorf("requestReader.Bytes: got %q; want %q", reader.Bytes(), "Hello World") + } + if reader.Len() != 11 { + t.Errorf("requestReader.Len: got %d; want %d", reader.Len(), 11) + } + + addrPort := netip.MustParseAddrPort("192.0.2.1:1234") + req.raddr, _ = rawSockaddrFromAddrPort(addrPort) + gotAddrPort, err := reader.RemoteAddrPort() + if err != nil { + t.Errorf("requestReader.RemoteAddrPort: unexpected error: %v", err) + } + if gotAddrPort != addrPort { + t.Errorf("requestReader.RemoteAddrPort: got %v; want %v", gotAddrPort, addrPort) + } +} + +func makeRequest(tb testing.TB, dataSize int) *request { + tb.Helper() + buff := makeAligned(totalRequestSize(uintptr(dataSize)), unsafe.Alignof(request{})) + req, _, ok := unsafeMapRequest(0, buff, 0, uintptr(dataSize)) + if !ok { + tb.Fatalf("failed to make request of size %d", dataSize) + } + return req +} + +// makeAligned returns a slice of n bytes such that the address of +// the first byte is aligned to the given alignment. +// Alignment must be a power of two. +func makeAligned(n, alignment uintptr) []byte { + if n == 0 { + return nil + } + buff := make([]byte, n+alignment-1) + base := uintptr(unsafe.Pointer(&buff[0])) + offset := alignUpOffset(base, 0, alignment) + return buff[offset : offset+uintptr(n)] +} + +// makeMisaligned returns a slice of n bytes whose start address is +// misaligned by misalign bytes relative to alignment, such that advancing +// the start address by misalign bytes yields an alignment-aligned address. +// Alignment must be a power of two. +func makeMisaligned(n, alignment, misalign uintptr) []byte { + if n == 0 { + return nil + } + buff := make([]byte, n+alignment+misalign) + base := uintptr(unsafe.Pointer(&buff[0])) + aligned := (base + misalign + alignment - 1) &^ (alignment - 1) + start := aligned - misalign + offset := start - base + return buff[offset : offset+n] +}