diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 8fb781233..492a2cec8 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -69,16 +69,14 @@ func main() { if err != nil { log.Fatal(err) } - mux := http.NewServeMux() - srv.Register(mux) - - log.Fatal(http.Serve(ln, mux)) + log.Fatal(http.Serve(ln, srv)) } type idpServer struct { lc *tailscale.LocalClient serverURL string // "https://foo.bar.ts.net" + lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] lazySigner lazy.SyncValue[jose.Signer] @@ -108,15 +106,6 @@ type authRequest struct { validTill time.Time } -func (s *idpServer) Register(mux *http.ServeMux) { - mux.Handle(oidcJWKSPath, http.HandlerFunc(s.serveJWKS)) - mux.Handle(oidcConfigPath, http.HandlerFunc(s.serveOpenIDConfig)) - mux.Handle("/authorize/", http.HandlerFunc(s.authorize)) - mux.Handle("/userinfo", http.HandlerFunc(s.serveUserInfo)) - mux.Handle("/token", http.HandlerFunc(s.serveToken)) - mux.Handle("/", s) -} - func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { @@ -149,15 +138,26 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, u, http.StatusFound) } +func (s *idpServer) newMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc(oidcJWKSPath, s.serveJWKS) + mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig) + mux.HandleFunc("/authorize/", s.authorize) + mux.HandleFunc("/userinfo", s.serveUserInfo) + mux.HandleFunc("/token", s.serveToken) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + io.WriteString(w, "

Tailscale OIDC IdP

") + return + } + http.Error(w, "tsidp: not found", http.StatusNotFound) + }) + return mux +} + func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("%v %v", r.Method, r.URL) - - if r.URL.Path == "/" { - io.WriteString(w, "

Tailscale OIDC IdP

") - return - } - - http.Error(w, "tsidp: not found", http.StatusNotFound) + s.lazyMux.Get(s.newMux).ServeHTTP(w, r) } func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {