mirror of
https://github.com/hashicorp/vault.git
synced 2025-11-28 14:11:10 +01:00
Aws auth fixes (#9825)
* Bring over PSIRT-37 changes from ENT * Add additional allowed headers * Already had this one * Change to string slice comma separated parsing * Add allowed_sts_header_values to read output * Only validate AWS related request headers * one per line * Import ordering * Update test * Add X-Amz-Credential * Reorder imports
This commit is contained in:
parent
cca11493ce
commit
ade448cd47
@ -17,6 +17,15 @@ import (
|
|||||||
cache "github.com/patrickmn/go-cache"
|
cache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const amzHeaderPrefix = "X-Amz-"
|
||||||
|
var defaultAllowedSTSRequestHeaders = []string{
|
||||||
|
"X-Amz-Date",
|
||||||
|
"X-Amz-Credential",
|
||||||
|
"X-Amz-Security-Token",
|
||||||
|
"X-Amz-Algorithm",
|
||||||
|
"X-Amz-Signature",
|
||||||
|
"X-Amz-SignedHeaders"}
|
||||||
|
|
||||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||||
b, err := Backend(conf)
|
b, err := Backend(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -2,9 +2,14 @@ package awsauth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/hashicorp/vault/sdk/framework"
|
"github.com/hashicorp/vault/sdk/framework"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,6 +58,11 @@ func (b *backend) pathConfigClient() *framework.Path {
|
|||||||
Default: "",
|
Default: "",
|
||||||
Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header",
|
Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header",
|
||||||
},
|
},
|
||||||
|
"allowed_sts_header_values": {
|
||||||
|
Type: framework.TypeCommaStringSlice,
|
||||||
|
Default: nil,
|
||||||
|
Description: "List of additional headers that are allowed to be in AWS STS request headers",
|
||||||
|
},
|
||||||
"max_retries": {
|
"max_retries": {
|
||||||
Type: framework.TypeInt,
|
Type: framework.TypeInt,
|
||||||
Default: aws.UseServiceDefaultRetries,
|
Default: aws.UseServiceDefaultRetries,
|
||||||
@ -136,6 +146,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
|
|||||||
"sts_region": clientConfig.STSRegion,
|
"sts_region": clientConfig.STSRegion,
|
||||||
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
||||||
"max_retries": clientConfig.MaxRetries,
|
"max_retries": clientConfig.MaxRetries,
|
||||||
|
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -257,6 +268,24 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
|||||||
configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string)
|
configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
aHeadersValStr, ok := data.GetOk("allowed_sts_header_values")
|
||||||
|
if ok {
|
||||||
|
aHeadersValSl := aHeadersValStr.([]string)
|
||||||
|
for i, v := range aHeadersValSl {
|
||||||
|
aHeadersValSl[i] = textproto.CanonicalMIMEHeaderKey(v)
|
||||||
|
}
|
||||||
|
if !strutil.EquivalentSlices(configEntry.AllowedSTSHeaderValues, aHeadersValSl) {
|
||||||
|
// NOT setting changedCreds here, since this isn't really cached
|
||||||
|
configEntry.AllowedSTSHeaderValues = aHeadersValSl
|
||||||
|
changedOtherConfig = true
|
||||||
|
}
|
||||||
|
} else if req.Operation == logical.CreateOperation {
|
||||||
|
ah, ok := data.GetOk("allowed_sts_header_values")
|
||||||
|
if ok {
|
||||||
|
configEntry.AllowedSTSHeaderValues = ah.([]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
maxRetriesInt, ok := data.GetOk("max_retries")
|
maxRetriesInt, ok := data.GetOk("max_retries")
|
||||||
if ok {
|
if ok {
|
||||||
configEntry.MaxRetries = maxRetriesInt.(int)
|
configEntry.MaxRetries = maxRetriesInt.(int)
|
||||||
@ -293,14 +322,27 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
|||||||
// Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to
|
// Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to
|
||||||
// interact with the AWS EC2 API.
|
// interact with the AWS EC2 API.
|
||||||
type clientConfig struct {
|
type clientConfig struct {
|
||||||
AccessKey string `json:"access_key"`
|
AccessKey string `json:"access_key"`
|
||||||
SecretKey string `json:"secret_key"`
|
SecretKey string `json:"secret_key"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
IAMEndpoint string `json:"iam_endpoint"`
|
IAMEndpoint string `json:"iam_endpoint"`
|
||||||
STSEndpoint string `json:"sts_endpoint"`
|
STSEndpoint string `json:"sts_endpoint"`
|
||||||
STSRegion string `json:"sts_region"`
|
STSRegion string `json:"sts_region"`
|
||||||
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
||||||
MaxRetries int `json:"max_retries"`
|
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
|
||||||
|
MaxRetries int `json:"max_retries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConfig) validateAllowedSTSHeaderValues(headers http.Header) error {
|
||||||
|
for k := range headers {
|
||||||
|
h := textproto.CanonicalMIMEHeaderKey(k)
|
||||||
|
if strings.HasPrefix(h, amzHeaderPrefix) &&
|
||||||
|
!strutil.StrListContains(defaultAllowedSTSRequestHeaders, h) &&
|
||||||
|
!strutil.StrListContains(c.AllowedSTSHeaderValues, h) {
|
||||||
|
return errors.New("invalid request header: " + k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathConfigClientHelpSyn = `
|
const pathConfigClientHelpSyn = `
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -43,6 +44,11 @@ const (
|
|||||||
retryWaitMax = 30 * time.Second
|
retryWaitMax = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errRequestBodyNotValid = errors.New("iam request body is invalid")
|
||||||
|
errInvalidGetCallerIdentityResponse = errors.New("body of GetCallerIdentity is invalid")
|
||||||
|
)
|
||||||
|
|
||||||
func (b *backend) pathLogin() *framework.Path {
|
func (b *backend) pathLogin() *framework.Path {
|
||||||
return &framework.Path{
|
return &framework.Path{
|
||||||
Pattern: "login$",
|
Pattern: "login$",
|
||||||
@ -1179,7 +1185,10 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse("error parsing iam_request_url"), nil
|
return logical.ErrorResponse("error parsing iam_request_url"), nil
|
||||||
}
|
}
|
||||||
|
if parsedUrl.RawQuery != "" {
|
||||||
|
// Should be no query parameters
|
||||||
|
return logical.ErrorResponse(logical.ErrInvalidRequest.Error()), nil
|
||||||
|
}
|
||||||
// TODO: There are two potentially valid cases we're not yet supporting that would
|
// TODO: There are two potentially valid cases we're not yet supporting that would
|
||||||
// necessitate this check being changed. First, if we support GET requests.
|
// necessitate this check being changed. First, if we support GET requests.
|
||||||
// Second if we support presigned POST requests
|
// Second if we support presigned POST requests
|
||||||
@ -1192,6 +1201,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
|||||||
return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil
|
return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil
|
||||||
}
|
}
|
||||||
body := string(bodyRaw)
|
body := string(bodyRaw)
|
||||||
|
if err = validateLoginIamRequestBody(body); err != nil {
|
||||||
|
return logical.ErrorResponse(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
headers := data.Get("iam_request_headers").(http.Header)
|
headers := data.Get("iam_request_headers").(http.Header)
|
||||||
if len(headers) == 0 {
|
if len(headers) == 0 {
|
||||||
@ -1213,6 +1225,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
|||||||
return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = config.validateAllowedSTSHeaderValues(headers); err != nil {
|
||||||
|
return logical.ErrorResponse(err.Error()), nil
|
||||||
|
}
|
||||||
if config.STSEndpoint != "" {
|
if config.STSEndpoint != "" {
|
||||||
endpoint = config.STSEndpoint
|
endpoint = config.STSEndpoint
|
||||||
}
|
}
|
||||||
@ -1394,6 +1409,29 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the iam_request_body passed is valid for the STS request
|
||||||
|
func validateLoginIamRequestBody(body string) error {
|
||||||
|
qs, err := url.ParseQuery(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for k, v := range qs {
|
||||||
|
switch k {
|
||||||
|
case "Action":
|
||||||
|
if len(v) != 1 || v[0] != "GetCallerIdentity" {
|
||||||
|
return errRequestBodyNotValid
|
||||||
|
}
|
||||||
|
case "Version":
|
||||||
|
// Will assume for now that future versions don't change
|
||||||
|
// the semantics
|
||||||
|
default:
|
||||||
|
// Not expecting any other values
|
||||||
|
return errRequestBodyNotValid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// These two methods (hasValuesFor*) return two bools
|
// These two methods (hasValuesFor*) return two bools
|
||||||
// The first is a hasAll, that is, does the request have all the values
|
// The first is a hasAll, that is, does the request have all the values
|
||||||
// necessary for this auth method
|
// necessary for this auth method
|
||||||
@ -1559,8 +1597,12 @@ func ensureHeaderIsSigned(signedHeaders, headerToSign string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) {
|
func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) {
|
||||||
decoder := xml.NewDecoder(strings.NewReader(response))
|
|
||||||
result := GetCallerIdentityResponse{}
|
result := GetCallerIdentityResponse{}
|
||||||
|
response = strings.TrimSpace(response)
|
||||||
|
if !strings.HasPrefix(response, "<GetCallerIdentityResponse") && !strings.HasPrefix(response, "<?xml") {
|
||||||
|
return result, errInvalidGetCallerIdentityResponse
|
||||||
|
}
|
||||||
|
decoder := xml.NewDecoder(strings.NewReader(response))
|
||||||
err := decoder.Decode(&result)
|
err := decoder.Decode(&result)
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
@ -1596,6 +1638,11 @@ func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, en
|
|||||||
if response != nil {
|
if response != nil {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
}
|
}
|
||||||
|
// Validate that the response type is XML
|
||||||
|
if ct := response.Header.Get("Content-Type"); ct != "text/xml" {
|
||||||
|
return nil, errInvalidGetCallerIdentityResponse
|
||||||
|
}
|
||||||
|
|
||||||
// we check for status code afterwards to also print out response body
|
// we check for status code afterwards to also print out response body
|
||||||
responseBody, err := ioutil.ReadAll(response.Body)
|
responseBody, err := ioutil.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -305,6 +305,19 @@ func TestBackend_pathLogin_IAMHeaders(t *testing.T) {
|
|||||||
},
|
},
|
||||||
ExpectErr: missingHeaderErr,
|
ExpectErr: missingHeaderErr,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Map-illegal-header",
|
||||||
|
Header: map[string]interface{}{
|
||||||
|
"Content-Length": "43",
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
|
||||||
|
"User-Agent": "aws-sdk-go/1.14.24 (go1.11; darwin; amd64)",
|
||||||
|
"X-Amz-Date": "20180910T203328Z",
|
||||||
|
"Authorization": "AWS4-HMAC-SHA256 Credential=AKIAJPQ466AIIQW4LPSQ/20180910/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-vault-aws-iam-server-id, Signature=cdef5819b2e97f1ff0f3e898fd2621aa03af00a4ec3e019122c20e5482534bf4",
|
||||||
|
"X-Vault-Aws-Iam-Server-Id": "VaultAcceptanceTesting",
|
||||||
|
"X-Amz-Mallory-Header": "<?xml><h4ck0r/>",
|
||||||
|
},
|
||||||
|
ExpectErr: errors.New("invalid request header: X-Amz-Mallory-Header"),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "JSON-complete",
|
Name: "JSON-complete",
|
||||||
Header: `{
|
Header: `{
|
||||||
@ -543,7 +556,8 @@ func setupIAMTestServer() *httptest.Server {
|
|||||||
<ResponseMetadata>
|
<ResponseMetadata>
|
||||||
<RequestId>7f4fc40c-853a-11e6-8848-8d035d01eb87</RequestId>
|
<RequestId>7f4fc40c-853a-11e6-8848-8d035d01eb87</RequestId>
|
||||||
</ResponseMetadata>
|
</ResponseMetadata>
|
||||||
</GetCallerIdentityResponse>`
|
</GetCallerIdentityResponse>
|
||||||
|
`
|
||||||
|
|
||||||
auth := r.Header.Get("Authorization")
|
auth := r.Header.Get("Authorization")
|
||||||
parts := strings.Split(auth, ",")
|
parts := strings.Split(auth, ",")
|
||||||
@ -566,6 +580,7 @@ func setupIAMTestServer() *httptest.Server {
|
|||||||
if matchingCount != len(expectedAuthParts) {
|
if matchingCount != len(expectedAuthParts) {
|
||||||
responseString = "missing auth parts"
|
responseString = "missing auth parts"
|
||||||
}
|
}
|
||||||
|
w.Header().Add("Content-Type", "text/xml")
|
||||||
fmt.Fprintln(w, responseString)
|
fmt.Fprintln(w, responseString)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,7 +66,11 @@ capabilities, the credentials are fetched automatically.
|
|||||||
signed headers validated by AWS. This is to protect against different types of
|
signed headers validated by AWS. This is to protect against different types of
|
||||||
replay attacks, for example a signed request sent to a dev server being resent
|
replay attacks, for example a signed request sent to a dev server being resent
|
||||||
to a production server. Consider setting this to the Vault server's DNS name.
|
to a production server. Consider setting this to the Vault server's DNS name.
|
||||||
|
- `allowed_sts_header_values` `(string: "")` A comma separated list of
|
||||||
|
additional request headers permitted when providing the iam_request_headers for
|
||||||
|
an IAM based login call. In any case, a default list of headers AWS STS
|
||||||
|
expects for a GetCallerIdentity are allowed.
|
||||||
|
|
||||||
### Sample Payload
|
### Sample Payload
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user