diff --git a/audit/audit.go b/audit/audit.go index dffa8eee54..b96391c5ca 100644 --- a/audit/audit.go +++ b/audit/audit.go @@ -25,15 +25,21 @@ type Backend interface { // GetHash is used to return the given data with the backend's hash, // so that a caller can determine if a value in the audit log matches // an expected plaintext value - GetHash(string) string + GetHash(string) (string, error) // Reload is called on SIGHUP for supporting backends. Reload() error + + // Invalidate is called for path invalidation + Invalidate() } type BackendConfig struct { - // The salt that should be used for any secret obfuscation - Salt *salt.Salt + // The view to store the salt + SaltView logical.Storage + + // The salt config that should be used for any secret obfuscation + SaltConfig *salt.Config // Config is the opaque user configuration provided when mounting Config map[string]string diff --git a/audit/format.go b/audit/format.go index 919da125e4..773d0ad3f7 100644 --- a/audit/format.go +++ b/audit/format.go @@ -7,6 +7,8 @@ import ( "time" "github.com/SermoDigital/jose/jws" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" ) @@ -14,6 +16,7 @@ import ( type AuditFormatWriter interface { WriteRequest(io.Writer, *AuditRequestEntry) error WriteResponse(io.Writer, *AuditResponseEntry) error + Salt() (*salt.Salt, error) } // AuditFormatter implements the Formatter interface, and allows the underlying @@ -41,6 +44,11 @@ func (f *AuditFormatter) FormatRequest( return fmt.Errorf("no format writer specified") } + salt, err := f.Salt() + if err != nil { + return errwrap.Wrapf("error fetching salt: {{err}}", err) + } + if !config.Raw { // Before we copy the structure we must nil out some data // otherwise we will cause reflection to panic and die @@ -70,7 +78,7 @@ func (f *AuditFormatter) FormatRequest( // Hash any sensitive information if auth != nil { - if err := Hash(config.Salt, auth); err != nil { + if err := Hash(salt, auth); err != nil { return err } } @@ -80,7 +88,7 @@ func (f *AuditFormatter) FormatRequest( if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { clientTokenAccessor = req.ClientTokenAccessor } - if err := Hash(config.Salt, req); err != nil { + if err := Hash(salt, req); err != nil { return err } if clientTokenAccessor != "" { @@ -152,6 +160,11 @@ func (f *AuditFormatter) FormatResponse( return fmt.Errorf("no format writer specified") } + salt, err := f.Salt() + if err != nil { + return errwrap.Wrapf("error fetching salt: {{err}}", err) + } + if !config.Raw { // Before we copy the structure we must nil out some data // otherwise we will cause reflection to panic and die @@ -195,7 +208,7 @@ func (f *AuditFormatter) FormatResponse( if !config.HMACAccessor && auth.Accessor != "" { accessor = auth.Accessor } - if err := Hash(config.Salt, auth); err != nil { + if err := Hash(salt, auth); err != nil { return err } if accessor != "" { @@ -208,7 +221,7 @@ func (f *AuditFormatter) FormatResponse( if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { clientTokenAccessor = req.ClientTokenAccessor } - if err := Hash(config.Salt, req); err != nil { + if err := Hash(salt, req); err != nil { return err } if clientTokenAccessor != "" { @@ -224,7 +237,7 @@ func (f *AuditFormatter) FormatResponse( if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { wrappedAccessor = resp.WrapInfo.WrappedAccessor } - if err := Hash(config.Salt, resp); err != nil { + if err := Hash(salt, resp); err != nil { return err } if accessor != "" { diff --git a/audit/format_json.go b/audit/format_json.go index 9e200f0032..0a5c9d90bd 100644 --- a/audit/format_json.go +++ b/audit/format_json.go @@ -4,12 +4,15 @@ import ( "encoding/json" "fmt" "io" + + "github.com/hashicorp/vault/helper/salt" ) // JSONFormatWriter is an AuditFormatWriter implementation that structures data into // a JSON format. type JSONFormatWriter struct { - Prefix string + Prefix string + SaltFunc func() (*salt.Salt, error) } func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { @@ -43,3 +46,7 @@ func (f *JSONFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) enc := json.NewEncoder(w) return enc.Encode(resp) } + +func (f *JSONFormatWriter) Salt() (*salt.Salt, error) { + return f.SaltFunc() +} diff --git a/audit/format_json_test.go b/audit/format_json_test.go index 21bb647856..4155dbb18a 100644 --- a/audit/format_json_test.go +++ b/audit/format_json_test.go @@ -15,6 +15,13 @@ import ( ) func TestFormatJSON_formatRequest(t *testing.T) { + salter, err := salt.NewSalt(nil, nil) + if err != nil { + t.Fatal(err) + } + saltFunc := func() (*salt.Salt, error) { + return salter, nil + } cases := map[string]struct { Auth *logical.Auth Req *logical.Request @@ -66,13 +73,11 @@ func TestFormatJSON_formatRequest(t *testing.T) { var buf bytes.Buffer formatter := AuditFormatter{ AuditFormatWriter: &JSONFormatWriter{ - Prefix: tc.Prefix, + Prefix: tc.Prefix, + SaltFunc: saltFunc, }, } - salter, _ := salt.NewSalt(nil, nil) - config := FormatterConfig{ - Salt: salter, - } + config := FormatterConfig{} if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { t.Fatalf("bad: %s\nerr: %s", name, err) } diff --git a/audit/format_jsonx.go b/audit/format_jsonx.go index cc6cc956be..792e5524c3 100644 --- a/audit/format_jsonx.go +++ b/audit/format_jsonx.go @@ -5,13 +5,15 @@ import ( "fmt" "io" + "github.com/hashicorp/vault/helper/salt" "github.com/jefferai/jsonx" ) // JSONxFormatWriter is an AuditFormatWriter implementation that structures data into // a XML format. type JSONxFormatWriter struct { - Prefix string + Prefix string + SaltFunc func() (*salt.Salt, error) } func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { @@ -65,3 +67,7 @@ func (f *JSONxFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) _, err = w.Write(xmlBytes) return err } + +func (f *JSONxFormatWriter) Salt() (*salt.Salt, error) { + return f.SaltFunc() +} diff --git a/audit/format_jsonx_test.go b/audit/format_jsonx_test.go index 8d4fe4ba29..a0cc3a191d 100644 --- a/audit/format_jsonx_test.go +++ b/audit/format_jsonx_test.go @@ -13,6 +13,13 @@ import ( ) func TestFormatJSONx_formatRequest(t *testing.T) { + salter, err := salt.NewSalt(nil, nil) + if err != nil { + t.Fatal(err) + } + saltFunc := func() (*salt.Salt, error) { + return salter, nil + } cases := map[string]struct { Auth *logical.Auth Req *logical.Request @@ -67,12 +74,11 @@ func TestFormatJSONx_formatRequest(t *testing.T) { var buf bytes.Buffer formatter := AuditFormatter{ AuditFormatWriter: &JSONxFormatWriter{ - Prefix: tc.Prefix, + Prefix: tc.Prefix, + SaltFunc: saltFunc, }, } - salter, _ := salt.NewSalt(nil, nil) config := FormatterConfig{ - Salt: salter, OmitTime: true, } if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { diff --git a/audit/format_test.go b/audit/format_test.go index 6a6425b3a4..5390229db2 100644 --- a/audit/format_test.go +++ b/audit/format_test.go @@ -10,6 +10,8 @@ import ( ) type noopFormatWriter struct { + salt *salt.Salt + SaltFunc func() (*salt.Salt, error) } func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error { @@ -20,11 +22,20 @@ func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) err return nil } -func TestFormatRequestErrors(t *testing.T) { - salter, _ := salt.NewSalt(nil, nil) - config := FormatterConfig{ - Salt: salter, +func (n *noopFormatWriter) Salt() (*salt.Salt, error) { + if n.salt != nil { + return n.salt, nil } + var err error + n.salt, err = salt.NewSalt(nil, nil) + if err != nil { + return nil, err + } + return n.salt, nil +} + +func TestFormatRequestErrors(t *testing.T) { + config := FormatterConfig{} formatter := AuditFormatter{ AuditFormatWriter: &noopFormatWriter{}, } @@ -38,10 +49,7 @@ func TestFormatRequestErrors(t *testing.T) { } func TestFormatResponseErrors(t *testing.T) { - salter, _ := salt.NewSalt(nil, nil) - config := FormatterConfig{ - Salt: salter, - } + config := FormatterConfig{} formatter := AuditFormatter{ AuditFormatWriter: &noopFormatWriter{}, } diff --git a/audit/formatter.go b/audit/formatter.go index 318bd1bc59..3c1748f53a 100644 --- a/audit/formatter.go +++ b/audit/formatter.go @@ -3,7 +3,6 @@ package audit import ( "io" - "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) @@ -19,7 +18,6 @@ type Formatter interface { type FormatterConfig struct { Raw bool - Salt *salt.Salt HMACAccessor bool // This should only ever be used in a testing context diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index cc2cfe5540..0b05b0a3d3 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -8,12 +8,16 @@ import ( "sync" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) func Factory(conf *audit.BackendConfig) (audit.Backend, error) { - if conf.Salt == nil { - return nil, fmt.Errorf("nil salt") + if conf.SaltConfig == nil { + return nil, fmt.Errorf("nil salt config") + } + if conf.SaltView == nil { + return nil, fmt.Errorf("nil salt view") } path, ok := conf.Config["file_path"] @@ -65,11 +69,12 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { } b := &Backend{ - path: path, - mode: mode, + path: path, + mode: mode, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, formatConfig: audit.FormatterConfig{ Raw: logRaw, - Salt: conf.Salt, HMACAccessor: hmacAccessor, }, } @@ -77,11 +82,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { switch format { case "json": b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } case "jsonx": b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } } @@ -109,10 +116,39 @@ type Backend struct { fileLock sync.RWMutex f *os.File mode os.FileMode + + saltMutex sync.RWMutex + salt *salt.Salt + saltConfig *salt.Config + saltView logical.Storage } -func (b *Backend) GetHash(data string) string { - return audit.HashString(b.formatConfig.Salt, data) +func (b *Backend) Salt() (*salt.Salt, error) { + b.saltMutex.RLock() + if b.salt != nil { + defer b.saltMutex.RUnlock() + return b.salt, nil + } + b.saltMutex.RUnlock() + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + if b.salt != nil { + return b.salt, nil + } + salt, err := salt.NewSalt(b.saltView, b.saltConfig) + if err != nil { + return nil, err + } + b.salt = salt + return salt, nil +} + +func (b *Backend) GetHash(data string) (string, error) { + salt, err := b.Salt() + if err != nil { + return "", err + } + return audit.HashString(salt, data), nil } func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { @@ -189,3 +225,9 @@ func (b *Backend) Reload() error { return b.open() } + +func (b *Backend) Invalidate() { + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + b.salt = nil +} diff --git a/builtin/audit/file/backend_test.go b/builtin/audit/file/backend_test.go index 0a1a8c7751..643e467c9b 100644 --- a/builtin/audit/file/backend_test.go +++ b/builtin/audit/file/backend_test.go @@ -9,11 +9,10 @@ import ( "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/logical" ) func TestAuditFile_fileModeNew(t *testing.T) { - salter, _ := salt.NewSalt(nil, nil) - modeStr := "0777" mode, err := strconv.ParseUint(modeStr, 8, 32) @@ -28,8 +27,9 @@ func TestAuditFile_fileModeNew(t *testing.T) { } _, err = Factory(&audit.BackendConfig{ - Salt: salter, - Config: config, + SaltConfig: &salt.Config{}, + SaltView: &logical.InmemStorage{}, + Config: config, }) if err != nil { t.Fatal(err) @@ -45,8 +45,6 @@ func TestAuditFile_fileModeNew(t *testing.T) { } func TestAuditFile_fileModeExisting(t *testing.T) { - salter, _ := salt.NewSalt(nil, nil) - f, err := ioutil.TempFile("", "test") if err != nil { t.Fatalf("Failure to create test file.") @@ -68,8 +66,9 @@ func TestAuditFile_fileModeExisting(t *testing.T) { } _, err = Factory(&audit.BackendConfig{ - Salt: salter, - Config: config, + Config: config, + SaltConfig: &salt.Config{}, + SaltView: &logical.InmemStorage{}, }) if err != nil { t.Fatal(err) diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index 91e701ed03..0507af3c2b 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -11,12 +11,16 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/parseutil" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) func Factory(conf *audit.BackendConfig) (audit.Backend, error) { - if conf.Salt == nil { - return nil, fmt.Errorf("nil salt passed in") + if conf.SaltConfig == nil { + return nil, fmt.Errorf("nil salt config") + } + if conf.SaltView == nil { + return nil, fmt.Errorf("nil salt view") } address, ok := conf.Config["address"] @@ -75,11 +79,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { b := &Backend{ connection: conn, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, formatConfig: audit.FormatterConfig{ Raw: logRaw, - Salt: conf.Salt, HMACAccessor: hmacAccessor, }, + writeDuration: writeDuration, address: address, socketType: socketType, @@ -88,11 +94,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { switch format { case "json": b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } case "jsonx": b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } } @@ -111,10 +119,19 @@ type Backend struct { socketType string sync.Mutex + + saltMutex sync.RWMutex + salt *salt.Salt + saltConfig *salt.Config + saltView logical.Storage } -func (b *Backend) GetHash(data string) string { - return audit.HashString(b.formatConfig.Salt, data) +func (b *Backend) GetHash(data string) (string, error) { + salt, err := b.Salt() + if err != nil { + return "", err + } + return audit.HashString(salt, data), nil } func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { @@ -198,3 +215,29 @@ func (b *Backend) Reload() error { return err } + +func (b *Backend) Salt() (*salt.Salt, error) { + b.saltMutex.RLock() + if b.salt != nil { + defer b.saltMutex.RUnlock() + return b.salt, nil + } + b.saltMutex.RUnlock() + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + if b.salt != nil { + return b.salt, nil + } + salt, err := salt.NewSalt(b.saltView, b.saltConfig) + if err != nil { + return nil, err + } + b.salt = salt + return salt, nil +} + +func (b *Backend) Invalidate() { + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + b.salt = nil +} diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index 4b1912f67e..22c39d4409 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -4,15 +4,20 @@ import ( "bytes" "fmt" "strconv" + "sync" "github.com/hashicorp/go-syslog" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) func Factory(conf *audit.BackendConfig) (audit.Backend, error) { - if conf.Salt == nil { - return nil, fmt.Errorf("Nil salt passed in") + if conf.SaltConfig == nil { + return nil, fmt.Errorf("nil salt config") + } + if conf.SaltView == nil { + return nil, fmt.Errorf("nil salt view") } // Get facility or default to AUTH @@ -64,10 +69,11 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { } b := &Backend{ - logger: logger, + logger: logger, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, formatConfig: audit.FormatterConfig{ Raw: logRaw, - Salt: conf.Salt, HMACAccessor: hmacAccessor, }, } @@ -75,11 +81,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { switch format { case "json": b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } case "jsonx": b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ - Prefix: conf.Config["prefix"], + Prefix: conf.Config["prefix"], + SaltFunc: b.Salt, } } @@ -92,10 +100,19 @@ type Backend struct { formatter audit.AuditFormatter formatConfig audit.FormatterConfig + + saltMutex sync.RWMutex + salt *salt.Salt + saltConfig *salt.Config + saltView logical.Storage } -func (b *Backend) GetHash(data string) string { - return audit.HashString(b.formatConfig.Salt, data) +func (b *Backend) GetHash(data string) (string, error) { + salt, err := b.Salt() + if err != nil { + return "", err + } + return audit.HashString(salt, data), nil } func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { @@ -123,3 +140,29 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo func (b *Backend) Reload() error { return nil } + +func (b *Backend) Salt() (*salt.Salt, error) { + b.saltMutex.RLock() + if b.salt != nil { + defer b.saltMutex.RUnlock() + return b.salt, nil + } + b.saltMutex.RUnlock() + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + if b.salt != nil { + return b.salt, nil + } + salt, err := salt.NewSalt(b.saltView, b.saltConfig) + if err != nil { + return nil, err + } + b.salt = salt + return salt, nil +} + +func (b *Backend) Invalidate() { + b.saltMutex.Lock() + defer b.saltMutex.Unlock() + b.salt = nil +} diff --git a/vault/audit.go b/vault/audit.go index 939184369c..25943da694 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -368,17 +368,16 @@ func (c *Core) newAuditBackend(entry *MountEntry, view logical.Storage, conf map if !ok { return nil, fmt.Errorf("unknown backend type: %s", entry.Type) } - salter, err := salt.NewSalt(view, &salt.Config{ + saltConfig := &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", - }) - if err != nil { - return nil, fmt.Errorf("core: unable to generate salt: %v", err) + Location: salt.DefaultLocation, } be, err := f(&audit.BackendConfig{ - Salt: salter, - Config: conf, + SaltView: view, + SaltConfig: saltConfig, + Config: conf, }) if err != nil { return nil, err @@ -474,20 +473,25 @@ func (a *AuditBroker) GetHash(name string, input string) (string, error) { return "", fmt.Errorf("unknown audit backend %s", name) } - return be.backend.GetHash(input), nil + return be.backend.GetHash(input) } // LogRequest is used to ensure all the audit backends have an opportunity to // log the given request and that *at least one* succeeds. -func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (retErr error) { +func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (ret error) { defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) a.RLock() defer a.RUnlock() + + var retErr *multierror.Error + defer func() { if r := recover(); r != nil { a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) } + + ret = retErr.ErrorOrNil() }() // All logged requests must have an identifier @@ -506,36 +510,46 @@ func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, heade anyLogged := false for name, be := range a.backends { req.Headers = nil - req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) + if thErr != nil { + a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) + continue + } + req.Headers = transHeaders start := time.Now() - err := be.backend.LogRequest(auth, req, outerErr) + lrErr := be.backend.LogRequest(auth, req, outerErr) metrics.MeasureSince([]string{"audit", name, "log_request"}, start) - if err != nil { - a.logger.Error("audit: backend failed to log request", "backend", name, "error", err) + if lrErr != nil { + a.logger.Error("audit: backend failed to log request", "backend", name, "error", lrErr) } else { anyLogged = true } } if !anyLogged && len(a.backends) > 0 { retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request")) - return } - return nil + + return retErr.ErrorOrNil() } // LogResponse is used to ensure all the audit backends have an opportunity to // log the given response and that *at least one* succeeds. func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, - resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (reterr error) { + resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (ret error) { defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) a.RLock() defer a.RUnlock() + + var retErr *multierror.Error + defer func() { if r := recover(); r != nil { a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) - reterr = fmt.Errorf("panic generating audit log") + retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) } + + ret = retErr.ErrorOrNil() }() headers := req.Headers @@ -547,19 +561,35 @@ func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, anyLogged := false for name, be := range a.backends { req.Headers = nil - req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) + if thErr != nil { + a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) + continue + } + req.Headers = transHeaders start := time.Now() - err := be.backend.LogResponse(auth, req, resp, err) + lrErr := be.backend.LogResponse(auth, req, resp, err) metrics.MeasureSince([]string{"audit", name, "log_response"}, start) - if err != nil { - a.logger.Error("audit: backend failed to log response", "backend", name, "error", err) + if lrErr != nil { + a.logger.Error("audit: backend failed to log response", "backend", name, "error", lrErr) } else { anyLogged = true } } if !anyLogged && len(a.backends) > 0 { - return fmt.Errorf("no audit backend succeeded in logging the response") + retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the response")) + } + + return retErr.ErrorOrNil() +} + +func (a *AuditBroker) Invalidate(key string) { + // For now we ignore the key as this would only apply to salts. We just + // sort of brute force it on each one. + a.Lock() + defer a.Unlock() + for _, be := range a.backends { + be.backend.Invalidate() } - return nil } diff --git a/vault/audit_test.go b/vault/audit_test.go index 5e97da86f4..76eba563ae 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -3,6 +3,8 @@ package vault import ( "fmt" "reflect" + "strings" + "sync" "testing" "time" @@ -13,6 +15,7 @@ import ( "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" log "github.com/mgutz/logxi/v1" "github.com/mitchellh/copystructure" @@ -31,6 +34,9 @@ type NoopAudit struct { RespReq []*logical.Request Resp []*logical.Response RespErrs []error + + salt *salt.Salt + saltMutex sync.RWMutex } func (n *NoopAudit) LogRequest(a *logical.Auth, r *logical.Request, err error) error { @@ -49,14 +55,44 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical return n.RespErr } -func (n *NoopAudit) GetHash(data string) string { - return n.Config.Salt.GetIdentifiedHMAC(data) +func (n *NoopAudit) Salt() (*salt.Salt, error) { + n.saltMutex.RLock() + if n.salt != nil { + defer n.saltMutex.RUnlock() + return n.salt, nil + } + n.saltMutex.RUnlock() + n.saltMutex.Lock() + defer n.saltMutex.Unlock() + if n.salt != nil { + return n.salt, nil + } + salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) + if err != nil { + return nil, err + } + n.salt = salt + return salt, nil +} + +func (n *NoopAudit) GetHash(data string) (string, error) { + salt, err := n.Salt() + if err != nil { + return "", err + } + return salt.GetIdentifiedHMAC(data), nil } func (n *NoopAudit) Reload() error { return nil } +func (n *NoopAudit) Invalidate() { + n.saltMutex.Lock() + defer n.saltMutex.Unlock() + n.salt = nil +} + func TestCore_EnableAudit(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { @@ -508,7 +544,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { t.Fatalf("Bad: %#v", a.Resp[0]) } if !reflect.DeepEqual(a.RespErrs[0], respErr) { - t.Fatalf("Bad: %#v", a.RespErrs[0]) + t.Fatalf("Expected\n%v\nGot\n%#v", respErr, a.RespErrs[0]) } } @@ -522,7 +558,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { // Should FAIL work with both failing backends a2.RespErr = fmt.Errorf("failed") err = b.LogResponse(auth, req, resp, headersConf, respErr) - if err.Error() != "no audit backend succeeded in logging the response" { + if !strings.Contains(err.Error(), "no audit backend succeeded in logging the response") { t.Fatalf("err: %v", err) } } diff --git a/vault/audited_headers.go b/vault/audited_headers.go index 781c035626..1e1a11b0bd 100644 --- a/vault/audited_headers.go +++ b/vault/audited_headers.go @@ -88,7 +88,7 @@ func (a *AuditedHeadersConfig) remove(header string) error { // ApplyConfig returns a map of approved headers and their values, either // hmac'ed or plaintext -func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) string) (result map[string][]string) { +func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) (string, error)) (result map[string][]string, retErr error) { // Grab a read lock a.RLock() defer a.RUnlock() @@ -110,7 +110,11 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc // Optionally hmac the values if settings.HMAC { for i, el := range hVals { - hVals[i] = hashFunc(el) + hVal, err := hashFunc(el) + if err != nil { + return nil, err + } + hVals[i] = hVal } } @@ -118,7 +122,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc } } - return + return result, nil } // Initalize the headers config by loading from the barrier view diff --git a/vault/audited_headers_test.go b/vault/audited_headers_test.go index 5e82ec71dc..93225cf62e 100644 --- a/vault/audited_headers_test.go +++ b/vault/audited_headers_test.go @@ -166,9 +166,12 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { "Content-Type": []string{"json"}, } - hashFunc := func(s string) string { return "hashed" } + hashFunc := func(s string) (string, error) { return "hashed", nil } - result := conf.ApplyConfig(reqHeaders, hashFunc) + result, err := conf.ApplyConfig(reqHeaders, hashFunc) + if err != nil { + t.Fatal(err) + } expected := map[string][]string{ "x-test-header": []string{"foo"}, @@ -214,7 +217,7 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { b.Fatal(err) } - hashFunc := func(s string) string { return salter.GetIdentifiedHMAC(s) } + hashFunc := func(s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } // Reset the timer since we did a lot above b.ResetTimer() diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 0a84c04af7..536b4fa392 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1254,13 +1254,11 @@ func TestSystemBackend_auditHash(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - var err error - config.Salt, err = salt.NewSalt(view, &salt.Config{ + config.SaltView = view + config.SaltConfig = &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", - }) - if err != nil { - t.Fatalf("error getting new salt: %v", err) + Location: salt.DefaultLocation, } return &NoopAudit{ Config: config, diff --git a/vault/testing.go b/vault/testing.go index a8c1f16bdc..49e792fb46 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -14,6 +14,7 @@ import ( "os" "os/exec" "path/filepath" + "sync" "testing" "time" @@ -111,14 +112,11 @@ func testCoreConfig(t testing.TB, physicalBackend physical.Backend, logger log.L Key: "salt", Value: []byte("foo"), }) - var err error - config.Salt, err = salt.NewSalt(view, &salt.Config{ + config.SaltConfig = &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", - }) - if err != nil { - t.Fatalf("error getting new salt: %v", err) } + config.SaltView = view return &noopAudit{ Config: config, }, nil @@ -442,11 +440,17 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error { } type noopAudit struct { - Config *audit.BackendConfig + Config *audit.BackendConfig + salt *salt.Salt + saltMutex sync.RWMutex } -func (n *noopAudit) GetHash(data string) string { - return n.Config.Salt.GetIdentifiedHMAC(data) +func (n *noopAudit) GetHash(data string) (string, error) { + salt, err := n.Salt() + if err != nil { + return "", err + } + return salt.GetIdentifiedHMAC(data), nil } func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error { @@ -461,6 +465,32 @@ func (n *noopAudit) Reload() error { return nil } +func (n *noopAudit) Invalidate() { + n.saltMutex.Lock() + defer n.saltMutex.Unlock() + n.salt = nil +} + +func (n *noopAudit) Salt() (*salt.Salt, error) { + n.saltMutex.RLock() + if n.salt != nil { + defer n.saltMutex.RUnlock() + return n.salt, nil + } + n.saltMutex.RUnlock() + n.saltMutex.Lock() + defer n.saltMutex.Unlock() + if n.salt != nil { + return n.salt, nil + } + salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) + if err != nil { + return nil, err + } + n.salt = salt + return salt, nil +} + type rawHTTP struct{} func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) {