diff --git a/audit/audit.go b/audit/audit.go index adbce1344b..9c704042b9 100644 --- a/audit/audit.go +++ b/audit/audit.go @@ -32,4 +32,4 @@ type BackendConfig struct { } // Factory is the factory function to create an audit backend. -type Factory func(BackendConfig) (Backend, error) +type Factory func(*BackendConfig) (Backend, error) diff --git a/audit/hashstructure.go b/audit/hashstructure.go index 65c1eaf6cc..65fa333b34 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -16,7 +16,7 @@ import ( // // The structure is modified in-place. func Hash(salter *salt.Salt, raw interface{}) error { - fn := salter.GetHMAC + fn := salter.GetIdentifiedHMAC switch s := raw.(type) { case *logical.Auth: @@ -86,17 +86,6 @@ func HashStructure(s interface{}, cb HashCallback) (interface{}, error) { // a value. type HashCallback func(string) string -// HashSHA1 returns a HashCallback that hashes data with SHA1 and -// with an optional salt. If salt is a blank string, no salt is used. -/* -func HashSHA1(salt string) HashCallback { - return func(v string) (string, error) { - hashed := sha1.Sum([]byte(v + salt)) - return "sha1:" + hex.EncodeToString(hashed[:]), nil - } -} -*/ - // hashWalker implements interfaces for the reflectwalk package // (github.com/mitchellh/reflectwalk) that can be used to automatically // replace primitives with a hashed value. diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 7dbc8e9ce5..1b652cf44a 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -1,11 +1,13 @@ package audit import ( + "crypto/sha256" "fmt" "reflect" "testing" "time" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" ) @@ -88,7 +90,7 @@ func TestHash(t *testing.T) { }{ { &logical.Auth{ClientToken: "foo"}, - &logical.Auth{ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"}, + &logical.Auth{ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a"}, }, { &logical.Request{ @@ -98,7 +100,7 @@ func TestHash(t *testing.T) { }, &logical.Request{ Data: map[string]interface{}{ - "foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d", + "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", }, }, }, @@ -110,7 +112,7 @@ func TestHash(t *testing.T) { }, &logical.Response{ Data: map[string]interface{}{ - "foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d", + "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", }, }, }, @@ -133,14 +135,22 @@ func TestHash(t *testing.T) { IssueTime: now, }, - ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", + ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a", }, }, } + localSalt, err := salt.NewSalt(nil, &salt.Config{ + HMAC: sha256.New, + HMACType: "hmac-sha256", + StaticSalt: "foo", + }) + if err != nil { + t.Fatalf("Error instantiating salt: %s", err) + } for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - if err := Hash(tc.Input); err != nil { + if err := Hash(localSalt, tc.Input); err != nil { t.Fatalf("err: %s\n\n%s", err, input) } if !reflect.DeepEqual(tc.Input, tc.Output) { @@ -176,8 +186,8 @@ func TestHashWalker(t *testing.T) { } for _, tc := range cases { - output, err := HashStructure(tc.Input, func(string) (string, error) { - return replaceText, nil + output, err := HashStructure(tc.Input, func(string) string { + return replaceText }) if err != nil { t.Fatalf("err: %s\n\n%#v", err, tc.Input) @@ -187,14 +197,3 @@ func TestHashWalker(t *testing.T) { } } } - -func TestHashSHA1(t *testing.T) { - fn := HashSHA1("") - result, err := fn("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if result != "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33" { - t.Fatalf("bad: %#v", result) - } -} diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index 37eb676625..9ee4843cb9 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -13,7 +13,7 @@ import ( "github.com/mitchellh/copystructure" ) -func Factory(conf audit.BackendConfig) (audit.Backend, error) { +func Factory(conf *audit.BackendConfig) (audit.Backend, error) { if conf.Salt == nil { return nil, fmt.Errorf("Nil salt passed in") } diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index 7878561b18..a93df385e7 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -12,7 +12,7 @@ import ( "github.com/mitchellh/copystructure" ) -func Factory(conf audit.BackendConfig) (audit.Backend, error) { +func Factory(conf *audit.BackendConfig) (audit.Backend, error) { if conf.Salt == nil { return nil, fmt.Errorf("Nil salt passed in") } diff --git a/helper/salt/salt.go b/helper/salt/salt.go index 198f0d0d83..b7ca0731c9 100644 --- a/helper/salt/salt.go +++ b/helper/salt/salt.go @@ -27,6 +27,7 @@ type Salt struct { salt string generated bool hmac hash.Hash + hmacType string } type HashFunc func([]byte) []byte @@ -44,6 +45,14 @@ type Config struct { // HMAC allows specification of a hash function to use for // the HMAC helpers HMAC func() hash.Hash + + // String prepended to HMAC strings for identification. + // Required if using HMAC + HMACType string + + // A static string to use if set. If not set, one will be + // generated and persisted. This value will *not* be persisted. + StaticSalt string } // NewSalt creates a new salt based on the configuration @@ -64,35 +73,49 @@ func NewSalt(view logical.Storage, config *Config) (*Salt, error) { config: config, } - // Look for the salt - raw, err := view.Get(config.Location) - if err != nil { - return nil, fmt.Errorf("failed to read salt: %v", err) - } + var raw *logical.StorageEntry + var err error + if config.StaticSalt != "" { + s.salt = config.StaticSalt + } else { + if view != nil { + // Look for the salt + raw, err = view.Get(config.Location) + if err != nil { + return nil, fmt.Errorf("failed to read salt: %v", err) + } - // Restore the salt if it exists - if raw != nil { - s.salt = string(raw.Value) + // Restore the salt if it exists + if raw != nil { + s.salt = string(raw.Value) + } + } } // Generate a new salt if necessary if s.salt == "" { s.salt = uuid.GenerateUUID() s.generated = true - raw = &logical.StorageEntry{ - Key: config.Location, - Value: []byte(s.salt), - } - if err := view.Put(raw); err != nil { - return nil, fmt.Errorf("failed to persist salt: %v", err) + if view != nil { + raw = &logical.StorageEntry{ + Key: config.Location, + Value: []byte(s.salt), + } + if err := view.Put(raw); err != nil { + return nil, fmt.Errorf("failed to persist salt: %v", err) + } } } if config.HMAC != nil { + if len(config.HMACType) == 0 { + return nil, fmt.Errorf("HMACType must be defined") + } s.hmac = hmac.New(config.HMAC, []byte(s.salt)) if s.hmac == nil { return nil, fmt.Errorf("failed to instantiate HMAC function") } + s.hmacType = config.HMACType } return s, nil @@ -104,7 +127,7 @@ func (s *Salt) SaltID(id string) string { return SaltID(s.salt, id, s.config.HashFunc) } -// SaltIDandHMAC is used to apply a salt and hash function to an ID to make sure +// GetHMAC is used to apply a salt and hash function to an ID to make sure // it is not reversible, with an additional HMAC func (s *Salt) GetHMAC(id string) string { if s.hmac == nil { @@ -112,7 +135,19 @@ func (s *Salt) GetHMAC(id string) string { } s.hmac.Reset() s.hmac.Write([]byte(id)) - return string(s.hmac.Sum(nil)) + return hex.EncodeToString(s.hmac.Sum(nil)) +} + +// GetIdentifiedHMAC is used to apply a salt and hash function to an ID to make sure +// it is not reversible, with an additional HMAC, and ID prepended +func (s *Salt) GetIdentifiedHMAC(id string) string { + if s.hmac == nil { + return "" + } + s.hmac.Reset() + s.hmac.Write([]byte(id)) + + return s.hmacType + ":" + hex.EncodeToString(s.hmac.Sum(nil)) } // DidGenerate returns if the underlying salt value was generated diff --git a/vault/audit.go b/vault/audit.go index 50dd615257..b9d044be96 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -209,11 +209,12 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s salter, err := salt.NewSalt(view, &salt.Config{ HashFunc: salt.SHA256Hash, HMAC: sha256.New, + HMACType: "hmac-sha256", }) if err != nil { return nil, fmt.Errorf("[ERR] core: unable to generate salt: %v", err) } - return f(audit.BackendConfig{ + return f(&audit.BackendConfig{ Salt: salter, Config: conf, }) diff --git a/vault/audit_test.go b/vault/audit_test.go index d7c78ed62b..502805e301 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -15,6 +15,7 @@ import ( ) type NoopAudit struct { + Config *audit.BackendConfig ReqErr error ReqAuth []*logical.Auth Req []*logical.Request @@ -44,8 +45,10 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical func TestCore_EnableAudit(t *testing.T) { c, key, _ := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } me := &MountEntry{ @@ -66,8 +69,10 @@ func TestCore_EnableAudit(t *testing.T) { AuditBackends: make(map[string]audit.Factory), DisableMlock: true, } - conf.AuditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + conf.AuditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } c2, err := NewCore(conf) if err != nil { @@ -94,8 +99,10 @@ func TestCore_EnableAudit(t *testing.T) { func TestCore_DisableAudit(t *testing.T) { c, key, _ := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } err := c.disableAudit("foo") diff --git a/vault/core_test.go b/vault/core_test.go index 0d59ae1982..c418f07baa 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -841,7 +841,10 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { // Create a noop audit backend noop := &NoopAudit{} c, _, root := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + noop = &NoopAudit{ + Config: config, + } return noop, nil } @@ -920,7 +923,10 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { return noopBack, nil } - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + noop = &NoopAudit{ + Config: config, + } return noop, nil } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 08ddef7042..fc71beefbd 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -521,8 +521,10 @@ func TestSystemBackend_policyCRUD(t *testing.T) { func TestSystemBackend_enableAudit(t *testing.T) { c, b, _ := testCoreSystemBackend(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } req := logical.TestRequest(t, logical.WriteOperation, "audit/foo") @@ -552,8 +554,10 @@ func TestSystemBackend_enableAudit_invalid(t *testing.T) { func TestSystemBackend_auditTable(t *testing.T) { c, b, _ := testCoreSystemBackend(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } req := logical.TestRequest(t, logical.WriteOperation, "audit/foo") @@ -586,8 +590,10 @@ func TestSystemBackend_auditTable(t *testing.T) { func TestSystemBackend_disableAudit(t *testing.T) { c, b, _ := testCoreSystemBackend(t) - c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) { - return &NoopAudit{}, nil + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil } req := logical.TestRequest(t, logical.WriteOperation, "audit/foo") diff --git a/vault/testing.go b/vault/testing.go index 24a1f5da47..4fb1c54b4a 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -57,8 +57,10 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F // TestCore returns a pure in-memory, uninitialized core for testing. func TestCore(t *testing.T) *Core { noopAudits := map[string]audit.Factory{ - "noop": func(audit.BackendConfig) (audit.Backend, error) { - return new(noopAudit), nil + "noop": func(config *audit.BackendConfig) (audit.Backend, error) { + return &noopAudit{ + Config: config, + }, nil }, } noopBackends := make(map[string]logical.Factory) @@ -240,7 +242,9 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error { return nil } -type noopAudit struct{} +type noopAudit struct { + Config *audit.BackendConfig +} func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error { return nil