diff --git a/cmd/iam-store.go b/cmd/iam-store.go index 6d025e8b0..26fe52e3a 100644 --- a/cmd/iam-store.go +++ b/cmd/iam-store.go @@ -2371,7 +2371,7 @@ func (store *IAMStoreSys) UpdateServiceAccount(ctx context.Context, accessKey st return updatedAt, err } - if len(policyBuf) > 2048 { + if len(policyBuf) > maxSVCSessionPolicySize { return updatedAt, errSessionPolicyTooLarge } diff --git a/cmd/iam.go b/cmd/iam.go index 8f6361e5a..283545ba8 100644 --- a/cmd/iam.go +++ b/cmd/iam.go @@ -78,6 +78,10 @@ const ( inheritedPolicyType = "inherited-policy" ) +const ( + maxSVCSessionPolicySize = 4096 +) + // IAMSys - config system. type IAMSys struct { // Need to keep them here to keep alignment - ref: https://golang.org/pkg/sync/atomic/#pkg-note-BUG @@ -977,7 +981,7 @@ func (sys *IAMSys) NewServiceAccount(ctx context.Context, parentUser string, gro if err != nil { return auth.Credentials{}, time.Time{}, err } - if len(policyBuf) > 2048 { + if len(policyBuf) > maxSVCSessionPolicySize { return auth.Credentials{}, time.Time{}, errSessionPolicyTooLarge } } diff --git a/cmd/sts-handlers.go b/cmd/sts-handlers.go index 2bcf9e434..fe46f6feb 100644 --- a/cmd/sts-handlers.go +++ b/cmd/sts-handlers.go @@ -22,9 +22,11 @@ import ( "context" "crypto/x509" "encoding/base64" + "encoding/json" "errors" "fmt" "net/http" + "net/url" "strconv" "strings" "time" @@ -82,8 +84,50 @@ const ( // Role Claim key roleArnClaim = "roleArn" + + // maximum supported STS session policy size + maxSTSSessionPolicySize = 2048 ) +type stsClaims map[string]interface{} + +func (c stsClaims) populateSessionPolicy(form url.Values) error { + if len(form) == 0 { + return nil + } + + sessionPolicyStr := form.Get(stsPolicy) + if len(sessionPolicyStr) == 0 { + return nil + } + + sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr))) + if err != nil { + return err + } + + // Version in policy must not be empty + if sessionPolicy.Version == "" { + return errors.New("Version cannot be empty expecting '2012-10-17'") + } + + policyBuf, err := json.Marshal(sessionPolicy) + if err != nil { + return err + } + + // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html + // The plain text that you use for both inline and managed session + // policies shouldn't exceed maxSTSSessionPolicySize characters. + if len(policyBuf) > maxSTSSessionPolicySize { + return errSessionPolicyTooLarge + } + + c[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString(policyBuf) + return nil +} + // stsAPIHandlers implements and provides http handlers for AWS STS API. type stsAPIHandlers struct{} @@ -212,7 +256,7 @@ func getTokenSigningKey() (string, error) { func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) { ctx := newContext(r, w, "AssumeRole") - claims := make(map[string]interface{}) + claims := stsClaims{} defer logger.AuditLog(ctx, w, r, claims) // Check auth here (otherwise r.Form will have unexpected values from @@ -249,29 +293,11 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) { return } - sessionPolicyStr := r.Form.Get(stsPolicy) - // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html - // The plain text that you use for both inline and managed session - // policies shouldn't exceed 2048 characters. - if len(sessionPolicyStr) > 2048 { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, errSessionPolicyTooLarge) + if err := claims.populateSessionPolicy(r.Form); err != nil { + writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) return } - if len(sessionPolicyStr) > 0 { - sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr))) - if err != nil { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) - return - } - - // Version in policy must not be empty - if sessionPolicy.Version == "" { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version cannot be empty expecting '2012-10-17'")) - return - } - } - duration, err := openid.GetDefaultExpiration(r.Form.Get(stsDurationSeconds)) if err != nil { writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) @@ -288,10 +314,6 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) { return } - if len(sessionPolicyStr) > 0 { - claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr)) - } - secret, err := getTokenSigningKey() if err != nil { writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) @@ -342,7 +364,7 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) { func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Request) { ctx := newContext(r, w, "AssumeRoleSSOCommon") - claims := make(map[string]interface{}) + claims := stsClaims{} defer logger.AuditLog(ctx, w, r, claims) // Parse the incoming form data. @@ -449,31 +471,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Requ claims[iamPolicyClaimNameOpenID()] = policyName } - sessionPolicyStr := r.Form.Get(stsPolicy) - // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - // The plain text that you use for both inline and managed session - // policies shouldn't exceed 2048 characters. - if len(sessionPolicyStr) > 2048 { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters")) + if err := claims.populateSessionPolicy(r.Form); err != nil { + writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) return } - if len(sessionPolicyStr) > 0 { - sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr))) - if err != nil { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) - return - } - - // Version in policy must not be empty - if sessionPolicy.Version == "" { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Invalid session policy version")) - return - } - - claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr)) - } - secret, err := getTokenSigningKey() if err != nil { writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) @@ -612,7 +614,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithClientGrants(w http.ResponseWriter, r * func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *http.Request) { ctx := newContext(r, w, "AssumeRoleWithLDAPIdentity") - claims := make(map[string]interface{}) + claims := stsClaims{} defer logger.AuditLog(ctx, w, r, claims, stsLDAPPassword) // Parse the incoming form data. @@ -643,29 +645,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r * return } - sessionPolicyStr := r.Form.Get(stsPolicy) - // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html - // The plain text that you use for both inline and managed session - // policies shouldn't exceed 2048 characters. - if len(sessionPolicyStr) > 2048 { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters")) + if err := claims.populateSessionPolicy(r.Form); err != nil { + writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) return } - if len(sessionPolicyStr) > 0 { - sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr))) - if err != nil { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) - return - } - - // Version in policy must not be empty - if sessionPolicy.Version == "" { - writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version needs to be specified in session policy")) - return - } - } - if !globalIAMSys.Initialized() { writeSTSErrorResponse(ctx, w, ErrSTSIAMNotInitialized, errIAMNotInitialized) return @@ -708,10 +692,6 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r * claims[ldapAttribPrefix+attrib] = value } - if len(sessionPolicyStr) > 0 { - claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr)) - } - secret, err := getTokenSigningKey() if err != nil { writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) diff --git a/internal/config/identity/openid/jwt.go b/internal/config/identity/openid/jwt.go index 5813cade8..89acb814b 100644 --- a/internal/config/identity/openid/jwt.go +++ b/internal/config/identity/openid/jwt.go @@ -133,7 +133,7 @@ const ( ) // Validate - validates the id_token. -func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error { +func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims map[string]interface{}) error { jp := new(jwtgo.Parser) jp.ValidMethods = []string{ "RS256", "RS384", "RS512", @@ -156,14 +156,15 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, return fmt.Errorf("Role %s does not exist", arn) } - jwtToken, err := jp.ParseWithClaims(token, &claims, keyFuncCallback) + mclaims := jwtgo.MapClaims(claims) + jwtToken, err := jp.ParseWithClaims(token, &mclaims, keyFuncCallback) if err != nil { // Re-populate the public key in-case the JWKS // pubkeys are refreshed if err = r.PopulatePublicKey(arn); err != nil { return err } - jwtToken, err = jwtgo.ParseWithClaims(token, &claims, keyFuncCallback) + jwtToken, err = jwtgo.ParseWithClaims(token, &mclaims, keyFuncCallback) if err != nil { return err } @@ -173,11 +174,11 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, return ErrTokenExpired } - if err = updateClaimsExpiry(dsecs, claims); err != nil { + if err = updateClaimsExpiry(dsecs, mclaims); err != nil { return err } - if err = r.updateUserinfoClaims(ctx, arn, accessToken, claims); err != nil { + if err = r.updateUserinfoClaims(ctx, arn, accessToken, mclaims); err != nil { return err } @@ -190,7 +191,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, // array of case sensitive strings. In the common special case // when there is one audience, the aud value MAY be a single // case sensitive - audValues, ok := policy.GetValuesFromClaims(claims, audClaim) + audValues, ok := policy.GetValuesFromClaims(mclaims, audClaim) if !ok { return errors.New("STS JWT Token has `aud` claim invalid, `aud` must match configured OpenID Client ID") } @@ -204,7 +205,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, // be included even when the authorized party is the same // as the sole audience. The azp value is a case sensitive // string containing a StringOrURI value - azpValues, ok := policy.GetValuesFromClaims(claims, azpClaim) + azpValues, ok := policy.GetValuesFromClaims(mclaims, azpClaim) if !ok { return errors.New("STS JWT Token has `azp` claim invalid, `azp` must match configured OpenID Client ID") }