diff --git a/net/rioconn/guard.go b/net/rioconn/guard.go new file mode 100644 index 000000000..bbf7b1eb3 --- /dev/null +++ b/net/rioconn/guard.go @@ -0,0 +1,90 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "sync" + "sync/atomic" +) + +// guard prevents new operations from starting after Close +// while allowing in-flight operations to complete. +type guard struct { + state atomic.Int64 + done chan struct{} + closeOnce sync.Once +} + +const ( + // state layout: + // bit 62: closed flag (1 == Close called) + // bits 0–61: in-flight operation count + // + // We avoid using bit 63 (the sign bit) so that valid states remain + // non-negative and the counter has a large positive range. guardCountMask + // isolates the counter bits and is used to detect unbalanced Release calls + // or counter overflow (which would wrap into the closed bit). + guardClosedBit = int64(1) << 62 + guardCountMask = guardClosedBit - 1 +) + +func newGuard() *guard { + return &guard{done: make(chan struct{})} +} + +// Acquire attempts to acquire a lease for an operation. +// +// If it reports true, the caller may proceed and must call +// [guard.Release] when done. Otherwise, it must not proceed. +// +// Acquire fails if [guard.Close] has already been called. +func (g *guard) Acquire() bool { + n := g.state.Add(1) + if n&guardClosedBit == 0 { + return true + } + g.decrementAndSignal() + return false +} + +// Release releases a lease acquired by [guard.Acquire]. +// It is a run-time error to call Release without a matching Acquire. +func (g *guard) Release() { + g.decrementAndSignal() +} + +func (g *guard) decrementAndSignal() { + n := g.state.Add(-1) + if n < 0 || n == guardCountMask { + panic("unbalanced Release call") + } + if n == guardClosedBit { + g.closeOnce.Do(func() { close(g.done) }) + } +} + +// Close prevents future Acquire calls from succeeding. +func (g *guard) Close() { + g.state.Or(guardClosedBit) +} + +// IsClosed reports whether Close has been called. +func (g *guard) IsClosed() bool { + return g.state.Load()&guardClosedBit != 0 +} + +// Wait blocks until all in-flight operations have called Release. +// It is a run-time error to call Wait before Close. +func (g *guard) Wait() { + state := g.state.Load() + if state&guardClosedBit == 0 { + panic("Wait called before Close") + } + if state == guardClosedBit { + return + } + <-g.done +} diff --git a/net/rioconn/guard_test.go b/net/rioconn/guard_test.go new file mode 100644 index 000000000..04c276278 --- /dev/null +++ b/net/rioconn/guard_test.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package rioconn + +import ( + "runtime" + "sync" + "testing" +) + +func TestGuardCloseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + g.Close() + g.Wait() +} + +func TestGuardAcquireReleaseCloseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + if !g.Acquire() { + t.Fatal("Acquire failed") + } + g.Release() + g.Close() + g.Wait() +} + +func TestGuardAcquireCloseReleaseAndWait(t *testing.T) { + t.Parallel() + + g := newGuard() + if !g.Acquire() { + t.Fatal("Acquire failed") + } + g.Close() + if g.Acquire() { + t.Fatal("Acquire succeeded after Close") + } + g.Release() + if g.Acquire() { + t.Fatal("Acquire succeeded after Release following Close") + } + g.Wait() +} + +func TestGuardConcurrentUse(t *testing.T) { + t.Parallel() + + const N = 1000 + g := newGuard() + + var wg sync.WaitGroup + wg.Add(N) + for range N { + go func() { + wg.Done() + if !g.Acquire() { + return + } + runtime.Gosched() + g.Release() + }() + } + wg.Wait() // wait for all goroutines to start + + g.Close() + g.Wait() +} + +func TestReleaseWithoutAcquire(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Release without Acquire") + } + }() + g.Release() +} + +func TestReleaseWithoutAcquireAfterClose(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Release without Acquire") + } + }() + g.Close() + g.Release() +} + +func TestWaitBeforeClose(t *testing.T) { + t.Parallel() + + g := newGuard() + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on Wait before Close") + } + }() + g.Wait() +}