From 3fa5c76cba4b24a3859c8e10264f17c70e9316a9 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 14 Nov 2023 23:53:20 -0800 Subject: [PATCH] cmd/tsidp: fix tsnet listener Signed-off-by: Maisem Ali --- cmd/tsidp/tsidp.go | 57 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 5c62eea4d..8a2f9e310 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -56,6 +56,8 @@ func main() { lc *tailscale.LocalClient st *ipnstate.Status err error + + lns []net.Listener ) if *flagUseLocalTailscaled { lc = &tailscale.LocalClient{} @@ -63,6 +65,23 @@ func main() { if err != nil { log.Fatalf("getting status: %v", err) } + portStr := fmt.Sprint(*flagPort) + anySuccess := false + for _, ip := range st.TailscaleIPs { + ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr)) + if err != nil { + log.Printf("failed to listen on %v: %v", ip, err) + continue + } + anySuccess = true + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: lc.GetCertificate, + }) + lns = append(lns, ln) + } + if !anySuccess { + log.Fatalf("failed to listen on any of %v", st.TailscaleIPs) + } } else { ts := &tsnet.Server{ Hostname: "idp", @@ -78,34 +97,38 @@ func main() { if err != nil { log.Fatalf("getting local client: %v", err) } + ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort)) + if err != nil { + log.Fatal(err) + } + lns = append(lns, ln) } srv := &idpServer{ - lc: lc, - serverURL: fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort), + lc: lc, } + if *flagPort != 443 { + srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort) + } else { + srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, ".")) + } + log.Printf("Running tsidp at %s ...", srv.serverURL) if *flagLocalPort != -1 { + log.Printf("Also running tsidp at %s ...", srv.loopbackURL) srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort) - go func() { - ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort)) - if err != nil { - log.Fatal(err) - } - log.Printf("Also running tsidp at %s ...", srv.loopbackURL) - http.Serve(ln, srv) - }() + ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort)) + if err != nil { + log.Fatal(err) + } + lns = append(lns, ln) } - ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", st.TailscaleIPs[0], *flagPort)) - if err != nil { - log.Fatal(err) + for _, ln := range lns { + go http.Serve(ln, srv) } - ln = tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, - }) - log.Fatal(http.Serve(ln, srv)) + select {} } type idpServer struct {