diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 492a2cec8..a3e867265 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -18,6 +18,7 @@ import ( "net/netip" "net/url" "os" + "strconv" "strings" "sync" "time" @@ -80,16 +81,18 @@ type idpServer struct { lazySigningKey lazy.SyncValue[*signingKey] lazySigner lazy.SyncValue[jose.Signer] - mu sync.Mutex // guards the fields below - code map[string]*authRequest - accessToken map[string]*authRequest + mu sync.Mutex // guards the fields below + code map[string]*authRequest // keyed by random hex + accessToken map[string]*authRequest // keyed by random hex } type authRequest struct { - // requesterNodeID is the node who requested the auth (say synology), not the node - // who is being authenticated. - // String form of tailcfg.NodeID - requesterNodeID string + // rpNodeID is the NodeID of the relying party (who requested the auth, such + // as Proxmox or Synology), not the user node who is being authenticated. + rpNodeID tailcfg.NodeID + + // clientID is the "client_id" sent in the authorized request. + clientID string // nonce presented in the request. nonce string @@ -114,15 +117,21 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { return } - nodeID := strings.TrimPrefix(r.URL.Path, "/authorize/") + rpNodeID, ok := parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/")) + if !ok { + http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest) + return + } uq := r.URL.Query() + code := rands.HexString(32) ar := &authRequest{ - requesterNodeID: nodeID, - nonce: uq.Get("nonce"), - remoteUser: who, - redirectURI: uq.Get("redirect_uri"), + rpNodeID: rpNodeID, + nonce: uq.Get("nonce"), + remoteUser: who, + redirectURI: uq.Get("redirect_uri"), + clientID: uq.Get("client_id"), } s.mu.Lock() @@ -184,7 +193,7 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: invalid token", http.StatusBadRequest) return } - if ar.requesterNodeID != who.Node.ID.String() { + if ar.rpNodeID != who.Node.ID { http.Error(w, "tsidp: token for different node", http.StatusForbidden) return } @@ -245,13 +254,15 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { } s.mu.Lock() ar, ok := s.code[code] - delete(s.code, code) + if ok { + delete(s.code, code) + } s.mu.Unlock() if !ok { http.Error(w, "tsidp: code not found", http.StatusBadRequest) return } - if ar.requesterNodeID != caller.Node.ID.String() { + if ar.rpNodeID != caller.Node.ID { http.Error(w, "tsidp: token for different node", http.StatusForbidden) return } @@ -280,7 +291,7 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { _, tcd, _ := strings.Cut(n.Name(), ".") tsClaims := tailscaleClaims{ Claims: jwt.Claims{ - Audience: jwt.Audience{"unused"}, + Audience: jwt.Audience{ar.clientID}, Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), ID: jti, IssuedAt: jwt.NewNumericDate(now), @@ -479,7 +490,7 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - authorizeEndpoint := fmt.Sprintf("%s/authorize/%s", s.serverURL, who.Node.ID.String()) + authorizeEndpoint := fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) w.Header().Set("Content-Type", "application/json") je := json.NewEncoder(w) @@ -573,3 +584,19 @@ func (sk *signingKey) UnmarshalJSON(b []byte) error { sk.kid = wrapper.ID return nil } + +// parseID takes a string input and returns a typed IntID T and true, or a zero +// value and false if the input is unhandled syntax or out of a valid range. +func parseID[T ~int64](input string) (_ T, ok bool) { + if input == "" { + return 0, false + } + i, err := strconv.ParseInt(input, 10, 64) + if err != nil { + return 0, false + } + if i < 0 { + return 0, false + } + return T(i), true +}