Aws auth fixes (#9825)

* Bring over PSIRT-37 changes from ENT

* Add additional allowed headers

* Already had this one

* Change to string slice comma separated parsing

* Add allowed_sts_header_values to read output

* Only validate AWS related request headers

* one per line

* Import ordering

* Update test

* Add X-Amz-Credential

* Reorder imports
This commit is contained in:
Scott Miller 2020-08-25 17:37:59 -05:00 committed by GitHub
parent cca11493ce
commit ade448cd47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 129 additions and 12 deletions

View File

@ -17,6 +17,15 @@ import (
cache "github.com/patrickmn/go-cache" cache "github.com/patrickmn/go-cache"
) )
const amzHeaderPrefix = "X-Amz-"
var defaultAllowedSTSRequestHeaders = []string{
"X-Amz-Date",
"X-Amz-Credential",
"X-Amz-Security-Token",
"X-Amz-Algorithm",
"X-Amz-Signature",
"X-Amz-SignedHeaders"}
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf) b, err := Backend(conf)
if err != nil { if err != nil {

View File

@ -2,9 +2,14 @@ package awsauth
import ( import (
"context" "context"
"errors"
"net/http"
"net/textproto"
"strings"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@ -53,6 +58,11 @@ func (b *backend) pathConfigClient() *framework.Path {
Default: "", Default: "",
Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header", Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header",
}, },
"allowed_sts_header_values": {
Type: framework.TypeCommaStringSlice,
Default: nil,
Description: "List of additional headers that are allowed to be in AWS STS request headers",
},
"max_retries": { "max_retries": {
Type: framework.TypeInt, Type: framework.TypeInt,
Default: aws.UseServiceDefaultRetries, Default: aws.UseServiceDefaultRetries,
@ -136,6 +146,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"sts_region": clientConfig.STSRegion, "sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, "iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries, "max_retries": clientConfig.MaxRetries,
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
}, },
}, nil }, nil
} }
@ -257,6 +268,24 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string) configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string)
} }
aHeadersValStr, ok := data.GetOk("allowed_sts_header_values")
if ok {
aHeadersValSl := aHeadersValStr.([]string)
for i, v := range aHeadersValSl {
aHeadersValSl[i] = textproto.CanonicalMIMEHeaderKey(v)
}
if !strutil.EquivalentSlices(configEntry.AllowedSTSHeaderValues, aHeadersValSl) {
// NOT setting changedCreds here, since this isn't really cached
configEntry.AllowedSTSHeaderValues = aHeadersValSl
changedOtherConfig = true
}
} else if req.Operation == logical.CreateOperation {
ah, ok := data.GetOk("allowed_sts_header_values")
if ok {
configEntry.AllowedSTSHeaderValues = ah.([]string)
}
}
maxRetriesInt, ok := data.GetOk("max_retries") maxRetriesInt, ok := data.GetOk("max_retries")
if ok { if ok {
configEntry.MaxRetries = maxRetriesInt.(int) configEntry.MaxRetries = maxRetriesInt.(int)
@ -293,14 +322,27 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
// Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to // Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to
// interact with the AWS EC2 API. // interact with the AWS EC2 API.
type clientConfig struct { type clientConfig struct {
AccessKey string `json:"access_key"` AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"` SecretKey string `json:"secret_key"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"` IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"` STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"` STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"` AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
MaxRetries int `json:"max_retries"`
}
func (c *clientConfig) validateAllowedSTSHeaderValues(headers http.Header) error {
for k := range headers {
h := textproto.CanonicalMIMEHeaderKey(k)
if strings.HasPrefix(h, amzHeaderPrefix) &&
!strutil.StrListContains(defaultAllowedSTSRequestHeaders, h) &&
!strutil.StrListContains(c.AllowedSTSHeaderValues, h) {
return errors.New("invalid request header: " + k)
}
}
return nil
} }
const pathConfigClientHelpSyn = ` const pathConfigClientHelpSyn = `

View File

@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -43,6 +44,11 @@ const (
retryWaitMax = 30 * time.Second retryWaitMax = 30 * time.Second
) )
var (
errRequestBodyNotValid = errors.New("iam request body is invalid")
errInvalidGetCallerIdentityResponse = errors.New("body of GetCallerIdentity is invalid")
)
func (b *backend) pathLogin() *framework.Path { func (b *backend) pathLogin() *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "login$", Pattern: "login$",
@ -1179,7 +1185,10 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
if err != nil { if err != nil {
return logical.ErrorResponse("error parsing iam_request_url"), nil return logical.ErrorResponse("error parsing iam_request_url"), nil
} }
if parsedUrl.RawQuery != "" {
// Should be no query parameters
return logical.ErrorResponse(logical.ErrInvalidRequest.Error()), nil
}
// TODO: There are two potentially valid cases we're not yet supporting that would // TODO: There are two potentially valid cases we're not yet supporting that would
// necessitate this check being changed. First, if we support GET requests. // necessitate this check being changed. First, if we support GET requests.
// Second if we support presigned POST requests // Second if we support presigned POST requests
@ -1192,6 +1201,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil
} }
body := string(bodyRaw) body := string(bodyRaw)
if err = validateLoginIamRequestBody(body); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
headers := data.Get("iam_request_headers").(http.Header) headers := data.Get("iam_request_headers").(http.Header)
if len(headers) == 0 { if len(headers) == 0 {
@ -1213,6 +1225,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil
} }
} }
if err = config.validateAllowedSTSHeaderValues(headers); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
if config.STSEndpoint != "" { if config.STSEndpoint != "" {
endpoint = config.STSEndpoint endpoint = config.STSEndpoint
} }
@ -1394,6 +1409,29 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
}, nil }, nil
} }
// Validate that the iam_request_body passed is valid for the STS request
func validateLoginIamRequestBody(body string) error {
qs, err := url.ParseQuery(body)
if err != nil {
return err
}
for k, v := range qs {
switch k {
case "Action":
if len(v) != 1 || v[0] != "GetCallerIdentity" {
return errRequestBodyNotValid
}
case "Version":
// Will assume for now that future versions don't change
// the semantics
default:
// Not expecting any other values
return errRequestBodyNotValid
}
}
return nil
}
// These two methods (hasValuesFor*) return two bools // These two methods (hasValuesFor*) return two bools
// The first is a hasAll, that is, does the request have all the values // The first is a hasAll, that is, does the request have all the values
// necessary for this auth method // necessary for this auth method
@ -1559,8 +1597,12 @@ func ensureHeaderIsSigned(signedHeaders, headerToSign string) error {
} }
func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) { func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) {
decoder := xml.NewDecoder(strings.NewReader(response))
result := GetCallerIdentityResponse{} result := GetCallerIdentityResponse{}
response = strings.TrimSpace(response)
if !strings.HasPrefix(response, "<GetCallerIdentityResponse") && !strings.HasPrefix(response, "<?xml") {
return result, errInvalidGetCallerIdentityResponse
}
decoder := xml.NewDecoder(strings.NewReader(response))
err := decoder.Decode(&result) err := decoder.Decode(&result)
return result, err return result, err
} }
@ -1596,6 +1638,11 @@ func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, en
if response != nil { if response != nil {
defer response.Body.Close() defer response.Body.Close()
} }
// Validate that the response type is XML
if ct := response.Header.Get("Content-Type"); ct != "text/xml" {
return nil, errInvalidGetCallerIdentityResponse
}
// we check for status code afterwards to also print out response body // we check for status code afterwards to also print out response body
responseBody, err := ioutil.ReadAll(response.Body) responseBody, err := ioutil.ReadAll(response.Body)
if err != nil { if err != nil {

View File

@ -305,6 +305,19 @@ func TestBackend_pathLogin_IAMHeaders(t *testing.T) {
}, },
ExpectErr: missingHeaderErr, ExpectErr: missingHeaderErr,
}, },
{
Name: "Map-illegal-header",
Header: map[string]interface{}{
"Content-Length": "43",
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
"User-Agent": "aws-sdk-go/1.14.24 (go1.11; darwin; amd64)",
"X-Amz-Date": "20180910T203328Z",
"Authorization": "AWS4-HMAC-SHA256 Credential=AKIAJPQ466AIIQW4LPSQ/20180910/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-vault-aws-iam-server-id, Signature=cdef5819b2e97f1ff0f3e898fd2621aa03af00a4ec3e019122c20e5482534bf4",
"X-Vault-Aws-Iam-Server-Id": "VaultAcceptanceTesting",
"X-Amz-Mallory-Header": "<?xml><h4ck0r/>",
},
ExpectErr: errors.New("invalid request header: X-Amz-Mallory-Header"),
},
{ {
Name: "JSON-complete", Name: "JSON-complete",
Header: `{ Header: `{
@ -543,7 +556,8 @@ func setupIAMTestServer() *httptest.Server {
<ResponseMetadata> <ResponseMetadata>
<RequestId>7f4fc40c-853a-11e6-8848-8d035d01eb87</RequestId> <RequestId>7f4fc40c-853a-11e6-8848-8d035d01eb87</RequestId>
</ResponseMetadata> </ResponseMetadata>
</GetCallerIdentityResponse>` </GetCallerIdentityResponse>
`
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
parts := strings.Split(auth, ",") parts := strings.Split(auth, ",")
@ -566,6 +580,7 @@ func setupIAMTestServer() *httptest.Server {
if matchingCount != len(expectedAuthParts) { if matchingCount != len(expectedAuthParts) {
responseString = "missing auth parts" responseString = "missing auth parts"
} }
w.Header().Add("Content-Type", "text/xml")
fmt.Fprintln(w, responseString) fmt.Fprintln(w, responseString)
})) }))
} }

View File

@ -66,7 +66,11 @@ capabilities, the credentials are fetched automatically.
signed headers validated by AWS. This is to protect against different types of signed headers validated by AWS. This is to protect against different types of
replay attacks, for example a signed request sent to a dev server being resent replay attacks, for example a signed request sent to a dev server being resent
to a production server. Consider setting this to the Vault server's DNS name. to a production server. Consider setting this to the Vault server's DNS name.
- `allowed_sts_header_values` `(string: "")` A comma separated list of
additional request headers permitted when providing the iam_request_headers for
an IAM based login call. In any case, a default list of headers AWS STS
expects for a GetCallerIdentity are allowed.
### Sample Payload ### Sample Payload
```json ```json