Delay salt initialization for audit backends

This commit is contained in:
Jeff Mitchell 2017-05-23 20:36:20 -04:00
parent 41d4c69b54
commit dd26071875
18 changed files with 382 additions and 105 deletions

View File

@ -25,15 +25,21 @@ type Backend interface {
// GetHash is used to return the given data with the backend's hash, // GetHash is used to return the given data with the backend's hash,
// so that a caller can determine if a value in the audit log matches // so that a caller can determine if a value in the audit log matches
// an expected plaintext value // an expected plaintext value
GetHash(string) string GetHash(string) (string, error)
// Reload is called on SIGHUP for supporting backends. // Reload is called on SIGHUP for supporting backends.
Reload() error Reload() error
// Invalidate is called for path invalidation
Invalidate()
} }
type BackendConfig struct { type BackendConfig struct {
// The salt that should be used for any secret obfuscation // The view to store the salt
Salt *salt.Salt SaltView logical.Storage
// The salt config that should be used for any secret obfuscation
SaltConfig *salt.Config
// Config is the opaque user configuration provided when mounting // Config is the opaque user configuration provided when mounting
Config map[string]string Config map[string]string

View File

@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/SermoDigital/jose/jws" "github.com/SermoDigital/jose/jws"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/mitchellh/copystructure" "github.com/mitchellh/copystructure"
) )
@ -14,6 +16,7 @@ import (
type AuditFormatWriter interface { type AuditFormatWriter interface {
WriteRequest(io.Writer, *AuditRequestEntry) error WriteRequest(io.Writer, *AuditRequestEntry) error
WriteResponse(io.Writer, *AuditResponseEntry) error WriteResponse(io.Writer, *AuditResponseEntry) error
Salt() (*salt.Salt, error)
} }
// AuditFormatter implements the Formatter interface, and allows the underlying // AuditFormatter implements the Formatter interface, and allows the underlying
@ -41,6 +44,11 @@ func (f *AuditFormatter) FormatRequest(
return fmt.Errorf("no format writer specified") return fmt.Errorf("no format writer specified")
} }
salt, err := f.Salt()
if err != nil {
return errwrap.Wrapf("error fetching salt: {{err}}", err)
}
if !config.Raw { if !config.Raw {
// Before we copy the structure we must nil out some data // Before we copy the structure we must nil out some data
// otherwise we will cause reflection to panic and die // otherwise we will cause reflection to panic and die
@ -70,7 +78,7 @@ func (f *AuditFormatter) FormatRequest(
// Hash any sensitive information // Hash any sensitive information
if auth != nil { if auth != nil {
if err := Hash(config.Salt, auth); err != nil { if err := Hash(salt, auth); err != nil {
return err return err
} }
} }
@ -80,7 +88,7 @@ func (f *AuditFormatter) FormatRequest(
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor clientTokenAccessor = req.ClientTokenAccessor
} }
if err := Hash(config.Salt, req); err != nil { if err := Hash(salt, req); err != nil {
return err return err
} }
if clientTokenAccessor != "" { if clientTokenAccessor != "" {
@ -152,6 +160,11 @@ func (f *AuditFormatter) FormatResponse(
return fmt.Errorf("no format writer specified") return fmt.Errorf("no format writer specified")
} }
salt, err := f.Salt()
if err != nil {
return errwrap.Wrapf("error fetching salt: {{err}}", err)
}
if !config.Raw { if !config.Raw {
// Before we copy the structure we must nil out some data // Before we copy the structure we must nil out some data
// otherwise we will cause reflection to panic and die // otherwise we will cause reflection to panic and die
@ -195,7 +208,7 @@ func (f *AuditFormatter) FormatResponse(
if !config.HMACAccessor && auth.Accessor != "" { if !config.HMACAccessor && auth.Accessor != "" {
accessor = auth.Accessor accessor = auth.Accessor
} }
if err := Hash(config.Salt, auth); err != nil { if err := Hash(salt, auth); err != nil {
return err return err
} }
if accessor != "" { if accessor != "" {
@ -208,7 +221,7 @@ func (f *AuditFormatter) FormatResponse(
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor clientTokenAccessor = req.ClientTokenAccessor
} }
if err := Hash(config.Salt, req); err != nil { if err := Hash(salt, req); err != nil {
return err return err
} }
if clientTokenAccessor != "" { if clientTokenAccessor != "" {
@ -224,7 +237,7 @@ func (f *AuditFormatter) FormatResponse(
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
wrappedAccessor = resp.WrapInfo.WrappedAccessor wrappedAccessor = resp.WrapInfo.WrappedAccessor
} }
if err := Hash(config.Salt, resp); err != nil { if err := Hash(salt, resp); err != nil {
return err return err
} }
if accessor != "" { if accessor != "" {

View File

@ -4,12 +4,15 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"github.com/hashicorp/vault/helper/salt"
) )
// JSONFormatWriter is an AuditFormatWriter implementation that structures data into // JSONFormatWriter is an AuditFormatWriter implementation that structures data into
// a JSON format. // a JSON format.
type JSONFormatWriter struct { type JSONFormatWriter struct {
Prefix string Prefix string
SaltFunc func() (*salt.Salt, error)
} }
func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error {
@ -43,3 +46,7 @@ func (f *JSONFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry)
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
return enc.Encode(resp) return enc.Encode(resp)
} }
func (f *JSONFormatWriter) Salt() (*salt.Salt, error) {
return f.SaltFunc()
}

View File

@ -15,6 +15,13 @@ import (
) )
func TestFormatJSON_formatRequest(t *testing.T) { func TestFormatJSON_formatRequest(t *testing.T) {
salter, err := salt.NewSalt(nil, nil)
if err != nil {
t.Fatal(err)
}
saltFunc := func() (*salt.Salt, error) {
return salter, nil
}
cases := map[string]struct { cases := map[string]struct {
Auth *logical.Auth Auth *logical.Auth
Req *logical.Request Req *logical.Request
@ -67,12 +74,10 @@ func TestFormatJSON_formatRequest(t *testing.T) {
formatter := AuditFormatter{ formatter := AuditFormatter{
AuditFormatWriter: &JSONFormatWriter{ AuditFormatWriter: &JSONFormatWriter{
Prefix: tc.Prefix, Prefix: tc.Prefix,
SaltFunc: saltFunc,
}, },
} }
salter, _ := salt.NewSalt(nil, nil) config := FormatterConfig{}
config := FormatterConfig{
Salt: salter,
}
if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil {
t.Fatalf("bad: %s\nerr: %s", name, err) t.Fatalf("bad: %s\nerr: %s", name, err)
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/hashicorp/vault/helper/salt"
"github.com/jefferai/jsonx" "github.com/jefferai/jsonx"
) )
@ -12,6 +13,7 @@ import (
// a XML format. // a XML format.
type JSONxFormatWriter struct { type JSONxFormatWriter struct {
Prefix string Prefix string
SaltFunc func() (*salt.Salt, error)
} }
func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error {
@ -65,3 +67,7 @@ func (f *JSONxFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry)
_, err = w.Write(xmlBytes) _, err = w.Write(xmlBytes)
return err return err
} }
func (f *JSONxFormatWriter) Salt() (*salt.Salt, error) {
return f.SaltFunc()
}

View File

@ -13,6 +13,13 @@ import (
) )
func TestFormatJSONx_formatRequest(t *testing.T) { func TestFormatJSONx_formatRequest(t *testing.T) {
salter, err := salt.NewSalt(nil, nil)
if err != nil {
t.Fatal(err)
}
saltFunc := func() (*salt.Salt, error) {
return salter, nil
}
cases := map[string]struct { cases := map[string]struct {
Auth *logical.Auth Auth *logical.Auth
Req *logical.Request Req *logical.Request
@ -68,11 +75,10 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
formatter := AuditFormatter{ formatter := AuditFormatter{
AuditFormatWriter: &JSONxFormatWriter{ AuditFormatWriter: &JSONxFormatWriter{
Prefix: tc.Prefix, Prefix: tc.Prefix,
SaltFunc: saltFunc,
}, },
} }
salter, _ := salt.NewSalt(nil, nil)
config := FormatterConfig{ config := FormatterConfig{
Salt: salter,
OmitTime: true, OmitTime: true,
} }
if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil {

View File

@ -10,6 +10,8 @@ import (
) )
type noopFormatWriter struct { type noopFormatWriter struct {
salt *salt.Salt
SaltFunc func() (*salt.Salt, error)
} }
func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error { func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error {
@ -20,11 +22,20 @@ func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) err
return nil return nil
} }
func TestFormatRequestErrors(t *testing.T) { func (n *noopFormatWriter) Salt() (*salt.Salt, error) {
salter, _ := salt.NewSalt(nil, nil) if n.salt != nil {
config := FormatterConfig{ return n.salt, nil
Salt: salter,
} }
var err error
n.salt, err = salt.NewSalt(nil, nil)
if err != nil {
return nil, err
}
return n.salt, nil
}
func TestFormatRequestErrors(t *testing.T) {
config := FormatterConfig{}
formatter := AuditFormatter{ formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{}, AuditFormatWriter: &noopFormatWriter{},
} }
@ -38,10 +49,7 @@ func TestFormatRequestErrors(t *testing.T) {
} }
func TestFormatResponseErrors(t *testing.T) { func TestFormatResponseErrors(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil) config := FormatterConfig{}
config := FormatterConfig{
Salt: salter,
}
formatter := AuditFormatter{ formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{}, AuditFormatWriter: &noopFormatWriter{},
} }

View File

@ -3,7 +3,6 @@ package audit
import ( import (
"io" "io"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@ -19,7 +18,6 @@ type Formatter interface {
type FormatterConfig struct { type FormatterConfig struct {
Raw bool Raw bool
Salt *salt.Salt
HMACAccessor bool HMACAccessor bool
// This should only ever be used in a testing context // This should only ever be used in a testing context

View File

@ -8,12 +8,16 @@ import (
"sync" "sync"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
func Factory(conf *audit.BackendConfig) (audit.Backend, error) { func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if conf.Salt == nil { if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt") return nil, fmt.Errorf("nil salt config")
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view")
} }
path, ok := conf.Config["file_path"] path, ok := conf.Config["file_path"]
@ -67,9 +71,10 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
b := &Backend{ b := &Backend{
path: path, path: path,
mode: mode, mode: mode,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
formatConfig: audit.FormatterConfig{ formatConfig: audit.FormatterConfig{
Raw: logRaw, Raw: logRaw,
Salt: conf.Salt,
HMACAccessor: hmacAccessor, HMACAccessor: hmacAccessor,
}, },
} }
@ -78,10 +83,12 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
case "json": case "json":
b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
case "jsonx": case "jsonx":
b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
} }
@ -109,10 +116,39 @@ type Backend struct {
fileLock sync.RWMutex fileLock sync.RWMutex
f *os.File f *os.File
mode os.FileMode mode os.FileMode
saltMutex sync.RWMutex
salt *salt.Salt
saltConfig *salt.Config
saltView logical.Storage
} }
func (b *Backend) GetHash(data string) string { func (b *Backend) Salt() (*salt.Salt, error) {
return audit.HashString(b.formatConfig.Salt, data) b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
salt, err := salt.NewSalt(b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = salt
return salt, nil
}
func (b *Backend) GetHash(data string) (string, error) {
salt, err := b.Salt()
if err != nil {
return "", err
}
return audit.HashString(salt, data), nil
} }
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
@ -189,3 +225,9 @@ func (b *Backend) Reload() error {
return b.open() return b.open()
} }
func (b *Backend) Invalidate() {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}

View File

@ -9,11 +9,10 @@ import (
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
) )
func TestAuditFile_fileModeNew(t *testing.T) { func TestAuditFile_fileModeNew(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
modeStr := "0777" modeStr := "0777"
mode, err := strconv.ParseUint(modeStr, 8, 32) mode, err := strconv.ParseUint(modeStr, 8, 32)
@ -28,7 +27,8 @@ func TestAuditFile_fileModeNew(t *testing.T) {
} }
_, err = Factory(&audit.BackendConfig{ _, err = Factory(&audit.BackendConfig{
Salt: salter, SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Config: config, Config: config,
}) })
if err != nil { if err != nil {
@ -45,8 +45,6 @@ func TestAuditFile_fileModeNew(t *testing.T) {
} }
func TestAuditFile_fileModeExisting(t *testing.T) { func TestAuditFile_fileModeExisting(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
f, err := ioutil.TempFile("", "test") f, err := ioutil.TempFile("", "test")
if err != nil { if err != nil {
t.Fatalf("Failure to create test file.") t.Fatalf("Failure to create test file.")
@ -68,8 +66,9 @@ func TestAuditFile_fileModeExisting(t *testing.T) {
} }
_, err = Factory(&audit.BackendConfig{ _, err = Factory(&audit.BackendConfig{
Salt: salter,
Config: config, Config: config,
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -11,12 +11,16 @@ import (
multierror "github.com/hashicorp/go-multierror" multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
func Factory(conf *audit.BackendConfig) (audit.Backend, error) { func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if conf.Salt == nil { if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt passed in") return nil, fmt.Errorf("nil salt config")
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view")
} }
address, ok := conf.Config["address"] address, ok := conf.Config["address"]
@ -75,11 +79,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
b := &Backend{ b := &Backend{
connection: conn, connection: conn,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
formatConfig: audit.FormatterConfig{ formatConfig: audit.FormatterConfig{
Raw: logRaw, Raw: logRaw,
Salt: conf.Salt,
HMACAccessor: hmacAccessor, HMACAccessor: hmacAccessor,
}, },
writeDuration: writeDuration, writeDuration: writeDuration,
address: address, address: address,
socketType: socketType, socketType: socketType,
@ -89,10 +95,12 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
case "json": case "json":
b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
case "jsonx": case "jsonx":
b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
} }
@ -111,10 +119,19 @@ type Backend struct {
socketType string socketType string
sync.Mutex sync.Mutex
saltMutex sync.RWMutex
salt *salt.Salt
saltConfig *salt.Config
saltView logical.Storage
} }
func (b *Backend) GetHash(data string) string { func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(b.formatConfig.Salt, data) salt, err := b.Salt()
if err != nil {
return "", err
}
return audit.HashString(salt, data), nil
} }
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
@ -198,3 +215,29 @@ func (b *Backend) Reload() error {
return err return err
} }
func (b *Backend) Salt() (*salt.Salt, error) {
b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
salt, err := salt.NewSalt(b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = salt
return salt, nil
}
func (b *Backend) Invalidate() {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}

View File

@ -4,15 +4,20 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"github.com/hashicorp/go-syslog" "github.com/hashicorp/go-syslog"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
func Factory(conf *audit.BackendConfig) (audit.Backend, error) { func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if conf.Salt == nil { if conf.SaltConfig == nil {
return nil, fmt.Errorf("Nil salt passed in") return nil, fmt.Errorf("nil salt config")
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view")
} }
// Get facility or default to AUTH // Get facility or default to AUTH
@ -65,9 +70,10 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
b := &Backend{ b := &Backend{
logger: logger, logger: logger,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
formatConfig: audit.FormatterConfig{ formatConfig: audit.FormatterConfig{
Raw: logRaw, Raw: logRaw,
Salt: conf.Salt,
HMACAccessor: hmacAccessor, HMACAccessor: hmacAccessor,
}, },
} }
@ -76,10 +82,12 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
case "json": case "json":
b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
case "jsonx": case "jsonx":
b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{
Prefix: conf.Config["prefix"], Prefix: conf.Config["prefix"],
SaltFunc: b.Salt,
} }
} }
@ -92,10 +100,19 @@ type Backend struct {
formatter audit.AuditFormatter formatter audit.AuditFormatter
formatConfig audit.FormatterConfig formatConfig audit.FormatterConfig
saltMutex sync.RWMutex
salt *salt.Salt
saltConfig *salt.Config
saltView logical.Storage
} }
func (b *Backend) GetHash(data string) string { func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(b.formatConfig.Salt, data) salt, err := b.Salt()
if err != nil {
return "", err
}
return audit.HashString(salt, data), nil
} }
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
@ -123,3 +140,29 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo
func (b *Backend) Reload() error { func (b *Backend) Reload() error {
return nil return nil
} }
func (b *Backend) Salt() (*salt.Salt, error) {
b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
salt, err := salt.NewSalt(b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = salt
return salt, nil
}
func (b *Backend) Invalidate() {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}

View File

@ -368,16 +368,15 @@ func (c *Core) newAuditBackend(entry *MountEntry, view logical.Storage, conf map
if !ok { if !ok {
return nil, fmt.Errorf("unknown backend type: %s", entry.Type) return nil, fmt.Errorf("unknown backend type: %s", entry.Type)
} }
salter, err := salt.NewSalt(view, &salt.Config{ saltConfig := &salt.Config{
HMAC: sha256.New, HMAC: sha256.New,
HMACType: "hmac-sha256", HMACType: "hmac-sha256",
}) Location: salt.DefaultLocation,
if err != nil {
return nil, fmt.Errorf("core: unable to generate salt: %v", err)
} }
be, err := f(&audit.BackendConfig{ be, err := f(&audit.BackendConfig{
Salt: salter, SaltView: view,
SaltConfig: saltConfig,
Config: conf, Config: conf,
}) })
if err != nil { if err != nil {
@ -474,20 +473,25 @@ func (a *AuditBroker) GetHash(name string, input string) (string, error) {
return "", fmt.Errorf("unknown audit backend %s", name) return "", fmt.Errorf("unknown audit backend %s", name)
} }
return be.backend.GetHash(input), nil return be.backend.GetHash(input)
} }
// LogRequest is used to ensure all the audit backends have an opportunity to // LogRequest is used to ensure all the audit backends have an opportunity to
// log the given request and that *at least one* succeeds. // log the given request and that *at least one* succeeds.
func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (retErr error) { func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (ret error) {
defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now())
a.RLock() a.RLock()
defer a.RUnlock() defer a.RUnlock()
var retErr *multierror.Error
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r)
retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log"))
} }
ret = retErr.ErrorOrNil()
}() }()
// All logged requests must have an identifier // All logged requests must have an identifier
@ -506,36 +510,46 @@ func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, heade
anyLogged := false anyLogged := false
for name, be := range a.backends { for name, be := range a.backends {
req.Headers = nil req.Headers = nil
req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash)
if thErr != nil {
a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr)
continue
}
req.Headers = transHeaders
start := time.Now() start := time.Now()
err := be.backend.LogRequest(auth, req, outerErr) lrErr := be.backend.LogRequest(auth, req, outerErr)
metrics.MeasureSince([]string{"audit", name, "log_request"}, start) metrics.MeasureSince([]string{"audit", name, "log_request"}, start)
if err != nil { if lrErr != nil {
a.logger.Error("audit: backend failed to log request", "backend", name, "error", err) a.logger.Error("audit: backend failed to log request", "backend", name, "error", lrErr)
} else { } else {
anyLogged = true anyLogged = true
} }
} }
if !anyLogged && len(a.backends) > 0 { if !anyLogged && len(a.backends) > 0 {
retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request")) retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request"))
return
} }
return nil
return retErr.ErrorOrNil()
} }
// LogResponse is used to ensure all the audit backends have an opportunity to // LogResponse is used to ensure all the audit backends have an opportunity to
// log the given response and that *at least one* succeeds. // log the given response and that *at least one* succeeds.
func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request,
resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (reterr error) { resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (ret error) {
defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now())
a.RLock() a.RLock()
defer a.RUnlock() defer a.RUnlock()
var retErr *multierror.Error
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r)
reterr = fmt.Errorf("panic generating audit log") retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log"))
} }
ret = retErr.ErrorOrNil()
}() }()
headers := req.Headers headers := req.Headers
@ -547,19 +561,35 @@ func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request,
anyLogged := false anyLogged := false
for name, be := range a.backends { for name, be := range a.backends {
req.Headers = nil req.Headers = nil
req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash)
if thErr != nil {
a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr)
continue
}
req.Headers = transHeaders
start := time.Now() start := time.Now()
err := be.backend.LogResponse(auth, req, resp, err) lrErr := be.backend.LogResponse(auth, req, resp, err)
metrics.MeasureSince([]string{"audit", name, "log_response"}, start) metrics.MeasureSince([]string{"audit", name, "log_response"}, start)
if err != nil { if lrErr != nil {
a.logger.Error("audit: backend failed to log response", "backend", name, "error", err) a.logger.Error("audit: backend failed to log response", "backend", name, "error", lrErr)
} else { } else {
anyLogged = true anyLogged = true
} }
} }
if !anyLogged && len(a.backends) > 0 { if !anyLogged && len(a.backends) > 0 {
return fmt.Errorf("no audit backend succeeded in logging the response") retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the response"))
}
return retErr.ErrorOrNil()
}
func (a *AuditBroker) Invalidate(key string) {
// For now we ignore the key as this would only apply to salts. We just
// sort of brute force it on each one.
a.Lock()
defer a.Unlock()
for _, be := range a.backends {
be.backend.Invalidate()
} }
return nil
} }

View File

@ -3,6 +3,8 @@ package vault
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"sync"
"testing" "testing"
"time" "time"
@ -13,6 +15,7 @@ import (
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"github.com/mitchellh/copystructure" "github.com/mitchellh/copystructure"
@ -31,6 +34,9 @@ type NoopAudit struct {
RespReq []*logical.Request RespReq []*logical.Request
Resp []*logical.Response Resp []*logical.Response
RespErrs []error RespErrs []error
salt *salt.Salt
saltMutex sync.RWMutex
} }
func (n *NoopAudit) LogRequest(a *logical.Auth, r *logical.Request, err error) error { func (n *NoopAudit) LogRequest(a *logical.Auth, r *logical.Request, err error) error {
@ -49,14 +55,44 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical
return n.RespErr return n.RespErr
} }
func (n *NoopAudit) GetHash(data string) string { func (n *NoopAudit) Salt() (*salt.Salt, error) {
return n.Config.Salt.GetIdentifiedHMAC(data) n.saltMutex.RLock()
if n.salt != nil {
defer n.saltMutex.RUnlock()
return n.salt, nil
}
n.saltMutex.RUnlock()
n.saltMutex.Lock()
defer n.saltMutex.Unlock()
if n.salt != nil {
return n.salt, nil
}
salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig)
if err != nil {
return nil, err
}
n.salt = salt
return salt, nil
}
func (n *NoopAudit) GetHash(data string) (string, error) {
salt, err := n.Salt()
if err != nil {
return "", err
}
return salt.GetIdentifiedHMAC(data), nil
} }
func (n *NoopAudit) Reload() error { func (n *NoopAudit) Reload() error {
return nil return nil
} }
func (n *NoopAudit) Invalidate() {
n.saltMutex.Lock()
defer n.saltMutex.Unlock()
n.salt = nil
}
func TestCore_EnableAudit(t *testing.T) { func TestCore_EnableAudit(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t) c, keys, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
@ -508,7 +544,7 @@ func TestAuditBroker_LogResponse(t *testing.T) {
t.Fatalf("Bad: %#v", a.Resp[0]) t.Fatalf("Bad: %#v", a.Resp[0])
} }
if !reflect.DeepEqual(a.RespErrs[0], respErr) { if !reflect.DeepEqual(a.RespErrs[0], respErr) {
t.Fatalf("Bad: %#v", a.RespErrs[0]) t.Fatalf("Expected\n%v\nGot\n%#v", respErr, a.RespErrs[0])
} }
} }
@ -522,7 +558,7 @@ func TestAuditBroker_LogResponse(t *testing.T) {
// Should FAIL work with both failing backends // Should FAIL work with both failing backends
a2.RespErr = fmt.Errorf("failed") a2.RespErr = fmt.Errorf("failed")
err = b.LogResponse(auth, req, resp, headersConf, respErr) err = b.LogResponse(auth, req, resp, headersConf, respErr)
if err.Error() != "no audit backend succeeded in logging the response" { if !strings.Contains(err.Error(), "no audit backend succeeded in logging the response") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
} }

