net/rioconn: add guard, a reference-counting/rundown-protection-like synchronization primitive

A RIO connection depends on unmanaged resources such as the socket handle, RIO buffer registrations,
and heap-allocated memory. We should protect these resources so we don’t free them while there are in-flight
calls that still use them, and prevent new operations from starting after the resources have been released.

In this commit, we introduce a guard, a synchronization primitive that does exactly that and will be used
in subsequent commits.

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:32:39 -06:00
parent e79262bced
commit d57b58193a
No known key found for this signature in database
2 changed files with 201 additions and 0 deletions

90
net/rioconn/guard.go Normal file
View File

@ -0,0 +1,90 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"sync"
"sync/atomic"
)
// guard prevents new operations from starting after Close
// while allowing in-flight operations to complete.
type guard struct {
state atomic.Int64
done chan struct{}
closeOnce sync.Once
}
const (
// state layout:
// bit 62: closed flag (1 == Close called)
// bits 061: in-flight operation count
//
// We avoid using bit 63 (the sign bit) so that valid states remain
// non-negative and the counter has a large positive range. guardCountMask
// isolates the counter bits and is used to detect unbalanced Release calls
// or counter overflow (which would wrap into the closed bit).
guardClosedBit = int64(1) << 62
guardCountMask = guardClosedBit - 1
)
func newGuard() *guard {
return &guard{done: make(chan struct{})}
}
// Acquire attempts to acquire a lease for an operation.
//
// If it reports true, the caller may proceed and must call
// [guard.Release] when done. Otherwise, it must not proceed.
//
// Acquire fails if [guard.Close] has already been called.
func (g *guard) Acquire() bool {
n := g.state.Add(1)
if n&guardClosedBit == 0 {
return true
}
g.decrementAndSignal()
return false
}
// Release releases a lease acquired by [guard.Acquire].
// It is a run-time error to call Release without a matching Acquire.
func (g *guard) Release() {
g.decrementAndSignal()
}
func (g *guard) decrementAndSignal() {
n := g.state.Add(-1)
if n < 0 || n == guardCountMask {
panic("unbalanced Release call")
}
if n == guardClosedBit {
g.closeOnce.Do(func() { close(g.done) })
}
}
// Close prevents future Acquire calls from succeeding.
func (g *guard) Close() {
g.state.Or(guardClosedBit)
}
// IsClosed reports whether Close has been called.
func (g *guard) IsClosed() bool {
return g.state.Load()&guardClosedBit != 0
}
// Wait blocks until all in-flight operations have called Release.
// It is a run-time error to call Wait before Close.
func (g *guard) Wait() {
state := g.state.Load()
if state&guardClosedBit == 0 {
panic("Wait called before Close")
}
if state == guardClosedBit {
return
}
<-g.done
}

111
net/rioconn/guard_test.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package rioconn
import (
"runtime"
"sync"
"testing"
)
func TestGuardCloseAndWait(t *testing.T) {
t.Parallel()
g := newGuard()
g.Close()
g.Wait()
}
func TestGuardAcquireReleaseCloseAndWait(t *testing.T) {
t.Parallel()
g := newGuard()
if !g.Acquire() {
t.Fatal("Acquire failed")
}
g.Release()
g.Close()
g.Wait()
}
func TestGuardAcquireCloseReleaseAndWait(t *testing.T) {
t.Parallel()
g := newGuard()
if !g.Acquire() {
t.Fatal("Acquire failed")
}
g.Close()
if g.Acquire() {
t.Fatal("Acquire succeeded after Close")
}
g.Release()
if g.Acquire() {
t.Fatal("Acquire succeeded after Release following Close")
}
g.Wait()
}
func TestGuardConcurrentUse(t *testing.T) {
t.Parallel()
const N = 1000
g := newGuard()
var wg sync.WaitGroup
wg.Add(N)
for range N {
go func() {
wg.Done()
if !g.Acquire() {
return
}
runtime.Gosched()
g.Release()
}()
}
wg.Wait() // wait for all goroutines to start
g.Close()
g.Wait()
}
func TestReleaseWithoutAcquire(t *testing.T) {
t.Parallel()
g := newGuard()
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on Release without Acquire")
}
}()
g.Release()
}
func TestReleaseWithoutAcquireAfterClose(t *testing.T) {
t.Parallel()
g := newGuard()
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on Release without Acquire")
}
}()
g.Close()
g.Release()
}
func TestWaitBeforeClose(t *testing.T) {
t.Parallel()
g := newGuard()
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on Wait before Close")
}
}()
g.Wait()
}