From 678a3bf88a316448dbb0ffdbaeb25ad54ac1256b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 14 Nov 2023 18:46:40 -0800 Subject: [PATCH] cmd/tsidp: fix HTTP mux layering Change-Id: I08fd8206fc8e6b405a2f2b03028f5f17866f3442 Signed-off-by: Brad Fitzpatrick --- cmd/tsidp/tsidp.go | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 8fb781233..492a2cec8 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -69,16 +69,14 @@ func main() { if err != nil { log.Fatal(err) } - mux := http.NewServeMux() - srv.Register(mux) - - log.Fatal(http.Serve(ln, mux)) + log.Fatal(http.Serve(ln, srv)) } type idpServer struct { lc *tailscale.LocalClient serverURL string // "https://foo.bar.ts.net" + lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] lazySigner lazy.SyncValue[jose.Signer] @@ -108,15 +106,6 @@ type authRequest struct { validTill time.Time } -func (s *idpServer) Register(mux *http.ServeMux) { - mux.Handle(oidcJWKSPath, http.HandlerFunc(s.serveJWKS)) - mux.Handle(oidcConfigPath, http.HandlerFunc(s.serveOpenIDConfig)) - mux.Handle("/authorize/", http.HandlerFunc(s.authorize)) - mux.Handle("/userinfo", http.HandlerFunc(s.serveUserInfo)) - mux.Handle("/token", http.HandlerFunc(s.serveToken)) - mux.Handle("/", s) -} - func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { @@ -149,15 +138,26 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, u, http.StatusFound) } +func (s *idpServer) newMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc(oidcJWKSPath, s.serveJWKS) + mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig) + mux.HandleFunc("/authorize/", s.authorize) + mux.HandleFunc("/userinfo", s.serveUserInfo) + mux.HandleFunc("/token", s.serveToken) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + io.WriteString(w, "

Tailscale OIDC IdP

") + return + } + http.Error(w, "tsidp: not found", http.StatusNotFound) + }) + return mux +} + func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("%v %v", r.Method, r.URL) - - if r.URL.Path == "/" { - io.WriteString(w, "

Tailscale OIDC IdP

") - return - } - - http.Error(w, "tsidp: not found", http.StatusNotFound) + s.lazyMux.Get(s.newMux).ServeHTTP(w, r) } func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {