diff --git a/cli/commands.go b/cli/commands.go index 9eee004aec..61ba681720 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -6,6 +6,7 @@ import ( auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSocket "github.com/hashicorp/vault/builtin/audit/socket" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" + "github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/version" credAppId "github.com/hashicorp/vault/builtin/credential/app-id" @@ -18,6 +19,23 @@ import ( credRadius "github.com/hashicorp/vault/builtin/credential/radius" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + physAzure "github.com/hashicorp/vault/physical/azure" + physCassandra "github.com/hashicorp/vault/physical/cassandra" + physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb" + physConsul "github.com/hashicorp/vault/physical/consul" + physCouchDB "github.com/hashicorp/vault/physical/couchdb" + physDynamoDB "github.com/hashicorp/vault/physical/dynamodb" + physEtcd "github.com/hashicorp/vault/physical/etcd" + physFile "github.com/hashicorp/vault/physical/file" + physGCS "github.com/hashicorp/vault/physical/gcs" + physInmem "github.com/hashicorp/vault/physical/inmem" + physMSSQL "github.com/hashicorp/vault/physical/mssql" + physMySQL "github.com/hashicorp/vault/physical/mysql" + physPostgreSQL "github.com/hashicorp/vault/physical/postgresql" + physS3 "github.com/hashicorp/vault/physical/s3" + physSwift "github.com/hashicorp/vault/physical/swift" + physZooKeeper "github.com/hashicorp/vault/physical/zookeeper" + "github.com/hashicorp/vault/builtin/logical/aws" "github.com/hashicorp/vault/builtin/logical/cassandra" "github.com/hashicorp/vault/builtin/logical/consul" @@ -63,7 +81,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { }, nil }, "server": func() (cli.Command, error) { - return &command.ServerCommand{ + c := &command.ServerCommand{ Meta: *metaPtr, AuditBackends: map[string]audit.Factory{ "file": auditFile.Factory, @@ -98,9 +116,36 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "totp": totp.Factory, "plugin": plugin.Factory, }, + ShutdownCh: command.MakeShutdownCh(), SighupCh: command.MakeSighupCh(), - }, nil + } + + c.PhysicalBackends = map[string]physical.Factory{ + "azure": physAzure.NewAzureBackend, + "cassandra": physCassandra.NewCassandraBackend, + "cockroachdb": physCockroachDB.NewCockroachDBBackend, + "consul": physConsul.NewConsulBackend, + "couchdb": physCouchDB.NewCouchDBBackend, + "couchdb_transactional": physCouchDB.NewTransactionalCouchDBBackend, + "dynamodb": physDynamoDB.NewDynamoDBBackend, + "etcd": physEtcd.NewEtcdBackend, + "file": physFile.NewFileBackend, + "file_transactional": physFile.NewTransactionalFileBackend, + "gcs": physGCS.NewGCSBackend, + "inmem": physInmem.NewInmem, + "inmem_ha": physInmem.NewInmemHA, + "inmem_transactional": physInmem.NewTransactionalInmem, + "inmem_transactional_ha": physInmem.NewTransactionalInmemHA, + "mssql": physMSSQL.NewMSSQLBackend, + "mysql": physMySQL.NewMySQLBackend, + "postgresql": physPostgreSQL.NewPostgreSQLBackend, + "s3": physS3.NewS3Backend, + "swift": physSwift.NewSwiftBackend, + "zookeeper": physZooKeeper.NewZooKeeperBackend, + } + + return c, nil }, "ssh": func() (cli.Command, error) { diff --git a/command/init.go b/command/init.go index 4c638dc6eb..42002043f4 100644 --- a/command/init.go +++ b/command/init.go @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/consul" ) // InitCommand is a Command that initializes a new Vault server. @@ -36,7 +36,7 @@ func (c *InitCommand) Run(args []string) int { flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "") flags.BoolVar(&check, "check", false, "") flags.BoolVar(&auto, "auto", false, "") - flags.StringVar(&consulServiceName, "consul-service", physical.DefaultServiceName, "") + flags.StringVar(&consulServiceName, "consul-service", consul.DefaultServiceName, "") if err := flags.Parse(args); err != nil { return 1 } diff --git a/command/server.go b/command/server.go index c453a8e5a6..96bcfcfbae 100644 --- a/command/server.go +++ b/command/server.go @@ -53,6 +53,7 @@ type ServerCommand struct { AuditBackends map[string]audit.Factory CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory + PhysicalBackends map[string]physical.Factory ShutdownCh chan struct{} SighupCh chan struct{} @@ -204,8 +205,14 @@ func (c *ServerCommand) Run(args []string) int { } // Initialize the backend - backend, err := physical.NewBackend( - config.Storage.Type, c.logger, config.Storage.Config) + factory, exists := c.PhysicalBackends[config.Storage.Type] + if !exists { + c.Ui.Output(fmt.Sprintf( + "Unknown storage type %s", + config.Storage.Type)) + return 1 + } + backend, err := factory(config.Storage.Config, c.logger) if err != nil { c.Ui.Output(fmt.Sprintf( "Error initializing storage of type %s: %s", @@ -266,8 +273,14 @@ func (c *ServerCommand) Run(args []string) int { // Initialize the separate HA storage backend, if it exists var ok bool if config.HAStorage != nil { - habackend, err := physical.NewBackend( - config.HAStorage.Type, c.logger, config.HAStorage.Config) + factory, exists := c.PhysicalBackends[config.HAStorage.Type] + if !exists { + c.Ui.Output(fmt.Sprintf( + "Unknown HA storage type %s", + config.HAStorage.Type)) + return 1 + } + habackend, err := factory(config.HAStorage.Config, c.logger) if err != nil { c.Ui.Output(fmt.Sprintf( "Error initializing HA storage of type %s: %s", diff --git a/command/server_ha_test.go b/command/server_ha_test.go index 5562191eb5..a9b1188126 100644 --- a/command/server_ha_test.go +++ b/command/server_ha_test.go @@ -9,7 +9,10 @@ import ( "testing" "github.com/hashicorp/vault/meta" + "github.com/hashicorp/vault/physical" "github.com/mitchellh/cli" + + physConsul "github.com/hashicorp/vault/physical/consul" ) // The following tests have a go-metrics/exp manager race condition @@ -19,6 +22,9 @@ func TestServer_CommonHA(t *testing.T) { Meta: meta.Meta{ Ui: ui, }, + PhysicalBackends: map[string]physical.Factory{ + "consul": physConsul.NewConsulBackend, + }, } tmpfile, err := ioutil.TempFile("", "") @@ -47,6 +53,9 @@ func TestServer_GoodSeparateHA(t *testing.T) { Meta: meta.Meta{ Ui: ui, }, + PhysicalBackends: map[string]physical.Factory{ + "consul": physConsul.NewConsulBackend, + }, } tmpfile, err := ioutil.TempFile("", "") @@ -75,6 +84,9 @@ func TestServer_BadSeparateHA(t *testing.T) { Meta: meta.Meta{ Ui: ui, }, + PhysicalBackends: map[string]physical.Factory{ + "consul": physConsul.NewConsulBackend, + }, } tmpfile, err := ioutil.TempFile("", "") diff --git a/command/server_test.go b/command/server_test.go index 9c37e4c575..9a90239011 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -15,7 +15,10 @@ import ( "time" "github.com/hashicorp/vault/meta" + "github.com/hashicorp/vault/physical" "github.com/mitchellh/cli" + + physFile "github.com/hashicorp/vault/physical/file" ) var ( @@ -100,6 +103,9 @@ func TestServer_ReloadListener(t *testing.T) { }, ShutdownCh: MakeShutdownCh(), SighupCh: MakeSighupCh(), + PhysicalBackends: map[string]physical.Factory{ + "file": physFile.NewFileBackend, + }, } finished := false diff --git a/http/logical_test.go b/http/logical_test.go index bbbd892966..e4101a50bf 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" "github.com/hashicorp/vault/vault" ) @@ -83,10 +84,13 @@ func TestLogical_StandbyRedirect(t *testing.T) { // Create an HA Vault logger := logformat.NewVaultLogger(log.LevelTrace) - inmha := physical.NewInmemHA(logger) + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } conf := &vault.CoreConfig{ Physical: inmha, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: addr1, DisableMlock: true, } @@ -108,7 +112,7 @@ func TestLogical_StandbyRedirect(t *testing.T) { // Create a second HA Vault conf2 := &vault.CoreConfig{ Physical: inmha, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: addr2, DisableMlock: true, } diff --git a/logical/testing/testing.go b/logical/testing/testing.go index b2072ea06c..ca52cddd33 100644 --- a/logical/testing/testing.go +++ b/logical/testing/testing.go @@ -15,7 +15,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" "github.com/hashicorp/vault/vault" ) @@ -136,8 +136,14 @@ func Test(tt TestT, c TestCase) { // Create an in-memory Vault core logger := logformat.NewVaultLogger(log.LevelTrace) + phys, err := inmem.NewInmem(nil, logger) + if err != nil { + tt.Fatal(err) + return + } + core, err := vault.NewCore(&vault.CoreConfig{ - Physical: physical.NewInmem(logger), + Physical: phys, LogicalBackends: map[string]logical.Factory{ "test": func(conf *logical.BackendConfig) (logical.Backend, error) { if c.Backend != nil { diff --git a/physical/azure.go b/physical/azure/azure.go similarity index 90% rename from physical/azure.go rename to physical/azure/azure.go index 3bb6827641..f938ae46f0 100644 --- a/physical/azure.go +++ b/physical/azure/azure.go @@ -1,4 +1,4 @@ -package physical +package azure import ( "encoding/base64" @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" ) // MaxBlobSize at this time @@ -27,13 +28,13 @@ var MaxBlobSize = 1024 * 1024 * 4 type AzureBackend struct { container *storage.Container logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newAzureBackend constructs an Azure backend using a pre-existing +// NewAzureBackend constructs an Azure backend using a pre-existing // bucket. Credentials can be provided to the backend, sourced // from the environment, AWS credential files or by IAM role. -func newAzureBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewAzureBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { name := os.Getenv("AZURE_BLOB_CONTAINER") if name == "" { name = conf["container"] @@ -88,13 +89,13 @@ func newAzureBackend(conf map[string]string, logger log.Logger) (Backend, error) a := &AzureBackend{ container: container, logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } return a, nil } // Put is used to insert or update an entry -func (a *AzureBackend) Put(entry *Entry) error { +func (a *AzureBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"azure", "put"}, time.Now()) if len(entry.Value) >= MaxBlobSize { @@ -120,7 +121,7 @@ func (a *AzureBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (a *AzureBackend) Get(key string) (*Entry, error) { +func (a *AzureBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"azure", "get"}, time.Now()) a.permitPool.Acquire() @@ -145,7 +146,7 @@ func (a *AzureBackend) Get(key string) (*Entry, error) { defer reader.Close() data, err := ioutil.ReadAll(reader) - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: data, } diff --git a/physical/azure_test.go b/physical/azure/azure_test.go similarity index 83% rename from physical/azure_test.go rename to physical/azure/azure_test.go index 5d37781793..eb0c510892 100644 --- a/physical/azure_test.go +++ b/physical/azure/azure_test.go @@ -1,4 +1,4 @@ -package physical +package azure import ( "fmt" @@ -8,6 +8,7 @@ import ( cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" storage "github.com/Azure/azure-sdk-for-go/storage" @@ -30,11 +31,11 @@ func TestAzureBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - backend, err := NewBackend("azure", logger, map[string]string{ + backend, err := NewAzureBackend(map[string]string{ "container": name, "accountName": accountName, "accountKey": accountKey, - }) + }, logger) defer func() { blobService := cleanupClient.GetBlobService() @@ -46,6 +47,6 @@ func TestAzureBackend(t *testing.T) { t.Fatalf("err: %s", err) } - testBackend(t, backend) - testBackend_ListPrefix(t, backend) + physical.ExerciseBackend(t, backend) + physical.ExerciseBackend_ListPrefix(t, backend) } diff --git a/physical/cassandra.go b/physical/cassandra/cassandra.go similarity index 94% rename from physical/cassandra.go rename to physical/cassandra/cassandra.go index 5d0a43d800..493e156fa8 100644 --- a/physical/cassandra.go +++ b/physical/cassandra/cassandra.go @@ -1,4 +1,4 @@ -package physical +package cassandra import ( "crypto/tls" @@ -14,6 +14,7 @@ import ( "github.com/armon/go-metrics" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/physical" ) // CassandraBackend is a physical backend that stores data in Cassandra. @@ -24,9 +25,9 @@ type CassandraBackend struct { logger log.Logger } -// newCassandraBackend constructs a Cassandra backend using a pre-existing +// NewCassandraBackend constructs a Cassandra backend using a pre-existing // keyspace and table. -func newCassandraBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewCassandraBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { splitArray := func(v string) []string { return strings.FieldsFunc(v, func(r rune) bool { return r == ',' @@ -230,7 +231,7 @@ func (c *CassandraBackend) bucketName(name string) string { // bucket returns all the prefix buckets the key should be stored at func (c *CassandraBackend) buckets(key string) []string { - vals := append([]string{""}, prefixes(key)...) + vals := append([]string{""}, physical.Prefixes(key)...) for i, v := range vals { vals[i] = c.bucketName(v) } @@ -244,7 +245,7 @@ func (c *CassandraBackend) bucket(key string) string { } // Put is used to insert or update an entry -func (c *CassandraBackend) Put(entry *Entry) error { +func (c *CassandraBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"cassandra", "put"}, time.Now()) // Execute inserts to each key prefix simultaneously @@ -265,7 +266,7 @@ func (c *CassandraBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (c *CassandraBackend) Get(key string) (*Entry, error) { +func (c *CassandraBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"cassandra", "get"}, time.Now()) v := []byte(nil) @@ -278,7 +279,7 @@ func (c *CassandraBackend) Get(key string) (*Entry, error) { return nil, err } - return &Entry{ + return &physical.Entry{ Key: key, Value: v, }, nil diff --git a/physical/cassandra_test.go b/physical/cassandra/cassandra_test.go similarity index 91% rename from physical/cassandra_test.go rename to physical/cassandra/cassandra_test.go index 02294336a8..4e7ef4ab28 100644 --- a/physical/cassandra_test.go +++ b/physical/cassandra/cassandra_test.go @@ -1,4 +1,4 @@ -package physical +package cassandra import ( "fmt" @@ -10,6 +10,7 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -24,16 +25,17 @@ func TestCassandraBackend(t *testing.T) { // Run vault tests logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("cassandra", logger, map[string]string{ + b, err := NewCassandraBackend(map[string]string{ "hosts": hosts, - "protocol_version": "3"}) + "protocol_version": "3", + }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } func TestCassandraBackendBuckets(t *testing.T) { diff --git a/physical/cockroachdb.go b/physical/cockroachdb/cockroachdb.go similarity index 88% rename from physical/cockroachdb.go rename to physical/cockroachdb/cockroachdb.go index b904858cd1..1765e68d59 100644 --- a/physical/cockroachdb.go +++ b/physical/cockroachdb/cockroachdb.go @@ -1,4 +1,4 @@ -package physical +package cockroachdb import ( "database/sql" @@ -12,6 +12,7 @@ import ( "github.com/cockroachdb/cockroach-go/crdb" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" // CockroachDB uses the Postgres SQL driver @@ -26,12 +27,12 @@ type CockroachDBBackend struct { rawStatements map[string]string statements map[string]*sql.Stmt logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newCockroachDBBackend constructs a CockroachDB backend using the given +// NewCockroachDBBackend constructs a CockroachDB backend using the given // API client, server address, credentials, and database. -func newCockroachDBBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the CockroachDB credentials to perform read/write operations. connURL, ok := conf["connection_url"] if !ok || connURL == "" { @@ -83,7 +84,7 @@ func newCockroachDBBackend(conf map[string]string, logger log.Logger) (Backend, }, statements: make(map[string]*sql.Stmt), logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } // Prepare all the statements required @@ -106,7 +107,7 @@ func (c *CockroachDBBackend) prepare(name, query string) error { } // Put is used to insert or update an entry. -func (c *CockroachDBBackend) Put(entry *Entry) error { +func (c *CockroachDBBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"cockroachdb", "put"}, time.Now()) c.permitPool.Acquire() @@ -120,7 +121,7 @@ func (c *CockroachDBBackend) Put(entry *Entry) error { } // Get is used to fetch and entry. -func (c *CockroachDBBackend) Get(key string) (*Entry, error) { +func (c *CockroachDBBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"cockroachdb", "get"}, time.Now()) c.permitPool.Acquire() @@ -135,7 +136,7 @@ func (c *CockroachDBBackend) Get(key string) (*Entry, error) { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: result, } @@ -194,7 +195,7 @@ func (c *CockroachDBBackend) List(prefix string) ([]string, error) { } // Transaction is used to run multiple entries via a transaction -func (c *CockroachDBBackend) Transaction(txns []TxnEntry) error { +func (c *CockroachDBBackend) Transaction(txns []physical.TxnEntry) error { defer metrics.MeasureSince([]string{"cockroachdb", "transaction"}, time.Now()) if len(txns) == 0 { return nil @@ -208,7 +209,7 @@ func (c *CockroachDBBackend) Transaction(txns []TxnEntry) error { }) } -func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []TxnEntry) error { +func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []physical.TxnEntry) error { deleteStmt, err := tx.Prepare(c.rawStatements["delete"]) if err != nil { return err @@ -220,9 +221,9 @@ func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []TxnEntry) error { for _, op := range txns { switch op.Operation { - case DeleteOperation: + case physical.DeleteOperation: _, err = deleteStmt.Exec(op.Entry.Key) - case PutOperation: + case physical.PutOperation: _, err = putStmt.Exec(op.Entry.Key, op.Entry.Value) default: return fmt.Errorf("%q is not a supported transaction operation", op.Operation) diff --git a/physical/cockroachdb_test.go b/physical/cockroachdb/cockroachdb_test.go similarity index 87% rename from physical/cockroachdb_test.go rename to physical/cockroachdb/cockroachdb_test.go index 35e186f5b0..35bcecf746 100644 --- a/physical/cockroachdb_test.go +++ b/physical/cockroachdb/cockroachdb_test.go @@ -1,4 +1,4 @@ -package physical +package cockroachdb import ( "database/sql" @@ -9,6 +9,7 @@ import ( dockertest "gopkg.in/ory-am/dockertest.v3" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" _ "github.com/lib/pq" @@ -73,10 +74,10 @@ func TestCockroachDBBackend(t *testing.T) { // Run vault tests logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("cockroachdb", logger, map[string]string{ + b, err := NewCockroachDBBackend(map[string]string{ "connection_url": connURL, "table": table, - }) + }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) @@ -86,14 +87,14 @@ func TestCockroachDBBackend(t *testing.T) { truncate(t, b) }() - testBackend(t, b) + physical.ExerciseBackend(t, b) truncate(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend_ListPrefix(t, b) truncate(t, b) - testTransactionalBackend(t, b) + physical.ExerciseTransactionalBackend(t, b) } -func truncate(t *testing.T, b Backend) { +func truncate(t *testing.T, b physical.Backend) { crdb := b.(*CockroachDBBackend) _, err := crdb.client.Exec("TRUNCATE TABLE " + crdb.table) if err != nil { diff --git a/physical/consul.go b/physical/consul/consul.go similarity index 95% rename from physical/consul.go rename to physical/consul/consul.go index c77e8de618..8256808b1a 100644 --- a/physical/consul.go +++ b/physical/consul/consul.go @@ -1,4 +1,4 @@ -package physical +package consul import ( "errors" @@ -28,6 +28,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/helper/tlsutil" + "github.com/hashicorp/vault/physical" ) const ( @@ -72,7 +73,7 @@ type ConsulBackend struct { logger log.Logger client *api.Client kv *api.KV - permitPool *PermitPool + permitPool *physical.PermitPool serviceLock sync.RWMutex redirectHost string redirectPort int64 @@ -86,9 +87,9 @@ type ConsulBackend struct { notifySealedCh chan notifyEvent } -// newConsulBackend constructs a Consul backend using the given API client +// NewConsulBackend constructs a Consul backend using the given API client // and the prefix in the KV store. -func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the path in Consul path, ok := conf["path"] if !ok { @@ -227,7 +228,7 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error logger: logger, client: client, kv: client.KV(), - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), serviceName: service, serviceTags: strutil.ParseDedupLowercaseAndSortStrings(tags, ","), checkTimeout: checkTimeout, @@ -295,7 +296,7 @@ func setupTLSConfig(conf map[string]string) (*tls.Config, error) { } // Used to run multiple entries via a transaction -func (c *ConsulBackend) Transaction(txns []TxnEntry) error { +func (c *ConsulBackend) Transaction(txns []physical.TxnEntry) error { if len(txns) == 0 { return nil } @@ -307,9 +308,9 @@ func (c *ConsulBackend) Transaction(txns []TxnEntry) error { Key: c.path + op.Entry.Key, } switch op.Operation { - case DeleteOperation: + case physical.DeleteOperation: cop.Verb = api.KVDelete - case PutOperation: + case physical.PutOperation: cop.Verb = api.KVSet cop.Value = op.Entry.Value default: @@ -339,7 +340,7 @@ func (c *ConsulBackend) Transaction(txns []TxnEntry) error { } // Put is used to insert or update an entry -func (c *ConsulBackend) Put(entry *Entry) error { +func (c *ConsulBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"consul", "put"}, time.Now()) c.permitPool.Acquire() @@ -355,7 +356,7 @@ func (c *ConsulBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (c *ConsulBackend) Get(key string) (*Entry, error) { +func (c *ConsulBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"consul", "get"}, time.Now()) c.permitPool.Acquire() @@ -375,7 +376,7 @@ func (c *ConsulBackend) Get(key string) (*Entry, error) { if pair == nil { return nil, nil } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: pair.Value, } @@ -418,7 +419,7 @@ func (c *ConsulBackend) List(prefix string) ([]string, error) { } // Lock is used for mutual exclusion based on the given key. -func (c *ConsulBackend) LockWith(key, value string) (Lock, error) { +func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) { // Create the lock opts := &api.LockOptions{ Key: c.path + key, @@ -525,7 +526,7 @@ func (c *ConsulBackend) checkDuration() time.Duration { return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) } -func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { +func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction) (err error) { if err := c.setRedirectAddr(redirectAddr); err != nil { return err } @@ -538,7 +539,7 @@ func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownC return nil } -func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { +func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction) { // This defer statement should be executed last. So push it first. defer waitGroup.Done() @@ -655,7 +656,7 @@ func (c *ConsulBackend) serviceID() string { // without any locks held and can be run concurrently, therefore no changes // to ConsulBackend can be made in this method (i.e. wtb const receiver for // compiler enforced safety). -func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc activeFunction, sealedFunc sealedFunction) (serviceID string, err error) { +func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction) (serviceID string, err error) { // Query vault Core for its current state active := activeFunc() sealed := sealedFunc() diff --git a/physical/consul_test.go b/physical/consul/consul_test.go similarity index 94% rename from physical/consul_test.go rename to physical/consul/consul_test.go index 59b129435e..4d3230c754 100644 --- a/physical/consul_test.go +++ b/physical/consul/consul_test.go @@ -1,4 +1,4 @@ -package physical +package consul import ( "fmt" @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" dockertest "gopkg.in/ory-am/dockertest.v2" ) @@ -37,7 +38,7 @@ func testConsulBackend(t *testing.T) *ConsulBackend { func testConsulBackendConfig(t *testing.T, conf *consulConf) *ConsulBackend { logger := logformat.NewVaultLogger(log.LevelTrace) - be, err := newConsulBackend(*conf, logger) + be, err := NewConsulBackend(*conf, logger) if err != nil { t.Fatalf("Expected Consul to initialize: %v", err) } @@ -57,7 +58,7 @@ func testConsul_testConsulBackend(t *testing.T) { } } -func testActiveFunc(activePct float64) activeFunction { +func testActiveFunc(activePct float64) physical.ActiveFunction { return func() bool { var active bool standbyProb := rand.Float64() @@ -68,7 +69,7 @@ func testActiveFunc(activePct float64) activeFunction { } } -func testSealedFunc(sealedPct float64) sealedFunction { +func testSealedFunc(sealedPct float64) physical.SealedFunction { return func() bool { var sealed bool unsealedProb := rand.Float64() @@ -94,7 +95,7 @@ func TestConsul_ServiceTags(t *testing.T) { } logger := logformat.NewVaultLogger(log.LevelTrace) - be, err := newConsulBackend(consulConfig, logger) + be, err := NewConsulBackend(consulConfig, logger) if err != nil { t.Fatal(err) } @@ -182,7 +183,7 @@ func TestConsul_newConsulBackend(t *testing.T) { for _, test := range tests { logger := logformat.NewVaultLogger(log.LevelTrace) - be, err := newConsulBackend(test.consulConfig, logger) + be, err := NewConsulBackend(test.consulConfig, logger) if test.fail { if err == nil { t.Fatalf(`Expected config "%s" to fail`, test.name) @@ -206,7 +207,7 @@ func TestConsul_newConsulBackend(t *testing.T) { } } - var shutdownCh ShutdownChannel + var shutdownCh physical.ShutdownChannel waitGroup := &sync.WaitGroup{} if err := c.RunServiceDiscovery(waitGroup, shutdownCh, test.redirectAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil { t.Fatalf("bad: %v", err) @@ -411,18 +412,18 @@ func TestConsulBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("consul", logger, map[string]string{ + b, err := NewConsulBackend(map[string]string{ "address": conf.Address, "path": randPath, "max_parallel": "256", "token": conf.Token, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } func TestConsulHABackend(t *testing.T) { @@ -452,23 +453,23 @@ func TestConsulHABackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("consul", logger, map[string]string{ + b, err := NewConsulBackend(map[string]string{ "address": conf.Address, "path": randPath, "max_parallel": "-1", "token": conf.Token, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - ha, ok := b.(HABackend) + ha, ok := b.(physical.HABackend) if !ok { t.Fatalf("consul does not implement HABackend") } - testHABackend(t, ha, ha) + physical.ExerciseHABackend(t, ha, ha) - detect, ok := b.(RedirectDetect) + detect, ok := b.(physical.RedirectDetect) if !ok { t.Fatalf("consul does not implement RedirectDetect") } diff --git a/physical/couchdb.go b/physical/couchdb/couchdb.go similarity index 87% rename from physical/couchdb.go rename to physical/couchdb/couchdb.go index 96552256e2..e7f945f118 100644 --- a/physical/couchdb.go +++ b/physical/couchdb/couchdb.go @@ -1,4 +1,4 @@ -package physical +package couchdb import ( "bytes" @@ -15,6 +15,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -22,7 +23,7 @@ import ( type CouchDBBackend struct { logger log.Logger client *couchDBClient - permitPool *PermitPool + permitPool *physical.PermitPool } type couchDBClient struct { @@ -84,7 +85,7 @@ func (m *couchDBClient) put(e couchDBEntry) error { return err } -func (m *couchDBClient) get(key string) (*Entry, error) { +func (m *couchDBClient) get(key string) (*physical.Entry, error) { req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", m.endpoint, url.PathEscape(key)), nil) if err != nil { return nil, err @@ -183,23 +184,23 @@ func buildCouchDBBackend(conf map[string]string, logger log.Logger) (*CouchDBBac Client: cleanhttp.DefaultPooledClient(), }, logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), }, nil } -func newCouchDBBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewCouchDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { return buildCouchDBBackend(conf, logger) } type couchDBEntry struct { - Entry *Entry `json:"entry"` - Rev string `json:"_rev,omitempty"` - ID string `json:"_id"` - Deleted *bool `json:"_deleted,omitempty"` + Entry *physical.Entry `json:"entry"` + Rev string `json:"_rev,omitempty"` + ID string `json:"_id"` + Deleted *bool `json:"_deleted,omitempty"` } // Put is used to insert or update an entry -func (m *CouchDBBackend) Put(entry *Entry) error { +func (m *CouchDBBackend) Put(entry *physical.Entry) error { m.permitPool.Acquire() defer m.permitPool.Release() @@ -207,7 +208,7 @@ func (m *CouchDBBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (m *CouchDBBackend) Get(key string) (*Entry, error) { +func (m *CouchDBBackend) Get(key string) (*physical.Entry, error) { m.permitPool.Acquire() defer m.permitPool.Release() @@ -258,12 +259,12 @@ type TransactionalCouchDBBackend struct { CouchDBBackend } -func newTransactionalCouchDBBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewTransactionalCouchDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { backend, err := buildCouchDBBackend(conf, logger) if err != nil { return nil, err } - backend.permitPool = NewPermitPool(1) + backend.permitPool = physical.NewPermitPool(1) return &TransactionalCouchDBBackend{ CouchDBBackend: *backend, @@ -271,14 +272,14 @@ func newTransactionalCouchDBBackend(conf map[string]string, logger log.Logger) ( } // GetInternal is used to fetch an entry -func (m *CouchDBBackend) GetInternal(key string) (*Entry, error) { +func (m *CouchDBBackend) GetInternal(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"couchdb", "get"}, time.Now()) return m.client.get(key) } // PutInternal is used to insert or update an entry -func (m *CouchDBBackend) PutInternal(entry *Entry) error { +func (m *CouchDBBackend) PutInternal(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"couchdb", "put"}, time.Now()) revision, _ := m.client.rev(url.PathEscape(entry.Key)) diff --git a/physical/couchdb_test.go b/physical/couchdb/couchdb_test.go similarity index 90% rename from physical/couchdb_test.go rename to physical/couchdb/couchdb_test.go index f524641061..de4d05d501 100644 --- a/physical/couchdb_test.go +++ b/physical/couchdb/couchdb_test.go @@ -1,4 +1,4 @@ -package physical +package couchdb import ( "fmt" @@ -10,6 +10,7 @@ import ( "time" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -20,17 +21,17 @@ func TestCouchDBBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("couchdb", logger, map[string]string{ + b, err := NewCouchDBBackend(map[string]string{ "endpoint": endpoint, "username": username, "password": password, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } func TestTransactionalCouchDBBackend(t *testing.T) { @@ -39,17 +40,17 @@ func TestTransactionalCouchDBBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("couchdb_transactional", logger, map[string]string{ + b, err := NewTransactionalCouchDBBackend(map[string]string{ "endpoint": endpoint, "username": username, "password": password, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } func prepareCouchdbDBTestContainer(t *testing.T) (cleanup func(), retAddress, username, password string) { diff --git a/physical/dynamodb.go b/physical/dynamodb/dynamodb.go similarity index 96% rename from physical/dynamodb.go rename to physical/dynamodb/dynamodb.go index 9bfdc6123b..c0b3f3e8c2 100644 --- a/physical/dynamodb.go +++ b/physical/dynamodb/dynamodb.go @@ -1,4 +1,4 @@ -package physical +package dynamodb import ( "fmt" @@ -25,6 +25,7 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/awsutil" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/physical" ) const ( @@ -76,7 +77,7 @@ type DynamoDBBackend struct { recovery bool logger log.Logger haEnabled bool - permitPool *PermitPool + permitPool *physical.PermitPool } // DynamoDBRecord is the representation of a vault entry in @@ -110,9 +111,9 @@ type DynamoDBLockRecord struct { Expires int64 } -// newDynamoDBBackend constructs a DynamoDB backend. If the +// NewDynamoDBBackend constructs a DynamoDB backend. If the // configured DynamoDB table does not exist, it creates it. -func newDynamoDBBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewDynamoDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { table := os.Getenv("AWS_DYNAMODB_TABLE") if table == "" { table = conf["table"] @@ -231,7 +232,7 @@ func newDynamoDBBackend(conf map[string]string, logger log.Logger) (Backend, err return &DynamoDBBackend{ table: table, client: client, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), recovery: recoveryModeBool, haEnabled: haEnabledBool, logger: logger, @@ -239,7 +240,7 @@ func newDynamoDBBackend(conf map[string]string, logger log.Logger) (Backend, err } // Put is used to insert or update an entry -func (d *DynamoDBBackend) Put(entry *Entry) error { +func (d *DynamoDBBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"dynamodb", "put"}, time.Now()) record := DynamoDBRecord{ @@ -257,7 +258,7 @@ func (d *DynamoDBBackend) Put(entry *Entry) error { }, }} - for _, prefix := range prefixes(entry.Key) { + for _, prefix := range physical.Prefixes(entry.Key) { record = DynamoDBRecord{ Path: recordPathForVaultKey(prefix), Key: fmt.Sprintf("%s/", recordKeyForVaultKey(prefix)), @@ -277,7 +278,7 @@ func (d *DynamoDBBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (d *DynamoDBBackend) Get(key string) (*Entry, error) { +func (d *DynamoDBBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"dynamodb", "get"}, time.Now()) d.permitPool.Acquire() @@ -303,7 +304,7 @@ func (d *DynamoDBBackend) Get(key string) (*Entry, error) { return nil, err } - return &Entry{ + return &physical.Entry{ Key: vaultKey(record), Value: record.Value, }, nil @@ -323,7 +324,7 @@ func (d *DynamoDBBackend) Delete(key string) error { }} // clean up now empty 'folders' - prefixes := prefixes(key) + prefixes := physical.Prefixes(key) sort.Sort(sort.Reverse(sort.StringSlice(prefixes))) for _, prefix := range prefixes { hasChildren, err := d.hasChildren(prefix) @@ -422,7 +423,7 @@ func (d *DynamoDBBackend) hasChildren(prefix string) (bool, error) { } // LockWith is used for mutual exclusion based on the given key. -func (d *DynamoDBBackend) LockWith(key, value string) (Lock, error) { +func (d *DynamoDBBackend) LockWith(key, value string) (physical.Lock, error) { identity, err := uuid.GenerateUUID() if err != nil { return nil, err @@ -774,15 +775,3 @@ func unescapeEmptyPath(s string) string { } return s } - -// prefixes returns all parent 'folders' for a given -// vault key. -// e.g. for 'foo/bar/baz', it returns ['foo', 'foo/bar'] -func prefixes(s string) []string { - components := strings.Split(s, "/") - result := []string{} - for i := 1; i < len(components); i++ { - result = append(result, strings.Join(components[:i], "/")) - } - return result -} diff --git a/physical/dynamodb_test.go b/physical/dynamodb/dynamodb_test.go similarity index 94% rename from physical/dynamodb_test.go rename to physical/dynamodb/dynamodb_test.go index daac8c873f..426f23fcae 100644 --- a/physical/dynamodb_test.go +++ b/physical/dynamodb/dynamodb_test.go @@ -1,4 +1,4 @@ -package physical +package dynamodb import ( "fmt" @@ -9,6 +9,7 @@ import ( "time" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v3" @@ -49,20 +50,20 @@ func TestDynamoDBBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("dynamodb", logger, map[string]string{ + b, err := NewDynamoDBBackend(map[string]string{ "access_key": creds.AccessKeyID, "secret_key": creds.SecretAccessKey, "session_token": creds.SessionToken, "table": table, "region": region, "endpoint": endpoint, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } func TestDynamoDBHABackend(t *testing.T) { @@ -95,30 +96,30 @@ func TestDynamoDBHABackend(t *testing.T) { }() logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("dynamodb", logger, map[string]string{ + b, err := NewDynamoDBBackend(map[string]string{ "access_key": creds.AccessKeyID, "secret_key": creds.SecretAccessKey, "session_token": creds.SessionToken, "table": table, "region": region, "endpoint": endpoint, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - ha, ok := b.(HABackend) + ha, ok := b.(physical.HABackend) if !ok { t.Fatalf("dynamodb does not implement HABackend") } - testHABackend(t, ha, ha) + physical.ExerciseHABackend(t, ha, ha) testDynamoDBLockTTL(t, ha) } // Similar to testHABackend, but using internal implementation details to // trigger the lock failure scenario by setting the lock renew period for one // of the locks to a higher value than the lock TTL. -func testDynamoDBLockTTL(t *testing.T, ha HABackend) { +func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) { // Set much smaller lock times to speed up the test. lockTTL := time.Second * 3 renewInterval := time.Second * 1 diff --git a/physical/etcd.go b/physical/etcd/etcd.go similarity index 92% rename from physical/etcd.go rename to physical/etcd/etcd.go index 01a928d407..5d9c26da96 100644 --- a/physical/etcd.go +++ b/physical/etcd/etcd.go @@ -1,4 +1,4 @@ -package physical +package etcd import ( "context" @@ -10,6 +10,7 @@ import ( "github.com/coreos/etcd/client" "github.com/coreos/go-semver/semver" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -22,11 +23,11 @@ var ( EtcdLockHeldError = errors.New("lock already held") EtcdLockNotHeldError = errors.New("lock not held") EtcdSemaphoreKeyRemovedError = errors.New("semaphore key removed before lock aquisition") - EtcdVersionUnknow = errors.New("etcd: unknown API version") + EtcdVersionUnknown = errors.New("etcd: unknown API version") ) -// newEtcdBackend constructs a etcd backend using a given machine address. -func newEtcdBackend(conf map[string]string, logger log.Logger) (Backend, error) { +// NewEtcdBackend constructs a etcd backend using a given machine address. +func NewEtcdBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { var ( apiVersion string ok bool @@ -75,7 +76,7 @@ func newEtcdBackend(conf map[string]string, logger log.Logger) (Backend, error) } return newEtcd3Backend(conf, logger) default: - return nil, EtcdVersionUnknow + return nil, EtcdVersionUnknown } } diff --git a/physical/etcd2.go b/physical/etcd/etcd2.go similarity index 97% rename from physical/etcd2.go rename to physical/etcd/etcd2.go index 4ef4b08c79..4e08615dcd 100644 --- a/physical/etcd2.go +++ b/physical/etcd/etcd2.go @@ -1,4 +1,4 @@ -package physical +package etcd import ( "context" @@ -14,6 +14,7 @@ import ( metrics "github.com/armon/go-metrics" "github.com/coreos/etcd/client" "github.com/coreos/etcd/pkg/transport" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -49,12 +50,12 @@ const ( type Etcd2Backend struct { path string kAPI client.KeysAPI - permitPool *PermitPool + permitPool *physical.PermitPool logger log.Logger haEnabled bool } -func newEtcd2Backend(conf map[string]string, logger log.Logger) (Backend, error) { +func newEtcd2Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the etcd path form the configuration. path, ok := conf["path"] if !ok { @@ -110,7 +111,7 @@ func newEtcd2Backend(conf map[string]string, logger log.Logger) (Backend, error) return &Etcd2Backend{ path: path, kAPI: kAPI, - permitPool: NewPermitPool(DefaultParallelOperations), + permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), logger: logger, haEnabled: haEnabledBool, }, nil @@ -169,7 +170,7 @@ func newEtcdV2Client(conf map[string]string) (client.Client, error) { } // Put is used to insert or update an entry. -func (c *Etcd2Backend) Put(entry *Entry) error { +func (c *Etcd2Backend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"etcd", "put"}, time.Now()) value := base64.StdEncoding.EncodeToString(entry.Value) @@ -181,7 +182,7 @@ func (c *Etcd2Backend) Put(entry *Entry) error { } // Get is used to fetch an entry. -func (c *Etcd2Backend) Get(key string) (*Entry, error) { +func (c *Etcd2Backend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"etcd", "get"}, time.Now()) c.permitPool.Acquire() @@ -206,7 +207,7 @@ func (c *Etcd2Backend) Get(key string) (*Entry, error) { } // Construct and return a new entry. - return &Entry{ + return &physical.Entry{ Key: key, Value: value, }, nil @@ -290,7 +291,7 @@ func (b *Etcd2Backend) nodePathLock(key string) string { } // Lock is used for mutual exclusion based on the given key. -func (c *Etcd2Backend) LockWith(key, value string) (Lock, error) { +func (c *Etcd2Backend) LockWith(key, value string) (physical.Lock, error) { return &Etcd2Lock{ kAPI: c.kAPI, value: value, diff --git a/physical/etcd3.go b/physical/etcd/etcd3.go similarity index 94% rename from physical/etcd3.go rename to physical/etcd/etcd3.go index daf5015918..04944e59f4 100644 --- a/physical/etcd3.go +++ b/physical/etcd/etcd3.go @@ -1,4 +1,4 @@ -package physical +package etcd import ( "errors" @@ -15,6 +15,7 @@ import ( "github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/pkg/transport" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "golang.org/x/net/context" ) @@ -27,7 +28,7 @@ type EtcdBackend struct { path string haEnabled bool - permitPool *PermitPool + permitPool *physical.PermitPool etcd *clientv3.Client } @@ -41,7 +42,7 @@ const ( ) // newEtcd3Backend constructs a etcd3 backend. -func newEtcd3Backend(conf map[string]string, logger log.Logger) (Backend, error) { +func newEtcd3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the etcd path form the configuration. path, ok := conf["path"] if !ok { @@ -133,13 +134,13 @@ func newEtcd3Backend(conf map[string]string, logger log.Logger) (Backend, error) return &EtcdBackend{ path: path, etcd: etcd, - permitPool: NewPermitPool(DefaultParallelOperations), + permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), logger: logger, haEnabled: haEnabledBool, }, nil } -func (c *EtcdBackend) Put(entry *Entry) error { +func (c *EtcdBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"etcd", "put"}, time.Now()) c.permitPool.Acquire() @@ -151,7 +152,7 @@ func (c *EtcdBackend) Put(entry *Entry) error { return err } -func (c *EtcdBackend) Get(key string) (*Entry, error) { +func (c *EtcdBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"etcd", "get"}, time.Now()) c.permitPool.Acquire() @@ -170,7 +171,7 @@ func (c *EtcdBackend) Get(key string) (*Entry, error) { if len(resp.Kvs) > 1 { return nil, errors.New("unexpected number of keys from a get request") } - return &Entry{ + return &physical.Entry{ Key: key, Value: resp.Kvs[0].Value, }, nil @@ -242,7 +243,7 @@ type EtcdLock struct { } // Lock is used for mutual exclusion based on the given key. -func (c *EtcdBackend) LockWith(key, value string) (Lock, error) { +func (c *EtcdBackend) LockWith(key, value string) (physical.Lock, error) { session, err := concurrency.NewSession(c.etcd, concurrency.WithTTL(etcd3LockTimeoutInSeconds)) if err != nil { return nil, err diff --git a/physical/etcd3_test.go b/physical/etcd/etcd3_test.go similarity index 66% rename from physical/etcd3_test.go rename to physical/etcd/etcd3_test.go index 0724091ad6..fbd842da1e 100644 --- a/physical/etcd3_test.go +++ b/physical/etcd/etcd3_test.go @@ -1,4 +1,4 @@ -package physical +package etcd import ( "fmt" @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -18,20 +19,20 @@ func TestEtcd3Backend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("etcd", logger, map[string]string{ + b, err := NewEtcdBackend(map[string]string{ "path": fmt.Sprintf("/vault-%d", time.Now().Unix()), "etcd_api": "3", - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) - ha, ok := b.(HABackend) + ha, ok := b.(physical.HABackend) if !ok { t.Fatalf("etcd3 does not implement HABackend") } - testHABackend(t, ha, ha) + physical.ExerciseHABackend(t, ha, ha) } diff --git a/physical/etcd_test.go b/physical/etcd/etcd_test.go similarity index 82% rename from physical/etcd_test.go rename to physical/etcd/etcd_test.go index adddac2b44..d5c30bb6fc 100644 --- a/physical/etcd_test.go +++ b/physical/etcd/etcd_test.go @@ -1,4 +1,4 @@ -package physical +package etcd import ( "fmt" @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "github.com/coreos/etcd/client" @@ -52,19 +53,19 @@ func TestEtcdBackend(t *testing.T) { // need to provide it explicitly. logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("etcd", logger, map[string]string{ + b, err := NewEtcdBackend(map[string]string{ "path": randPath, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) - ha, ok := b.(HABackend) + ha, ok := b.(physical.HABackend) if !ok { t.Fatalf("etcd does not implement HABackend") } - testHABackend(t, ha, ha) + physical.ExerciseHABackend(t, ha, ha) } diff --git a/physical/file.go b/physical/file/file.go similarity index 84% rename from physical/file.go rename to physical/file/file.go index cf22b83aa0..df05dba64c 100644 --- a/physical/file.go +++ b/physical/file/file.go @@ -1,4 +1,4 @@ -package physical +package file import ( "encoding/json" @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/physical" ) // FileBackend is a physical backend that stores data on disk @@ -26,15 +27,15 @@ type FileBackend struct { sync.RWMutex path string logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } type TransactionalFileBackend struct { FileBackend } -// newFileBackend constructs a FileBackend using the given directory -func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) { +// NewFileBackend constructs a FileBackend using the given directory +func NewFileBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { path, ok := conf["path"] if !ok { return nil, fmt.Errorf("'path' must be set") @@ -43,11 +44,11 @@ func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) return &FileBackend{ path: path, logger: logger, - permitPool: NewPermitPool(DefaultParallelOperations), + permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), }, nil } -func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewTransactionalFileBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { path, ok := conf["path"] if !ok { return nil, fmt.Errorf("'path' must be set") @@ -58,7 +59,7 @@ func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Bac FileBackend: FileBackend{ path: path, logger: logger, - permitPool: NewPermitPool(1), + permitPool: physical.NewPermitPool(1), }, }, nil } @@ -132,7 +133,7 @@ func (b *FileBackend) cleanupLogicalPath(path string) error { return nil } -func (b *FileBackend) Get(k string) (*Entry, error) { +func (b *FileBackend) Get(k string) (*physical.Entry, error) { b.permitPool.Acquire() defer b.permitPool.Release() @@ -142,7 +143,7 @@ func (b *FileBackend) Get(k string) (*Entry, error) { return b.GetInternal(k) } -func (b *FileBackend) GetInternal(k string) (*Entry, error) { +func (b *FileBackend) GetInternal(k string) (*physical.Entry, error) { if err := b.validatePath(k); err != nil { return nil, err } @@ -162,7 +163,7 @@ func (b *FileBackend) GetInternal(k string) (*Entry, error) { return nil, err } - var entry Entry + var entry physical.Entry if err := jsonutil.DecodeJSONFromReader(f, &entry); err != nil { return nil, err } @@ -170,7 +171,7 @@ func (b *FileBackend) GetInternal(k string) (*Entry, error) { return &entry, nil } -func (b *FileBackend) Put(entry *Entry) error { +func (b *FileBackend) Put(entry *physical.Entry) error { b.permitPool.Acquire() defer b.permitPool.Release() @@ -180,7 +181,7 @@ func (b *FileBackend) Put(entry *Entry) error { return b.PutInternal(entry) } -func (b *FileBackend) PutInternal(entry *Entry) error { +func (b *FileBackend) PutInternal(entry *physical.Entry) error { if err := b.validatePath(entry.Key); err != nil { return err } @@ -272,12 +273,12 @@ func (b *FileBackend) validatePath(path string) error { return nil } -func (b *TransactionalFileBackend) Transaction(txns []TxnEntry) error { +func (b *TransactionalFileBackend) Transaction(txns []physical.TxnEntry) error { b.permitPool.Acquire() defer b.permitPool.Release() b.Lock() defer b.Unlock() - return genericTransactionHandler(b, txns) + return physical.GenericTransactionHandler(b, txns) } diff --git a/physical/file_test.go b/physical/file/file_test.go similarity index 90% rename from physical/file_test.go rename to physical/file/file_test.go index a405c5bbf2..6438e213ca 100644 --- a/physical/file_test.go +++ b/physical/file/file_test.go @@ -1,4 +1,4 @@ -package physical +package file import ( "encoding/json" @@ -9,6 +9,7 @@ import ( "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -21,9 +22,9 @@ func TestFileBackend_Base64URLEncoding(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("file", logger, map[string]string{ + b, err := NewFileBackend(map[string]string{ "path": backendPath, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } @@ -39,7 +40,7 @@ func TestFileBackend_Base64URLEncoding(t *testing.T) { // Create a storage entry without base64 encoding the file name rawFullPath := filepath.Join(backendPath, "_foo") - e := &Entry{Key: "foo", Value: []byte("test")} + e := &physical.Entry{Key: "foo", Value: []byte("test")} f, err := os.OpenFile( rawFullPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, @@ -140,9 +141,9 @@ func TestFileBackend_ValidatePath(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("file", logger, map[string]string{ + b, err := NewFileBackend(map[string]string{ "path": dir, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } @@ -164,13 +165,13 @@ func TestFileBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("file", logger, map[string]string{ + b, err := NewFileBackend(map[string]string{ "path": dir, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/gcs.go b/physical/gcs/gcs.go similarity index 91% rename from physical/gcs.go rename to physical/gcs/gcs.go index e4d418753d..27125b4716 100644 --- a/physical/gcs.go +++ b/physical/gcs/gcs.go @@ -1,4 +1,4 @@ -package physical +package gcs import ( "fmt" @@ -10,6 +10,7 @@ import ( "time" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "cloud.google.com/go/storage" @@ -24,15 +25,14 @@ import ( type GCSBackend struct { bucketName string client *storage.Client - permitPool *PermitPool + permitPool *physical.PermitPool logger log.Logger } -// newGCSBackend constructs a Google Cloud Storage backend using a pre-existing +// NewGCSBackend constructs a Google Cloud Storage backend using a pre-existing // bucket. Credentials can be provided to the backend, sourced // from environment variables or a service account file -func newGCSBackend(conf map[string]string, logger log.Logger) (Backend, error) { - +func NewGCSBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { bucketName := os.Getenv("GOOGLE_STORAGE_BUCKET") if bucketName == "" { @@ -81,7 +81,7 @@ func newGCSBackend(conf map[string]string, logger log.Logger) (Backend, error) { g := GCSBackend{ bucketName: bucketName, client: client, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), logger: logger, } @@ -89,7 +89,7 @@ func newGCSBackend(conf map[string]string, logger log.Logger) (Backend, error) { } // Put is used to insert or update an entry -func (g *GCSBackend) Put(entry *Entry) error { +func (g *GCSBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"gcs", "put"}, time.Now()) bucket := g.client.Bucket(g.bucketName) @@ -105,7 +105,7 @@ func (g *GCSBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (g *GCSBackend) Get(key string) (*Entry, error) { +func (g *GCSBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"gcs", "get"}, time.Now()) bucket := g.client.Bucket(g.bucketName) @@ -127,7 +127,7 @@ func (g *GCSBackend) Get(key string) (*Entry, error) { return nil, fmt.Errorf("error reading object '%v': '%v'", key, err) } - ent := Entry{ + ent := physical.Entry{ Key: key, Value: value, } diff --git a/physical/gcs_test.go b/physical/gcs/gcs_test.go similarity index 82% rename from physical/gcs_test.go rename to physical/gcs/gcs_test.go index 23c4d3aff9..9a602fc104 100644 --- a/physical/gcs_test.go +++ b/physical/gcs/gcs_test.go @@ -1,4 +1,4 @@ -package physical +package gcs import ( "fmt" @@ -11,14 +11,15 @@ import ( "cloud.google.com/go/storage" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" "golang.org/x/net/context" "google.golang.org/api/iterator" "google.golang.org/api/option" ) -var ConsistencyDelays = delays{ - beforeList: 5 * time.Second, - beforeGet: 0 * time.Second, +var ConsistencyDelays = physical.Delays{ + BeforeList: 5 * time.Second, + BeforeGet: 0 * time.Second, } func TestGCSBackend(t *testing.T) { @@ -54,7 +55,7 @@ func TestGCSBackend(t *testing.T) { defer func() { objects_it := bucket.Objects(context.Background(), nil) - time.Sleep(ConsistencyDelays.beforeList) + time.Sleep(ConsistencyDelays.BeforeList) // have to delete all objects before deleting bucket for { objAttrs, err := objects_it.Next() @@ -71,7 +72,7 @@ func TestGCSBackend(t *testing.T) { } // not a list operation, but google lists to make sure the bucket is empty on delete - time.Sleep(ConsistencyDelays.beforeList) + time.Sleep(ConsistencyDelays.BeforeList) err := bucket.Delete(context.Background()) if err != nil { t.Fatalf("error deleting bucket '%s': '%v'", bucketName, err) @@ -80,16 +81,16 @@ func TestGCSBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("gcs", logger, map[string]string{ + b, err := NewGCSBackend(map[string]string{ "bucket": bucketName, "credentials_file": credentialsFile, - }) + }, logger) if err != nil { t.Fatalf("error creating google cloud storage backend: '%s'", err) } - testEventuallyConsistentBackend(t, b, ConsistencyDelays) - testEventuallyConsistentBackend_ListPrefix(t, b, ConsistencyDelays) + physical.ExerciseEventuallyConsistentBackend(t, b, ConsistencyDelays) + physical.ExerciseEventuallyConsistentBackend_ListPrefix(t, b, ConsistencyDelays) } diff --git a/physical/cache_test.go b/physical/inmem/cache_test.go similarity index 75% rename from physical/cache_test.go rename to physical/inmem/cache_test.go index 151cf99f9a..c771f03920 100644 --- a/physical/cache_test.go +++ b/physical/inmem/cache_test.go @@ -1,32 +1,39 @@ -package physical +package inmem import ( "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) func TestCache(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - inm := NewInmem(logger) - cache := NewCache(inm, 0, logger) - testBackend(t, cache) - testBackend_ListPrefix(t, cache) + inm, err := NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + cache := physical.NewCache(inm, 0, logger) + physical.ExerciseBackend(t, cache) + physical.ExerciseBackend_ListPrefix(t, cache) } func TestCache_Purge(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - inm := NewInmem(logger) - cache := NewCache(inm, 0, logger) + inm, err := NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + cache := physical.NewCache(inm, 0, logger) - ent := &Entry{ + ent := &physical.Entry{ Key: "foo", Value: []byte("bar"), } - err := cache.Put(ent) + err = cache.Put(ent) if err != nil { t.Fatalf("err: %v", err) } @@ -59,21 +66,24 @@ func TestCache_Purge(t *testing.T) { func TestCache_IgnoreCore(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - inm := NewInmem(logger) - cache := NewCache(inm, 0, logger) + inm, err := NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } - var ent *Entry - var err error + cache := physical.NewCache(inm, 0, logger) + + var ent *physical.Entry // First try normal handling - ent = &Entry{ + ent = &physical.Entry{ Key: "foo", Value: []byte("bar"), } if err := cache.Put(ent); err != nil { t.Fatal(err) } - ent = &Entry{ + ent = &physical.Entry{ Key: "foo", Value: []byte("foobar"), } @@ -89,14 +99,14 @@ func TestCache_IgnoreCore(t *testing.T) { } // Now try core path - ent = &Entry{ + ent = &physical.Entry{ Key: "core/foo", Value: []byte("bar"), } if err := cache.Put(ent); err != nil { t.Fatal(err) } - ent = &Entry{ + ent = &physical.Entry{ Key: "core/foo", Value: []byte("foobar"), } @@ -112,7 +122,7 @@ func TestCache_IgnoreCore(t *testing.T) { } // Now make sure looked-up values aren't added - ent = &Entry{ + ent = &physical.Entry{ Key: "core/zip", Value: []byte("zap"), } @@ -126,7 +136,7 @@ func TestCache_IgnoreCore(t *testing.T) { if string(ent.Value) != "zap" { t.Fatal("expected non-cached value") } - ent = &Entry{ + ent = &physical.Entry{ Key: "core/zip", Value: []byte("zipzap"), } diff --git a/physical/inmem.go b/physical/inmem/inmem.go similarity index 73% rename from physical/inmem.go rename to physical/inmem/inmem.go index 47f18ebc4a..d4f92019cc 100644 --- a/physical/inmem.go +++ b/physical/inmem/inmem.go @@ -1,9 +1,10 @@ -package physical +package inmem import ( "strings" "sync" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "github.com/armon/go-radix" @@ -15,7 +16,7 @@ import ( type InmemBackend struct { sync.RWMutex root *radix.Tree - permitPool *PermitPool + permitPool *physical.PermitPool logger log.Logger } @@ -24,30 +25,30 @@ type TransactionalInmemBackend struct { } // NewInmem constructs a new in-memory backend -func NewInmem(logger log.Logger) *InmemBackend { +func NewInmem(_ map[string]string, logger log.Logger) (physical.Backend, error) { in := &InmemBackend{ root: radix.New(), - permitPool: NewPermitPool(DefaultParallelOperations), + permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), logger: logger, } - return in + return in, nil } // Basically for now just creates a permit pool of size 1 so only one operation // can run at a time -func NewTransactionalInmem(logger log.Logger) *TransactionalInmemBackend { +func NewTransactionalInmem(_ map[string]string, logger log.Logger) (physical.Backend, error) { in := &TransactionalInmemBackend{ InmemBackend: InmemBackend{ root: radix.New(), - permitPool: NewPermitPool(1), + permitPool: physical.NewPermitPool(1), logger: logger, }, } - return in + return in, nil } // Put is used to insert or update an entry -func (i *InmemBackend) Put(entry *Entry) error { +func (i *InmemBackend) Put(entry *physical.Entry) error { i.permitPool.Acquire() defer i.permitPool.Release() @@ -57,13 +58,13 @@ func (i *InmemBackend) Put(entry *Entry) error { return i.PutInternal(entry) } -func (i *InmemBackend) PutInternal(entry *Entry) error { +func (i *InmemBackend) PutInternal(entry *physical.Entry) error { i.root.Insert(entry.Key, entry) return nil } // Get is used to fetch an entry -func (i *InmemBackend) Get(key string) (*Entry, error) { +func (i *InmemBackend) Get(key string) (*physical.Entry, error) { i.permitPool.Acquire() defer i.permitPool.Release() @@ -73,9 +74,9 @@ func (i *InmemBackend) Get(key string) (*Entry, error) { return i.GetInternal(key) } -func (i *InmemBackend) GetInternal(key string) (*Entry, error) { +func (i *InmemBackend) GetInternal(key string) (*physical.Entry, error) { if raw, ok := i.root.Get(key); ok { - return raw.(*Entry), nil + return raw.(*physical.Entry), nil } return nil, nil } @@ -131,12 +132,12 @@ func (i *InmemBackend) ListInternal(prefix string) ([]string, error) { } // Implements the transaction interface -func (t *TransactionalInmemBackend) Transaction(txns []TxnEntry) error { +func (t *TransactionalInmemBackend) Transaction(txns []physical.TxnEntry) error { t.permitPool.Acquire() defer t.permitPool.Release() t.Lock() defer t.Unlock() - return genericTransactionHandler(t, txns) + return physical.GenericTransactionHandler(t, txns) } diff --git a/physical/inmem_ha.go b/physical/inmem/inmem_ha.go similarity index 81% rename from physical/inmem_ha.go rename to physical/inmem/inmem_ha.go index bc691c59a9..d322da229e 100644 --- a/physical/inmem_ha.go +++ b/physical/inmem/inmem_ha.go @@ -1,14 +1,15 @@ -package physical +package inmem import ( "fmt" "sync" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) type InmemHABackend struct { - Backend + physical.Backend locks map[string]string l sync.Mutex cond *sync.Cond @@ -16,23 +17,31 @@ type InmemHABackend struct { } type TransactionalInmemHABackend struct { - Transactional + physical.Transactional InmemHABackend } // NewInmemHA constructs a new in-memory HA backend. This is only for testing. -func NewInmemHA(logger log.Logger) *InmemHABackend { +func NewInmemHA(_ map[string]string, logger log.Logger) (physical.Backend, error) { + be, err := NewInmem(nil, logger) + if err != nil { + return nil, err + } + in := &InmemHABackend{ - Backend: NewInmem(logger), + Backend: be, locks: make(map[string]string), logger: logger, } in.cond = sync.NewCond(&in.l) - return in + return in, nil } -func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend { - transInmem := NewTransactionalInmem(logger) +func NewTransactionalInmemHA(_ map[string]string, logger log.Logger) (physical.Backend, error) { + transInmem, err := NewTransactionalInmem(nil, logger) + if err != nil { + return nil, err + } inmemHA := InmemHABackend{ Backend: transInmem, locks: make(map[string]string), @@ -41,14 +50,14 @@ func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend { in := &TransactionalInmemHABackend{ InmemHABackend: inmemHA, - Transactional: transInmem, + Transactional: transInmem.(physical.Transactional), } in.cond = sync.NewCond(&in.l) - return in + return in, nil } // LockWith is used for mutual exclusion based on the given key. -func (i *InmemHABackend) LockWith(key, value string) (Lock, error) { +func (i *InmemHABackend) LockWith(key, value string) (physical.Lock, error) { l := &InmemLock{ in: i, key: key, diff --git a/physical/inmem/inmem_ha_test.go b/physical/inmem/inmem_ha_test.go new file mode 100644 index 0000000000..8288595945 --- /dev/null +++ b/physical/inmem/inmem_ha_test.go @@ -0,0 +1,19 @@ +package inmem + +import ( + "testing" + + "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" + log "github.com/mgutz/logxi/v1" +) + +func TestInmemHA(t *testing.T) { + logger := logformat.NewVaultLogger(log.LevelTrace) + + inm, err := NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + physical.ExerciseHABackend(t, inm.(physical.HABackend), inm.(physical.HABackend)) +} diff --git a/physical/inmem/inmem_test.go b/physical/inmem/inmem_test.go new file mode 100644 index 0000000000..998061ba92 --- /dev/null +++ b/physical/inmem/inmem_test.go @@ -0,0 +1,20 @@ +package inmem + +import ( + "testing" + + "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" + log "github.com/mgutz/logxi/v1" +) + +func TestInmem(t *testing.T) { + logger := logformat.NewVaultLogger(log.LevelTrace) + + inm, err := NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + physical.ExerciseBackend(t, inm) + physical.ExerciseBackend_ListPrefix(t, inm) +} diff --git a/physical/physical_view_test.go b/physical/inmem/physical_view_test.go similarity index 76% rename from physical/physical_view_test.go rename to physical/inmem/physical_view_test.go index dbfed6efcc..719642acaf 100644 --- a/physical/physical_view_test.go +++ b/physical/inmem/physical_view_test.go @@ -1,26 +1,30 @@ -package physical +package inmem import ( "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) func TestPhysicalView_impl(t *testing.T) { - var _ Backend = new(View) + var _ physical.Backend = new(physical.View) } -func newInmemTestBackend() *InmemBackend { +func newInmemTestBackend() (physical.Backend, error) { logger := logformat.NewVaultLogger(log.LevelTrace) - return NewInmem(logger) + return NewInmem(nil, logger) } func TestPhysicalView_BadKeysKeys(t *testing.T) { - backend := newInmemTestBackend() - view := NewView(backend, "foo/") + backend, err := newInmemTestBackend() + if err != nil { + t.Fatal(err) + } + view := physical.NewView(backend, "foo/") - _, err := view.List("../") + _, err = view.List("../") if err == nil { t.Fatalf("expected error") } @@ -35,7 +39,7 @@ func TestPhysicalView_BadKeysKeys(t *testing.T) { t.Fatalf("expected error") } - le := &Entry{ + le := &physical.Entry{ Key: "../foo", Value: []byte("test"), } @@ -46,11 +50,15 @@ func TestPhysicalView_BadKeysKeys(t *testing.T) { } func TestPhysicalView(t *testing.T) { - backend := newInmemTestBackend() - view := NewView(backend, "foo/") + backend, err := newInmemTestBackend() + if err != nil { + t.Fatal(err) + } + + view := physical.NewView(backend, "foo/") // Write a key outside of foo/ - entry := &Entry{Key: "test", Value: []byte("test")} + entry := &physical.Entry{Key: "test", Value: []byte("test")} if err := backend.Put(entry); err != nil { t.Fatalf("bad: %v", err) } diff --git a/physical/transactions_test.go b/physical/inmem/transactions_test.go similarity index 63% rename from physical/transactions_test.go rename to physical/inmem/transactions_test.go index ab5d02bb7d..5565fbe35a 100644 --- a/physical/transactions_test.go +++ b/physical/inmem/transactions_test.go @@ -1,4 +1,4 @@ -package physical +package inmem import ( "fmt" @@ -8,6 +8,7 @@ import ( radix "github.com/armon/go-radix" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) @@ -16,11 +17,11 @@ type faultyPseudo struct { faultyPaths map[string]struct{} } -func (f *faultyPseudo) Get(key string) (*Entry, error) { +func (f *faultyPseudo) Get(key string) (*physical.Entry, error) { return f.underlying.Get(key) } -func (f *faultyPseudo) Put(entry *Entry) error { +func (f *faultyPseudo) Put(entry *physical.Entry) error { return f.underlying.Put(entry) } @@ -28,14 +29,14 @@ func (f *faultyPseudo) Delete(key string) error { return f.underlying.Delete(key) } -func (f *faultyPseudo) GetInternal(key string) (*Entry, error) { +func (f *faultyPseudo) GetInternal(key string) (*physical.Entry, error) { if _, ok := f.faultyPaths[key]; ok { return nil, fmt.Errorf("fault") } return f.underlying.GetInternal(key) } -func (f *faultyPseudo) PutInternal(entry *Entry) error { +func (f *faultyPseudo) PutInternal(entry *physical.Entry) error { if _, ok := f.faultyPaths[entry.Key]; ok { return fmt.Errorf("fault") } @@ -53,21 +54,21 @@ func (f *faultyPseudo) List(prefix string) ([]string, error) { return f.underlying.List(prefix) } -func (f *faultyPseudo) Transaction(txns []TxnEntry) error { +func (f *faultyPseudo) Transaction(txns []physical.TxnEntry) error { f.underlying.permitPool.Acquire() defer f.underlying.permitPool.Release() f.underlying.Lock() defer f.underlying.Unlock() - return genericTransactionHandler(f, txns) + return physical.GenericTransactionHandler(f, txns) } func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo { out := &faultyPseudo{ underlying: InmemBackend{ root: radix.New(), - permitPool: NewPermitPool(1), + permitPool: physical.NewPermitPool(1), logger: logger, }, faultyPaths: make(map[string]struct{}, len(faultyPaths)), @@ -81,22 +82,22 @@ func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo { func TestPseudo_Basic(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) p := newFaultyPseudo(logger, nil) - testBackend(t, p) - testBackend_ListPrefix(t, p) + physical.ExerciseBackend(t, p) + physical.ExerciseBackend_ListPrefix(t, p) } func TestPseudo_SuccessfulTransaction(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) p := newFaultyPseudo(logger, nil) - testTransactionalBackend(t, p) + physical.ExerciseTransactionalBackend(t, p) } func TestPseudo_FailedTransaction(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) p := newFaultyPseudo(logger, []string{"zip"}) - txns := setupTransactions(t, p) + txns := physical.SetupTestingTransactions(t, p) if err := p.Transaction(txns); err == nil { t.Fatal("expected error during transaction") } @@ -142,67 +143,3 @@ func TestPseudo_FailedTransaction(t *testing.T) { t.Fatal("values did not rollback correctly") } } - -func setupTransactions(t *testing.T, b Backend) []TxnEntry { - // Add a few keys so that we test rollback with deletion - if err := b.Put(&Entry{ - Key: "foo", - Value: []byte("bar"), - }); err != nil { - t.Fatal(err) - } - if err := b.Put(&Entry{ - Key: "zip", - Value: []byte("zap"), - }); err != nil { - t.Fatal(err) - } - if err := b.Put(&Entry{ - Key: "deleteme", - }); err != nil { - t.Fatal(err) - } - if err := b.Put(&Entry{ - Key: "deleteme2", - }); err != nil { - t.Fatal(err) - } - - txns := []TxnEntry{ - TxnEntry{ - Operation: PutOperation, - Entry: &Entry{ - Key: "foo", - Value: []byte("bar2"), - }, - }, - TxnEntry{ - Operation: DeleteOperation, - Entry: &Entry{ - Key: "deleteme", - }, - }, - TxnEntry{ - Operation: PutOperation, - Entry: &Entry{ - Key: "foo", - Value: []byte("bar3"), - }, - }, - TxnEntry{ - Operation: DeleteOperation, - Entry: &Entry{ - Key: "deleteme2", - }, - }, - TxnEntry{ - Operation: PutOperation, - Entry: &Entry{ - Key: "zip", - Value: []byte("zap3"), - }, - }, - } - - return txns -} diff --git a/physical/inmem_ha_test.go b/physical/inmem_ha_test.go deleted file mode 100644 index 102f85b027..0000000000 --- a/physical/inmem_ha_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package physical - -import ( - "testing" - - "github.com/hashicorp/vault/helper/logformat" - log "github.com/mgutz/logxi/v1" -) - -func TestInmemHA(t *testing.T) { - logger := logformat.NewVaultLogger(log.LevelTrace) - - inm := NewInmemHA(logger) - testHABackend(t, inm, inm) -} diff --git a/physical/inmem_test.go b/physical/inmem_test.go deleted file mode 100644 index 7c3c788822..0000000000 --- a/physical/inmem_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package physical - -import ( - "testing" - - "github.com/hashicorp/vault/helper/logformat" - log "github.com/mgutz/logxi/v1" -) - -func TestInmem(t *testing.T) { - logger := logformat.NewVaultLogger(log.LevelTrace) - - inm := NewInmem(logger) - testBackend(t, inm) - testBackend_ListPrefix(t, inm) -} diff --git a/physical/mssql.go b/physical/mssql/mssql.go similarity index 89% rename from physical/mssql.go rename to physical/mssql/mssql.go index c9d50e1ddd..16228d624e 100644 --- a/physical/mssql.go +++ b/physical/mssql/mssql.go @@ -1,4 +1,4 @@ -package physical +package mssql import ( "database/sql" @@ -12,18 +12,19 @@ import ( _ "github.com/denisenkom/go-mssqldb" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" ) -type MsSQLBackend struct { +type MSSQLBackend struct { dbTable string client *sql.DB statements map[string]*sql.Stmt logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -func newMsSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { username, ok := conf["username"] if !ok { username = "" @@ -51,7 +52,7 @@ func newMsSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt) } } else { - maxParInt = DefaultParallelOperations + maxParInt = physical.DefaultParallelOperations } database, ok := conf["database"] @@ -131,12 +132,12 @@ func newMsSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) return nil, fmt.Errorf("failed to create mssql table: %v", err) } - m := &MsSQLBackend{ + m := &MSSQLBackend{ dbTable: dbTable, client: db, statements: make(map[string]*sql.Stmt), logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } statements := map[string]string{ @@ -156,7 +157,7 @@ func newMsSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) return m, nil } -func (m *MsSQLBackend) prepare(name, query string) error { +func (m *MSSQLBackend) prepare(name, query string) error { stmt, err := m.client.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare '%s': %v", name, err) @@ -167,7 +168,7 @@ func (m *MsSQLBackend) prepare(name, query string) error { return nil } -func (m *MsSQLBackend) Put(entry *Entry) error { +func (m *MSSQLBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now()) m.permitPool.Acquire() @@ -181,7 +182,7 @@ func (m *MsSQLBackend) Put(entry *Entry) error { return nil } -func (m *MsSQLBackend) Get(key string) (*Entry, error) { +func (m *MSSQLBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now()) m.permitPool.Acquire() @@ -197,7 +198,7 @@ func (m *MsSQLBackend) Get(key string) (*Entry, error) { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: result, } @@ -205,7 +206,7 @@ func (m *MsSQLBackend) Get(key string) (*Entry, error) { return ent, nil } -func (m *MsSQLBackend) Delete(key string) error { +func (m *MSSQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now()) m.permitPool.Acquire() @@ -219,7 +220,7 @@ func (m *MsSQLBackend) Delete(key string) error { return nil } -func (m *MsSQLBackend) List(prefix string) ([]string, error) { +func (m *MSSQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now()) m.permitPool.Acquire() diff --git a/physical/mssql_test.go b/physical/mssql/mssql_test.go similarity index 77% rename from physical/mssql_test.go rename to physical/mssql/mssql_test.go index 11f4684ea4..7e1446e94d 100644 --- a/physical/mssql_test.go +++ b/physical/mssql/mssql_test.go @@ -1,16 +1,17 @@ -package physical +package mssql import ( "os" "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" _ "github.com/denisenkom/go-mssqldb" ) -func TestMsSQLBackend(t *testing.T) { +func TestMSSQLBackend(t *testing.T) { server := os.Getenv("MSSQL_SERVER") if server == "" { t.SkipNow() @@ -32,27 +33,26 @@ func TestMsSQLBackend(t *testing.T) { // Run vault tests logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("mssql", logger, map[string]string{ + b, err := NewMSSQLBackend(map[string]string{ "server": server, "database": database, "table": table, "username": username, "password": password, - }) + }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) } defer func() { - mssql := b.(*MsSQLBackend) + mssql := b.(*MSSQLBackend) _, err := mssql.client.Exec("DROP TABLE " + mssql.dbTable) if err != nil { t.Fatalf("Failed to drop table: %v", err) } }() - testBackend(t, b) - testBackend_ListPrefix(t, b) - + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/mysql.go b/physical/mysql/mysql.go similarity index 92% rename from physical/mysql.go rename to physical/mysql/mysql.go index affc4d1783..87daa9a461 100644 --- a/physical/mysql.go +++ b/physical/mysql/mysql.go @@ -1,4 +1,4 @@ -package physical +package mysql import ( "crypto/tls" @@ -18,6 +18,7 @@ import ( mysql "github.com/go-sql-driver/mysql" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" ) // Unreserved tls key @@ -31,12 +32,12 @@ type MySQLBackend struct { client *sql.DB statements map[string]*sql.Stmt logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newMySQLBackend constructs a MySQL backend using the given API client and +// NewMySQLBackend constructs a MySQL backend using the given API client and // server address and credential for accessing mysql database. -func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { var err error // Get the MySQL credentials to perform read/write operations. @@ -77,7 +78,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt) } } else { - maxParInt = DefaultParallelOperations + maxParInt = physical.DefaultParallelOperations } dsnParams := url.Values{} @@ -117,7 +118,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) client: db, statements: make(map[string]*sql.Stmt), logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } // Prepare all the statements required @@ -148,7 +149,7 @@ func (m *MySQLBackend) prepare(name, query string) error { } // Put is used to insert or update an entry. -func (m *MySQLBackend) Put(entry *Entry) error { +func (m *MySQLBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) m.permitPool.Acquire() @@ -162,7 +163,7 @@ func (m *MySQLBackend) Put(entry *Entry) error { } // Get is used to fetch and entry. -func (m *MySQLBackend) Get(key string) (*Entry, error) { +func (m *MySQLBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) m.permitPool.Acquire() @@ -177,7 +178,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: result, } diff --git a/physical/mysql_test.go b/physical/mysql/mysql_test.go similarity index 83% rename from physical/mysql_test.go rename to physical/mysql/mysql_test.go index 1eabd9f18c..ecf8431416 100644 --- a/physical/mysql_test.go +++ b/physical/mysql/mysql_test.go @@ -1,10 +1,11 @@ -package physical +package mysql import ( "os" "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" _ "github.com/go-sql-driver/mysql" @@ -32,13 +33,13 @@ func TestMySQLBackend(t *testing.T) { // Run vault tests logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("mysql", logger, map[string]string{ + b, err := NewMySQLBackend(map[string]string{ "address": address, "database": database, "table": table, "username": username, "password": password, - }) + }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) @@ -52,7 +53,6 @@ func TestMySQLBackend(t *testing.T) { } }() - testBackend(t, b) - testBackend_ListPrefix(t, b) - + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/physical.go b/physical/physical.go index 16a7e6645a..088a86b995 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -1,7 +1,7 @@ package physical import ( - "fmt" + "strings" "sync" log "github.com/mgutz/logxi/v1" @@ -70,8 +70,8 @@ type RedirectDetect interface { } // Callback signatures for RunServiceDiscovery -type activeFunction func() bool -type sealedFunction func() bool +type ActiveFunction func() bool +type SealedFunction func() bool // ServiceDiscovery is an optional interface that an HABackend can implement. // If they do, the state of a backend is advertised to the service discovery @@ -89,7 +89,7 @@ type ServiceDiscovery interface { // Run executes any background service discovery tasks until the // shutdown channel is closed. - RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error + RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc ActiveFunction, sealedFunc SealedFunction) error } type Lock interface { @@ -115,50 +115,6 @@ type Entry struct { // Factory is the factory function to create a physical backend. type Factory func(config map[string]string, logger log.Logger) (Backend, error) -// NewBackend returns a new backend with the given type and configuration. -// The backend is looked up in the builtinBackends variable. -func NewBackend(t string, logger log.Logger, conf map[string]string) (Backend, error) { - f, ok := builtinBackends[t] - if !ok { - return nil, fmt.Errorf("unknown physical backend type: %s", t) - } - return f(conf, logger) -} - -// BuiltinBackends is the list of built-in physical backends that can -// be used with NewBackend. -var builtinBackends = map[string]Factory{ - "inmem": func(_ map[string]string, logger log.Logger) (Backend, error) { - return NewInmem(logger), nil - }, - "inmem_transactional": func(_ map[string]string, logger log.Logger) (Backend, error) { - return NewTransactionalInmem(logger), nil - }, - "inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) { - return NewInmemHA(logger), nil - }, - "inmem_transactional_ha": func(_ map[string]string, logger log.Logger) (Backend, error) { - return NewTransactionalInmemHA(logger), nil - }, - "file_transactional": newTransactionalFileBackend, - "consul": newConsulBackend, - "zookeeper": newZookeeperBackend, - "file": newFileBackend, - "s3": newS3Backend, - "azure": newAzureBackend, - "dynamodb": newDynamoDBBackend, - "etcd": newEtcdBackend, - "mssql": newMsSQLBackend, - "mysql": newMySQLBackend, - "postgresql": newPostgreSQLBackend, - "cockroachdb": newCockroachDBBackend, - "couchdb": newCouchDBBackend, - "couchdb_transactional": newTransactionalCouchDBBackend, - "swift": newSwiftBackend, - "gcs": newGCSBackend, - "cassandra": newCassandraBackend, -} - // PermitPool is used to limit maximum outstanding requests type PermitPool struct { sem chan int @@ -184,3 +140,15 @@ func (c *PermitPool) Acquire() { func (c *PermitPool) Release() { <-c.sem } + +// Prefixes is a shared helper function returns all parent 'folders' for a +// given vault key. +// e.g. for 'foo/bar/baz', it returns ['foo', 'foo/bar'] +func Prefixes(s string) []string { + components := strings.Split(s, "/") + result := []string{} + for i := 1; i < len(components); i++ { + result = append(result, strings.Join(components[:i], "/")) + } + return result +} diff --git a/physical/postgresql.go b/physical/postgresql/postgresql.go similarity index 90% rename from physical/postgresql.go rename to physical/postgresql/postgresql.go index 631be10743..cb35782df8 100644 --- a/physical/postgresql.go +++ b/physical/postgresql/postgresql.go @@ -1,4 +1,4 @@ -package physical +package postgresql import ( "database/sql" @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "github.com/armon/go-metrics" @@ -24,12 +25,12 @@ type PostgreSQLBackend struct { delete_query string list_query string logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newPostgreSQLBackend constructs a PostgreSQL backend using the given +// NewPostgreSQLBackend constructs a PostgreSQL backend using the given // API client, server address, credentials, and database. -func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { +func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the PostgreSQL credentials to perform read/write operations. connURL, ok := conf["connection_url"] if !ok || connURL == "" { @@ -54,7 +55,7 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e logger.Debug("postgres: max_parallel set", "max_parallel", maxParInt) } } else { - maxParInt = DefaultParallelOperations + maxParInt = physical.DefaultParallelOperations } // Create PostgreSQL handle for the database. @@ -93,7 +94,7 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e "UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table + " WHERE parent_path LIKE $1 || '%'", logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } return m, nil @@ -124,7 +125,7 @@ func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { } // Put is used to insert or update an entry. -func (m *PostgreSQLBackend) Put(entry *Entry) error { +func (m *PostgreSQLBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now()) m.permitPool.Acquire() @@ -140,7 +141,7 @@ func (m *PostgreSQLBackend) Put(entry *Entry) error { } // Get is used to fetch and entry. -func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) { +func (m *PostgreSQLBackend) Get(fullPath string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now()) m.permitPool.Acquire() @@ -157,7 +158,7 @@ func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: result, } diff --git a/physical/postgresql_test.go b/physical/postgresql/postgresql_test.go similarity index 78% rename from physical/postgresql_test.go rename to physical/postgresql/postgresql_test.go index 5cdaaa02de..940d0e253a 100644 --- a/physical/postgresql_test.go +++ b/physical/postgresql/postgresql_test.go @@ -1,10 +1,11 @@ -package physical +package postgresql import ( "os" "testing" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" _ "github.com/lib/pq" @@ -24,11 +25,10 @@ func TestPostgreSQLBackend(t *testing.T) { // Run vault tests logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("postgresql", logger, map[string]string{ + b, err := NewPostgreSQLBackend(map[string]string{ "connection_url": connURL, "table": table, - }) - + }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) } @@ -41,7 +41,6 @@ func TestPostgreSQLBackend(t *testing.T) { } }() - testBackend(t, b) - testBackend_ListPrefix(t, b) - + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/s3.go b/physical/s3/s3.go similarity index 92% rename from physical/s3.go rename to physical/s3/s3.go index 13df06bc4f..7118e7da14 100644 --- a/physical/s3.go +++ b/physical/s3/s3.go @@ -1,4 +1,4 @@ -package physical +package s3 import ( "bytes" @@ -22,6 +22,7 @@ import ( cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/awsutil" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/physical" ) // S3Backend is a physical backend that stores data @@ -30,14 +31,13 @@ type S3Backend struct { bucket string client *s3.S3 logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newS3Backend constructs a S3 backend using a pre-existing +// NewS3Backend constructs a S3 backend using a pre-existing // bucket. Credentials can be provided to the backend, sourced // from the environment, AWS credential files or by IAM role. -func newS3Backend(conf map[string]string, logger log.Logger) (Backend, error) { - +func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) { bucket := os.Getenv("AWS_S3_BUCKET") if bucket == "" { bucket = conf["bucket"] @@ -116,13 +116,13 @@ func newS3Backend(conf map[string]string, logger log.Logger) (Backend, error) { client: s3conn, bucket: bucket, logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } return s, nil } // Put is used to insert or update an entry -func (s *S3Backend) Put(entry *Entry) error { +func (s *S3Backend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"s3", "put"}, time.Now()) s.permitPool.Acquire() @@ -142,7 +142,7 @@ func (s *S3Backend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (s *S3Backend) Get(key string) (*Entry, error) { +func (s *S3Backend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"s3", "get"}, time.Now()) s.permitPool.Acquire() @@ -172,7 +172,7 @@ func (s *S3Backend) Get(key string) (*Entry, error) { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: data, } diff --git a/physical/s3_test.go b/physical/s3/s3_test.go similarity index 91% rename from physical/s3_test.go rename to physical/s3/s3_test.go index 7191cfec44..dbe4c93339 100644 --- a/physical/s3_test.go +++ b/physical/s3/s3_test.go @@ -1,4 +1,4 @@ -package physical +package s3 import ( "fmt" @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/helper/awsutil" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "github.com/aws/aws-sdk-go/aws" @@ -81,14 +82,13 @@ func TestS3Backend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) // This uses the same logic to find the AWS credentials as we did at the beginning of the test - b, err := NewBackend("s3", logger, map[string]string{ + b, err := NewS3Backend(map[string]string{ "bucket": bucket, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) - + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/swift.go b/physical/swift/swift.go similarity index 91% rename from physical/swift.go rename to physical/swift/swift.go index cff664e6fd..30d7e66e44 100644 --- a/physical/swift.go +++ b/physical/swift/swift.go @@ -1,4 +1,4 @@ -package physical +package swift import ( "fmt" @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/physical" "github.com/ncw/swift" ) @@ -23,14 +24,13 @@ type SwiftBackend struct { container string client *swift.Connection logger log.Logger - permitPool *PermitPool + permitPool *physical.PermitPool } -// newSwiftBackend constructs a Swift backend using a pre-existing +// NewSwiftBackend constructs a Swift backend using a pre-existing // container. Credentials can be provided to the backend, sourced // from the environment. -func newSwiftBackend(conf map[string]string, logger log.Logger) (Backend, error) { - +func NewSwiftBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { var ok bool username := os.Getenv("OS_USERNAME") @@ -117,13 +117,13 @@ func newSwiftBackend(conf map[string]string, logger log.Logger) (Backend, error) client: &c, container: container, logger: logger, - permitPool: NewPermitPool(maxParInt), + permitPool: physical.NewPermitPool(maxParInt), } return s, nil } // Put is used to insert or update an entry -func (s *SwiftBackend) Put(entry *Entry) error { +func (s *SwiftBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"swift", "put"}, time.Now()) s.permitPool.Acquire() @@ -139,7 +139,7 @@ func (s *SwiftBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (s *SwiftBackend) Get(key string) (*Entry, error) { +func (s *SwiftBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"swift", "get"}, time.Now()) s.permitPool.Acquire() @@ -162,7 +162,7 @@ func (s *SwiftBackend) Get(key string) (*Entry, error) { if err != nil { return nil, err } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: data, } diff --git a/physical/swift_test.go b/physical/swift/swift_test.go similarity index 90% rename from physical/swift_test.go rename to physical/swift/swift_test.go index 2da37f043e..5aa2ec9581 100644 --- a/physical/swift_test.go +++ b/physical/swift/swift_test.go @@ -1,4 +1,4 @@ -package physical +package swift import ( "fmt" @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" "github.com/ncw/swift" ) @@ -66,7 +67,7 @@ func TestSwiftBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("swift", logger, map[string]string{ + b, err := NewSwiftBackend(map[string]string{ "username": username, "password": password, "container": container, @@ -74,12 +75,11 @@ func TestSwiftBackend(t *testing.T) { "project": project, "domain": domain, "project-domain": projectDomain, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) - + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } diff --git a/physical/physical_test.go b/physical/testing.go similarity index 86% rename from physical/physical_test.go rename to physical/testing.go index 6921a89d49..1b434a2a31 100644 --- a/physical/physical_test.go +++ b/physical/testing.go @@ -5,30 +5,9 @@ import ( "sort" "testing" "time" - - "github.com/hashicorp/vault/helper/logformat" - log "github.com/mgutz/logxi/v1" ) -func testNewBackend(t *testing.T) { - logger := logformat.NewVaultLogger(log.LevelTrace) - - _, err := NewBackend("foobar", logger, nil) - if err == nil { - t.Fatalf("expected error") - } - - b, err := NewBackend("inmem", logger, nil) - if err != nil { - t.Fatalf("err: %v", err) - } - - if b == nil { - t.Fatalf("expected backend") - } -} - -func testBackend(t *testing.T, b Backend) { +func ExerciseBackend(t *testing.T, b Backend) { // Should be empty keys, err := b.List("") if err != nil { @@ -216,7 +195,7 @@ func testBackend(t *testing.T, b Backend) { } } -func testBackend_ListPrefix(t *testing.T, b Backend) { +func ExerciseBackend_ListPrefix(t *testing.T, b Backend) { e1 := &Entry{Key: "foo", Value: []byte("test")} e2 := &Entry{Key: "foo/bar", Value: []byte("test")} e3 := &Entry{Key: "foo/bar/baz", Value: []byte("test")} @@ -286,7 +265,7 @@ func testBackend_ListPrefix(t *testing.T, b Backend) { } } -func testHABackend(t *testing.T, b HABackend, b2 HABackend) { +func ExerciseHABackend(t *testing.T, b HABackend, b2 HABackend) { // Get the lock lock, err := b.LockWith("foo", "bar") if err != nil { @@ -362,13 +341,12 @@ func testHABackend(t *testing.T, b HABackend, b2 HABackend) { lock2.Unlock() } -type delays struct { - beforeGet time.Duration - beforeList time.Duration +type Delays struct { + BeforeGet time.Duration + BeforeList time.Duration } -func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { - +func ExerciseEventuallyConsistentBackend(t *testing.T, b Backend, d Delays) { // no delay required: nothing written to bucket // Should be empty keys, err := b.List("") @@ -403,7 +381,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } // Get should work - time.Sleep(d.beforeGet) + time.Sleep(d.BeforeGet) out, err = b.Get("foo") if err != nil { t.Fatalf("err: %v", err) @@ -413,7 +391,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } // List should not be empty - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -432,7 +410,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } // Should be empty - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -442,7 +420,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } // Get should fail - time.Sleep(d.beforeGet) + time.Sleep(d.BeforeGet) out, err = b.Get("foo") if err != nil { t.Fatalf("err: %v", err) @@ -470,7 +448,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { t.Fatalf("err: %v", err) } - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -490,7 +468,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } // Get should return the child - time.Sleep(d.beforeGet) + time.Sleep(d.BeforeGet) out, err = b.Get("foo/bar") if err != nil { t.Fatalf("err: %v", err) @@ -511,7 +489,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { t.Fatalf("failed to remove nested secret: %v", err) } - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("foo/") if err != nil { t.Fatalf("err: %v", err) @@ -539,7 +517,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { t.Fatalf("err: %v", err) } - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -557,7 +535,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { t.Fatalf("err: %v", err) } - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -567,7 +545,7 @@ func testEventuallyConsistentBackend(t *testing.T, b Backend, d delays) { } } -func testEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d delays) { +func ExerciseEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d Delays) { e1 := &Entry{Key: "foo", Value: []byte("test")} e2 := &Entry{Key: "foo/bar", Value: []byte("test")} e3 := &Entry{Key: "foo/bar/baz", Value: []byte("test")} @@ -586,7 +564,7 @@ func testEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d delay } // Scan the root - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err := b.List("") if err != nil { t.Fatalf("err: %v", err) @@ -603,7 +581,7 @@ func testEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d delay } // Scan foo/ - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("foo/") if err != nil { t.Fatalf("err: %v", err) @@ -620,7 +598,7 @@ func testEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d delay } // Scan foo/bar/ - time.Sleep(d.beforeList) + time.Sleep(d.BeforeList) keys, err = b.List("foo/bar/") if err != nil { t.Fatalf("err: %v", err) @@ -635,13 +613,13 @@ func testEventuallyConsistentBackend_ListPrefix(t *testing.T, b Backend, d delay } -func testTransactionalBackend(t *testing.T, b Backend) { +func ExerciseTransactionalBackend(t *testing.T, b Backend) { tb, ok := b.(Transactional) if !ok { t.Fatal("Not a transactional backend") } - txns := setupTransactions(t, b) + txns := SetupTestingTransactions(t, b) if err := tb.Transaction(txns); err != nil { t.Fatal(err) @@ -688,3 +666,67 @@ func testTransactionalBackend(t *testing.T, b Backend) { t.Fatal("updates did not apply correctly") } } + +func SetupTestingTransactions(t *testing.T, b Backend) []TxnEntry { + // Add a few keys so that we test rollback with deletion + if err := b.Put(&Entry{ + Key: "foo", + Value: []byte("bar"), + }); err != nil { + t.Fatal(err) + } + if err := b.Put(&Entry{ + Key: "zip", + Value: []byte("zap"), + }); err != nil { + t.Fatal(err) + } + if err := b.Put(&Entry{ + Key: "deleteme", + }); err != nil { + t.Fatal(err) + } + if err := b.Put(&Entry{ + Key: "deleteme2", + }); err != nil { + t.Fatal(err) + } + + txns := []TxnEntry{ + TxnEntry{ + Operation: PutOperation, + Entry: &Entry{ + Key: "foo", + Value: []byte("bar2"), + }, + }, + TxnEntry{ + Operation: DeleteOperation, + Entry: &Entry{ + Key: "deleteme", + }, + }, + TxnEntry{ + Operation: PutOperation, + Entry: &Entry{ + Key: "foo", + Value: []byte("bar3"), + }, + }, + TxnEntry{ + Operation: DeleteOperation, + Entry: &Entry{ + Key: "deleteme2", + }, + }, + TxnEntry{ + Operation: PutOperation, + Entry: &Entry{ + Key: "zip", + Value: []byte("zap3"), + }, + }, + } + + return txns +} diff --git a/physical/transactions.go b/physical/transactions.go index b9ddffa902..f8668d2be4 100644 --- a/physical/transactions.go +++ b/physical/transactions.go @@ -27,7 +27,7 @@ type PseudoTransactional interface { } // Implements the transaction interface -func genericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) { +func GenericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) { rollbackStack := make([]TxnEntry, 0, len(txns)) var dirty bool diff --git a/physical/zookeeper.go b/physical/zookeeper/zookeeper.go similarity index 87% rename from physical/zookeeper.go rename to physical/zookeeper/zookeeper.go index b772cf8878..8ecc0d6eac 100644 --- a/physical/zookeeper.go +++ b/physical/zookeeper/zookeeper.go @@ -1,4 +1,4 @@ -package physical +package zookeeper import ( "fmt" @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" metrics "github.com/armon/go-metrics" @@ -22,20 +23,20 @@ const ( ZKNodeFilePrefix = "_" ) -// ZookeeperBackend is a physical backend that stores data at specific -// prefix within Zookeeper. It is used in production situations as +// ZooKeeperBackend is a physical backend that stores data at specific +// prefix within ZooKeeper. It is used in production situations as // it allows Vault to run on multiple machines in a highly-available manner. -type ZookeeperBackend struct { +type ZooKeeperBackend struct { path string client *zk.Conn acl []zk.ACL logger log.Logger } -// newZookeeperBackend constructs a Zookeeper backend using the given API client +// NewZooKeeperBackend constructs a ZooKeeper backend using the given API client // and the prefix in the KV store. -func newZookeeperBackend(conf map[string]string, logger log.Logger) (Backend, error) { - // Get the path in Zookeeper +func NewZooKeeperBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { + // Get the path in ZooKeeper path, ok := conf["path"] if !ok { path = "vault/" @@ -114,12 +115,12 @@ func newZookeeperBackend(conf map[string]string, logger log.Logger) (Backend, er if useAddAuth { err = client.AddAuth(schema, []byte(owner)) if err != nil { - return nil, fmt.Errorf("Zookeeper rejected authentication information provided at auth_info: %v", err) + return nil, fmt.Errorf("ZooKeeper rejected authentication information provided at auth_info: %v", err) } } // Setup the backend - c := &ZookeeperBackend{ + c := &ZooKeeperBackend{ path: path, client: client, acl: acl, @@ -131,7 +132,7 @@ func newZookeeperBackend(conf map[string]string, logger log.Logger) (Backend, er // ensurePath is used to create each node in the path hierarchy. // We avoid calling this optimistically, and invoke it when we get // an error during an operation -func (c *ZookeeperBackend) ensurePath(path string, value []byte) error { +func (c *ZooKeeperBackend) ensurePath(path string, value []byte) error { nodes := strings.Split(path, "/") fullPath := "" for index, node := range nodes { @@ -161,7 +162,7 @@ func (c *ZookeeperBackend) ensurePath(path string, value []byte) error { // cleanupLogicalPath is used to remove all empty nodes, begining with deepest one, // aborting on first non-empty one, up to top-level node. -func (c *ZookeeperBackend) cleanupLogicalPath(path string) error { +func (c *ZooKeeperBackend) cleanupLogicalPath(path string) error { nodes := strings.Split(path, "/") for i := len(nodes) - 1; i > 0; i-- { fullPath := c.path + strings.Join(nodes[:i], "/") @@ -192,12 +193,12 @@ func (c *ZookeeperBackend) cleanupLogicalPath(path string) error { } // nodePath returns an zk path based on the given key. -func (c *ZookeeperBackend) nodePath(key string) string { +func (c *ZooKeeperBackend) nodePath(key string) string { return filepath.Join(c.path, filepath.Dir(key), ZKNodeFilePrefix+filepath.Base(key)) } // Put is used to insert or update an entry -func (c *ZookeeperBackend) Put(entry *Entry) error { +func (c *ZooKeeperBackend) Put(entry *physical.Entry) error { defer metrics.MeasureSince([]string{"zookeeper", "put"}, time.Now()) // Attempt to set the full path @@ -212,7 +213,7 @@ func (c *ZookeeperBackend) Put(entry *Entry) error { } // Get is used to fetch an entry -func (c *ZookeeperBackend) Get(key string) (*Entry, error) { +func (c *ZooKeeperBackend) Get(key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"zookeeper", "get"}, time.Now()) // Attempt to read the full path @@ -231,7 +232,7 @@ func (c *ZookeeperBackend) Get(key string) (*Entry, error) { if value == nil { return nil, nil } - ent := &Entry{ + ent := &physical.Entry{ Key: key, Value: value, } @@ -239,7 +240,7 @@ func (c *ZookeeperBackend) Get(key string) (*Entry, error) { } // Delete is used to permanently delete an entry -func (c *ZookeeperBackend) Delete(key string) error { +func (c *ZooKeeperBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"zookeeper", "delete"}, time.Now()) if key == "" { @@ -262,7 +263,7 @@ func (c *ZookeeperBackend) Delete(key string) error { // List is used ot list all the keys under a given // prefix, up to the next prefix. -func (c *ZookeeperBackend) List(prefix string) ([]string, error) { +func (c *ZooKeeperBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"zookeeper", "list"}, time.Now()) // Query the children at the full path @@ -310,8 +311,8 @@ func (c *ZookeeperBackend) List(prefix string) ([]string, error) { } // LockWith is used for mutual exclusion based on the given key. -func (c *ZookeeperBackend) LockWith(key, value string) (Lock, error) { - l := &ZookeeperHALock{ +func (c *ZooKeeperBackend) LockWith(key, value string) (physical.Lock, error) { + l := &ZooKeeperHALock{ in: c, key: key, value: value, @@ -321,13 +322,13 @@ func (c *ZookeeperBackend) LockWith(key, value string) (Lock, error) { // HAEnabled indicates whether the HA functionality should be exposed. // Currently always returns true. -func (c *ZookeeperBackend) HAEnabled() bool { +func (c *ZooKeeperBackend) HAEnabled() bool { return true } -// ZookeeperHALock is a Zookeeper Lock implementation for the HABackend -type ZookeeperHALock struct { - in *ZookeeperBackend +// ZooKeeperHALock is a ZooKeeper Lock implementation for the HABackend +type ZooKeeperHALock struct { + in *ZooKeeperBackend key string value string @@ -337,7 +338,7 @@ type ZookeeperHALock struct { zkLock *zk.Lock } -func (i *ZookeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { +func (i *ZooKeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { i.localLock.Lock() defer i.localLock.Unlock() if i.held { @@ -379,7 +380,7 @@ func (i *ZookeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) return i.leaderCh, nil } -func (i *ZookeeperHALock) attemptLock(lockpath string, didLock chan struct{}, failLock chan error, releaseCh chan bool) { +func (i *ZooKeeperHALock) attemptLock(lockpath string, didLock chan struct{}, failLock chan error, releaseCh chan bool) { // Wait to acquire the lock in ZK lock := zk.NewLock(i.in.client, lockpath, i.in.acl) err := lock.Lock() @@ -407,7 +408,7 @@ func (i *ZookeeperHALock) attemptLock(lockpath string, didLock chan struct{}, fa } } -func (i *ZookeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan struct{}) { +func (i *ZooKeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan struct{}) { for { select { case event := <-lockeventCh: @@ -432,7 +433,7 @@ func (i *ZookeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan } } -func (i *ZookeeperHALock) Unlock() error { +func (i *ZooKeeperHALock) Unlock() error { i.localLock.Lock() defer i.localLock.Unlock() if !i.held { @@ -444,7 +445,7 @@ func (i *ZookeeperHALock) Unlock() error { return nil } -func (i *ZookeeperHALock) Value() (bool, string, error) { +func (i *ZooKeeperHALock) Value() (bool, string, error) { lockpath := i.in.nodePath(i.key) value, _, err := i.in.client.Get(lockpath) return (value != nil), string(value), err diff --git a/physical/zookeeper_test.go b/physical/zookeeper/zookeeper_test.go similarity index 80% rename from physical/zookeeper_test.go rename to physical/zookeeper/zookeeper_test.go index b9969aed27..a85c27ccd8 100644 --- a/physical/zookeeper_test.go +++ b/physical/zookeeper/zookeeper_test.go @@ -1,4 +1,4 @@ -package physical +package zookeeper import ( "fmt" @@ -7,12 +7,13 @@ import ( "time" "github.com/hashicorp/vault/helper/logformat" + "github.com/hashicorp/vault/physical" log "github.com/mgutz/logxi/v1" "github.com/samuel/go-zookeeper/zk" ) -func TestZookeeperBackend(t *testing.T) { +func TestZooKeeperBackend(t *testing.T) { addr := os.Getenv("ZOOKEEPER_ADDR") if addr == "" { t.SkipNow() @@ -45,19 +46,19 @@ func TestZookeeperBackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("zookeeper", logger, map[string]string{ + b, err := NewZooKeeperBackend(map[string]string{ "address": addr + "," + addr, "path": randPath, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - testBackend(t, b) - testBackend_ListPrefix(t, b) + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) } -func TestZookeeperHABackend(t *testing.T) { +func TestZooKeeperHABackend(t *testing.T) { addr := os.Getenv("ZOOKEEPER_ADDR") if addr == "" { t.SkipNow() @@ -85,17 +86,17 @@ func TestZookeeperHABackend(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) - b, err := NewBackend("zookeeper", logger, map[string]string{ + b, err := NewZooKeeperBackend(map[string]string{ "address": addr + "," + addr, "path": randPath, - }) + }, logger) if err != nil { t.Fatalf("err: %s", err) } - ha, ok := b.(HABackend) + ha, ok := b.(physical.HABackend) if !ok { t.Fatalf("zookeeper does not implement HABackend") } - testHABackend(t, ha, ha) + physical.ExerciseHABackend(t, ha, ha) } diff --git a/vault/barrier_aes_gcm_test.go b/vault/barrier_aes_gcm_test.go index 1303b72435..ef0fe38daa 100644 --- a/vault/barrier_aes_gcm_test.go +++ b/vault/barrier_aes_gcm_test.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" log "github.com/mgutz/logxi/v1" ) @@ -16,8 +17,10 @@ var ( // mockBarrier returns a physical backend, security barrier, and master key func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -31,8 +34,10 @@ func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) { } func TestAESGCMBarrier_Basic(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -41,8 +46,10 @@ func TestAESGCMBarrier_Basic(t *testing.T) { } func TestAESGCMBarrier_Rotate(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -51,8 +58,10 @@ func TestAESGCMBarrier_Rotate(t *testing.T) { } func TestAESGCMBarrier_Upgrade(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b1, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -65,8 +74,10 @@ func TestAESGCMBarrier_Upgrade(t *testing.T) { } func TestAESGCMBarrier_Upgrade_Rekey(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b1, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -79,8 +90,10 @@ func TestAESGCMBarrier_Upgrade_Rekey(t *testing.T) { } func TestAESGCMBarrier_Rekey(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -91,8 +104,10 @@ func TestAESGCMBarrier_Rekey(t *testing.T) { // Test an upgrade from the old (0.1) barrier/init to the new // core/keyring style func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -171,8 +186,10 @@ func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) { // Verify data sent through is encrypted func TestAESGCMBarrier_Confidential(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -209,8 +226,10 @@ func TestAESGCMBarrier_Confidential(t *testing.T) { // Verify data sent through cannot be tampered with func TestAESGCMBarrier_Integrity(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -245,8 +264,10 @@ func TestAESGCMBarrier_Integrity(t *testing.T) { // Verify data sent through cannot be moved func TestAESGCMBarrier_MoveIntegrityV1(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -287,8 +308,10 @@ func TestAESGCMBarrier_MoveIntegrityV1(t *testing.T) { } func TestAESGCMBarrier_MoveIntegrityV2(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -329,8 +352,10 @@ func TestAESGCMBarrier_MoveIntegrityV2(t *testing.T) { } func TestAESGCMBarrier_UpgradeV1toV2(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -382,8 +407,10 @@ func TestAESGCMBarrier_UpgradeV1toV2(t *testing.T) { } func TestEncrypt_Unique(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -410,8 +437,10 @@ func TestEncrypt_Unique(t *testing.T) { } func TestInitialize_KeyLength(t *testing.T) { - - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) @@ -441,7 +470,13 @@ func TestInitialize_KeyLength(t *testing.T) { } func TestEncrypt_BarrierEncryptor(t *testing.T) { - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatalf("err: %v", err) + } + if err != nil { + t.Fatalf("err: %v", err) + } b, err := NewAESGCMBarrier(inm) if err != nil { t.Fatalf("err: %v", err) diff --git a/vault/cluster_test.go b/vault/cluster_test.go index 2d59b87dff..a2711217b7 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" log "github.com/mgutz/logxi/v1" ) @@ -43,9 +44,17 @@ func TestClusterHAFetching(t *testing.T) { redirect := "http://127.0.0.1:8200" + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } c, err := NewCore(&CoreConfig{ - Physical: physical.NewInmemHA(logger), - HAPhysical: physical.NewInmemHA(logger), + Physical: inm, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirect, DisableMlock: true, }) diff --git a/vault/core_test.go b/vault/core_test.go index aab33111be..b940254d31 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" log "github.com/mgutz/logxi/v1" ) @@ -23,12 +24,17 @@ var ( func TestNewCore_badRedirectAddr(t *testing.T) { logger = logformat.NewVaultLogger(log.LevelTrace) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + conf := &CoreConfig{ RedirectAddr: "127.0.0.1:8200", - Physical: physical.NewInmem(logger), + Physical: inm, DisableMlock: true, } - _, err := NewCore(conf) + _, err = NewCore(conf) if err == nil { t.Fatal("should error") } @@ -974,12 +980,19 @@ func TestCore_Standby_Seal(t *testing.T) { // Create the first core and initialize it logger = logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) - inmha := physical.NewInmemHA(logger) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal, DisableMlock: true, }) @@ -1021,7 +1034,7 @@ func TestCore_Standby_Seal(t *testing.T) { redirectOriginal2 := "http://127.0.0.1:8500" core2, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal2, DisableMlock: true, }) @@ -1085,12 +1098,19 @@ func TestCore_StepDown(t *testing.T) { // Create the first core and initialize it logger = logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) - inmha := physical.NewInmemHA(logger) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal, DisableMlock: true, }) @@ -1132,7 +1152,7 @@ func TestCore_StepDown(t *testing.T) { redirectOriginal2 := "http://127.0.0.1:8500" core2, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal2, DisableMlock: true, }) @@ -1276,12 +1296,19 @@ func TestCore_CleanLeaderPrefix(t *testing.T) { // Create the first core and initialize it logger = logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) - inmha := physical.NewInmemHA(logger) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal, DisableMlock: true, }) @@ -1350,7 +1377,7 @@ func TestCore_CleanLeaderPrefix(t *testing.T) { redirectOriginal2 := "http://127.0.0.1:8500" core2, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal2, DisableMlock: true, }) @@ -1438,14 +1465,27 @@ func TestCore_CleanLeaderPrefix(t *testing.T) { func TestCore_Standby(t *testing.T) { logger = logformat.NewVaultLogger(log.LevelTrace) - inmha := physical.NewInmemHA(logger) - testCore_Standby_Common(t, inmha, inmha) + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + + testCore_Standby_Common(t, inmha, inmha.(physical.HABackend)) } func TestCore_Standby_SeparateHA(t *testing.T) { logger = logformat.NewVaultLogger(log.LevelTrace) - testCore_Standby_Common(t, physical.NewInmemHA(logger), physical.NewInmemHA(logger)) + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha2, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + + testCore_Standby_Common(t, inmha, inmha2.(physical.HABackend)) } func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.HABackend) { @@ -1604,18 +1644,18 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical. t.Fatalf("Bad advertise: %v, orig is %v", advertise, redirectOriginal2) } - if inm.(*physical.InmemHABackend) == inmha.(*physical.InmemHABackend) { - lockSize := inm.(*physical.InmemHABackend).LockMapSize() + if inm.(*inmem.InmemHABackend) == inmha.(*inmem.InmemHABackend) { + lockSize := inm.(*inmem.InmemHABackend).LockMapSize() if lockSize == 0 { t.Fatalf("locks not used with only one HA backend") } } else { - lockSize := inmha.(*physical.InmemHABackend).LockMapSize() + lockSize := inmha.(*inmem.InmemHABackend).LockMapSize() if lockSize == 0 { t.Fatalf("locks not used with expected HA backend") } - lockSize = inm.(*physical.InmemHABackend).LockMapSize() + lockSize = inm.(*inmem.InmemHABackend).LockMapSize() if lockSize != 0 { t.Fatalf("locks used with unexpected HA backend") } @@ -2015,12 +2055,19 @@ func TestCore_Standby_Rotate(t *testing.T) { // Create the first core and initialize it logger = logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) - inmha := physical.NewInmemHA(logger) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal, DisableMlock: true, }) @@ -2041,7 +2088,7 @@ func TestCore_Standby_Rotate(t *testing.T) { redirectOriginal2 := "http://127.0.0.1:8500" core2, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal2, DisableMlock: true, }) diff --git a/vault/expiration_test.go b/vault/expiration_test.go index ad50adf68a..339c5ad8d7 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -2,7 +2,6 @@ package vault import ( "fmt" - "os" "reflect" "sort" "strings" @@ -15,6 +14,7 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" log "github.com/mgutz/logxi/v1" ) @@ -241,16 +241,19 @@ func TestExpiration_Tidy(t *testing.T) { } } +// To avoid pulling in deps for all users of the package, don't leave these +// uncommented in the public tree +/* func BenchmarkExpiration_Restore_Etcd(b *testing.B) { addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR") randPath := fmt.Sprintf("vault-%d/", time.Now().Unix()) logger := logformat.NewVaultLogger(log.LevelTrace) - physicalBackend, err := physical.NewBackend("etcd", logger, map[string]string{ + physicalBackend, err := physEtcd.NewEtcdBackend(map[string]string{ "address": addr, "path": randPath, "max_parallel": "256", - }) + }, logger) if err != nil { b.Fatalf("err: %s", err) } @@ -263,21 +266,26 @@ func BenchmarkExpiration_Restore_Consul(b *testing.B) { randPath := fmt.Sprintf("vault-%d/", time.Now().Unix()) logger := logformat.NewVaultLogger(log.LevelTrace) - physicalBackend, err := physical.NewBackend("consul", logger, map[string]string{ + physicalBackend, err := physConsul.NewConsulBackend(map[string]string{ "address": addr, "path": randPath, "max_parallel": "256", - }) + }, logger) if err != nil { b.Fatalf("err: %s", err) } benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases } +*/ func BenchmarkExpiration_Restore_InMem(b *testing.B) { logger := logformat.NewVaultLogger(log.LevelTrace) - benchmarkExpirationBackend(b, physical.NewInmem(logger), 100000) // 100,000 Leases + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + b.Fatal(err) + } + benchmarkExpirationBackend(b, inm, 100000) // 100,000 Leases } func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, numLeases int) { diff --git a/vault/init_test.go b/vault/init_test.go index 38d95e44bb..91d691d00e 100644 --- a/vault/init_test.go +++ b/vault/init_test.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" ) func TestCore_Init(t *testing.T) { @@ -25,7 +25,10 @@ func TestCore_Init(t *testing.T) { func testCore_NewTestCore(t *testing.T, seal Seal) (*Core, *CoreConfig) { logger := logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } conf := &CoreConfig{ Physical: inm, DisableMlock: true, diff --git a/vault/rekey_test.go b/vault/rekey_test.go index c463325fe4..e6453ad138 100644 --- a/vault/rekey_test.go +++ b/vault/rekey_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/physical/inmem" ) func TestCore_Rekey_Lifecycle(t *testing.T) { @@ -372,12 +373,19 @@ func TestCore_Standby_Rekey(t *testing.T) { // Create the first core and initialize it logger := logformat.NewVaultLogger(log.LevelTrace) - inm := physical.NewInmem(logger) - inmha := physical.NewInmemHA(logger) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + inmha, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal, DisableMlock: true, DisableCache: true, @@ -399,7 +407,7 @@ func TestCore_Standby_Rekey(t *testing.T) { redirectOriginal2 := "http://127.0.0.1:8500" core2, err := NewCore(&CoreConfig{ Physical: inm, - HAPhysical: inmha, + HAPhysical: inmha.(physical.HABackend), RedirectAddr: redirectOriginal2, DisableMlock: true, DisableCache: true, diff --git a/vault/testing.go b/vault/testing.go index ca4df78f00..9d2ee25fe5 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -41,6 +41,8 @@ import ( "github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/physical" "github.com/mitchellh/go-testing-interface" + + physInmem "github.com/hashicorp/vault/physical/inmem" ) // This file contains a number of methods that are useful for unit @@ -96,7 +98,10 @@ func TestCoreNewSeal(t testing.T) *Core { // specified seal for testing. func TestCoreWithSeal(t testing.T, testSeal Seal) *Core { logger := logformat.NewVaultLogger(log.LevelTrace) - physicalBackend := physical.NewInmem(logger) + physicalBackend, err := physInmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } conf := testCoreConfig(t, physicalBackend, logger) @@ -1083,10 +1088,17 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } if coreConfig.Physical == nil { - coreConfig.Physical = physical.NewInmem(logger) + coreConfig.Physical, err = physInmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } } if coreConfig.HAPhysical == nil { - coreConfig.HAPhysical = physical.NewInmemHA(logger) + haPhys, err := physInmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + coreConfig.HAPhysical = haPhys.(physical.HABackend) } c1, err := NewCore(coreConfig)