mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-09 15:47:17 +02:00
ipn/ipn{server,test}: extract the LocalAPI test client and server into ipntest
In this PR, we extract the in-process LocalAPI client/server implementation from ipn/ipnserver/server_test.go into a new ipntest package to be used in high‑level black‑box tests, such as those for the tailscale CLI. Updates #15575 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
0f4f808e70
commit
f0a27066c4
@ -179,6 +179,12 @@ func contextWithActor(ctx context.Context, logf logger.Logf, c net.Conn) context
|
|||||||
return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err})
|
return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewContextWithActorForTest returns a new context that carries the identity
|
||||||
|
// of the specified actor. It is used in tests only.
|
||||||
|
func NewContextWithActorForTest(ctx context.Context, actor ipnauth.Actor) context.Context {
|
||||||
|
return actorKey.WithValue(ctx, actorOrError{actor: actor})
|
||||||
|
}
|
||||||
|
|
||||||
// actorFromContext returns an [ipnauth.Actor] associated with ctx,
|
// actorFromContext returns an [ipnauth.Actor] associated with ctx,
|
||||||
// or an error if the context does not carry an actor's identity.
|
// or an error if the context does not carry an actor's identity.
|
||||||
func actorFromContext(ctx context.Context) (ipnauth.Actor, error) {
|
func actorFromContext(ctx context.Context) (ipnauth.Actor, error) {
|
||||||
|
42
ipn/ipnserver/server_fortest.go
Normal file
42
ipn/ipnserver/server_fortest.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package ipnserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"tailscale.com/ipn/ipnauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BlockWhileInUseByOtherForTest blocks while the actor can't connect to the server because
|
||||||
|
// the server is in use by a different actor. It is used in tests only.
|
||||||
|
func (s *Server) BlockWhileInUseByOtherForTest(ctx context.Context, actor ipnauth.Actor) error {
|
||||||
|
return s.blockWhileIdentityInUse(ctx, actor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockWhileInUseForTest blocks until the server becomes idle (no active requests),
|
||||||
|
// or the specified context is done. It returns the context's error if it is done.
|
||||||
|
// It is used in tests only.
|
||||||
|
func (s *Server) BlockWhileInUseForTest(ctx context.Context) error {
|
||||||
|
ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
busy := len(s.activeReqs) != 0
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if busy {
|
||||||
|
<-ready
|
||||||
|
}
|
||||||
|
cleanup()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTPForTest responds to a single LocalAPI HTTP request.
|
||||||
|
// The request's context carries the actor that made the request
|
||||||
|
// and can be created with [NewContextWithActorForTest].
|
||||||
|
// It is used in tests only.
|
||||||
|
func (s *Server) ServeHTTPForTest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.serveHTTP(w, r)
|
||||||
|
}
|
@ -1,76 +1,22 @@
|
|||||||
// Copyright (c) Tailscale Inc & AUTHORS
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
package ipnserver
|
package ipnserver_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"tailscale.com/client/local"
|
"tailscale.com/client/local"
|
||||||
"tailscale.com/client/tailscale"
|
|
||||||
"tailscale.com/client/tailscale/apitype"
|
|
||||||
"tailscale.com/control/controlclient"
|
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/ipn/ipnauth"
|
"tailscale.com/ipn/lapitest"
|
||||||
"tailscale.com/ipn/ipnlocal"
|
|
||||||
"tailscale.com/ipn/store/mem"
|
|
||||||
"tailscale.com/tsd"
|
|
||||||
"tailscale.com/tstest"
|
|
||||||
"tailscale.com/types/logger"
|
|
||||||
"tailscale.com/types/logid"
|
|
||||||
"tailscale.com/types/ptr"
|
"tailscale.com/types/ptr"
|
||||||
"tailscale.com/util/mak"
|
|
||||||
"tailscale.com/wgengine"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWaiterSet(t *testing.T) {
|
|
||||||
var s waiterSet
|
|
||||||
|
|
||||||
wantLen := func(want int, when string) {
|
|
||||||
t.Helper()
|
|
||||||
if got := len(s); got != want {
|
|
||||||
t.Errorf("%s: len = %v; want %v", when, got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
wantLen(0, "initial")
|
|
||||||
var mu sync.Mutex
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
ready, cleanup := s.add(&mu, ctx)
|
|
||||||
wantLen(1, "after add")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ready:
|
|
||||||
t.Fatal("should not be ready")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
s.wakeAll()
|
|
||||||
<-ready
|
|
||||||
|
|
||||||
wantLen(1, "after fire")
|
|
||||||
cleanup()
|
|
||||||
wantLen(0, "after cleanup")
|
|
||||||
|
|
||||||
// And again but on an already-expired ctx.
|
|
||||||
cancel()
|
|
||||||
ready, cleanup = s.add(&mu, ctx)
|
|
||||||
<-ready // shouldn't block
|
|
||||||
cleanup()
|
|
||||||
wantLen(0, "at end")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserConnectDisconnectNonWindows(t *testing.T) {
|
func TestUserConnectDisconnectNonWindows(t *testing.T) {
|
||||||
enableLogging := false
|
enableLogging := false
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@ -78,20 +24,20 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
// UserA connects and starts watching the IPN bus.
|
// UserA connects and starts watching the IPN bus.
|
||||||
clientA := server.getClientAs("UserA")
|
clientA := server.ClientWithName("UserA")
|
||||||
watcherA, _ := clientA.WatchIPNBus(ctx, 0)
|
watcherA, _ := clientA.WatchIPNBus(ctx, 0)
|
||||||
|
|
||||||
// The concept of "current user" is only relevant on Windows
|
// The concept of "current user" is only relevant on Windows
|
||||||
// and it should not be set on non-Windows platforms.
|
// and it should not be set on non-Windows platforms.
|
||||||
server.checkCurrentUser(nil)
|
server.CheckCurrentUser(nil)
|
||||||
|
|
||||||
// Additionally, a different user should be able to connect and use the LocalAPI.
|
// Additionally, a different user should be able to connect and use the LocalAPI.
|
||||||
clientB := server.getClientAs("UserB")
|
clientB := server.ClientWithName("UserB")
|
||||||
if _, gotErr := clientB.Status(ctx); gotErr != nil {
|
if _, gotErr := clientB.Status(ctx); gotErr != nil {
|
||||||
t.Fatalf("Status(%q): want nil; got %v", clientB.User.Name, gotErr)
|
t.Fatalf("Status(%q): want nil; got %v", clientB.Username(), gotErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Watching the IPN bus should also work for UserB.
|
// Watching the IPN bus should also work for UserB.
|
||||||
@ -100,18 +46,18 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) {
|
|||||||
// And if we send a notification, both users should receive it.
|
// And if we send a notification, both users should receive it.
|
||||||
wantErrMessage := "test error"
|
wantErrMessage := "test error"
|
||||||
testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)}
|
testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)}
|
||||||
server.mustBackend().DebugNotify(testNotify)
|
server.Backend().DebugNotify(testNotify)
|
||||||
|
|
||||||
if n, err := watcherA.Next(); err != nil {
|
if n, err := watcherA.Next(); err != nil {
|
||||||
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.User.Name, err)
|
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.Username(), err)
|
||||||
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
|
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
|
||||||
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.User.Name, wantErrMessage, gotErrMessage)
|
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.Username(), wantErrMessage, gotErrMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n, err := watcherB.Next(); err != nil {
|
if n, err := watcherB.Next(); err != nil {
|
||||||
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.User.Name, err)
|
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.Username(), err)
|
||||||
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
|
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
|
||||||
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.User.Name, wantErrMessage, gotErrMessage)
|
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.Username(), wantErrMessage, gotErrMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,21 +66,21 @@ func TestUserConnectDisconnectOnWindows(t *testing.T) {
|
|||||||
setGOOSForTest(t, "windows")
|
setGOOSForTest(t, "windows")
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
client := server.getClientAs("User")
|
client := server.ClientWithName("User")
|
||||||
_, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
_, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
||||||
|
|
||||||
// On Windows, however, the current user should be set to the user that connected.
|
// On Windows, however, the current user should be set to the user that connected.
|
||||||
server.checkCurrentUser(client.User)
|
server.CheckCurrentUser(client.Actor)
|
||||||
|
|
||||||
// Cancel the IPN bus watcher request and wait for the server to unblock.
|
// Cancel the IPN bus watcher request and wait for the server to unblock.
|
||||||
cancelWatcher()
|
cancelWatcher()
|
||||||
server.blockWhileInUse(ctx)
|
server.BlockWhileInUse(ctx)
|
||||||
|
|
||||||
// The current user should not be set after a disconnect, as no one is
|
// The current user should not be set after a disconnect, as no one is
|
||||||
// currently using the server.
|
// currently using the server.
|
||||||
server.checkCurrentUser(nil)
|
server.CheckCurrentUser(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIPNAlreadyInUseOnWindows(t *testing.T) {
|
func TestIPNAlreadyInUseOnWindows(t *testing.T) {
|
||||||
@ -142,22 +88,22 @@ func TestIPNAlreadyInUseOnWindows(t *testing.T) {
|
|||||||
setGOOSForTest(t, "windows")
|
setGOOSForTest(t, "windows")
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
// UserA connects and starts watching the IPN bus.
|
// UserA connects and starts watching the IPN bus.
|
||||||
clientA := server.getClientAs("UserA")
|
clientA := server.ClientWithName("UserA")
|
||||||
clientA.WatchIPNBus(ctx, 0)
|
clientA.WatchIPNBus(ctx, 0)
|
||||||
|
|
||||||
// While UserA is connected, UserB should not be able to connect.
|
// While UserA is connected, UserB should not be able to connect.
|
||||||
clientB := server.getClientAs("UserB")
|
clientB := server.ClientWithName("UserB")
|
||||||
if _, gotErr := clientB.Status(ctx); gotErr == nil {
|
if _, gotErr := clientB.Status(ctx); gotErr == nil {
|
||||||
t.Fatalf("Status(%q): want error; got nil", clientB.User.Name)
|
t.Fatalf("Status(%q): want error; got nil", clientB.Username())
|
||||||
} else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError {
|
} else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError {
|
||||||
t.Fatalf("Status(%q): want %q; got %q", clientB.User.Name, wantError, gotErr.Error())
|
t.Fatalf("Status(%q): want %q; got %q", clientB.Username(), wantError, gotErr.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Current user should still be UserA.
|
// Current user should still be UserA.
|
||||||
server.checkCurrentUser(clientA.User)
|
server.CheckCurrentUser(clientA.Actor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
|
func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
|
||||||
@ -165,22 +111,22 @@ func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
|
|||||||
setGOOSForTest(t, "windows")
|
setGOOSForTest(t, "windows")
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
connectDisconnectAsUser := func(name string) {
|
connectDisconnectAsUser := func(name string) {
|
||||||
// User connects and starts watching the IPN bus.
|
// User connects and starts watching the IPN bus.
|
||||||
client := server.getClientAs(name)
|
client := server.ClientWithName(name)
|
||||||
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
||||||
defer cancelWatcher()
|
defer cancelWatcher()
|
||||||
go pumpIPNBus(watcher)
|
go pumpIPNBus(watcher)
|
||||||
|
|
||||||
// It should be the current user from the LocalBackend's perspective...
|
// It should be the current user from the LocalBackend's perspective...
|
||||||
server.checkCurrentUser(client.User)
|
server.CheckCurrentUser(client.Actor)
|
||||||
// until it disconnects.
|
// until it disconnects.
|
||||||
cancelWatcher()
|
cancelWatcher()
|
||||||
server.blockWhileInUse(ctx)
|
server.BlockWhileInUse(ctx)
|
||||||
// Now, the current user should be unset.
|
// Now, the current user should be unset.
|
||||||
server.checkCurrentUser(nil)
|
server.CheckCurrentUser(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserA logs in, uses Tailscale for a bit, then logs out.
|
// UserA logs in, uses Tailscale for a bit, then logs out.
|
||||||
@ -194,11 +140,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
|
|||||||
setGOOSForTest(t, "windows")
|
setGOOSForTest(t, "windows")
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
connectDisconnectAsUser := func(name string) {
|
connectDisconnectAsUser := func(name string) {
|
||||||
// User connects and starts watching the IPN bus.
|
// User connects and starts watching the IPN bus.
|
||||||
client := server.getClientAs(name)
|
client := server.ClientWithName(name)
|
||||||
watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState)
|
watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState)
|
||||||
defer cancelWatcher()
|
defer cancelWatcher()
|
||||||
|
|
||||||
@ -206,7 +152,7 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
|
|||||||
|
|
||||||
// Get the current user from the LocalBackend's perspective
|
// Get the current user from the LocalBackend's perspective
|
||||||
// as soon as we're connected.
|
// as soon as we're connected.
|
||||||
gotUID, gotActor := server.mustBackend().CurrentUserForTest()
|
gotUID, gotActor := server.Backend().CurrentUserForTest()
|
||||||
|
|
||||||
// Wait for the first notification to arrive.
|
// Wait for the first notification to arrive.
|
||||||
// It will either be the initial state we've requested via [ipn.NotifyInitialState],
|
// It will either be the initial state we've requested via [ipn.NotifyInitialState],
|
||||||
@ -225,17 +171,17 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, our user should have been the current user since the time we connected.
|
// Otherwise, our user should have been the current user since the time we connected.
|
||||||
if gotUID != client.User.UID {
|
if gotUID != client.Actor.UserID() {
|
||||||
t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.User.UID)
|
t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.Actor.UserID())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gotActor, ok := gotActor.(*ipnauth.TestActor); !ok || *gotActor != *client.User {
|
if hasActor := gotActor != nil; !hasActor || gotActor != client.Actor {
|
||||||
t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.User)
|
t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.Actor)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// And should still be the current user (as they're still connected)...
|
// And should still be the current user (as they're still connected)...
|
||||||
server.checkCurrentUser(client.User)
|
server.CheckCurrentUser(client.Actor)
|
||||||
}
|
}
|
||||||
|
|
||||||
numIterations := 10
|
numIterations := 10
|
||||||
@ -253,11 +199,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if err := server.blockWhileInUse(ctx); err != nil {
|
if err := server.BlockWhileInUse(ctx); err != nil {
|
||||||
t.Fatalf("blockWhileInUse: %v", err)
|
t.Fatalf("BlockUntilIdle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server.checkCurrentUser(nil)
|
server.CheckCurrentUser(nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,13 +212,13 @@ func TestBlockWhileIdentityInUse(t *testing.T) {
|
|||||||
setGOOSForTest(t, "windows")
|
setGOOSForTest(t, "windows")
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
server := startDefaultTestIPNServer(t, ctx, enableLogging)
|
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
|
||||||
|
|
||||||
// connectWaitDisconnectAsUser connects as a user with the specified name
|
// connectWaitDisconnectAsUser connects as a user with the specified name
|
||||||
// and keeps the IPN bus watcher alive until the context is canceled.
|
// and keeps the IPN bus watcher alive until the context is canceled.
|
||||||
// It returns a channel that is closed when done.
|
// It returns a channel that is closed when done.
|
||||||
connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} {
|
connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} {
|
||||||
client := server.getClientAs(name)
|
client := server.ClientWithName(name)
|
||||||
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@ -301,8 +247,8 @@ func TestBlockWhileIdentityInUse(t *testing.T) {
|
|||||||
// in blockWhileIdentityInUse. But the issue also occurs during
|
// in blockWhileIdentityInUse. But the issue also occurs during
|
||||||
// the normal execution path when UserB connects to the IPN server
|
// the normal execution path when UserB connects to the IPN server
|
||||||
// while UserA is disconnecting.
|
// while UserA is disconnecting.
|
||||||
userB := server.makeTestUser("UserB", "ClientB")
|
userB := server.MakeTestActor("UserB", "ClientB")
|
||||||
server.blockWhileIdentityInUse(ctx, userB)
|
server.BlockWhileInUseByOther(ctx, userB)
|
||||||
<-userADone
|
<-userADone
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,41 +259,7 @@ func setGOOSForTest(tb testing.TB, goos string) {
|
|||||||
tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") })
|
tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") })
|
||||||
}
|
}
|
||||||
|
|
||||||
func testLogger(tb testing.TB, enableLogging bool) logger.Logf {
|
func pumpIPNBus(watcher *local.IPNBusWatcher) {
|
||||||
tb.Helper()
|
|
||||||
if enableLogging {
|
|
||||||
return tstest.WhileTestRunningLogger(tb)
|
|
||||||
}
|
|
||||||
return logger.Discard
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTestIPNServer creates a new IPN server for testing, using the specified local backend.
|
|
||||||
func newTestIPNServer(tb testing.TB, lb *ipnlocal.LocalBackend, enableLogging bool) *Server {
|
|
||||||
tb.Helper()
|
|
||||||
server := New(testLogger(tb, enableLogging), logid.PublicID{}, lb.NetMon())
|
|
||||||
server.lb.Store(lb)
|
|
||||||
return server
|
|
||||||
}
|
|
||||||
|
|
||||||
type testIPNClient struct {
|
|
||||||
tb testing.TB
|
|
||||||
*local.Client
|
|
||||||
User *ipnauth.TestActor
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testIPNClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, context.CancelFunc) {
|
|
||||||
c.tb.Helper()
|
|
||||||
ctx, cancelWatcher := context.WithCancel(ctx)
|
|
||||||
c.tb.Cleanup(cancelWatcher)
|
|
||||||
watcher, err := c.Client.WatchIPNBus(ctx, mask)
|
|
||||||
if err != nil {
|
|
||||||
c.tb.Fatalf("WatchIPNBus(%q): %v", c.User.Name, err)
|
|
||||||
}
|
|
||||||
c.tb.Cleanup(func() { watcher.Close() })
|
|
||||||
return watcher, cancelWatcher
|
|
||||||
}
|
|
||||||
|
|
||||||
func pumpIPNBus(watcher *tailscale.IPNBusWatcher) {
|
|
||||||
for {
|
for {
|
||||||
_, err := watcher.Next()
|
_, err := watcher.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -355,206 +267,3 @@ func pumpIPNBus(watcher *tailscale.IPNBusWatcher) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type testIPNServer struct {
|
|
||||||
tb testing.TB
|
|
||||||
*Server
|
|
||||||
clientID atomic.Int64
|
|
||||||
getClient func(*ipnauth.TestActor) *local.Client
|
|
||||||
|
|
||||||
actorsMu sync.Mutex
|
|
||||||
actors map[string]*ipnauth.TestActor
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testIPNServer) getClientAs(name string) *testIPNClient {
|
|
||||||
clientID := fmt.Sprintf("Client-%d", 1+s.clientID.Add(1))
|
|
||||||
user := s.makeTestUser(name, clientID)
|
|
||||||
return &testIPNClient{
|
|
||||||
tb: s.tb,
|
|
||||||
Client: s.getClient(user),
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testIPNServer) makeTestUser(name string, clientID string) *ipnauth.TestActor {
|
|
||||||
s.actorsMu.Lock()
|
|
||||||
defer s.actorsMu.Unlock()
|
|
||||||
actor := s.actors[name]
|
|
||||||
if actor == nil {
|
|
||||||
actor = &ipnauth.TestActor{Name: name}
|
|
||||||
if envknob.GOOS() == "windows" {
|
|
||||||
// Historically, as of 2025-01-13, IPN does not distinguish between
|
|
||||||
// different users on non-Windows devices. Therefore, the UID, which is
|
|
||||||
// an [ipn.WindowsUserID], should only be populated when the actual or
|
|
||||||
// fake GOOS is Windows.
|
|
||||||
actor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actors)))
|
|
||||||
}
|
|
||||||
mak.Set(&s.actors, name, actor)
|
|
||||||
s.tb.Cleanup(func() { delete(s.actors, name) })
|
|
||||||
}
|
|
||||||
actor = ptr.To(*actor)
|
|
||||||
actor.CID = ipnauth.ClientIDFrom(clientID)
|
|
||||||
return actor
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testIPNServer) blockWhileInUse(ctx context.Context) error {
|
|
||||||
ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx)
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
busy := len(s.activeReqs) != 0
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if busy {
|
|
||||||
<-ready
|
|
||||||
}
|
|
||||||
cleanup()
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testIPNServer) checkCurrentUser(want *ipnauth.TestActor) {
|
|
||||||
s.tb.Helper()
|
|
||||||
var wantUID ipn.WindowsUserID
|
|
||||||
if want != nil {
|
|
||||||
wantUID = want.UID
|
|
||||||
}
|
|
||||||
gotUID, gotActor := s.mustBackend().CurrentUserForTest()
|
|
||||||
if gotUID != wantUID {
|
|
||||||
s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID)
|
|
||||||
}
|
|
||||||
if gotActor, ok := gotActor.(*ipnauth.TestActor); ok != (want != nil) || (want != nil && *gotActor != *want) {
|
|
||||||
s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// startTestIPNServer starts a [httptest.Server] that hosts the specified IPN server for the
|
|
||||||
// duration of the test, using the specified base context for incoming requests.
|
|
||||||
// It returns a function that creates a [local.Client] as a given [ipnauth.TestActor].
|
|
||||||
func startTestIPNServer(tb testing.TB, baseContext context.Context, server *Server) *testIPNServer {
|
|
||||||
tb.Helper()
|
|
||||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
actor, err := extractActorFromHeader(r.Header)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
tb.Errorf("extractActorFromHeader: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx := newTestContextWithActor(r.Context(), actor)
|
|
||||||
server.serveHTTP(w, r.Clone(ctx))
|
|
||||||
}))
|
|
||||||
ts.Config.Addr = "http://" + apitype.LocalAPIHost
|
|
||||||
ts.Config.BaseContext = func(_ net.Listener) context.Context { return baseContext }
|
|
||||||
ts.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(server.logf, "ipnserver: "))
|
|
||||||
ts.Start()
|
|
||||||
tb.Cleanup(ts.Close)
|
|
||||||
return &testIPNServer{
|
|
||||||
tb: tb,
|
|
||||||
Server: server,
|
|
||||||
getClient: func(actor *ipnauth.TestActor) *local.Client {
|
|
||||||
return &local.Client{Transport: newTestRoundTripper(ts, actor)}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startDefaultTestIPNServer(tb testing.TB, ctx context.Context, enableLogging bool) *testIPNServer {
|
|
||||||
tb.Helper()
|
|
||||||
lb := newLocalBackendWithTestControl(tb, newUnreachableControlClient, enableLogging)
|
|
||||||
ctx, stopServer := context.WithCancel(ctx)
|
|
||||||
tb.Cleanup(stopServer)
|
|
||||||
return startTestIPNServer(tb, ctx, newTestIPNServer(tb, lb, enableLogging))
|
|
||||||
}
|
|
||||||
|
|
||||||
type testRoundTripper struct {
|
|
||||||
transport http.RoundTripper
|
|
||||||
actor *ipnauth.TestActor
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTestRoundTripper creates a new [http.RoundTripper] that sends requests
|
|
||||||
// to the specified test server as the specified actor.
|
|
||||||
func newTestRoundTripper(ts *httptest.Server, actor *ipnauth.TestActor) *testRoundTripper {
|
|
||||||
return &testRoundTripper{
|
|
||||||
transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
var std net.Dialer
|
|
||||||
return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String())
|
|
||||||
}},
|
|
||||||
actor: actor,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const testActorHeaderName = "TS-Test-Actor"
|
|
||||||
|
|
||||||
// RoundTrip implements [http.RoundTripper] by forwarding the request to the underlying transport
|
|
||||||
// and including the test actor's identity in the request headers.
|
|
||||||
func (rt *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
||||||
actorJSON, err := json.Marshal(&rt.actor)
|
|
||||||
if err != nil {
|
|
||||||
// An [http.RoundTripper] must always close the request body, including on error.
|
|
||||||
if r.Body != nil {
|
|
||||||
r.Body.Close()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r = r.Clone(r.Context())
|
|
||||||
r.Header.Set(testActorHeaderName, string(actorJSON))
|
|
||||||
return rt.transport.RoundTrip(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractActorFromHeader extracts a test actor from the specified request headers.
|
|
||||||
func extractActorFromHeader(h http.Header) (*ipnauth.TestActor, error) {
|
|
||||||
actorJSON := h.Get(testActorHeaderName)
|
|
||||||
if actorJSON == "" {
|
|
||||||
return nil, errors.New("missing Test-Actor header")
|
|
||||||
}
|
|
||||||
actor := &ipnauth.TestActor{}
|
|
||||||
if err := json.Unmarshal([]byte(actorJSON), &actor); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid Test-Actor header: %v", err)
|
|
||||||
}
|
|
||||||
return actor, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type newControlClientFn func(tb testing.TB, opts controlclient.Options) controlclient.Client
|
|
||||||
|
|
||||||
func newLocalBackendWithTestControl(tb testing.TB, newControl newControlClientFn, enableLogging bool) *ipnlocal.LocalBackend {
|
|
||||||
tb.Helper()
|
|
||||||
|
|
||||||
sys := tsd.NewSystem()
|
|
||||||
store := &mem.Store{}
|
|
||||||
sys.Set(store)
|
|
||||||
|
|
||||||
logf := testLogger(tb, enableLogging)
|
|
||||||
e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get())
|
|
||||||
if err != nil {
|
|
||||||
tb.Fatalf("NewFakeUserspaceEngine: %v", err)
|
|
||||||
}
|
|
||||||
tb.Cleanup(e.Close)
|
|
||||||
sys.Set(e)
|
|
||||||
|
|
||||||
b, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
|
|
||||||
if err != nil {
|
|
||||||
tb.Fatalf("NewLocalBackend: %v", err)
|
|
||||||
}
|
|
||||||
tb.Cleanup(b.Shutdown)
|
|
||||||
b.DisablePortMapperForTest()
|
|
||||||
|
|
||||||
b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) {
|
|
||||||
return newControl(tb, opts), nil
|
|
||||||
})
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUnreachableControlClient(tb testing.TB, opts controlclient.Options) controlclient.Client {
|
|
||||||
tb.Helper()
|
|
||||||
opts.ServerURL = "https://127.0.0.1:1"
|
|
||||||
cc, err := controlclient.New(opts)
|
|
||||||
if err != nil {
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
return cc
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTestContextWithActor returns a new context that carries the identity
|
|
||||||
// of the specified actor and can be used for testing.
|
|
||||||
// It can be retrieved with [actorFromContext].
|
|
||||||
func newTestContextWithActor(ctx context.Context, actor ipnauth.Actor) context.Context {
|
|
||||||
return actorKey.WithValue(ctx, actorOrError{actor: actor})
|
|
||||||
}
|
|
||||||
|
46
ipn/ipnserver/waiterset_test.go
Normal file
46
ipn/ipnserver/waiterset_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package ipnserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWaiterSet(t *testing.T) {
|
||||||
|
var s waiterSet
|
||||||
|
|
||||||
|
wantLen := func(want int, when string) {
|
||||||
|
t.Helper()
|
||||||
|
if got := len(s); got != want {
|
||||||
|
t.Errorf("%s: len = %v; want %v", when, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wantLen(0, "initial")
|
||||||
|
var mu sync.Mutex
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ready, cleanup := s.add(&mu, ctx)
|
||||||
|
wantLen(1, "after add")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
t.Fatal("should not be ready")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
s.wakeAll()
|
||||||
|
<-ready
|
||||||
|
|
||||||
|
wantLen(1, "after fire")
|
||||||
|
cleanup()
|
||||||
|
wantLen(0, "after cleanup")
|
||||||
|
|
||||||
|
// And again but on an already-expired ctx.
|
||||||
|
cancel()
|
||||||
|
ready, cleanup = s.add(&mu, ctx)
|
||||||
|
<-ready // shouldn't block
|
||||||
|
cleanup()
|
||||||
|
wantLen(0, "at end")
|
||||||
|
}
|
63
ipn/lapitest/backend.go
Normal file
63
ipn/lapitest/backend.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lapitest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/control/controlclient"
|
||||||
|
"tailscale.com/ipn/ipnlocal"
|
||||||
|
"tailscale.com/ipn/store/mem"
|
||||||
|
"tailscale.com/types/logid"
|
||||||
|
"tailscale.com/wgengine"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewBackend returns a new [ipnlocal.LocalBackend] for testing purposes.
|
||||||
|
// It fails the test if the specified options are invalid or if the backend cannot be created.
|
||||||
|
func NewBackend(tb testing.TB, opts ...Option) *ipnlocal.LocalBackend {
|
||||||
|
tb.Helper()
|
||||||
|
options, err := newOptions(tb, opts...)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("NewBackend: %v", err)
|
||||||
|
}
|
||||||
|
return newBackend(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBackend(opts *options) *ipnlocal.LocalBackend {
|
||||||
|
tb := opts.TB()
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
sys := opts.Sys()
|
||||||
|
if _, ok := sys.StateStore.GetOK(); !ok {
|
||||||
|
sys.Set(&mem.Store{})
|
||||||
|
}
|
||||||
|
|
||||||
|
e, err := wgengine.NewFakeUserspaceEngine(opts.Logf(), sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get())
|
||||||
|
if err != nil {
|
||||||
|
opts.tb.Fatalf("NewFakeUserspaceEngine: %v", err)
|
||||||
|
}
|
||||||
|
tb.Cleanup(e.Close)
|
||||||
|
sys.Set(e)
|
||||||
|
|
||||||
|
b, err := ipnlocal.NewLocalBackend(opts.Logf(), logid.PublicID{}, sys, 0)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("NewLocalBackend: %v", err)
|
||||||
|
}
|
||||||
|
tb.Cleanup(b.Shutdown)
|
||||||
|
b.DisablePortMapperForTest()
|
||||||
|
b.SetControlClientGetterForTesting(opts.MakeControlClient)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnreachableControlClient is a [NewControlFn] that creates
|
||||||
|
// a new [controlclient.Client] for an unreachable control server.
|
||||||
|
func NewUnreachableControlClient(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) {
|
||||||
|
tb.Helper()
|
||||||
|
opts.ServerURL = "https://127.0.0.1:1"
|
||||||
|
cc, err := controlclient.New(opts)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
return cc, nil
|
||||||
|
}
|
71
ipn/lapitest/client.go
Normal file
71
ipn/lapitest/client.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lapitest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/client/local"
|
||||||
|
"tailscale.com/ipn"
|
||||||
|
"tailscale.com/ipn/ipnauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client wraps a [local.Client] for testing purposes.
|
||||||
|
// It can be created using [Server.Client], [Server.ClientWithName],
|
||||||
|
// or [Server.ClientFor] and sends requests as the specified actor
|
||||||
|
// to the associated [Server].
|
||||||
|
type Client struct {
|
||||||
|
tb testing.TB
|
||||||
|
// Client is the underlying [local.Client] wrapped by the test client.
|
||||||
|
// It is configured to send requests to the test server on behalf of the actor.
|
||||||
|
*local.Client
|
||||||
|
// Actor represents the user on whose behalf this client is making requests.
|
||||||
|
// The server uses it to determine the client's identity and permissions.
|
||||||
|
// The test can mutate the user to alter the actor's identity or permissions
|
||||||
|
// before making a new request. It is typically an [ipnauth.TestActor],
|
||||||
|
// unless the [Client] was created with s specific actor using [Server.ClientFor].
|
||||||
|
Actor ipnauth.Actor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Username returns username of the client's owner.
|
||||||
|
func (c *Client) Username() string {
|
||||||
|
c.tb.Helper()
|
||||||
|
name, err := c.Actor.Username()
|
||||||
|
if err != nil {
|
||||||
|
c.tb.Fatalf("Client.Username: %v", err)
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchIPNBus is like [local.Client.WatchIPNBus] but returns a [local.IPNBusWatcher]
|
||||||
|
// that is closed when the test ends and a cancel function that stops the watcher.
|
||||||
|
// It fails the test if the underlying WatchIPNBus returns an error.
|
||||||
|
func (c *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, context.CancelFunc) {
|
||||||
|
c.tb.Helper()
|
||||||
|
ctx, cancelWatcher := context.WithCancel(ctx)
|
||||||
|
c.tb.Cleanup(cancelWatcher)
|
||||||
|
watcher, err := c.Client.WatchIPNBus(ctx, mask)
|
||||||
|
name, _ := c.Actor.Username()
|
||||||
|
if err != nil {
|
||||||
|
c.tb.Fatalf("Client.WatchIPNBus(%q): %v", name, err)
|
||||||
|
}
|
||||||
|
c.tb.Cleanup(func() { watcher.Close() })
|
||||||
|
return watcher, cancelWatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSequentialName generates a unique sequential name based on the given prefix and number n.
|
||||||
|
// It uses a base-26 encoding to create names like "User-A", "User-B", ..., "User-Z", "User-AA", etc.
|
||||||
|
func generateSequentialName(prefix string, n int) string {
|
||||||
|
n++
|
||||||
|
name := ""
|
||||||
|
const numLetters = 'Z' - 'A' + 1
|
||||||
|
for n > 0 {
|
||||||
|
n--
|
||||||
|
remainder := byte(n % numLetters)
|
||||||
|
name = string([]byte{'A' + remainder}) + name
|
||||||
|
n = n / numLetters
|
||||||
|
}
|
||||||
|
return prefix + "-" + name
|
||||||
|
}
|
80
ipn/lapitest/example_test.go
Normal file
80
ipn/lapitest/example_test.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lapitest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/ipn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientServer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a server and two clients.
|
||||||
|
// Both clients represent the same user to make this work across platforms.
|
||||||
|
// On Windows we've been restricting the API usage to a single user at a time.
|
||||||
|
// While we're planning on changing this once a better permission model is in place,
|
||||||
|
// this test is currently limited to a single user (but more than one client is fine).
|
||||||
|
// Alternatively, we could override GOOS via envknobs to test as if we're
|
||||||
|
// on a different platform, but that would make the test depend on global state, etc.
|
||||||
|
s := NewServer(t, WithLogging(false))
|
||||||
|
c1 := s.ClientWithName("User-A")
|
||||||
|
c2 := s.ClientWithName("User-A")
|
||||||
|
|
||||||
|
// Start watching the IPN bus as the second client.
|
||||||
|
w2, _ := c2.WatchIPNBus(context.Background(), ipn.NotifyInitialPrefs)
|
||||||
|
|
||||||
|
// We're supposed to get a notification about the initial prefs,
|
||||||
|
// and WantRunning should be false.
|
||||||
|
n, err := w2.Next()
|
||||||
|
for ; err == nil; n, err = w2.Next() {
|
||||||
|
if n.Prefs == nil {
|
||||||
|
// Ignore non-prefs notifications.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n.Prefs.WantRunning() {
|
||||||
|
t.Errorf("WantRunning(initial): got %v, want false", n.Prefs.WantRunning())
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IPNBusWatcher.Next failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now send an EditPrefs request from the first client to set WantRunning to true.
|
||||||
|
change := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true}
|
||||||
|
gotPrefs, err := c1.EditPrefs(ctx, change)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EditPrefs failed: %v", err)
|
||||||
|
}
|
||||||
|
if !gotPrefs.WantRunning {
|
||||||
|
t.Fatalf("EditPrefs.WantRunning: got %v, want true", gotPrefs.WantRunning)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can check the backend directly to see if the prefs were set correctly.
|
||||||
|
if gotWantRunning := s.Backend().Prefs().WantRunning(); !gotWantRunning {
|
||||||
|
t.Fatalf("Backend.Prefs.WantRunning: got %v, want true", gotWantRunning)
|
||||||
|
}
|
||||||
|
|
||||||
|
// And can also wait for the second client with an IPN bus watcher to receive the notification
|
||||||
|
// about the prefs change.
|
||||||
|
n, err = w2.Next()
|
||||||
|
for ; err == nil; n, err = w2.Next() {
|
||||||
|
if n.Prefs == nil {
|
||||||
|
// Ignore non-prefs notifications.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !n.Prefs.WantRunning() {
|
||||||
|
t.Fatalf("WantRunning(changed): got %v, want true", n.Prefs.WantRunning())
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IPNBusWatcher.Next failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
170
ipn/lapitest/opts.go
Normal file
170
ipn/lapitest/opts.go
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lapitest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/control/controlclient"
|
||||||
|
"tailscale.com/ipn/ipnlocal"
|
||||||
|
"tailscale.com/tsd"
|
||||||
|
"tailscale.com/tstest"
|
||||||
|
"tailscale.com/types/lazy"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option is any optional configuration that can be passed to [NewServer] or [NewBackend].
|
||||||
|
type Option interface {
|
||||||
|
apply(*options) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// options is the merged result of all applied [Option]s.
|
||||||
|
type options struct {
|
||||||
|
tb testing.TB
|
||||||
|
ctx lazy.SyncValue[context.Context]
|
||||||
|
logf lazy.SyncValue[logger.Logf]
|
||||||
|
sys lazy.SyncValue[*tsd.System]
|
||||||
|
newCC lazy.SyncValue[NewControlFn]
|
||||||
|
backend lazy.SyncValue[*ipnlocal.LocalBackend]
|
||||||
|
}
|
||||||
|
|
||||||
|
// newOptions returns a new [options] struct with the specified [Option]s applied.
|
||||||
|
func newOptions(tb testing.TB, opts ...Option) (*options, error) {
|
||||||
|
options := &options{tb: tb}
|
||||||
|
for _, opt := range opts {
|
||||||
|
if err := opt.apply(options); err != nil {
|
||||||
|
return nil, fmt.Errorf("lapitest: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return options, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TB returns the owning [*testing.T] or [*testing.B].
|
||||||
|
func (o *options) TB() testing.TB {
|
||||||
|
return o.tb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns the base context to be used by the server.
|
||||||
|
func (o *options) Context() context.Context {
|
||||||
|
return o.ctx.Get(context.Background)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logf returns the [logger.Logf] to be used for logging.
|
||||||
|
func (o *options) Logf() logger.Logf {
|
||||||
|
return o.logf.Get(func() logger.Logf { return logger.Discard })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sys returns the [tsd.System] that contains subsystems to be used
|
||||||
|
// when creating a new [ipnlocal.LocalBackend].
|
||||||
|
func (o *options) Sys() *tsd.System {
|
||||||
|
return o.sys.Get(func() *tsd.System { return tsd.NewSystem() })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend returns the [ipnlocal.LocalBackend] to be used by the server.
|
||||||
|
// If a backend is provided via [WithBackend], it is used as-is.
|
||||||
|
// Otherwise, a new backend is created with the the [options] in o.
|
||||||
|
func (o *options) Backend() *ipnlocal.LocalBackend {
|
||||||
|
return o.backend.Get(func() *ipnlocal.LocalBackend { return newBackend(o) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeControlClient returns a new [controlclient.Client] to be used by newly
|
||||||
|
// created [ipnlocal.LocalBackend]s. It is only used if no backend is provided
|
||||||
|
// via [WithBackend].
|
||||||
|
func (o *options) MakeControlClient(opts controlclient.Options) (controlclient.Client, error) {
|
||||||
|
newCC := o.newCC.Get(func() NewControlFn { return NewUnreachableControlClient })
|
||||||
|
return newCC(o.tb, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type loggingOption struct{ enableLogging bool }
|
||||||
|
|
||||||
|
// WithLogging returns an [Option] that enables or disables logging.
|
||||||
|
func WithLogging(enableLogging bool) Option {
|
||||||
|
return loggingOption{enableLogging: enableLogging}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o loggingOption) apply(opts *options) error {
|
||||||
|
var logf logger.Logf
|
||||||
|
if o.enableLogging {
|
||||||
|
logf = tstest.WhileTestRunningLogger(opts.tb)
|
||||||
|
} else {
|
||||||
|
logf = logger.Discard
|
||||||
|
}
|
||||||
|
if !opts.logf.Set(logf) {
|
||||||
|
return errors.New("logging already configured")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextOption struct{ ctx context.Context }
|
||||||
|
|
||||||
|
// WithContext returns an [Option] that sets the base context to be used by the [Server].
|
||||||
|
func WithContext(ctx context.Context) Option {
|
||||||
|
return contextOption{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o contextOption) apply(opts *options) error {
|
||||||
|
if !opts.ctx.Set(o.ctx) {
|
||||||
|
return errors.New("context already configured")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sysOption struct{ sys *tsd.System }
|
||||||
|
|
||||||
|
// WithSys returns an [Option] that sets the [tsd.System] to be used
|
||||||
|
// when creating a new [ipnlocal.LocalBackend].
|
||||||
|
func WithSys(sys *tsd.System) Option {
|
||||||
|
return sysOption{sys: sys}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o sysOption) apply(opts *options) error {
|
||||||
|
if !opts.sys.Set(o.sys) {
|
||||||
|
return errors.New("tsd.System already configured")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type backendOption struct{ backend *ipnlocal.LocalBackend }
|
||||||
|
|
||||||
|
// WithBackend returns an [Option] that configures the server to use the specified
|
||||||
|
// [ipnlocal.LocalBackend] instead of creating a new one.
|
||||||
|
// It is mutually exclusive with [WithControlClient].
|
||||||
|
func WithBackend(backend *ipnlocal.LocalBackend) Option {
|
||||||
|
return backendOption{backend: backend}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o backendOption) apply(opts *options) error {
|
||||||
|
if _, ok := opts.backend.Peek(); ok {
|
||||||
|
return errors.New("backend cannot be set when control client is already set")
|
||||||
|
}
|
||||||
|
if !opts.backend.Set(o.backend) {
|
||||||
|
return errors.New("backend already set")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewControlFn is any function that creates a new [controlclient.Client]
|
||||||
|
// with the specified options.
|
||||||
|
type NewControlFn func(tb testing.TB, opts controlclient.Options) (controlclient.Client, error)
|
||||||
|
|
||||||
|
// WithControlClient returns an option that specifies a function to be used
|
||||||
|
// by the [ipnlocal.LocalBackend] when creating a new [controlclient.Client].
|
||||||
|
// It is mutually exclusive with [WithBackend] and is only used if no backend
|
||||||
|
// has been provided.
|
||||||
|
func WithControlClient(newControl NewControlFn) Option {
|
||||||
|
return newControl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fn NewControlFn) apply(opts *options) error {
|
||||||
|
if _, ok := opts.backend.Peek(); ok {
|
||||||
|
return errors.New("control client cannot be set when backend is already set")
|
||||||
|
}
|
||||||
|
if !opts.newCC.Set(fn) {
|
||||||
|
return errors.New("control client already set")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
324
ipn/lapitest/server.go
Normal file
324
ipn/lapitest/server.go
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package lapitest provides utilities for black-box testing of LocalAPI ([ipnserver]).
|
||||||
|
package lapitest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/client/local"
|
||||||
|
"tailscale.com/client/tailscale/apitype"
|
||||||
|
"tailscale.com/envknob"
|
||||||
|
"tailscale.com/ipn"
|
||||||
|
"tailscale.com/ipn/ipnauth"
|
||||||
|
"tailscale.com/ipn/ipnlocal"
|
||||||
|
"tailscale.com/ipn/ipnserver"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/types/logid"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
"tailscale.com/util/mak"
|
||||||
|
"tailscale.com/util/rands"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Server is an in-process LocalAPI server that can be used in end-to-end tests.
|
||||||
|
type Server struct {
|
||||||
|
tb testing.TB
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelCtx context.CancelFunc
|
||||||
|
|
||||||
|
lb *ipnlocal.LocalBackend
|
||||||
|
ipnServer *ipnserver.Server
|
||||||
|
|
||||||
|
// mu protects the following fields.
|
||||||
|
mu sync.Mutex
|
||||||
|
started bool
|
||||||
|
httpServer *httptest.Server
|
||||||
|
actorsByName map[string]*ipnauth.TestActor
|
||||||
|
lastClientID int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnstartedServer returns a new [Server] with the specified options without starting it.
|
||||||
|
func NewUnstartedServer(tb testing.TB, opts ...Option) *Server {
|
||||||
|
tb.Helper()
|
||||||
|
options, err := newOptions(tb, opts...)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("invalid options: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{tb: tb, lb: options.Backend()}
|
||||||
|
s.ctx, s.cancelCtx = context.WithCancel(options.Context())
|
||||||
|
s.ipnServer = newUnstartedIPNServer(options)
|
||||||
|
s.httpServer = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP))
|
||||||
|
s.httpServer.Config.Addr = "http://" + apitype.LocalAPIHost
|
||||||
|
s.httpServer.Config.BaseContext = func(_ net.Listener) context.Context { return s.ctx }
|
||||||
|
s.httpServer.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(options.Logf(), "lapitest: "))
|
||||||
|
tb.Cleanup(s.Close)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer starts and returns a new [Server] with the specified options.
|
||||||
|
func NewServer(tb testing.TB, opts ...Option) *Server {
|
||||||
|
tb.Helper()
|
||||||
|
server := NewUnstartedServer(tb, opts...)
|
||||||
|
server.Start()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the server from [NewUnstartedServer].
|
||||||
|
func (s *Server) Start() {
|
||||||
|
s.tb.Helper()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if !s.started && s.httpServer != nil {
|
||||||
|
s.httpServer.Start()
|
||||||
|
s.started = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend returns the underlying [ipnlocal.LocalBackend].
|
||||||
|
func (s *Server) Backend() *ipnlocal.LocalBackend {
|
||||||
|
s.tb.Helper()
|
||||||
|
return s.lb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns a new [Client] configured for making requests to the server
|
||||||
|
// as a new [ipnauth.TestActor] with a unique username and [ipnauth.ClientID].
|
||||||
|
func (s *Server) Client() *Client {
|
||||||
|
s.tb.Helper()
|
||||||
|
user := s.MakeTestActor("", "") // generate a unique username and client ID
|
||||||
|
return s.ClientFor(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientWithName returns a new [Client] configured for making requests to the server
|
||||||
|
// as a new [ipnauth.TestActor] with the specified name and a unique [ipnauth.ClientID].
|
||||||
|
func (s *Server) ClientWithName(name string) *Client {
|
||||||
|
s.tb.Helper()
|
||||||
|
user := s.MakeTestActor(name, "") // generate a unique client ID
|
||||||
|
return s.ClientFor(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientFor returns a new [Client] configured for making requests to the server
|
||||||
|
// as the specified actor.
|
||||||
|
func (s *Server) ClientFor(actor ipnauth.Actor) *Client {
|
||||||
|
s.tb.Helper()
|
||||||
|
client := &Client{
|
||||||
|
tb: s.tb,
|
||||||
|
Actor: actor,
|
||||||
|
}
|
||||||
|
client.Client = &local.Client{Transport: newRoundTripper(client, s.httpServer)}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeTestActor returns a new [ipnauth.TestActor] with the specified name and client ID.
|
||||||
|
// If the name is empty, a unique sequential name is generated. Likewise,
|
||||||
|
// if clientID is empty, a unique sequential client ID is generated.
|
||||||
|
func (s *Server) MakeTestActor(name string, clientID string) *ipnauth.TestActor {
|
||||||
|
s.tb.Helper()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Generate a unique sequential name if the provided name is empty.
|
||||||
|
if name == "" {
|
||||||
|
n := len(s.actorsByName)
|
||||||
|
name = generateSequentialName("User", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID == "" {
|
||||||
|
s.lastClientID += 1
|
||||||
|
clientID = fmt.Sprintf("Client-%d", s.lastClientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new base actor if one doesn't already exist for the given name.
|
||||||
|
baseActor := s.actorsByName[name]
|
||||||
|
if baseActor == nil {
|
||||||
|
baseActor = &ipnauth.TestActor{Name: name}
|
||||||
|
if envknob.GOOS() == "windows" {
|
||||||
|
// Historically, as of 2025-04-15, IPN does not distinguish between
|
||||||
|
// different users on non-Windows devices. Therefore, the UID, which is
|
||||||
|
// an [ipn.WindowsUserID], should only be populated when the actual or
|
||||||
|
// fake GOOS is Windows.
|
||||||
|
baseActor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actorsByName)))
|
||||||
|
}
|
||||||
|
mak.Set(&s.actorsByName, name, baseActor)
|
||||||
|
s.tb.Cleanup(func() { delete(s.actorsByName, name) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a shallow copy of the base actor and assign it the new client ID.
|
||||||
|
actor := ptr.To(*baseActor)
|
||||||
|
actor.CID = ipnauth.ClientIDFrom(clientID)
|
||||||
|
return actor
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockWhileInUse blocks until the server becomes idle (no active requests),
|
||||||
|
// or the context is done. It returns the context's error if it is done.
|
||||||
|
// It is used in tests only.
|
||||||
|
func (s *Server) BlockWhileInUse(ctx context.Context) error {
|
||||||
|
s.tb.Helper()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.httpServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.ipnServer.BlockWhileInUseForTest(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockWhileInUseByOther blocks while the specified actor can't connect to the server
|
||||||
|
// due to another actor being connected.
|
||||||
|
// It is used in tests only.
|
||||||
|
func (s *Server) BlockWhileInUseByOther(ctx context.Context, actor ipnauth.Actor) error {
|
||||||
|
s.tb.Helper()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.httpServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.ipnServer.BlockWhileInUseByOtherForTest(ctx, actor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckCurrentUser fails the test if the current user does not match the expected user.
|
||||||
|
// It is only used on Windows and will be removed as we progress on tailscale/corp#18342.
|
||||||
|
func (s *Server) CheckCurrentUser(want ipnauth.Actor) {
|
||||||
|
s.tb.Helper()
|
||||||
|
var wantUID ipn.WindowsUserID
|
||||||
|
if want != nil {
|
||||||
|
wantUID = want.UserID()
|
||||||
|
}
|
||||||
|
lb := s.Backend()
|
||||||
|
if lb == nil {
|
||||||
|
s.tb.Fatalf("Backend: nil")
|
||||||
|
}
|
||||||
|
gotUID, gotActor := lb.CurrentUserForTest()
|
||||||
|
if gotUID != wantUID {
|
||||||
|
s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID)
|
||||||
|
}
|
||||||
|
if hasActor := gotActor != nil; hasActor != (want != nil) || (want != nil && gotActor != want) {
|
||||||
|
s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
actor, err := getActorForRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
s.tb.Errorf("getActorForRequest: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := ipnserver.NewContextWithActorForTest(r.Context(), actor)
|
||||||
|
s.ipnServer.ServeHTTPForTest(w, r.Clone(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the server and blocks until all outstanding requests on this server have completed.
|
||||||
|
func (s *Server) Close() {
|
||||||
|
s.tb.Helper()
|
||||||
|
s.mu.Lock()
|
||||||
|
server := s.httpServer
|
||||||
|
s.httpServer = nil
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if server != nil {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
s.cancelCtx()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUnstartedIPNServer returns a new [ipnserver.Server] that exposes
|
||||||
|
// the specified [ipnlocal.LocalBackend] via LocalAPI, but does not start it.
|
||||||
|
// The opts carry additional configuration options.
|
||||||
|
func newUnstartedIPNServer(opts *options) *ipnserver.Server {
|
||||||
|
opts.TB().Helper()
|
||||||
|
lb := opts.Backend()
|
||||||
|
server := ipnserver.New(opts.Logf(), logid.PublicID{}, lb.NetMon())
|
||||||
|
server.SetLocalBackend(lb)
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
// roundTripper is a [http.RoundTripper] that sends requests to a [Server]
|
||||||
|
// on behalf of the [Client] who owns it.
|
||||||
|
type roundTripper struct {
|
||||||
|
client *Client
|
||||||
|
transport http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRoundTripper returns a new [http.RoundTripper] that sends requests
|
||||||
|
// to the specified server as the specified client.
|
||||||
|
func newRoundTripper(client *Client, server *httptest.Server) http.RoundTripper {
|
||||||
|
return &roundTripper{
|
||||||
|
client: client,
|
||||||
|
transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
var std net.Dialer
|
||||||
|
return std.DialContext(ctx, network, server.Listener.Addr().(*net.TCPAddr).String())
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestIDHeaderName is the name of the header used to pass request IDs
|
||||||
|
// between the client and server. It is used to associate requests with their actors.
|
||||||
|
const requestIDHeaderName = "TS-Request-ID"
|
||||||
|
|
||||||
|
// RoundTrip implements [http.RoundTripper] by sending the request to the [ipnserver.Server]
|
||||||
|
// on behalf of the owning [Client]. It registers each request for the duration
|
||||||
|
// of the call and associates it with the actor sending the request.
|
||||||
|
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
reqID, unregister := registerRequest(rt.client.Actor)
|
||||||
|
defer unregister()
|
||||||
|
r = r.Clone(r.Context())
|
||||||
|
r.Header.Set(requestIDHeaderName, reqID)
|
||||||
|
return rt.transport.RoundTrip(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getActorForRequest returns the actor for a given request.
|
||||||
|
// It returns an error if the request is not associated with an actor,
|
||||||
|
// such as when it wasn't sent by a [roundTripper].
|
||||||
|
func getActorForRequest(r *http.Request) (ipnauth.Actor, error) {
|
||||||
|
reqID := r.Header.Get(requestIDHeaderName)
|
||||||
|
if reqID == "" {
|
||||||
|
return nil, fmt.Errorf("missing %s header", requestIDHeaderName)
|
||||||
|
}
|
||||||
|
actor, ok := getActorByRequestID(reqID)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown request: %s", reqID)
|
||||||
|
}
|
||||||
|
return actor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
inFlightRequestsMu sync.Mutex
|
||||||
|
inFlightRequests map[string]ipnauth.Actor
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerRequest associates a request with the specified actor and returns a unique request ID
|
||||||
|
// which can be used to retrieve the actor later. The returned function unregisters the request.
|
||||||
|
func registerRequest(actor ipnauth.Actor) (requestID string, unregister func()) {
|
||||||
|
inFlightRequestsMu.Lock()
|
||||||
|
defer inFlightRequestsMu.Unlock()
|
||||||
|
for {
|
||||||
|
requestID = rands.HexString(16)
|
||||||
|
if _, ok := inFlightRequests[requestID]; !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mak.Set(&inFlightRequests, requestID, actor)
|
||||||
|
return requestID, func() {
|
||||||
|
inFlightRequestsMu.Lock()
|
||||||
|
defer inFlightRequestsMu.Unlock()
|
||||||
|
delete(inFlightRequests, requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getActorByRequestID returns the actor associated with the specified request ID.
|
||||||
|
// It returns the actor and true if found, or nil and false if not.
|
||||||
|
func getActorByRequestID(requestID string) (ipnauth.Actor, bool) {
|
||||||
|
inFlightRequestsMu.Lock()
|
||||||
|
defer inFlightRequestsMu.Unlock()
|
||||||
|
actor, ok := inFlightRequests[requestID]
|
||||||
|
return actor, ok
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user