diff --git a/audit/entry_formatter.go b/audit/entry_formatter.go index 2bfa86f134..7a459efe24 100644 --- a/audit/entry_formatter.go +++ b/audit/entry_formatter.go @@ -30,7 +30,7 @@ var ( // NewEntryFormatter should be used to create an EntryFormatter. // Accepted options: WithPrefix. -func NewEntryFormatter(config FormatterConfig, salter Salter, opt ...Option) (*EntryFormatter, error) { +func NewEntryFormatter(config FormatterConfig, salter Salter, headersConfig HeaderFormatter, opt ...Option) (*EntryFormatter, error) { const op = "audit.NewEntryFormatter" if salter == nil { @@ -48,19 +48,20 @@ func NewEntryFormatter(config FormatterConfig, salter Salter, opt ...Option) (*E } return &EntryFormatter{ - salter: salter, - config: config, - prefix: opts.withPrefix, + salter: salter, + config: config, + headersConfig: headersConfig, + prefix: opts.withPrefix, }, nil } // Reopen is a no-op for the formatter node. -func (_ *EntryFormatter) Reopen() error { +func (*EntryFormatter) Reopen() error { return nil } // Type describes the type of this node (formatter). -func (_ *EntryFormatter) Type() eventlogger.NodeType { +func (*EntryFormatter) Type() eventlogger.NodeType { return eventlogger.NodeTypeFormatter } @@ -145,11 +146,6 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput return nil, errors.New("salt func not configured") } - s, err := f.salter.Salt(ctx) - if err != nil { - return nil, fmt.Errorf("error fetching salt: %w", err) - } - // Set these to the input values at first auth := in.Auth req := in.Request @@ -163,12 +159,13 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput } if !f.config.Raw { - auth, err = HashAuth(s, auth, f.config.HMACAccessor) + var err error + auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) if err != nil { return nil, err } - req, err = HashRequest(s, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) + req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) if err != nil { return nil, err } @@ -277,11 +274,6 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu return nil, errors.New("salt func not configured") } - s, err := f.salter.Salt(ctx) - if err != nil { - return nil, fmt.Errorf("error fetching salt: %w", err) - } - // Set these to the input values at first auth, req, resp := in.Auth, in.Request, in.Response if auth == nil { @@ -314,17 +306,18 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu respData = resp.Data } } else { - auth, err = HashAuth(s, auth, f.config.HMACAccessor) + var err error + auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) if err != nil { return nil, err } - req, err = HashRequest(s, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) + req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) if err != nil { return nil, err } - resp, err = HashResponse(s, resp, f.config.HMACAccessor, in.NonHMACRespDataKeys, elideListResponseData) + resp, err = HashResponse(ctx, f.salter, resp, f.config.HMACAccessor, in.NonHMACRespDataKeys, elideListResponseData) if err != nil { return nil, err } diff --git a/audit/entry_formatter_test.go b/audit/entry_formatter_test.go index 465579a40c..235022889c 100644 --- a/audit/entry_formatter_test.go +++ b/audit/entry_formatter_test.go @@ -127,7 +127,7 @@ func TestNewEntryFormatter(t *testing.T) { cfg, err := NewFormatterConfig(tc.Options...) require.NoError(t, err) - f, err := NewEntryFormatter(cfg, ss, tc.Options...) + f, err := NewEntryFormatter(cfg, ss, nil, tc.Options...) switch { case tc.IsErrorExpected: @@ -150,7 +150,7 @@ func TestEntryFormatter_Reopen(t *testing.T) { cfg, err := NewFormatterConfig() require.NoError(t, err) - f, err := NewEntryFormatter(cfg, ss) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) require.NotNil(t, f) require.NoError(t, f.Reopen()) @@ -162,7 +162,7 @@ func TestEntryFormatter_Type(t *testing.T) { cfg, err := NewFormatterConfig() require.NoError(t, err) - f, err := NewEntryFormatter(cfg, ss) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) require.NotNil(t, f) require.Equal(t, eventlogger.NodeTypeFormatter, f.Type()) @@ -305,7 +305,7 @@ func TestEntryFormatter_Process(t *testing.T) { cfg, err := NewFormatterConfig(WithFormat(tc.RequiredFormat.String())) require.NoError(t, err) - f, err := NewEntryFormatter(cfg, ss) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) require.NotNil(t, f) @@ -366,13 +366,13 @@ func BenchmarkAuditFileSink_Process(b *testing.B) { }, } - ctx := namespace.RootContext(nil) + ctx := namespace.RootContext(context.Background()) // Create the formatter node. cfg, err := NewFormatterConfig() require.NoError(b, err) ss := newStaticSalt(b) - formatter, err := NewEntryFormatter(cfg, ss) + formatter, err := NewEntryFormatter(cfg, ss, nil) require.NoError(b, err) require.NotNil(b, formatter) diff --git a/audit/entry_formatter_writer.go b/audit/entry_formatter_writer.go index 55861f7746..f1a5c1be81 100644 --- a/audit/entry_formatter_writer.go +++ b/audit/entry_formatter_writer.go @@ -96,7 +96,7 @@ func NewTemporaryFormatter(requiredFormat, prefix string) (*EntryFormatterWriter return nil, err } - eventFormatter, err := NewEntryFormatter(cfg, &nonPersistentSalt{}, WithPrefix(prefix)) + eventFormatter, err := NewEntryFormatter(cfg, &nonPersistentSalt{}, nil, WithPrefix(prefix)) if err != nil { return nil, err } diff --git a/audit/entry_formatter_writer_test.go b/audit/entry_formatter_writer_test.go index cced2f24b9..8fabf10178 100644 --- a/audit/entry_formatter_writer_test.go +++ b/audit/entry_formatter_writer_test.go @@ -127,7 +127,7 @@ func TestNewEntryFormatterWriter(t *testing.T) { var f Formatter if !tc.UseNilFormatter { - tempFormatter, err := NewEntryFormatter(cfg, s) + tempFormatter, err := NewEntryFormatter(cfg, s, nil) require.NoError(t, err) require.NotNil(t, tempFormatter) f = tempFormatter @@ -189,9 +189,10 @@ func TestEntryFormatter_FormatRequest(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() + ss := newStaticSalt(t) cfg, err := NewFormatterConfig() require.NoError(t, err) - f, err := NewEntryFormatter(cfg, newStaticSalt(t)) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) var ctx context.Context @@ -255,9 +256,10 @@ func TestEntryFormatter_FormatResponse(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() + ss := newStaticSalt(t) cfg, err := NewFormatterConfig() require.NoError(t, err) - f, err := NewEntryFormatter(cfg, newStaticSalt(t)) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) var ctx context.Context @@ -359,7 +361,7 @@ func TestElideListResponses(t *testing.T) { formatResponse := func(t *testing.T, config FormatterConfig, operation logical.Operation, inputData map[string]interface{}, ) { - f, err := NewEntryFormatter(config, &tfw) + f, err := NewEntryFormatter(config, &tfw, nil) require.NoError(t, err) formatter, err := NewEntryFormatterWriter(config, f, &tfw) require.NoError(t, err) diff --git a/audit/hashstructure.go b/audit/hashstructure.go index cd4f8085d1..e9cd8ca203 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -4,13 +4,13 @@ package audit import ( + "context" "encoding/json" "errors" "reflect" "time" "github.com/hashicorp/go-secure-stdlib/strutil" - "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/copystructure" @@ -18,17 +18,27 @@ import ( ) // HashString hashes the given opaque string and returns it -func HashString(salter *salt.Salt, data string) string { - return salter.GetIdentifiedHMAC(data) +func HashString(ctx context.Context, salter Salter, data string) (string, error) { + salt, err := salter.Salt(ctx) + if err != nil { + return "", err + } + + return salt.GetIdentifiedHMAC(data), nil } // HashAuth returns a hashed copy of the logical.Auth input. -func HashAuth(salter *salt.Salt, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { +func HashAuth(ctx context.Context, salter Salter, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { if in == nil { return nil, nil } - fn := salter.GetIdentifiedHMAC + salt, err := salter.Salt(ctx) + if err != nil { + return nil, err + } + + fn := salt.GetIdentifiedHMAC auth := *in if auth.ClientToken != "" { @@ -41,12 +51,17 @@ func HashAuth(salter *salt.Salt, in *logical.Auth, HMACAccessor bool) (*logical. } // HashRequest returns a hashed copy of the logical.Request input. -func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { +func HashRequest(ctx context.Context, salter Salter, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { if in == nil { return nil, nil } - fn := salter.GetIdentifiedHMAC + salt, err := salter.Salt(ctx) + if err != nil { + return nil, err + } + + fn := salt.GetIdentifiedHMAC req := *in if req.Auth != nil { @@ -55,7 +70,7 @@ func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonH return nil, err } - req.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) + req.Auth, err = HashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) if err != nil { return nil, err } @@ -84,11 +99,11 @@ func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonH return &req, nil } -func hashMap(fn func(string) string, data map[string]interface{}, nonHMACDataKeys []string) error { +func hashMap(hashFunc HashCallback, data map[string]interface{}, nonHMACDataKeys []string) error { for k, v := range data { if o, ok := v.(logical.OptMarshaler); ok { marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ - ValueHasher: fn, + ValueHasher: hashFunc, }) if err != nil { return err @@ -97,22 +112,21 @@ func hashMap(fn func(string) string, data map[string]interface{}, nonHMACDataKey } } - return HashStructure(data, fn, nonHMACDataKeys) + return HashStructure(data, hashFunc, nonHMACDataKeys) } // HashResponse returns a hashed copy of the logical.Request input. -func HashResponse( - salter *salt.Salt, - in *logical.Response, - HMACAccessor bool, - nonHMACDataKeys []string, - elideListResponseData bool, -) (*logical.Response, error) { +func HashResponse(ctx context.Context, salter Salter, in *logical.Response, HMACAccessor bool, nonHMACDataKeys []string, elideListResponseData bool) (*logical.Response, error) { if in == nil { return nil, nil } - fn := salter.GetIdentifiedHMAC + salt, err := salter.Salt(ctx) + if err != nil { + return nil, err + } + + fn := salt.GetIdentifiedHMAC resp := *in if resp.Auth != nil { @@ -121,7 +135,7 @@ func HashResponse( return nil, err } - resp.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) + resp.Auth, err = HashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) if err != nil { return nil, err } @@ -154,7 +168,7 @@ func HashResponse( if resp.WrapInfo != nil { var err error - resp.WrapInfo, err = HashWrapInfo(salter, resp.WrapInfo, HMACAccessor) + resp.WrapInfo, err = hashWrapInfo(fn, resp.WrapInfo, HMACAccessor) if err != nil { return nil, err } @@ -163,22 +177,21 @@ func HashResponse( return &resp, nil } -// HashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. -func HashWrapInfo(salter *salt.Salt, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { +// hashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. +func hashWrapInfo(hashFunc HashCallback, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { if in == nil { return nil, nil } - fn := salter.GetIdentifiedHMAC wrapinfo := *in - wrapinfo.Token = fn(wrapinfo.Token) + wrapinfo.Token = hashFunc(wrapinfo.Token) if HMACAccessor { - wrapinfo.Accessor = fn(wrapinfo.Accessor) + wrapinfo.Accessor = hashFunc(wrapinfo.Accessor) if wrapinfo.WrappedAccessor != "" { - wrapinfo.WrappedAccessor = fn(wrapinfo.WrappedAccessor) + wrapinfo.WrappedAccessor = hashFunc(wrapinfo.WrappedAccessor) } } diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index c65931f7c5..0ab55d6eec 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -98,20 +98,32 @@ func TestCopy_response(t *testing.T) { } } -func TestHashString(t *testing.T) { +// TestSalter is a structure that implements the Salter interface in a trivial +// manner. +type TestSalter struct{} + +// Salt returns a salt.Salt pointer based on dummy data stored in an in-memory +// storage instance. +func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) { inmemStorage := &logical.InmemStorage{} inmemStorage.Put(context.Background(), &logical.StorageEntry{ Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ + + return salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", }) +} + +func TestHashString(t *testing.T) { + salter := &TestSalter{} + + out, err := HashString(context.Background(), salter, "foo") if err != nil { t.Fatalf("Error instantiating salt: %s", err) } - out := HashString(localSalt, "foo") if out != "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a" { t.Fatalf("err: HashString output did not match expected") } @@ -152,16 +164,10 @@ func TestHashAuth(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ - HMAC: sha256.New, - HMACType: "hmac-sha256", - }) - if err != nil { - t.Fatalf("Error instantiating salt: %s", err) - } + salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := HashAuth(localSalt, tc.Input, tc.HMACAccessor) + out, err := HashAuth(context.Background(), salter, tc.Input, tc.HMACAccessor) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } @@ -216,16 +222,10 @@ func TestHashRequest(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ - HMAC: sha256.New, - HMACType: "hmac-sha256", - }) - if err != nil { - t.Fatalf("Error instantiating salt: %s", err) - } + salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := HashRequest(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) + out, err := HashRequest(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } @@ -287,16 +287,10 @@ func TestHashResponse(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ - HMAC: sha256.New, - HMACType: "hmac-sha256", - }) - if err != nil { - t.Fatalf("Error instantiating salt: %s", err) - } + salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := HashResponse(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) + out, err := HashResponse(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } diff --git a/audit/types.go b/audit/types.go index ddb6a6af95..e4fee94b85 100644 --- a/audit/types.go +++ b/audit/types.go @@ -85,11 +85,21 @@ type Writer interface { WriteResponse(io.Writer, *ResponseEntry) error } +// HeaderFormatter is an interface defining the methods of the +// vault.AuditedHeadersConfig structure needed in this package. +type HeaderFormatter interface { + // ApplyConfig returns a map of header values that consists of the + // intersection of the provided set of header values with a configured + // set of headers and will hash headers that have been configured as such. + ApplyConfig(context.Context, map[string][]string, Salter) (map[string][]string, error) +} + // EntryFormatter should be used to format audit requests and responses. type EntryFormatter struct { - salter Salter - config FormatterConfig - prefix string + salter Salter + headersConfig HeaderFormatter + config FormatterConfig + prefix string } // EntryFormatterWriter should be used to format and write out audit requests and responses. @@ -255,6 +265,9 @@ type nonPersistentSalt struct{} // sink information to different backends such as logs, file, databases, // or other external services. type Backend interface { + // Salter interface must be implemented by anything implementing Backend. + Salter + // LogRequest is used to synchronously log a request. This is done after the // request is authorized but before the request is executed. The arguments // MUST not be modified in any way. They should be deep copied if this is @@ -273,11 +286,6 @@ type Backend interface { // operation on creation, which is currently disallowed.) LogTestMessage(context.Context, *logical.LogInput, map[string]string) error - // GetHash is used to return the given data with the backend's hash, - // so that a caller can determine if a value in the audit log matches - // an expected plaintext value - GetHash(context.Context, string) (string, error) - // Reload is called on SIGHUP for supporting backends. Reload(context.Context) error @@ -305,4 +313,4 @@ type BackendConfig struct { } // Factory is the factory function to create an audit backend. -type Factory func(context.Context, *BackendConfig, bool) (Backend, error) +type Factory func(context.Context, *BackendConfig, bool, HeaderFormatter) (Backend, error) diff --git a/audit/writer_json_test.go b/audit/writer_json_test.go index 7b78aca8cb..56852066c8 100644 --- a/audit/writer_json_test.go +++ b/audit/writer_json_test.go @@ -100,7 +100,7 @@ func TestFormatJSON_formatRequest(t *testing.T) { var buf bytes.Buffer cfg, err := NewFormatterConfig() require.NoError(t, err) - f, err := NewEntryFormatter(cfg, ss) + f, err := NewEntryFormatter(cfg, ss, nil) require.NoError(t, err) formatter := EntryFormatterWriter{ Formatter: f, diff --git a/audit/writer_jsonx_test.go b/audit/writer_jsonx_test.go index 05eb7c677a..17d8b6ff01 100644 --- a/audit/writer_jsonx_test.go +++ b/audit/writer_jsonx_test.go @@ -119,7 +119,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) { WithFormat(JSONxFormat.String()), ) require.NoError(t, err) - f, err := NewEntryFormatter(cfg, tempStaticSalt) + f, err := NewEntryFormatter(cfg, tempStaticSalt, nil) require.NoError(t, err) writer := &JSONxWriter{Prefix: tc.Prefix} formatter, err := NewEntryFormatterWriter(cfg, f, writer) diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index 1ab970ed7a..fe15e0b044 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -22,7 +22,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { +func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { if conf.SaltConfig == nil { return nil, fmt.Errorf("nil salt config") } @@ -131,7 +131,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool b.salt.Store((*salt.Salt)(nil)) // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithPrefix(conf.Config["prefix"])) + f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig, audit.WithPrefix(conf.Config["prefix"])) if err != nil { return nil, fmt.Errorf("error creating formatter: %w", err) } @@ -253,15 +253,6 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { return newSalt, nil } -func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { - salt, err := b.Salt(ctx) - if err != nil { - return "", err - } - - return audit.HashString(salt, data), nil -} - func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var writer io.Writer switch b.path { diff --git a/builtin/audit/file/backend_test.go b/builtin/audit/file/backend_test.go index a9ef8cb67d..0996a51a2e 100644 --- a/builtin/audit/file/backend_test.go +++ b/builtin/audit/file/backend_test.go @@ -35,7 +35,7 @@ func TestAuditFile_fileModeNew(t *testing.T) { SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, Config: config, - }, false) + }, false, nil) if err != nil { t.Fatal(err) } @@ -74,7 +74,7 @@ func TestAuditFile_fileModeExisting(t *testing.T) { Config: config, SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - }, false) + }, false, nil) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestAuditFile_fileMode0000(t *testing.T) { Config: config, SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - }, false) + }, false, nil) if err != nil { t.Fatal(err) } @@ -148,7 +148,7 @@ func TestAuditFile_EventLogger_fileModeNew(t *testing.T) { SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, Config: config, - }, true) + }, true, nil) if err != nil { t.Fatal(err) } @@ -170,7 +170,7 @@ func BenchmarkAuditFile_request(b *testing.B) { Config: config, SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - }, false) + }, false, nil) if err != nil { b.Fatal(err) } diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index bc9f444076..48946dd66f 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -21,7 +21,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { +func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { if conf.SaltConfig == nil { return nil, fmt.Errorf("nil salt config") } @@ -108,7 +108,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool } // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b) + f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig) if err != nil { return nil, fmt.Errorf("error creating formatter: %w", err) } @@ -177,14 +177,6 @@ type Backend struct { var _ audit.Backend = (*Backend)(nil) -func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { - salt, err := b.Salt(ctx) - if err != nil { - return "", err - } - return audit.HashString(salt, data), nil -} - func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index 9dde55afc7..bb773b82e2 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -18,7 +18,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { +func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { if conf.SaltConfig == nil { return nil, fmt.Errorf("nil salt config") } @@ -102,7 +102,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool } // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithPrefix(conf.Config["prefix"])) + f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig, audit.WithPrefix(conf.Config["prefix"])) if err != nil { return nil, fmt.Errorf("error creating formatter: %w", err) } @@ -166,14 +166,6 @@ type Backend struct { var _ audit.Backend = (*Backend)(nil) -func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { - salt, err := b.Salt(ctx) - if err != nil { - return "", err - } - return audit.HashString(salt, data), nil -} - func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { diff --git a/helper/testhelpers/corehelpers/corehelpers.go b/helper/testhelpers/corehelpers/corehelpers.go index 144a196c88..4e8783ff32 100644 --- a/helper/testhelpers/corehelpers/corehelpers.go +++ b/helper/testhelpers/corehelpers/corehelpers.go @@ -252,7 +252,7 @@ func NewNoopAudit(config map[string]string) (*NoopAudit, error) { return nil, err } - f, err := audit.NewEntryFormatter(cfg, n) + f, err := audit.NewEntryFormatter(cfg, n, nil) if err != nil { return nil, fmt.Errorf("error creating formatter: %w", err) } @@ -268,7 +268,7 @@ func NewNoopAudit(config map[string]string) (*NoopAudit, error) { } func NoopAuditFactory(records **[][]byte) audit.Factory { - return func(_ context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + return func(_ context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { n, err := NewNoopAudit(config.Config) if err != nil { return nil, err diff --git a/http/logical_test.go b/http/logical_test.go index 90eac9469d..e0f9935df8 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -482,7 +482,7 @@ func TestLogical_Audit_invalidWrappingToken(t *testing.T) { noop := corehelpers.TestNoopAudit(t, nil) c, _, root := vault.TestCoreUnsealedWithConfig(t, &vault.CoreConfig{ AuditBackends: map[string]audit.Factory{ - "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { return noop, nil }, }, diff --git a/vault/audit.go b/vault/audit.go index 51b35cd56a..8c9254cf50 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -486,7 +486,7 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi SaltView: view, SaltConfig: saltConfig, Config: conf, - }, c.IsExperimentEnabled(experiments.VaultExperimentCoreAuditEventsAlpha1)) + }, c.IsExperimentEnabled(experiments.VaultExperimentCoreAuditEventsAlpha1), c.auditedHeaders) if err != nil { return nil, err } diff --git a/vault/audit_broker.go b/vault/audit_broker.go index 6d8a4ad14f..7fe47b61e7 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -129,7 +129,7 @@ func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (s return "", fmt.Errorf("unknown audit backend %q", name) } - return be.backend.GetHash(ctx, input) + return audit.HashString(ctx, be.backend, input) } // LogRequest is used to ensure all the audit backends have an opportunity to @@ -182,7 +182,7 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head anyLogged := false for name, be := range a.backends { in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) if thErr != nil { a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) continue @@ -247,7 +247,7 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea anyLogged := false for name, be := range a.backends { in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) if thErr != nil { a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) continue diff --git a/vault/audit_test.go b/vault/audit_test.go index d7349a63b8..8242002c6c 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -27,7 +27,7 @@ import ( func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { c, _, _ := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { err := config.SaltView.Put(ctx, &logical.StorageEntry{ Key: "bar", Value: []byte("baz"), @@ -36,7 +36,7 @@ func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { t.Fatalf("expected a read-only error") } factory := corehelpers.NoopAuditFactory(nil) - return factory(ctx, config, false) + return factory(ctx, config, false, nil) } me := &MountEntry{ @@ -103,7 +103,7 @@ func TestCore_EnableAudit(t *testing.T) { func TestCore_EnableAudit_MixedFailures(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) - c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { return nil, fmt.Errorf("failing enabling") } @@ -152,7 +152,7 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) { func TestCore_EnableAudit_Local(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) - c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { return nil, fmt.Errorf("failing enabling") } diff --git a/vault/audited_headers.go b/vault/audited_headers.go index 70c9f467ee..22f40b22ff 100644 --- a/vault/audited_headers.go +++ b/vault/audited_headers.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/sdk/logical" ) @@ -92,7 +93,7 @@ func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error // ApplyConfig returns a map of approved headers and their values, either // hmac'ed or plaintext -func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, hashFunc func(context.Context, string) (string, error)) (result map[string][]string, retErr error) { +func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, salter audit.Salter) (result map[string][]string, retErr error) { // Grab a read lock a.RLock() defer a.RUnlock() @@ -114,7 +115,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[stri // Optionally hmac the values if settings.HMAC { for i, el := range hVals { - hVal, err := hashFunc(ctx, el) + hVal, err := audit.HashString(ctx, salter, el) if err != nil { return nil, err } diff --git a/vault/audited_headers_test.go b/vault/audited_headers_test.go index 940197e997..9988124b11 100644 --- a/vault/audited_headers_test.go +++ b/vault/audited_headers_test.go @@ -5,7 +5,9 @@ package vault import ( "context" + "errors" "reflect" + "strings" "testing" "github.com/hashicorp/vault/sdk/helper/salt" @@ -169,6 +171,12 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { } } +type TestSalter struct{} + +func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) { + return salt.NewSalt(ctx, nil, nil) +} + func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { conf := mockAuditedHeadersConfig(t) @@ -181,20 +189,40 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { "Content-Type": {"json"}, } - hashFunc := func(ctx context.Context, s string) (string, error) { return "hashed", nil } + salter := &TestSalter{} - result, err := conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) + result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) if err != nil { t.Fatal(err) } expected := map[string][]string{ "x-test-header": {"foo"}, - "x-vault-header": {"hashed", "hashed"}, + "x-vault-header": {"hmac-sha256:", "hmac-sha256:"}, } - if !reflect.DeepEqual(result, expected) { - t.Fatalf("Expected headers did not match actual: Expected %#v\n Got %#v\n", expected, result) + if len(expected) != len(result) { + t.Fatalf("Expected headers count did not match actual count: Expected count %d\n Got %d\n", len(expected), len(result)) + } + + for resultKey, resultValues := range result { + expectedValues := expected[resultKey] + + if len(expectedValues) != len(resultValues) { + t.Fatalf("Expected header values count did not match actual values count: Expected count: %d\n Got %d\n", len(expectedValues), len(resultValues)) + } + + for i, e := range expectedValues { + if e == "hmac-sha256:" { + if !strings.HasPrefix(resultValues[i], e) { + t.Fatalf("Expected headers did not match actual: Expected %#v...\n Got %#v\n", e, resultValues[i]) + } + } else { + if e != resultValues[i] { + t.Fatalf("Expected headers did not match actual: Expected %#v\n Got %#v\n", e, resultValues[i]) + } + } + } } // Make sure we didn't edit the reqHeaders map @@ -209,6 +237,91 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { } } +// TestAuditedHeadersConfig_ApplyConfig_NoHeaders tests the case where there are +// no headers in the request. +func TestAuditedHeadersConfig_ApplyConfig_NoRequestHeaders(t *testing.T) { + conf := mockAuditedHeadersConfig(t) + + conf.add(context.Background(), "X-TesT-Header", false) + conf.add(context.Background(), "X-Vault-HeAdEr", true) + + reqHeaders := map[string][]string{} + + salter := &TestSalter{} + + result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) + if err != nil { + t.Fatal(err) + } + + if len(result) != 0 { + t.Fatalf("Expected no headers but actually got: %d\n", len(result)) + } +} + +func TestAuditedHeadersConfig_ApplyConfig_NoConfiguredHeaders(t *testing.T) { + conf := mockAuditedHeadersConfig(t) + + reqHeaders := map[string][]string{ + "X-Test-Header": {"foo"}, + "X-Vault-Header": {"bar", "bar"}, + "Content-Type": {"json"}, + } + + salter := &TestSalter{} + + result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) + if err != nil { + t.Fatal(err) + } + + if len(result) != 0 { + t.Fatalf("Expected no headers but actually got: %d\n", len(result)) + } + + // Make sure we didn't edit the reqHeaders map + reqHeadersCopy := map[string][]string{ + "X-Test-Header": {"foo"}, + "X-Vault-Header": {"bar", "bar"}, + "Content-Type": {"json"}, + } + + if !reflect.DeepEqual(reqHeaders, reqHeadersCopy) { + t.Fatalf("Req headers were changed, expected %#v\n got %#v", reqHeadersCopy, reqHeaders) + } +} + +// FailingSalter is an implementation of the Salter interface where the Salt +// method always returns an error. +type FailingSalter struct{} + +// Salt always returns an error. +func (s *FailingSalter) Salt(context.Context) (*salt.Salt, error) { + return nil, errors.New("testing error") +} + +// TestAuditedHeadersConfig_ApplyConfig_HashStringError tests the case where +// an error is returned from HashString instead of a map of headers. +func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) { + conf := mockAuditedHeadersConfig(t) + + conf.add(context.Background(), "X-TesT-Header", false) + conf.add(context.Background(), "X-Vault-HeAdEr", true) + + reqHeaders := map[string][]string{ + "X-Test-Header": {"foo"}, + "X-Vault-Header": {"bar", "bar"}, + "Content-Type": {"json"}, + } + + salter := &FailingSalter{} + + _, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) + if err == nil { + t.Fatal("expected error from ApplyConfig") + } +} + func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { conf := &AuditedHeadersConfig{ Headers: make(map[string]*auditedHeaderSettings), @@ -226,16 +339,11 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { "Content-Type": {"json"}, } - salter, err := salt.NewSalt(context.Background(), nil, nil) - if err != nil { - b.Fatal(err) - } - - hashFunc := func(ctx context.Context, s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } + salter := &TestSalter{} // Reset the timer since we did a lot above b.ResetTimer() for i := 0; i < b.N; i++ { - conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) + conf.ApplyConfig(context.Background(), reqHeaders, salter) } } diff --git a/vault/core_test.go b/vault/core_test.go index 306afea483..a44eea3d7b 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -1137,7 +1137,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { // Create a noop audit backend noop := &corehelpers.NoopAudit{} c, _, root := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { noop = &corehelpers.NoopAudit{ Config: config, } @@ -1201,7 +1201,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { // Create a noop audit backend var noop *corehelpers.NoopAudit c, _, root := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { noop = &corehelpers.NoopAudit{ Config: config, } @@ -1323,7 +1323,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { return noopBack, nil } - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { noop = &corehelpers.NoopAudit{ Config: config, } diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index 74d6ab713d..6877ccdabb 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -61,7 +61,7 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { "totp": totp.Factory, }, AuditBackends: map[string]audit.Factory{ - "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { return noop, nil }, }, diff --git a/vault/mount_test.go b/vault/mount_test.go index 3b2b96d083..fca1a151d8 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -724,7 +724,7 @@ func TestDefaultMountTable(t *testing.T) { func TestCore_MountTable_UpgradeToTyped(t *testing.T) { c, _, _ := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { return &corehelpers.NoopAudit{ Config: config, }, nil