DBPW - Revert AutoMTLS (#10065)

This commit is contained in:
Michael Golowka 2020-09-30 17:08:37 -06:00 committed by GitHub
parent f49fa06237
commit 9978ba802f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 171 additions and 52 deletions

View File

@ -6,6 +6,7 @@ import (
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
)
@ -25,13 +26,13 @@ func New() (interface{}, error) {
}
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
func RunV5() error {
func RunV5(apiTLSConfig *api.TLSConfig) error {
dbType, err := New()
if err != nil {
return err
}
newdbplugin.Serve(dbType.(newdbplugin.Database))
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/random"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
@ -30,6 +31,9 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L
return dbw, nil
}
merr := &multierror.Error{}
merr = multierror.Append(merr, err)
legacyDB, err := dbplugin.PluginFactory(ctx, pluginName, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
@ -37,8 +41,9 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L
}
return dbw, nil
}
merr = multierror.Append(merr, err)
return dbw, fmt.Errorf("invalid database version")
return dbw, fmt.Errorf("invalid database version: %s", merr)
}
// Initialize the underlying database. This is analogous to a constructor on the database plugin object.

View File

@ -265,7 +265,7 @@ func TestBackend_PluginMain_MockV5(t *testing.T) {
flags := apiClientMeta.FlagSet()
flags.Parse(args)
RunV5()
RunV5(apiClientMeta.GetTLSConfig())
}
func assertNoRespData(t *testing.T, resp *logical.Response) {

View File

@ -45,7 +45,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
pluginutil.HandshakeConfig(handshakeConfig),
pluginutil.Logger(logger),
pluginutil.MetadataMode(isMetadataMode),
pluginutil.AutoMTLS(true),
pluginutil.AutoMTLS(false),
)
if err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package newdbplugin
import (
"crypto/tls"
"fmt"
"github.com/hashicorp/go-plugin"
@ -10,11 +11,11 @@ import (
// Serve is called from within a plugin and wraps the provided
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database) {
plugin.Serve(ServeConfig(db))
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database) *plugin.ServeConfig {
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
err := pluginutil.OptionallyEnableMlock()
if err != nil {
fmt.Println(err)
@ -34,6 +35,7 @@ func ServeConfig(db Database) *plugin.ServeConfig {
HandshakeConfig: handshakeConfig,
VersionedPlugins: pluginSets,
GRPCServer: plugin.DefaultGRPCServer,
TLSProvider: tlsProvider,
}
return conf

View File

@ -2,8 +2,6 @@ package credsutil
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin"
@ -31,46 +29,17 @@ func (scp *SQLCredentialsProducer) GenerateCredentials(ctx context.Context) (str
}
func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) {
username := "v"
displayName := config.DisplayName
if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
displayName = displayName[:scp.DisplayNameLen]
} else if scp.DisplayNameLen == NoneLength {
displayName = ""
}
if len(displayName) > 0 {
username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName)
}
roleName := config.RoleName
if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen {
roleName = roleName[:scp.RoleNameLen]
} else if scp.RoleNameLen == NoneLength {
roleName = ""
}
if len(roleName) > 0 {
username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName)
}
userUUID, err := RandomAlphaNumeric(20, false)
if err != nil {
return "", err
}
username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID)
username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix()))
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
username = username[:scp.UsernameLen]
}
caseOp := KeepCase
if scp.LowercaseUsername {
username = strings.ToLower(username)
caseOp = Lowercase
}
return username, nil
return GenerateUsername(
DisplayName(config.DisplayName, scp.DisplayNameLen),
RoleName(config.RoleName, scp.RoleNameLen),
Case(caseOp),
Separator(scp.Separator),
MaxLength(scp.UsernameLen),
)
}
func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {

View File

@ -0,0 +1,140 @@
package credsutil
import (
"fmt"
"strings"
"time"
)
type CaseOp int
const (
KeepCase CaseOp = iota
Uppercase
Lowercase
)
type usernameBuilder struct {
displayName string
roleName string
separator string
maxLen int
caseOperation CaseOp
}
func (ub usernameBuilder) makeUsername() (string, error) {
userUUID, err := RandomAlphaNumeric(20, false)
if err != nil {
return "", err
}
now := fmt.Sprint(time.Now().Unix())
username := joinNonEmpty(ub.separator,
"v",
ub.displayName,
ub.roleName,
userUUID,
now)
username = trunc(username, ub.maxLen)
switch ub.caseOperation {
case Lowercase:
username = strings.ToLower(username)
case Uppercase:
username = strings.ToUpper(username)
}
return username, nil
}
type UsernameOpt func(*usernameBuilder)
func DisplayName(dispName string, maxLength int) UsernameOpt {
return func(b *usernameBuilder) {
b.displayName = trunc(dispName, maxLength)
}
}
func RoleName(roleName string, maxLength int) UsernameOpt {
return func(b *usernameBuilder) {
b.roleName = trunc(roleName, maxLength)
}
}
func Separator(sep string) UsernameOpt {
return func(b *usernameBuilder) {
b.separator = sep
}
}
func MaxLength(maxLen int) UsernameOpt {
return func(b *usernameBuilder) {
b.maxLen = maxLen
}
}
func Case(c CaseOp) UsernameOpt {
return func(b *usernameBuilder) {
b.caseOperation = c
}
}
func ToLower() UsernameOpt {
return Case(Lowercase)
}
func ToUpper() UsernameOpt {
return Case(Uppercase)
}
func GenerateUsername(opts ...UsernameOpt) (string, error) {
b := usernameBuilder{
separator: "_",
maxLen: 100,
caseOperation: KeepCase,
}
for _, opt := range opts {
opt(&b)
}
return b.makeUsername()
}
func trunc(str string, l int) string {
switch {
case l > 0:
if l > len(str) {
return str
}
return str[:l]
case l == 0:
return str
default:
return ""
}
}
func joinNonEmpty(sep string, vals ...string) string {
if sep == "" {
return strings.Join(vals, sep)
}
switch len(vals) {
case 0:
return ""
case 1:
return vals[0]
}
builder := &strings.Builder{}
for _, val := range vals {
if val == "" {
continue
}
if builder.Len() > 0 {
builder.WriteString(sep)
}
builder.WriteString(val)
}
return builder.String()
}

View File

@ -45,7 +45,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
pluginutil.HandshakeConfig(handshakeConfig),
pluginutil.Logger(logger),
pluginutil.MetadataMode(isMetadataMode),
pluginutil.AutoMTLS(true),
pluginutil.AutoMTLS(false),
)
if err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package newdbplugin
import (
"crypto/tls"
"fmt"
"github.com/hashicorp/go-plugin"
@ -10,11 +11,11 @@ import (
// Serve is called from within a plugin and wraps the provided
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database) {
plugin.Serve(ServeConfig(db))
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database) *plugin.ServeConfig {
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
err := pluginutil.OptionallyEnableMlock()
if err != nil {
fmt.Println(err)
@ -34,6 +35,7 @@ func ServeConfig(db Database) *plugin.ServeConfig {
HandshakeConfig: handshakeConfig,
VersionedPlugins: pluginSets,
GRPCServer: plugin.DefaultGRPCServer,
TLSProvider: tlsProvider,
}
return conf