diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 17114cd8bd..bcede8e29d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -9,13 +9,37 @@ import ( log "github.com/mgutz/logxi/v1" + "github.com/hashicorp/errwrap" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" ) const databaseConfigPath = "database/config/" +type dbPluginInstance struct { + sync.RWMutex + dbplugin.Database + + id string + name string + closed bool +} + +func (dbi *dbPluginInstance) Close() error { + dbi.Lock() + defer dbi.Unlock() + + if dbi.closed { + return nil + } + dbi.closed = true + + return dbi.Database.Close() +} + func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b := Backend(conf) if err := b.Setup(ctx, conf); err != nil { @@ -42,6 +66,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathRoles(&b), pathCredsCreate(&b), pathResetConnection(&b), + pathRotateCredentials(&b), }, Secrets: []*framework.Secret{ @@ -53,72 +78,22 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbplugin.Database) + b.connections = make(map[string]*dbPluginInstance) return &b } type databaseBackend struct { - connections map[string]dbplugin.Database + connections map[string]*dbPluginInstance logger log.Logger *framework.Backend sync.RWMutex } -// closeAllDBs closes all connections from all database types -func (b *databaseBackend) closeAllDBs(ctx context.Context) { - b.Lock() - defer b.Unlock() - - for _, db := range b.connections { - db.Close() - } - - b.connections = make(map[string]dbplugin.Database) -} - -// This function is used to retrieve a database object either from the cached -// connection map. The caller of this function needs to hold the backend's read -// lock. -func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { - db, ok := b.connections[name] - return db, ok -} - -// This function creates a new db object from the stored configuration and -// caches it in the connections map. The caller of this function needs to hold -// the backend's write lock -func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) { - db, ok := b.connections[name] - if ok { - return db, nil - } - - config, err := b.DatabaseConfig(ctx, s, name) - if err != nil { - return nil, err - } - - db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger) - if err != nil { - return nil, err - } - - err = db.Initialize(ctx, config.ConnectionDetails, true) - if err != nil { - db.Close() - return nil, err - } - - b.connections[name] = db - - return db, nil -} - func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration: %s", err) + return nil, errwrap.Wrapf("failed to read connection configuration: {{err}}", err) } if entry == nil { return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) @@ -144,7 +119,7 @@ type upgradeStatements struct { type upgradeCheck struct { // This json tag has a typo in it, the new version does not. This // necessitates this upgrade logic. - Statements upgradeStatements `json:"statments"` + Statements *upgradeStatements `json:"statments,omitempty"` } func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) { @@ -166,48 +141,140 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName return nil, err } - empty := upgradeCheck{} - if upgradeCh != empty { - result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements - result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements - result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements - result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements + switch { + case upgradeCh.Statements != nil: + var stmts dbplugin.Statements + if upgradeCh.Statements.CreationStatements != "" { + stmts.Creation = []string{upgradeCh.Statements.CreationStatements} + } + if upgradeCh.Statements.RevocationStatements != "" { + stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements} + } + if upgradeCh.Statements.RollbackStatements != "" { + stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements} + } + if upgradeCh.Statements.RenewStatements != "" { + stmts.Renewal = []string{upgradeCh.Statements.RenewStatements} + } + result.Statements = stmts } + // For backwards compatibility, copy the values back into the string form + // of the fields + result.Statements = dbutil.StatementCompatibilityHelper(result.Statements) + return &result, nil } func (b *databaseBackend) invalidate(ctx context.Context, key string) { - b.Lock() - defer b.Unlock() - switch { case strings.HasPrefix(key, databaseConfigPath): name := strings.TrimPrefix(key, databaseConfigPath) - b.clearConnection(name) + b.ClearConnection(name) } } -// clearConnection closes the database connection and -// removes it from the b.connections map. -func (b *databaseBackend) clearConnection(name string) { +func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, name string) (*dbPluginInstance, error) { + b.RLock() + unlockFunc := b.RUnlock + defer func() { unlockFunc() }() + db, ok := b.connections[name] if ok { + return db, nil + } + + // Upgrade lock + b.RUnlock() + b.Lock() + unlockFunc = b.Unlock + + db, ok = b.connections[name] + if ok { + return db, nil + } + + config, err := b.DatabaseConfig(ctx, s, name) + if err != nil { + return nil, err + } + + dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger) + if err != nil { + return nil, err + } + + _, err = dbp.Init(ctx, config.ConnectionDetails, true) + if err != nil { + dbp.Close() + return nil, err + } + + id, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + db = &dbPluginInstance{ + Database: dbp, + name: name, + id: id, + } + + b.connections[name] = db + return db, nil +} + +// ClearConnection closes the database connection and +// removes it from the b.connections map. +func (b *databaseBackend) ClearConnection(name string) error { + b.Lock() + defer b.Unlock() + return b.clearConnection(name) +} + +func (b *databaseBackend) clearConnection(name string) error { + db, ok := b.connections[name] + if ok { + // Ignore error here since the database client is always killed db.Close() delete(b.connections, name) } + return nil } -func (b *databaseBackend) closeIfShutdown(name string, err error) { +func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) { // Plugin has shutdown, close it so next call can reconnect. switch err { case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: - b.Lock() - b.clearConnection(name) - b.Unlock() + // Put this in a goroutine so that requests can run with the read or write lock + // and simply defer the unlock. Since we are attaching the instance and matching + // the id in the conneciton map, we can safely do this. + go func() { + b.Lock() + defer b.Unlock() + db.Close() + + // Ensure we are deleting the correct connection + mapDB, ok := b.connections[db.name] + if ok && db.id == mapDB.id { + delete(b.connections, db.name) + } + }() } } +// closeAllDBs closes all connections from all database types +func (b *databaseBackend) closeAllDBs(ctx context.Context) { + b.Lock() + defer b.Unlock() + + for _, db := range b.connections { + db.Close() + } + b.connections = make(map[string]*dbPluginInstance) +} + const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 07e462069c..3b7a7aa74d 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -7,6 +7,7 @@ import ( "log" "os" "reflect" + "strings" "sync" "testing" "time" @@ -27,6 +28,7 @@ var ( ) func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) { + t.Helper() if os.Getenv("PG_URL") != "" { return func() {}, os.Getenv("PG_URL") } @@ -64,7 +66,7 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac }) if err != nil || (resp != nil && resp.IsError()) { // It's likely not up and running yet, so return error and try again - return fmt.Errorf("err:%s resp:%#v\n", err, resp) + return fmt.Errorf("err:%#v resp:%#v", err, resp) } if resp == nil { t.Fatal("expected warning") @@ -123,13 +125,18 @@ func TestBackend_RoleUpgrade(t *testing.T) { storage := &logical.InmemStorage{} backend := &databaseBackend{} - roleEnt := &roleEntry{ + roleExpected := &roleEntry{ Statements: dbplugin.Statements{ CreationStatements: "test", + Creation: []string{"test"}, }, } - entry, err := logical.StorageEntryJSON("role/test", roleEnt) + entry, err := logical.StorageEntryJSON("role/test", &roleEntry{ + Statements: dbplugin.Statements{ + CreationStatements: "test", + }, + }) if err != nil { t.Fatal(err) } @@ -142,8 +149,8 @@ func TestBackend_RoleUpgrade(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(role, roleEnt) { - t.Fatalf("bad role %#v", role) + if !reflect.DeepEqual(role, roleExpected) { + t.Fatalf("bad role %#v, %#v", role, roleExpected) } // Upgrade case @@ -161,8 +168,8 @@ func TestBackend_RoleUpgrade(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(role, roleEnt) { - t.Fatalf("bad role %#v", role) + if !reflect.DeepEqual(role, roleExpected) { + t.Fatalf("bad role %#v, %#v", role, roleExpected) } } @@ -206,7 +213,8 @@ func TestBackend_config_connection(t *testing.T) { "connection_details": map[string]interface{}{ "connection_url": "sample_connection_url", }, - "allowed_roles": []string{"*"}, + "allowed_roles": []string{"*"}, + "root_credentials_rotate_statements": []string{}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(context.Background(), configReq) @@ -233,6 +241,55 @@ func TestBackend_config_connection(t *testing.T) { } } +func TestBackend_BadConnectionString(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(context.Background(), config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + cleanup, _ := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + respCheck := func(req *logical.Request) { + t.Helper() + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected error, resp:%#v", resp) + } + err = resp.Error() + if strings.Contains(err.Error(), "localhost") { + t.Fatalf("error should not contain connection info") + } + } + + // Configure a connection + data := map[string]interface{}{ + "connection_url": "postgresql://:pw@[localhost", + "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + respCheck(req) + + time.Sleep(1 * time.Second) +} + func TestBackend_basic(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() @@ -388,7 +445,6 @@ func TestBackend_basic(t *testing.T) { if testCredsExist(t, credsResp, connURL) { t.Fatalf("Creds should not exist") } - } func TestBackend_connectionCrud(t *testing.T) { @@ -467,7 +523,8 @@ func TestBackend_connectionCrud(t *testing.T) { "connection_details": map[string]interface{}{ "connection_url": connURL, }, - "allowed_roles": []string{"plugin-role-test"}, + "allowed_roles": []string{"plugin-role-test"}, + "root_credentials_rotate_statements": []string{}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(context.Background(), req) @@ -602,15 +659,15 @@ func TestBackend_roleCrud(t *testing.T) { } expected := dbplugin.Statements{ - CreationStatements: testRole, - RevocationStatements: defaultRevocationSQL, + Creation: []string{strings.TrimSpace(testRole)}, + Revocation: []string{strings.TrimSpace(defaultRevocationSQL)}, } actual := dbplugin.Statements{ - CreationStatements: resp.Data["creation_statements"].(string), - RevocationStatements: resp.Data["revocation_statements"].(string), - RollbackStatements: resp.Data["rollback_statements"].(string), - RenewStatements: resp.Data["renew_statements"].(string), + Creation: resp.Data["creation_statements"].([]string), + Revocation: resp.Data["revocation_statements"].([]string), + Rollback: resp.Data["rollback_statements"].([]string), + Renewal: resp.Data["renew_statements"].([]string), } if !reflect.DeepEqual(expected, actual) { diff --git a/builtin/logical/database/dbplugin/database.pb.go b/builtin/logical/database/dbplugin/database.pb.go index c4c4101968..31ae9c5c46 100644 --- a/builtin/logical/database/dbplugin/database.pb.go +++ b/builtin/logical/database/dbplugin/database.pb.go @@ -1,5 +1,6 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. +// Code generated by protoc-gen-go. // source: builtin/logical/database/dbplugin/database.proto +// DO NOT EDIT! /* Package dbplugin is a generated protocol buffer package. @@ -9,13 +10,17 @@ It is generated from these files: It has these top-level messages: InitializeRequest + InitRequest CreateUserRequest RenewUserRequest RevokeUserRequest + RotateRootCredentialsRequest Statements UsernameConfig + InitResponse CreateUserResponse TypeResponse + RotateRootCredentialsResponse Empty */ package dbplugin @@ -65,6 +70,30 @@ func (m *InitializeRequest) GetVerifyConnection() bool { return false } +type InitRequest struct { + Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"` +} + +func (m *InitRequest) Reset() { *m = InitRequest{} } +func (m *InitRequest) String() string { return proto.CompactTextString(m) } +func (*InitRequest) ProtoMessage() {} +func (*InitRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *InitRequest) GetConfig() []byte { + if m != nil { + return m.Config + } + return nil +} + +func (m *InitRequest) GetVerifyConnection() bool { + if m != nil { + return m.VerifyConnection + } + return false +} + type CreateUserRequest struct { Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` @@ -74,7 +103,7 @@ type CreateUserRequest struct { func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} } func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) } func (*CreateUserRequest) ProtoMessage() {} -func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } +func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } func (m *CreateUserRequest) GetStatements() *Statements { if m != nil { @@ -106,7 +135,7 @@ type RenewUserRequest struct { func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} } func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) } func (*RenewUserRequest) ProtoMessage() {} -func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } +func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } func (m *RenewUserRequest) GetStatements() *Statements { if m != nil { @@ -137,7 +166,7 @@ type RevokeUserRequest struct { func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} } func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) } func (*RevokeUserRequest) ProtoMessage() {} -func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } +func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } func (m *RevokeUserRequest) GetStatements() *Statements { if m != nil { @@ -153,17 +182,41 @@ func (m *RevokeUserRequest) GetUsername() string { return "" } +type RotateRootCredentialsRequest struct { + Statements []string `protobuf:"bytes,1,rep,name=statements" json:"statements,omitempty"` +} + +func (m *RotateRootCredentialsRequest) Reset() { *m = RotateRootCredentialsRequest{} } +func (m *RotateRootCredentialsRequest) String() string { return proto.CompactTextString(m) } +func (*RotateRootCredentialsRequest) ProtoMessage() {} +func (*RotateRootCredentialsRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +func (m *RotateRootCredentialsRequest) GetStatements() []string { + if m != nil { + return m.Statements + } + return nil +} + type Statements struct { - CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` + // DEPRECATED, will be removed in 0.12 + CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` + // DEPRECATED, will be removed in 0.12 RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` - RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` - RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` + // DEPRECATED, will be removed in 0.12 + RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` + // DEPRECATED, will be removed in 0.12 + RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` + Creation []string `protobuf:"bytes,5,rep,name=creation" json:"creation,omitempty"` + Revocation []string `protobuf:"bytes,6,rep,name=revocation" json:"revocation,omitempty"` + Rollback []string `protobuf:"bytes,7,rep,name=rollback" json:"rollback,omitempty"` + Renewal []string `protobuf:"bytes,8,rep,name=renewal" json:"renewal,omitempty"` } func (m *Statements) Reset() { *m = Statements{} } func (m *Statements) String() string { return proto.CompactTextString(m) } func (*Statements) ProtoMessage() {} -func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } +func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } func (m *Statements) GetCreationStatements() string { if m != nil { @@ -193,6 +246,34 @@ func (m *Statements) GetRenewStatements() string { return "" } +func (m *Statements) GetCreation() []string { + if m != nil { + return m.Creation + } + return nil +} + +func (m *Statements) GetRevocation() []string { + if m != nil { + return m.Revocation + } + return nil +} + +func (m *Statements) GetRollback() []string { + if m != nil { + return m.Rollback + } + return nil +} + +func (m *Statements) GetRenewal() []string { + if m != nil { + return m.Renewal + } + return nil +} + type UsernameConfig struct { DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` @@ -201,7 +282,7 @@ type UsernameConfig struct { func (m *UsernameConfig) Reset() { *m = UsernameConfig{} } func (m *UsernameConfig) String() string { return proto.CompactTextString(m) } func (*UsernameConfig) ProtoMessage() {} -func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } +func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } func (m *UsernameConfig) GetDisplayName() string { if m != nil { @@ -217,6 +298,22 @@ func (m *UsernameConfig) GetRoleName() string { return "" } +type InitResponse struct { + Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` +} + +func (m *InitResponse) Reset() { *m = InitResponse{} } +func (m *InitResponse) String() string { return proto.CompactTextString(m) } +func (*InitResponse) ProtoMessage() {} +func (*InitResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } + +func (m *InitResponse) GetConfig() []byte { + if m != nil { + return m.Config + } + return nil +} + type CreateUserResponse struct { Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` @@ -225,7 +322,7 @@ type CreateUserResponse struct { func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} } func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) } func (*CreateUserResponse) ProtoMessage() {} -func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } +func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{9} } func (m *CreateUserResponse) GetUsername() string { if m != nil { @@ -248,7 +345,7 @@ type TypeResponse struct { func (m *TypeResponse) Reset() { *m = TypeResponse{} } func (m *TypeResponse) String() string { return proto.CompactTextString(m) } func (*TypeResponse) ProtoMessage() {} -func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } +func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{10} } func (m *TypeResponse) GetType() string { if m != nil { @@ -257,23 +354,43 @@ func (m *TypeResponse) GetType() string { return "" } +type RotateRootCredentialsResponse struct { + Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` +} + +func (m *RotateRootCredentialsResponse) Reset() { *m = RotateRootCredentialsResponse{} } +func (m *RotateRootCredentialsResponse) String() string { return proto.CompactTextString(m) } +func (*RotateRootCredentialsResponse) ProtoMessage() {} +func (*RotateRootCredentialsResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{11} } + +func (m *RotateRootCredentialsResponse) GetConfig() []byte { + if m != nil { + return m.Config + } + return nil +} + type Empty struct { } func (m *Empty) Reset() { *m = Empty{} } func (m *Empty) String() string { return proto.CompactTextString(m) } func (*Empty) ProtoMessage() {} -func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } +func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{12} } func init() { proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") + proto.RegisterType((*InitRequest)(nil), "dbplugin.InitRequest") proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") + proto.RegisterType((*RotateRootCredentialsRequest)(nil), "dbplugin.RotateRootCredentialsRequest") proto.RegisterType((*Statements)(nil), "dbplugin.Statements") proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") + proto.RegisterType((*InitResponse)(nil), "dbplugin.InitResponse") proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") + proto.RegisterType((*RotateRootCredentialsResponse)(nil), "dbplugin.RotateRootCredentialsResponse") proto.RegisterType((*Empty)(nil), "dbplugin.Empty") } @@ -292,8 +409,10 @@ type DatabaseClient interface { CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) - Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) + RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error) + Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) + Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) } type databaseClient struct { @@ -340,9 +459,18 @@ func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, return out, nil } -func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) +func (c *databaseClient) RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error) { + out := new(RotateRootCredentialsResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/RotateRootCredentials", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) { + out := new(InitResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/Init", in, out, c.cc, opts...) if err != nil { return nil, err } @@ -358,6 +486,15 @@ func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.Call return out, nil } +func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // Server API for Database service type DatabaseServer interface { @@ -365,8 +502,10 @@ type DatabaseServer interface { CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) RenewUser(context.Context, *RenewUserRequest) (*Empty, error) RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) - Initialize(context.Context, *InitializeRequest) (*Empty, error) + RotateRootCredentials(context.Context, *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) + Init(context.Context, *InitRequest) (*InitResponse, error) Close(context.Context, *Empty) (*Empty, error) + Initialize(context.Context, *InitializeRequest) (*Empty, error) } func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { @@ -445,20 +584,38 @@ func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func return interceptor(ctx, in, info, handler) } -func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(InitializeRequest) +func _Database_RotateRootCredentials_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RotateRootCredentialsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DatabaseServer).Initialize(ctx, in) + return srv.(DatabaseServer).RotateRootCredentials(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/dbplugin.Database/Initialize", + FullMethod: "/dbplugin.Database/RotateRootCredentials", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) + return srv.(DatabaseServer).RotateRootCredentials(ctx, req.(*RotateRootCredentialsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Init(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Init", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Init(ctx, req.(*InitRequest)) } return interceptor(ctx, in, info, handler) } @@ -481,6 +638,24 @@ func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(inte return interceptor(ctx, in, info, handler) } +func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitializeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Initialize(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Initialize", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) + } + return interceptor(ctx, in, info, handler) +} + var _Database_serviceDesc = grpc.ServiceDesc{ ServiceName: "dbplugin.Database", HandlerType: (*DatabaseServer)(nil), @@ -502,13 +677,21 @@ var _Database_serviceDesc = grpc.ServiceDesc{ Handler: _Database_RevokeUser_Handler, }, { - MethodName: "Initialize", - Handler: _Database_Initialize_Handler, + MethodName: "RotateRootCredentials", + Handler: _Database_RotateRootCredentials_Handler, + }, + { + MethodName: "Init", + Handler: _Database_Init_Handler, }, { MethodName: "Close", Handler: _Database_Close_Handler, }, + { + MethodName: "Initialize", + Handler: _Database_Initialize_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "builtin/logical/database/dbplugin/database.proto", @@ -517,40 +700,49 @@ var _Database_serviceDesc = grpc.ServiceDesc{ func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 548 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, - 0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, - 0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, - 0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, - 0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, - 0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, - 0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, - 0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, - 0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, - 0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, - 0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, - 0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, - 0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, - 0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, - 0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, - 0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, - 0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, - 0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, - 0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, - 0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, - 0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, - 0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, - 0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, - 0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, - 0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, - 0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, - 0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, - 0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, - 0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, - 0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, - 0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, - 0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, - 0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, - 0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, - 0x94, 0x05, 0x00, 0x00, + // 690 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x55, 0x41, 0x4f, 0xdb, 0x4a, + 0x10, 0x96, 0x93, 0x00, 0xc9, 0x80, 0x80, 0xec, 0x03, 0x64, 0xf9, 0xf1, 0xde, 0x43, 0x3e, 0xf0, + 0x40, 0x95, 0xe2, 0x0a, 0x5a, 0xb5, 0xe2, 0xd0, 0xaa, 0x0a, 0x55, 0x55, 0xa9, 0xe2, 0xb0, 0xc0, + 0xad, 0x12, 0xda, 0x38, 0x43, 0xba, 0xc2, 0xf1, 0xba, 0xde, 0x0d, 0x34, 0xfd, 0x03, 0xed, 0xcf, + 0xe8, 0x4f, 0xe9, 0xb1, 0x3f, 0xab, 0xf2, 0xda, 0x6b, 0x6f, 0x62, 0x28, 0x07, 0xda, 0x9b, 0x67, + 0xe6, 0xfb, 0x66, 0xbe, 0x9d, 0x9d, 0x59, 0xc3, 0xe3, 0xc1, 0x84, 0x47, 0x8a, 0xc7, 0x41, 0x24, + 0x46, 0x3c, 0x64, 0x51, 0x30, 0x64, 0x8a, 0x0d, 0x98, 0xc4, 0x60, 0x38, 0x48, 0xa2, 0xc9, 0x88, + 0xc7, 0xa5, 0xa7, 0x97, 0xa4, 0x42, 0x09, 0xd2, 0x36, 0x01, 0xef, 0xbf, 0x91, 0x10, 0xa3, 0x08, + 0x03, 0xed, 0x1f, 0x4c, 0x2e, 0x03, 0xc5, 0xc7, 0x28, 0x15, 0x1b, 0x27, 0x39, 0xd4, 0x7f, 0x0f, + 0xdd, 0xb7, 0x31, 0x57, 0x9c, 0x45, 0xfc, 0x33, 0x52, 0xfc, 0x38, 0x41, 0xa9, 0xc8, 0x16, 0x2c, + 0x86, 0x22, 0xbe, 0xe4, 0x23, 0xd7, 0xd9, 0x71, 0xf6, 0x56, 0x68, 0x61, 0x91, 0x47, 0xd0, 0xbd, + 0xc6, 0x94, 0x5f, 0x4e, 0x2f, 0x42, 0x11, 0xc7, 0x18, 0x2a, 0x2e, 0x62, 0xb7, 0xb1, 0xe3, 0xec, + 0xb5, 0xe9, 0x7a, 0x1e, 0xe8, 0x97, 0xfe, 0xa3, 0x86, 0xeb, 0xf8, 0x14, 0x96, 0xb3, 0xec, 0xbf, + 0x33, 0xaf, 0xff, 0xc3, 0x81, 0x6e, 0x3f, 0x45, 0xa6, 0xf0, 0x5c, 0x62, 0x6a, 0x52, 0x3f, 0x01, + 0x90, 0x8a, 0x29, 0x1c, 0x63, 0xac, 0xa4, 0x4e, 0xbf, 0x7c, 0xb0, 0xd1, 0x33, 0x7d, 0xe8, 0x9d, + 0x96, 0x31, 0x6a, 0xe1, 0xc8, 0x2b, 0x58, 0x9b, 0x48, 0x4c, 0x63, 0x36, 0xc6, 0x8b, 0x42, 0x59, + 0x43, 0x53, 0xdd, 0x8a, 0x7a, 0x5e, 0x00, 0xfa, 0x3a, 0x4e, 0x57, 0x27, 0x33, 0x36, 0x39, 0x02, + 0xc0, 0x4f, 0x09, 0x4f, 0x99, 0x16, 0xdd, 0xd4, 0x6c, 0xaf, 0x97, 0xb7, 0xbd, 0x67, 0xda, 0xde, + 0x3b, 0x33, 0x6d, 0xa7, 0x16, 0xda, 0xff, 0xe6, 0xc0, 0x3a, 0xc5, 0x18, 0x6f, 0x1e, 0x7e, 0x12, + 0x0f, 0xda, 0x46, 0x98, 0x3e, 0x42, 0x87, 0x96, 0xf6, 0x83, 0x24, 0x22, 0x74, 0x29, 0x5e, 0x8b, + 0x2b, 0xfc, 0xa3, 0x12, 0xfd, 0x17, 0xb0, 0x4d, 0x45, 0x06, 0xa5, 0x42, 0xa8, 0x7e, 0x8a, 0x43, + 0x8c, 0xb3, 0x99, 0x94, 0xa6, 0xe2, 0xbf, 0x73, 0x15, 0x9b, 0x7b, 0x1d, 0x3b, 0xb7, 0xff, 0xbd, + 0x01, 0x50, 0x95, 0x25, 0x01, 0xfc, 0x15, 0x66, 0x23, 0xc2, 0x45, 0x7c, 0x31, 0xa7, 0xb4, 0x43, + 0x89, 0x09, 0x59, 0x84, 0x43, 0xd8, 0x4c, 0xf1, 0x5a, 0x84, 0x35, 0x4a, 0x2e, 0x74, 0xa3, 0x0a, + 0xce, 0x56, 0x49, 0x45, 0x14, 0x0d, 0x58, 0x78, 0x65, 0x53, 0x9a, 0x79, 0x15, 0x13, 0xb2, 0x08, + 0xfb, 0xb0, 0x9e, 0x66, 0xd7, 0x6d, 0xa3, 0x5b, 0x1a, 0xbd, 0xa6, 0xfd, 0xa7, 0x33, 0xcd, 0x32, + 0x32, 0xdd, 0x05, 0x7d, 0xdc, 0xd2, 0xce, 0x9a, 0x51, 0xe9, 0x71, 0x17, 0xf3, 0x66, 0x54, 0x9e, + 0x8c, 0x6b, 0x8a, 0xbb, 0x4b, 0x39, 0xd7, 0xd8, 0xc4, 0x85, 0x25, 0x5d, 0x8a, 0x45, 0x6e, 0x5b, + 0x87, 0x8c, 0xe9, 0x9f, 0xc0, 0xea, 0xec, 0xa8, 0x93, 0x1d, 0x58, 0x3e, 0xe6, 0x32, 0x89, 0xd8, + 0xf4, 0x24, 0xbb, 0xb3, 0xbc, 0x7b, 0xb6, 0x2b, 0xab, 0x44, 0x45, 0x84, 0x27, 0xd6, 0x95, 0x1a, + 0xdb, 0xdf, 0x85, 0x95, 0x7c, 0xf7, 0x65, 0x22, 0x62, 0x89, 0x77, 0x2d, 0xbf, 0xff, 0x0e, 0x88, + 0xbd, 0xce, 0x05, 0xda, 0x1e, 0x16, 0x67, 0x6e, 0x9e, 0x3d, 0x68, 0x27, 0x4c, 0xca, 0x1b, 0x91, + 0x0e, 0x4d, 0x55, 0x63, 0xfb, 0x3e, 0xac, 0x9c, 0x4d, 0x13, 0x2c, 0xf3, 0x10, 0x68, 0xa9, 0x69, + 0x62, 0x72, 0xe8, 0x6f, 0xff, 0x19, 0xfc, 0x73, 0xc7, 0xb0, 0xdd, 0x23, 0x75, 0x09, 0x16, 0x5e, + 0x8f, 0x13, 0x35, 0x3d, 0xf8, 0xd2, 0x82, 0xf6, 0x71, 0xf1, 0xe6, 0x92, 0x00, 0x5a, 0x59, 0x49, + 0xb2, 0x56, 0x6d, 0x80, 0x46, 0x79, 0x5b, 0x95, 0x63, 0x46, 0xd3, 0x1b, 0x80, 0xea, 0xc4, 0xe4, + 0xef, 0x0a, 0x55, 0x7b, 0xd6, 0xbc, 0xed, 0xdb, 0x83, 0x45, 0xa2, 0xe7, 0xd0, 0x29, 0x9f, 0x0f, + 0xe2, 0x55, 0xd0, 0xf9, 0x37, 0xc5, 0x9b, 0x97, 0x96, 0x3d, 0x09, 0xd5, 0x5a, 0xdb, 0x12, 0x6a, + 0xcb, 0x5e, 0xe7, 0x7e, 0x80, 0xcd, 0x5b, 0xdb, 0x47, 0x76, 0xad, 0x34, 0xbf, 0x58, 0x66, 0xef, + 0xff, 0x7b, 0x71, 0xc5, 0xf9, 0x9e, 0x42, 0x2b, 0x1b, 0x21, 0xb2, 0x59, 0x11, 0xac, 0xdf, 0x89, + 0xdd, 0xdf, 0x99, 0x49, 0xdb, 0x87, 0x85, 0x7e, 0x24, 0xe4, 0x2d, 0x37, 0x52, 0x3b, 0xcb, 0x4b, + 0x80, 0xea, 0xf7, 0x67, 0xf7, 0xa1, 0xf6, 0x53, 0xac, 0x71, 0xfd, 0xe6, 0xd7, 0x86, 0x33, 0x58, + 0xd4, 0xef, 0xe7, 0xe1, 0xcf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa7, 0x13, 0xfe, 0x55, 0xa5, 0x07, + 0x00, 0x00, } diff --git a/builtin/logical/database/dbplugin/database.proto b/builtin/logical/database/dbplugin/database.proto index d5e7d4068f..52d7c8c228 100644 --- a/builtin/logical/database/dbplugin/database.proto +++ b/builtin/logical/database/dbplugin/database.proto @@ -4,6 +4,12 @@ package dbplugin; import "google/protobuf/timestamp.proto"; message InitializeRequest { + option deprecated = true; + bytes config = 1; + bool verify_connection = 2; +} + +message InitRequest { bytes config = 1; bool verify_connection = 2; } @@ -25,11 +31,24 @@ message RevokeUserRequest { string username = 2; } +message RotateRootCredentialsRequest { + repeated string statements = 1; +} + message Statements { + // DEPRECATED, will be removed in 0.12 string creation_statements = 1; + // DEPRECATED, will be removed in 0.12 string revocation_statements = 2; + // DEPRECATED, will be removed in 0.12 string rollback_statements = 3; + // DEPRECATED, will be removed in 0.12 string renew_statements = 4; + + repeated string creation = 5; + repeated string revocation = 6; + repeated string rollback = 7; + repeated string renewal = 8; } message UsernameConfig { @@ -37,22 +56,35 @@ message UsernameConfig { string RoleName = 2; } +message InitResponse { + bytes config = 1; +} + message CreateUserResponse { string username = 1; string password = 2; } message TypeResponse { - string type = 1; + string type = 1; +} + +message RotateRootCredentialsResponse { + bytes config = 1; } message Empty {} service Database { - rpc Type(Empty) returns (TypeResponse); - rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); - rpc RenewUser(RenewUserRequest) returns (Empty); - rpc RevokeUser(RevokeUserRequest) returns (Empty); - rpc Initialize(InitializeRequest) returns (Empty); - rpc Close(Empty) returns (Empty); + rpc Type(Empty) returns (TypeResponse); + rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); + rpc RenewUser(RenewUserRequest) returns (Empty); + rpc RevokeUser(RevokeUserRequest) returns (Empty); + rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse); + rpc Init(InitRequest) returns (InitResponse); + rpc Close(Empty) returns (Empty); + + rpc Initialize(InitializeRequest) returns (Empty) { + option deprecated = true; + }; } diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index c8bbdf61d5..36a7558e5b 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -2,8 +2,14 @@ package dbplugin import ( "context" + "errors" + "net/url" + "strings" + "sync" "time" + "github.com/hashicorp/errwrap" + metrics "github.com/armon/go-metrics" log "github.com/mgutz/logxi/v1" ) @@ -51,13 +57,27 @@ func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) { + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.RotateRootCredentials(ctx, statements) +} + +func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := mw.Init(ctx, conf, verifyConnection) + return err +} + +func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) { defer func(then time.Time) { mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) - return mw.next.Initialize(ctx, conf, verifyConnection) + return mw.next.Init(ctx, conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { @@ -131,7 +151,28 @@ func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1) + return mw.next.RotateRootCredentials(ctx, statements) +} + +func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := mw.Init(ctx, conf, verifyConnection) + return err +} + +func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) @@ -144,7 +185,7 @@ func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[st metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) - return mw.next.Initialize(ctx, conf, verifyConnection) + return mw.next.Init(ctx, conf, verifyConnection) } func (mw *databaseMetricsMiddleware) Close() (err error) { @@ -162,3 +203,76 @@ func (mw *databaseMetricsMiddleware) Close() (err error) { metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) return mw.next.Close() } + +// ---- Error Sanitizer Middleware Domain ---- + +// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and +// sanitizes returned error messages +type DatabaseErrorSanitizerMiddleware struct { + l sync.RWMutex + next Database + secretsFn func() map[string]interface{} +} + +func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware { + return &DatabaseErrorSanitizerMiddleware{ + next: next, + secretsFn: secretsFn, + } +} + +func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) { + dbType, err := mw.next.Type() + return dbType, mw.sanitize(err) +} + +func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { + username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration) + return username, password, mw.sanitize(err) +} + +func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { + return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration)) +} + +func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { + return mw.sanitize(mw.next.RevokeUser(ctx, statements, username)) +} + +func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) { + conf, err = mw.next.RotateRootCredentials(ctx, statements) + return conf, mw.sanitize(err) +} + +func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := mw.Init(ctx, conf, verifyConnection) + return err +} + +func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) { + saveConf, err = mw.next.Init(ctx, conf, verifyConnection) + return saveConf, mw.sanitize(err) +} + +func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) { + return mw.sanitize(mw.next.Close()) +} + +// sanitize +func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error { + if err == nil { + return nil + } + if errwrap.ContainsType(err, new(url.Error)) { + return errors.New("unable to parse connection url") + } + if mw.secretsFn != nil { + for k, v := range mw.secretsFn() { + if k == "" { + continue + } + err = errors.New(strings.Replace(err.Error(), k, v.(string), -1)) + } + } + return err +} diff --git a/builtin/logical/database/dbplugin/grpc_transport.go b/builtin/logical/database/dbplugin/grpc_transport.go index 735d9e5b88..1b5267e8f3 100644 --- a/builtin/logical/database/dbplugin/grpc_transport.go +++ b/builtin/logical/database/dbplugin/grpc_transport.go @@ -7,6 +7,8 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/helper/pluginutil" @@ -61,16 +63,51 @@ func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*E return &Empty{}, err } -func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { - config := map[string]interface{}{} +func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) { + resp, err := s.impl.RotateRootCredentials(ctx, req.Statements) + if err != nil { + return nil, err + } + + respConfig, err := json.Marshal(resp) + if err != nil { + return nil, err + } + + return &RotateRootCredentialsResponse{ + Config: respConfig, + }, err +} + +func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { + _, err := s.Init(ctx, &InitRequest{ + Config: req.Config, + VerifyConnection: req.VerifyConnection, + }) + return &Empty{}, err +} + +func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) { + config := map[string]interface{}{} err := json.Unmarshal(req.Config, &config) if err != nil { return nil, err } - err = s.impl.Initialize(ctx, config, req.VerifyConnection) - return &Empty{}, err + resp, err := s.impl.Init(ctx, config, req.VerifyConnection) + if err != nil { + return nil, err + } + + respConfig, err := json.Marshal(resp) + if err != nil { + return nil, err + } + + return &InitResponse{ + Config: respConfig, + }, err } func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { @@ -87,7 +124,7 @@ type gRPCClient struct { doneCtx context.Context } -func (c gRPCClient) Type() (string, error) { +func (c *gRPCClient) Type() (string, error) { resp, err := c.client.Type(c.doneCtx, &Empty{}) if err != nil { return "", err @@ -96,7 +133,7 @@ func (c gRPCClient) Type() (string, error) { return resp.Type, err } -func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { t, err := ptypes.TimestampProto(expiration) if err != nil { return "", "", err @@ -172,10 +209,40 @@ func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, user return nil } -func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { - configRaw, err := json.Marshal(config) +func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) { + ctx, cancel := context.WithCancel(ctx) + quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx) + defer close(quitCh) + defer cancel() + + resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{ + Statements: statements, + }) + if err != nil { - return err + if c.doneCtx.Err() != nil { + return nil, ErrPluginShutdown + } + + return nil, err + } + + if err := json.Unmarshal(resp.Config, &conf); err != nil { + return nil, err + } + + return conf, nil +} + +func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := c.Init(ctx, conf, verifyConnection) + return err +} + +func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { + configRaw, err := json.Marshal(conf) + if err != nil { + return nil, err } ctx, cancel := context.WithCancel(ctx) @@ -183,19 +250,33 @@ func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface defer close(quitCh) defer cancel() - _, err = c.client.Initialize(ctx, &InitializeRequest{ + resp, err := c.client.Init(ctx, &InitRequest{ Config: configRaw, VerifyConnection: verifyConnection, }) if err != nil { - if c.doneCtx.Err() != nil { - return ErrPluginShutdown + // Fall back to old call if not implemented + grpcStatus, ok := status.FromError(err) + if ok && grpcStatus.Code() == codes.Unimplemented { + _, err = c.client.Initialize(ctx, &InitializeRequest{ + Config: configRaw, + VerifyConnection: verifyConnection, + }) + if err == nil { + return conf, nil + } } - return err + if c.doneCtx.Err() != nil { + return nil, ErrPluginShutdown + } + return nil, err } - return nil + if err := json.Unmarshal(resp.Config, &conf); err != nil { + return nil, err + } + return conf, nil } func (c *gRPCClient) Close() error { diff --git a/builtin/logical/database/dbplugin/netrpc_transport.go b/builtin/logical/database/dbplugin/netrpc_transport.go index 6f6f3a5bfe..25cbc97967 100644 --- a/builtin/logical/database/dbplugin/netrpc_transport.go +++ b/builtin/logical/database/dbplugin/netrpc_transport.go @@ -2,8 +2,10 @@ package dbplugin import ( "context" + "encoding/json" "fmt" "net/rpc" + "strings" "time" ) @@ -37,8 +39,28 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *str return err } +func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error { + config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements) + if err != nil { + return err + } + resp.Config, err = json.Marshal(config) + return err +} + func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { - err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) + return ds.Init(&InitRequestRPC{ + Config: args.Config, + VerifyConnection: args.VerifyConnection, + }, &InitResponse{}) +} + +func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error { + config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection) + if err != nil { + return err + } + resp.Config, err = json.Marshal(config) return err } @@ -81,9 +103,7 @@ func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements State Expiration: expiration, } - err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) - - return err + return dr.client.Call("Plugin.RenewUser", req, &struct{}{}) } func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { @@ -92,26 +112,55 @@ func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Stat Username: username, } - err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + return dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) +} - return err +func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) { + req := RotateRootCredentialsRequestRPC{ + Statements: statements, + } + + var resp RotateRootCredentialsResponse + err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp) + + err = json.Unmarshal(resp.Config, &saveConf) + return saveConf, err } func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { - req := InitializeRequestRPC{ + _, err := dr.Init(nil, conf, verifyConnection) + return err +} + +func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) { + req := InitRequestRPC{ Config: conf, VerifyConnection: verifyConnection, } - err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) + var resp InitResponse + err = dr.client.Call("Plugin.Init", req, &resp) + if err != nil { + if strings.Contains(err.Error(), "can't find method Plugin.Init") { + req := InitializeRequestRPC{ + Config: conf, + VerifyConnection: verifyConnection, + } - return err + err = dr.client.Call("Plugin.Initialize", req, &struct{}{}) + if err == nil { + return conf, nil + } + } + return nil, err + } + + err = json.Unmarshal(resp.Config, &saveConf) + return saveConf, err } func (dr *databasePluginRPCClient) Close() error { - err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) - - return err + return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) } // ---- RPC Request Args Domain ---- @@ -121,6 +170,11 @@ type InitializeRequestRPC struct { VerifyConnection bool } +type InitRequestRPC struct { + Config map[string]interface{} + VerifyConnection bool +} + type CreateUserRequestRPC struct { Statements Statements UsernameConfig UsernameConfig @@ -137,3 +191,7 @@ type RevokeUserRequestRPC struct { Statements Statements Username string } + +type RotateRootCredentialsRequestRPC struct { + Statements []string +} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index f776049a6b..a447e24262 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -8,6 +8,7 @@ import ( "google.golang.org/grpc" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" log "github.com/mgutz/logxi/v1" @@ -20,8 +21,13 @@ type Database interface { RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error RevokeUser(ctx context.Context, statements Statements, username string) error - Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error + RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error) + + Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error) Close() error + + // DEPRECATED, will be removed in 0.12 + Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error) } // PluginFactory is used to build plugin database types. It wraps the database @@ -40,7 +46,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu // from the pluginRunner. Then cast it to a Database. dbRaw, err := pluginRunner.BuiltinFactory() if err != nil { - return nil, fmt.Errorf("error getting plugin type: %s", err) + return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err) } var ok bool @@ -71,7 +77,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu typeStr, err := db.Type() if err != nil { - return nil, fmt.Errorf("error getting plugin type: %s", err) + return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err) } // Wrap with metrics middleware @@ -113,7 +119,11 @@ type DatabasePlugin struct { } func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { - return &databasePluginRPCServer{impl: d.impl}, nil + impl := &DatabaseErrorSanitizerMiddleware{ + next: d.impl, + } + + return &databasePluginRPCServer{impl: impl}, nil } func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { @@ -121,7 +131,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e } func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { - RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) + impl := &DatabaseErrorSanitizerMiddleware{ + next: d.impl, + } + + RegisterDatabaseServer(s, &gRPCServer{impl: impl}) return nil } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 1547fcf908..f33f553e30 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -61,6 +61,17 @@ func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statement delete(m.users, username) return nil } +func (m *mockPlugin) RotateRootCredentials(_ context.Context, statements []string) (map[string]interface{}, error) { + return nil, nil +} +func (m *mockPlugin) Init(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) { + err := errors.New("err") + if len(conf) != 1 { + return nil, err + } + + return conf, nil +} func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { @@ -132,7 +143,7 @@ func TestPlugin_NetRPC_Main(t *testing.T) { plugin.Serve(serveConf) } -func TestPlugin_Initialize(t *testing.T) { +func TestPlugin_Init(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() @@ -145,7 +156,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(context.Background(), connectionDetails, true) + _, err = dbRaw.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -170,7 +181,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -209,7 +220,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -243,7 +254,7 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -272,7 +283,7 @@ func TestPlugin_RevokeUser(t *testing.T) { } // Test the code is still compatible with an old netRPC plugin -func TestPlugin_NetRPC_Initialize(t *testing.T) { +func TestPlugin_NetRPC_Init(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() @@ -285,7 +296,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(context.Background(), connectionDetails, true) + _, err = dbRaw.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -310,7 +321,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -349,7 +360,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -383,7 +394,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f486788bbd..e81766e12a 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/fatih/structs" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -24,6 +25,8 @@ type DatabaseConfig struct { // by each database type. ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"` + + RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"` } // pathResetConnection configures a path to reset a plugin. @@ -55,16 +58,13 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { return logical.ErrorResponse(respErrEmptyName), nil } - // Grab the mutex lock - b.Lock() - defer b.Unlock() - // Close plugin and delete the entry in the connections cache. - b.clearConnection(name) + if err := b.ClearConnection(name); err != nil { + return nil, err + } // Execute plugin again, we don't need the object so throw away. - _, err := b.createDBObj(ctx, req.Storage, name) - if err != nil { + if _, err := b.GetConnection(ctx, req.Storage, name); err != nil { return nil, err } @@ -103,6 +103,14 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { allowed to get creds from this database connection. If empty no roles are allowed. If "*" all roles are allowed.`, }, + + "root_rotation_statements": &framework.FieldSchema{ + Type: framework.TypeStringSlice, + Description: `Specifies the database statements to be executed + to rotate the root user's credentials. See the plugin's API + page for more information on support and formatting for this + parameter.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -179,16 +187,8 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return nil, errors.New("failed to delete connection configuration") } - b.Lock() - defer b.Unlock() - - if _, ok := b.connections[name]; ok { - err = b.connections[name].Close() - if err != nil { - return nil, err - } - - delete(b.connections, name) + if err := b.ClearConnection(name); err != nil { + return nil, err } return nil, nil @@ -210,8 +210,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } verifyConnection := data.Get("verify_connection").(bool) - allowedRoles := data.Get("allowed_roles").([]string) + rootRotationStatements := data.Get("root_rotation_statements").([]string) // Remove these entries from the data before we store it keyed under // ConnectionDetails. @@ -219,35 +219,45 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { delete(data.Raw, "plugin_name") delete(data.Raw, "allowed_roles") delete(data.Raw, "verify_connection") + delete(data.Raw, "root_rotation_statements") - config := &DatabaseConfig{ - ConnectionDetails: data.Raw, - PluginName: pluginName, - AllowedRoles: allowedRoles, - } - - db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger) + // Create a database plugin and initialize it. This instance is not + // going to be used and is initialized just to ensure all parameters + // are valid and the connection is verified, if requested. + db, err := dbplugin.PluginFactory(ctx, pluginName, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } - - err = db.Initialize(ctx, config.ConnectionDetails, verifyConnection) + connDetails, err := db.Init(ctx, data.Raw, verifyConnection) if err != nil { db.Close() return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } - // Grab the mutex lock b.Lock() defer b.Unlock() // Close and remove the old connection b.clearConnection(name) - // Save the new connection - b.connections[name] = db + id, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + b.connections[name] = &dbPluginInstance{ + Database: db, + name: name, + id: id, + } // Store it + config := &DatabaseConfig{ + ConnectionDetails: connDetails, + PluginName: pluginName, + AllowedRoles: allowedRoles, + RootCredentialsRotateStatements: rootRotationStatements, + } entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) if err != nil { return nil, err diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index b32929348a..06328a895b 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -54,26 +54,15 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } - // Grab the read lock - b.RLock() - unlockFunc := b.RUnlock - // Get the Database object - db, ok := b.getDBObj(role.DBName) - if !ok { - // Upgrade lock - b.RUnlock() - b.Lock() - unlockFunc = b.Unlock - - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err) - } + db, err := b.GetConnection(ctx, req.Storage, role.DBName) + if err != nil { + return nil, err } + db.RLock() + defer db.RUnlock() + ttl := b.System().DefaultLeaseTTL() if role.DefaultTTL != 0 { ttl = role.DefaultTTL @@ -96,8 +85,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Create the user username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration) if err != nil { - unlockFunc() - b.closeIfShutdown(role.DBName, err) + b.CloseIfShutdown(db, err) return nil, err } @@ -109,8 +97,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { "role": name, }) resp.Secret.TTL = ttl - - unlockFunc() return resp, nil } } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 4762082f9c..7fe30f2eb9 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -36,26 +36,26 @@ func pathRoles(b *databaseBackend) *framework.Path { Description: "Name of the database this role acts on.", }, "creation_statements": { - Type: framework.TypeString, + Type: framework.TypeStringSlice, Description: `Specifies the database statements executed to create and configure a user. See the plugin's API page for more information on support and formatting for this parameter.`, }, "revocation_statements": { - Type: framework.TypeString, + Type: framework.TypeStringSlice, Description: `Specifies the database statements to be executed to revoke a user. See the plugin's API page for more information on support and formatting for this parameter.`, }, "renew_statements": { - Type: framework.TypeString, + Type: framework.TypeStringSlice, Description: `Specifies the database statements to be executed to renew a user. Not every plugin type will support this functionality. See the plugin's API page for more information on support and formatting for this parameter. `, }, "rollback_statements": { - Type: framework.TypeString, + Type: framework.TypeStringSlice, Description: `Specifies the database statements to be executed rollback a create operation in the event of an error. Not every plugin type will support this functionality. See the plugin's @@ -109,10 +109,10 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc { return &logical.Response{ Data: map[string]interface{}{ "db_name": role.DBName, - "creation_statements": role.Statements.CreationStatements, - "revocation_statements": role.Statements.RevocationStatements, - "rollback_statements": role.Statements.RollbackStatements, - "renew_statements": role.Statements.RenewStatements, + "creation_statements": role.Statements.Creation, + "revocation_statements": role.Statements.Revocation, + "rollback_statements": role.Statements.Rollback, + "renew_statements": role.Statements.Renewal, "default_ttl": role.DefaultTTL.Seconds(), "max_ttl": role.MaxTTL.Seconds(), }, @@ -144,10 +144,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { } // Get statements - creationStmts := data.Get("creation_statements").(string) - revocationStmts := data.Get("revocation_statements").(string) - rollbackStmts := data.Get("rollback_statements").(string) - renewStmts := data.Get("renew_statements").(string) + creationStmts := data.Get("creation_statements").([]string) + revocationStmts := data.Get("revocation_statements").([]string) + rollbackStmts := data.Get("rollback_statements").([]string) + renewStmts := data.Get("renew_statements").([]string) // Get TTLs defaultTTLRaw := data.Get("default_ttl").(int) @@ -156,10 +156,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { maxTTL := time.Duration(maxTTLRaw) * time.Second statements := dbplugin.Statements{ - CreationStatements: creationStmts, - RevocationStatements: revocationStmts, - RollbackStatements: rollbackStmts, - RenewStatements: renewStmts, + Creation: creationStmts, + Revocation: revocationStmts, + Rollback: rollbackStmts, + Renewal: renewStmts, } // Store it @@ -181,10 +181,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name"` + Statements dbplugin.Statements `json:"statements"` + DefaultTTL time.Duration `json:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/path_rotate_credentials.go b/builtin/logical/database/path_rotate_credentials.go new file mode 100644 index 0000000000..e8598f291e --- /dev/null +++ b/builtin/logical/database/path_rotate_credentials.go @@ -0,0 +1,80 @@ +package database + +import ( + "context" + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRotateCredentials(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "rotate-root/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this database connection", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRotateCredentialsUpdate(), + }, + + HelpSynopsis: pathCredsCreateReadHelpSyn, + HelpDescription: pathCredsCreateReadHelpDesc, + } +} + +func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse(respErrEmptyName), nil + } + + config, err := b.DatabaseConfig(ctx, req.Storage, name) + if err != nil { + return nil, err + } + + db, err := b.GetConnection(ctx, req.Storage, name) + if err != nil { + return nil, err + } + + // Take the write lock instead of read since we are updating the + // connection + db.Lock() + defer db.Unlock() + + connectionDetails, err := db.RotateRootCredentials(ctx, config.RootCredentialsRotateStatements) + if err != nil { + return nil, err + } + + config.ConnectionDetails = connectionDetails + entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + + if err := b.ClearConnection(name); err != nil { + return nil, err + } + + return nil, nil + } +} + +const pathRotateCredentialsUpdateHelpSyn = ` +Request to rotate the root credentials for a certain database connection. +` + +const pathRotateCredentialsUpdateHelpDesc = ` +This path attempts to rotate the root credentials for the given database. +` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 55f6220be1..754ff0e260 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -48,37 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { return nil, err } - // Grab the read lock - b.RLock() - unlockFunc := b.RUnlock - // Get the Database object - db, ok := b.getDBObj(role.DBName) - if !ok { - // Upgrade lock - b.RUnlock() - b.Lock() - unlockFunc = b.Unlock - - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err) - } + db, err := b.GetConnection(ctx, req.Storage, role.DBName) + if err != nil { + return nil, err } + db.RLock() + defer db.RUnlock() + // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { err := db.RenewUser(ctx, role.Statements, username, expireTime) if err != nil { - unlockFunc() - b.closeIfShutdown(role.DBName, err) + b.CloseIfShutdown(db, err) return nil, err } } - - unlockFunc() return resp, nil } } @@ -107,33 +93,19 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"]) } - // Grab the read lock - b.RLock() - unlockFunc := b.RUnlock - // Get our connection - db, ok := b.getDBObj(role.DBName) - if !ok { - // Upgrade lock - b.RUnlock() - b.Lock() - unlockFunc = b.Unlock - - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err) - } - } - - if err := db.RevokeUser(ctx, role.Statements, username); err != nil { - unlockFunc() - b.closeIfShutdown(role.DBName, err) + db, err := b.GetConnection(ctx, req.Storage, role.DBName) + if err != nil { return nil, err } - unlockFunc() + db.RLock() + defer db.RUnlock() + + if err := db.RevokeUser(ctx, role.Statements, username); err != nil { + b.CloseIfShutdown(db, err) + return nil, err + } return resp, nil } } diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 221784e0fc..05ad662619 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -11,27 +11,34 @@ import ( "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" - "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" ) const ( - defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultUserDeletionCQL = `DROP USER '{{username}}';` - cassandraTypeName = "cassandra" + defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultUserDeletionCQL = `DROP USER '{{username}}';` + defaultRootCredentialRotationCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';` + cassandraTypeName = "cassandra" ) var _ dbplugin.Database = &Cassandra{} // Cassandra is an implementation of Database interface type Cassandra struct { - connutil.ConnectionProducer + *cassandraConnectionProducer credsutil.CredentialsProducer } // New returns a new Cassandra instance func New() (interface{}, error) { + db := new() + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) + + return dbType, nil +} + +func new() *Cassandra { connProducer := &cassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -42,12 +49,10 @@ func New() (interface{}, error) { Separator: "_", } - dbType := &Cassandra{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &Cassandra{ + cassandraConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - - return dbType, nil } // Run instantiates a Cassandra object, and runs the RPC server for the plugin @@ -57,7 +62,7 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - plugins.Serve(dbType.(*Cassandra), apiTLSConfig) + plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig) return nil } @@ -83,19 +88,22 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen c.Lock() defer c.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + // Get the connection session, err := c.getConnection(ctx) if err != nil { return "", "", err } - creationCQL := statements.CreationStatements - if creationCQL == "" { - creationCQL = defaultUserCreationCQL + creationCQL := statements.Creation + if len(creationCQL) == 0 { + creationCQL = []string{defaultUserCreationCQL} } - rollbackCQL := statements.RollbackStatements - if rollbackCQL == "" { - rollbackCQL = defaultUserDeletionCQL + + rollbackCQL := statements.Rollback + if len(rollbackCQL) == 0 { + rollbackCQL = []string{defaultUserDeletionCQL} } username, err = c.GenerateUsername(usernameConfig) @@ -112,28 +120,32 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen } // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - err = session.Query(dbutil.QueryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - session.Query(dbutil.QueryHelper(query, map[string]string{ - "username": username, - })).Exec() + for _, stmt := range creationCQL { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err = session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, stmt := range rollbackCQL { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + })).Exec() + } + } + return "", "", err } - return "", "", err } } @@ -152,29 +164,79 @@ func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statemen c.Lock() defer c.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + session, err := c.getConnection(ctx) if err != nil { return err } - revocationCQL := statements.RevocationStatements - if revocationCQL == "" { - revocationCQL = defaultUserDeletionCQL + revocationCQL := statements.Revocation + if len(revocationCQL) == 0 { + revocationCQL = []string{defaultUserDeletionCQL} } var result *multierror.Error - for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue + for _, stmt := range revocationCQL { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err := session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + })).Exec() + + result = multierror.Append(result, err) } - - err := session.Query(dbutil.QueryHelper(query, map[string]string{ - "username": username, - })).Exec() - - result = multierror.Append(result, err) } return result.ErrorOrNil() } + +func (c *Cassandra) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + // Grab the lock + c.Lock() + defer c.Unlock() + + session, err := c.getConnection(ctx) + if err != nil { + return nil, err + } + + rotateCQL := statements + if len(rotateCQL) == 0 { + rotateCQL = []string{defaultRootCredentialRotationCQL} + } + + password, err := c.GeneratePassword() + if err != nil { + return nil, err + } + + var result *multierror.Error + for _, stmt := range rotateCQL { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err := session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": c.Username, + "password": password, + })).Exec() + + result = multierror.Append(result, err) + } + } + + err = result.ErrorOrNil() + if err != nil { + return nil, err + } + + c.rawConfig["password"] = password + return c.rawConfig, nil +} diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index 0409f22441..5bb8d9c2cd 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gocql/gocql" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -60,7 +61,7 @@ func prepareCassandraTestContainer(t *testing.T) (func(), string, int) { session, err := clusterConfig.CreateSession() if err != nil { - return fmt.Errorf("error creating session: %s", err) + return errwrap.Wrapf("error creating session: {{err}}", err) } defer session.Close() return nil @@ -86,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) - connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -113,7 +111,7 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": "4", } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -134,15 +132,14 @@ func TestCassandra_CreateUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testCassandraRole, + Creation: []string{testCassandraRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -175,15 +172,14 @@ func TestMyCassandra_RenewUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testCassandraRole, + Creation: []string{testCassandraRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -221,15 +217,14 @@ func TestCassandra_RevokeUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testCassandraRole, + Creation: []string{testCassandraRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -268,7 +263,7 @@ func testCredsExist(t testing.TB, address string, port int, username, password s session, err := clusterConfig.CreateSession() if err != nil { - return fmt.Errorf("error creating session: %s", err) + return errwrap.Wrapf("error creating session: {{err}}", err) } defer session.Close() return nil diff --git a/plugins/database/cassandra/connection_producer.go b/plugins/database/cassandra/connection_producer.go index ff4cae79f0..700f963fe2 100644 --- a/plugins/database/cassandra/connection_producer.go +++ b/plugins/database/cassandra/connection_producer.go @@ -11,6 +11,7 @@ import ( "github.com/mitchellh/mapstructure" "github.com/gocql/gocql" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/tlsutil" @@ -37,6 +38,7 @@ type cassandraConnectionProducer struct { certificate string privateKey string issuingCA string + rawConfig map[string]interface{} Initialized bool Type string @@ -45,12 +47,19 @@ type cassandraConnectionProducer struct { } func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := c.Init(ctx, conf, verifyConnection) + return err +} + +func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { c.Lock() defer c.Unlock() + c.rawConfig = conf + err := mapstructure.WeakDecode(conf, c) if err != nil { - return err + return nil, err } if c.ConnectTimeoutRaw == nil { @@ -58,16 +67,16 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s } c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) if err != nil { - return fmt.Errorf("invalid connect_timeout: %s", err) + return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err) } switch { case len(c.Hosts) == 0: - return fmt.Errorf("hosts cannot be empty") + return nil, fmt.Errorf("hosts cannot be empty") case len(c.Username) == 0: - return fmt.Errorf("username cannot be empty") + return nil, fmt.Errorf("username cannot be empty") case len(c.Password) == 0: - return fmt.Errorf("password cannot be empty") + return nil, fmt.Errorf("password cannot be empty") } var certBundle *certutil.CertBundle @@ -76,11 +85,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s case len(c.PemJSON) != 0: parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) if err != nil { - return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err) + return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err) } certBundle, err = parsedCertBundle.ToCertBundle() if err != nil { - return fmt.Errorf("Error marshaling PEM information: %s", err) + return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) } c.certificate = certBundle.Certificate c.privateKey = certBundle.PrivateKey @@ -90,11 +99,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s case len(c.PemBundle) != 0: parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) if err != nil { - return fmt.Errorf("Error parsing the given PEM information: %s", err) + return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err) } certBundle, err = parsedCertBundle.ToCertBundle() if err != nil { - return fmt.Errorf("Error marshaling PEM information: %s", err) + return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) } c.certificate = certBundle.Certificate c.privateKey = certBundle.PrivateKey @@ -108,11 +117,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s if verifyConnection { if _, err := c.Connection(ctx); err != nil { - return fmt.Errorf("error verifying connection: %s", err) + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) } } - return nil + return conf, nil } func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { @@ -186,12 +195,12 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { parsedCertBundle, err := certBundle.ToParsedCertBundle() if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err) } tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err) } tlsConfig.InsecureSkipVerify = c.InsecureTLS @@ -215,7 +224,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { session, err := clusterConfig.CreateSession() if err != nil { - return nil, fmt.Errorf("error creating session: %s", err) + return nil, errwrap.Wrapf("error creating session: {{err}}", err) } // Set consistency @@ -231,8 +240,16 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { // Verify the info err = session.Query(`LIST ALL`).Exec() if err != nil { - return nil, fmt.Errorf("error validating connection info: %s", err) + return nil, errwrap.Wrapf("error validating connection info: {{err}}", err) } return session, nil } + +func (c *cassandraConnectionProducer) secretValues() map[string]interface{} { + return map[string]interface{}{ + c.Password: "[password]", + c.PemBundle: "[pem_bundle]", + c.PemJSON: "[pem_json]", + } +} diff --git a/plugins/database/cassandra/test-fixtures/cassandra.yaml b/plugins/database/cassandra/test-fixtures/cassandra.yaml index 82aa35dd4d..e8535107c8 100644 --- a/plugins/database/cassandra/test-fixtures/cassandra.yaml +++ b/plugins/database/cassandra/test-fixtures/cassandra.yaml @@ -572,7 +572,7 @@ ssl_storage_port: 7001 # # Setting listen_address to 0.0.0.0 is always wrong. # -listen_address: 172.17.0.5 +listen_address: 172.17.0.2 # Set listen_address OR listen_interface, not both. Interfaces must correspond # to a single address, IP aliasing is not supported. diff --git a/plugins/database/hana/hana.go b/plugins/database/hana/hana.go index 5411505c8b..1fdafe77ad 100644 --- a/plugins/database/hana/hana.go +++ b/plugins/database/hana/hana.go @@ -3,6 +3,7 @@ package hana import ( "context" "database/sql" + "errors" "fmt" "strings" "time" @@ -23,7 +24,7 @@ const ( // HANA is an implementation of Database interface type HANA struct { - connutil.ConnectionProducer + *connutil.SQLConnectionProducer credsutil.CredentialsProducer } @@ -31,6 +32,14 @@ var _ dbplugin.Database = &HANA{} // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + + return dbType, nil +} + +func new() *HANA { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = hanaTypeName @@ -41,12 +50,10 @@ func New() (interface{}, error) { Separator: "_", } - dbType := &HANA{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &HANA{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - - return dbType, nil } // Run instantiates a HANA object, and runs the RPC server for the plugin @@ -56,7 +63,7 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - plugins.Serve(dbType.(*HANA), apiTLSConfig) + plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig) return nil } @@ -82,13 +89,15 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u h.Lock() defer h.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + // Get the connection db, err := h.getConnection(ctx) if err != nil { return "", "", err } - if statements.CreationStatements == "" { + if len(statements.Creation) == 0 { return "", "", dbutil.ErrEmptyCreationStatement } @@ -127,23 +136,25 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range statements.Creation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expirationStr, - })) - if err != nil { - return "", "", err - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return "", "", err + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return "", "", err + } } } @@ -157,6 +168,8 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u // Renewing hana user just means altering user's valid until property func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { + statements = dbutil.StatementCompatibilityHelper(statements) + // Get connection db, err := h.getConnection(ctx) if err != nil { @@ -197,8 +210,10 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us // Revoking hana user will deactivate user and try to perform a soft drop func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { + statements = dbutil.StatementCompatibilityHelper(statements) + // default revoke will be a soft drop on user - if statements.RevocationStatements == "" { + if len(statements.Revocation) == 0 { return h.revokeUserDefault(ctx, username) } @@ -216,30 +231,27 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range statements.Revocation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return err + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return err + } } } - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil + return tx.Commit() } func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { @@ -284,3 +296,8 @@ func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { return nil } + +// RotateRootCredentials is not currently supported on HANA +func (h *HANA) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine") +} diff --git a/plugins/database/hana/hana_test.go b/plugins/database/hana/hana_test.go index 01b5194776..cb352520bd 100644 --- a/plugins/database/hana/hana_test.go +++ b/plugins/database/hana/hana_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" ) func TestHANA_Initialize(t *testing.T) { @@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initialized") } @@ -53,10 +49,8 @@ func TestHANA_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -73,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) { } statements := dbplugin.Statements{ - CreationStatements: testHANARole, + Creation: []string{testHANARole}, } username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) @@ -96,16 +90,14 @@ func TestHANA_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testHANARole, + Creation: []string{testHANARole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -139,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - statements.RevocationStatements = testHANADrop + statements.Revocation = []string{testHANADrop} err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/mongodb/connection_producer.go b/plugins/database/mongodb/connection_producer.go index 88d14e6e04..a4d394f9e7 100644 --- a/plugins/database/mongodb/connection_producer.go +++ b/plugins/database/mongodb/connection_producer.go @@ -14,7 +14,9 @@ import ( "sync" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/mitchellh/mapstructure" "gopkg.in/mgo.v2" @@ -25,28 +27,43 @@ import ( type mongoDBConnectionProducer struct { ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` Initialized bool + RawConfig map[string]interface{} Type string session *mgo.Session safe *mgo.Safe sync.Mutex } -// Initialize parses connection configuration. func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := c.Init(ctx, conf, verifyConnection) + return err +} + +// Initialize parses connection configuration. +func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { c.Lock() defer c.Unlock() + c.RawConfig = conf + err := mapstructure.WeakDecode(conf, c) if err != nil { - return err + return nil, err } if len(c.ConnectionURL) == 0 { - return fmt.Errorf("connection_url cannot be empty") + return nil, fmt.Errorf("connection_url cannot be empty") } + c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ + "username": c.Username, + "password": c.Password, + }) + if c.WriteConcern != "" { input := c.WriteConcern @@ -60,13 +77,13 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str concern := &mgo.Safe{} err = json.Unmarshal([]byte(input), concern) if err != nil { - return fmt.Errorf("error mashalling write_concern: %s", err) + return nil, errwrap.Wrapf("error mashalling write_concern: {{err}}", err) } // Guard against empty, non-nil mgo.Safe object; we don't want to pass that // into mgo.SetSafe in Connection(). if (mgo.Safe{} == *concern) { - return fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields") + return nil, fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields") } c.safe = concern } @@ -77,15 +94,15 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str if verifyConnection { if _, err := c.Connection(ctx); err != nil { - return fmt.Errorf("error verifying connection: %s", err) + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) } if err := c.session.Ping(); err != nil { - return fmt.Errorf("error verifying connection: %s", err) + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) } } - return nil + return conf, nil } // Connection creates or returns an existing a database connection. If the session fails @@ -203,3 +220,9 @@ func parseMongoURL(rawURL string) (*mgo.DialInfo, error) { return &info, nil } + +func (c *mongoDBConnectionProducer) secretValues() map[string]interface{} { + return map[string]interface{}{ + c.Password: "[password]", + } +} diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index b0e0a26208..61ca9c51bf 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -2,6 +2,7 @@ package mongodb import ( "context" + "errors" "io" "strings" "time" @@ -14,7 +15,6 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/plugins" - "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" "gopkg.in/mgo.v2" @@ -24,7 +24,7 @@ const mongoDBTypeName = "mongodb" // MongoDB is an implementation of Database interface type MongoDB struct { - connutil.ConnectionProducer + *mongoDBConnectionProducer credsutil.CredentialsProducer } @@ -32,6 +32,12 @@ var _ dbplugin.Database = &MongoDB{} // New returns a new MongoDB instance func New() (interface{}, error) { + db := new() + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) + return dbType, nil +} + +func new() *MongoDB { connProducer := &mongoDBConnectionProducer{} connProducer.Type = mongoDBTypeName @@ -42,11 +48,10 @@ func New() (interface{}, error) { Separator: "-", } - dbType := &MongoDB{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &MongoDB{ + mongoDBConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - return dbType, nil } // Run instantiates a MongoDB object, and runs the RPC server for the plugin @@ -88,7 +93,9 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements m.Lock() defer m.Unlock() - if statements.CreationStatements == "" { + statements = dbutil.StatementCompatibilityHelper(statements) + + if len(statements.Creation) == 0 { return "", "", dbutil.ErrEmptyCreationStatement } @@ -109,7 +116,7 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements // Unmarshal statements.CreationStatements into mongodbRoles var mongoCS mongoDBStatement - err = json.Unmarshal([]byte(statements.CreationStatements), &mongoCS) + err = json.Unmarshal([]byte(statements.Creation[0]), &mongoCS) if err != nil { return "", "", err } @@ -158,15 +165,22 @@ func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, // RevokeUser drops the specified user from the authentication database. If none is provided // in the revocation statement, the default "admin" authentication database will be assumed. func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { + statements = dbutil.StatementCompatibilityHelper(statements) + session, err := m.getConnection(ctx) if err != nil { return err } // If no revocation statements provided, pass in empty JSON - revocationStatement := statements.RevocationStatements - if revocationStatement == "" { + var revocationStatement string + switch len(statements.Revocation) { + case 0: revocationStatement = `{}` + case 1: + revocationStatement = statements.Revocation[0] + default: + return fmt.Errorf("expected 0 or 1 revocation statements, got %d", len(statements.Revocation)) } // Unmarshal revocation statements into mongodbRoles @@ -186,7 +200,7 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements switch { case err == nil, err == mgo.ErrNotFound: case err == io.EOF, strings.Contains(err.Error(), "EOF"): - if err := m.ConnectionProducer.Close(); err != nil { + if err := m.Close(); err != nil { return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) } session, err := m.getConnection(ctx) @@ -203,3 +217,8 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements return nil } + +// RotateRootCredentials is not currently supported on MongoDB +func (m *MongoDB) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine") +} diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index a6895e1814..9bb7d9d034 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) - - err = db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initialized") } @@ -103,18 +97,14 @@ func TestMongoDB_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - err = db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testMongoDBRole, + Creation: []string{testMongoDBRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -141,18 +131,14 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { "write_concern": testMongoDBWriteConcern, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - err = db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testMongoDBRole, + Creation: []string{testMongoDBRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -178,18 +164,14 @@ func TestMongoDB_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - err = db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testMongoDBRole, + Creation: []string{testMongoDBRole}, } usernameConfig := dbplugin.UsernameConfig{ diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index c0de592de5..84f7e1462a 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -3,11 +3,13 @@ package mssql import ( "context" "database/sql" + "errors" "fmt" "strings" "time" _ "github.com/denisenkom/go-mssqldb" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" @@ -23,11 +25,19 @@ var _ dbplugin.Database = &MSSQL{} // MSSQL is an implementation of Database interface type MSSQL struct { - connutil.ConnectionProducer + *connutil.SQLConnectionProducer credsutil.CredentialsProducer } func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + + return dbType, nil +} + +func new() *MSSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = msSQLTypeName @@ -38,12 +48,10 @@ func New() (interface{}, error) { Separator: "-", } - dbType := &MSSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &MSSQL{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - - return dbType, nil } // Run instantiates a MSSQL object, and runs the RPC server for the plugin @@ -53,7 +61,7 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - plugins.Serve(dbType.(*MSSQL), apiTLSConfig) + plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig) return nil } @@ -79,13 +87,15 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, m.Lock() defer m.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + // Get the connection db, err := m.getConnection(ctx) if err != nil { return "", "", err } - if statements.CreationStatements == "" { + if len(statements.Creation) == 0 { return "", "", dbutil.ErrEmptyCreationStatement } @@ -112,23 +122,25 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range statements.Creation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expirationStr, - })) - if err != nil { - return "", "", err - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return "", "", err + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return "", "", err + } } } @@ -150,7 +162,9 @@ func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, u // then kill pending connections from that user, and finally drop the user and login from the // database instance. func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { - if statements.RevocationStatements == "" { + statements = dbutil.StatementCompatibilityHelper(statements) + + if len(statements.Revocation) == 0 { return m.revokeUserDefault(ctx, username) } @@ -168,21 +182,23 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range statements.Revocation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return err + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return err + } } } @@ -283,10 +299,10 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { // can't drop if not all database users are dropped if rows.Err() != nil { - return fmt.Errorf("could not generate sql statements for all rows: %s", rows.Err()) + return errwrap.Wrapf("could not generate sql statements for all rows: {{err}}", rows.Err()) } if lastStmtError != nil { - return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + return errwrap.Wrapf("could not perform all sql statements: {{err}}", lastStmtError) } // Drop this login @@ -302,6 +318,70 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { return nil } +func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + m.Lock() + defer m.Unlock() + + if len(m.Username) == 0 || len(m.Password) == 0 { + return nil, errors.New("username and password are required to rotate") + } + + rotateStatents := statements + if len(rotateStatents) == 0 { + rotateStatents = []string{rotateRootCredentialsSQL} + } + + db, err := m.getConnection(ctx) + if err != nil { + return nil, err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + tx.Rollback() + }() + + password, err := m.GeneratePassword() + if err != nil { + return nil, err + } + + for _, stmt := range rotateStatents { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "username": m.Username, + "password": password, + })) + if err != nil { + return nil, err + } + + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return nil, err + } + } + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + if err := db.Close(); err != nil { + return nil, err + } + + m.RawConfig["password"] = password + return m.RawConfig, nil +} + const dropUserSQL = ` USE [%s] IF EXISTS @@ -322,3 +402,7 @@ BEGIN DROP LOGIN [%s] END ` + +const rotateRootCredentialsSQL = ` +ALTER LOGIN [%s] WITH PASSWORD = '%s' +` diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index a1610810f2..f2ad54c993 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" ) var ( @@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -52,7 +48,7 @@ func TestMSSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -68,9 +64,8 @@ func TestMSSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -87,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) { } statements := dbplugin.Statements{ - CreationStatements: testMSSQLRole, + Creation: []string{testMSSQLRole}, } username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) @@ -110,15 +105,14 @@ func TestMSSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testMSSQLRole, + Creation: []string{testMSSQLRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -155,7 +149,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { } // Test custom revoke statement - statements.RevocationStatements = testMSSQLDrop + statements.Revocation = []string{testMSSQLDrop} err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 38c928c35a..00fe475045 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -3,6 +3,7 @@ package mysql import ( "context" "database/sql" + "errors" "strings" "time" @@ -21,6 +22,11 @@ const ( REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` + + defaultMySQLRotateRootCredentialsSQL = ` + ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}'; + ` + mySQLTypeName = "mysql" ) @@ -34,32 +40,38 @@ var ( var _ dbplugin.Database = &MySQL{} type MySQL struct { - connutil.ConnectionProducer + *connutil.SQLConnectionProducer credsutil.CredentialsProducer } // New implements builtinplugins.BuiltinFactory func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) { return func() (interface{}, error) { - connProducer := &connutil.SQLConnectionProducer{} - connProducer.Type = mySQLTypeName - - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: displayNameLen, - RoleNameLen: roleNameLen, - UsernameLen: usernameLen, - Separator: "-", - } - - dbType := &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } + db := new(displayNameLen, roleNameLen, usernameLen) + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) return dbType, nil } } +func new(displayNameLen, roleNameLen, usernameLen int) *MySQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: displayNameLen, + RoleNameLen: roleNameLen, + UsernameLen: usernameLen, + Separator: "-", + } + + return &MySQL{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } +} + // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run(apiTLSConfig *api.TLSConfig) error { return runCommon(false, apiTLSConfig) @@ -82,7 +94,7 @@ func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error { return err } - plugins.Serve(dbType.(*MySQL), apiTLSConfig) + plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig) return nil } @@ -105,13 +117,15 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, m.Lock() defer m.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + // Get the connection db, err := m.getConnection(ctx) if err != nil { return "", "", err } - if statements.CreationStatements == "" { + if len(statements.Creation) == 0 { return "", "", dbutil.ErrEmptyCreationStatement } @@ -138,38 +152,40 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - query = dbutil.QueryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expirationStr, - }) - - stmt, err := tx.PrepareContext(ctx, query) - if err != nil { - // If the error code we get back is Error 1295: This command is not - // supported in the prepared statement protocol yet, we will execute - // the statement without preparing it. This allows the caller to - // manually prepare statements, as well as run other not yet - // prepare supported commands. If there is no error when running we - // will continue to the next statement. - if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { - _, err = tx.ExecContext(ctx, query) - if err != nil { - return "", "", err - } + for _, stmt := range statements.Creation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { continue } + query = dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + }) - return "", "", err - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return "", "", err + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + // If the error code we get back is Error 1295: This command is not + // supported in the prepared statement protocol yet, we will execute + // the statement without preparing it. This allows the caller to + // manually prepare statements, as well as run other not yet + // prepare supported commands. If there is no error when running we + // will continue to the next statement. + if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { + _, err = tx.ExecContext(ctx, query) + if err != nil { + return "", "", err + } + continue + } + + return "", "", err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return "", "", err + } } } @@ -191,16 +207,18 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, m.Lock() defer m.Unlock() + statements = dbutil.StatementCompatibilityHelper(statements) + // Get the connection db, err := m.getConnection(ctx) if err != nil { return err } - revocationStmts := statements.RevocationStatements + revocationStmts := statements.Revocation // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { - revocationStmts = defaultMysqlRevocationStmts + if len(revocationStmts) == 0 { + revocationStmts = []string{defaultMysqlRevocationStmts} } // Start a transaction @@ -210,21 +228,22 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, } defer tx.Rollback() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range revocationStmts { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - // This is not a prepared statement because not all commands are supported - // 1295: This command is not supported in the prepared statement protocol yet - // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ - query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.ExecContext(ctx, query) - if err != nil { - return err + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.ExecContext(ctx, query) + if err != nil { + return err + } } - } // Commit the transaction @@ -234,3 +253,67 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, return nil } + +func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + m.Lock() + defer m.Unlock() + + if len(m.Username) == 0 || len(m.Password) == 0 { + return nil, errors.New("username and password are required to rotate") + } + + rotateStatents := statements + if len(rotateStatents) == 0 { + rotateStatents = []string{defaultMySQLRotateRootCredentialsSQL} + } + + db, err := m.getConnection(ctx) + if err != nil { + return nil, err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + tx.Rollback() + }() + + password, err := m.GeneratePassword() + if err != nil { + return nil, err + } + + for _, stmt := range rotateStatents { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "username": m.Username, + "password": password, + })) + if err != nil { + return nil, err + } + + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return nil, err + } + } + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + if err := db.Close(); err != nil { + return nil, err + } + + m.RawConfig["password"] = password + return m.RawConfig, nil +} diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index e92a0ebdc7..ff746bb680 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -104,17 +104,13 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -129,7 +125,7 @@ func TestMySQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -143,11 +139,8 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -164,7 +157,7 @@ func TestMySQL_CreateUser(t *testing.T) { } statements := dbplugin.Statements{ - CreationStatements: testMySQLRoleWildCard, + Creation: []string{testMySQLRoleWildCard}, } username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) @@ -187,7 +180,7 @@ func TestMySQL_CreateUser(t *testing.T) { } // Test with a manually prepare statement - statements.CreationStatements = testMySQLRolePreparedStmt + statements.Creation = []string{testMySQLRolePreparedStmt} username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { @@ -208,11 +201,8 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { "connection_url": connURL, } - f := New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -229,7 +219,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } statements := dbplugin.Statements{ - CreationStatements: testMySQLRoleWildCard, + Creation: []string{testMySQLRoleWildCard}, } username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) @@ -252,6 +242,42 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } } +func TestMySQL_RotateRootCredentials(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "username": "root", + "password": "secret", + } + + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !db.Initialized { + t.Fatal("Database should be initalized") + } + + newConf, err := db.RotateRootCredentials(context.Background(), nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if newConf["password"] == "secret" { + t.Fatal("password was not updated") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + func TestMySQL_RevokeUser(t *testing.T) { cleanup, connURL := prepareMySQLTestContainer(t) defer cleanup() @@ -260,17 +286,14 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testMySQLRoleWildCard, + Creation: []string{testMySQLRoleWildCard}, } usernameConfig := dbplugin.UsernameConfig{ @@ -297,7 +320,7 @@ func TestMySQL_RevokeUser(t *testing.T) { t.Fatal("Credentials were not revoked") } - statements.CreationStatements = testMySQLRoleWildCard + statements.Creation = []string{testMySQLRoleWildCard} username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) @@ -308,7 +331,7 @@ func TestMySQL_RevokeUser(t *testing.T) { } // Test custom revoke statements - statements.RevocationStatements = testMySQLRevocationSQL + statements.Revocation = []string{testMySQLRevocationSQL} err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index f2e20d3f45..c56f9ed02d 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -3,10 +3,12 @@ package postgresql import ( "context" "database/sql" + "errors" "fmt" "strings" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" @@ -15,13 +17,15 @@ import ( "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/lib/pq" - _ "github.com/lib/pq" ) const ( - postgreSQLTypeName string = "postgres" - defaultPostgresRenewSQL = ` + postgreSQLTypeName = "postgres" + defaultPostgresRenewSQL = ` ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; +` + defaultPostgresRotateRootCredentialsSQL = ` +ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; ` ) @@ -29,6 +33,13 @@ var _ dbplugin.Database = &PostgreSQL{} // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + return dbType, nil +} + +func new() *PostgreSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = postgreSQLTypeName @@ -39,12 +50,12 @@ func New() (interface{}, error) { Separator: "-", } - dbType := &PostgreSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + db := &PostgreSQL{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - return dbType, nil + return db } // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin @@ -54,13 +65,13 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig) + plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig) return nil } type PostgreSQL struct { - connutil.ConnectionProducer + *connutil.SQLConnectionProducer credsutil.CredentialsProducer } @@ -78,7 +89,9 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { } func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { - if statements.CreationStatements == "" { + statements = dbutil.StatementCompatibilityHelper(statements) + + if len(statements.Creation) == 0 { return "", "", dbutil.ErrEmptyCreationStatement } @@ -105,7 +118,6 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme db, err := p.getConnection(ctx) if err != nil { return "", "", err - } // Start a transaction @@ -120,25 +132,25 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expirationStr, - })) - if err != nil { - return "", "", err - - } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return "", "", err + for _, stmt := range statements.Creation { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return "", "", err + } } } @@ -155,9 +167,11 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen p.Lock() defer p.Unlock() - renewStmts := statements.RenewStatements - if renewStmts == "" { - renewStmts = defaultPostgresRenewSQL + statements = dbutil.StatementCompatibilityHelper(statements) + + renewStmts := statements.Renewal + if len(renewStmts) == 0 { + renewStmts = []string{defaultPostgresRenewSQL} } db, err := p.getConnection(ctx) @@ -178,30 +192,28 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen return err } - for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - "expiration": expirationStr, - })) - if err != nil { - return err - } + for _, stmt := range renewStmts { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + "expiration": expirationStr, + })) + if err != nil { + return err + } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return err + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return err + } } } - if err := tx.Commit(); err != nil { - return err - } - - return nil + return tx.Commit() } func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { @@ -209,14 +221,16 @@ func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Stateme p.Lock() defer p.Unlock() - if statements.RevocationStatements == "" { + statements = dbutil.StatementCompatibilityHelper(statements) + + if len(statements.Revocation) == 0 { return p.defaultRevokeUser(ctx, username) } - return p.customRevokeUser(ctx, username, statements.RevocationStatements) + return p.customRevokeUser(ctx, username, statements.Revocation) } -func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { +func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { db, err := p.getConnection(ctx) if err != nil { return err @@ -230,30 +244,28 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } + for _, stmt := range revocationStmts { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { - return err + if _, err := stmt.ExecContext(ctx); err != nil { + return err + } } } - if err := tx.Commit(); err != nil { - return err - } - - return nil + return tx.Commit() } func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { @@ -354,10 +366,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err // can't drop if not all privileges are revoked if rows.Err() != nil { - return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) } if lastStmtError != nil { - return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) } // Drop this user @@ -373,3 +385,68 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err return nil } + +func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { + p.Lock() + defer p.Unlock() + + if len(p.Username) == 0 || len(p.Password) == 0 { + return nil, errors.New("username and password are required to rotate") + } + + rotateStatents := statements + if len(rotateStatents) == 0 { + rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL} + } + + db, err := p.getConnection(ctx) + if err != nil { + return nil, err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + tx.Rollback() + }() + + password, err := p.GeneratePassword() + if err != nil { + return nil, err + } + + for _, stmt := range rotateStatents { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + "username": p.Username, + "password": password, + })) + if err != nil { + return nil, err + } + + defer stmt.Close() + if _, err := stmt.ExecContext(ctx); err != nil { + return nil, err + } + } + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + // Close the database connection to ensure no new connections come in + if err := db.Close(); err != nil { + return nil, err + } + + p.RawConfig["password"] = password + return p.RawConfig, nil +} diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 97ff926753..4b83060271 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -68,17 +67,13 @@ func TestPostgreSQL_Initialize(t *testing.T) { "max_open_connections": 5, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) - - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -93,7 +88,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(context.Background(), connectionDetails, true) + _, err = db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -108,9 +103,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -127,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } statements := dbplugin.Statements{ - CreationStatements: testPostgresRole, + Creation: []string{testPostgresRole}, } username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) @@ -139,7 +133,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - statements.CreationStatements = testPostgresReadOnlyRole + statements.Creation = []string{testPostgresReadOnlyRole} username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) @@ -161,15 +155,14 @@ func TestPostgreSQL_RenewUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testPostgresRole, + Creation: []string{testPostgresRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -197,7 +190,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { if err = testCredsExist(t, connURL, username, password); err != nil { t.Fatalf("Could not connect with new credentials: %s", err) } - statements.RenewStatements = defaultPostgresRenewSQL + statements.Renewal = []string{defaultPostgresRenewSQL} username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) @@ -221,6 +214,46 @@ func TestPostgreSQL_RenewUser(t *testing.T) { } +func TestPostgreSQL_RotateRootCredentials(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + "username": "postgres", + "password": "secret", + } + + db := new() + + connProducer := db.SQLConnectionProducer + + _, err := db.Init(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + newConf, err := db.RotateRootCredentials(context.Background(), nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if newConf["password"] == "secret" { + t.Fatal("password was not updated") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + func TestPostgreSQL_RevokeUser(t *testing.T) { cleanup, connURL := preparePostgresTestContainer(t) defer cleanup() @@ -229,15 +262,14 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) - err := db.Initialize(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } statements := dbplugin.Statements{ - CreationStatements: testPostgresRole, + Creation: []string{testPostgresRole}, } usernameConfig := dbplugin.UsernameConfig{ @@ -274,7 +306,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } // Test custom revoke statements - statements.RevocationStatements = defaultPostgresRevocationSQL + statements.Revocation = []string{defaultPostgresRevocationSQL} err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) @@ -286,6 +318,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } func testCredsExist(t testing.TB, connURL, username, password string) error { + t.Helper() // Log in with the new creds connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) db, err := sql.Open("postgres", connURL) diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go index 7cf23c5c3e..45f6fa0ad7 100644 --- a/plugins/helper/database/connutil/connutil.go +++ b/plugins/helper/database/connutil/connutil.go @@ -15,8 +15,11 @@ var ( // connections and is used in all the builtin database types. type ConnectionProducer interface { Close() error - Initialize(context.Context, map[string]interface{}, bool) error + Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error) Connection(context.Context) (interface{}, error) sync.Locker + + // DEPRECATED, will be removed in 0.12 + Initialize(context.Context, map[string]interface{}, bool) error } diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index ec351381bd..38685d0be4 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -8,18 +8,25 @@ import ( "sync" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/parseutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/mitchellh/mapstructure" ) +var _ ConnectionProducer = &SQLConnectionProducer{} + // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type SQLConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` + MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` + Username string `json:"username" mapstructure:"username" structs:"username"` + Password string `json:"password" mapstructure:"password" structs:"password"` Type string + RawConfig map[string]interface{} maxConnectionLifetime time.Duration Initialized bool db *sql.DB @@ -27,18 +34,30 @@ type SQLConnectionProducer struct { } func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := c.Init(ctx, conf, verifyConnection) + return err +} + +func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { c.Lock() defer c.Unlock() - err := mapstructure.WeakDecode(conf, c) + c.RawConfig = conf + + err := mapstructure.WeakDecode(conf, &c) if err != nil { - return err + return nil, err } if len(c.ConnectionURL) == 0 { - return fmt.Errorf("connection_url cannot be empty") + return nil, fmt.Errorf("connection_url cannot be empty") } + c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ + "username": c.Username, + "password": c.Password, + }) + if c.MaxOpenConnections == 0 { c.MaxOpenConnections = 2 } @@ -55,7 +74,7 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string] c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) if err != nil { - return fmt.Errorf("invalid max_connection_lifetime: %s", err) + return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err) } // Set initialized to true at this point since all fields are set, @@ -64,15 +83,15 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string] if verifyConnection { if _, err := c.Connection(ctx); err != nil { - return fmt.Errorf("error verifying connection: %s", err) + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) } if err := c.db.PingContext(ctx); err != nil { - return fmt.Errorf("error verifying connection: %s", err) + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) } } - return nil + return c.RawConfig, nil } func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { @@ -123,6 +142,12 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er return c.db, nil } +func (c *SQLConnectionProducer) SecretValues() map[string]interface{} { + return map[string]interface{}{ + c.Password: "[password]", + } +} + // Close attempts to close the connection func (c *SQLConnectionProducer) Close() error { // Grab the write lock diff --git a/plugins/helper/database/dbutil/dbutil.go b/plugins/helper/database/dbutil/dbutil.go index e80273b7fb..42257053ce 100644 --- a/plugins/helper/database/dbutil/dbutil.go +++ b/plugins/helper/database/dbutil/dbutil.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strings" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" ) var ( @@ -18,3 +20,33 @@ func QueryHelper(tpl string, data map[string]string) string { return tpl } + +// StatementCompatibilityHelper will populate the statements fields to support +// compatibility +func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements { + switch { + case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0: + statements.CreationStatements = strings.Join(statements.Creation, ";") + case len(statements.CreationStatements) > 0: + statements.Creation = []string{statements.CreationStatements} + } + switch { + case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0: + statements.RevocationStatements = strings.Join(statements.Revocation, ";") + case len(statements.RevocationStatements) > 0: + statements.Revocation = []string{statements.RevocationStatements} + } + switch { + case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0: + statements.RenewStatements = strings.Join(statements.Renewal, ";") + case len(statements.RenewStatements) > 0: + statements.Renewal = []string{statements.RenewStatements} + } + switch { + case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0: + statements.RollbackStatements = strings.Join(statements.Rollback, ";") + case len(statements.RollbackStatements) > 0: + statements.Rollback = []string{statements.RollbackStatements} + } + return statements +} diff --git a/plugins/helper/database/dbutil/dbutil_test.go b/plugins/helper/database/dbutil/dbutil_test.go new file mode 100644 index 0000000000..4d239a6efb --- /dev/null +++ b/plugins/helper/database/dbutil/dbutil_test.go @@ -0,0 +1,62 @@ +package dbutil + +import ( + "reflect" + "testing" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" +) + +func TestStatementCompatibilityHelper(t *testing.T) { + const ( + creationStatement = "creation" + renewStatement = "renew" + revokeStatement = "revoke" + rollbackStatement = "rollback" + ) + + expectedStatements := dbplugin.Statements{ + Creation: []string{creationStatement}, + Rollback: []string{rollbackStatement}, + Revocation: []string{revokeStatement}, + Renewal: []string{renewStatement}, + CreationStatements: creationStatement, + RenewStatements: renewStatement, + RollbackStatements: rollbackStatement, + RevocationStatements: revokeStatement, + } + + statements1 := dbplugin.Statements{ + CreationStatements: creationStatement, + RenewStatements: renewStatement, + RollbackStatements: rollbackStatement, + RevocationStatements: revokeStatement, + } + + if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) { + t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1) + } + + statements2 := dbplugin.Statements{ + Creation: []string{creationStatement}, + Rollback: []string{rollbackStatement}, + Revocation: []string{revokeStatement}, + Renewal: []string{renewStatement}, + } + + if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) { + t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2) + } + + statements3 := dbplugin.Statements{ + CreationStatements: creationStatement, + } + expectedStatements3 := dbplugin.Statements{ + Creation: []string{creationStatement}, + CreationStatements: creationStatement, + } + if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) { + t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3) + } + +}