vault/sdk/database/dbplugin/v5/grpc_server.go
John-Michael Faircloth 07927e036c
feature: secrets/auth plugin multiplexing (#14946)
* enable registering backend muxed plugins in plugin catalog

* set the sysview on the pluginconfig to allow enabling secrets/auth plugins

* store backend instances in map

* store single implementations in the instances map

cleanup instance map and ensure we don't deadlock

* fix system backend unit tests

move GetMultiplexIDFromContext to pluginutil package

fix pluginutil test

fix dbplugin ut

* return error(s) if we can't get the plugin client

update comments

* refactor/move GetMultiplexIDFromContext test

* add changelog

* remove unnecessary field on pluginClient

* add unit tests to PluginCatalog for secrets/auth plugins

* fix comment

* return pluginClient from TestRunTestPlugin

* add multiplexed backend test

* honor metadatamode value in newbackend pluginconfig

* check that connection exists on cleanup

* add automtls to secrets/auth plugins

* don't remove apiclientmeta parsing

* use formatting directive for fmt.Errorf

* fix ut: remove tls provider func

* remove tlsproviderfunc from backend plugin tests

* use env var to prevent test plugin from running as a unit test

* WIP: remove lazy loading

* move non lazy loaded backend to new package

* use version wrapper for backend plugin factory

* remove backendVersionWrapper type

* implement getBackendPluginType for plugin catalog

* handle backend plugin v4 registration

* add plugin automtls env guard

* modify plugin factory to determine the backend to use

* remove old pluginsets from v5 and log pid in plugin catalog

* add reload mechanism via context

* readd v3 and v4 to pluginset

* call cleanup from reload if non-muxed

* move v5 backend code to new package

* use context reload for for ErrPluginShutdown case

* add wrapper on v5 backend

* fix run config UTs

* fix unit tests

- use v4/v5 mapping for plugin versions
- fix test build err
- add reload method on fakePluginClient
- add multiplexed cases for integration tests

* remove comment and update AutoMTLS field in test

* remove comment

* remove errwrap and unused context

* only support metadatamode false for v5 backend plugins

* update plugin catalog errors

* use const for env variables

* rename locks and remove unused

* remove unneeded nil check

* improvements based on staticcheck recommendations

* use const for single implementation string

* use const for context key

* use info default log level

* move pid to pluginClient struct

* remove v3 and v4 from multiplexed plugin set

* return from reload when non-multiplexed

* update automtls env string

* combine getBackend and getBrokeredClient

* update comments for plugin reload, Backend return val and log

* revert Backend return type

* allow non-muxed plugins to serve v5

* move v5 code to existing sdk plugin package

* do next export sdk fields now that we have removed extra plugin pkg

* set TLSProvider in ServeMultiplex for backwards compat

* use bool to flag multiplexing support on grpc backend server

* revert userpass main.go

* refactor plugin sdk

- update comments
- make use of multiplexing boolean and single implementation ID const

* update comment and use multierr

* attempt v4 if dispense fails on getPluginTypeForUnknown

* update comments on sdk plugin backend
2022-08-29 21:42:26 -05:00

315 lines
8.0 KiB
Go

package dbplugin
import (
"context"
"fmt"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var _ proto.DatabaseServer = &gRPCServer{}
type gRPCServer struct {
proto.UnimplementedDatabaseServer
// holds the non-multiplexed Database
// when this is set the plugin does not support multiplexing
singleImpl Database
// instances holds the multiplexed Databases
instances map[string]Database
factoryFunc func() (interface{}, error)
sync.RWMutex
}
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
g.Lock()
defer g.Unlock()
if g.singleImpl != nil {
return g.singleImpl, nil
}
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
if db, ok := g.instances[id]; ok {
return db, nil
}
db, err := g.factoryFunc()
if err != nil {
return nil, err
}
database := db.(Database)
g.instances[id] = database
return database, nil
}
// getDatabaseInternal returns the database but does not hold a lock
func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
if g.singleImpl != nil {
return g.singleImpl, nil
}
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
if db, ok := g.instances[id]; ok {
return db, nil
}
return nil, fmt.Errorf("no database instance found")
}
// getDatabase holds a read lock and returns the database
func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
g.RLock()
impl, err := g.getDatabaseInternal(ctx)
g.RUnlock()
return impl, err
}
// Initialize the database plugin
func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
impl, err := g.getOrCreateDatabase(ctx)
if err != nil {
return nil, err
}
rawConfig := structToMap(request.ConfigData)
dbReq := InitializeRequest{
Config: rawConfig,
VerifyConnection: request.VerifyConnection,
}
dbResp, err := impl.Initialize(ctx, dbReq)
if err != nil {
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
}
newConfig, err := mapToStruct(dbResp.Config)
if err != nil {
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to marshal new config to JSON: %s", err)
}
resp := &proto.InitializeResponse{
ConfigData: newConfig,
}
return resp, nil
}
func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
if req.GetUsernameConfig() == nil {
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
}
var expiration time.Time
if req.GetExpiration() != nil {
exp, err := ptypes.Timestamp(req.GetExpiration())
if err != nil {
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "unable to parse expiration date: %s", err)
}
expiration = exp
}
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
dbReq := NewUserRequest{
UsernameConfig: UsernameMetadata{
DisplayName: req.GetUsernameConfig().GetDisplayName(),
RoleName: req.GetUsernameConfig().GetRoleName(),
},
CredentialType: CredentialType(req.GetCredentialType()),
Password: req.GetPassword(),
PublicKey: req.GetPublicKey(),
Expiration: expiration,
Statements: getStatementsFromProto(req.GetStatements()),
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
}
dbResp, err := impl.NewUser(ctx, dbReq)
if err != nil {
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
}
resp := &proto.NewUserResponse{
Username: dbResp.Username,
}
return resp, nil
}
func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
if req.GetUsername() == "" {
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
}
dbReq, err := getUpdateUserRequest(req)
if err != nil {
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
}
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
_, err = impl.UpdateUser(ctx, dbReq)
if err != nil {
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
}
return &proto.UpdateUserResponse{}, nil
}
func getUpdateUserRequest(req *proto.UpdateUserRequest) (UpdateUserRequest, error) {
var password *ChangePassword
if req.GetPassword() != nil && req.GetPassword().GetNewPassword() != "" {
password = &ChangePassword{
NewPassword: req.GetPassword().GetNewPassword(),
Statements: getStatementsFromProto(req.GetPassword().GetStatements()),
}
}
var publicKey *ChangePublicKey
if req.GetPublicKey() != nil && len(req.GetPublicKey().GetNewPublicKey()) > 0 {
publicKey = &ChangePublicKey{
NewPublicKey: req.GetPublicKey().GetNewPublicKey(),
Statements: getStatementsFromProto(req.GetPublicKey().GetStatements()),
}
}
var expiration *ChangeExpiration
if req.GetExpiration() != nil && req.GetExpiration().GetNewExpiration() != nil {
newExpiration, err := ptypes.Timestamp(req.GetExpiration().GetNewExpiration())
if err != nil {
return UpdateUserRequest{}, fmt.Errorf("unable to parse new expiration: %w", err)
}
expiration = &ChangeExpiration{
NewExpiration: newExpiration,
Statements: getStatementsFromProto(req.GetExpiration().GetStatements()),
}
}
dbReq := UpdateUserRequest{
Username: req.GetUsername(),
CredentialType: CredentialType(req.GetCredentialType()),
Password: password,
PublicKey: publicKey,
Expiration: expiration,
}
if !hasChange(dbReq) {
return UpdateUserRequest{}, fmt.Errorf("update user request has no changes")
}
return dbReq, nil
}
func hasChange(dbReq UpdateUserRequest) bool {
if dbReq.Password != nil && dbReq.Password.NewPassword != "" {
return true
}
if dbReq.PublicKey != nil && len(dbReq.PublicKey.NewPublicKey) > 0 {
return true
}
if dbReq.Expiration != nil && !dbReq.Expiration.NewExpiration.IsZero() {
return true
}
return false
}
func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
if req.GetUsername() == "" {
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
}
dbReq := DeleteUserRequest{
Username: req.GetUsername(),
Statements: getStatementsFromProto(req.GetStatements()),
}
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
_, err = impl.DeleteUser(ctx, dbReq)
if err != nil {
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
}
return &proto.DeleteUserResponse{}, nil
}
func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
impl, err := g.getOrCreateDatabase(ctx)
if err != nil {
return nil, err
}
t, err := impl.Type()
if err != nil {
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
}
resp := &proto.TypeResponse{
Type: t,
}
return resp, nil
}
func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
g.Lock()
defer g.Unlock()
impl, err := g.getDatabaseInternal(ctx)
if err != nil {
return nil, err
}
err = impl.Close()
if err != nil {
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
}
if g.singleImpl == nil {
// only cleanup instances map when multiplexing is supported
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
delete(g.instances, id)
}
return &proto.Empty{}, nil
}
func getStatementsFromProto(protoStmts *proto.Statements) (statements Statements) {
if protoStmts == nil {
return statements
}
cmds := protoStmts.GetCommands()
statements = Statements{
Commands: cmds,
}
return statements
}