diff --git a/cmd/tailproxy/proxy.go b/cmd/tailproxy/proxy.go new file mode 100644 index 000000000..1bcaa7859 --- /dev/null +++ b/cmd/tailproxy/proxy.go @@ -0,0 +1,107 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net" + "slices" + "strings" + "sync" + + "github.com/inetaf/tcpproxy" + "tailscale.com/client/tailscale" + "tailscale.com/net/netutil" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/must" +) + +type proxyGrantRule struct { + AllowedHosts []dnsname.FQDN +} + +func handleConn(ctx context.Context, c net.Conn, lc *tailscale.LocalClient, dialCtx func(context.Context, string, string) (net.Conn, error)) { + addrPortStr := c.LocalAddr().String() + _, port, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr) + c.Close() + return + } + who, err := lc.WhoIs(ctx, c.RemoteAddr().String()) + if err != nil { + c.Close() + log.Printf("tcpSNIHandler.Handle: WhoIs: %v", err) + return + } + rules, err := tailcfg.UnmarshalCapJSON[proxyGrantRule](who.CapMap, "maisem.com/tailproxy") + if err != nil { + c.Close() + log.Printf("tcpSNIHandler.Handle: UnmarshalCapJSON: %v", err) + return + } + + var p tcpproxy.Proxy + p.ListenFunc = func(net, laddr string) (net.Listener, error) { + return netutil.NewOneConnListener(c, nil), nil + } + p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) { + sniFQDN, err := dnsname.ToFQDN(sniName) + if err != nil { + log.Printf("tcpSNIHandler.Handle: ToFQDN: %v", err) + return nil, false + } + for _, rule := range rules { + if slices.ContainsFunc(rule.AllowedHosts, func(fqdn dnsname.FQDN) bool { + return fqdn == "*" || fqdn.Contains(sniFQDN) + }) { + log.Printf("tcpSNIHandler.Handle: %s is allowed", sniName) + return &tcpproxy.DialProxy{ + Addr: net.JoinHostPort(sniName, port), + DialContext: dialCtx, + }, true + } + } + log.Printf("tcpSNIHandler.Handle: %s is not allowed", sniName) + return nil, false + }) + p.Start() +} + +func main() { + var ( + ports = flag.String("ports", "443", "comma-separated list of ports to proxy") + hostname = flag.String("hostname", "", "Hostname to register the service under") + ) + flag.Parse() + + ctx := context.Background() + s := &tsnet.Server{ + Hostname: *hostname, + Logf: logger.Discard, + } + must.Get(s.Up(ctx)) + var wg sync.WaitGroup + log.Printf("Listening on ports: %s", *ports) + for _, port := range strings.Split(*ports, ",") { + wg.Add(1) + ln := must.Get(s.Listen("tcp", ":"+port)) + lc := must.Get(s.LocalClient()) + go func() { + defer wg.Done() + for { + c, err := ln.Accept() + if err != nil { + continue + } + fmt.Println("Accepted connection") + go handleConn(ctx, c, lc, s.Dial) + } + }() + } + wg.Wait() +}