mirror of
https://github.com/hashicorp/vault.git
synced 2025-11-18 17:21:13 +01:00
Improve sts header parsing (#3013)
This commit is contained in:
parent
6f26cea0ab
commit
88910d0b1c
@ -1519,7 +1519,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if resp == nil || resp.Auth == nil || resp.IsError() {
|
if resp == nil || resp.Auth == nil || resp.IsError() {
|
||||||
t.Errorf("bad: expected valid login: resp:%#v", resp)
|
t.Fatalf("bad: expected valid login: resp:%#v", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
renewReq := &logical.Request{
|
renewReq := &logical.Request{
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -1086,14 +1087,12 @@ func (b *backend) pathLoginUpdateIam(
|
|||||||
if headersB64 == "" {
|
if headersB64 == "" {
|
||||||
return logical.ErrorResponse("missing iam_request_headers"), nil
|
return logical.ErrorResponse("missing iam_request_headers"), nil
|
||||||
}
|
}
|
||||||
headersJson, err := base64.StdEncoding.DecodeString(headersB64)
|
headers, err := parseIamRequestHeaders(headersB64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse("failed to base64 decode iam_request_headers"), nil
|
return logical.ErrorResponse(fmt.Sprintf("Error parsing iam_request_headers: %v", err)), nil
|
||||||
}
|
}
|
||||||
var headers http.Header
|
if headers == nil {
|
||||||
err = jsonutil.DecodeJSON(headersJson, &headers)
|
return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf("failed to JSON decode iam_request_headers %q: %v", headersJson, err)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := b.lockedClientConfigEntry(req.Storage)
|
config, err := b.lockedClientConfigEntry(req.Storage)
|
||||||
@ -1399,6 +1398,37 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseIamRequestHeaders(headersB64 string) (http.Header, error) {
|
||||||
|
headersJson, err := base64.StdEncoding.DecodeString(headersB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to base64 decode iam_request_headers")
|
||||||
|
}
|
||||||
|
var headersDecoded map[string]interface{}
|
||||||
|
err = jsonutil.DecodeJSON(headersJson, &headersDecoded)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to JSON decode iam_request_headers %q: %v", headersJson, err)
|
||||||
|
}
|
||||||
|
headers := make(http.Header)
|
||||||
|
for k, v := range headersDecoded {
|
||||||
|
switch typedValue := v.(type) {
|
||||||
|
case string:
|
||||||
|
headers.Add(k, typedValue)
|
||||||
|
case []interface{}:
|
||||||
|
for _, individualVal := range typedValue {
|
||||||
|
switch possibleStrVal := individualVal.(type) {
|
||||||
|
case string:
|
||||||
|
headers.Add(k, possibleStrVal)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("header %q contains value %q that has type %s, not string", k, individualVal, reflect.TypeOf(individualVal))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("header %q value %q has type %s, not string or []interface", k, typedValue, reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return headers, nil
|
||||||
|
}
|
||||||
|
|
||||||
func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
|
func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
|
||||||
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
|
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
|
||||||
// The protection against this is that this method will only call the endpoint specified in the
|
// The protection against this is that this method will only call the endpoint specified in the
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
package awsauth
|
package awsauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -143,3 +147,43 @@ func TestBackend_validateVaultHeaderValue(t *testing.T) {
|
|||||||
t.Errorf("did NOT validate valid POST request with split Authorization header: %v", err)
|
t.Errorf("did NOT validate valid POST request with split Authorization header: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackend_pathLogin_parseIamRequestHeaders(t *testing.T) {
|
||||||
|
testIamParser := func(headers interface{}, expectedHeaders http.Header) error {
|
||||||
|
headersJson, err := json.Marshal(headers)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to JSON encode headers: %v", err)
|
||||||
|
}
|
||||||
|
headersB64 := base64.StdEncoding.EncodeToString(headersJson)
|
||||||
|
|
||||||
|
parsedHeaders, err := parseIamRequestHeaders(headersB64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error parsing encoded headers: %v", err)
|
||||||
|
}
|
||||||
|
if parsedHeaders == nil {
|
||||||
|
return fmt.Errorf("nil result from parsing headers")
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(parsedHeaders, expectedHeaders) {
|
||||||
|
return fmt.Errorf("parsed headers not equal to input headers")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
headersGoStyle := http.Header{
|
||||||
|
"Header1": []string{"Value1"},
|
||||||
|
"Header2": []string{"Value2"},
|
||||||
|
}
|
||||||
|
headersMixedType := map[string]interface{}{
|
||||||
|
"Header1": "Value1",
|
||||||
|
"Header2": []string{"Value2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := testIamParser(headersGoStyle, headersGoStyle)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error parsing go-style headers: %v", err)
|
||||||
|
}
|
||||||
|
err = testIamParser(headersMixedType, headersGoStyle)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error parsing mixed-style headers: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -1872,8 +1872,10 @@ The response will be in JSON. For example:
|
|||||||
<li>
|
<li>
|
||||||
<span class="param">iam_request_headers</span>
|
<span class="param">iam_request_headers</span>
|
||||||
<span class="param-flags">required</span>
|
<span class="param-flags">required</span>
|
||||||
Base64-encoded, JSON-serialized representation of the HTTP request
|
Base64-encoded, JSON-serialized representation of the
|
||||||
headers. The JSON serialization assumes that each header key maps to an
|
sts:GetCallerIdentity HTTP request
|
||||||
|
headers. The JSON serialization assumes that each header key maps to
|
||||||
|
either a string value or an
|
||||||
array of string values (though the length of that array will probably
|
array of string values (though the length of that array will probably
|
||||||
only be one). If the `iam_server_id_header_value` is configured in Vault
|
only be one). If the `iam_server_id_header_value` is configured in Vault
|
||||||
for the aws auth mount, then the headers must include the
|
for the aws auth mount, then the headers must include the
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user