diff --git a/builtin/logical/database/mockv5.go b/builtin/logical/database/mockv5.go index 16f5caa617..320e065703 100644 --- a/builtin/logical/database/mockv5.go +++ b/builtin/logical/database/mockv5.go @@ -6,6 +6,7 @@ import ( "time" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/sdk/database/newdbplugin" ) @@ -25,13 +26,13 @@ func New() (interface{}, error) { } // Run instantiates a MongoDB object, and runs the RPC server for the plugin -func RunV5() error { +func RunV5(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - newdbplugin.Serve(dbType.(newdbplugin.Database)) + newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) return nil } diff --git a/builtin/logical/database/version_wrapper.go b/builtin/logical/database/version_wrapper.go index e8867d87c6..f79f1aff25 100644 --- a/builtin/logical/database/version_wrapper.go +++ b/builtin/logical/database/version_wrapper.go @@ -6,6 +6,7 @@ import ( "fmt" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/helper/random" "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/newdbplugin" @@ -30,6 +31,9 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L return dbw, nil } + merr := &multierror.Error{} + merr = multierror.Append(merr, err) + legacyDB, err := dbplugin.PluginFactory(ctx, pluginName, sys, logger) if err == nil { dbw = databaseVersionWrapper{ @@ -37,8 +41,9 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L } return dbw, nil } + merr = multierror.Append(merr, err) - return dbw, fmt.Errorf("invalid database version") + return dbw, fmt.Errorf("invalid database version: %s", merr) } // Initialize the underlying database. This is analogous to a constructor on the database plugin object. diff --git a/builtin/logical/database/versioning_large_test.go b/builtin/logical/database/versioning_large_test.go index 3e8eff67e3..7bbe040850 100644 --- a/builtin/logical/database/versioning_large_test.go +++ b/builtin/logical/database/versioning_large_test.go @@ -265,7 +265,7 @@ func TestBackend_PluginMain_MockV5(t *testing.T) { flags := apiClientMeta.FlagSet() flags.Parse(args) - RunV5() + RunV5(apiClientMeta.GetTLSConfig()) } func assertNoRespData(t *testing.T, resp *logical.Response) { diff --git a/sdk/database/newdbplugin/plugin_client.go b/sdk/database/newdbplugin/plugin_client.go index c97c03b735..d9cac2f9ee 100644 --- a/sdk/database/newdbplugin/plugin_client.go +++ b/sdk/database/newdbplugin/plugin_client.go @@ -45,7 +45,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne pluginutil.HandshakeConfig(handshakeConfig), pluginutil.Logger(logger), pluginutil.MetadataMode(isMetadataMode), - pluginutil.AutoMTLS(true), + pluginutil.AutoMTLS(false), ) if err != nil { return nil, err diff --git a/sdk/database/newdbplugin/plugin_server.go b/sdk/database/newdbplugin/plugin_server.go index 8098364be6..0366522afd 100644 --- a/sdk/database/newdbplugin/plugin_server.go +++ b/sdk/database/newdbplugin/plugin_server.go @@ -1,6 +1,7 @@ package newdbplugin import ( + "crypto/tls" "fmt" "github.com/hashicorp/go-plugin" @@ -10,11 +11,11 @@ import ( // Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func Serve(db Database) { - plugin.Serve(ServeConfig(db)) +func Serve(db Database, tlsProvider func() (*tls.Config, error)) { + plugin.Serve(ServeConfig(db, tlsProvider)) } -func ServeConfig(db Database) *plugin.ServeConfig { +func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { err := pluginutil.OptionallyEnableMlock() if err != nil { fmt.Println(err) @@ -34,6 +35,7 @@ func ServeConfig(db Database) *plugin.ServeConfig { HandshakeConfig: handshakeConfig, VersionedPlugins: pluginSets, GRPCServer: plugin.DefaultGRPCServer, + TLSProvider: tlsProvider, } return conf diff --git a/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/sql.go b/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/sql.go index 986631da94..39fb467a79 100644 --- a/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/sql.go +++ b/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/sql.go @@ -2,8 +2,6 @@ package credsutil import ( "context" - "fmt" - "strings" "time" "github.com/hashicorp/vault/sdk/database/dbplugin" @@ -31,46 +29,17 @@ func (scp *SQLCredentialsProducer) GenerateCredentials(ctx context.Context) (str } func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) { - username := "v" - - displayName := config.DisplayName - if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen { - displayName = displayName[:scp.DisplayNameLen] - } else if scp.DisplayNameLen == NoneLength { - displayName = "" - } - - if len(displayName) > 0 { - username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName) - } - - roleName := config.RoleName - if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen { - roleName = roleName[:scp.RoleNameLen] - } else if scp.RoleNameLen == NoneLength { - roleName = "" - } - - if len(roleName) > 0 { - username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName) - } - - userUUID, err := RandomAlphaNumeric(20, false) - if err != nil { - return "", err - } - - username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID) - username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix())) - if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { - username = username[:scp.UsernameLen] - } - + caseOp := KeepCase if scp.LowercaseUsername { - username = strings.ToLower(username) + caseOp = Lowercase } - - return username, nil + return GenerateUsername( + DisplayName(config.DisplayName, scp.DisplayNameLen), + RoleName(config.RoleName, scp.RoleNameLen), + Case(caseOp), + Separator(scp.Separator), + MaxLength(scp.UsernameLen), + ) } func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) { diff --git a/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/usernames.go b/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/usernames.go new file mode 100644 index 0000000000..c1e3ccb529 --- /dev/null +++ b/vendor/github.com/hashicorp/vault/sdk/database/helper/credsutil/usernames.go @@ -0,0 +1,140 @@ +package credsutil + +import ( + "fmt" + "strings" + "time" +) + +type CaseOp int + +const ( + KeepCase CaseOp = iota + Uppercase + Lowercase +) + +type usernameBuilder struct { + displayName string + roleName string + separator string + + maxLen int + caseOperation CaseOp +} + +func (ub usernameBuilder) makeUsername() (string, error) { + userUUID, err := RandomAlphaNumeric(20, false) + if err != nil { + return "", err + } + + now := fmt.Sprint(time.Now().Unix()) + + username := joinNonEmpty(ub.separator, + "v", + ub.displayName, + ub.roleName, + userUUID, + now) + username = trunc(username, ub.maxLen) + switch ub.caseOperation { + case Lowercase: + username = strings.ToLower(username) + case Uppercase: + username = strings.ToUpper(username) + } + + return username, nil +} + +type UsernameOpt func(*usernameBuilder) + +func DisplayName(dispName string, maxLength int) UsernameOpt { + return func(b *usernameBuilder) { + b.displayName = trunc(dispName, maxLength) + } +} + +func RoleName(roleName string, maxLength int) UsernameOpt { + return func(b *usernameBuilder) { + b.roleName = trunc(roleName, maxLength) + } +} + +func Separator(sep string) UsernameOpt { + return func(b *usernameBuilder) { + b.separator = sep + } +} + +func MaxLength(maxLen int) UsernameOpt { + return func(b *usernameBuilder) { + b.maxLen = maxLen + } +} + +func Case(c CaseOp) UsernameOpt { + return func(b *usernameBuilder) { + b.caseOperation = c + } +} + +func ToLower() UsernameOpt { + return Case(Lowercase) +} + +func ToUpper() UsernameOpt { + return Case(Uppercase) +} + +func GenerateUsername(opts ...UsernameOpt) (string, error) { + b := usernameBuilder{ + separator: "_", + maxLen: 100, + caseOperation: KeepCase, + } + + for _, opt := range opts { + opt(&b) + } + + return b.makeUsername() +} + +func trunc(str string, l int) string { + switch { + case l > 0: + if l > len(str) { + return str + } + return str[:l] + case l == 0: + return str + default: + return "" + } +} + +func joinNonEmpty(sep string, vals ...string) string { + if sep == "" { + return strings.Join(vals, sep) + } + switch len(vals) { + case 0: + return "" + case 1: + return vals[0] + } + builder := &strings.Builder{} + for _, val := range vals { + if val == "" { + continue + } + if builder.Len() > 0 { + builder.WriteString(sep) + } + builder.WriteString(val) + } + return builder.String() +} diff --git a/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_client.go b/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_client.go index c97c03b735..d9cac2f9ee 100644 --- a/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_client.go +++ b/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_client.go @@ -45,7 +45,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne pluginutil.HandshakeConfig(handshakeConfig), pluginutil.Logger(logger), pluginutil.MetadataMode(isMetadataMode), - pluginutil.AutoMTLS(true), + pluginutil.AutoMTLS(false), ) if err != nil { return nil, err diff --git a/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_server.go b/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_server.go index 8098364be6..0366522afd 100644 --- a/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_server.go +++ b/vendor/github.com/hashicorp/vault/sdk/database/newdbplugin/plugin_server.go @@ -1,6 +1,7 @@ package newdbplugin import ( + "crypto/tls" "fmt" "github.com/hashicorp/go-plugin" @@ -10,11 +11,11 @@ import ( // Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func Serve(db Database) { - plugin.Serve(ServeConfig(db)) +func Serve(db Database, tlsProvider func() (*tls.Config, error)) { + plugin.Serve(ServeConfig(db, tlsProvider)) } -func ServeConfig(db Database) *plugin.ServeConfig { +func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { err := pluginutil.OptionallyEnableMlock() if err != nil { fmt.Println(err) @@ -34,6 +35,7 @@ func ServeConfig(db Database) *plugin.ServeConfig { HandshakeConfig: handshakeConfig, VersionedPlugins: pluginSets, GRPCServer: plugin.DefaultGRPCServer, + TLSProvider: tlsProvider, } return conf