diff --git a/cmd/lopower/lopower.go b/cmd/lopower/lopower.go index faccf3a1e..6c975d37b 100644 --- a/cmd/lopower/lopower.go +++ b/cmd/lopower/lopower.go @@ -27,6 +27,8 @@ import ( "sync/atomic" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" qrcode "github.com/skip2/go-qrcode" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" @@ -47,6 +49,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/tsnet" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/must" @@ -76,6 +79,11 @@ type config struct { V6CIDR netip.Prefix } +// IsLocalIP reports whether ip is one of the local IPs. +func (c *config) IsLocalIP(ip netip.Addr) bool { + return ip.IsValid() && (ip == c.V4 || ip == c.V6) +} + type Peer struct { PrivKey key.NodePrivate // e.g. proxy client's V4 netip.Addr @@ -168,6 +176,7 @@ func newLP(ctx context.Context) *lpServer { Errorf: logf, } lp := &lpServer{ + ctx: ctx, dir: *confDir, readCh: make(chan *stack.PacketBuffer, 16), } @@ -197,8 +206,9 @@ type lpServer struct { tsnet *tsnet.Server d *device.Device ns *stack.Stack + ctx context.Context // canceled on shutdown linkEP *channel.Endpoint - readCh chan *stack.PacketBuffer + readCh chan *stack.PacketBuffer // from gvisor/dns server => out to network // protocolConns tracks the number of active connections for each connection. // It is used to add and remove protocol addresses from netstack as needed. @@ -553,6 +563,7 @@ func (t *nsTUN) Close() error { return nil } +// Read reads packets from gvisor (or the DNS server) to send out to the network. func (t *nsTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { select { case <-t.closeCh: @@ -578,12 +589,18 @@ func (t *nsTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { func (t *nsTUN) Write(buffs [][]byte, offset int) (int, error) { var pkt packet.Parsed for _, buff := range buffs { - pkt.Decode(buff[offset:]) + raw := buff[offset:] + pkt.Decode(raw) packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(slices.Clone(buff[offset:])), + Payload: buffer.MakeWithData(slices.Clone(raw)), }) if *verbosePackets { - log.Printf("[v] nsTUN.Write (in): % 02x", buff[offset:]) + log.Printf("[v] nsTUN.Write (in): % 02x", raw) + } + if pkt.IPProto == ipproto.UDP && pkt.Dst.Port() == 53 && t.lp.c.IsLocalIP(pkt.Dst.Addr()) { + // Handle DNS queries before sending to gvisor. + t.lp.handleDNSUDPQuery(raw) + continue } if pkt.IPVersion == 4 { t.lp.linkEP.InjectInbound(ipv4.ProtocolNumber, packetBuf) @@ -618,6 +635,119 @@ func (lp *lpServer) startTSNet(ctx context.Context) { } } +// caller owns the raw memory. +func (lp *lpServer) handleDNSUDPQuery(raw []byte) { + var pkt packet.Parsed + pkt.Decode(raw) + if pkt.IPProto != ipproto.UDP || pkt.Dst.Port() != 53 || !lp.c.IsLocalIP(pkt.Dst.Addr()) { + panic("caller error") + } + m, ok := lp.tsnet.Sys().DNSManager.GetOK() + if !ok { + log.Printf("DNSManager.Get: not ready") + return + } + dnsRes, err := m.Query(context.Background(), pkt.Payload(), "udp", pkt.Src) + if err != nil { + log.Printf("DNS query error: %v", err) + return + } + + ipLayer := mkIPLayer(layers.IPProtocolUDP, pkt.Dst.Addr(), pkt.Src.Addr()) + udpLayer := &layers.UDP{ + SrcPort: 53, + DstPort: layers.UDPPort(pkt.Src.Port()), + } + + resPkt, err := mkPacket(ipLayer, udpLayer, gopacket.Payload(dnsRes)) + if err != nil { + log.Printf("mkPacket: %v", err) + return + } + pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(resPkt), + }) + select { + case lp.readCh <- pktBuf: + case <-lp.ctx.Done(): + } +} + +type serializableNetworkLayer interface { + gopacket.SerializableLayer + gopacket.NetworkLayer +} + +func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetworkLayer { + if src.Is4() { + return &layers.IPv4{ + Protocol: proto, + SrcIP: src.AsSlice(), + DstIP: dst.AsSlice(), + } + } + if src.Is6() { + return &layers.IPv6{ + NextHeader: proto, + SrcIP: src.AsSlice(), + DstIP: dst.AsSlice(), + } + } + panic("invalid src IP") +} + +// mkPacket is a serializes a number of layers into a packet. +// +// It's a convenience wrapper around gopacket.SerializeLayers +// that does some things automatically: +// +// * layers.IPv4/IPv6 Version is set to 4/6 if not already set +// * layers.IPv4/IPv6 TTL/HopLimit is set to 64 if not already set +// * the TCP/UDP/ICMPv6 checksum is set based on the network layer +// +// The provided layers in ll must be sorted from lowest (e.g. *layers.Ethernet) +// to highest. (Depending on the need, the first layer will be either *layers.Ethernet +// or *layers.IPv4/IPv6). +func mkPacket(ll ...gopacket.SerializableLayer) ([]byte, error) { + var nl gopacket.NetworkLayer + for _, la := range ll { + switch la := la.(type) { + case *layers.IPv4: + nl = la + if la.Version == 0 { + la.Version = 4 + } + if la.TTL == 0 { + la.TTL = 64 + } + case *layers.IPv6: + nl = la + if la.Version == 0 { + la.Version = 6 + } + if la.HopLimit == 0 { + la.HopLimit = 64 + } + } + } + for _, la := range ll { + switch la := la.(type) { + case *layers.TCP: + la.SetNetworkLayerForChecksum(nl) + case *layers.UDP: + la.SetNetworkLayerForChecksum(nl) + case *layers.ICMPv6: + la.SetNetworkLayerForChecksum(nl) + } + } + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buf, opts, ll...); err != nil { + return nil, fmt.Errorf("serializing packet: %v", err) + } + return buf.Bytes(), nil +} + func main() { flag.Parse() log.Printf("lopower starting")