vault/audit/headers_test.go
Chris Capurso 69411d7925
VAULT-30108: Include User-Agent header in audit requests by default (#28596)
* include user-agent header in audit by default

* add user-agent audit tests

* update audit default headers docs

* add changelog entry

* remove temp changes from TestAuditedHeadersConfig_ApplyConfig

* more TestAuditedHeadersConfig_ApplyConfig fixes

* add some test comments

* verify type assertions in TestAudit_Headers

* more type assertion checks
2024-10-07 10:02:17 -04:00

655 lines
18 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"context"
"encoding/json"
"errors"
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// mockStorage is a struct that is used to mock barrier storage.
type mockStorage struct {
mock.Mock
v map[string][]byte
}
// List implements List from BarrierStorage interface.
// ignore-nil-nil-function-check.
func (m *mockStorage) List(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
// Get implements Get from BarrierStorage interface.
// ignore-nil-nil-function-check.
func (m *mockStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
b, ok := m.v[key]
if !ok {
return nil, nil
}
var entry *logical.StorageEntry
err := json.Unmarshal(b, &entry)
return entry, err
}
// Put implements Put from BarrierStorage interface.
func (m *mockStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
b, err := json.Marshal(entry)
if err != nil {
return err
}
m.v[entry.Key] = b
return nil
}
// Delete implements Delete from BarrierStorage interface.
func (m *mockStorage) Delete(_ context.Context, _ string) error {
return nil
}
func newMockStorage(t *testing.T) *mockStorage {
t.Helper()
return &mockStorage{
Mock: mock.Mock{},
v: make(map[string][]byte),
}
}
func mockAuditedHeadersConfig(t *testing.T) *HeadersConfig {
return &HeadersConfig{
headerSettings: make(map[string]*headerSettings),
view: newMockStorage(t),
}
}
func TestAuditedHeadersConfig_CRUD(t *testing.T) {
t.Parallel()
conf := mockAuditedHeadersConfig(t)
testAddHeaders(t, conf)
testRemoveHeaders(t, conf)
}
func testAddHeaders(t *testing.T, conf *HeadersConfig) {
t.Helper()
err := conf.Add(context.Background(), "X-Test-Header", false)
if err != nil {
t.Fatalf("Error when adding header to config: %s", err)
}
settings, ok := conf.headerSettings["x-test-header"]
if !ok {
t.Fatal("Expected header to be found in config")
}
if settings.HMAC {
t.Fatal("Expected HMAC to be set to false, got true")
}
out, err := conf.view.Get(context.Background(), auditedHeadersEntry)
if err != nil {
t.Fatalf("Could not retrieve headers entry from config: %s", err)
}
if out == nil {
t.Fatal("nil value")
}
headers := make(map[string]*headerSettings)
err = out.DecodeJSON(&headers)
if err != nil {
t.Fatalf("Error decoding header view: %s", err)
}
expected := map[string]*headerSettings{
"x-test-header": {
HMAC: false,
},
}
if !reflect.DeepEqual(headers, expected) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
}
err = conf.Add(context.Background(), "X-Vault-Header", true)
if err != nil {
t.Fatalf("Error when adding header to config: %s", err)
}
settings, ok = conf.headerSettings["x-vault-header"]
if !ok {
t.Fatal("Expected header to be found in config")
}
if !settings.HMAC {
t.Fatal("Expected HMAC to be set to true, got false")
}
out, err = conf.view.Get(context.Background(), auditedHeadersEntry)
if err != nil {
t.Fatalf("Could not retrieve headers entry from config: %s", err)
}
if out == nil {
t.Fatal("nil value")
}
headers = make(map[string]*headerSettings)
err = out.DecodeJSON(&headers)
if err != nil {
t.Fatalf("Error decoding header view: %s", err)
}
expected["x-vault-header"] = &headerSettings{
HMAC: true,
}
if !reflect.DeepEqual(headers, expected) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
}
}
func testRemoveHeaders(t *testing.T, conf *HeadersConfig) {
t.Helper()
err := conf.Remove(context.Background(), "X-Test-Header")
if err != nil {
t.Fatalf("Error when adding header to config: %s", err)
}
_, ok := conf.headerSettings["x-Test-HeAder"]
if ok {
t.Fatal("Expected header to not be found in config")
}
out, err := conf.view.Get(context.Background(), auditedHeadersEntry)
if err != nil {
t.Fatalf("Could not retrieve headers entry from config: %s", err)
}
if out == nil {
t.Fatal("nil value")
}
headers := make(map[string]*headerSettings)
err = out.DecodeJSON(&headers)
if err != nil {
t.Fatalf("Error decoding header view: %s", err)
}
expected := map[string]*headerSettings{
"x-vault-header": {
HMAC: true,
},
}
if !reflect.DeepEqual(headers, expected) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
}
err = conf.Remove(context.Background(), "x-VaulT-Header")
if err != nil {
t.Fatalf("Error when adding header to config: %s", err)
}
_, ok = conf.headerSettings["x-vault-header"]
if ok {
t.Fatal("Expected header to not be found in config")
}
out, err = conf.view.Get(context.Background(), auditedHeadersEntry)
if err != nil {
t.Fatalf("Could not retrieve headers entry from config: %s", err)
}
if out == nil {
t.Fatal("nil value")
}
headers = make(map[string]*headerSettings)
err = out.DecodeJSON(&headers)
if err != nil {
t.Fatalf("Error decoding header view: %s", err)
}
expected = make(map[string]*headerSettings)
if !reflect.DeepEqual(headers, expected) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
}
}
func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
t.Parallel()
conf := mockAuditedHeadersConfig(t)
err := conf.Add(context.Background(), "X-TesT-Header", false)
require.NoError(t, err)
err = conf.Add(context.Background(), "X-Vault-HeAdEr", true)
require.NoError(t, err)
reqHeaders := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
salter := &testSalter{}
result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter)
if err != nil {
t.Fatal(err)
}
const hmacPrefix = "hmac-sha256:"
expected := map[string][]string{
"x-test-header": {"foo"},
"x-vault-header": {hmacPrefix, hmacPrefix},
}
if len(expected) != len(result) {
t.Fatalf("Expected headers count did not match actual count: Expected count %d\n Got %d\n", len(expected), len(result))
}
for resultKey, resultValues := range result {
expectedValues := expected[resultKey]
if len(expectedValues) != len(resultValues) {
t.Fatalf("Expected header values count did not match actual values count: Expected count: %d\n Got %d\n", len(expectedValues), len(resultValues))
}
for i, e := range expectedValues {
if e == hmacPrefix {
if !strings.HasPrefix(resultValues[i], e) {
t.Fatalf("Expected headers did not match actual: Expected %#v...\n Got %#v\n", e, resultValues[i])
}
} else {
if e != resultValues[i] {
t.Fatalf("Expected headers did not match actual: Expected %#v\n Got %#v\n", e, resultValues[i])
}
}
}
}
// Make sure we didn't edit the reqHeaders map
reqHeadersCopy := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
if !reflect.DeepEqual(reqHeaders, reqHeadersCopy) {
t.Fatalf("Req headers were changed, expected %#v\n got %#v", reqHeadersCopy, reqHeaders)
}
}
// TestAuditedHeadersConfig_ApplyConfig_NoHeaders tests the case where there are
// no headers in the request.
func TestAuditedHeadersConfig_ApplyConfig_NoRequestHeaders(t *testing.T) {
t.Parallel()
conf := mockAuditedHeadersConfig(t)
err := conf.Add(context.Background(), "X-TesT-Header", false)
require.NoError(t, err)
err = conf.Add(context.Background(), "X-Vault-HeAdEr", true)
require.NoError(t, err)
salter := &testSalter{}
// Test sending in nil headers first.
result, err := conf.ApplyConfig(context.Background(), nil, salter)
require.NoError(t, err)
require.NotNil(t, result)
result, err = conf.ApplyConfig(context.Background(), map[string][]string{}, salter)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result, 0)
}
func TestAuditedHeadersConfig_ApplyConfig_NoConfiguredHeaders(t *testing.T) {
t.Parallel()
conf := mockAuditedHeadersConfig(t)
reqHeaders := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
salter := &testSalter{}
result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter)
if err != nil {
t.Fatal(err)
}
if len(result) != 0 {
t.Fatalf("Expected no headers but actually got: %d\n", len(result))
}
// Make sure we didn't edit the reqHeaders map
reqHeadersCopy := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
if !reflect.DeepEqual(reqHeaders, reqHeadersCopy) {
t.Fatalf("Req headers were changed, expected %#v\n got %#v", reqHeadersCopy, reqHeaders)
}
}
// FailingSalter is an implementation of the Salter interface where the Salt
// method always returns an error.
type FailingSalter struct{}
// Salt always returns an error.
func (s *FailingSalter) Salt(context.Context) (*salt.Salt, error) {
return nil, errors.New("testing error")
}
// TestAuditedHeadersConfig_ApplyConfig_HashStringError tests the case where
// an error is returned from hashString instead of a map of headers.
func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) {
t.Parallel()
conf := mockAuditedHeadersConfig(t)
err := conf.Add(context.Background(), "X-TesT-Header", false)
require.NoError(t, err)
err = conf.Add(context.Background(), "X-Vault-HeAdEr", true)
require.NoError(t, err)
reqHeaders := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
salter := &FailingSalter{}
_, err = conf.ApplyConfig(context.Background(), reqHeaders, salter)
if err == nil {
t.Fatal("expected error from ApplyConfig")
}
}
func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) {
conf := &HeadersConfig{
headerSettings: make(map[string]*headerSettings),
view: nil,
}
conf.headerSettings = map[string]*headerSettings{
"X-Test-Header": {false},
"X-Vault-Header": {true},
}
reqHeaders := map[string][]string{
"X-Test-Header": {"foo"},
"X-Vault-Header": {"bar", "bar"},
"Content-Type": {"json"},
}
salter := &testSalter{}
// Reset the timer since we did a lot above
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := conf.ApplyConfig(context.Background(), reqHeaders, salter)
require.NoError(b, err)
}
}
// TestAuditedHeaders_auditedHeadersKey is used to check the key we use to handle
// invalidation doesn't change when we weren't expecting it to.
func TestAuditedHeaders_auditedHeadersKey(t *testing.T) {
t.Parallel()
require.Equal(t, "audited-headers-config/audited-headers", AuditedHeadersKey())
}
// TestAuditedHeaders_NewAuditedHeadersConfig checks supplying incorrect params to
// the constructor for HeadersConfig returns an error.
func TestAuditedHeaders_NewAuditedHeadersConfig(t *testing.T) {
t.Parallel()
ac, err := NewHeadersConfig(nil)
require.Error(t, err)
require.Nil(t, ac)
ac, err = NewHeadersConfig(newMockStorage(t))
require.NoError(t, err)
require.NotNil(t, ac)
}
// TestAuditedHeaders_invalidate ensures that we can update the headers on HeadersConfig
// when we invalidate, and load the updated headers from the view/storage.
func TestAuditedHeaders_invalidate(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
// Store some data using the view.
fakeHeaders1 := map[string]*headerSettings{"x-magic-header": {}}
fakeBytes1, err := json.Marshal(fakeHeaders1)
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err)
// Invalidate and check we now see the header we stored
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders())+1, len(ahc.headerSettings)) // (defaults + 1).
_, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok)
// Do it again with more headers and random casing.
fakeHeaders2 := map[string]*headerSettings{
"x-magic-header": {},
"x-even-MORE-magic-header": {},
}
fakeBytes2, err := json.Marshal(fakeHeaders2)
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes2})
require.NoError(t, err)
// Invalidate and check we now see the header we stored
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders())+2, len(ahc.headerSettings)) // (defaults + 2 new headers)
_, ok = ahc.headerSettings["x-magic-header"]
require.True(t, ok)
_, ok = ahc.headerSettings["x-even-more-magic-header"]
require.True(t, ok)
}
// TestAuditedHeaders_invalidate_nil_view ensures that we invalidate the headers
// correctly (clear them) when we get nil for the storage entry from the view.
func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
// Store some data using the view.
fakeHeaders1 := map[string]*headerSettings{"x-magic-header": {}}
fakeBytes1, err := json.Marshal(fakeHeaders1)
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err)
// Invalidate and check we now see the header we stored
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders())+1, len(ahc.headerSettings)) // defaults + 1
_, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok)
// Swap out the view with a mock that returns nil when we try to invalidate.
// This should mean we end up just clearing the headers (no errors).
mockStorageBarrier := newMockStorage(t)
mockStorageBarrier.On("Get", mock.Anything, mock.Anything).Return(nil, nil)
ahc.view = mockStorageBarrier
// ahc.view = NewBarrierView(mockStorageBarrier, AuditedHeadersSubPath)
// Invalidate should clear out the existing headers without error
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders()), len(ahc.headerSettings)) // defaults
}
// TestAuditedHeaders_invalidate_bad_data ensures that we correctly error if the
// underlying data cannot be parsed as expected.
func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
// Store some bad data using the view.
badBytes, err := json.Marshal("i am bad")
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: badBytes})
require.NoError(t, err)
// Invalidate should
err = ahc.Invalidate(context.Background())
require.Error(t, err)
require.ErrorContains(t, err, "failed to parse config")
}
// TestAuditedHeaders_header checks we can return a copy of settings associated with
// an existing header, and we also know when a header wasn't found.
func TestAuditedHeaders_header(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
err = ahc.Add(context.Background(), "juan", true)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 1)
s, ok := ahc.Header("juan")
require.True(t, ok)
require.Equal(t, true, s.HMAC)
s, ok = ahc.Header("x-magic-token")
require.False(t, ok)
}
// TestAuditedHeaders_headers checks we are able to return a copy of the existing
// configured headers.
func TestAuditedHeaders_headers(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
err = ahc.Add(context.Background(), "juan", true)
require.NoError(t, err)
err = ahc.Add(context.Background(), "john", false)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 2)
s := ahc.Headers()
require.Len(t, s, 2)
require.Equal(t, true, s["juan"].HMAC)
require.Equal(t, false, s["john"].HMAC)
}
// TestAuditedHeaders_invalidate_defaults checks that we ensure any 'default' headers
// are present after invalidation, and if they were loaded from storage then they
// do not get overwritten with our defaults.
func TestAuditedHeaders_invalidate_defaults(t *testing.T) {
t.Parallel()
view := newMockStorage(t)
ahc, err := NewHeadersConfig(view)
require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0)
// Store some data using the view.
fakeHeaders1 := map[string]*headerSettings{"x-magic-header": {}}
fakeBytes1, err := json.Marshal(fakeHeaders1)
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err)
// Invalidate and check we now see the header we stored
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders())+1, len(ahc.headerSettings)) // (defaults + 1 new header)
_, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok)
s, ok := ahc.headerSettings["x-correlation-id"]
require.True(t, ok)
require.False(t, s.HMAC)
s, ok = ahc.headerSettings["user-agent"]
require.True(t, ok)
require.False(t, s.HMAC)
// Add correlation ID and user-agent specifically with HMAC and make sure it doesn't get blasted away.
fakeHeaders1 = map[string]*headerSettings{
"x-magic-header": {},
"X-Correlation-ID": {
HMAC: true,
},
"User-Agent": {
HMAC: true,
},
}
fakeBytes1, err = json.Marshal(fakeHeaders1)
require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err)
// Invalidate and check we now see the header we stored
err = ahc.Invalidate(context.Background())
require.NoError(t, err)
require.Equal(t, len(ahc.DefaultHeaders())+1, len(ahc.headerSettings)) // (defaults + 1 new header, 1 is also a default)
_, ok = ahc.headerSettings["x-magic-header"]
require.True(t, ok)
s, ok = ahc.headerSettings["x-correlation-id"]
require.True(t, ok)
require.True(t, s.HMAC)
s, ok = ahc.headerSettings["user-agent"]
require.True(t, ok)
require.True(t, s.HMAC)
}