diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 480c65d7d..f79f64573 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/binary" + "encoding/hex" "encoding/json" "encoding/pem" "flag" @@ -16,6 +17,7 @@ import ( "log" "net/http" "net/netip" + "net/url" "os" "strings" "sync" @@ -23,12 +25,15 @@ import ( "github.com/golang-jwt/jwt" "gopkg.in/square/go-jose.v2" "tailscale.com/client/tailscale" + "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/views" + "tailscale.com/util/mak" "tailscale.com/util/must" ) @@ -71,9 +76,12 @@ type idpServer struct { lc *tailscale.LocalClient serverURL string // "https://foo.bar.ts.net" - oidcSignerInitOnce sync.Once - oidcSignerLazy jose.Signer - oidcSignerError error + lazySigningKey lazy.SyncValue[*signingKey] + lazySigner lazy.SyncValue[jose.Signer] + + mu sync.Mutex // guards the fields below + + code map[string]*apitype.WhoIsResponse // code -> whois } func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -97,11 +105,32 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.URL.Path == "/authorize" { - redir := r.URL.Query().Get("redirect_uri") + who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - http.Redirect(w, r, redir, http.StatusFound) + code := must.Get(readHex()) + + s.mu.Lock() + mak.Set(&s.code, code, who) + s.mu.Unlock() + + q := make(url.Values) + q.Set("code", code) + q.Set("state", r.URL.Query().Get("state")) + u := r.URL.Query().Get("redirect_uri") + "?" + q.Encode() + log.Printf("Redirecting to %q", u) + + http.Redirect(w, r, u, http.StatusFound) return } + + if r.URL.Path == "/token" { + + } http.Error(w, "tsidp: not found", http.StatusNotFound) } @@ -111,24 +140,44 @@ const ( ) func (s *idpServer) oidcSigner() (jose.Signer, error) { - s.oidcSignerInitOnce.Do(s.oidcSignerInit) - return s.oidcSignerLazy, s.oidcSignerError + return s.lazySigner.GetErr(func() (jose.Signer, error) { + sk, err := s.oidcPrivateKey() + if err != nil { + return nil, err + } + return jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: sk.k, + }, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{ + jose.HeaderType: "JWT", + "kid": fmt.Sprint(sk.kid), + }}) + }) } -func (s *idpServer) oidcSignerInit() { - id, k := s.oidcPrivateKey() - s.oidcSignerLazy, s.oidcSignerError = jose.NewSigner(jose.SigningKey{ - Algorithm: jose.RS256, - Key: k, - }, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{ - jose.HeaderType: "JWT", - "kid": fmt.Sprint(id), - }}) -} - -func (s *idpServer) oidcPrivateKey() (id uint64, k *rsa.PrivateKey) { - id, k = mustGenRSAKey(2048) - return +func (s *idpServer) oidcPrivateKey() (*signingKey, error) { + return s.lazySigningKey.GetErr(func() (*signingKey, error) { + var sk signingKey + b, err := os.ReadFile("oidc-key.json") + if err == nil { + if err := sk.UnmarshalJSON(b); err == nil { + return &sk, nil + } else { + log.Printf("Error unmarshaling key: %v", err) + } + } + id, k := mustGenRSAKey(2048) + sk.k = k + sk.kid = id + b, err = sk.MarshalJSON() + if err != nil { + log.Fatalf("Error marshaling key: %v", err) + } + if err := os.WriteFile("oidc-key.json", b, 0600); err != nil { + log.Fatalf("Error writing key: %v", err) + } + return &sk, nil + }) } func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { @@ -136,16 +185,21 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { return tsweb.Error(404, "", nil) } w.Header().Set("Content-Type", "application/json") - id, k := s.oidcPrivateKey() + sk, err := s.oidcPrivateKey() + if err != nil { + return tsweb.Error(500, err.Error(), err) + } // TODO(maisem): maybe only marshal this once and reuse? // TODO(maisem): implement key rotation. - if err := json.NewEncoder(w).Encode(jose.JSONWebKeySet{ + je := json.NewEncoder(w) + je.SetIndent("", " ") + if err := je.Encode(jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - Key: k.Public(), + Key: sk.k.Public(), Algorithm: string(jose.RS256), Use: "sig", - KeyID: fmt.Sprint(id), + KeyID: fmt.Sprint(sk.kid), }, }, }); err != nil { @@ -189,7 +243,7 @@ type tailscaleClaims struct { var ( openIDSupportedClaims = views.SliceOf([]string{ // Standard claims, these correspond to fields in jwt.Claims. - "sub", "aud", "exp", "iat", "iss", "jti", "nbf", + "sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email", // Tailscale claims, these correspond to fields in tailscaleClaims. "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", @@ -215,8 +269,10 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er return tsweb.Error(404, "", nil) } w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(io.MultiWriter(w, os.Stderr)).Encode(openIDProviderMetadata{ - Issuer: s.serverURL + "/", + je := json.NewEncoder(io.MultiWriter(w, os.Stderr)) + je.SetIndent("", " ") + if err := je.Encode(openIDProviderMetadata{ + Issuer: s.serverURL, JWKS_URI: s.serverURL + oidcJWKSPath, UserInfoEndpoint: s.serverURL + "/userinfo", AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add / suffix @@ -247,6 +303,14 @@ func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) { return } +func readHex() (string, error) { + var proxyCred [16]byte + if _, err := crand.Read(proxyCred[:]); err != nil { + return "", err + } + return hex.EncodeToString(proxyCred[:]), nil +} + // readUint64 reads from r until 8 bytes represent a non-zero uint64. func readUint64(r io.Reader) (uint64, error) { for { @@ -267,31 +331,41 @@ type rsaPrivateKeyJSONWrapper struct { ID uint64 } -func marshalKeyJSON(k *rsa.PrivateKey, kid uint64) ([]byte, error) { +type signingKey struct { + k *rsa.PrivateKey + kid uint64 +} + +func (sk *signingKey) MarshalJSON() ([]byte, error) { b := pem.Block{ Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(k), + Bytes: x509.MarshalPKCS1PrivateKey(sk.k), } bts := pem.EncodeToMemory(&b) return json.Marshal(rsaPrivateKeyJSONWrapper{ Key: base64.URLEncoding.EncodeToString(bts), - ID: kid, + ID: sk.kid, }) } -func unmarshalKeyJSON(b []byte) (*rsa.PrivateKey, uint64, error) { +func (sk *signingKey) UnmarshalJSON(b []byte) error { var wrapper rsaPrivateKeyJSONWrapper if err := json.Unmarshal(b, &wrapper); err != nil { - return nil, 0, err + return err } if len(wrapper.Key) == 0 { - return nil, 0, nil + return nil } b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key) if err != nil { - return nil, 0, err + return err } blk, _ := pem.Decode(b64dec) k, err := x509.ParsePKCS1PrivateKey(blk.Bytes) - return k, wrapper.ID, err + if err != nil { + return err + } + sk.k = k + sk.kid = wrapper.ID + return nil }