mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-13 00:06:13 +02:00
net/rioconn: implement request and requestRing
RIO requires all send/receive request data buffers, as well as control/OOB data buffers and address buffers to be allocated from a memory region that was previously registered with RIO. In this commit, we introduce a request: a structure that always resides in a heap-allocated, RIO-registered memory region, followed by the actual data buffer. A request can be created with unsafeMapRequest, which constructs it in place from a byte buffer, an offset into that buffer, and the payload capacity the request should support. The request itself is agnostic to the backing storage and allocation strategy. We then implement a requestRing, a circular buffer of fixed-length, fixed-stride request elements. It allocates memory on the heap, registers it with RIO, and tracks which requests are in use. We can experiment with other allocation or mapping strategies later, or allow requests to return to the ring out of order if necessary. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
b6677f2e0f
commit
e79262bced
24
net/rioconn/doc.go
Normal file
24
net/rioconn/doc.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
)
|
||||
|
||||
// ErrRIOUnavailable is returned when Windows RIO is required but not available.
|
||||
var ErrRIOUnavailable = fmt.Errorf("Registered I/O (RIO) is not available on this system")
|
||||
|
||||
// Initialize initializes the Windows RIO API extensions.
|
||||
// It returns [ErrRIOUnavailable] if RIO cannot be used.
|
||||
func Initialize() error {
|
||||
if !winrio.Initialize() {
|
||||
return ErrRIOUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
35
net/rioconn/memory.go
Normal file
35
net/rioconn/memory.go
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// alignUp returns the smallest value >= v that is a multiple of alignment.
|
||||
// The alignment must be a power of two.
|
||||
func alignUp[V, A constraints.Integer](v V, alignment A) V {
|
||||
return (v + V(alignment) - 1) &^ (V(alignment) - 1)
|
||||
}
|
||||
|
||||
// alignUpOffset rounds offset up so that base+offset is aligned to the
|
||||
// specified boundary. Alignment must be a power of two.
|
||||
func alignUpOffset(base, offset, alignment uintptr) uintptr {
|
||||
return alignUp(base+offset, alignment) - base
|
||||
}
|
||||
|
||||
// isPowerOfTwo reports whether n is a power of two.
|
||||
func isPowerOfTwo[T constraints.Integer](n T) bool {
|
||||
return n > 0 && (n&(n-1)) == 0
|
||||
}
|
||||
|
||||
// floorPowerOfTwo returns the largest power of two <= n.
|
||||
func floorPowerOfTwo[T constraints.Unsigned](n T) T {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
return 1 << (bits.Len64(uint64(n)) - 1)
|
||||
}
|
||||
187
net/rioconn/request.go
Normal file
187
net/rioconn/request.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
)
|
||||
|
||||
// request represents a portion of a RIO-registered memory buffer used for
|
||||
// a single send or receive operation. It is always heap-allocated, and the
|
||||
// fixed-size struct is followed in memory by a variable-size data buffer.
|
||||
//
|
||||
// A memory buffer implementation, such as a [requestRing], is responsible for
|
||||
// allocating requests within a registered buffer via [unsafeMapRequest] and
|
||||
// ensuring that the buffer remains registered and valid for the lifetime
|
||||
// of the requests.
|
||||
type request struct {
|
||||
buffID winrio.BufferId // ID of the registered buffer containing this request
|
||||
buffBase uintptr // base address of the registered buffer
|
||||
|
||||
raddr rawSockaddr // remote address for RIO send/receive operations
|
||||
data []byte // a slice pointing into the data buffer area after the struct
|
||||
// followed by the actual data at [requestDataOffset]
|
||||
// from the start of the struct.
|
||||
}
|
||||
|
||||
const (
|
||||
requestDataAlignment = 8
|
||||
requestDataOffset = (unsafe.Sizeof(request{}) + requestDataAlignment - 1) &^ (requestDataAlignment - 1)
|
||||
)
|
||||
|
||||
// totalRequestSize returns the total number of bytes required to hold
|
||||
// a [request] struct followed by a data buffer of the given size.
|
||||
func totalRequestSize(dataSize uintptr) uintptr {
|
||||
return alignUp(requestDataOffset+dataSize, unsafe.Alignof(request{}))
|
||||
}
|
||||
|
||||
// unsafeMapRequest maps a [request] into the given RIO-registered buffer at the
|
||||
// specified offset and returns a pointer to it, the number of bytes used,
|
||||
// and whether the mapping succeeded.
|
||||
//
|
||||
// On success, the returned pointer is aligned to the [request]'s natural
|
||||
// alignment and the request can hold up to dataSize bytes of data.
|
||||
//
|
||||
// It is the caller's responsibility to ensure that the buffer remains
|
||||
// registered and valid for the lifetime of the returned request.
|
||||
func unsafeMapRequest(buffID winrio.BufferId, buff []byte, offset, dataSize uintptr) (_ *request, n uintptr, ok bool) {
|
||||
baseAddr := uintptr(unsafe.Pointer(unsafe.SliceData(buff)))
|
||||
alignedOffset := alignUpOffset(baseAddr, offset, unsafe.Alignof(request{}))
|
||||
bytesNeeded := totalRequestSize(dataSize) + uintptr(alignedOffset-offset)
|
||||
if offset >= uintptr(len(buff)) {
|
||||
return nil, 0, false
|
||||
}
|
||||
bytesAvailable := uintptr(len(buff)) - offset
|
||||
if bytesAvailable < bytesNeeded {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
requestBytes := unsafe.SliceData(buff[alignedOffset:])
|
||||
request := (*request)(unsafe.Pointer(requestBytes))
|
||||
request.buffID = buffID
|
||||
request.buffBase = baseAddr
|
||||
request.data = unsafe.Slice(
|
||||
(*byte)(unsafe.Add(unsafe.Pointer(request), requestDataOffset)),
|
||||
dataSize,
|
||||
)[:0] // zero-length data slice with capacity of dataSize
|
||||
return request, bytesNeeded, true
|
||||
}
|
||||
|
||||
// Writer returns a [requestWriter] for the request.
|
||||
func (r *request) Writer() *requestWriter {
|
||||
return (*requestWriter)(r)
|
||||
}
|
||||
|
||||
// Reader returns a [requestReader] for the request.
|
||||
func (r *request) Reader() *requestReader {
|
||||
return (*requestReader)(r)
|
||||
}
|
||||
|
||||
// Reset prepares the request for reuse by resetting its state.
|
||||
func (r *request) Reset() {
|
||||
r.raddr = rawSockaddr{}
|
||||
r.data = r.data[:0]
|
||||
}
|
||||
|
||||
type (
|
||||
requestWriter request
|
||||
requestReader request
|
||||
)
|
||||
|
||||
// Len returns the number of bytes written or reserved so far.
|
||||
func (w *requestWriter) Len() int {
|
||||
return len(w.data)
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of bytes that can be written or reserved.
|
||||
func (w *requestWriter) Cap() int {
|
||||
return cap(w.data)
|
||||
}
|
||||
|
||||
// Available returns the number of bytes available for writing or reserving.
|
||||
func (w *requestWriter) Available() int {
|
||||
return cap(w.data) - len(w.data)
|
||||
}
|
||||
|
||||
// SetRemoteAddrPort sets the remote address for the request from a [netip.AddrPort].
|
||||
// It returns an error if the specified address cannot converted to a [rawSockaddr].
|
||||
func (w *requestWriter) SetRemoteAddrPort(raddr netip.AddrPort) error {
|
||||
var err error
|
||||
w.raddr, err = rawSockaddrFromAddrPort(raddr)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetRemoteAddr sets the remote address for the request from a [rawSockaddr].
|
||||
func (w *requestWriter) SetRemoteAddr(raddr rawSockaddr) {
|
||||
w.raddr = raddr
|
||||
}
|
||||
|
||||
// Reserve reserves n bytes in the request's data buffer,
|
||||
// and returns a slice pointing to the reserved space.
|
||||
// It panics if n is negative or exceeds the available capacity.
|
||||
func (w *requestWriter) Reserve(n int) []byte {
|
||||
if n < 0 {
|
||||
panic(fmt.Errorf("cannot reserve negative bytes: %d", n))
|
||||
}
|
||||
if avail := w.Available(); n > avail {
|
||||
panic(fmt.Errorf("cannot reserve %d bytes: only %d available", n, avail))
|
||||
}
|
||||
oldLen := len(w.data)
|
||||
newLen := oldLen + n
|
||||
w.data = w.data[:newLen]
|
||||
return w.data[oldLen:newLen:newLen] // prevent reslicing beyond newLen
|
||||
}
|
||||
|
||||
// Write implements [io.Writer].
|
||||
func (w *requestWriter) Write(p []byte) (n int, err error) {
|
||||
if len(p) > w.Available() {
|
||||
return 0, errors.New("not enough space to write data")
|
||||
}
|
||||
oldLen := len(w.data)
|
||||
newLen := oldLen + len(p)
|
||||
w.data = w.data[:newLen]
|
||||
copy(w.data[oldLen:], p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// SetLen sets the length of the request's data buffer to n.
|
||||
// It panics if n is negative or exceeds the number of bytes written.
|
||||
func (w *requestWriter) SetLen(n int) {
|
||||
if n < 0 {
|
||||
panic(fmt.Errorf("cannot set negative data length: %d", n))
|
||||
}
|
||||
if n > len(w.data) {
|
||||
panic(fmt.Errorf("data length %d exceeds the number of bytes written (%d)", n, len(w.data)))
|
||||
}
|
||||
w.data = w.data[:n]
|
||||
}
|
||||
|
||||
// Len returns the number of bytes available to read.
|
||||
func (r *requestReader) Len() int {
|
||||
return len(r.data)
|
||||
}
|
||||
|
||||
// Bytes returns the request's payload data.
|
||||
func (r *requestReader) Bytes() []byte {
|
||||
return r.data[:len(r.data):len(r.data)] // prevent reslicing beyond len
|
||||
}
|
||||
|
||||
// RemoteAddrPort returns the request's remote address as a [netip.AddrPort].
|
||||
func (r *requestReader) RemoteAddrPort() (netip.AddrPort, error) {
|
||||
return r.raddr.ToAddrPort()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the request's remote address as an [rawSockaddr].
|
||||
// It is more efficient than [requestReader.RemoteAddrPort] if the caller
|
||||
// only needs the raw socket address.
|
||||
func (r *requestReader) RemoteAddr() rawSockaddr {
|
||||
return r.raddr
|
||||
}
|
||||
310
net/rioconn/request_ring.go
Normal file
310
net/rioconn/request_ring.go
Normal file
@ -0,0 +1,310 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"math"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
"golang.org/x/sys/windows"
|
||||
"tailscale.com/util/winutil"
|
||||
)
|
||||
|
||||
// requestRing is a circular buffer of [request]s.
|
||||
// It is not safe for concurrent use.
|
||||
type requestRing struct {
|
||||
capacity uint32 // number of requests in the ring; always a power of two
|
||||
dataSize uint32 // per-request data buffer length in bytes
|
||||
stride uintptr // byte offset from one request to the next
|
||||
indexMask uint32 // masks an index to [0, capacity)
|
||||
|
||||
ptr uintptr // base address of the ring buffer allocation
|
||||
size uintptr // size of the allocation in bytes
|
||||
buff []byte // a byte slice view of the allocated buffer
|
||||
largePages bool // whether the allocation used large pages
|
||||
|
||||
head, tail uint32 // monotonic counters; apply &indexMask for indexing
|
||||
|
||||
id winrio.BufferId
|
||||
}
|
||||
|
||||
// When a memory buffer is registered with RIO, the virtual memory pages
|
||||
// containing the buffer are locked into physical memory.
|
||||
// Set hard limits to avoid excessive memory usage.
|
||||
// TODO(nickkhyl): derive preferred values from system parameters?
|
||||
const (
|
||||
maxRequestRingSize = 1 << 30 // 1 GiB; arbitrary limit
|
||||
maxNumberOfRequests = 16384 // arbitrary power-of-two limit
|
||||
maxRequestDataSize = math.MaxUint16 // IP packet size limit
|
||||
)
|
||||
|
||||
// requestStride returns the byte offset from the start of one [request]
|
||||
// in a [requestRing] to the start of the next, given the per-request data buffer length.
|
||||
func requestStride(dataSize uint16) uintptr {
|
||||
return totalRequestSize(uintptr(dataSize))
|
||||
}
|
||||
|
||||
// maxRequestRingCapacity returns the maximum ring buffer capacity that fits
|
||||
// within maxBytes, given the per-request data size. The returned capacity
|
||||
// does not exceed idealCapacity unless idealCapacity is zero, in which case the
|
||||
// maximum possible value is returned.
|
||||
//
|
||||
// dataSize must be in (0, [maxRequestDataSize]], and idealCapacity, if non-zero,
|
||||
// must be a power of two. The function returns an error if the buffer cannot
|
||||
//
|
||||
// It returns an error if the parameters are invalid, if the resulting
|
||||
// capacity cannot hold at least one request, or if maxBytes exceeds
|
||||
// [maxRequestRingSize].
|
||||
func maxRequestRingCapacity(idealCapacity uint32, dataSize uint16, maxBytes uintptr) (uint32, error) {
|
||||
if idealCapacity != 0 && !isPowerOfTwo(idealCapacity) {
|
||||
return 0, fmt.Errorf("the capacity must be a power of two, got %d", idealCapacity)
|
||||
}
|
||||
capacity := min(idealCapacity, maxNumberOfRequests)
|
||||
|
||||
if dataSize == 0 || dataSize > maxRequestDataSize {
|
||||
return 0, fmt.Errorf("the data size must be in (0, %d], got %d", maxRequestDataSize, dataSize)
|
||||
}
|
||||
|
||||
stride := requestStride(dataSize)
|
||||
if maxBytes < stride {
|
||||
return 0, fmt.Errorf("cannot fit any requests within maxBytes %d", maxBytes)
|
||||
}
|
||||
if maxBytes > maxRequestRingSize {
|
||||
return 0, fmt.Errorf("maxBytes %d exceeds limit of %d", maxBytes, maxRequestRingSize)
|
||||
}
|
||||
|
||||
if capacity == 0 || uintptr(capacity)*stride > maxBytes {
|
||||
capacity = uint32(floorPowerOfTwo(maxBytes / stride))
|
||||
}
|
||||
return capacity, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// seLockMemoryPrivilege is the name of the Windows privilege required to allocate large pages.
|
||||
seLockMemoryPrivilege = "SeLockMemoryPrivilege"
|
||||
)
|
||||
|
||||
// newRequestRing creates a ring buffer of up to maxBytes bytes,
|
||||
// with each element representing a RIO request backed by a data buffer
|
||||
// of dataSize bytes.
|
||||
//
|
||||
// It determines the buffer capacity as the maximum power-of-two number
|
||||
// of requests that fits within the allocation size limit, using large
|
||||
// pages when possible.
|
||||
//
|
||||
// If the allocation fails due to insufficient memory, it retries
|
||||
// with progressively smaller sizes until it succeeds or cannot fit
|
||||
// at least one request.
|
||||
//
|
||||
// The returned buffer is registered with RIO and must be closed with
|
||||
// [requestRing.Close] to unregister it and free the memory.
|
||||
func newRequestRing(dataSize uint16, maxBytes uintptr) (_ *requestRing, err error) {
|
||||
rb := &requestRing{
|
||||
dataSize: uint32(dataSize),
|
||||
stride: requestStride(dataSize),
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
rb.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
var largePageSize uintptr // 0 means "do not use large pages"
|
||||
// The SeLockMemoryPrivilege privilege is required to allocate large pages.
|
||||
// By default, this privilege can be requested only by processes
|
||||
// running as Local System, such as the Tailscale service,
|
||||
// and is not available to regular user processes (e.g., when running tests).
|
||||
// If enabling the privilege fails, we fall back to normal pages.
|
||||
// For testing, you can grant the privilege to your user account
|
||||
// using the Local Security Policy management console (secpol.msc).
|
||||
dropPrivs, err := winutil.EnableCurrentThreadPrivilege(seLockMemoryPrivilege)
|
||||
if err == nil {
|
||||
defer dropPrivs()
|
||||
largePageSize = windows.GetLargePageMinimum()
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
capacity, err := maxRequestRingCapacity(0, dataSize, maxBytes)
|
||||
if err != nil {
|
||||
// The requested parameters are invalid.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rb.capacity = uint32(capacity)
|
||||
rb.indexMask = uint32(capacity - 1)
|
||||
rb.size = rb.stride * uintptr(capacity)
|
||||
|
||||
var largePageFlags uint32
|
||||
if largePageSize != 0 {
|
||||
if alignedSize := alignUp(rb.size, largePageSize); alignedSize <= maxBytes {
|
||||
largePageFlags = windows.MEM_LARGE_PAGES
|
||||
rb.size = alignedSize
|
||||
}
|
||||
}
|
||||
|
||||
rb.ptr, err = windows.VirtualAlloc(
|
||||
0, // no preferred address
|
||||
rb.size,
|
||||
windows.MEM_COMMIT|windows.MEM_RESERVE|largePageFlags,
|
||||
windows.PAGE_READWRITE,
|
||||
)
|
||||
switch err {
|
||||
case nil:
|
||||
// Allocation succeeded.
|
||||
rb.buff = unsafe.Slice((*byte)(unsafe.Pointer(rb.ptr)), rb.size)
|
||||
rb.largePages = largePageFlags != 0
|
||||
break loop
|
||||
case windows.ERROR_NOT_ENOUGH_MEMORY:
|
||||
// Try again with a smaller buffer.
|
||||
maxBytes /= 2
|
||||
continue
|
||||
case windows.ERROR_NO_SYSTEM_RESOURCES, windows.ERROR_PRIVILEGE_NOT_HELD:
|
||||
// Cannot use large pages, try again without them.
|
||||
largePageSize = 0
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to allocate request ring buffer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// The actual RIO initialization is guarded by [sync.Once], and we usually
|
||||
// perform it much earlier. We check it here as well to ensure that calling
|
||||
// [winrio.RegisterPointer] won't panic (e.g., in tests).
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register the allocated buffer with RIO.
|
||||
if rb.id, err = winrio.RegisterPointer(unsafe.Pointer(unsafe.SliceData(rb.buff)), uint32(rb.size)); err != nil {
|
||||
return nil, fmt.Errorf("failed to register request ring buffer with RIO: %w", err)
|
||||
}
|
||||
|
||||
// Initialize each request in the ring.
|
||||
for i := uintptr(0); i < uintptr(rb.capacity); i++ {
|
||||
_, bytesUsed, ok := unsafeMapRequest(rb.id, rb.buff, i*rb.stride, uintptr(dataSize))
|
||||
if !ok || bytesUsed != rb.stride {
|
||||
// This should never happen.
|
||||
panic("failed to map request in newly created request ring")
|
||||
}
|
||||
}
|
||||
return rb, nil
|
||||
}
|
||||
|
||||
// newRequestRingWithCapacity creates a ring buffer with the specified
|
||||
// power-of-two capacity and per-request data length.
|
||||
//
|
||||
// If allocating that many requests exceeds the maximum allowed request ring
|
||||
// size or fails due to insufficient memory, it retries with progressively
|
||||
// smaller sizes until it succeeds or cannot fit at least one request.
|
||||
//
|
||||
// The caller is responsible for calling [requestRing.close] to free
|
||||
// the allocated memory when done.
|
||||
func newRequestRingWithCapacity(dataSize uint16, capacity uint32) (*requestRing, error) {
|
||||
if !isPowerOfTwo(capacity) {
|
||||
return nil, fmt.Errorf("capacity must be a power of two, got %d", capacity)
|
||||
}
|
||||
maxSizeInBytes := requestStride(dataSize) * uintptr(capacity)
|
||||
return newRequestRing(dataSize, maxSizeInBytes)
|
||||
}
|
||||
|
||||
// Cap returns the total number of requests the ring can hold.
|
||||
func (rb *requestRing) Cap() uint32 {
|
||||
return rb.capacity
|
||||
}
|
||||
|
||||
// Len returns the number requests currently in use (acquired and not yet released).
|
||||
func (rb *requestRing) Len() uint32 {
|
||||
return rb.tail - rb.head
|
||||
}
|
||||
|
||||
// IsEmpty reports whether no requests are currently in use.
|
||||
func (rb *requestRing) IsEmpty() bool {
|
||||
return rb.head == rb.tail
|
||||
}
|
||||
|
||||
// IsFull reports whether all requests are currently in use.
|
||||
func (rb *requestRing) IsFull() bool {
|
||||
return rb.Len() == rb.Cap()
|
||||
}
|
||||
|
||||
// Peek returns the next request without advancing the tail.
|
||||
// It panics if [requestRing.IsFull] reports true.
|
||||
func (rb *requestRing) Peek() *request {
|
||||
if rb.IsFull() {
|
||||
panic("ring is full")
|
||||
}
|
||||
return rb.peek()
|
||||
}
|
||||
|
||||
// Advance marks the next request as in use by advancing the tail.
|
||||
// It panics if [requestRing.IsFull] reports true.
|
||||
func (rb *requestRing) Advance() {
|
||||
if rb.IsFull() {
|
||||
panic("ring is full")
|
||||
}
|
||||
rb.tail += 1
|
||||
}
|
||||
|
||||
// Acquire returns the next [request] from the ring, advancing the tail.
|
||||
// It panics if [requestRing.IsFull] reports true.
|
||||
func (rb *requestRing) Acquire() *request {
|
||||
req := rb.Peek()
|
||||
rb.tail += 1
|
||||
return req
|
||||
}
|
||||
|
||||
// AcquireSeq yields available [request]s one by one until the ring
|
||||
// runs out of unused requests or the caller stops the iteration.
|
||||
func (rb *requestRing) AcquireSeq() iter.Seq[*request] {
|
||||
return func(yield func(req *request) bool) {
|
||||
end := rb.head + rb.capacity
|
||||
for ; rb.tail != end; rb.tail++ {
|
||||
if !yield(rb.peek()) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *requestRing) peek() *request {
|
||||
offset := uintptr(rb.tail&rb.indexMask) * rb.stride
|
||||
ptr := unsafe.SliceData(rb.buff[offset : offset+rb.stride])
|
||||
req := (*request)(unsafe.Pointer(ptr))
|
||||
req.Reset()
|
||||
return req
|
||||
}
|
||||
|
||||
// ReleaseN marks n requests at the head of the ring as free.
|
||||
// It is a run-time error to release more requests than have been acquired.
|
||||
func (rb *requestRing) ReleaseN(n int) {
|
||||
if n < 0 {
|
||||
panic("ring: negative release count")
|
||||
}
|
||||
if uint64(n) > uint64(rb.Len()) {
|
||||
panic("ring: releasing more requests than acquired")
|
||||
}
|
||||
rb.head += uint32(n)
|
||||
}
|
||||
|
||||
// Close frees the memory allocated for the request ring.
|
||||
func (rb *requestRing) Close() error {
|
||||
if rb.id != 0 {
|
||||
winrio.DeregisterBuffer(rb.id)
|
||||
rb.id = 0
|
||||
}
|
||||
|
||||
if rb.ptr != 0 {
|
||||
if err := windows.VirtualFree(rb.ptr, 0, windows.MEM_RELEASE); err != nil {
|
||||
return fmt.Errorf("failed to free request ring buffer: %w", err)
|
||||
}
|
||||
rb.ptr = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
507
net/rioconn/request_ring_test.go
Normal file
507
net/rioconn/request_ring_test.go
Normal file
@ -0,0 +1,507 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestMaxNumberOfRequestsIsPow2(t *testing.T) {
|
||||
if !isPowerOfTwo(maxNumberOfRequests) {
|
||||
t.Fatalf("maxNumberOfRRequests %d is not a power of two", maxNumberOfRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloorPowerOfTwo(t *testing.T) {
|
||||
tests := []struct {
|
||||
n uint64
|
||||
want uint64
|
||||
}{
|
||||
{0, 0},
|
||||
{1, 1},
|
||||
{16, 16},
|
||||
{17, 16},
|
||||
{31, 16},
|
||||
{32, 32},
|
||||
{uint64(1 << 63), uint64(1 << 63)},
|
||||
{uint64(1<<64 - 1), uint64(1) << 63},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%x", tt.n), func(t *testing.T) {
|
||||
if got := floorPowerOfTwo(tt.n); got != tt.want {
|
||||
t.Fatalf("got %d; want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxRingBufferCapacity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idealCapacity uint32
|
||||
dataSize uint16
|
||||
maxBytes uintptr
|
||||
wantCapacity uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid/not-pow2",
|
||||
idealCapacity: 3,
|
||||
dataSize: 512,
|
||||
maxBytes: 65536,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid/data-length-zero",
|
||||
idealCapacity: 16,
|
||||
dataSize: 0,
|
||||
maxBytes: 65536,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid/max-bytes-too-small",
|
||||
idealCapacity: 16,
|
||||
dataSize: 512,
|
||||
maxBytes: requestStride(512) - 1, // less than one request
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid/no-clamp",
|
||||
idealCapacity: 16,
|
||||
dataSize: 512,
|
||||
maxBytes: 65536, // can fit [0; 128) requests
|
||||
wantCapacity: 16, // and we asked for 16
|
||||
},
|
||||
{
|
||||
name: "valid/exact-fit",
|
||||
idealCapacity: 16,
|
||||
dataSize: 512,
|
||||
maxBytes: requestStride(512) * 16,
|
||||
wantCapacity: 16,
|
||||
},
|
||||
{
|
||||
name: "valid/clamp-down",
|
||||
idealCapacity: 128,
|
||||
dataSize: 512,
|
||||
maxBytes: requestStride(512) * 64, // can fit only 64 requests
|
||||
wantCapacity: 64, // clamps down to 64
|
||||
},
|
||||
{
|
||||
name: "valid/max-requests",
|
||||
idealCapacity: 0, // want as many as possible
|
||||
dataSize: 512,
|
||||
maxBytes: 65536, // can fit [0; 128) requests
|
||||
wantCapacity: 64, // the max power of two that fits
|
||||
},
|
||||
{
|
||||
name: "valid/large-buffer/no-clamp",
|
||||
idealCapacity: 8192,
|
||||
dataSize: maxRequestDataSize,
|
||||
maxBytes: requestStride(maxRequestDataSize) * 8192,
|
||||
wantCapacity: 8192,
|
||||
},
|
||||
{
|
||||
name: "valid/large-buffer/clamp-down",
|
||||
idealCapacity: 8192,
|
||||
dataSize: maxRequestDataSize,
|
||||
maxBytes: requestStride(maxRequestDataSize) * 8191, // can fit only 8191
|
||||
wantCapacity: 4096, // clamps to next lower power of two
|
||||
},
|
||||
{
|
||||
name: "invalid/too-many-requests",
|
||||
idealCapacity: maxNumberOfRequests * 2,
|
||||
dataSize: 512,
|
||||
maxBytes: 1 << 30,
|
||||
wantCapacity: maxNumberOfRequests, // cannot exceed [maxNumberOfRRequests]
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := maxRequestRingCapacity(tt.idealCapacity, tt.dataSize, uintptr(tt.maxBytes))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("maxRingBufferCapacity error: got %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantCapacity {
|
||||
t.Fatalf("maxRingBufferCapacity: got %v; want %v", got, tt.wantCapacity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRingBuffer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dataSize uint16
|
||||
maxSizeInBytes uintptr
|
||||
wantCapacity uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "small-buffer",
|
||||
dataSize: 256,
|
||||
maxSizeInBytes: requestStride(256) * 4,
|
||||
wantCapacity: 4,
|
||||
},
|
||||
{
|
||||
name: "large-buffer/small-requests",
|
||||
dataSize: 1280,
|
||||
maxSizeInBytes: requestStride(1280) * 8192,
|
||||
wantCapacity: 8192,
|
||||
},
|
||||
{
|
||||
name: "large-buffer/large-requests",
|
||||
dataSize: 65535,
|
||||
maxSizeInBytes: requestStride(65535) * 32,
|
||||
wantCapacity: 32,
|
||||
},
|
||||
{
|
||||
name: "invalid/limit-too-small",
|
||||
dataSize: 256,
|
||||
maxSizeInBytes: 100,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rb, err := newRequestRing(tt.dataSize, tt.maxSizeInBytes)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("newRingBuffer error: got %v; wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := rb.Close(); err != nil {
|
||||
t.Fatalf("ringBuffer.close() failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if gotCapacity := rb.Cap(); gotCapacity != tt.wantCapacity {
|
||||
t.Errorf("ringBuffer.Cap() = %v; want %v", gotCapacity, tt.wantCapacity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferAcquireRelease(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
capacity uint32
|
||||
initialHead uint32
|
||||
initialTail uint32
|
||||
acquires int
|
||||
wantAcquirePanic bool
|
||||
releaseN int
|
||||
wantReleasePanic bool
|
||||
wantHead uint32
|
||||
wantTail uint32
|
||||
wantDepth int
|
||||
}{
|
||||
{
|
||||
name: "empty/acquire-one",
|
||||
capacity: 16,
|
||||
acquires: 1,
|
||||
wantAcquirePanic: false,
|
||||
wantTail: 1,
|
||||
wantDepth: 1,
|
||||
},
|
||||
{
|
||||
name: "empty/acquire-few",
|
||||
capacity: 16,
|
||||
acquires: 4,
|
||||
wantAcquirePanic: false,
|
||||
wantTail: 4,
|
||||
wantDepth: 4,
|
||||
},
|
||||
{
|
||||
name: "empty/acquire-all",
|
||||
capacity: 16,
|
||||
acquires: 16,
|
||||
wantAcquirePanic: false,
|
||||
wantTail: 16,
|
||||
wantDepth: 16,
|
||||
},
|
||||
{
|
||||
name: "empty/acquire-too-many",
|
||||
capacity: 16,
|
||||
acquires: 17, // one more than the buffer can hold
|
||||
wantAcquirePanic: true,
|
||||
wantTail: 16,
|
||||
wantDepth: 16,
|
||||
},
|
||||
{
|
||||
name: "empty/release-one",
|
||||
capacity: 16,
|
||||
releaseN: 1,
|
||||
wantReleasePanic: true,
|
||||
},
|
||||
{
|
||||
name: "partially-full/acquire-few",
|
||||
capacity: 16,
|
||||
initialTail: 8,
|
||||
acquires: 4,
|
||||
wantAcquirePanic: false,
|
||||
wantTail: 12,
|
||||
wantDepth: 8,
|
||||
},
|
||||
{
|
||||
name: "partially-full/release-few",
|
||||
capacity: 16,
|
||||
initialTail: 8,
|
||||
releaseN: 4,
|
||||
wantReleasePanic: false,
|
||||
wantHead: 4,
|
||||
wantTail: 8,
|
||||
wantDepth: 4,
|
||||
},
|
||||
{
|
||||
name: "partially-full/acquire-all/wrap-around",
|
||||
capacity: 16,
|
||||
initialHead: 4,
|
||||
initialTail: 8,
|
||||
acquires: 12,
|
||||
wantAcquirePanic: false,
|
||||
wantHead: 4,
|
||||
wantTail: 20,
|
||||
wantDepth: 16,
|
||||
},
|
||||
{
|
||||
name: "partially-full/acquire-too-many",
|
||||
capacity: 16,
|
||||
initialHead: 4,
|
||||
initialTail: 8,
|
||||
acquires: 13, // one more than can fit
|
||||
wantAcquirePanic: true,
|
||||
wantTail: 20,
|
||||
wantHead: 4,
|
||||
wantDepth: 16,
|
||||
},
|
||||
{
|
||||
name: "partially-full/release-too-many",
|
||||
capacity: 16,
|
||||
initialHead: 4,
|
||||
initialTail: 8,
|
||||
releaseN: 9, // one more than acquired
|
||||
wantReleasePanic: true,
|
||||
wantHead: 4,
|
||||
wantTail: 8,
|
||||
},
|
||||
{
|
||||
name: "full/acquire-one",
|
||||
capacity: 16,
|
||||
initialTail: 16,
|
||||
acquires: 1,
|
||||
wantAcquirePanic: true,
|
||||
wantTail: 16,
|
||||
},
|
||||
{
|
||||
name: "full/release-one",
|
||||
capacity: 16,
|
||||
initialTail: 16,
|
||||
releaseN: 1,
|
||||
wantReleasePanic: false,
|
||||
wantHead: 1,
|
||||
wantTail: 16,
|
||||
},
|
||||
{
|
||||
name: "full/release-all",
|
||||
capacity: 16,
|
||||
initialTail: 16,
|
||||
releaseN: 16,
|
||||
wantReleasePanic: false,
|
||||
wantHead: 16,
|
||||
wantTail: 16,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rb, err := newRequestRingWithCapacity(4, tt.capacity)
|
||||
if err != nil {
|
||||
t.Fatalf("newRingBuffer failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { rb.Close() })
|
||||
|
||||
rb.head = uint32(tt.initialHead)
|
||||
rb.tail = uint32(tt.initialTail)
|
||||
|
||||
CheckPanic(t, tt.wantAcquirePanic, func() {
|
||||
for range tt.acquires {
|
||||
_ = rb.Acquire()
|
||||
}
|
||||
})
|
||||
|
||||
CheckPanic(t, tt.wantReleasePanic, func() {
|
||||
rb.ReleaseN(tt.releaseN)
|
||||
})
|
||||
|
||||
if rb.head != tt.wantHead {
|
||||
t.Fatalf("rb.head = %d; want %d", rb.head, tt.wantHead)
|
||||
}
|
||||
if rb.tail != tt.wantTail {
|
||||
t.Fatalf("rb.tail = %d; want %d", rb.tail, tt.wantTail)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferAcquireSeq(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
capacity uint32
|
||||
initialHead uint32
|
||||
initialTail uint32
|
||||
wantTail uint32
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
capacity: 16,
|
||||
initialHead: 0,
|
||||
initialTail: 0,
|
||||
wantTail: 16,
|
||||
wantCount: 16, // 16 requests to acquire: [0; 16)
|
||||
},
|
||||
{
|
||||
name: "partially-full",
|
||||
capacity: 16,
|
||||
initialHead: 4,
|
||||
initialTail: 10,
|
||||
wantTail: 20,
|
||||
wantCount: 10, // 10 requests to acquire: [10; 15) and [0; 4)
|
||||
},
|
||||
{
|
||||
name: "full",
|
||||
capacity: 16,
|
||||
initialHead: 0,
|
||||
initialTail: 16,
|
||||
wantTail: 16,
|
||||
wantCount: 0, // nothing to acquire
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rb, err := newRequestRingWithCapacity(4, tt.capacity)
|
||||
if err != nil {
|
||||
t.Fatalf("newRingBuffer failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { rb.Close() })
|
||||
|
||||
// Initialize head and tail.
|
||||
rb.head = tt.initialHead
|
||||
rb.tail = tt.initialTail
|
||||
|
||||
// Acquire all available requests and count how many we got.
|
||||
gotCount := len(slices.Collect(rb.AcquireSeq()))
|
||||
|
||||
// Check that we got the expected count, the head didn't change,
|
||||
// and the tail advanced as expected.
|
||||
if gotCount != tt.wantCount {
|
||||
t.Fatalf("gotCount = %d; want %d", gotCount, tt.wantCount)
|
||||
}
|
||||
if rb.head != tt.initialHead {
|
||||
t.Fatalf("rb.head = %d; want %d", rb.head, tt.initialHead)
|
||||
}
|
||||
if rb.tail != tt.wantTail {
|
||||
t.Fatalf("rb.tail = %d; want %d", rb.tail, tt.wantTail)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferWrapAround(t *testing.T) {
|
||||
const capacity = 16
|
||||
rb, err := newRequestRingWithCapacity(4, capacity)
|
||||
if err != nil {
|
||||
t.Fatalf("newRingBuffer failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { rb.Close() })
|
||||
|
||||
// Acquire all requests and store pointers to them in a slice.
|
||||
requests := slices.Collect(rb.AcquireSeq())
|
||||
if len(requests) != capacity {
|
||||
t.Fatalf("acquired %d requests; want %d", len(requests), capacity)
|
||||
}
|
||||
|
||||
rb.ReleaseN(capacity) // release all
|
||||
|
||||
// Acquire again and ensure we get the same requests in the same order.
|
||||
for i, wantReq := range requests {
|
||||
if gotReq := rb.Acquire(); gotReq != wantReq {
|
||||
t.Fatalf("acquired request %d = %p; want %p", i, gotReq, wantReq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CheckPanic checks whether the given function panics or not,
|
||||
// and fails the test if the result does not match wantPanic.
|
||||
func CheckPanic(t *testing.T, wantPanic bool, fn func()) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if !wantPanic {
|
||||
t.Fatalf("unexpected panic: %v", r)
|
||||
}
|
||||
} else if wantPanic {
|
||||
t.Fatal("expected panic but none occurred")
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
|
||||
// checkProcessPrivilege reports whether the given privilege is present
|
||||
// and/or enabled in the current process token.
|
||||
func checkProcessPrivilege(privName string) (present bool, enabled bool, err error) {
|
||||
var tok windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &tok); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
defer tok.Close()
|
||||
return checkTokenPrivilege(tok, privName)
|
||||
}
|
||||
|
||||
// checkTokenPrivilege reports whether the given privilege is present
|
||||
// and/or enabled in the specified token.
|
||||
func checkTokenPrivilege(token windows.Token, privName string) (present bool, enabled bool, err error) {
|
||||
var luid windows.LUID
|
||||
namePtr, err := windows.UTF16PtrFromString(privName)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
var needed uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenPrivileges, nil, 0, &needed)
|
||||
if needed == 0 {
|
||||
return false, false, windows.GetLastError()
|
||||
}
|
||||
|
||||
buf := make([]byte, needed)
|
||||
if err := windows.GetTokenInformation(token, windows.TokenPrivileges, &buf[0], uint32(len(buf)), &needed); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
tp := (*windows.Tokenprivileges)(unsafe.Pointer(&buf[0]))
|
||||
count := int(tp.PrivilegeCount)
|
||||
|
||||
laa := unsafe.Slice((*windows.LUIDAndAttributes)(unsafe.Pointer(&tp.Privileges[0])), count)
|
||||
for _, p := range laa {
|
||||
if p.Luid == luid {
|
||||
present = true
|
||||
enabled = (p.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0
|
||||
return present, enabled, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, false, nil
|
||||
}
|
||||
319
net/rioconn/request_test.go
Normal file
319
net/rioconn/request_test.go
Normal file
@ -0,0 +1,319 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build windows
|
||||
|
||||
package rioconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestUnsafeMapRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buff []byte
|
||||
offset uintptr
|
||||
dataSize uintptr
|
||||
wantOk bool
|
||||
wantOffset uintptr
|
||||
wantBytesUsed uintptr
|
||||
}{
|
||||
{
|
||||
name: "enough-space",
|
||||
buff: makeAligned(totalRequestSize(10), unsafe.Alignof(request{})),
|
||||
dataSize: 10,
|
||||
wantOk: true,
|
||||
wantOffset: 0,
|
||||
wantBytesUsed: totalRequestSize(10),
|
||||
},
|
||||
{
|
||||
name: "enough-space/unaligned-buffer",
|
||||
buff: makeMisaligned(totalRequestSize(10)+2, unsafe.Alignof(request{}), 2),
|
||||
offset: 0,
|
||||
dataSize: 10,
|
||||
wantOk: true,
|
||||
wantOffset: 2, // the request starts at offset 2 for proper alignment
|
||||
wantBytesUsed: totalRequestSize(10) + 2, // includes padding before the request
|
||||
},
|
||||
{
|
||||
name: "enough-space/unaligned-buffer/aligned-offset",
|
||||
buff: makeMisaligned(totalRequestSize(10)+2, unsafe.Alignof(request{}), 2),
|
||||
offset: 2,
|
||||
dataSize: 10,
|
||||
wantOk: true,
|
||||
wantOffset: 2, // same as offset requested
|
||||
wantBytesUsed: totalRequestSize(10), // no extra padding needed
|
||||
},
|
||||
{
|
||||
name: "enough-space/aligned-buffer/non-zero-offset",
|
||||
buff: makeAligned(totalRequestSize(10)+16, unsafe.Alignof(request{})),
|
||||
offset: 16,
|
||||
dataSize: 10,
|
||||
wantOk: true,
|
||||
wantOffset: 16,
|
||||
wantBytesUsed: totalRequestSize(10),
|
||||
},
|
||||
{
|
||||
name: "not-enough-space/nil-buffer",
|
||||
buff: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "not-enough-space/small-buffer",
|
||||
buff: makeAligned(totalRequestSize(10)-1, unsafe.Alignof(request{})),
|
||||
dataSize: 10,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "not-enough-space/due-to-offset",
|
||||
buff: makeAligned(totalRequestSize(10), unsafe.Alignof(request{})),
|
||||
offset: 8,
|
||||
dataSize: 10,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
request, bytesUsed, ok := unsafeMapRequest(0, tt.buff, tt.offset, tt.dataSize)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("mapRequest: ok: got %v; want %v", ok, tt.wantOk)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if bytesUsed != tt.wantBytesUsed {
|
||||
t.Errorf("mapRequest: bytesUsed: got %d; want %d", bytesUsed, tt.wantBytesUsed)
|
||||
}
|
||||
gotOffset := uintptr(unsafe.Pointer(request)) - uintptr(unsafe.Pointer(&tt.buff[0]))
|
||||
if gotOffset != tt.wantOffset {
|
||||
t.Errorf("mapRequest: offset: got %d; want %d", gotOffset, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestWriterReserve(t *testing.T) {
|
||||
const dataSize = 1312
|
||||
req := makeRequest(t, dataSize)
|
||||
writer := req.Writer()
|
||||
|
||||
// The capacity of the writer's data buffer must match the requested data size.
|
||||
if gotCap := writer.Cap(); gotCap != dataSize {
|
||||
t.Errorf("requestWriter.Cap: got %d; want %d", gotCap, dataSize)
|
||||
}
|
||||
|
||||
// Initially, the entire data buffer must be available for writing.
|
||||
if gotAvail := writer.Available(); gotAvail != dataSize {
|
||||
t.Errorf("requestWriter.Available: got %d; want %d", gotAvail, dataSize)
|
||||
}
|
||||
|
||||
// And the length of the data buffer must be zero
|
||||
// since nothing has been written or reserved yet.
|
||||
if gotLen := writer.Len(); gotLen != 0 {
|
||||
t.Errorf("requestWriter.Len: got %d; want %d", gotLen, 0)
|
||||
}
|
||||
|
||||
// Reserving more than the available capacity must fail.
|
||||
CheckPanic(t, true, func() { writer.Reserve(dataSize + 1) })
|
||||
if gotAvail := writer.Available(); gotAvail != dataSize {
|
||||
t.Errorf("requestWriter.Available after failed Reserve: got %d; want %d",
|
||||
gotAvail, dataSize)
|
||||
}
|
||||
|
||||
// Reserving zero bytes must return an empty buffer.
|
||||
buf := writer.Reserve(0)
|
||||
if len(buf) != 0 {
|
||||
t.Errorf("Reserve: got buffer of length %d; want %d", len(buf), 0)
|
||||
}
|
||||
// The capacity of the returned buffer must also be zero.
|
||||
if cap(buf) != 0 {
|
||||
t.Errorf("Reserve: got buffer of capacity %d; want %d", cap(buf), 0)
|
||||
}
|
||||
|
||||
// Reserving part of the available capacity should succeed.
|
||||
buf = writer.Reserve(64)
|
||||
if len(buf) != 64 {
|
||||
t.Errorf("Reserve: got buffer of length %d; want %d", len(buf), 64)
|
||||
}
|
||||
// The capacity of the returned buffer should be equal to its length
|
||||
// to prevent reslicing beyond the reserved length.
|
||||
if cap(buf) != len(buf) {
|
||||
t.Errorf("Reserve: got buffer of capacity %d; want %d", cap(buf), len(buf))
|
||||
}
|
||||
// The returned buffer should point to the start of the writer's data buffer.
|
||||
if &buf[0] != &writer.data[0] {
|
||||
t.Errorf("Reserve: got buffer starting at %p; want %p", &buf[0], &writer.data[0])
|
||||
}
|
||||
// After reserving X bytes, the length of the data buffer should be X.
|
||||
if writer.Len() != len(buf) {
|
||||
t.Errorf("requestWriter.Len after Reserve: got %d; want %d", writer.Len(), len(buf))
|
||||
}
|
||||
// And the available capacity should decrease by X.
|
||||
if gotAvail := writer.Available(); gotAvail != dataSize-len(buf) {
|
||||
t.Errorf("requestWriter.Available after Reserve: got %d; want %d",
|
||||
gotAvail, dataSize-len(buf))
|
||||
}
|
||||
// However, the total capacity of the writer's data buffer should remain unchanged.
|
||||
if cap(writer.data) != dataSize {
|
||||
t.Errorf("requestWriter.data capacity: got %d; want %d", cap(writer.data), dataSize)
|
||||
}
|
||||
|
||||
bytesWritten := copy(buf, []byte("Hello World"))
|
||||
|
||||
// SetLen must panic if the desired length exceeds the number of bytes
|
||||
// written or reserved so far.
|
||||
CheckPanic(t, true, func() { writer.SetLen(len(buf) + 1) })
|
||||
// Or if the desired length is negative.
|
||||
CheckPanic(t, true, func() { writer.SetLen(-1) })
|
||||
|
||||
// Otherwise, it must succeed and update the length of the data buffer accordingly.
|
||||
writer.SetLen(len(buf))
|
||||
if writer.Len() != len(buf) { // no change expected
|
||||
t.Errorf("requestWriter.Len after SetLen: got %d; want %d", writer.Len(), len(buf))
|
||||
}
|
||||
|
||||
// The contents of the data buffer up to the set length should be what was written.
|
||||
writer.SetLen(bytesWritten)
|
||||
if !bytes.Equal(writer.data, []byte("Hello World")) {
|
||||
t.Errorf("requestWriter.data: got %q; want %q", writer.data, "Hello World")
|
||||
}
|
||||
|
||||
// Reserving more bytes should succeed and return a buffer starting immediately
|
||||
// after the previously reserved and written data.
|
||||
newBuf := writer.Reserve(10)
|
||||
if &newBuf[0] != &writer.data[bytesWritten] {
|
||||
t.Errorf("Reserve: got buffer starting at %p; want %p", &newBuf[0], &writer.data[bytesWritten])
|
||||
}
|
||||
if writer.Len() != bytesWritten+len(newBuf) {
|
||||
t.Errorf("requestWriter.Len after second Reserve: got %d; want %d",
|
||||
writer.Len(), bytesWritten+len(newBuf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestWriterWrite(t *testing.T) {
|
||||
const dataSize = 128
|
||||
req := makeRequest(t, dataSize)
|
||||
writer := req.Writer()
|
||||
n, err := writer.Write([]byte("Hello"))
|
||||
if err != nil {
|
||||
t.Errorf("requestWriter.Write: unexpected error: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("requestWriter.Write: got %d bytes written; want %d", n, 5)
|
||||
}
|
||||
writer.Write([]byte(" World"))
|
||||
if !bytes.Equal(writer.data[:11], []byte("Hello World")) {
|
||||
t.Errorf("requestWriter.data: got %q; want %q", writer.data[:11], "Hello World")
|
||||
}
|
||||
if writer.Len() != 11 {
|
||||
t.Errorf("requestWriter.Len after Write: got %d; want %d", writer.Len(), 11)
|
||||
}
|
||||
// Writing more bytes than the available capacity should fail.
|
||||
n, err = writer.Write(make([]byte, dataSize))
|
||||
if err == nil {
|
||||
t.Error("requestWriter.Write: expected error when writing beyond capacity; got nil")
|
||||
}
|
||||
// We do not allow partial writes, so n should be zero...
|
||||
if n != 0 {
|
||||
t.Errorf("requestWriter.Write: got %d bytes written; want %d", n, 0)
|
||||
}
|
||||
// ... and the length should remain unchanged.
|
||||
if writer.Len() != 11 {
|
||||
t.Errorf("requestWriter.Len after failed Write: got %d; want %d", writer.Len(), 11)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestWriterSetAddrPort(t *testing.T) {
|
||||
const dataSize = 128
|
||||
addrPort := netip.MustParseAddrPort("192.0.2.1:1234")
|
||||
|
||||
req := makeRequest(t, dataSize)
|
||||
writer := req.Writer()
|
||||
err := writer.SetRemoteAddrPort(addrPort)
|
||||
if err != nil {
|
||||
t.Errorf("requestWriter.SetRemoteAddrPort: unexpected error: %v", err)
|
||||
}
|
||||
gotAddrPort, err := req.raddr.ToAddrPort()
|
||||
if err != nil {
|
||||
t.Errorf("request.raddr.ToAddrPort: unexpected error: %v", err)
|
||||
}
|
||||
if gotAddrPort != addrPort {
|
||||
t.Errorf("request.raddr: got %v; want %v", gotAddrPort, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestReader(t *testing.T) {
|
||||
const dataSize = 128
|
||||
req := makeRequest(t, dataSize)
|
||||
reader := req.Reader()
|
||||
if reader.Len() != 0 {
|
||||
t.Errorf("requestReader.Len: got %d; want %d", reader.Len(), 0)
|
||||
}
|
||||
if len(reader.Bytes()) != 0 {
|
||||
t.Errorf("requestReader.Bytes: got buffer of length %d; want %d", len(reader.Bytes()), 0)
|
||||
}
|
||||
|
||||
// Simulate receiving data by directly modifying the request's data buffer and length.
|
||||
req.data = req.data[:11]
|
||||
copy(req.data, []byte("Hello World"))
|
||||
|
||||
if !bytes.Equal(reader.Bytes(), []byte("Hello World")) {
|
||||
t.Errorf("requestReader.Bytes: got %q; want %q", reader.Bytes(), "Hello World")
|
||||
}
|
||||
if reader.Len() != 11 {
|
||||
t.Errorf("requestReader.Len: got %d; want %d", reader.Len(), 11)
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort("192.0.2.1:1234")
|
||||
req.raddr, _ = rawSockaddrFromAddrPort(addrPort)
|
||||
gotAddrPort, err := reader.RemoteAddrPort()
|
||||
if err != nil {
|
||||
t.Errorf("requestReader.RemoteAddrPort: unexpected error: %v", err)
|
||||
}
|
||||
if gotAddrPort != addrPort {
|
||||
t.Errorf("requestReader.RemoteAddrPort: got %v; want %v", gotAddrPort, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
func makeRequest(tb testing.TB, dataSize int) *request {
|
||||
tb.Helper()
|
||||
buff := makeAligned(totalRequestSize(uintptr(dataSize)), unsafe.Alignof(request{}))
|
||||
req, _, ok := unsafeMapRequest(0, buff, 0, uintptr(dataSize))
|
||||
if !ok {
|
||||
tb.Fatalf("failed to make request of size %d", dataSize)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// makeAligned returns a slice of n bytes such that the address of
|
||||
// the first byte is aligned to the given alignment.
|
||||
// Alignment must be a power of two.
|
||||
func makeAligned(n, alignment uintptr) []byte {
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
buff := make([]byte, n+alignment-1)
|
||||
base := uintptr(unsafe.Pointer(&buff[0]))
|
||||
offset := alignUpOffset(base, 0, alignment)
|
||||
return buff[offset : offset+uintptr(n)]
|
||||
}
|
||||
|
||||
// makeMisaligned returns a slice of n bytes whose start address is
|
||||
// misaligned by misalign bytes relative to alignment, such that advancing
|
||||
// the start address by misalign bytes yields an alignment-aligned address.
|
||||
// Alignment must be a power of two.
|
||||
func makeMisaligned(n, alignment, misalign uintptr) []byte {
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
buff := make([]byte, n+alignment+misalign)
|
||||
base := uintptr(unsafe.Pointer(&buff[0]))
|
||||
aligned := (base + misalign + alignment - 1) &^ (alignment - 1)
|
||||
start := aligned - misalign
|
||||
offset := start - base
|
||||
return buff[offset : offset+n]
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user