diff --git a/ipn/ipnserver/actor.go b/ipn/ipnserver/actor.go index 9c203fc5f..dd40924bb 100644 --- a/ipn/ipnserver/actor.go +++ b/ipn/ipnserver/actor.go @@ -179,6 +179,12 @@ func contextWithActor(ctx context.Context, logf logger.Logf, c net.Conn) context return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err}) } +// NewContextWithActorForTest returns a new context that carries the identity +// of the specified actor. It is used in tests only. +func NewContextWithActorForTest(ctx context.Context, actor ipnauth.Actor) context.Context { + return actorKey.WithValue(ctx, actorOrError{actor: actor}) +} + // actorFromContext returns an [ipnauth.Actor] associated with ctx, // or an error if the context does not carry an actor's identity. func actorFromContext(ctx context.Context) (ipnauth.Actor, error) { diff --git a/ipn/ipnserver/server_fortest.go b/ipn/ipnserver/server_fortest.go new file mode 100644 index 000000000..9aab3b276 --- /dev/null +++ b/ipn/ipnserver/server_fortest.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "net/http" + + "tailscale.com/ipn/ipnauth" +) + +// BlockWhileInUseByOtherForTest blocks while the actor can't connect to the server because +// the server is in use by a different actor. It is used in tests only. +func (s *Server) BlockWhileInUseByOtherForTest(ctx context.Context, actor ipnauth.Actor) error { + return s.blockWhileIdentityInUse(ctx, actor) +} + +// BlockWhileInUseForTest blocks until the server becomes idle (no active requests), +// or the specified context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUseForTest(ctx context.Context) error { + ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) + + s.mu.Lock() + busy := len(s.activeReqs) != 0 + s.mu.Unlock() + + if busy { + <-ready + } + cleanup() + return ctx.Err() +} + +// ServeHTTPForTest responds to a single LocalAPI HTTP request. +// The request's context carries the actor that made the request +// and can be created with [NewContextWithActorForTest]. +// It is used in tests only. +func (s *Server) ServeHTTPForTest(w http.ResponseWriter, r *http.Request) { + s.serveHTTP(w, r) +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index 9340fd1c6..903cb6b73 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,76 +1,22 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package ipnserver +package ipnserver_test import ( "context" - "encoding/json" - "errors" - "fmt" - "net" - "net/http" - "net/http/httptest" "runtime" "strconv" "sync" - "sync/atomic" "testing" "tailscale.com/client/local" - "tailscale.com/client/tailscale" - "tailscale.com/client/tailscale/apitype" - "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/ipn" - "tailscale.com/ipn/ipnauth" - "tailscale.com/ipn/ipnlocal" - "tailscale.com/ipn/store/mem" - "tailscale.com/tsd" - "tailscale.com/tstest" - "tailscale.com/types/logger" - "tailscale.com/types/logid" + "tailscale.com/ipn/lapitest" "tailscale.com/types/ptr" - "tailscale.com/util/mak" - "tailscale.com/wgengine" ) -func TestWaiterSet(t *testing.T) { - var s waiterSet - - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) - } - } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) - - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") - - select { - case <-ready: - t.Fatal("should not be ready") - default: - } - s.wakeAll() - <-ready - - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") - - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") -} - func TestUserConnectDisconnectNonWindows(t *testing.T) { enableLogging := false if runtime.GOOS == "windows" { @@ -78,20 +24,20 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) { } ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) // UserA connects and starts watching the IPN bus. - clientA := server.getClientAs("UserA") + clientA := server.ClientWithName("UserA") watcherA, _ := clientA.WatchIPNBus(ctx, 0) // The concept of "current user" is only relevant on Windows // and it should not be set on non-Windows platforms. - server.checkCurrentUser(nil) + server.CheckCurrentUser(nil) // Additionally, a different user should be able to connect and use the LocalAPI. - clientB := server.getClientAs("UserB") + clientB := server.ClientWithName("UserB") if _, gotErr := clientB.Status(ctx); gotErr != nil { - t.Fatalf("Status(%q): want nil; got %v", clientB.User.Name, gotErr) + t.Fatalf("Status(%q): want nil; got %v", clientB.Username(), gotErr) } // Watching the IPN bus should also work for UserB. @@ -100,18 +46,18 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) { // And if we send a notification, both users should receive it. wantErrMessage := "test error" testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)} - server.mustBackend().DebugNotify(testNotify) + server.Backend().DebugNotify(testNotify) if n, err := watcherA.Next(); err != nil { - t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.User.Name, err) + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.Username(), err) } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { - t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.User.Name, wantErrMessage, gotErrMessage) + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.Username(), wantErrMessage, gotErrMessage) } if n, err := watcherB.Next(); err != nil { - t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.User.Name, err) + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.Username(), err) } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { - t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.User.Name, wantErrMessage, gotErrMessage) + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.Username(), wantErrMessage, gotErrMessage) } } @@ -120,21 +66,21 @@ func TestUserConnectDisconnectOnWindows(t *testing.T) { setGOOSForTest(t, "windows") ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) - client := server.getClientAs("User") + client := server.ClientWithName("User") _, cancelWatcher := client.WatchIPNBus(ctx, 0) // On Windows, however, the current user should be set to the user that connected. - server.checkCurrentUser(client.User) + server.CheckCurrentUser(client.Actor) // Cancel the IPN bus watcher request and wait for the server to unblock. cancelWatcher() - server.blockWhileInUse(ctx) + server.BlockWhileInUse(ctx) // The current user should not be set after a disconnect, as no one is // currently using the server. - server.checkCurrentUser(nil) + server.CheckCurrentUser(nil) } func TestIPNAlreadyInUseOnWindows(t *testing.T) { @@ -142,22 +88,22 @@ func TestIPNAlreadyInUseOnWindows(t *testing.T) { setGOOSForTest(t, "windows") ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) // UserA connects and starts watching the IPN bus. - clientA := server.getClientAs("UserA") + clientA := server.ClientWithName("UserA") clientA.WatchIPNBus(ctx, 0) // While UserA is connected, UserB should not be able to connect. - clientB := server.getClientAs("UserB") + clientB := server.ClientWithName("UserB") if _, gotErr := clientB.Status(ctx); gotErr == nil { - t.Fatalf("Status(%q): want error; got nil", clientB.User.Name) + t.Fatalf("Status(%q): want error; got nil", clientB.Username()) } else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError { - t.Fatalf("Status(%q): want %q; got %q", clientB.User.Name, wantError, gotErr.Error()) + t.Fatalf("Status(%q): want %q; got %q", clientB.Username(), wantError, gotErr.Error()) } // Current user should still be UserA. - server.checkCurrentUser(clientA.User) + server.CheckCurrentUser(clientA.Actor) } func TestSequentialOSUserSwitchingOnWindows(t *testing.T) { @@ -165,22 +111,22 @@ func TestSequentialOSUserSwitchingOnWindows(t *testing.T) { setGOOSForTest(t, "windows") ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) connectDisconnectAsUser := func(name string) { // User connects and starts watching the IPN bus. - client := server.getClientAs(name) + client := server.ClientWithName(name) watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) defer cancelWatcher() go pumpIPNBus(watcher) // It should be the current user from the LocalBackend's perspective... - server.checkCurrentUser(client.User) + server.CheckCurrentUser(client.Actor) // until it disconnects. cancelWatcher() - server.blockWhileInUse(ctx) + server.BlockWhileInUse(ctx) // Now, the current user should be unset. - server.checkCurrentUser(nil) + server.CheckCurrentUser(nil) } // UserA logs in, uses Tailscale for a bit, then logs out. @@ -194,11 +140,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { setGOOSForTest(t, "windows") ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) connectDisconnectAsUser := func(name string) { // User connects and starts watching the IPN bus. - client := server.getClientAs(name) + client := server.ClientWithName(name) watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState) defer cancelWatcher() @@ -206,7 +152,7 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { // Get the current user from the LocalBackend's perspective // as soon as we're connected. - gotUID, gotActor := server.mustBackend().CurrentUserForTest() + gotUID, gotActor := server.Backend().CurrentUserForTest() // Wait for the first notification to arrive. // It will either be the initial state we've requested via [ipn.NotifyInitialState], @@ -225,17 +171,17 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { } // Otherwise, our user should have been the current user since the time we connected. - if gotUID != client.User.UID { - t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.User.UID) + if gotUID != client.Actor.UserID() { + t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.Actor.UserID()) return } - if gotActor, ok := gotActor.(*ipnauth.TestActor); !ok || *gotActor != *client.User { - t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.User) + if hasActor := gotActor != nil; !hasActor || gotActor != client.Actor { + t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.Actor) return } // And should still be the current user (as they're still connected)... - server.checkCurrentUser(client.User) + server.CheckCurrentUser(client.Actor) } numIterations := 10 @@ -253,11 +199,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { } wg.Wait() - if err := server.blockWhileInUse(ctx); err != nil { - t.Fatalf("blockWhileInUse: %v", err) + if err := server.BlockWhileInUse(ctx); err != nil { + t.Fatalf("BlockUntilIdle: %v", err) } - server.checkCurrentUser(nil) + server.CheckCurrentUser(nil) } } @@ -266,13 +212,13 @@ func TestBlockWhileIdentityInUse(t *testing.T) { setGOOSForTest(t, "windows") ctx := context.Background() - server := startDefaultTestIPNServer(t, ctx, enableLogging) + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) // connectWaitDisconnectAsUser connects as a user with the specified name // and keeps the IPN bus watcher alive until the context is canceled. // It returns a channel that is closed when done. connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} { - client := server.getClientAs(name) + client := server.ClientWithName(name) watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) done := make(chan struct{}) @@ -301,8 +247,8 @@ func TestBlockWhileIdentityInUse(t *testing.T) { // in blockWhileIdentityInUse. But the issue also occurs during // the normal execution path when UserB connects to the IPN server // while UserA is disconnecting. - userB := server.makeTestUser("UserB", "ClientB") - server.blockWhileIdentityInUse(ctx, userB) + userB := server.MakeTestActor("UserB", "ClientB") + server.BlockWhileInUseByOther(ctx, userB) <-userADone } } @@ -313,41 +259,7 @@ func setGOOSForTest(tb testing.TB, goos string) { tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") }) } -func testLogger(tb testing.TB, enableLogging bool) logger.Logf { - tb.Helper() - if enableLogging { - return tstest.WhileTestRunningLogger(tb) - } - return logger.Discard -} - -// newTestIPNServer creates a new IPN server for testing, using the specified local backend. -func newTestIPNServer(tb testing.TB, lb *ipnlocal.LocalBackend, enableLogging bool) *Server { - tb.Helper() - server := New(testLogger(tb, enableLogging), logid.PublicID{}, lb.NetMon()) - server.lb.Store(lb) - return server -} - -type testIPNClient struct { - tb testing.TB - *local.Client - User *ipnauth.TestActor -} - -func (c *testIPNClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, context.CancelFunc) { - c.tb.Helper() - ctx, cancelWatcher := context.WithCancel(ctx) - c.tb.Cleanup(cancelWatcher) - watcher, err := c.Client.WatchIPNBus(ctx, mask) - if err != nil { - c.tb.Fatalf("WatchIPNBus(%q): %v", c.User.Name, err) - } - c.tb.Cleanup(func() { watcher.Close() }) - return watcher, cancelWatcher -} - -func pumpIPNBus(watcher *tailscale.IPNBusWatcher) { +func pumpIPNBus(watcher *local.IPNBusWatcher) { for { _, err := watcher.Next() if err != nil { @@ -355,206 +267,3 @@ func pumpIPNBus(watcher *tailscale.IPNBusWatcher) { } } } - -type testIPNServer struct { - tb testing.TB - *Server - clientID atomic.Int64 - getClient func(*ipnauth.TestActor) *local.Client - - actorsMu sync.Mutex - actors map[string]*ipnauth.TestActor -} - -func (s *testIPNServer) getClientAs(name string) *testIPNClient { - clientID := fmt.Sprintf("Client-%d", 1+s.clientID.Add(1)) - user := s.makeTestUser(name, clientID) - return &testIPNClient{ - tb: s.tb, - Client: s.getClient(user), - User: user, - } -} - -func (s *testIPNServer) makeTestUser(name string, clientID string) *ipnauth.TestActor { - s.actorsMu.Lock() - defer s.actorsMu.Unlock() - actor := s.actors[name] - if actor == nil { - actor = &ipnauth.TestActor{Name: name} - if envknob.GOOS() == "windows" { - // Historically, as of 2025-01-13, IPN does not distinguish between - // different users on non-Windows devices. Therefore, the UID, which is - // an [ipn.WindowsUserID], should only be populated when the actual or - // fake GOOS is Windows. - actor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actors))) - } - mak.Set(&s.actors, name, actor) - s.tb.Cleanup(func() { delete(s.actors, name) }) - } - actor = ptr.To(*actor) - actor.CID = ipnauth.ClientIDFrom(clientID) - return actor -} - -func (s *testIPNServer) blockWhileInUse(ctx context.Context) error { - ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) - - s.mu.Lock() - busy := len(s.activeReqs) != 0 - s.mu.Unlock() - - if busy { - <-ready - } - cleanup() - return ctx.Err() -} - -func (s *testIPNServer) checkCurrentUser(want *ipnauth.TestActor) { - s.tb.Helper() - var wantUID ipn.WindowsUserID - if want != nil { - wantUID = want.UID - } - gotUID, gotActor := s.mustBackend().CurrentUserForTest() - if gotUID != wantUID { - s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID) - } - if gotActor, ok := gotActor.(*ipnauth.TestActor); ok != (want != nil) || (want != nil && *gotActor != *want) { - s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want) - } -} - -// startTestIPNServer starts a [httptest.Server] that hosts the specified IPN server for the -// duration of the test, using the specified base context for incoming requests. -// It returns a function that creates a [local.Client] as a given [ipnauth.TestActor]. -func startTestIPNServer(tb testing.TB, baseContext context.Context, server *Server) *testIPNServer { - tb.Helper() - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actor, err := extractActorFromHeader(r.Header) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - tb.Errorf("extractActorFromHeader: %v", err) - return - } - ctx := newTestContextWithActor(r.Context(), actor) - server.serveHTTP(w, r.Clone(ctx)) - })) - ts.Config.Addr = "http://" + apitype.LocalAPIHost - ts.Config.BaseContext = func(_ net.Listener) context.Context { return baseContext } - ts.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(server.logf, "ipnserver: ")) - ts.Start() - tb.Cleanup(ts.Close) - return &testIPNServer{ - tb: tb, - Server: server, - getClient: func(actor *ipnauth.TestActor) *local.Client { - return &local.Client{Transport: newTestRoundTripper(ts, actor)} - }, - } -} - -func startDefaultTestIPNServer(tb testing.TB, ctx context.Context, enableLogging bool) *testIPNServer { - tb.Helper() - lb := newLocalBackendWithTestControl(tb, newUnreachableControlClient, enableLogging) - ctx, stopServer := context.WithCancel(ctx) - tb.Cleanup(stopServer) - return startTestIPNServer(tb, ctx, newTestIPNServer(tb, lb, enableLogging)) -} - -type testRoundTripper struct { - transport http.RoundTripper - actor *ipnauth.TestActor -} - -// newTestRoundTripper creates a new [http.RoundTripper] that sends requests -// to the specified test server as the specified actor. -func newTestRoundTripper(ts *httptest.Server, actor *ipnauth.TestActor) *testRoundTripper { - return &testRoundTripper{ - transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - var std net.Dialer - return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String()) - }}, - actor: actor, - } -} - -const testActorHeaderName = "TS-Test-Actor" - -// RoundTrip implements [http.RoundTripper] by forwarding the request to the underlying transport -// and including the test actor's identity in the request headers. -func (rt *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - actorJSON, err := json.Marshal(&rt.actor) - if err != nil { - // An [http.RoundTripper] must always close the request body, including on error. - if r.Body != nil { - r.Body.Close() - } - return nil, err - } - - r = r.Clone(r.Context()) - r.Header.Set(testActorHeaderName, string(actorJSON)) - return rt.transport.RoundTrip(r) -} - -// extractActorFromHeader extracts a test actor from the specified request headers. -func extractActorFromHeader(h http.Header) (*ipnauth.TestActor, error) { - actorJSON := h.Get(testActorHeaderName) - if actorJSON == "" { - return nil, errors.New("missing Test-Actor header") - } - actor := &ipnauth.TestActor{} - if err := json.Unmarshal([]byte(actorJSON), &actor); err != nil { - return nil, fmt.Errorf("invalid Test-Actor header: %v", err) - } - return actor, nil -} - -type newControlClientFn func(tb testing.TB, opts controlclient.Options) controlclient.Client - -func newLocalBackendWithTestControl(tb testing.TB, newControl newControlClientFn, enableLogging bool) *ipnlocal.LocalBackend { - tb.Helper() - - sys := tsd.NewSystem() - store := &mem.Store{} - sys.Set(store) - - logf := testLogger(tb, enableLogging) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) - if err != nil { - tb.Fatalf("NewFakeUserspaceEngine: %v", err) - } - tb.Cleanup(e.Close) - sys.Set(e) - - b, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0) - if err != nil { - tb.Fatalf("NewLocalBackend: %v", err) - } - tb.Cleanup(b.Shutdown) - b.DisablePortMapperForTest() - - b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { - return newControl(tb, opts), nil - }) - return b -} - -func newUnreachableControlClient(tb testing.TB, opts controlclient.Options) controlclient.Client { - tb.Helper() - opts.ServerURL = "https://127.0.0.1:1" - cc, err := controlclient.New(opts) - if err != nil { - tb.Fatal(err) - } - return cc -} - -// newTestContextWithActor returns a new context that carries the identity -// of the specified actor and can be used for testing. -// It can be retrieved with [actorFromContext]. -func newTestContextWithActor(ctx context.Context, actor ipnauth.Actor) context.Context { - return actorKey.WithValue(ctx, actorOrError{actor: actor}) -} diff --git a/ipn/ipnserver/waiterset_test.go b/ipn/ipnserver/waiterset_test.go new file mode 100644 index 000000000..b7d5ea144 --- /dev/null +++ b/ipn/ipnserver/waiterset_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/lapitest/backend.go b/ipn/lapitest/backend.go new file mode 100644 index 000000000..ddf48fb28 --- /dev/null +++ b/ipn/lapitest/backend.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logid" + "tailscale.com/wgengine" +) + +// NewBackend returns a new [ipnlocal.LocalBackend] for testing purposes. +// It fails the test if the specified options are invalid or if the backend cannot be created. +func NewBackend(tb testing.TB, opts ...Option) *ipnlocal.LocalBackend { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("NewBackend: %v", err) + } + return newBackend(options) +} + +func newBackend(opts *options) *ipnlocal.LocalBackend { + tb := opts.TB() + tb.Helper() + + sys := opts.Sys() + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(&mem.Store{}) + } + + e, err := wgengine.NewFakeUserspaceEngine(opts.Logf(), sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + opts.tb.Fatalf("NewFakeUserspaceEngine: %v", err) + } + tb.Cleanup(e.Close) + sys.Set(e) + + b, err := ipnlocal.NewLocalBackend(opts.Logf(), logid.PublicID{}, sys, 0) + if err != nil { + tb.Fatalf("NewLocalBackend: %v", err) + } + tb.Cleanup(b.Shutdown) + b.DisablePortMapperForTest() + b.SetControlClientGetterForTesting(opts.MakeControlClient) + return b +} + +// NewUnreachableControlClient is a [NewControlFn] that creates +// a new [controlclient.Client] for an unreachable control server. +func NewUnreachableControlClient(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) { + tb.Helper() + opts.ServerURL = "https://127.0.0.1:1" + cc, err := controlclient.New(opts) + if err != nil { + tb.Fatal(err) + } + return cc, nil +} diff --git a/ipn/lapitest/client.go b/ipn/lapitest/client.go new file mode 100644 index 000000000..6d22e938b --- /dev/null +++ b/ipn/lapitest/client.go @@ -0,0 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" +) + +// Client wraps a [local.Client] for testing purposes. +// It can be created using [Server.Client], [Server.ClientWithName], +// or [Server.ClientFor] and sends requests as the specified actor +// to the associated [Server]. +type Client struct { + tb testing.TB + // Client is the underlying [local.Client] wrapped by the test client. + // It is configured to send requests to the test server on behalf of the actor. + *local.Client + // Actor represents the user on whose behalf this client is making requests. + // The server uses it to determine the client's identity and permissions. + // The test can mutate the user to alter the actor's identity or permissions + // before making a new request. It is typically an [ipnauth.TestActor], + // unless the [Client] was created with s specific actor using [Server.ClientFor]. + Actor ipnauth.Actor +} + +// Username returns username of the client's owner. +func (c *Client) Username() string { + c.tb.Helper() + name, err := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.Username: %v", err) + } + return name +} + +// WatchIPNBus is like [local.Client.WatchIPNBus] but returns a [local.IPNBusWatcher] +// that is closed when the test ends and a cancel function that stops the watcher. +// It fails the test if the underlying WatchIPNBus returns an error. +func (c *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, context.CancelFunc) { + c.tb.Helper() + ctx, cancelWatcher := context.WithCancel(ctx) + c.tb.Cleanup(cancelWatcher) + watcher, err := c.Client.WatchIPNBus(ctx, mask) + name, _ := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.WatchIPNBus(%q): %v", name, err) + } + c.tb.Cleanup(func() { watcher.Close() }) + return watcher, cancelWatcher +} + +// generateSequentialName generates a unique sequential name based on the given prefix and number n. +// It uses a base-26 encoding to create names like "User-A", "User-B", ..., "User-Z", "User-AA", etc. +func generateSequentialName(prefix string, n int) string { + n++ + name := "" + const numLetters = 'Z' - 'A' + 1 + for n > 0 { + n-- + remainder := byte(n % numLetters) + name = string([]byte{'A' + remainder}) + name + n = n / numLetters + } + return prefix + "-" + name +} diff --git a/ipn/lapitest/example_test.go b/ipn/lapitest/example_test.go new file mode 100644 index 000000000..57479199a --- /dev/null +++ b/ipn/lapitest/example_test.go @@ -0,0 +1,80 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/ipn" +) + +func TestClientServer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create a server and two clients. + // Both clients represent the same user to make this work across platforms. + // On Windows we've been restricting the API usage to a single user at a time. + // While we're planning on changing this once a better permission model is in place, + // this test is currently limited to a single user (but more than one client is fine). + // Alternatively, we could override GOOS via envknobs to test as if we're + // on a different platform, but that would make the test depend on global state, etc. + s := NewServer(t, WithLogging(false)) + c1 := s.ClientWithName("User-A") + c2 := s.ClientWithName("User-A") + + // Start watching the IPN bus as the second client. + w2, _ := c2.WatchIPNBus(context.Background(), ipn.NotifyInitialPrefs) + + // We're supposed to get a notification about the initial prefs, + // and WantRunning should be false. + n, err := w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if n.Prefs.WantRunning() { + t.Errorf("WantRunning(initial): got %v, want false", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } + + // Now send an EditPrefs request from the first client to set WantRunning to true. + change := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + gotPrefs, err := c1.EditPrefs(ctx, change) + if err != nil { + t.Fatalf("EditPrefs failed: %v", err) + } + if !gotPrefs.WantRunning { + t.Fatalf("EditPrefs.WantRunning: got %v, want true", gotPrefs.WantRunning) + } + + // We can check the backend directly to see if the prefs were set correctly. + if gotWantRunning := s.Backend().Prefs().WantRunning(); !gotWantRunning { + t.Fatalf("Backend.Prefs.WantRunning: got %v, want true", gotWantRunning) + } + + // And can also wait for the second client with an IPN bus watcher to receive the notification + // about the prefs change. + n, err = w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if !n.Prefs.WantRunning() { + t.Fatalf("WantRunning(changed): got %v, want true", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } +} diff --git a/ipn/lapitest/opts.go b/ipn/lapitest/opts.go new file mode 100644 index 000000000..6eb1594da --- /dev/null +++ b/ipn/lapitest/opts.go @@ -0,0 +1,170 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "errors" + "fmt" + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" +) + +// Option is any optional configuration that can be passed to [NewServer] or [NewBackend]. +type Option interface { + apply(*options) error +} + +// options is the merged result of all applied [Option]s. +type options struct { + tb testing.TB + ctx lazy.SyncValue[context.Context] + logf lazy.SyncValue[logger.Logf] + sys lazy.SyncValue[*tsd.System] + newCC lazy.SyncValue[NewControlFn] + backend lazy.SyncValue[*ipnlocal.LocalBackend] +} + +// newOptions returns a new [options] struct with the specified [Option]s applied. +func newOptions(tb testing.TB, opts ...Option) (*options, error) { + options := &options{tb: tb} + for _, opt := range opts { + if err := opt.apply(options); err != nil { + return nil, fmt.Errorf("lapitest: %w", err) + } + } + return options, nil +} + +// TB returns the owning [*testing.T] or [*testing.B]. +func (o *options) TB() testing.TB { + return o.tb +} + +// Context returns the base context to be used by the server. +func (o *options) Context() context.Context { + return o.ctx.Get(context.Background) +} + +// Logf returns the [logger.Logf] to be used for logging. +func (o *options) Logf() logger.Logf { + return o.logf.Get(func() logger.Logf { return logger.Discard }) +} + +// Sys returns the [tsd.System] that contains subsystems to be used +// when creating a new [ipnlocal.LocalBackend]. +func (o *options) Sys() *tsd.System { + return o.sys.Get(func() *tsd.System { return tsd.NewSystem() }) +} + +// Backend returns the [ipnlocal.LocalBackend] to be used by the server. +// If a backend is provided via [WithBackend], it is used as-is. +// Otherwise, a new backend is created with the the [options] in o. +func (o *options) Backend() *ipnlocal.LocalBackend { + return o.backend.Get(func() *ipnlocal.LocalBackend { return newBackend(o) }) +} + +// MakeControlClient returns a new [controlclient.Client] to be used by newly +// created [ipnlocal.LocalBackend]s. It is only used if no backend is provided +// via [WithBackend]. +func (o *options) MakeControlClient(opts controlclient.Options) (controlclient.Client, error) { + newCC := o.newCC.Get(func() NewControlFn { return NewUnreachableControlClient }) + return newCC(o.tb, opts) +} + +type loggingOption struct{ enableLogging bool } + +// WithLogging returns an [Option] that enables or disables logging. +func WithLogging(enableLogging bool) Option { + return loggingOption{enableLogging: enableLogging} +} + +func (o loggingOption) apply(opts *options) error { + var logf logger.Logf + if o.enableLogging { + logf = tstest.WhileTestRunningLogger(opts.tb) + } else { + logf = logger.Discard + } + if !opts.logf.Set(logf) { + return errors.New("logging already configured") + } + return nil +} + +type contextOption struct{ ctx context.Context } + +// WithContext returns an [Option] that sets the base context to be used by the [Server]. +func WithContext(ctx context.Context) Option { + return contextOption{ctx: ctx} +} + +func (o contextOption) apply(opts *options) error { + if !opts.ctx.Set(o.ctx) { + return errors.New("context already configured") + } + return nil +} + +type sysOption struct{ sys *tsd.System } + +// WithSys returns an [Option] that sets the [tsd.System] to be used +// when creating a new [ipnlocal.LocalBackend]. +func WithSys(sys *tsd.System) Option { + return sysOption{sys: sys} +} + +func (o sysOption) apply(opts *options) error { + if !opts.sys.Set(o.sys) { + return errors.New("tsd.System already configured") + } + return nil +} + +type backendOption struct{ backend *ipnlocal.LocalBackend } + +// WithBackend returns an [Option] that configures the server to use the specified +// [ipnlocal.LocalBackend] instead of creating a new one. +// It is mutually exclusive with [WithControlClient]. +func WithBackend(backend *ipnlocal.LocalBackend) Option { + return backendOption{backend: backend} +} + +func (o backendOption) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("backend cannot be set when control client is already set") + } + if !opts.backend.Set(o.backend) { + return errors.New("backend already set") + } + return nil +} + +// NewControlFn is any function that creates a new [controlclient.Client] +// with the specified options. +type NewControlFn func(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) + +// WithControlClient returns an option that specifies a function to be used +// by the [ipnlocal.LocalBackend] when creating a new [controlclient.Client]. +// It is mutually exclusive with [WithBackend] and is only used if no backend +// has been provided. +func WithControlClient(newControl NewControlFn) Option { + return newControl +} + +func (fn NewControlFn) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("control client cannot be set when backend is already set") + } + if !opts.newCC.Set(fn) { + return errors.New("control client already set") + } + return nil +} diff --git a/ipn/lapitest/server.go b/ipn/lapitest/server.go new file mode 100644 index 000000000..d477dc182 --- /dev/null +++ b/ipn/lapitest/server.go @@ -0,0 +1,324 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lapitest provides utilities for black-box testing of LocalAPI ([ipnserver]). +package lapitest + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnserver" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + "tailscale.com/util/rands" +) + +// A Server is an in-process LocalAPI server that can be used in end-to-end tests. +type Server struct { + tb testing.TB + + ctx context.Context + cancelCtx context.CancelFunc + + lb *ipnlocal.LocalBackend + ipnServer *ipnserver.Server + + // mu protects the following fields. + mu sync.Mutex + started bool + httpServer *httptest.Server + actorsByName map[string]*ipnauth.TestActor + lastClientID int +} + +// NewUnstartedServer returns a new [Server] with the specified options without starting it. +func NewUnstartedServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("invalid options: %v", err) + } + + s := &Server{tb: tb, lb: options.Backend()} + s.ctx, s.cancelCtx = context.WithCancel(options.Context()) + s.ipnServer = newUnstartedIPNServer(options) + s.httpServer = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP)) + s.httpServer.Config.Addr = "http://" + apitype.LocalAPIHost + s.httpServer.Config.BaseContext = func(_ net.Listener) context.Context { return s.ctx } + s.httpServer.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(options.Logf(), "lapitest: ")) + tb.Cleanup(s.Close) + return s +} + +// NewServer starts and returns a new [Server] with the specified options. +func NewServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + server := NewUnstartedServer(tb, opts...) + server.Start() + return server +} + +// Start starts the server from [NewUnstartedServer]. +func (s *Server) Start() { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if !s.started && s.httpServer != nil { + s.httpServer.Start() + s.started = true + } +} + +// Backend returns the underlying [ipnlocal.LocalBackend]. +func (s *Server) Backend() *ipnlocal.LocalBackend { + s.tb.Helper() + return s.lb +} + +// Client returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with a unique username and [ipnauth.ClientID]. +func (s *Server) Client() *Client { + s.tb.Helper() + user := s.MakeTestActor("", "") // generate a unique username and client ID + return s.ClientFor(user) +} + +// ClientWithName returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with the specified name and a unique [ipnauth.ClientID]. +func (s *Server) ClientWithName(name string) *Client { + s.tb.Helper() + user := s.MakeTestActor(name, "") // generate a unique client ID + return s.ClientFor(user) +} + +// ClientFor returns a new [Client] configured for making requests to the server +// as the specified actor. +func (s *Server) ClientFor(actor ipnauth.Actor) *Client { + s.tb.Helper() + client := &Client{ + tb: s.tb, + Actor: actor, + } + client.Client = &local.Client{Transport: newRoundTripper(client, s.httpServer)} + return client +} + +// MakeTestActor returns a new [ipnauth.TestActor] with the specified name and client ID. +// If the name is empty, a unique sequential name is generated. Likewise, +// if clientID is empty, a unique sequential client ID is generated. +func (s *Server) MakeTestActor(name string, clientID string) *ipnauth.TestActor { + s.tb.Helper() + + s.mu.Lock() + defer s.mu.Unlock() + + // Generate a unique sequential name if the provided name is empty. + if name == "" { + n := len(s.actorsByName) + name = generateSequentialName("User", n) + } + + if clientID == "" { + s.lastClientID += 1 + clientID = fmt.Sprintf("Client-%d", s.lastClientID) + } + + // Create a new base actor if one doesn't already exist for the given name. + baseActor := s.actorsByName[name] + if baseActor == nil { + baseActor = &ipnauth.TestActor{Name: name} + if envknob.GOOS() == "windows" { + // Historically, as of 2025-04-15, IPN does not distinguish between + // different users on non-Windows devices. Therefore, the UID, which is + // an [ipn.WindowsUserID], should only be populated when the actual or + // fake GOOS is Windows. + baseActor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actorsByName))) + } + mak.Set(&s.actorsByName, name, baseActor) + s.tb.Cleanup(func() { delete(s.actorsByName, name) }) + } + + // Create a shallow copy of the base actor and assign it the new client ID. + actor := ptr.To(*baseActor) + actor.CID = ipnauth.ClientIDFrom(clientID) + return actor +} + +// BlockWhileInUse blocks until the server becomes idle (no active requests), +// or the context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUse(ctx context.Context) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseForTest(ctx) +} + +// BlockWhileInUseByOther blocks while the specified actor can't connect to the server +// due to another actor being connected. +// It is used in tests only. +func (s *Server) BlockWhileInUseByOther(ctx context.Context, actor ipnauth.Actor) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseByOtherForTest(ctx, actor) +} + +// CheckCurrentUser fails the test if the current user does not match the expected user. +// It is only used on Windows and will be removed as we progress on tailscale/corp#18342. +func (s *Server) CheckCurrentUser(want ipnauth.Actor) { + s.tb.Helper() + var wantUID ipn.WindowsUserID + if want != nil { + wantUID = want.UserID() + } + lb := s.Backend() + if lb == nil { + s.tb.Fatalf("Backend: nil") + } + gotUID, gotActor := lb.CurrentUserForTest() + if gotUID != wantUID { + s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID) + } + if hasActor := gotActor != nil; hasActor != (want != nil) || (want != nil && gotActor != want) { + s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want) + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { + actor, err := getActorForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + s.tb.Errorf("getActorForRequest: %v", err) + return + } + ctx := ipnserver.NewContextWithActorForTest(r.Context(), actor) + s.ipnServer.ServeHTTPForTest(w, r.Clone(ctx)) +} + +// Close shuts down the server and blocks until all outstanding requests on this server have completed. +func (s *Server) Close() { + s.tb.Helper() + s.mu.Lock() + server := s.httpServer + s.httpServer = nil + s.mu.Unlock() + + if server != nil { + server.Close() + } + s.cancelCtx() +} + +// newUnstartedIPNServer returns a new [ipnserver.Server] that exposes +// the specified [ipnlocal.LocalBackend] via LocalAPI, but does not start it. +// The opts carry additional configuration options. +func newUnstartedIPNServer(opts *options) *ipnserver.Server { + opts.TB().Helper() + lb := opts.Backend() + server := ipnserver.New(opts.Logf(), logid.PublicID{}, lb.NetMon()) + server.SetLocalBackend(lb) + return server +} + +// roundTripper is a [http.RoundTripper] that sends requests to a [Server] +// on behalf of the [Client] who owns it. +type roundTripper struct { + client *Client + transport http.RoundTripper +} + +// newRoundTripper returns a new [http.RoundTripper] that sends requests +// to the specified server as the specified client. +func newRoundTripper(client *Client, server *httptest.Server) http.RoundTripper { + return &roundTripper{ + client: client, + transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var std net.Dialer + return std.DialContext(ctx, network, server.Listener.Addr().(*net.TCPAddr).String()) + }}, + } +} + +// requestIDHeaderName is the name of the header used to pass request IDs +// between the client and server. It is used to associate requests with their actors. +const requestIDHeaderName = "TS-Request-ID" + +// RoundTrip implements [http.RoundTripper] by sending the request to the [ipnserver.Server] +// on behalf of the owning [Client]. It registers each request for the duration +// of the call and associates it with the actor sending the request. +func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + reqID, unregister := registerRequest(rt.client.Actor) + defer unregister() + r = r.Clone(r.Context()) + r.Header.Set(requestIDHeaderName, reqID) + return rt.transport.RoundTrip(r) +} + +// getActorForRequest returns the actor for a given request. +// It returns an error if the request is not associated with an actor, +// such as when it wasn't sent by a [roundTripper]. +func getActorForRequest(r *http.Request) (ipnauth.Actor, error) { + reqID := r.Header.Get(requestIDHeaderName) + if reqID == "" { + return nil, fmt.Errorf("missing %s header", requestIDHeaderName) + } + actor, ok := getActorByRequestID(reqID) + if !ok { + return nil, fmt.Errorf("unknown request: %s", reqID) + } + return actor, nil +} + +var ( + inFlightRequestsMu sync.Mutex + inFlightRequests map[string]ipnauth.Actor +) + +// registerRequest associates a request with the specified actor and returns a unique request ID +// which can be used to retrieve the actor later. The returned function unregisters the request. +func registerRequest(actor ipnauth.Actor) (requestID string, unregister func()) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + for { + requestID = rands.HexString(16) + if _, ok := inFlightRequests[requestID]; !ok { + break + } + } + mak.Set(&inFlightRequests, requestID, actor) + return requestID, func() { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + delete(inFlightRequests, requestID) + } +} + +// getActorByRequestID returns the actor associated with the specified request ID. +// It returns the actor and true if found, or nil and false if not. +func getActorByRequestID(requestID string) (ipnauth.Actor, bool) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + actor, ok := inFlightRequests[requestID] + return actor, ok +}