Improve sts header parsing (#3013)

This commit is contained in:
Joel Thompson 2017-07-18 09:51:45 -04:00 committed by Jeff Mitchell
parent 6f26cea0ab
commit 88910d0b1c
4 changed files with 85 additions and 9 deletions

View File

@ -1519,7 +1519,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if resp == nil || resp.Auth == nil || resp.IsError() { if resp == nil || resp.Auth == nil || resp.IsError() {
t.Errorf("bad: expected valid login: resp:%#v", resp) t.Fatalf("bad: expected valid login: resp:%#v", resp)
} }
renewReq := &logical.Request{ renewReq := &logical.Request{

View File

@ -10,6 +10,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -1086,14 +1087,12 @@ func (b *backend) pathLoginUpdateIam(
if headersB64 == "" { if headersB64 == "" {
return logical.ErrorResponse("missing iam_request_headers"), nil return logical.ErrorResponse("missing iam_request_headers"), nil
} }
headersJson, err := base64.StdEncoding.DecodeString(headersB64) headers, err := parseIamRequestHeaders(headersB64)
if err != nil { if err != nil {
return logical.ErrorResponse("failed to base64 decode iam_request_headers"), nil return logical.ErrorResponse(fmt.Sprintf("Error parsing iam_request_headers: %v", err)), nil
} }
var headers http.Header if headers == nil {
err = jsonutil.DecodeJSON(headersJson, &headers) return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to JSON decode iam_request_headers %q: %v", headersJson, err)), nil
} }
config, err := b.lockedClientConfigEntry(req.Storage) config, err := b.lockedClientConfigEntry(req.Storage)
@ -1399,6 +1398,37 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
return result, err return result, err
} }
func parseIamRequestHeaders(headersB64 string) (http.Header, error) {
headersJson, err := base64.StdEncoding.DecodeString(headersB64)
if err != nil {
return nil, fmt.Errorf("failed to base64 decode iam_request_headers")
}
var headersDecoded map[string]interface{}
err = jsonutil.DecodeJSON(headersJson, &headersDecoded)
if err != nil {
return nil, fmt.Errorf("failed to JSON decode iam_request_headers %q: %v", headersJson, err)
}
headers := make(http.Header)
for k, v := range headersDecoded {
switch typedValue := v.(type) {
case string:
headers.Add(k, typedValue)
case []interface{}:
for _, individualVal := range typedValue {
switch possibleStrVal := individualVal.(type) {
case string:
headers.Add(k, possibleStrVal)
default:
return nil, fmt.Errorf("header %q contains value %q that has type %s, not string", k, individualVal, reflect.TypeOf(individualVal))
}
}
default:
return nil, fmt.Errorf("header %q value %q has type %s, not string or []interface", k, typedValue, reflect.TypeOf(v))
}
}
return headers, nil
}
func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) { func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy // NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
// The protection against this is that this method will only call the endpoint specified in the // The protection against this is that this method will only call the endpoint specified in the

View File

@ -1,8 +1,12 @@
package awsauth package awsauth
import ( import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"testing" "testing"
) )
@ -143,3 +147,43 @@ func TestBackend_validateVaultHeaderValue(t *testing.T) {
t.Errorf("did NOT validate valid POST request with split Authorization header: %v", err) t.Errorf("did NOT validate valid POST request with split Authorization header: %v", err)
} }
} }
func TestBackend_pathLogin_parseIamRequestHeaders(t *testing.T) {
testIamParser := func(headers interface{}, expectedHeaders http.Header) error {
headersJson, err := json.Marshal(headers)
if err != nil {
return fmt.Errorf("unable to JSON encode headers: %v", err)
}
headersB64 := base64.StdEncoding.EncodeToString(headersJson)
parsedHeaders, err := parseIamRequestHeaders(headersB64)
if err != nil {
return fmt.Errorf("error parsing encoded headers: %v", err)
}
if parsedHeaders == nil {
return fmt.Errorf("nil result from parsing headers")
}
if !reflect.DeepEqual(parsedHeaders, expectedHeaders) {
return fmt.Errorf("parsed headers not equal to input headers")
}
return nil
}
headersGoStyle := http.Header{
"Header1": []string{"Value1"},
"Header2": []string{"Value2"},
}
headersMixedType := map[string]interface{}{
"Header1": "Value1",
"Header2": []string{"Value2"},
}
err := testIamParser(headersGoStyle, headersGoStyle)
if err != nil {
t.Errorf("error parsing go-style headers: %v", err)
}
err = testIamParser(headersMixedType, headersGoStyle)
if err != nil {
t.Errorf("error parsing mixed-style headers: %v", err)
}
}

View File

@ -1872,8 +1872,10 @@ The response will be in JSON. For example:
<li> <li>
<span class="param">iam_request_headers</span> <span class="param">iam_request_headers</span>
<span class="param-flags">required</span> <span class="param-flags">required</span>
Base64-encoded, JSON-serialized representation of the HTTP request Base64-encoded, JSON-serialized representation of the
headers. The JSON serialization assumes that each header key maps to an sts:GetCallerIdentity HTTP request
headers. The JSON serialization assumes that each header key maps to
either a string value or an
array of string values (though the length of that array will probably array of string values (though the length of that array will probably
only be one). If the `iam_server_id_header_value` is configured in Vault only be one). If the `iam_server_id_header_value` is configured in Vault
for the aws auth mount, then the headers must include the for the aws auth mount, then the headers must include the