diff --git a/build_dist.sh b/build_dist.sh index f11d4aae2..fed37c264 100755 --- a/build_dist.sh +++ b/build_dist.sh @@ -41,7 +41,7 @@ while [ "$#" -gt 1 ]; do fi shift ldflags="$ldflags -w -s" - tags="${tags:+$tags,}ts_omit_aws,ts_omit_bird,ts_omit_tap,ts_omit_kube,ts_omit_completion,ts_omit_ssh,ts_omit_wakeonlan,ts_omit_capture,ts_omit_relayserver,ts_omit_taildrop" + tags="${tags:+$tags,}ts_omit_aws,ts_omit_bird,ts_omit_tap,ts_omit_kube,ts_omit_completion,ts_omit_ssh,ts_omit_wakeonlan,ts_omit_capture,ts_omit_relayserver,ts_omit_taildrop,ts_omit_tpm" ;; --box) if [ ! -z "${TAGS:-}" ]; then diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 4cc4a8d46..544fe9089 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -135,6 +135,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/google/go-cmp/cmp/internal/flags from github.com/google/go-cmp/cmp+ github.com/google/go-cmp/cmp/internal/function from github.com/google/go-cmp/cmp 💣 github.com/google/go-cmp/cmp/internal/value from github.com/google/go-cmp/cmp + github.com/google/go-tpm/legacy/tpm2 from github.com/google/go-tpm/tpm2/transport+ + github.com/google/go-tpm/tpm2 from tailscale.com/feature/tpm + github.com/google/go-tpm/tpm2/transport from github.com/google/go-tpm/tpm2/transport/linuxtpm+ + L github.com/google/go-tpm/tpm2/transport/linuxtpm from tailscale.com/feature/tpm + W github.com/google/go-tpm/tpm2/transport/windowstpm from tailscale.com/feature/tpm + github.com/google/go-tpm/tpmutil from github.com/google/go-tpm/legacy/tpm2+ + W 💣 github.com/google/go-tpm/tpmutil/tbs from github.com/google/go-tpm/legacy/tpm2+ github.com/google/gofuzz from k8s.io/apimachinery/pkg/apis/meta/v1+ github.com/google/gofuzz/bytesource from github.com/google/gofuzz L github.com/google/nftables from tailscale.com/util/linuxfw @@ -813,6 +820,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/feature/relayserver from tailscale.com/feature/condregister tailscale.com/feature/taildrop from tailscale.com/feature/condregister L tailscale.com/feature/tap from tailscale.com/feature/condregister + tailscale.com/feature/tpm from tailscale.com/feature/condregister tailscale.com/feature/wakeonlan from tailscale.com/feature/condregister tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal @@ -832,9 +840,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/ipn/store/kubestore from tailscale.com/cmd/k8s-operator+ tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ tailscale.com/k8s-operator from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/api-proxy from tailscale.com/cmd/k8s-operator tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ - tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/sessionrecording from tailscale.com/k8s-operator/api-proxy tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording @@ -937,7 +946,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/clientmetric from tailscale.com/cmd/k8s-operator+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/cmpver from tailscale.com/clientupdate+ - tailscale.com/util/ctxkey from tailscale.com/cmd/k8s-operator+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ 💣 tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics+ tailscale.com/util/dnsname from tailscale.com/appc+ diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 37db252ad..f4ea674ec 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -83,7 +83,7 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf.BoolVar(&setArgs.advertiseConnector, "advertise-connector", false, "offer to be an app connector for domain specific internet traffic for the tailnet") setf.BoolVar(&setArgs.updateCheck, "update-check", true, "notify about available Tailscale updates") setf.BoolVar(&setArgs.updateApply, "auto-update", false, "automatically update to the latest available version") - setf.BoolVar(&setArgs.postureChecking, "posture-checking", false, hidden+"allow management plane to gather device posture information") + setf.BoolVar(&setArgs.postureChecking, "posture-checking", false, "allow management plane to gather device posture information") setf.BoolVar(&setArgs.runWebClient, "webclient", false, "expose the web interface for managing this node over Tailscale at port 5252") setf.StringVar(&setArgs.relayServerPort, "relay-server-port", "", hidden+"UDP port number (0 will pick a random unused port) for the relay server to bind to, on all interfaces, or empty string to disable relay server functionality") diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 329c00e93..c5d5a7b2d 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -109,6 +109,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L 💣 github.com/godbus/dbus/v5 from tailscale.com/net/dns+ github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + github.com/google/go-tpm/legacy/tpm2 from github.com/google/go-tpm/tpm2/transport+ + github.com/google/go-tpm/tpm2 from tailscale.com/feature/tpm + github.com/google/go-tpm/tpm2/transport from github.com/google/go-tpm/tpm2/transport/linuxtpm+ + L github.com/google/go-tpm/tpm2/transport/linuxtpm from tailscale.com/feature/tpm + W github.com/google/go-tpm/tpm2/transport/windowstpm from tailscale.com/feature/tpm + github.com/google/go-tpm/tpmutil from github.com/google/go-tpm/legacy/tpm2+ + W 💣 github.com/google/go-tpm/tpmutil/tbs from github.com/google/go-tpm/legacy/tpm2+ L github.com/google/nftables from tailscale.com/util/linuxfw L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ @@ -271,6 +278,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/feature/relayserver from tailscale.com/feature/condregister tailscale.com/feature/taildrop from tailscale.com/feature/condregister L tailscale.com/feature/tap from tailscale.com/feature/condregister + tailscale.com/feature/tpm from tailscale.com/feature/condregister tailscale.com/feature/wakeonlan from tailscale.com/feature/condregister tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 1c5236123..4b0dc95f9 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -573,7 +573,7 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID if ms, ok := sys.MagicSock.GetOK(); ok { debugMux.HandleFunc("/debug/magicsock", ms.ServeHTTPDebug) } - go runDebugServer(debugMux, args.debug) + go runDebugServer(logf, debugMux, args.debug) } ns, err := newNetstack(logf, sys) @@ -819,12 +819,20 @@ func servePrometheusMetrics(w http.ResponseWriter, r *http.Request) { clientmetric.WritePrometheusExpositionFormat(w) } -func runDebugServer(mux *http.ServeMux, addr string) { +func runDebugServer(logf logger.Logf, mux *http.ServeMux, addr string) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("debug server: %v", err) + } + if strings.HasSuffix(addr, ":0") { + // Log kernel-selected port number so integration tests + // can find it portably. + logf("DEBUG-ADDR=%v", ln.Addr()) + } srv := &http.Server{ - Addr: addr, Handler: mux, } - if err := srv.ListenAndServe(); err != nil { + if err := srv.Serve(ln); err != nil { log.Fatal(err) } } diff --git a/cmd/tsidp/README.md b/cmd/tsidp/README.md index 29ce089df..61a81e8ae 100644 --- a/cmd/tsidp/README.md +++ b/cmd/tsidp/README.md @@ -35,7 +35,7 @@ ```bash docker run -d \ - --name `tsidp` \ + --name tsidp \ -p 443:443 \ -e TS_AUTHKEY=YOUR_TAILSCALE_AUTHKEY \ -e TS_HOSTNAME=idp \ diff --git a/feature/condregister/maybe_tpm.go b/feature/condregister/maybe_tpm.go new file mode 100644 index 000000000..caa57fef1 --- /dev/null +++ b/feature/condregister/maybe_tpm.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_tpm + +package condregister + +import _ "tailscale.com/feature/tpm" diff --git a/feature/taildrop/integration_test.go b/feature/taildrop/integration_test.go new file mode 100644 index 000000000..46768bb31 --- /dev/null +++ b/feature/taildrop/integration_test.go @@ -0,0 +1,170 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" +) + +// TODO(bradfitz): add test where control doesn't send tailcfg.CapabilityFileSharing +// and verify that we get the "file sharing not enabled by Tailscale admin" error. + +// TODO(bradfitz): add test between different users with the peercap to permit that? + +func TestTaildropIntegration(t *testing.T) { + tstest.Parallel(t) + controlOpt := integration.ConfigureControl(func(s *testcontrol.Server) { + s.AllNodesSameUser = true // required for Taildrop + }) + env := integration.NewTestEnv(t, controlOpt) + + // Create two nodes: + n1 := integration.NewTestNode(t, env) + d1 := n1.StartDaemon() + + n2 := integration.NewTestNode(t, env) + d2 := n2.StartDaemon() + + n1.AwaitListening() + t.Logf("n1 is listening") + n2.AwaitListening() + t.Logf("n2 is listening") + n1.MustUp() + t.Logf("n1 is up") + n2.MustUp() + t.Logf("n2 is up") + n1.AwaitRunning() + t.Logf("n1 is running") + n2.AwaitRunning() + t.Logf("n2 is running") + + var peerStableID tailcfg.StableNodeID + + if err := tstest.WaitFor(5*time.Second, func() error { + st := n1.MustStatus() + if len(st.Peer) == 0 { + return errors.New("no peers") + } + if len(st.Peer) > 1 { + return fmt.Errorf("got %d peers; want 1", len(st.Peer)) + } + peer := st.Peer[st.Peers()[0]] + peerStableID = peer.ID + if peer.ID == st.Self.ID { + return errors.New("peer is self") + } + + if len(st.TailscaleIPs) == 0 { + return errors.New("no Tailscale IPs") + } + + return nil + }); err != nil { + t.Fatal(err) + } + + const timeout = 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + c1 := n1.LocalClient() + c2 := n2.LocalClient() + + wantNoWaitingFiles := func(c *local.Client) { + t.Helper() + files, err := c.WaitingFiles(ctx) + if err != nil { + t.Fatalf("WaitingFiles: %v", err) + } + if len(files) != 0 { + t.Fatalf("WaitingFiles: got %d files; want 0", len(files)) + } + } + + // Verify c2 has no files. + wantNoWaitingFiles(c2) + + gotFile := make(chan bool, 1) + go func() { + v, err := c2.AwaitWaitingFiles(t.Context(), timeout) + if err != nil { + return + } + if len(v) != 0 { + gotFile <- true + } + }() + + fileContents := []byte("hello world this is a file") + + n2ID := n2.MustStatus().Self.ID + t.Logf("n2 self.ID = %q; n1's peer[0].ID = %q", n2ID, peerStableID) + t.Logf("Doing PushFile ...") + err := c1.PushFile(ctx, n2.MustStatus().Self.ID, int64(len(fileContents)), "test.txt", bytes.NewReader(fileContents)) + if err != nil { + t.Fatalf("PushFile from n1->n2: %v", err) + } + t.Logf("PushFile done") + + select { + case <-gotFile: + t.Logf("n2 saw AwaitWaitingFiles wake up") + case <-ctx.Done(): + t.Fatalf("n2 timeout waiting for AwaitWaitingFiles") + } + + files, err := c2.WaitingFiles(ctx) + if err != nil { + t.Fatalf("c2.WaitingFiles: %v", err) + } + if len(files) != 1 { + t.Fatalf("c2.WaitingFiles: got %d files; want 1", len(files)) + } + got := files[0] + want := apitype.WaitingFile{ + Name: "test.txt", + Size: int64(len(fileContents)), + } + if got != want { + t.Fatalf("c2.WaitingFiles: got %+v; want %+v", got, want) + } + + // Download the file. + rc, size, err := c2.GetWaitingFile(ctx, got.Name) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if size != int64(len(fileContents)) { + t.Fatalf("c2.GetWaitingFile: got size %d; want %d", size, len(fileContents)) + } + gotBytes, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if !bytes.Equal(gotBytes, fileContents) { + t.Fatalf("c2.GetWaitingFile: got %q; want %q", gotBytes, fileContents) + } + + // Now delete it. + if err := c2.DeleteWaitingFile(ctx, got.Name); err != nil { + t.Fatalf("c2.DeleteWaitingFile: %v", err) + } + wantNoWaitingFiles(c2) + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} diff --git a/feature/taildrop/localapi.go b/feature/taildrop/localapi.go index ce812514e..067a51f91 100644 --- a/feature/taildrop/localapi.go +++ b/feature/taildrop/localapi.go @@ -365,6 +365,7 @@ func serveFiles(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { return } ctx := r.Context() + var wfs []apitype.WaitingFile if s := r.FormValue("waitsec"); s != "" && s != "0" { d, err := strconv.Atoi(s) if err != nil { @@ -375,11 +376,18 @@ func serveFiles(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, deadline) defer cancel() - } - wfs, err := lb.AwaitWaitingFiles(ctx) - if err != nil && ctx.Err() == nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + wfs, err = lb.AwaitWaitingFiles(ctx) + if err != nil && ctx.Err() == nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } else { + var err error + wfs, err = lb.WaitingFiles() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(wfs) diff --git a/feature/tpm/tpm.go b/feature/tpm/tpm.go new file mode 100644 index 000000000..18e56ae89 --- /dev/null +++ b/feature/tpm/tpm.go @@ -0,0 +1,83 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tpm implements support for TPM 2.0 devices. +package tpm + +import ( + "slices" + "sync" + + "github.com/google/go-tpm/tpm2" + "github.com/google/go-tpm/tpm2/transport" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/tailcfg" +) + +var infoOnce = sync.OnceValue(info) + +func init() { + feature.Register("tpm") + hostinfo.RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { + hi.TPM = infoOnce() + }) +} + +//lint:ignore U1000 used in Linux and Windows builds only +func infoFromCapabilities(tpm transport.TPM) *tailcfg.TPMInfo { + info := new(tailcfg.TPMInfo) + toStr := func(s *string) func(*tailcfg.TPMInfo, uint32) { + return func(info *tailcfg.TPMInfo, value uint32) { + *s += propToString(value) + } + } + for _, cap := range []struct { + prop tpm2.TPMPT + apply func(info *tailcfg.TPMInfo, value uint32) + }{ + {tpm2.TPMPTManufacturer, toStr(&info.Manufacturer)}, + {tpm2.TPMPTVendorString1, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString2, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString3, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString4, toStr(&info.Vendor)}, + {tpm2.TPMPTRevision, func(info *tailcfg.TPMInfo, value uint32) { info.SpecRevision = int(value) }}, + {tpm2.TPMPTVendorTPMType, func(info *tailcfg.TPMInfo, value uint32) { info.Model = int(value) }}, + {tpm2.TPMPTFirmwareVersion1, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) << 32 }}, + {tpm2.TPMPTFirmwareVersion2, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) }}, + } { + resp, err := tpm2.GetCapability{ + Capability: tpm2.TPMCapTPMProperties, + Property: uint32(cap.prop), + PropertyCount: 1, + }.Execute(tpm) + if err != nil { + continue + } + props, err := resp.CapabilityData.Data.TPMProperties() + if err != nil { + continue + } + if len(props.TPMProperty) == 0 { + continue + } + cap.apply(info, props.TPMProperty[0].Value) + } + return info +} + +// propToString converts TPM_PT property value, which is a uint32, into a +// string of up to 4 ASCII characters. This encoding applies only to some +// properties, see +// https://trustedcomputinggroup.org/resource/tpm-library-specification/ Part +// 2, section 6.13. +func propToString(v uint32) string { + chars := []byte{ + byte(v >> 24), + byte(v >> 16), + byte(v >> 8), + byte(v), + } + // Delete any non-printable ASCII characters. + return string(slices.DeleteFunc(chars, func(b byte) bool { return b < ' ' || b > '~' })) +} diff --git a/feature/tpm/tpm_linux.go b/feature/tpm/tpm_linux.go new file mode 100644 index 000000000..a90c0e153 --- /dev/null +++ b/feature/tpm/tpm_linux.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport/linuxtpm" + "tailscale.com/tailcfg" +) + +func info() *tailcfg.TPMInfo { + t, err := linuxtpm.Open("/dev/tpm0") + if err != nil { + return nil + } + defer t.Close() + return infoFromCapabilities(t) +} diff --git a/feature/tpm/tpm_other.go b/feature/tpm/tpm_other.go new file mode 100644 index 000000000..ba7c67621 --- /dev/null +++ b/feature/tpm/tpm_other.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows + +package tpm + +import "tailscale.com/tailcfg" + +func info() *tailcfg.TPMInfo { + return nil +} diff --git a/feature/tpm/tpm_test.go b/feature/tpm/tpm_test.go new file mode 100644 index 000000000..fc0fc178c --- /dev/null +++ b/feature/tpm/tpm_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import "testing" + +func TestPropToString(t *testing.T) { + for prop, want := range map[uint32]string{ + 0: "", + 0x4D534654: "MSFT", + 0x414D4400: "AMD", + 0x414D440D: "AMD", + } { + if got := propToString(prop); got != want { + t.Errorf("propToString(0x%x): got %q, want %q", prop, got, want) + } + } +} diff --git a/feature/tpm/tpm_windows.go b/feature/tpm/tpm_windows.go new file mode 100644 index 000000000..578d687af --- /dev/null +++ b/feature/tpm/tpm_windows.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport/windowstpm" + "tailscale.com/tailcfg" +) + +func info() *tailcfg.TPMInfo { + t, err := windowstpm.Open() + if err != nil { + return nil + } + defer t.Close() + return infoFromCapabilities(t) +} diff --git a/go.mod b/go.mod index 0c1224cf1..f346b1e40 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/golangci/golangci-lint v1.57.1 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.2 + github.com/google/go-tpm v0.9.4 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 8c8da8d14..bdbae11bb 100644 --- a/go.sum +++ b/go.sum @@ -486,6 +486,10 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= +github.com/google/go-tpm v0.9.4 h1:awZRf9FwOeTunQmHoDYSHJps3ie6f1UlhS1fOdPEt1I= +github.com/google/go-tpm v0.9.4/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 95fe22641..b2998d11c 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -258,7 +258,7 @@ type LocalBackend struct { // We intend to relax this in the future and only require holding b.mu when replacing it, // but that requires a better (strictly ordered?) state machine and better management // of [LocalBackend]'s own state that is not tied to the node context. - currentNodeAtomic atomic.Pointer[localNodeContext] + currentNodeAtomic atomic.Pointer[nodeBackend] conf *conffile.Config // latest parsed config, or nil if not in declarative mode pm *profileManager // mu guards access @@ -519,7 +519,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } - b.currentNodeAtomic.Store(newLocalNodeContext()) + b.currentNodeAtomic.Store(newNodeBackend()) mConn.SetNetInfoCallback(b.setNetInfo) if sys.InitialConfig != nil { @@ -594,12 +594,12 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo func (b *LocalBackend) Clock() tstime.Clock { return b.clock } func (b *LocalBackend) Sys() *tsd.System { return b.sys } -func (b *LocalBackend) currentNode() *localNodeContext { +func (b *LocalBackend) currentNode() *nodeBackend { if v := b.currentNodeAtomic.Load(); v != nil || !testenv.InTest() { return v } // Auto-init one in tests for LocalBackend created without the NewLocalBackend constructor... - v := newLocalNodeContext() + v := newNodeBackend() b.currentNodeAtomic.CompareAndSwap(nil, v) return b.currentNodeAtomic.Load() } @@ -1463,15 +1463,30 @@ func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { return b.currentNode().PeerCaps(src) } -func (b *localNodeContext) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { - b.mu.Lock() - defer b.mu.Unlock() - ret := base - if b.netMap == nil { - return ret +// AppendMatchingPeers returns base with all peers that match pred appended. +// +// It acquires b.mu to read the netmap but releases it before calling pred. +func (nb *nodeBackend) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { + var peers []tailcfg.NodeView + + nb.mu.Lock() + if nb.netMap != nil { + // All fields on b.netMap are immutable, so this is + // safe to copy and use outside the lock. + peers = nb.netMap.Peers } - for _, peer := range b.netMap.Peers { - if pred(peer) { + nb.mu.Unlock() + + ret := base + for _, peer := range peers { + // The peers in b.netMap don't contain updates made via + // UpdateNetmapDelta. So only use PeerView in b.netMap for its NodeID, + // and then look up the latest copy in b.peers which is updated in + // response to UpdateNetmapDelta edits. + nb.mu.Lock() + peer, ok := nb.peers[peer.ID()] + nb.mu.Unlock() + if ok && pred(peer) { ret = append(ret, peer) } } @@ -1480,21 +1495,21 @@ func (b *localNodeContext) AppendMatchingPeers(base []tailcfg.NodeView, pred fun // PeerCaps returns the capabilities that remote src IP has to // ths current node. -func (b *localNodeContext) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { - b.mu.Lock() - defer b.mu.Unlock() - return b.peerCapsLocked(src) +func (nb *nodeBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.peerCapsLocked(src) } -func (b *localNodeContext) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { - if b.netMap == nil { +func (nb *nodeBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { + if nb.netMap == nil { return nil } - filt := b.filterAtomic.Load() + filt := nb.filterAtomic.Load() if filt == nil { return nil } - addrs := b.netMap.GetAddresses() + addrs := nb.netMap.GetAddresses() for i := range addrs.Len() { a := addrs.At(i) if !a.IsSingleIP() { @@ -1508,8 +1523,8 @@ func (b *localNodeContext) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { return nil } -func (b *localNodeContext) GetFilterForTest() *filter.Filter { - return b.filterAtomic.Load() +func (nb *nodeBackend) GetFilterForTest() *filter.Filter { + return nb.filterAtomic.Load() } // SetControlClientStatus is the callback invoked by the control client whenever it posts a new status. @@ -2019,14 +2034,14 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo return true } -func (c *localNodeContext) netMapWithPeers() *netmap.NetworkMap { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { +func (nb *nodeBackend) netMapWithPeers() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { return nil } - nm := ptr.To(*c.netMap) // shallow clone - nm.Peers = slicesx.MapValues(c.peers) + nm := ptr.To(*nb.netMap) // shallow clone + nm.Peers = slicesx.MapValues(nb.peers) slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) }) @@ -2063,10 +2078,10 @@ func (b *LocalBackend) pickNewAutoExitNode() { b.send(ipn.Notify{Prefs: &newPrefs}) } -func (c *localNodeContext) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil || len(c.peers) == 0 { +func (nb *nodeBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil || len(nb.peers) == 0 { return false } @@ -2078,7 +2093,7 @@ func (c *localNodeContext) UpdateNetmapDelta(muts []netmap.NodeMutation) (handle for _, m := range muts { n, ok := mutableNodes[m.NodeIDBeingMutated()] if !ok { - nv, ok := c.peers[m.NodeIDBeingMutated()] + nv, ok := nb.peers[m.NodeIDBeingMutated()] if !ok { // TODO(bradfitz): unexpected metric? return false @@ -2089,7 +2104,7 @@ func (c *localNodeContext) UpdateNetmapDelta(muts []netmap.NodeMutation) (handle m.Apply(n) } for nid, n := range mutableNodes { - c.peers[nid] = n.View() + nb.peers[nid] = n.View() } return true } @@ -2250,10 +2265,10 @@ func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { return b.currentNode().PeersForTest() } -func (b *localNodeContext) PeersForTest() []tailcfg.NodeView { - b.mu.Lock() - defer b.mu.Unlock() - ret := slicesx.MapValues(b.peers) +func (nb *nodeBackend) PeersForTest() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + ret := slicesx.MapValues(nb.peers) slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) }) @@ -2532,12 +2547,12 @@ var invalidPacketFilterWarnable = health.Register(&health.Warnable{ // b.mu must be held. func (b *LocalBackend) updateFilterLocked(prefs ipn.PrefsView) { // TODO(nickkhyl) split this into two functions: - // - (*localNodeContext).RebuildFilters() (normalFilter, jailedFilter *filter.Filter, changed bool), + // - (*nodeBackend).RebuildFilters() (normalFilter, jailedFilter *filter.Filter, changed bool), // which would return packet filters for the current state and whether they changed since the last call. // - (*LocalBackend).updateFilters(), which would use the above to update the engine with the new filters, // notify b.sshServer, etc. // - // For this, we would need to plumb a few more things into the [localNodeContext]. Most importantly, + // For this, we would need to plumb a few more things into the [nodeBackend]. Most importantly, // the current [ipn.PrefsView]), but also maybe also a b.logf and a b.health? // // NOTE(danderson): keep change detection as the first thing in @@ -2823,8 +2838,8 @@ func (b *LocalBackend) setFilter(f *filter.Filter) { b.e.SetFilter(f) } -func (c *localNodeContext) setFilter(f *filter.Filter) { - c.filterAtomic.Store(f) +func (nb *nodeBackend) setFilter(f *filter.Filter) { + nb.filterAtomic.Store(f) } var removeFromDefaultRoute = []netip.Prefix{ @@ -3886,7 +3901,7 @@ func (b *LocalBackend) parseWgStatusLocked(s *wgengine.Status) (ret ipn.EngineSt // in Hostinfo. When the user preferences currently request "shields up" // mode, all inbound connections are refused, so services are not reported. // Otherwise, shouldUploadServices respects NetMap.CollectServices. -// TODO(nickkhyl): move this into [localNodeContext]? +// TODO(nickkhyl): move this into [nodeBackend]? func (b *LocalBackend) shouldUploadServices() bool { b.mu.Lock() defer b.mu.Unlock() @@ -4758,10 +4773,10 @@ func (b *LocalBackend) NetMap() *netmap.NetworkMap { return b.currentNode().NetMap() } -func (c *localNodeContext) NetMap() *netmap.NetworkMap { - c.mu.Lock() - defer c.mu.Unlock() - return c.netMap +func (nb *nodeBackend) NetMap() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.netMap } func (b *LocalBackend) isEngineBlocked() bool { @@ -5003,10 +5018,10 @@ func shouldUseOneCGNATRoute(logf logger.Logf, mon *netmon.Monitor, controlKnobs return false } -func (c *localNodeContext) dnsConfigForNetmap(prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { - c.mu.Lock() - defer c.mu.Unlock() - return dnsConfigForNetmap(c.netMap, c.peers, prefs, selfExpired, logf, versionOS) +func (nb *nodeBackend) dnsConfigForNetmap(prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { + nb.mu.Lock() + defer nb.mu.Unlock() + return dnsConfigForNetmap(nb.netMap, nb.peers, prefs, selfExpired, logf, versionOS) } // dnsConfigForNetmap returns a *dns.Config for the given netmap, @@ -5041,6 +5056,8 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg. !nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs4) dcfg.OnlyIPv6 = selfV6Only + wantAAAA := nm.AllCaps.Contains(tailcfg.NodeAttrMagicDNSPeerAAAA) + // Populate MagicDNS records. We do this unconditionally so that // quad-100 can always respond to MagicDNS queries, even if the OS // isn't configured to make MagicDNS resolution truly @@ -5077,7 +5094,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg. // https://github.com/tailscale/tailscale/issues/1152 // tracks adding the right capability reporting to // enable AAAA in MagicDNS. - if addr.Addr().Is6() && have4 { + if addr.Addr().Is6() && have4 && !wantAAAA { continue } ips = append(ips, addr.Addr()) @@ -6129,12 +6146,12 @@ func (b *LocalBackend) setAutoExitNodeIDLockedOnEntry(unlock unlockOnce) (newPre return newPrefs } -func (c *localNodeContext) SetNetMap(nm *netmap.NetworkMap) { - c.mu.Lock() - defer c.mu.Unlock() - c.netMap = nm - c.updateNodeByAddrLocked() - c.updatePeersLocked() +func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.netMap = nm + nb.updateNodeByAddrLocked() + nb.updatePeersLocked() } // setNetMapLocked updates the LocalBackend state to reflect the newly @@ -6209,25 +6226,25 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.driveNotifyCurrentSharesLocked() } -func (b *localNodeContext) updateNodeByAddrLocked() { - nm := b.netMap +func (nb *nodeBackend) updateNodeByAddrLocked() { + nm := nb.netMap if nm == nil { - b.nodeByAddr = nil + nb.nodeByAddr = nil return } // Update the nodeByAddr index. - if b.nodeByAddr == nil { - b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} + if nb.nodeByAddr == nil { + nb.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} } // First pass, mark everything unwanted. - for k := range b.nodeByAddr { - b.nodeByAddr[k] = 0 + for k := range nb.nodeByAddr { + nb.nodeByAddr[k] = 0 } addNode := func(n tailcfg.NodeView) { for _, ipp := range n.Addresses().All() { if ipp.IsSingleIP() { - b.nodeByAddr[ipp.Addr()] = n.ID() + nb.nodeByAddr[ipp.Addr()] = n.ID() } } } @@ -6238,34 +6255,34 @@ func (b *localNodeContext) updateNodeByAddrLocked() { addNode(p) } // Third pass, actually delete the unwanted items. - for k, v := range b.nodeByAddr { + for k, v := range nb.nodeByAddr { if v == 0 { - delete(b.nodeByAddr, k) + delete(nb.nodeByAddr, k) } } } -func (b *localNodeContext) updatePeersLocked() { - nm := b.netMap +func (nb *nodeBackend) updatePeersLocked() { + nm := nb.netMap if nm == nil { - b.peers = nil + nb.peers = nil return } // First pass, mark everything unwanted. - for k := range b.peers { - b.peers[k] = tailcfg.NodeView{} + for k := range nb.peers { + nb.peers[k] = tailcfg.NodeView{} } // Second pass, add everything wanted. for _, p := range nm.Peers { - mak.Set(&b.peers, p.ID(), p) + mak.Set(&nb.peers, p.ID(), p) } // Third pass, remove deleted things. - for k, v := range b.peers { + for k, v := range nb.peers { if !v.Valid() { - delete(b.peers, k) + delete(nb.peers, k) } } } @@ -6652,14 +6669,14 @@ func (b *LocalBackend) TestOnlyPublicKeys() (machineKey key.MachinePublic, nodeK // PeerHasCap reports whether the peer with the given Tailscale IP addresses // contains the given capability string, with any value(s). -func (b *localNodeContext) PeerHasCap(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.peerHasCapLocked(addr, wantCap) +func (nb *nodeBackend) PeerHasCap(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.peerHasCapLocked(addr, wantCap) } -func (b *localNodeContext) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { - return b.peerCapsLocked(addr).HasCapability(wantCap) +func (nb *nodeBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { + return nb.peerCapsLocked(addr).HasCapability(wantCap) } // SetDNS adds a DNS record for the given domain name & TXT record @@ -6722,16 +6739,16 @@ func peerAPIURL(ip netip.Addr, port uint16) string { return fmt.Sprintf("http://%v", netip.AddrPortFrom(ip, port)) } -func (c *localNodeContext) PeerHasPeerAPI(p tailcfg.NodeView) bool { - return c.PeerAPIBase(p) != "" +func (nb *nodeBackend) PeerHasPeerAPI(p tailcfg.NodeView) bool { + return nb.PeerAPIBase(p) != "" } // PeerAPIBase returns the "http://ip:port" URL base to reach peer's PeerAPI, // or the empty string if the peer is invalid or doesn't support PeerAPI. -func (c *localNodeContext) PeerAPIBase(p tailcfg.NodeView) string { - c.mu.Lock() - nm := c.netMap - c.mu.Unlock() +func (nb *nodeBackend) PeerAPIBase(p tailcfg.NodeView) string { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() return peerAPIBase(nm, p) } @@ -6972,10 +6989,10 @@ func exitNodeCanProxyDNS(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg return "", false } -func (c *localNodeContext) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - return exitNodeCanProxyDNS(c.netMap, c.peers, exitNodeID) +func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID) } // wireguardExitNodeDNSResolvers returns the DNS resolvers to use for a @@ -7396,7 +7413,7 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err // down, so no need to do any work. return nil } - b.currentNodeAtomic.Store(newLocalNodeContext()) + b.currentNodeAtomic.Store(newNodeBackend()) b.setNetMapLocked(nil) // Reset netmap. b.updateFilterLocked(ipn.PrefsView{}) // Reset the NetworkMap in the engine @@ -8086,7 +8103,7 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) { // rules that require a source IP to have a certain node capability. // // TODO(bradfitz): optimize this later if/when it matters. -// TODO(nickkhyl): move this into [localNodeContext] along with [LocalBackend.updateFilterLocked]. +// TODO(nickkhyl): move this into [nodeBackend] along with [LocalBackend.updateFilterLocked]. func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool { if cap == "" { // Shouldn't happen, but just in case. diff --git a/ipn/ipnlocal/local_node_context.go b/ipn/ipnlocal/node_backend.go similarity index 65% rename from ipn/ipnlocal/local_node_context.go rename to ipn/ipnlocal/node_backend.go index 871880893..415c32ccf 100644 --- a/ipn/ipnlocal/local_node_context.go +++ b/ipn/ipnlocal/node_backend.go @@ -18,29 +18,29 @@ import ( "tailscale.com/wgengine/filter" ) -// localNodeContext holds the [LocalBackend]'s context tied to a local node (usually the current one). +// nodeBackend is node-specific [LocalBackend] state. It is usually the current node. // // Its exported methods are safe for concurrent use, but the struct is not a snapshot of state at a given moment; // its state can change between calls. For example, asking for the same value (e.g., netmap or prefs) twice // may return different results. Returned values are immutable and safe for concurrent use. // -// If both the [LocalBackend]'s internal mutex and the [localNodeContext] mutex must be held at the same time, +// If both the [LocalBackend]'s internal mutex and the [nodeBackend] mutex must be held at the same time, // the [LocalBackend] mutex must be acquired first. See the comment on the [LocalBackend] field for more details. // -// Two pointers to different [localNodeContext] instances represent different local nodes. -// However, there's currently a bug where a new [localNodeContext] might not be created +// Two pointers to different [nodeBackend] instances represent different local nodes. +// However, there's currently a bug where a new [nodeBackend] might not be created // during an implicit node switch (see tailscale/corp#28014). // In the future, we might want to include at least the following in this struct (in addition to the current fields). // However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions, // peer API handlers, etc.). -// - [ipn.State]: when the LocalBackend switches to a different [localNodeContext], it can update the state of the old one. +// - [ipn.State]: when the LocalBackend switches to a different [nodeBackend], it can update the state of the old one. // - [ipn.LoginProfileView] and [ipn.Prefs]: we should update them when the [profileManager] reports changes to them. // In the future, [profileManager] (and the corresponding methods of the [LocalBackend]) can be made optional, // and something else could be used to set them once or update them as needed. // - [tailcfg.HostinfoView]: it includes certain fields that are tied to the current profile/node/prefs. We should also // update to build it once instead of mutating it in twelvety different places. -// - [filter.Filter] (normal and jailed, along with the filterHash): the localNodeContext could have a method to (re-)build +// - [filter.Filter] (normal and jailed, along with the filterHash): the nodeBackend could have a method to (re-)build // the filter for the current netmap/prefs (see [LocalBackend.updateFilterLocked]), and it needs to track the current // filters and their hash. // - Fields related to a requested or required (re-)auth: authURL, authURLTime, authActor, keyExpired, etc. @@ -51,7 +51,7 @@ import ( // It should not include any fields used by specific features that don't belong in [LocalBackend]. // Even if they're tied to the local node, instead of moving them here, we should extract the entire feature // into a separate package and have it install proper hooks. -type localNodeContext struct { +type nodeBackend struct { // filterAtomic is a stateful packet filter. Immutable once created, but can be // replaced with a new one. filterAtomic atomic.Pointer[filter.Filter] @@ -71,33 +71,33 @@ type localNodeContext struct { // peers is the set of current peers and their current values after applying // delta node mutations as they come in (with mu held). The map values can be // given out to callers, but the map itself can be mutated in place (with mu held) - // and must not escape the [localNodeContext]. + // and must not escape the [nodeBackend]. peers map[tailcfg.NodeID]tailcfg.NodeView // nodeByAddr maps nodes' own addresses (excluding subnet routes) to node IDs. - // It is mutated in place (with mu held) and must not escape the [localNodeContext]. + // It is mutated in place (with mu held) and must not escape the [nodeBackend]. nodeByAddr map[netip.Addr]tailcfg.NodeID } -func newLocalNodeContext() *localNodeContext { - cn := &localNodeContext{} +func newNodeBackend() *nodeBackend { + cn := &nodeBackend{} // Default filter blocks everything and logs nothing. noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{}) cn.filterAtomic.Store(noneFilter) return cn } -func (c *localNodeContext) Self() tailcfg.NodeView { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { +func (nb *nodeBackend) Self() tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { return tailcfg.NodeView{} } - return c.netMap.SelfNode + return nb.netMap.SelfNode } -func (c *localNodeContext) SelfUserID() tailcfg.UserID { - self := c.Self() +func (nb *nodeBackend) SelfUserID() tailcfg.UserID { + self := nb.Self() if !self.Valid() { return 0 } @@ -105,59 +105,59 @@ func (c *localNodeContext) SelfUserID() tailcfg.UserID { } // SelfHasCap reports whether the specified capability was granted to the self node in the most recent netmap. -func (c *localNodeContext) SelfHasCap(wantCap tailcfg.NodeCapability) bool { - return c.SelfHasCapOr(wantCap, false) +func (nb *nodeBackend) SelfHasCap(wantCap tailcfg.NodeCapability) bool { + return nb.SelfHasCapOr(wantCap, false) } -// SelfHasCapOr is like [localNodeContext.SelfHasCap], but returns the specified default value +// SelfHasCapOr is like [nodeBackend.SelfHasCap], but returns the specified default value // if the netmap is not available yet. -func (c *localNodeContext) SelfHasCapOr(wantCap tailcfg.NodeCapability, def bool) bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { +func (nb *nodeBackend) SelfHasCapOr(wantCap tailcfg.NodeCapability, def bool) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { return def } - return c.netMap.AllCaps.Contains(wantCap) + return nb.netMap.AllCaps.Contains(wantCap) } -func (c *localNodeContext) NetworkProfile() ipn.NetworkProfile { - c.mu.Lock() - defer c.mu.Unlock() +func (nb *nodeBackend) NetworkProfile() ipn.NetworkProfile { + nb.mu.Lock() + defer nb.mu.Unlock() return ipn.NetworkProfile{ // These are ok to call with nil netMap. - MagicDNSName: c.netMap.MagicDNSSuffix(), - DomainName: c.netMap.DomainName(), + MagicDNSName: nb.netMap.MagicDNSSuffix(), + DomainName: nb.netMap.DomainName(), } } // TODO(nickkhyl): update it to return a [tailcfg.DERPMapView]? -func (c *localNodeContext) DERPMap() *tailcfg.DERPMap { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { +func (nb *nodeBackend) DERPMap() *tailcfg.DERPMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { return nil } - return c.netMap.DERPMap + return nb.netMap.DERPMap } -func (c *localNodeContext) NodeByAddr(ip netip.Addr) (_ tailcfg.NodeID, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - nid, ok := c.nodeByAddr[ip] +func (nb *nodeBackend) NodeByAddr(ip netip.Addr) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + nid, ok := nb.nodeByAddr[ip] return nid, ok } -func (c *localNodeContext) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { +func (nb *nodeBackend) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { return 0, false } - if self := c.netMap.SelfNode; self.Valid() && self.Key() == k { + if self := nb.netMap.SelfNode; self.Valid() && self.Key() == k { return self.ID(), true } // TODO(bradfitz,nickkhyl): add nodeByKey like nodeByAddr instead of walking peers. - for _, n := range c.peers { + for _, n := range nb.peers { if n.Key() == k { return n.ID(), true } @@ -165,17 +165,17 @@ func (c *localNodeContext) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok boo return 0, false } -func (c *localNodeContext) PeerByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - n, ok := c.peers[id] +func (nb *nodeBackend) PeerByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + n, ok := nb.peers[id] return n, ok } -func (c *localNodeContext) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileView, ok bool) { - c.mu.Lock() - nm := c.netMap - c.mu.Unlock() +func (nb *nodeBackend) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileView, ok bool) { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() if nm == nil { return tailcfg.UserProfileView{}, false } @@ -184,10 +184,10 @@ func (c *localNodeContext) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileVie } // Peers returns all the current peers in an undefined order. -func (c *localNodeContext) Peers() []tailcfg.NodeView { - c.mu.Lock() - defer c.mu.Unlock() - return slicesx.MapValues(c.peers) +func (nb *nodeBackend) Peers() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + return slicesx.MapValues(nb.peers) } // unlockedNodesPermitted reports whether any peer with theUnsignedPeerAPIOnly bool set true has any of its allowed IPs @@ -195,13 +195,13 @@ func (c *localNodeContext) Peers() []tailcfg.NodeView { // // TODO(nickkhyl): It is here temporarily until we can move the whole [LocalBackend.updateFilterLocked] here, // but change it so it builds and returns a filter for the current netmap/prefs instead of re-configuring the engine filter. -// Something like (*localNodeContext).RebuildFilters() (filter, jailedFilter *filter.Filter, changed bool) perhaps? -func (c *localNodeContext) unlockedNodesPermitted(packetFilter []filter.Match) bool { - c.mu.Lock() - defer c.mu.Unlock() - return packetFilterPermitsUnlockedNodes(c.peers, packetFilter) +// Something like (*nodeBackend).RebuildFilters() (filter, jailedFilter *filter.Filter, changed bool) perhaps? +func (nb *nodeBackend) unlockedNodesPermitted(packetFilter []filter.Match) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + return packetFilterPermitsUnlockedNodes(nb.peers, packetFilter) } -func (c *localNodeContext) filter() *filter.Filter { - return c.filterAtomic.Load() +func (nb *nodeBackend) filter() *filter.Filter { + return nb.filterAtomic.Load() } diff --git a/ipn/ipnlocal/taildrop.go b/ipn/ipnlocal/taildrop.go index 17ca40926..d8113d219 100644 --- a/ipn/ipnlocal/taildrop.go +++ b/ipn/ipnlocal/taildrop.go @@ -194,8 +194,8 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { if !p.Valid() || p.Hostinfo().OS() == "tvOS" { return false } - if self != p.User() { - return false + if self == p.User() { + return true } if p.Addresses().Len() != 0 && cn.PeerHasCap(p.Addresses().At(0).Addr(), tailcfg.PeerCapabilityFileSharingTarget) { // Explicitly noted in the netmap ACL caps as a target. diff --git a/logtail/logtail.go b/logtail/logtail.go index a617397f9..b355addd2 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -15,9 +15,7 @@ import ( "log" mrand "math/rand/v2" "net/http" - "net/netip" "os" - "regexp" "runtime" "slices" "strconv" @@ -29,7 +27,6 @@ import ( "tailscale.com/envknob" "tailscale.com/net/netmon" "tailscale.com/net/sockstats" - "tailscale.com/net/tsaddr" "tailscale.com/tstime" tslogger "tailscale.com/types/logger" "tailscale.com/types/logid" @@ -833,8 +830,6 @@ func (l *Logger) Logf(format string, args ...any) { fmt.Fprintf(l, format, args...) } -var obscureIPs = envknob.RegisterBool("TS_OBSCURE_LOGGED_IPS") - // Write logs an encoded JSON blob. // // If the []byte passed to Write is not an encoded JSON blob, @@ -859,10 +854,6 @@ func (l *Logger) Write(buf []byte) (int, error) { } } - if obscureIPs() { - buf = redactIPs(buf) - } - l.writeLock.Lock() defer l.writeLock.Unlock() @@ -871,40 +862,6 @@ func (l *Logger) Write(buf []byte) (int, error) { return inLen, err } -var ( - regexMatchesIPv6 = regexp.MustCompile(`([0-9a-fA-F]{1,4}):([0-9a-fA-F]{1,4}):([0-9a-fA-F:]{1,4})*`) - regexMatchesIPv4 = regexp.MustCompile(`(\d{1,3})\.(\d{1,3})\.\d{1,3}\.\d{1,3}`) -) - -// redactIPs is a helper function used in Write() to redact IPs (other than tailscale IPs). -// This function takes a log line as a byte slice and -// uses regex matching to parse and find IP addresses. Based on if the IP address is IPv4 or -// IPv6, it parses and replaces the end of the addresses with an "x". This function returns the -// log line with the IPs redacted. -func redactIPs(buf []byte) []byte { - out := regexMatchesIPv6.ReplaceAllFunc(buf, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(":")) - return bytes.Join(append(prefix[:2], []byte("x")), []byte(":")) - }) - - out = regexMatchesIPv4.ReplaceAllFunc(out, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(".")) - return bytes.Join(append(prefix[:2], []byte("x.x")), []byte(".")) - }) - - return []byte(out) -} - var ( openBracketV = []byte("[v") v1 = []byte("[v1] ") diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 3ea630406..b8c46c448 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/go-json-experiment/json/jsontext" - "tailscale.com/envknob" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/util/must" @@ -316,85 +315,6 @@ func TestLoggerWriteResult(t *testing.T) { t.Errorf("mismatch.\n got: %#q\nwant: %#q", back, want) } } -func TestRedact(t *testing.T) { - envknob.Setenv("TS_OBSCURE_LOGGED_IPS", "true") - tests := []struct { - in string - want string - }{ - // tests for ipv4 addresses - { - "120.100.30.47", - "120.100.x.x", - }, - { - "192.167.0.1/65", - "192.167.x.x/65", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce", - }, - //tests for ipv6 addresses - { - "2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "2001:0db8:x", - }, - { - "2345:0425:2CA1:0000:0000:0567:5673:23b5", - "2345:0425:x", - }, - { - "2601:645:8200:edf0::c9de/64", - "2601:645:x/64", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:140F::875B:131C mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:x mtu=1360 tx=d81a8a35a0ce", - }, - { - "2601:645:8200:edf0::c9de/64 2601:645:8200:edf0:1ce9:b17d:71f5:f6a3/64", - "2601:645:x/64 2601:645:x/64", - }, - //tests for tailscale ip addresses - { - "100.64.5.6", - "100.64.5.6", - }, - { - "fd7a:115c:a1e0::/96", - "fd7a:115c:a1e0::/96", - }, - //tests for ipv6 and ipv4 together - { - "192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "192.167.x.x 2001:0db8:x", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:2CA1::0567:5673:23b5", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:x", - }, - { - "100.64.5.6 2091:0db8:85a3:0000:0000:8a2e:0370:7334", - "100.64.5.6 2091:0db8:x", - }, - { - "192.167.0.1 120.100.30.47 2041:0000:140F::875B:131B", - "192.167.x.x 120.100.x.x 2041:0000:x", - }, - { - "fd7a:115c:a1e0::/96 192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "fd7a:115c:a1e0::/96 192.167.x.x 2001:0db8:x", - }, - } - - for _, tt := range tests { - gotBuf := redactIPs([]byte(tt.in)) - if string(gotBuf) != tt.want { - t.Errorf("for %q,\n got: %#q\nwant: %#q\n", tt.in, gotBuf, tt.want) - } - } -} func TestAppendMetadata(t *testing.T) { var l Logger diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index ada0df8fc..11a0d0830 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -160,7 +160,8 @@ type CapabilityVersion int // - 113: 2025-01-20: Client communicates to control whether funnel is enabled by sending Hostinfo.IngressEnabled (#14688) // - 114: 2025-01-30: NodeAttrMaxKeyDuration CapMap defined, clients might use it (no tailscaled code change) (#14829) // - 115: 2025-03-07: Client understands DERPRegion.NoMeasureNoHome. -const CurrentCapabilityVersion CapabilityVersion = 115 +// - 116: 2025-05-05: Client serves MagicDNS "AAAA" if NodeAttrMagicDNSPeerAAAA set on self node +const CurrentCapabilityVersion CapabilityVersion = 116 // ID is an integer ID for a user, node, or login allocated by the // control plane. @@ -875,10 +876,37 @@ type Hostinfo struct { // explicitly declared by a node. Location *Location `json:",omitempty"` + TPM *TPMInfo `json:",omitempty"` // TPM device metadata, if available + // NOTE: any new fields containing pointers in this type // require changes to Hostinfo.Equal. } +// TPMInfo contains information about a TPM 2.0 device present on a node. +// All fields are read from TPM_CAP_TPM_PROPERTIES, see Part 2, section 6.13 of +// https://trustedcomputinggroup.org/resource/tpm-library-specification/. +type TPMInfo struct { + // Manufacturer is a 4-letter code from section 4.1 of + // https://trustedcomputinggroup.org/resource/vendor-id-registry/, + // for example "MSFT" for Microsoft. + // Read from TPM_PT_MANUFACTURER. + Manufacturer string `json:",omitempty"` + // Vendor is a vendor ID string, up to 16 characters. + // Read from TPM_PT_VENDOR_STRING_*. + Vendor string `json:",omitempty"` + // Model is a vendor-defined TPM model. + // Read from TPM_PT_VENDOR_TPM_TYPE. + Model int `json:",omitempty"` + // FirmwareVersion is the version number of the firmware. + // Read from TPM_PT_FIRMWARE_VERSION_*. + FirmwareVersion uint64 `json:",omitempty"` + // SpecRevision is the TPM 2.0 spec revision encoded as a single number. All + // revisions can be found at + // https://trustedcomputinggroup.org/resource/tpm-library-specification/. + // Before revision 184, TCG used the "01.83" format for revision 183. + SpecRevision int `json:",omitempty"` +} + // ServiceName is the name of a service, of the form `svc:dns-label`. Services // represent some kind of application provided for users of the tailnet with a // MagicDNS name and possibly dedicated IP addresses. Currently (2024-01-21), @@ -2466,6 +2494,10 @@ const ( // NodeAttrRelayClient permits the node to act as an underlay UDP relay // client. There are no expected values for this key in NodeCapMap. NodeAttrRelayClient NodeCapability = "relay:client" + + // NodeAttrMagicDNSPeerAAAA is a capability that tells the node's MagicDNS + // server to answer AAAA queries about its peers. See tailscale/tailscale#1152. + NodeAttrMagicDNSPeerAAAA NodeCapability = "magicdns-aaaa" ) // SetDNSRequest is a request to add a DNS record. diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 3952f5f47..2c7941d51 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -141,6 +141,9 @@ func (src *Hostinfo) Clone() *Hostinfo { if dst.Location != nil { dst.Location = ptr.To(*src.Location) } + if dst.TPM != nil { + dst.TPM = ptr.To(*src.TPM) + } return dst } @@ -184,6 +187,7 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { AppConnector opt.Bool ServicesHash string Location *Location + TPM *TPMInfo }{}) // Clone makes a deep copy of NetInfo. diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index dd81af5d6..079162a15 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -68,6 +68,7 @@ func TestHostinfoEqual(t *testing.T) { "AppConnector", "ServicesHash", "Location", + "TPM", } if have := fieldsOf(reflect.TypeFor[Hostinfo]()); !reflect.DeepEqual(have, hiHandles) { t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n", diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index f8f9f865c..c76654887 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -301,7 +301,9 @@ func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.User func (v HostinfoView) AppConnector() opt.Bool { return v.ж.AppConnector } func (v HostinfoView) ServicesHash() string { return v.ж.ServicesHash } func (v HostinfoView) Location() LocationView { return v.ж.Location.View() } -func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } +func (v HostinfoView) TPM() views.ValuePointer[TPMInfo] { return views.ValuePointerOf(v.ж.TPM) } + +func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HostinfoViewNeedsRegeneration = Hostinfo(struct { @@ -343,6 +345,7 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { AppConnector opt.Bool ServicesHash string Location *Location + TPM *TPMInfo }{}) // View returns a read-only view of NetInfo. diff --git a/tool/gocross/gocross.go b/tool/gocross/gocross.go index 8011c1095..d14ea0388 100644 --- a/tool/gocross/gocross.go +++ b/tool/gocross/gocross.go @@ -15,9 +15,9 @@ import ( "fmt" "os" "path/filepath" + "runtime/debug" "tailscale.com/atomicfile" - "tailscale.com/version" ) func main() { @@ -28,8 +28,19 @@ func main() { // any time. switch os.Args[1] { case "gocross-version": - fmt.Println(version.GetMeta().GitCommit) - os.Exit(0) + bi, ok := debug.ReadBuildInfo() + if !ok { + fmt.Fprintln(os.Stderr, "failed getting build info") + os.Exit(1) + } + for _, s := range bi.Settings { + if s.Key == "vcs.revision" { + fmt.Println(s.Value) + os.Exit(0) + } + } + fmt.Fprintln(os.Stderr, "did not find vcs.revision in build info") + os.Exit(1) case "is-gocross": // This subcommand exits with an error code when called on a // regular go binary, so it can be used to detect when `go` is @@ -85,9 +96,9 @@ func main() { path := filepath.Join(toolchain, "bin") + string(os.PathListSeparator) + os.Getenv("PATH") env.Set("PATH", path) - debug("Input: %s\n", formatArgv(os.Args)) - debug("Command: %s\n", formatArgv(newArgv)) - debug("Set the following flags/envvars:\n%s\n", env.Diff()) + debugf("Input: %s\n", formatArgv(os.Args)) + debugf("Command: %s\n", formatArgv(newArgv)) + debugf("Set the following flags/envvars:\n%s\n", env.Diff()) args = newArgv if err := env.Apply(); err != nil { @@ -103,7 +114,7 @@ func main() { //go:embed gocross-wrapper.sh var wrapperScript []byte -func debug(format string, args ...any) { +func debugf(format string, args ...any) { debug := os.Getenv("GOCROSS_DEBUG") var ( out *os.File diff --git a/tool/gocross/gocross_test.go b/tool/gocross/gocross_test.go new file mode 100644 index 000000000..82afd268c --- /dev/null +++ b/tool/gocross/gocross_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "tailscale.com/tailcfg": "circular dependency via go generate", + "tailscale.com/version": "circular dependency via go generate", + }, + }.Check(t) +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index f97598075..1880b62b1 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1101,13 +1101,33 @@ type FunnelOption interface { funnelOption() } -type funnelOnly int +type funnelOnly struct{} func (funnelOnly) funnelOption() {} // FunnelOnly configures the listener to only respond to connections from Tailscale Funnel. // The local tailnet will not be able to connect to the listener. -func FunnelOnly() FunnelOption { return funnelOnly(1) } +func FunnelOnly() FunnelOption { return funnelOnly{} } + +type funnelTLSConfig struct{ conf *tls.Config } + +func (f funnelTLSConfig) funnelOption() {} + +// FunnelTLSConfig configures the TLS configuration for [Server.ListenFunnel] +// +// This is rarely needed but can permit requiring client certificates, specific +// ciphers suites, etc. +// +// The provided conf should at least be able to get a certificate, setting +// GetCertificate, Certificates or GetConfigForClient appropriately. +// The most common configuration is to set GetCertificate to +// Server.LocalClient's GetCertificate method. +// +// Unless [FunnelOnly] is also used, the configuration is also used for +// in-tailnet connections that don't arrive over Funnel. +func FunnelTLSConfig(conf *tls.Config) FunnelOption { + return funnelTLSConfig{conf: conf} +} // ListenFunnel announces on the public internet using Tailscale Funnel. // @@ -1140,6 +1160,26 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return nil, err } + // Process, validate opts. + lnOn := listenOnBoth + var tlsConfig *tls.Config + for _, opt := range opts { + switch v := opt.(type) { + case funnelTLSConfig: + if v.conf == nil { + return nil, errors.New("invalid nil FunnelTLSConfig") + } + tlsConfig = v.conf + case funnelOnly: + lnOn = listenOnFunnel + default: + return nil, fmt.Errorf("unknown opts FunnelOption type %T", v) + } + } + if tlsConfig == nil { + tlsConfig = &tls.Config{GetCertificate: s.getCert} + } + ctx := context.Background() st, err := s.Up(ctx) if err != nil { @@ -1177,19 +1217,11 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L } // Start a funnel listener. - lnOn := listenOnBoth - for _, opt := range opts { - if _, ok := opt.(funnelOnly); ok { - lnOn = listenOnFunnel - } - } ln, err := s.listen(network, addr, lnOn) if err != nil { return nil, err } - return tls.NewListener(ln, &tls.Config{ - GetCertificate: s.getCert, - }), nil + return tls.NewListener(ln, tlsConfig), nil } type listenOn string diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index 9df536971..d64bfbbd9 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -33,6 +33,7 @@ import ( "time" "go4.org/mem" + "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/ipn" @@ -64,61 +65,151 @@ var ( // as a last ditch place to report errors. var MainError syncs.AtomicValue[error] -// CleanupBinaries cleans up any resources created by calls to BinaryDir, TailscaleBinary, or TailscaledBinary. -// It should be called from TestMain after all tests have completed. -func CleanupBinaries() { - buildOnce.Do(func() {}) - if binDir != "" { - os.RemoveAll(binDir) +// Binaries contains the paths to the tailscale and tailscaled binaries. +type Binaries struct { + Dir string + Tailscale BinaryInfo + Tailscaled BinaryInfo +} + +// BinaryInfo describes a tailscale or tailscaled binary. +type BinaryInfo struct { + Path string // abs path to tailscale or tailscaled binary + Size int64 + + // FD and FDmu are set on Unix to efficiently copy the binary to a new + // test's automatically-cleaned-up temp directory. + FD *os.File // for Unix (macOS, Linux, ...) + FDMu sync.Locker + + // Contents is used on Windows instead of FD to copy the binary between + // test directories. (On Windows you can't keep an FD open while an earlier + // test's temp directories are deleted.) + // This burns some memory and costs more in I/O, but oh well. + Contents []byte +} + +func (b BinaryInfo) CopyTo(dir string) (BinaryInfo, error) { + ret := b + ret.Path = filepath.Join(dir, path.Base(b.Path)) + + switch runtime.GOOS { + case "linux": + // TODO(bradfitz): be fancy and use linkat with AT_EMPTY_PATH to avoid + // copying? I couldn't get it to work, though. + // For now, just do the same thing as every other Unix and copy + // the binary. + fallthrough + case "darwin", "freebsd", "openbsd", "netbsd": + f, err := os.OpenFile(ret.Path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o755) + if err != nil { + return BinaryInfo{}, err + } + b.FDMu.Lock() + b.FD.Seek(0, 0) + size, err := io.Copy(f, b.FD) + b.FDMu.Unlock() + if err != nil { + f.Close() + return BinaryInfo{}, fmt.Errorf("copying %q: %w", b.Path, err) + } + if size != b.Size { + f.Close() + return BinaryInfo{}, fmt.Errorf("copy %q: size mismatch: %d != %d", b.Path, size, b.Size) + } + if err := f.Close(); err != nil { + return BinaryInfo{}, err + } + return ret, nil + case "windows": + return ret, os.WriteFile(ret.Path, b.Contents, 0o755) + default: + return BinaryInfo{}, fmt.Errorf("unsupported OS %q", runtime.GOOS) } } -// BinaryDir returns a directory containing test tailscale and tailscaled binaries. -// If any test calls BinaryDir, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func BinaryDir(tb testing.TB) string { +// GetBinaries create a temp directory using tb and builds (or copies previously +// built) cmd/tailscale and cmd/tailscaled binaries into that directory. +// +// It fails tb if the build or binary copies fail. +func GetBinaries(tb testing.TB) *Binaries { + dir := tb.TempDir() buildOnce.Do(func() { - binDir, buildErr = buildTestBinaries() + buildErr = buildTestBinaries(dir) }) if buildErr != nil { tb.Fatal(buildErr) } - return binDir -} - -// TailscaleBinary returns the path to the test tailscale binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaleBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscale"+exe()) -} - -// TailscaledBinary returns the path to the test tailscaled binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaledBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscaled"+exe()) + if binariesCache.Dir == dir { + return binariesCache + } + ts, err := binariesCache.Tailscale.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscale binary: %v", err) + } + tsd, err := binariesCache.Tailscaled.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscaled binary: %v", err) + } + return &Binaries{ + Dir: dir, + Tailscale: ts, + Tailscaled: tsd, + } } var ( - buildOnce sync.Once - buildErr error - binDir string + buildOnce sync.Once + buildErr error + binariesCache *Binaries ) // buildTestBinaries builds tailscale and tailscaled. -// It returns the dir containing the binaries. -func buildTestBinaries() (string, error) { - bindir, err := os.MkdirTemp("", "") - if err != nil { - return "", err +// On success, it initializes [binariesCache]. +func buildTestBinaries(dir string) error { + getBinaryInfo := func(name string) (BinaryInfo, error) { + bi := BinaryInfo{Path: filepath.Join(dir, name+exe())} + fi, err := os.Stat(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("stat %q: %v", bi.Path, err) + } + bi.Size = fi.Size() + + switch runtime.GOOS { + case "windows": + bi.Contents, err = os.ReadFile(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("read %q: %v", bi.Path, err) + } + default: + bi.FD, err = os.OpenFile(bi.Path, os.O_RDONLY, 0) + if err != nil { + return BinaryInfo{}, fmt.Errorf("open %q: %v", bi.Path, err) + } + bi.FDMu = new(sync.Mutex) + // Note: bi.FD is copied around between tests but never closed, by + // design. It will be closed when the process exits, and that will + // close the inode that we're copying the bytes from for each test. + } + return bi, nil } - err = build(bindir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") + err := build(dir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") if err != nil { - os.RemoveAll(bindir) - return "", err + return err } - return bindir, nil + b := &Binaries{ + Dir: dir, + } + b.Tailscale, err = getBinaryInfo("tailscale") + if err != nil { + return err + } + b.Tailscaled, err = getBinaryInfo("tailscaled") + if err != nil { + return err + } + binariesCache = b + return nil } func build(outDir string, targets ...string) error { @@ -436,14 +527,16 @@ func NewTestEnv(t testing.TB, opts ...TestEnvOpt) *TestEnv { derpMap := RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") logc := new(LogCatcher) control := &testcontrol.Server{ + Logf: logger.WithPrefix(t.Logf, "testcontrol: "), DERPMap: derpMap, } control.HTTPTestServer = httptest.NewUnstartedServer(control) trafficTrap := new(trafficTrap) + binaries := GetBinaries(t) e := &TestEnv{ t: t, - cli: TailscaleBinary(t), - daemon: TailscaledBinary(t), + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, LogCatcher: logc, LogCatcherServer: httptest.NewServer(logc), Control: control, @@ -484,6 +577,7 @@ type TestNode struct { mu sync.Mutex onLogLine []func([]byte) + lc *local.Client } // NewTestNode allocates a temp directory for a new test node. @@ -500,14 +594,18 @@ func NewTestNode(t *testing.T, env *TestEnv) *TestNode { env: env, dir: dir, sockFile: sockFile, - stateFile: filepath.Join(dir, "tailscale.state"), + stateFile: filepath.Join(dir, "tailscaled.state"), // matches what cmd/tailscaled uses } - // Look for a data race. Once we see the start marker, start logging the rest. + // Look for a data race or panic. + // Once we see the start marker, start logging the rest. var sawRace bool var sawPanic bool n.addLogLineHook(func(line []byte) { lineB := mem.B(line) + if mem.Contains(lineB, mem.S("DEBUG-ADDR=")) { + t.Log(strings.TrimSpace(string(line))) + } if mem.Contains(lineB, mem.S("WARNING: DATA RACE")) { sawRace = true } @@ -522,6 +620,20 @@ func NewTestNode(t *testing.T, env *TestEnv) *TestNode { return n } +func (n *TestNode) LocalClient() *local.Client { + n.mu.Lock() + defer n.mu.Unlock() + if n.lc == nil { + tr := &http.Transport{} + n.lc = &local.Client{ + Socket: n.sockFile, + UseSocketOnly: true, + } + n.env.t.Cleanup(tr.CloseIdleConnections) + } + return n.lc +} + func (n *TestNode) diskPrefs() *ipn.Prefs { t := n.env.t t.Helper() @@ -658,6 +770,27 @@ func (d *Daemon) MustCleanShutdown(t testing.TB) { } } +// awaitTailscaledRunnable tries to run `tailscaled --version` until it +// works. This is an unsatisfying workaround for ETXTBSY we were seeing +// on GitHub Actions that aren't understood. It's not clear what's holding +// a writable fd to tailscaled after `go install` completes. +// See https://github.com/tailscale/tailscale/issues/15868. +func (n *TestNode) awaitTailscaledRunnable() error { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(10*time.Second, func() error { + out, err := exec.Command(n.env.daemon, "--version").CombinedOutput() + if err == nil { + return nil + } + t.Logf("error running tailscaled --version: %v, %s", err, out) + return err + }); err != nil { + return fmt.Errorf("gave up trying to run tailscaled: %v", err) + } + return nil +} + // StartDaemon starts the node's tailscaled, failing if it fails to start. // StartDaemon ensures that the process will exit when the test completes. func (n *TestNode) StartDaemon() *Daemon { @@ -666,11 +799,17 @@ func (n *TestNode) StartDaemon() *Daemon { func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { t := n.env.t + + if err := n.awaitTailscaledRunnable(); err != nil { + t.Fatalf("awaitTailscaledRunnable: %v", err) + } + cmd := exec.Command(n.env.daemon) cmd.Args = append(cmd.Args, - "--state="+n.stateFile, + "--statedir="+n.dir, "--socket="+n.sockFile, "--socks5-server=localhost:0", + "--debug=localhost:0", ) if *verboseTailscaled { cmd.Args = append(cmd.Args, "-verbose=2") @@ -684,7 +823,6 @@ func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { cmd.Args = append(cmd.Args, "--config="+n.configFile) } cmd.Env = append(os.Environ(), - "TS_CONTROL_IS_PLAINTEXT_HTTP=1", "TS_DEBUG_PERMIT_HTTP_C2N=1", "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, "HTTP_PROXY="+n.env.TrafficTrapServer.URL, diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 0da2e6086..90cc7e443 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -49,7 +49,6 @@ func TestMain(m *testing.M) { os.Setenv("TS_DISABLE_UPNP", "true") flag.Parse() v := m.Run() - CleanupBinaries() if v != 0 { os.Exit(v) } @@ -278,15 +277,20 @@ func TestOneNodeUpAuth(t *testing.T) { t.Logf("Running up --login-server=%s ...", env.ControlURL()) cmd := n1.Tailscale("up", "--login-server="+env.ControlURL()) - var authCountAtomic int32 + var authCountAtomic atomic.Int32 cmd.Stdout = &authURLParserWriter{fn: func(urlStr string) error { + t.Logf("saw auth URL %q", urlStr) if env.Control.CompleteAuth(urlStr) { - atomic.AddInt32(&authCountAtomic, 1) + if authCountAtomic.Add(1) > 1 { + err := errors.New("completed multple auth URLs") + t.Error(err) + return err + } t.Logf("completed auth path %s", urlStr) return nil } err := fmt.Errorf("Failed to complete auth path to %q", urlStr) - t.Log(err) + t.Error(err) return err }} cmd.Stderr = cmd.Stdout @@ -297,7 +301,7 @@ func TestOneNodeUpAuth(t *testing.T) { n1.AwaitRunning() - if n := atomic.LoadInt32(&authCountAtomic); n != 1 { + if n := authCountAtomic.Load(); n != 1 { t.Errorf("Auth URLs completed = %d; want 1", n) } diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 52b96fe4d..71205f897 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -55,6 +55,10 @@ type Server struct { MagicDNSDomain string HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + // AllNodesSameUser, if true, makes all created nodes + // belong to the same user. + AllNodesSameUser bool + // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL @@ -96,9 +100,9 @@ type Server struct { logins map[key.NodePublic]*tailcfg.Login updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath - nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + nodeKeyAuthed set.Set[key.NodePublic] + msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. } // BaseURL returns the server's base URL, without trailing slash. @@ -522,6 +526,10 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) return u, s.logins[nodeKey] } id := tailcfg.UserID(len(s.users) + 1) + if s.AllNodesSameUser { + id = 123 + } + s.logf("Created user %v for node %s", id, nodeKey) loginName := fmt.Sprintf("user-%d@%s", id, domain) displayName := fmt.Sprintf("User %d", id) login := &tailcfg.Login{ @@ -582,10 +590,8 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool { if ap.nodeKey.IsZero() { panic("zero AuthPath.NodeKey") } - if s.nodeKeyAuthed == nil { - s.nodeKeyAuthed = map[key.NodePublic]bool{} - } - s.nodeKeyAuthed[ap.nodeKey] = true + s.nodeKeyAuthed.Make() + s.nodeKeyAuthed.Add(ap.nodeKey) ap.CompleteSuccessfully() return true } @@ -645,36 +651,40 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. if s.nodes == nil { s.nodes = map[key.NodePublic]*tailcfg.Node{} } - + _, ok := s.nodes[nk] machineAuthorized := true // TODO: add Server.RequireMachineAuth + if !ok { - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) - v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) + nodeID := len(s.nodes) + 1 + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(nodeID>>8), uint8(nodeID)), 32) + v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - allowedIPs := []netip.Prefix{ - v4Prefix, - v6Prefix, - } - - s.nodes[nk] = &tailcfg.Node{ - ID: tailcfg.NodeID(user.ID), - StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), - User: user.ID, - Machine: mkey, - Key: req.NodeKey, - MachineAuthorized: machineAuthorized, - Addresses: allowedIPs, - AllowedIPs: allowedIPs, - Hostinfo: req.Hostinfo.View(), - Name: req.Hostinfo.Hostname, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityHTTPS, - tailcfg.NodeAttrFunnel, - tailcfg.CapabilityFunnelPorts + "?ports=8080,443", - }, + allowedIPs := []netip.Prefix{ + v4Prefix, + v6Prefix, + } + node := &tailcfg.Node{ + ID: tailcfg.NodeID(nodeID), + StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(nodeID))), + User: user.ID, + Machine: mkey, + Key: req.NodeKey, + MachineAuthorized: machineAuthorized, + Addresses: allowedIPs, + AllowedIPs: allowedIPs, + Hostinfo: req.Hostinfo.View(), + Name: req.Hostinfo.Hostname, + Capabilities: []tailcfg.NodeCapability{ + tailcfg.CapabilityHTTPS, + tailcfg.NodeAttrFunnel, + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityFunnelPorts + "?ports=8080,443", + }, + } + s.nodes[nk] = node } requireAuth := s.RequireAuth - if requireAuth && s.nodeKeyAuthed[nk] { + if requireAuth && s.nodeKeyAuthed.Contains(nk) { requireAuth = false } allExpired := s.allExpired @@ -951,7 +961,6 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, node.CapMap = nodeCapMap node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) - user, _ := s.getUser(nk) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) dns := s.DNSConfig if dns != nil && s.MagicDNSDomain != "" { @@ -1013,7 +1022,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, }) res.UserProfiles = s.allUserProfiles() - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(node.ID>>8), uint8(node.ID)), 32) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) res.Node.Addresses = []netip.Prefix{ diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 1e080414d..256227d6c 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -134,11 +134,12 @@ func newHarness(t *testing.T) *Harness { loginServer := fmt.Sprintf("http://%s", ln.Addr()) t.Logf("loginServer: %s", loginServer) + binaries := integration.GetBinaries(t) h := &Harness{ pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), + binaryDir: binaries.Dir, + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, signer: signer, loginServerURL: loginServer, cs: cs, diff --git a/tstest/integration/vms/vms_test.go b/tstest/integration/vms/vms_test.go index 6d73a3f78..f71f2bdbf 100644 --- a/tstest/integration/vms/vms_test.go +++ b/tstest/integration/vms/vms_test.go @@ -28,7 +28,6 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" "tailscale.com/tstest" - "tailscale.com/tstest/integration" "tailscale.com/types/logger" ) @@ -51,13 +50,6 @@ var ( }() ) -func TestMain(m *testing.M) { - flag.Parse() - v := m.Run() - integration.CleanupBinaries() - os.Exit(v) -} - func TestDownloadImages(t *testing.T) { if !*runVMTests { t.Skip("not running integration tests (need --run-vm-tests)") diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 0f411521b..b87298c61 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -1710,55 +1710,43 @@ func (n *nftablesRunner) AddSNATRule() error { return nil } +func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error { + + rule, err := createMatchSubnetRouteMarkRule(table, chain, Masq) + if err != nil { + return fmt.Errorf("create match subnet route mark rule: %w", err) + } + + SNATRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find SNAT rule v4: %w", err) + } + + if SNATRule != nil { + _ = conn.DelRule(SNATRule) + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del SNAT rule: %w", err) + } + + return nil +} + // DelSNATRule removes the netfilter rule to SNAT traffic destined for // local subnets. An error is returned if the rule does not exist. func (n *nftablesRunner) DelSNATRule() error { conn := n.conn - hexTSFwmarkMask := getTailscaleFwmarkMask() - hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() - - exprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: hexTSFwmarkMask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: hexTSSubnetRouteMark, - }, - &expr.Counter{}, - &expr.Masq{}, - } - for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { - return fmt.Errorf("get postrouting chain v4: %w", err) + return fmt.Errorf("get postrouting chain: %w", err) } - - rule := &nftables.Rule{ - Table: table.Nat, - Chain: chain, - Exprs: exprs, - } - - SNATRule, err := findRule(conn, rule) + err = delMatchSubnetRouteMarkMasqRule(conn, table.Nat, chain) if err != nil { - return fmt.Errorf("find SNAT rule v4: %w", err) + return err } - - if SNATRule != nil { - _ = conn.DelRule(SNATRule) - } - } - - if err := conn.Flush(); err != nil { - return fmt.Errorf("flush del SNAT rule: %w", err) } return nil diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 712a7b939..6fb180ed6 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "runtime" + "slices" "strings" "testing" @@ -24,21 +25,21 @@ import ( "tailscale.com/types/logger" ) +func toAnySlice[T any](s []T) []any { + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing // users to make sense of large byte literals more easily. func nfdump(b []byte) string { var buf bytes.Buffer - i := 0 - for ; i < len(b); i += 4 { - // TODO: show printable characters as ASCII - fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", - b[i], - b[i+1], - b[i+2], - b[i+3]) - } - for ; i < len(b); i++ { - fmt.Fprintf(&buf, "%02x ", b[i]) + for c := range slices.Chunk(b, 4) { + format := strings.Repeat("%02x ", len(c)) + fmt.Fprintf(&buf, format+"\n", toAnySlice(c)...) } return buf.String() } @@ -75,7 +76,7 @@ func linediff(a, b string) string { return buf.String() } -func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { +func newTestConn(t *testing.T, want [][]byte, reply [][]netlink.Message) *nftables.Conn { conn, err := nftables.New(nftables.WithTestDial( func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { @@ -96,7 +97,13 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { } want = want[1:] } - return req, nil + // no reply for batch end message + if len(want) == 0 { + return nil, nil + } + rep := reply[0] + reply = reply[1:] + return rep, nil })) if err != nil { t.Fatal(err) @@ -120,7 +127,7 @@ func TestInsertHookRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -160,7 +167,7 @@ func TestInsertLoopbackRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -196,7 +203,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) tableV6 := testConn.AddTable(&nftables.Table{ Family: protoV6, Name: "ts-filter-test", @@ -232,7 +239,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -264,7 +271,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -296,7 +303,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -328,7 +335,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -360,7 +367,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -392,7 +399,7 @@ func TestAddAcceptIncomingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -420,11 +427,11 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { // nft add chain ip ts-nat-test ts-postrouting-test { type nat hook postrouting priority 100; } []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x03\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"), // nft add rule ip ts-nat-test ts-postrouting-test meta mark & 0x00ff0000 == 0x00040000 counter masquerade - []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xd8\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"), // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-nat-test", @@ -436,7 +443,46 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { Hooknum: nftables.ChainHookPostrouting, Priority: nftables.ChainPriorityNATSource, }) - err := addMatchSubnetRouteMarkRule(testConn, table, chain, Accept) + err := addMatchSubnetRouteMarkRule(testConn, table, chain, Masq) + if err != nil { + t.Fatal(err) + } +} + +func TestDelMatchSubnetRouteMarkMasqRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + reply := [][]netlink.Message{ + nil, + {{Header: netlink.Header{Length: 0x128, Type: 0xa06, Flags: 0x802, Sequence: 0xa213d55d, PID: 0x11e79}, Data: []uint8{0x2, 0x0, 0x0, 0x8c, 0xd, 0x0, 0x1, 0x0, 0x6e, 0x61, 0x74, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0x0, 0x0, 0x0, 0x18, 0x0, 0x2, 0x0, 0x74, 0x73, 0x2d, 0x70, 0x6f, 0x73, 0x74, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0xe0, 0x0, 0x4, 0x0, 0x24, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x65, 0x74, 0x61, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x2, 0x0, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x3, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x4c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x62, 0x69, 0x74, 0x77, 0x69, 0x73, 0x65, 0x0, 0x3c, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x4, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0xff, 0x0, 0x0, 0xc, 0x0, 0x5, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0x8, 0x0, 0x1, 0x0, 0x63, 0x6d, 0x70, 0x0, 0x20, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x3, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x4, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x61, 0x73, 0x71, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x2, 0x0}}}, + {{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + {{Header: netlink.Header{Length: 0x24, Type: 0x2, Flags: 0x100, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0, 0x48, 0x0, 0x0, 0x0, 0x8, 0xa, 0x5, 0x0, 0xcb, 0xdc, 0x1f, 0x31, 0x79, 0x1e, 0x1, 0x0}}}, + } + want := [][]byte{ + // get rules in nat-test table ts-postrouting-test chain + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00"), + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft delete rule ip nat-test ts-postrouting-test handle 4 + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x0c\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x04"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + conn := newTestConn(t, want, reply) + + table := &nftables.Table{ + Family: proto, + Name: "nat-test", + } + chain := &nftables.Chain{ + Name: "ts-postrouting-test", + Table: table, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource, + } + + err := delMatchSubnetRouteMarkMasqRule(conn, table, chain) if err != nil { t.Fatal(err) } @@ -456,7 +502,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 31bf66b2b..7df46f76c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "expvar" "fmt" @@ -316,7 +317,11 @@ type Conn struct { // by node key, node ID, and discovery key. peerMap peerMap - // discoInfo is the state for an active DiscoKey. + // relayManager manages allocation and handshaking of + // [tailscale.com/net/udprelay.Server] endpoints. + relayManager relayManager + + // discoInfo is the state for an active peer DiscoKey. discoInfo map[key.DiscoPublic]*discoInfo // netInfoFunc is a callback that provides a tailcfg.NetInfo when @@ -1624,6 +1629,27 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, geneveVNI *uint32, dstKey ke c.mu.Unlock() return false, errConnClosed } + var di *discoInfo + switch { + case isRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(dstDisco) + if !ok { + c.mu.Unlock() + return false, errors.New("unknown relay server") + } + case c.peerMap.knownPeerDiscoKey(dstDisco): + di = c.discoInfoForKnownPeerLocked(dstDisco) + default: + // This is an attempt to send to an unknown peer that is not a relay + // server. This can happen when a call to the current function, which is + // often via a new goroutine, races with applying a change in the + // netmap, e.g. the associated peer(s) for dstDisco goes away. + c.mu.Unlock() + return false, errors.New("unknown peer") + } + c.mu.Unlock() + pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. if geneveVNI != nil { gh := packet.GeneveHeader{ @@ -1640,23 +1666,6 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, geneveVNI *uint32, dstKey ke } pkt = append(pkt, disco.Magic...) pkt = c.discoPublic.AppendTo(pkt) - var di *discoInfo - if !isRelayHandshakeMsg { - di = c.discoInfoLocked(dstDisco) - } else { - // c.discoInfoLocked() caches [*discoInfo] for dstDisco. It assumes that - // dstDisco is a known Tailscale peer, and will be cleaned around - // network map changes. In the case of a relay handshake message, - // dstDisco belongs to a relay server with a disco key that is - // discovered at endpoint allocation time or [disco.CallMeMaybeVia] - // reception time. There is no clear ending to its lifetime, so we - // can't cache with the same strategy. Instead, generate the shared - // key on the fly for now. - di = &discoInfo{ - sharedKey: c.discoPrivate.Shared(dstDisco), - } - } - c.mu.Unlock() if isDERP { metricSendDiscoDERP.Add(1) @@ -1707,6 +1716,45 @@ const ( discoRXPathRawSocket discoRXPath = "raw socket" ) +const discoHeaderLen = len(disco.Magic) + key.DiscoPublicRawLen + +// isDiscoMaybeGeneve reports whether msg is a Tailscale Disco protocol +// message, and if true, whether it is encapsulated by a Geneve header. +// +// isGeneveEncap is only relevant when isDiscoMsg is true. +// +// Naked Disco, Geneve followed by Disco, and naked WireGuard can be confidently +// distinguished based on the following: +// 1. [disco.Magic] is sufficiently non-overlapping with a Geneve protocol +// field value of [packet.GeneveProtocolDisco]. +// 2. [disco.Magic] is sufficiently non-overlapping with the first 4 bytes of +// a WireGuard packet. +// 3. [packet.GeneveHeader] with a Geneve protocol field value of +// [packet.GeneveProtocolDisco] is sufficiently non-overlapping with the +// first 4 bytes of a WireGuard packet. +func isDiscoMaybeGeneve(msg []byte) (isDiscoMsg bool, isGeneveEncap bool) { + if len(msg) < discoHeaderLen { + return false, false + } + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, false + } + if len(msg) < packet.GeneveFixedHeaderLength+discoHeaderLen { + return false, false + } + if msg[0]&0xC0 != 0 || // version bits that we always transmit as 0s + msg[1]&0x3F != 0 || // reserved bits that we always transmit as 0s + binary.BigEndian.Uint16(msg[2:4]) != packet.GeneveProtocolDisco || + msg[7] != 0 { // reserved byte that we always transmit as 0 + return false, false + } + msg = msg[packet.GeneveFixedHeaderLength:] + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, true + } + return false, false +} + // handleDiscoMessage handles a discovery message and reports whether // msg was a Tailscale inter-node discovery message. // @@ -1722,18 +1770,28 @@ const ( // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { - const headerLen = len(disco.Magic) + key.DiscoPublicRawLen - if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { - return false + isDiscoMsg, isGeneveEncap := isDiscoMaybeGeneve(msg) + if !isDiscoMsg { + return } + var geneve packet.GeneveHeader + if isGeneveEncap { + err := geneve.Decode(msg) + if err != nil { + // Decode only returns an error when 'msg' is too short, and + // 'isGeneveEncap' indicates it's a sufficient length. + c.logf("[unexpected] geneve header decoding error: %v", err) + return + } + msg = msg[packet.GeneveFixedHeaderLength:] + } + // The control bit should only be set for relay handshake messages + // terminating on or originating from a UDP relay server. We have yet to + // open the encrypted payload to determine the [disco.MessageType], but + // we assert it should be handshake-related. + shouldBeRelayHandshakeMsg := isGeneveEncap && geneve.Control - // If the first four parts are the prefix of disco.Magic - // (0x5453f09f) then it's definitely not a valid WireGuard - // packet (which starts with little-endian uint32 1, 2, 3, 4). - // Use naked returns for all following paths. - isDiscoMsg = true - - sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):headerLen])) + sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):discoHeaderLen])) c.mu.Lock() defer c.mu.Unlock() @@ -1750,7 +1808,20 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } - if !c.peerMap.knownPeerDiscoKey(sender) { + var di *discoInfo + switch { + case shouldBeRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(sender) + if !ok { + if debugDisco() { + c.logf("magicsock: disco: ignoring disco-looking relay handshake frame, no active handshakes with key %v over VNI %d", sender.ShortString(), geneve.VNI) + } + return + } + case c.peerMap.knownPeerDiscoKey(sender): + di = c.discoInfoForKnownPeerLocked(sender) + default: metricRecvDiscoBadPeer.Add(1) if debugDisco() { c.logf("magicsock: disco: ignoring disco-looking frame, don't know of key %v", sender.ShortString()) @@ -1759,7 +1830,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } isDERP := src.Addr() == tailcfg.DerpMagicIPAddr - if !isDERP { + if !isDERP && !shouldBeRelayHandshakeMsg { // Record receive time for UDP transport packets. pi, ok := c.peerMap.byIPPort[src] if ok { @@ -1767,17 +1838,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } } - // We're now reasonably sure we're expecting communication from - // this peer, do the heavy crypto lifting to see what they want. - // - // From here on, peerNode and de are non-nil. + // We're now reasonably sure we're expecting communication from 'sender', + // do the heavy crypto lifting to see what they want. - di := c.discoInfoLocked(sender) - - sealedBox := msg[headerLen:] + sealedBox := msg[discoHeaderLen:] payload, ok := di.sharedKey.Open(sealedBox) if !ok { - // This might be have been intended for a previous + // This might have been intended for a previous // disco key. When we restart we get a new disco key // and old packets might've still been in flight (or // scheduled). This is particularly the case for LANs @@ -1820,6 +1887,19 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke metricRecvDiscoUDP.Add(1) } + if shouldBeRelayHandshakeMsg { + challenge, ok := dm.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + // We successfully parsed the disco message, but it wasn't a + // challenge. We should never receive other message types + // from a relay server with the Geneve header control bit set. + c.logf("[unexpected] %T packets should not come from a relay server with Geneve control bit set", dm) + return + } + c.relayManager.handleBindUDPRelayEndpointChallenge(challenge, di, src, geneve.VNI) + return + } + switch dm := dm.(type) { case *disco.Ping: metricRecvDiscoPing.Add(1) @@ -1835,18 +1915,28 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } return true }) - case *disco.CallMeMaybe: + case *disco.CallMeMaybe, *disco.CallMeMaybeVia: + var via *disco.CallMeMaybeVia + isVia := false + msgType := "CallMeMaybe" + cmm, ok := dm.(*disco.CallMeMaybe) + if !ok { + via = dm.(*disco.CallMeMaybeVia) + msgType = "CallMeMaybeVia" + isVia = true + } + metricRecvDiscoCallMeMaybe.Add(1) if !isDERP || derpNodeSrc.IsZero() { - // CallMeMaybe messages should only come via DERP. - c.logf("[unexpected] CallMeMaybe packets should only come via DERP") + // CallMeMaybe{Via} messages should only come via DERP. + c.logf("[unexpected] %s packets should only come via DERP", msgType) return } nodeKey := derpNodeSrc ep, ok := c.peerMap.endpointForNodeKey(nodeKey) if !ok { metricRecvDiscoCallMeMaybeBadNode.Add(1) - c.logf("magicsock: disco: ignoring CallMeMaybe from %v; %v is unknown", sender.ShortString(), derpNodeSrc.ShortString()) + c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) return } epDisco := ep.disco.Load() @@ -1855,14 +1945,23 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } if epDisco.key != di.discoKey { metricRecvDiscoCallMeMaybeBadDisco.Add(1) - c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source") + c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) return } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), - len(dm.MyNumber)) - go ep.handleCallMeMaybe(dm) + if isVia { + c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", + c.discoShort, epDisco.short, via.ServerDisco.ShortString(), + ep.publicKey.ShortString(), derpStr(src.String()), + len(via.AddrPorts)) + c.relayManager.handleCallMeMaybeVia(via) + } else { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", + c.discoShort, epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + len(cmm.MyNumber)) + go ep.handleCallMeMaybe(cmm) + } + } return } @@ -2034,10 +2133,15 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { } } -// discoInfoLocked returns the previous or new discoInfo for k. +// discoInfoForKnownPeerLocked returns the previous or new discoInfo for k. +// +// Callers must only pass key.DiscoPublic's that are present in and +// lifetime-managed via [Conn].peerMap. UDP relay server disco keys are discovered +// at relay endpoint allocation time or [disco.CallMeMaybeVia] reception time +// and therefore must never pass through this method. // // c.mu must be held. -func (c *Conn) discoInfoLocked(k key.DiscoPublic) *discoInfo { +func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { di, ok := c.discoInfo[k] if !ok { di = &discoInfo{ diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index f50f21f56..1a899ea22 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -3155,3 +3155,165 @@ func TestNetworkDownSendErrors(t *testing.T) { t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String()) } } + +func Test_isDiscoMaybeGeneve(t *testing.T) { + discoPub := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 30: 30, 31: 31})) + nakedDisco := make([]byte, 0, 512) + nakedDisco = append(nakedDisco, disco.Magic...) + nakedDisco = discoPub.AppendTo(nakedDisco) + + geneveEncapDisco := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err := gh.Encode(geneveEncapDisco) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDisco[packet.GeneveFixedHeaderLength:], nakedDisco) + + nakedWireGuardInitiation := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardInitiation, device.MessageInitiationType) + nakedWireGuardResponse := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardResponse, device.MessageResponseType) + nakedWireGuardCookieReply := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardCookieReply, device.MessageCookieReplyType) + nakedWireGuardTransport := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardTransport, device.MessageTransportType) + + geneveEncapWireGuard := make([]byte, packet.GeneveFixedHeaderLength+len(nakedWireGuardInitiation)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolWireGuard, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapWireGuard) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapWireGuard[packet.GeneveFixedHeaderLength:], nakedWireGuardInitiation) + + geneveEncapDiscoNonZeroGeneveVersion := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 1, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVersion) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDiscoNonZeroGeneveVersion[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveReservedBits := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveReservedBits) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveReservedBits[1] |= 0x3F + copy(geneveEncapDiscoNonZeroGeneveReservedBits[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveVNILSB := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVNILSB) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveVNILSB[7] |= 0xFF + copy(geneveEncapDiscoNonZeroGeneveVNILSB[packet.GeneveFixedHeaderLength:], nakedDisco) + + tests := []struct { + name string + msg []byte + wantIsDiscoMsg bool + wantIsGeneveEncap bool + }{ + { + name: "naked disco", + msg: nakedDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco", + msg: geneveEncapDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: true, + }, + { + name: "geneve encap disco nonzero geneve version", + msg: geneveEncapDiscoNonZeroGeneveVersion, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve reserved bits", + msg: geneveEncapDiscoNonZeroGeneveReservedBits, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve vni lsb", + msg: geneveEncapDiscoNonZeroGeneveVNILSB, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap wireguard", + msg: geneveEncapWireGuard, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Initiation type", + msg: nakedWireGuardInitiation, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Response type", + msg: nakedWireGuardResponse, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Cookie Reply type", + msg: nakedWireGuardCookieReply, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Transport type", + msg: nakedWireGuardTransport, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsDiscoMsg, gotIsGeneveEncap := isDiscoMaybeGeneve(tt.msg) + if gotIsDiscoMsg != tt.wantIsDiscoMsg { + t.Errorf("isDiscoMaybeGeneve() gotIsDiscoMsg = %v, want %v", gotIsDiscoMsg, tt.wantIsDiscoMsg) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("isDiscoMaybeGeneve() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + }) + } +} diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go new file mode 100644 index 000000000..bf737b078 --- /dev/null +++ b/wgengine/magicsock/relaymanager.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "sync" + + "tailscale.com/disco" + "tailscale.com/types/key" +) + +// relayManager manages allocation and handshaking of +// [tailscale.com/net/udprelay.Server] endpoints. The zero value is ready for +// use. +type relayManager struct { + mu sync.Mutex // guards the following fields + discoInfoByServerDisco map[key.DiscoPublic]*discoInfo +} + +func (h *relayManager) initLocked() { + if h.discoInfoByServerDisco != nil { + return + } + h.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo) +} + +// discoInfo returns a [*discoInfo] for 'serverDisco' if there is an +// active/ongoing handshake with it, otherwise it returns nil, false. +func (h *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { + h.mu.Lock() + defer h.mu.Unlock() + h.initLocked() + di, ok := h.discoInfoByServerDisco[serverDisco] + return di, ok +} + +func (h *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) { + h.mu.Lock() + defer h.mu.Unlock() + h.initLocked() + // TODO(jwhited): implement +} + +func (h *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) { + h.mu.Lock() + defer h.mu.Unlock() + h.initLocked() + // TODO(jwhited): implement +}