View File

@ -88,7 +88,7 @@ func (a *AuditedHeadersConfig) remove(header string) error {
// ApplyConfig returns a map of approved headers and their values, either // ApplyConfig returns a map of approved headers and their values, either
// hmac'ed or plaintext // hmac'ed or plaintext
func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) string) (result map[string][]string) { func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) (string, error)) (result map[string][]string, retErr error) {
// Grab a read lock // Grab a read lock
a.RLock() a.RLock()
defer a.RUnlock() defer a.RUnlock()
@ -110,7 +110,11 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc
// Optionally hmac the values // Optionally hmac the values
if settings.HMAC { if settings.HMAC {
for i, el := range hVals { for i, el := range hVals {
hVals[i] = hashFunc(el) hVal, err := hashFunc(el)
if err != nil {
return nil, err
}
hVals[i] = hVal
} }
} }
@ -118,7 +122,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc
} }
} }
return return result, nil
} }
// Initalize the headers config by loading from the barrier view // Initalize the headers config by loading from the barrier view

View File

@ -166,9 +166,12 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
"Content-Type": []string{"json"}, "Content-Type": []string{"json"},
} }
hashFunc := func(s string) string { return "hashed" } hashFunc := func(s string) (string, error) { return "hashed", nil }
result := conf.ApplyConfig(reqHeaders, hashFunc) result, err := conf.ApplyConfig(reqHeaders, hashFunc)
if err != nil {
t.Fatal(err)
}
expected := map[string][]string{ expected := map[string][]string{
"x-test-header": []string{"foo"}, "x-test-header": []string{"foo"},
@ -214,7 +217,7 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
hashFunc := func(s string) string { return salter.GetIdentifiedHMAC(s) } hashFunc := func(s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil }
// Reset the timer since we did a lot above // Reset the timer since we did a lot above
b.ResetTimer() b.ResetTimer()

View File

@ -1254,13 +1254,11 @@ func TestSystemBackend_auditHash(t *testing.T) {
Key: "salt", Key: "salt",
Value: []byte("foo"), Value: []byte("foo"),
}) })
var err error config.SaltView = view
config.Salt, err = salt.NewSalt(view, &salt.Config{ config.SaltConfig = &salt.Config{
HMAC: sha256.New, HMAC: sha256.New,
HMACType: "hmac-sha256", HMACType: "hmac-sha256",
}) Location: salt.DefaultLocation,
if err != nil {
t.Fatalf("error getting new salt: %v", err)
} }
return &NoopAudit{ return &NoopAudit{
Config: config, Config: config,

View File

@ -14,6 +14,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sync"
"testing" "testing"
"time" "time"
@ -111,14 +112,11 @@ func testCoreConfig(t testing.TB, physicalBackend physical.Backend, logger log.L
Key: "salt", Key: "salt",
Value: []byte("foo"), Value: []byte("foo"),
}) })
var err error config.SaltConfig = &salt.Config{
config.Salt, err = salt.NewSalt(view, &salt.Config{
HMAC: sha256.New, HMAC: sha256.New,
HMACType: "hmac-sha256", HMACType: "hmac-sha256",
})
if err != nil {
t.Fatalf("error getting new salt: %v", err)
} }
config.SaltView = view
return &noopAudit{ return &noopAudit{
Config: config, Config: config,
}, nil }, nil
@ -443,10 +441,16 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error {
type noopAudit struct { type noopAudit struct {
Config *audit.BackendConfig Config *audit.BackendConfig
salt *salt.Salt
saltMutex sync.RWMutex
} }
func (n *noopAudit) GetHash(data string) string { func (n *noopAudit) GetHash(data string) (string, error) {
return n.Config.Salt.GetIdentifiedHMAC(data) salt, err := n.Salt()
if err != nil {
return "", err
}
return salt.GetIdentifiedHMAC(data), nil
} }
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error { func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {
@ -461,6 +465,32 @@ func (n *noopAudit) Reload() error {
return nil return nil
} }
func (n *noopAudit) Invalidate() {
n.saltMutex.Lock()
defer n.saltMutex.Unlock()
n.salt = nil
}
func (n *noopAudit) Salt() (*salt.Salt, error) {
n.saltMutex.RLock()
if n.salt != nil {
defer n.saltMutex.RUnlock()
return n.salt, nil
}
n.saltMutex.RUnlock()
n.saltMutex.Lock()
defer n.saltMutex.Unlock()
if n.salt != nil {
return n.salt, nil
}
salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig)
if err != nil {
return nil, err
}
n.salt = salt
return salt, nil
}
type rawHTTP struct{} type rawHTTP struct{}
func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) { func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) {