tailscale/control/controlclient/controlclient_test.go
James Tucker a96ef432cf control/controlclient,ipn/ipnlocal: replace State enum with boolean flags
Remove the State enum (StateNew, StateNotAuthenticated, etc.) from
controlclient and replace it with two explicit boolean fields:
- LoginFinished: indicates successful authentication
- Synced: indicates we've received at least one netmap

This makes the state more composable and easier to reason about, as
multiple conditions can be true independently rather than being
encoded in a single enum value.

The State enum was originally intended as the state machine for the
whole client, but that abstraction moved to ipn.Backend long ago.
This change continues moving away from the legacy state machine by
representing state as a combination of independent facts.

Also adds test helpers in ipnlocal that check independent, observable
facts (hasValidNetMap, needsLogin, etc.) rather than relying on
derived state enums, making tests more robust.

Updates #12639

Signed-off-by: James Tucker <james@tailscale.com>
2025-11-14 16:18:34 -08:00

437 lines
11 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package controlclient
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"reflect"
"sync/atomic"
"testing"
"time"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/net/bakedroots"
"tailscale.com/net/connectproxy"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/tstest/tlstest"
"tailscale.com/tstime"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/persist"
"tailscale.com/util/eventbus/eventbustest"
)
func fieldsOf(t reflect.Type) (fields []string) {
for i := range t.NumField() {
if name := t.Field(i).Name; name != "_" {
fields = append(fields, name)
}
}
return
}
func TestStatusEqual(t *testing.T) {
// Verify that the Equal method stays in sync with reality
equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
have, equalHandles)
}
tests := []struct {
a, b *Status
want bool
}{
{
&Status{},
nil,
false,
},
{
nil,
&Status{},
false,
},
{
nil,
nil,
true,
},
{
&Status{},
&Status{},
true,
},
{
&Status{},
&Status{LoggedIn: true, Persist: new(persist.Persist).View()},
false,
},
}
for i, tt := range tests {
got := tt.a.Equal(tt.b)
if got != tt.want {
t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
}
}
}
// tests [canSkipStatus].
func TestCanSkipStatus(t *testing.T) {
st := new(Status)
nm1 := &netmap.NetworkMap{}
nm2 := &netmap.NetworkMap{}
tests := []struct {
name string
s1, s2 *Status
want bool
}{
{
name: "nil-s2",
s1: st,
s2: nil,
want: false,
},
{
name: "equal",
s1: st,
s2: st,
want: false,
},
{
name: "s1-error",
s1: &Status{Err: io.EOF, NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-url",
s1: &Status{URL: "foo", NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-persist-diff",
s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-login-finished-diff",
s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-login-finished",
s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-synced-diff",
s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-no-netmap1",
s1: &Status{NetMap: nil},
s2: &Status{NetMap: nm2},
want: false,
},
{
name: "s1-no-netmap2",
s1: &Status{NetMap: nm1},
s2: &Status{NetMap: nil},
want: false,
},
{
name: "skip",
s1: &Status{NetMap: nm1},
s2: &Status{NetMap: nm2},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
}
})
}
coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) {
t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields)
}
}
func TestRetryableErrors(t *testing.T) {
errorTests := []struct {
err error
want bool
}{
{errNoNoiseClient, true},
{errNoNodeKey, true},
{fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
{fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
{fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
{errBadHTTPResponse(429, "too may requests"), true},
{errBadHTTPResponse(500, "internal server eror"), true},
{errBadHTTPResponse(502, "bad gateway"), true},
{errBadHTTPResponse(503, "service unavailable"), true},
{errBadHTTPResponse(504, "gateway timeout"), true},
{errBadHTTPResponse(1234, "random error"), false},
}
for _, tt := range errorTests {
t.Run(tt.err.Error(), func(t *testing.T) {
if isRetryableErrorForTest(tt.err) != tt.want {
t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
}
})
}
}
type retryableForTest interface {
Retryable() bool
}
func isRetryableErrorForTest(err error) bool {
var ae retryableForTest
if errors.As(err, &ae) {
return ae.Retryable()
}
return false
}
var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
func TestDirectProxyManual(t *testing.T) {
if !*liveNetworkTest {
t.Skip("skipping without --live-network-test")
}
bus := eventbustest.NewBus(t)
dialer := &tsdial.Dialer{}
dialer.SetNetMon(netmon.NewStatic())
dialer.SetBus(bus)
opts := Options{
Persist: persist.Persist{},
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return key.NewMachine(), nil
},
ServerURL: "https://controlplane.tailscale.com",
Clock: tstime.StdClock{},
Hostinfo: &tailcfg.Hostinfo{
BackendLogID: "test-backend-log-id",
},
DiscoPublicKey: key.NewDisco().Public(),
Logf: t.Logf,
HealthTracker: health.NewTracker(bus),
PopBrowserURL: func(url string) {
t.Logf("PopBrowserURL: %q", url)
},
Dialer: dialer,
ControlKnobs: &controlknobs.Knobs{},
Bus: bus,
}
d, err := NewDirect(opts)
if err != nil {
t.Fatalf("NewDirect: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
url, err := d.TryLogin(ctx, LoginEphemeral)
if err != nil {
t.Fatalf("TryLogin: %v", err)
}
t.Logf("URL: %q", url)
}
func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
// TestTLSWithProxy verifies we can connect to the control plane via
// an HTTPS proxy.
func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
func testHTTPS(t *testing.T, withProxy bool) {
bakedroots.ResetForTest(t, tlstest.TestRootCA())
bus := eventbustest.NewBus(t)
controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
if err != nil {
t.Fatal(err)
}
defer controlLn.Close()
proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
if err != nil {
t.Fatal(err)
}
defer proxyLn.Close()
const requiredAuthKey = "hunter2"
const someUsername = "testuser"
const somePassword = "testpass"
testControl := &testcontrol.Server{
Logf: tstest.WhileTestRunningLogger(t),
RequireAuthKey: requiredAuthKey,
}
controlSrv := &http.Server{
Handler: testControl,
ErrorLog: logger.StdLogger(t.Logf),
}
go controlSrv.Serve(controlLn)
const fakeControlIP = "1.2.3.4"
const fakeProxyIP = "5.6.7.8"
dialer := &tsdial.Dialer{}
dialer.SetNetMon(netmon.NewStatic())
dialer.SetBus(bus)
dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
}
var d net.Dialer
if host == fakeControlIP {
return d.DialContext(ctx, network, controlLn.Addr().String())
}
if host == fakeProxyIP {
return d.DialContext(ctx, network, proxyLn.Addr().String())
}
return nil, fmt.Errorf("unexpected dial to %q", addr)
})
opts := Options{
Persist: persist.Persist{},
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return key.NewMachine(), nil
},
AuthKey: requiredAuthKey,
ServerURL: "https://controlplane.tstest",
Clock: tstime.StdClock{},
Hostinfo: &tailcfg.Hostinfo{
BackendLogID: "test-backend-log-id",
},
DiscoPublicKey: key.NewDisco().Public(),
Logf: t.Logf,
HealthTracker: health.NewTracker(bus),
PopBrowserURL: func(url string) {
t.Logf("PopBrowserURL: %q", url)
},
Dialer: dialer,
Bus: bus,
}
d, err := NewDirect(opts)
if err != nil {
t.Fatalf("NewDirect: %v", err)
}
d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
switch host {
case "controlplane.tstest":
return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
case "proxy.tstest":
if !withProxy {
t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
}
return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
}
t.Errorf("unexpected DNS query for %q", host)
return []netip.Addr{}, nil
}
var proxyReqs atomic.Int64
if withProxy {
d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
t.Logf("using proxy for %q", req.URL)
u := &url.URL{
Scheme: "https",
Host: "proxy.tstest:443",
User: url.UserPassword(someUsername, somePassword),
}
return u, nil
}
connectProxy := &http.Server{
Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
}
go connectProxy.Serve(proxyLn)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
url, err := d.TryLogin(ctx, LoginEphemeral)
if err != nil {
t.Fatalf("TryLogin: %v", err)
}
if url != "" {
t.Errorf("got URL %q, want empty", url)
}
if withProxy {
if got, want := proxyReqs.Load(), int64(1); got != want {
t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
}
}
}
func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != target {
t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
http.Error(w, "bad target", http.StatusBadRequest)
return
}
r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
user, pass, ok := r.BasicAuth()
if !ok || user != "testuser" || pass != "testpass" {
t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
http.Error(w, "bad auth", http.StatusUnauthorized)
return
}
(&connectproxy.Handler{
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
c, err := d.DialContext(ctx, network, backendAddrPort)
if err == nil {
reqs.Add(1)
}
return c, err
},
Logf: t.Logf,
}).ServeHTTP(w, r)
})
}