From d57b58193a4a86684e6d2ffaf598b6e63ccccec7 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 19 Feb 2026 08:32:39 -0600 Subject: [PATCH] net/rioconn: add guard, a reference-counting/rundown-protection-like synchronization primitive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A RIO connection depends on unmanaged resources such as the socket handle, RIO buffer registrations, and heap-allocated memory. We should protect these resources so we don’t free them while there are in-flight calls that still use them, and prevent new operations from starting after the resources have been released. In this commit, we introduce a guard, a synchronization primitive that does exactly that and will be used in subsequent commits. Updates tailscale/corp#8610 Signed-off-by: Nick Khyl --- net/rioconn/guard.go | 90 +++++++++++++++++++++++++++++++ net/rioconn/guard_test.go | 111 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 net/rioconn/guard.go create mode 100644 net/rioconn/guard_test.go diff --git a/net/rioconn/guard.go b/net/rioconn/guard.go new file mode 100644 index 000000000..bbf7b1eb3 --- /dev/null +++ b/net/rioconn/guard.go @@ -0,0 +1,90 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "sync" + "sync/atomic" +) + +// guard prevents new operations from starting after Close +// while allowing in-flight operations to complete. +type guard struct { + state atomic.Int64 + done chan struct{} + closeOnce sync.Once +} + +const ( + // state layout: + // bit 62: closed flag (1 == Close called) + // bits 0–61: in-flight operation count + // + // We avoid using bit 63 (the sign bit) so that valid states remain + // non-negative and the counter has a large positive range. guardCountMask + // isolates the counter bits and is used to detect unbalanced Release calls + // or counter overflow (which would wrap into the closed bit). + guardClosedBit = int64(1) << 62 + guardCountMask = guardClosedBit - 1 +) + +func newGuard() *guard { + return &guard{done: make(chan struct{})} +} + +// Acquire attempts to acquire a lease for an operation. +// +// If it reports true, the caller may proceed and must call +// [guard.Release] when done. Otherwise, it must not proceed. +// +// Acquire fails if [guard.Close] has already been called. +func (g *guard) Acquire() bool { + n := g.state.Add(1) + if n&guardClosedBit == 0 { + return true + } + g.decrementAndSignal() + return false +} + +// Release releases a lease acquired by [guard.Acquire]. +// It is a run-time error to call Release without a matching Acquire. +func (g *guard) Release() { + g.decrementAndSignal() +} + +func (g *guard) decrementAndSignal() { + n := g.state.Add(-1) + if n < 0 || n == guardCountMask { + panic("unbalanced Release call") + } + if n == guardClosedBit { + g.closeOnce.Do(func() { close(g.done) }) + } +} + +// Close prevents future Acquire calls from succeeding. +func (g *guard) Close() { + g.state.Or(guardClosedBit) +} + +// IsClosed reports whether Close has been called. +func (g *guard) IsClosed() bool { + return g.state.Load()&guardClosedBit != 0 +} + +// Wait blocks until all in-flight operations have called Release. +// It is a run-time error to call Wait before Close. +func (g *guard) Wait() { + state := g.state.Load() + if state&guardClosedBit == 0 { + panic("Wait called before Close") + } + if state == guardClosedBit { + return + } + <-g.done +} diff --git a/net/rioconn/guard_test.go b/net/rioconn/guard_test.go new file mode 100644 index 000000000..04c276278 --- /dev/null +++ b/net/rioconn/guard_test.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "runtime" + "sync" + "testing" +) + +func TestGuardCloseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + g.Close() + g.Wait() +} + +func TestGuardAcquireReleaseCloseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + if !g.Acquire() { + t.Fatal("Acquire failed") + } + g.Release() + g.Close() + g.Wait() +} + +func TestGuardAcquireCloseReleaseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + if !g.Acquire() { + t.Fatal("Acquire failed") + } + g.Close() + if g.Acquire() { + t.Fatal("Acquire succeeded after Close") + } + g.Release() + if g.Acquire() { + t.Fatal("Acquire succeeded after Release following Close") + } + g.Wait() +} + +func TestGuardConcurrentUse(t *testing.T) { + t.Parallel() + + const N = 1000 + g := newGuard() + + var wg sync.WaitGroup + wg.Add(N) + for range N { + go func() { + wg.Done() + if !g.Acquire() { + return + } + runtime.Gosched() + g.Release() + }() + } + wg.Wait() // wait for all goroutines to start + + g.Close() + g.Wait() +} + +func TestReleaseWithoutAcquire(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Release without Acquire") + } + }() + g.Release() +} + +func TestReleaseWithoutAcquireAfterClose(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Release without Acquire") + } + }() + g.Close() + g.Release() +} + +func TestWaitBeforeClose(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Wait before Close") + } + }() + g.Wait() +}