diff --git a/cmd/tta/bypass_linux.go b/cmd/tta/bypass_linux.go new file mode 100644 index 000000000..868cd716f --- /dev/null +++ b/cmd/tta/bypass_linux.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" + "tailscale.com/net/netmon" +) + +// bypassControlFunc is set as net.Dialer.Control so that sockets dialed by +// TTA bypass tailscaled's policy routing. Without it, sockets opened before +// tailscaled installs an exit-node route would have their packets rerouted +// via the exit node when the route is later installed, breaking the +// existing connection. +// +// We bind the socket to the default route's interface (typically the VM's +// LAN-facing NIC) rather than relying on the bypass fwmark. The fwmark +// approach is conditional on tailscaled having configured SO_MARK-based +// policy routing; binding to the underlying interface is unconditional. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + ifc, err := netmon.DefaultRouteInterface() + if err != nil { + return fmt.Errorf("netmon.DefaultRouteInterface: %w", err) + } + var sockErr error + if err := c.Control(func(fd uintptr) { + sockErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifc) + }); err != nil { + return err + } + if sockErr != nil { + return fmt.Errorf("setting SO_BINDTODEVICE on %q: %w", ifc, sockErr) + } + return nil +} diff --git a/cmd/tta/bypass_other.go b/cmd/tta/bypass_other.go new file mode 100644 index 000000000..e6b453f49 --- /dev/null +++ b/cmd/tta/bypass_other.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package main + +import "syscall" + +// bypassControlFunc is a no-op on non-Linux platforms; SO_MARK is a Linux +// concept and exit-node routing only matters here for Linux VMs in vmtest. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index a94727503..e62b5f025 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -335,7 +335,9 @@ func main() { } func connect() (net.Conn, error) { - var d net.Dialer + d := net.Dialer{ + Control: bypassControlFunc, + } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() c, err := d.DialContext(ctx, "tcp", *driverAddr) diff --git a/tstest/natlab/vmtest/vmtest.go b/tstest/natlab/vmtest/vmtest.go index 1f58d02f9..964c114ec 100644 --- a/tstest/natlab/vmtest/vmtest.go +++ b/tstest/natlab/vmtest/vmtest.go @@ -387,6 +387,45 @@ func (e *Env) startWebServer(ctx context.Context, n *Node) error { return nil } +// SetExitNode sets the client node's exit node to use for internet traffic. +// If exitNode is nil, the client's exit node is cleared (i.e., turned off). +// Otherwise exitNode must be a tailnet node with an approved 0.0.0.0/0 (and +// ::/0) route, typically configured via [AdvertiseRoutes] and +// [Env.ApproveRoutes]. +func (e *Env) SetExitNode(client, exitNode *Node) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var ip netip.Addr + if exitNode != nil { + st, err := exitNode.agent.Status(ctx) + if err != nil { + e.t.Fatalf("SetExitNode: status for %s: %v", exitNode.name, err) + } + if len(st.Self.TailscaleIPs) == 0 { + e.t.Fatalf("SetExitNode: %s has no Tailscale IPs", exitNode.name) + } + ip = st.Self.TailscaleIPs[0] + } + + if _, err := client.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", + ExitNodeIP: ip, + }, + ExitNodeIDSet: true, + ExitNodeIPSet: true, + }); err != nil { + e.t.Fatalf("SetExitNode(%s -> %v): %v", client.name, exitNode, err) + } + if exitNode == nil { + e.t.Logf("[%s] cleared exit node", client.name) + } else { + e.t.Logf("[%s] using exit node %s (%v)", client.name, exitNode.name, ip) + } +} + // ApproveRoutes tells the test control server to approve subnet routes // for the given node. The routes should be CIDR strings. func (e *Env) ApproveRoutes(n *Node, routes ...string) { diff --git a/tstest/natlab/vmtest/vmtest_test.go b/tstest/natlab/vmtest/vmtest_test.go index 89e5a022f..c2d0329f7 100644 --- a/tstest/natlab/vmtest/vmtest_test.go +++ b/tstest/natlab/vmtest/vmtest_test.go @@ -127,3 +127,106 @@ func testSiteToSite(t *testing.T, srOS vmtest.OSImage) { t.Fatalf("source IP not preserved: expected %q in response, got %q", backendAIP, body) } } + +// TestInterNetworkTCP verifies that vnet routes raw TCP between simulated +// networks: a non-Tailscale VM on one NAT'd LAN can reach a webserver on a +// different network using a 1:1 NAT, and the webserver sees the client's +// network's WAN IP as the source (post-NAT). +func TestInterNetworkTCP(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet()) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + env.Start() + + body := env.HTTPGet(client, fmt.Sprintf("http://%s:8080/", webWAN)) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + t.Fatalf("unexpected response: %q", body) + } + if !strings.Contains(body, "from "+clientWAN) { + t.Fatalf("expected source %q in response, got %q", clientWAN, body) + } +} + +// TestExitNode verifies that switching the client's exit node setting between +// off, exit1, and exit2 correctly routes the client's internet traffic. +// +// Topology: each of the client and the two exit nodes lives behind its own NAT +// with a unique WAN IP, and a webserver lives on yet another network using a +// 1:1 NAT so it's reachable from the simulated internet at a stable address. +// The webserver echoes the source IP of incoming requests, so we can tell +// which network's NAT the client's traffic egressed through: +// - off: source is the client's network WAN IP. +// - exit1: source is exit1's network WAN IP. +// - exit2: source is exit2's network WAN IP. +func TestExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + exit1WAN = "2.0.0.1" + exit2WAN = "3.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + exit1Net := env.AddNetwork(exit1WAN, "192.168.2.1/24", vnet.EasyNAT) + exit2Net := env.AddNetwork(exit2WAN, "192.168.3.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + exit1 := env.AddNode("exit1", exit1Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + exit2 := env.AddNode("exit2", exit2Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + env.Start() + env.ApproveRoutes(exit1, "0.0.0.0/0", "::/0") + env.ApproveRoutes(exit2, "0.0.0.0/0", "::/0") + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + tests := []struct { + name string // subtest name + exit *vmtest.Node + wantSrc string + }{ + {"off", nil, clientWAN}, + {"exit1", exit1, exit1WAN}, + {"exit2", exit2, exit2WAN}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + env.SetExitNode(client, tt.exit) + body := env.HTTPGet(client, webURL) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + t.Fatalf("unexpected webserver response: %q", body) + } + if !strings.Contains(body, "from "+tt.wantSrc) { + t.Fatalf("expected source %q in response, got %q", tt.wantSrc, body) + } + }) + } +} diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index f7ecdb828..f3c1108ed 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -1161,6 +1161,23 @@ func (s *Server) handleEthernetFrameFromVM(packetRaw []byte) error { return nil } +// routeTCPPacket forwards a TCP packet to the network owning the +// destination IP (looked up by WAN IP). Used for inter-network TCP +// forwarding so guest VM TCP stacks talk end-to-end through vnet's +// packet-level NAT. +func (s *Server) routeTCPPacket(tp TCPPacket) { + dstIP := tp.Dst.Addr() + netw, ok := s.networkByWAN.Lookup(dstIP) + if !ok { + if dstIP.IsPrivate() { + return + } + log.Printf("no network to route TCP packet for %v", tp.Dst) + return + } + netw.HandleTCPPacket(tp) +} + func (s *Server) routeUDPPacket(up UDPPacket) { // Find which network owns this based on the destination IP // and all the known networks' wan IPs. @@ -1397,6 +1414,65 @@ func (n *network) nodeByIP(ip netip.Addr) (node *node, ok bool) { return node, ok } +// HandleTCPPacket handles a TCP packet arriving from the simulated +// internet, addressed to the network's WAN IP. It NATs the destination +// back to a LAN node and writes the rewritten packet onto the LAN. +func (n *network) HandleTCPPacket(p TCPPacket) { + buf, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + if p.Dst.Addr().Is4() && n.breakWAN4 { + return + } + dst := n.doNATIn(p.Src, p.Dst) + if !dst.IsValid() { + n.logf("Warning: NAT dropped TCP packet; no mapping for %v=>%v", p.Src, p.Dst) + return + } + p.Dst = dst + buf, err = n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + n.WriteTCPPacketNoNAT(p) +} + +// WriteTCPPacketNoNAT writes a TCP packet to the network without doing +// any NAT translation. The src/dst in p must already be in their final +// form for the LAN. +func (n *network) WriteTCPPacketNoNAT(p TCPPacket) { + node, ok := n.nodeByIP(p.Dst.Addr()) + if !ok { + n.logf("no node for dest IP %v in TCP packet %v=>%v", p.Dst.Addr(), p.Src, p.Dst) + return + } + eth := &layers.Ethernet{ + SrcMAC: n.mac.HWAddr(), + DstMAC: node.macForNet(n).HWAddr(), + } + ethRaw, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, eth) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.writeEth(ethRaw) +} + // WriteUDPPacketNoNAT writes a UDP packet to the network, without // doing any NAT translation. // @@ -1446,6 +1522,27 @@ func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetwork panic("invalid src IP") } +// serializedTCPPacket serializes a TCP packet with the given src/dst, +// using the provided TCP layer (its flags, seq/ack, window, options, +// and payload are preserved; only the src/dst ports are overwritten). +// +// If eth is non-nil, it is used as the Ethernet layer, otherwise the +// Ethernet layer is omitted. +func (n *network) serializedTCPPacket(src, dst netip.AddrPort, tcp *layers.TCP, eth *layers.Ethernet) ([]byte, error) { + ip := mkIPLayer(layers.IPProtocolTCP, src.Addr(), dst.Addr()) + // Copy the TCP layer with new ports and a zeroed checksum so + // gopacket recomputes it against the new IP pseudo-header. + newTCP := *tcp + newTCP.SrcPort = layers.TCPPort(src.Port()) + newTCP.DstPort = layers.TCPPort(dst.Port()) + newTCP.Checksum = 0 + payload := gopacket.Payload(tcp.Payload) + if eth == nil { + return mkPacket(ip, &newTCP, payload) + } + return mkPacket(eth, ip, &newTCP, payload) +} + // serializedUDPPacket serializes a UDP packet with the given source and // destination IP:port pairs, and payload. // @@ -1517,6 +1614,19 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { return } + // Inter-network TCP forwarding: a guest VM is sending TCP to another + // simulated network's WAN IP. Apply egress NAT (rewriting src) and + // hand the packet off to the destination network for ingress NAT and + // LAN delivery, so the two guest TCP stacks talk end-to-end. + if toForward && flow.dst.Is4() { + if tcp, ok := packet.Layer(layers.LayerTypeTCP).(*layers.TCP); ok { + if _, ok := n.s.networkByWAN.Lookup(flow.dst); ok { + n.handleTCPPacketForRouter(tcp, flow) + return + } + } + } + if flow.src.Is6() && flow.src.IsLinkLocalUnicast() && !flow.dst.IsLinkLocalUnicast() { // Don't log. return @@ -1531,6 +1641,54 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { n.logf("router got unknown packet: %v", packet) } +// handleTCPPacketForRouter handles a TCP packet from a LAN node that +// targets another simulated network's WAN IP. It rewrites src via the +// local NAT, then routes the packet to the destination network where +// HandleTCPPacket rewrites dst and delivers it to the LAN. +func (n *network) handleTCPPacketForRouter(tcp *layers.TCP, flow ipSrcDst) { + if flow.dst.Is4() && n.breakWAN4 { + return + } + src := netip.AddrPortFrom(flow.src, uint16(tcp.SrcPort)) + dst := netip.AddrPortFrom(flow.dst, uint16(tcp.DstPort)) + + buf, err := n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + + lanSrc := src + src = n.doNATOut(src, dst) + if !src.IsValid() { + n.logf("warning: NAT dropped TCP packet; no NAT out mapping for %v=>%v", lanSrc, dst) + return + } + buf, err = n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + + n.s.routeTCPPacket(TCPPacket{ + Src: src, + Dst: dst, + TCP: tcp, + }) +} + func (n *network) handleUDPPacketForRouter(ep EthernetPacket, udp *layers.UDP, toForward bool, flow ipSrcDst) { packet := ep.gp srcIP, dstIP := flow.src, flow.dst @@ -2320,6 +2478,17 @@ type UDPPacket struct { Payload []byte // everything after UDP header } +// TCPPacket is a TCP packet flowing through vnet's NAT, used for +// packet-level TCP forwarding between simulated networks. Unlike UDP +// (which only needs ports + payload), TCP carries flags, sequence +// numbers, and options that must be preserved end-to-end so the guest +// VM kernels' TCP state machines stay in sync. +type TCPPacket struct { + Src netip.AddrPort + Dst netip.AddrPort + TCP *layers.TCP // full parsed TCP layer (header + options + payload) +} + func (s *Server) WriteStartingBanner(w io.Writer) { fmt.Fprintf(w, "vnet serving clients:\n")