mirror of
https://github.com/hashicorp/vault.git
synced 2025-12-16 06:51:23 +01:00
feature: multiplexing support for database plugins (#14033)
* feat: DB plugin multiplexing (#13734) * WIP: start from main and get a plugin runner from core * move MultiplexedClient map to plugin catalog - call sys.NewPluginClient from PluginFactory - updates to getPluginClient - thread through isMetadataMode * use go-plugin ClientProtocol interface - call sys.NewPluginClient from dbplugin.NewPluginClient * move PluginSets to dbplugin package - export dbplugin HandshakeConfig - small refactor of PluginCatalog.getPluginClient * add removeMultiplexedClient; clean up on Close() - call client.Kill from plugin catalog - set rpcClient when muxed client exists * add ID to dbplugin.DatabasePluginClient struct * only create one plugin process per plugin type * update NewPluginClient to return connection ID to sdk - wrap grpc.ClientConn so we can inject the ID into context - get ID from context on grpc server * add v6 multiplexing protocol version * WIP: backwards compat for db plugins * Ensure locking on plugin catalog access - Create public GetPluginClient method for plugin catalog - rename postgres db plugin * use the New constructor for db plugins * grpc server: use write lock for Close and rlock for CRUD * cleanup MultiplexedClients on Close * remove TODO * fix multiplexing regression with grpc server connection * cleanup grpc server instances on close * embed ClientProtocol in Multiplexer interface * use PluginClientConfig arg to make NewPluginClient plugin type agnostic * create a new plugin process for non-muxed plugins * feat: plugin multiplexing: handle plugin client cleanup (#13896) * use closure for plugin client cleanup * log and return errors; add comments * move rpcClient wrapping to core for ID injection * refactor core plugin client and sdk * remove unused ID method * refactor and only wrap clientConn on multiplexed plugins * rename structs and do not export types * Slight refactor of system view interface * Revert "Slight refactor of system view interface" This reverts commit 73d420e5cd2f0415e000c5a9284ea72a58016dd6. * Revert "Revert "Slight refactor of system view interface"" This reverts commit f75527008a1db06d04a23e04c3059674be8adb5f. * only provide pluginRunner arg to the internal newPluginClient method * embed ClientProtocol in pluginClient and name logger * Add back MLock support * remove enableMlock arg from setupPluginCatalog * rename plugin util interface to PluginClient Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> * feature: multiplexing: fix unit tests (#14007) * fix grpc_server tests and add coverage * update run_config tests * add happy path test case for grpc_server ID from context * update test helpers * feat: multiplexing: handle v5 plugin compiled with new sdk * add mux supported flag and increase test coverage * set multiplexingSupport field in plugin server * remove multiplexingSupport field in sdk * revert postgres to non-multiplexed * add comments on grpc server fields * use pointer receiver on grpc server methods * add changelog * use pointer for grpcserver instance * Use a gRPC server to determine if a plugin should be multiplexed * Apply suggestions from code review Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com> * add lock to removePluginClient * add multiplexingSupport field to externalPlugin struct * do not send nil to grpc MultiplexingSupport * check err before logging * handle locking scenario for cleanupFunc * allow ServeConfigMultiplex to dispense v5 plugin * reposition structs, add err check and comments * add comment on locking for cleanupExternalPlugin Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
This commit is contained in:
parent
8b36f650c1
commit
3565c90cf8
1
Makefile
1
Makefile
@ -194,6 +194,7 @@ proto: bootstrap
|
|||||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto
|
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto
|
||||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto
|
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto
|
||||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto
|
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto
|
||||||
|
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto
|
||||||
|
|
||||||
# No additional sed expressions should be added to this list. Going forward
|
# No additional sed expressions should be added to this list. Going forward
|
||||||
# we should just use the variable names choosen by protobuf. These are left
|
# we should just use the variable names choosen by protobuf. These are left
|
||||||
|
|||||||
@ -110,6 +110,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type databaseBackend struct {
|
type databaseBackend struct {
|
||||||
|
// connections holds configured database connections by config name
|
||||||
connections map[string]*dbPluginInstance
|
connections map[string]*dbPluginInstance
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
|||||||
@ -329,6 +329,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||||||
}
|
}
|
||||||
config.ConnectionDetails = initResp.Config
|
config.ConnectionDetails = initResp.Config
|
||||||
|
|
||||||
|
b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName)
|
||||||
|
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
@ -365,6 +367,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||||||
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
|
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(resp.Warnings) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
3
changelog/14033.txt
Normal file
3
changelog/14033.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
```release-note:feature
|
||||||
|
**Database plugin multiplexing**: manage multiple database connections with a single plugin process
|
||||||
|
```
|
||||||
@ -150,8 +150,8 @@ type registry struct {
|
|||||||
logicalBackends map[string]logical.Factory
|
logicalBackends map[string]logical.Factory
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the BuiltinFactory func for a particular backend plugin
|
// Get returns the Factory func for a particular backend plugin from the
|
||||||
// from the plugins map.
|
// plugins map.
|
||||||
func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
|
func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
|
||||||
switch pluginType {
|
switch pluginType {
|
||||||
case consts.PluginTypeCredential:
|
case consts.PluginTypeCredential:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: helper/forwarding/types.proto
|
// source: helper/forwarding/types.proto
|
||||||
|
|
||||||
package forwarding
|
package forwarding
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: helper/identity/mfa/types.proto
|
// source: helper/identity/mfa/types.proto
|
||||||
|
|
||||||
package mfa
|
package mfa
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: helper/identity/types.proto
|
// source: helper/identity/types.proto
|
||||||
|
|
||||||
package identity
|
package identity
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: helper/storagepacker/types.proto
|
// source: helper/storagepacker/types.proto
|
||||||
|
|
||||||
package storagepacker
|
package storagepacker
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: physical/raft/types.proto
|
// source: physical/raft/types.proto
|
||||||
|
|
||||||
package raft
|
package raft
|
||||||
|
|||||||
@ -48,7 +48,6 @@ var (
|
|||||||
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
|
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
db := new()
|
db := new()
|
||||||
// Wrap the plugin with middleware to sanitize errors
|
// Wrap the plugin with middleware to sanitize errors
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: sdk/database/dbplugin/database.proto
|
// source: sdk/database/dbplugin/database.proto
|
||||||
|
|
||||||
package dbplugin
|
package dbplugin
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/hashicorp/go-plugin"
|
"github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,14 +13,17 @@ import (
|
|||||||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||||
// This prevents users from executing bad plugins or executing a plugin
|
// This prevents users from executing bad plugins or executing a plugin
|
||||||
// directory. It is a UX feature, not a security feature.
|
// directory. It is a UX feature, not a security feature.
|
||||||
var handshakeConfig = plugin.HandshakeConfig{
|
var HandshakeConfig = plugin.HandshakeConfig{
|
||||||
ProtocolVersion: 5,
|
|
||||||
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
|
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
|
||||||
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Factory is the factory function to create a dbplugin Database.
|
||||||
|
type Factory func() (interface{}, error)
|
||||||
|
|
||||||
type GRPCDatabasePlugin struct {
|
type GRPCDatabasePlugin struct {
|
||||||
Impl Database
|
FactoryFunc Factory
|
||||||
|
Impl Database
|
||||||
|
|
||||||
// Embeding this will disable the netRPC protocol
|
// Embeding this will disable the netRPC protocol
|
||||||
plugin.NetRPCUnsupportedPlugin
|
plugin.NetRPCUnsupportedPlugin
|
||||||
@ -31,7 +35,25 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
||||||
proto.RegisterDatabaseServer(s, gRPCServer{impl: d.Impl})
|
var server gRPCServer
|
||||||
|
|
||||||
|
if d.Impl != nil {
|
||||||
|
server = gRPCServer{singleImpl: d.Impl}
|
||||||
|
} else {
|
||||||
|
// multiplexing is supported
|
||||||
|
server = gRPCServer{
|
||||||
|
factoryFunc: d.FactoryFunc,
|
||||||
|
instances: make(map[string]Database),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiplexing is enabled for this plugin, register the server so we
|
||||||
|
// can tell the client in Vault.
|
||||||
|
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
|
||||||
|
Supported: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
proto.RegisterDatabaseServer(s, &server)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,24 +3,113 @@ package dbplugin
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ proto.DatabaseServer = gRPCServer{}
|
var _ proto.DatabaseServer = &gRPCServer{}
|
||||||
|
|
||||||
type gRPCServer struct {
|
type gRPCServer struct {
|
||||||
proto.UnimplementedDatabaseServer
|
proto.UnimplementedDatabaseServer
|
||||||
|
|
||||||
impl Database
|
// holds the non-multiplexed Database
|
||||||
|
// when this is set the plugin does not support multiplexing
|
||||||
|
singleImpl Database
|
||||||
|
|
||||||
|
// instances holds the multiplexed Databases
|
||||||
|
instances map[string]Database
|
||||||
|
factoryFunc func() (interface{}, error)
|
||||||
|
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMultiplexIDFromContext(ctx context.Context) (string, error) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("missing plugin multiplexing metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplexIDs := md[pluginutil.MultiplexingCtxKey]
|
||||||
|
if len(multiplexIDs) != 1 {
|
||||||
|
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplexID := multiplexIDs[0]
|
||||||
|
if multiplexID == "" {
|
||||||
|
return "", fmt.Errorf("empty multiplex ID in metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
return multiplexID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
if g.singleImpl != nil {
|
||||||
|
return g.singleImpl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := getMultiplexIDFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if db, ok := g.instances[id]; ok {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := g.factoryFunc()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
database := db.(Database)
|
||||||
|
g.instances[id] = database
|
||||||
|
|
||||||
|
return database, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDatabaseInternal returns the database but does not hold a lock
|
||||||
|
func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
|
||||||
|
if g.singleImpl != nil {
|
||||||
|
return g.singleImpl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := getMultiplexIDFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if db, ok := g.instances[id]; ok {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("no database instance found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDatabase holds a read lock and returns the database
|
||||||
|
func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
|
||||||
|
g.RLock()
|
||||||
|
impl, err := g.getDatabaseInternal(ctx)
|
||||||
|
g.RUnlock()
|
||||||
|
return impl, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the database plugin
|
// Initialize the database plugin
|
||||||
func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
|
func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
|
||||||
|
impl, err := g.getOrCreateDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
rawConfig := structToMap(request.ConfigData)
|
rawConfig := structToMap(request.ConfigData)
|
||||||
|
|
||||||
dbReq := InitializeRequest{
|
dbReq := InitializeRequest{
|
||||||
@ -28,7 +117,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
|
|||||||
VerifyConnection: request.VerifyConnection,
|
VerifyConnection: request.VerifyConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbResp, err := g.impl.Initialize(ctx, dbReq)
|
dbResp, err := impl.Initialize(ctx, dbReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
|
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
|
||||||
}
|
}
|
||||||
@ -45,7 +134,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
|
func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
|
||||||
if req.GetUsernameConfig() == nil {
|
if req.GetUsernameConfig() == nil {
|
||||||
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
|
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
|
||||||
}
|
}
|
||||||
@ -60,6 +149,11 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||||||
expiration = exp
|
expiration = exp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl, err := g.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
dbReq := NewUserRequest{
|
dbReq := NewUserRequest{
|
||||||
UsernameConfig: UsernameMetadata{
|
UsernameConfig: UsernameMetadata{
|
||||||
DisplayName: req.GetUsernameConfig().GetDisplayName(),
|
DisplayName: req.GetUsernameConfig().GetDisplayName(),
|
||||||
@ -71,7 +165,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||||||
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
|
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
|
||||||
}
|
}
|
||||||
|
|
||||||
dbResp, err := g.impl.NewUser(ctx, dbReq)
|
dbResp, err := impl.NewUser(ctx, dbReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
|
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
|
||||||
}
|
}
|
||||||
@ -82,7 +176,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
|
func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
|
||||||
if req.GetUsername() == "" {
|
if req.GetUsername() == "" {
|
||||||
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
||||||
}
|
}
|
||||||
@ -92,7 +186,12 @@ func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest
|
|||||||
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = g.impl.UpdateUser(ctx, dbReq)
|
impl, err := g.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = impl.UpdateUser(ctx, dbReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
|
||||||
}
|
}
|
||||||
@ -144,7 +243,7 @@ func hasChange(dbReq UpdateUserRequest) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
|
func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
|
||||||
if req.GetUsername() == "" {
|
if req.GetUsername() == "" {
|
||||||
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
||||||
}
|
}
|
||||||
@ -153,15 +252,25 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest
|
|||||||
Statements: getStatementsFromProto(req.GetStatements()),
|
Statements: getStatementsFromProto(req.GetStatements()),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := g.impl.DeleteUser(ctx, dbReq)
|
impl, err := g.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = impl.DeleteUser(ctx, dbReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
|
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
|
||||||
}
|
}
|
||||||
return &proto.DeleteUserResponse{}, nil
|
return &proto.DeleteUserResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
|
func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
|
||||||
t, err := g.impl.Type()
|
impl, err := g.getOrCreateDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := impl.Type()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
|
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
|
||||||
}
|
}
|
||||||
@ -172,11 +281,29 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
||||||
err := g.impl.Close()
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
impl, err := g.getDatabaseInternal(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = impl.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
|
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if g.singleImpl == nil {
|
||||||
|
// only cleanup instances map when multiplexing is supported
|
||||||
|
id, err := getMultiplexIDFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(g.instances, id)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.Empty{}, nil
|
return &proto.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,9 @@ import (
|
|||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,11 +24,12 @@ var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
|
|||||||
|
|
||||||
func TestGRPCServer_Initialize(t *testing.T) {
|
func TestGRPCServer_Initialize(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
db Database
|
db Database
|
||||||
req *proto.InitializeRequest
|
req *proto.InitializeRequest
|
||||||
expectedResp *proto.InitializeResponse
|
expectedResp *proto.InitializeResponse
|
||||||
expectErr bool
|
expectErr bool
|
||||||
expectCode codes.Code
|
expectCode codes.Code
|
||||||
|
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := map[string]testCase{
|
tests := map[string]testCase{
|
||||||
@ -34,10 +37,11 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||||||
db: fakeDatabase{
|
db: fakeDatabase{
|
||||||
initErr: errors.New("initialization error"),
|
initErr: errors.New("initialization error"),
|
||||||
},
|
},
|
||||||
req: &proto.InitializeRequest{},
|
req: &proto.InitializeRequest{},
|
||||||
expectedResp: &proto.InitializeResponse{},
|
expectedResp: &proto.InitializeResponse{},
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
expectCode: codes.Internal,
|
expectCode: codes.Internal,
|
||||||
|
grpcSetupFunc: testGrpcServer,
|
||||||
},
|
},
|
||||||
"newConfig can't marshal to JSON": {
|
"newConfig can't marshal to JSON": {
|
||||||
db: fakeDatabase{
|
db: fakeDatabase{
|
||||||
@ -47,12 +51,13 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
req: &proto.InitializeRequest{},
|
req: &proto.InitializeRequest{},
|
||||||
expectedResp: &proto.InitializeResponse{},
|
expectedResp: &proto.InitializeResponse{},
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
expectCode: codes.Internal,
|
expectCode: codes.Internal,
|
||||||
|
grpcSetupFunc: testGrpcServer,
|
||||||
},
|
},
|
||||||
"happy path with config data": {
|
"happy path with config data for multiplexed plugin": {
|
||||||
db: fakeDatabase{
|
db: fakeDatabase{
|
||||||
initResp: InitializeResponse{
|
initResp: InitializeResponse{
|
||||||
Config: map[string]interface{}{
|
Config: map[string]interface{}{
|
||||||
@ -70,21 +75,39 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
expectCode: codes.OK,
|
expectCode: codes.OK,
|
||||||
|
grpcSetupFunc: testGrpcServer,
|
||||||
|
},
|
||||||
|
"happy path with config data for non-multiplexed plugin": {
|
||||||
|
db: fakeDatabase{
|
||||||
|
initResp: InitializeResponse{
|
||||||
|
Config: map[string]interface{}{
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
req: &proto.InitializeRequest{
|
||||||
|
ConfigData: marshal(t, map[string]interface{}{
|
||||||
|
"foo": "bar",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
expectedResp: &proto.InitializeResponse{
|
||||||
|
ConfigData: marshal(t, map[string]interface{}{
|
||||||
|
"foo": "bar",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
expectCode: codes.OK,
|
||||||
|
grpcSetupFunc: testGrpcServerSingleImpl,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := test.grpcSetupFunc(t, test.db)
|
||||||
impl: test.db,
|
resp, err := g.Initialize(idCtx, test.req)
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
resp, err := g.Initialize(ctx, test.req)
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -252,14 +275,9 @@ func TestGRPCServer_NewUser(t *testing.T) {
|
|||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := testGrpcServer(t, test.db)
|
||||||
impl: test.db,
|
resp, err := g.NewUser(idCtx, test.req)
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
resp, err := g.NewUser(ctx, test.req)
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -362,14 +380,9 @@ func TestGRPCServer_UpdateUser(t *testing.T) {
|
|||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := testGrpcServer(t, test.db)
|
||||||
impl: test.db,
|
resp, err := g.UpdateUser(idCtx, test.req)
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
resp, err := g.UpdateUser(ctx, test.req)
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -430,14 +443,9 @@ func TestGRPCServer_DeleteUser(t *testing.T) {
|
|||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := testGrpcServer(t, test.db)
|
||||||
impl: test.db,
|
resp, err := g.DeleteUser(idCtx, test.req)
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
resp, err := g.DeleteUser(ctx, test.req)
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -488,14 +496,9 @@ func TestGRPCServer_Type(t *testing.T) {
|
|||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := testGrpcServer(t, test.db)
|
||||||
impl: test.db,
|
resp, err := g.Type(idCtx, &proto.Empty{})
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
resp, err := g.Type(ctx, &proto.Empty{})
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -517,9 +520,11 @@ func TestGRPCServer_Type(t *testing.T) {
|
|||||||
|
|
||||||
func TestGRPCServer_Close(t *testing.T) {
|
func TestGRPCServer_Close(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
db Database
|
db Database
|
||||||
expectErr bool
|
expectErr bool
|
||||||
expectCode codes.Code
|
expectCode codes.Code
|
||||||
|
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
||||||
|
assertFunc func(t *testing.T, g gRPCServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := map[string]testCase{
|
tests := map[string]testCase{
|
||||||
@ -527,26 +532,36 @@ func TestGRPCServer_Close(t *testing.T) {
|
|||||||
db: fakeDatabase{
|
db: fakeDatabase{
|
||||||
closeErr: errors.New("close error"),
|
closeErr: errors.New("close error"),
|
||||||
},
|
},
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
expectCode: codes.Internal,
|
expectCode: codes.Internal,
|
||||||
|
grpcSetupFunc: testGrpcServer,
|
||||||
|
assertFunc: nil,
|
||||||
},
|
},
|
||||||
"happy path": {
|
"happy path for multiplexed plugin": {
|
||||||
db: fakeDatabase{},
|
db: fakeDatabase{},
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
expectCode: codes.OK,
|
expectCode: codes.OK,
|
||||||
|
grpcSetupFunc: testGrpcServer,
|
||||||
|
assertFunc: func(t *testing.T, g gRPCServer) {
|
||||||
|
if len(g.instances) != 0 {
|
||||||
|
t.Fatalf("err expected instances map to be empty")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"happy path for non-multiplexed plugin": {
|
||||||
|
db: fakeDatabase{},
|
||||||
|
expectErr: false,
|
||||||
|
expectCode: codes.OK,
|
||||||
|
grpcSetupFunc: testGrpcServerSingleImpl,
|
||||||
|
assertFunc: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
g := gRPCServer{
|
idCtx, g := test.grpcSetupFunc(t, test.db)
|
||||||
impl: test.db,
|
_, err := g.Close(idCtx, &proto.Empty{})
|
||||||
}
|
|
||||||
|
|
||||||
// Context doesn't need to timeout since this is just passed through
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err := g.Close(ctx, &proto.Empty{})
|
|
||||||
if test.expectErr && err == nil {
|
if test.expectErr && err == nil {
|
||||||
t.Fatalf("err expected, got nil")
|
t.Fatalf("err expected, got nil")
|
||||||
}
|
}
|
||||||
@ -558,10 +573,105 @@ func TestGRPCServer_Close(t *testing.T) {
|
|||||||
if actualCode != test.expectCode {
|
if actualCode != test.expectCode {
|
||||||
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if test.assertFunc != nil {
|
||||||
|
test.assertFunc(t, g)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetMultiplexIDFromContext(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
ctx context.Context
|
||||||
|
expectedResp string
|
||||||
|
expectedErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]testCase{
|
||||||
|
"missing plugin multiplexing metadata": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
expectedResp: "",
|
||||||
|
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
|
||||||
|
},
|
||||||
|
"unexpected number of IDs in metadata": {
|
||||||
|
ctx: idCtx(t, "12345", "67891"),
|
||||||
|
expectedResp: "",
|
||||||
|
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
|
||||||
|
},
|
||||||
|
"empty multiplex ID in metadata": {
|
||||||
|
ctx: idCtx(t, ""),
|
||||||
|
expectedResp: "",
|
||||||
|
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
|
||||||
|
},
|
||||||
|
"happy path, id is returned from metadata": {
|
||||||
|
ctx: idCtx(t, "12345"),
|
||||||
|
expectedResp: "12345",
|
||||||
|
expectedErr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
resp, err := getMultiplexIDFromContext(test.ctx)
|
||||||
|
|
||||||
|
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
|
||||||
|
t.Fatalf("err expected, got nil")
|
||||||
|
} else if !reflect.DeepEqual(err, test.expectedErr) {
|
||||||
|
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
|
||||||
|
t.Fatalf("no error expected, got: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||||
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testGrpcServer is a test helper that returns a context with an ID set in its
|
||||||
|
// metadata and a gRPCServer instance for a multiplexed plugin
|
||||||
|
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
|
||||||
|
t.Helper()
|
||||||
|
g := gRPCServer{
|
||||||
|
factoryFunc: func() (interface{}, error) {
|
||||||
|
return db, nil
|
||||||
|
},
|
||||||
|
instances: make(map[string]Database),
|
||||||
|
}
|
||||||
|
|
||||||
|
id := "12345"
|
||||||
|
idCtx := idCtx(t, id)
|
||||||
|
g.instances[id] = db
|
||||||
|
|
||||||
|
return idCtx, g
|
||||||
|
}
|
||||||
|
|
||||||
|
// testGrpcServerSingleImpl is a test helper that returns a context and a
|
||||||
|
// gRPCServer instance for a non-multiplexed plugin
|
||||||
|
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
|
||||||
|
t.Helper()
|
||||||
|
return context.Background(), gRPCServer{
|
||||||
|
singleImpl: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// idCtx is a test helper that will return a context with the IDs set in its
|
||||||
|
// metadata
|
||||||
|
func idCtx(t *testing.T, ids ...string) context.Context {
|
||||||
|
t.Helper()
|
||||||
|
// Context doesn't need to timeout since this is just passed through
|
||||||
|
ctx := context.Background()
|
||||||
|
md := metadata.MD{}
|
||||||
|
for _, id := range ids {
|
||||||
|
md.Append(pluginutil.MultiplexingCtxKey, id)
|
||||||
|
}
|
||||||
|
return metadata.NewIncomingContext(ctx, md)
|
||||||
|
}
|
||||||
|
|
||||||
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@ -3,19 +3,14 @@ package dbplugin
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
|
||||||
plugin "github.com/hashicorp/go-plugin"
|
plugin "github.com/hashicorp/go-plugin"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
|
|
||||||
// method to also call Kill() on the plugin.Client.
|
|
||||||
type DatabasePluginClient struct {
|
type DatabasePluginClient struct {
|
||||||
client *plugin.Client
|
client pluginutil.PluginClient
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
Database
|
Database
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,42 +18,31 @@ type DatabasePluginClient struct {
|
|||||||
// and kill the plugin.
|
// and kill the plugin.
|
||||||
func (dc *DatabasePluginClient) Close() error {
|
func (dc *DatabasePluginClient) Close() error {
|
||||||
err := dc.Database.Close()
|
err := dc.Database.Close()
|
||||||
dc.client.Kill()
|
dc.client.Close()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pluginSets is the map of plugins we can dispense.
|
||||||
|
var PluginSets = map[int]plugin.PluginSet{
|
||||||
|
5: {
|
||||||
|
"database": &GRPCDatabasePlugin{},
|
||||||
|
},
|
||||||
|
6: {
|
||||||
|
"database": &GRPCDatabasePlugin{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// NewPluginClient returns a databaseRPCClient with a connection to a running
|
// NewPluginClient returns a databaseRPCClient with a connection to a running
|
||||||
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
|
// plugin.
|
||||||
// plugin is killed on call of Close().
|
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) {
|
||||||
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) {
|
pluginClient, err := sys.NewPluginClient(ctx, config)
|
||||||
// pluginSets is the map of plugins we can dispense.
|
|
||||||
pluginSets := map[int]plugin.PluginSet{
|
|
||||||
5: {
|
|
||||||
"database": new(GRPCDatabasePlugin),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := pluginRunner.RunConfig(ctx,
|
|
||||||
pluginutil.Runner(sys),
|
|
||||||
pluginutil.PluginSets(pluginSets),
|
|
||||||
pluginutil.HandshakeConfig(handshakeConfig),
|
|
||||||
pluginutil.Logger(logger),
|
|
||||||
pluginutil.MetadataMode(isMetadataMode),
|
|
||||||
pluginutil.AutoMTLS(true),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect via RPC
|
|
||||||
rpcClient, err := client.Client()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request the plugin
|
// Request the plugin
|
||||||
raw, err := rpcClient.Dispense("database")
|
raw, err := pluginClient.Dispense("database")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -66,16 +50,19 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||||||
// We should have a database type now. This feels like a normal interface
|
// We should have a database type now. This feels like a normal interface
|
||||||
// implementation but is in fact over an RPC connection.
|
// implementation but is in fact over an RPC connection.
|
||||||
var db Database
|
var db Database
|
||||||
switch raw.(type) {
|
switch c := raw.(type) {
|
||||||
case gRPCClient:
|
case gRPCClient:
|
||||||
db = raw.(gRPCClient)
|
// This is an abstraction leak from go-plugin but it is necessary in
|
||||||
|
// order to enable multiplexing on multiplexed plugins
|
||||||
|
c.client = proto.NewDatabaseClient(pluginClient.Conn())
|
||||||
|
|
||||||
|
db = c
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported client type")
|
return nil, errors.New("unsupported client type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap RPC implementation in DatabasePluginClient
|
|
||||||
return &DatabasePluginClient{
|
return &DatabasePluginClient{
|
||||||
client: client,
|
client: pluginClient,
|
||||||
Database: db,
|
Database: db,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,8 +40,17 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||||||
transport = "builtin"
|
transport = "builtin"
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
config := pluginutil.PluginClientConfig{
|
||||||
|
Name: pluginName,
|
||||||
|
PluginType: consts.PluginTypeDatabase,
|
||||||
|
PluginSets: PluginSets,
|
||||||
|
HandshakeConfig: HandshakeConfig,
|
||||||
|
Logger: namedLogger,
|
||||||
|
IsMetadataMode: false,
|
||||||
|
AutoMTLS: true,
|
||||||
|
}
|
||||||
// create a DatabasePluginClient instance
|
// create a DatabasePluginClient instance
|
||||||
db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false)
|
db, err = NewPluginClient(ctx, sys, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -59,6 +68,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
logger.Debug("got database plugin instance", "type", typeStr)
|
||||||
|
|
||||||
// Wrap with metrics middleware
|
// Wrap with metrics middleware
|
||||||
db = &databaseMetricsMiddleware{
|
db = &databaseMetricsMiddleware{
|
||||||
|
|||||||
@ -31,7 +31,49 @@ func ServeConfig(db Database) *plugin.ServeConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conf := &plugin.ServeConfig{
|
conf := &plugin.ServeConfig{
|
||||||
HandshakeConfig: handshakeConfig,
|
HandshakeConfig: HandshakeConfig,
|
||||||
|
VersionedPlugins: pluginSets,
|
||||||
|
GRPCServer: plugin.DefaultGRPCServer,
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServeMultiplex(factory Factory) {
|
||||||
|
plugin.Serve(ServeConfigMultiplex(factory))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig {
|
||||||
|
err := pluginutil.OptionallyEnableMlock()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := factory()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
database := db.(Database)
|
||||||
|
|
||||||
|
// pluginSets is the map of plugins we can dispense.
|
||||||
|
pluginSets := map[int]plugin.PluginSet{
|
||||||
|
5: {
|
||||||
|
"database": &GRPCDatabasePlugin{
|
||||||
|
Impl: database,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
6: {
|
||||||
|
"database": &GRPCDatabasePlugin{
|
||||||
|
FactoryFunc: factory,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conf := &plugin.ServeConfig{
|
||||||
|
HandshakeConfig: HandshakeConfig,
|
||||||
VersionedPlugins: pluginSets,
|
VersionedPlugins: pluginSets,
|
||||||
GRPCServer: plugin.DefaultGRPCServer,
|
GRPCServer: plugin.DefaultGRPCServer,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: sdk/database/dbplugin/v5/proto/database.proto
|
// source: sdk/database/dbplugin/v5/proto/database.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
|
|||||||
47
sdk/helper/pluginutil/multiplexing.go
Normal file
47
sdk/helper/pluginutil/multiplexing.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package pluginutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PluginMultiplexingServerImpl struct {
|
||||||
|
UnimplementedPluginMultiplexingServer
|
||||||
|
|
||||||
|
Supported bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, req *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
|
||||||
|
return &MultiplexingSupportResponse{
|
||||||
|
Supported: pm.Supported,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) {
|
||||||
|
if cc == nil {
|
||||||
|
return false, fmt.Errorf("client connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := new(MultiplexingSupportRequest)
|
||||||
|
resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
// If the server does not implement the multiplexing server then we can
|
||||||
|
// assume it is not multiplexed
|
||||||
|
if status.Code(err) == codes.Unimplemented {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
// Somehow got a nil response, assume not multiplexed
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Supported, nil
|
||||||
|
}
|
||||||
213
sdk/helper/pluginutil/multiplexing.pb.go
Normal file
213
sdk/helper/pluginutil/multiplexing.pb.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.27.1
|
||||||
|
// protoc v3.19.4
|
||||||
|
// source: sdk/helper/pluginutil/multiplexing.proto
|
||||||
|
|
||||||
|
package pluginutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
type MultiplexingSupportRequest struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportRequest) Reset() {
|
||||||
|
*x = MultiplexingSupportRequest{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*MultiplexingSupportRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use MultiplexingSupportRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*MultiplexingSupportRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultiplexingSupportResponse struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
Supported bool `protobuf:"varint,1,opt,name=supported,proto3" json:"supported,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportResponse) Reset() {
|
||||||
|
*x = MultiplexingSupportResponse{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*MultiplexingSupportResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use MultiplexingSupportResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*MultiplexingSupportResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MultiplexingSupportResponse) GetSupported() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Supported
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_sdk_helper_pluginutil_multiplexing_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
var file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = []byte{
|
||||||
|
0x0a, 0x28, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, 0x72, 0x2f, 0x70, 0x6c, 0x75,
|
||||||
|
0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
|
||||||
|
0x78, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x70, 0x6c, 0x75, 0x67,
|
||||||
|
0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
|
||||||
|
0x69, 0x6e, 0x67, 0x22, 0x1c, 0x0a, 0x1a, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
|
||||||
|
0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||||
|
0x74, 0x22, 0x3b, 0x0a, 0x1b, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e,
|
||||||
|
0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||||
|
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20,
|
||||||
|
0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x32, 0x97,
|
||||||
|
0x01, 0x0a, 0x12, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c,
|
||||||
|
0x65, 0x78, 0x69, 0x6e, 0x67, 0x12, 0x80, 0x01, 0x0a, 0x13, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70,
|
||||||
|
0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x33, 0x2e,
|
||||||
|
0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69,
|
||||||
|
0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
|
||||||
|
0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||||
|
0x73, 0x74, 0x1a, 0x34, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e,
|
||||||
|
0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c,
|
||||||
|
0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
|
||||||
|
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68,
|
||||||
|
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70,
|
||||||
|
0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65,
|
||||||
|
0x72, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x62, 0x06, 0x70, 0x72,
|
||||||
|
0x6f, 0x74, 0x6f, 0x33,
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce sync.Once
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = file_sdk_helper_pluginutil_multiplexing_proto_rawDesc
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP() []byte {
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce.Do(func() {
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = protoimpl.X.CompressGZIP(file_sdk_helper_pluginutil_multiplexing_proto_rawDescData)
|
||||||
|
})
|
||||||
|
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_sdk_helper_pluginutil_multiplexing_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||||
|
var file_sdk_helper_pluginutil_multiplexing_proto_goTypes = []interface{}{
|
||||||
|
(*MultiplexingSupportRequest)(nil), // 0: pluginutil.multiplexing.MultiplexingSupportRequest
|
||||||
|
(*MultiplexingSupportResponse)(nil), // 1: pluginutil.multiplexing.MultiplexingSupportResponse
|
||||||
|
}
|
||||||
|
var file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = []int32{
|
||||||
|
0, // 0: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:input_type -> pluginutil.multiplexing.MultiplexingSupportRequest
|
||||||
|
1, // 1: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:output_type -> pluginutil.multiplexing.MultiplexingSupportResponse
|
||||||
|
1, // [1:2] is the sub-list for method output_type
|
||||||
|
0, // [0:1] is the sub-list for method input_type
|
||||||
|
0, // [0:0] is the sub-list for extension type_name
|
||||||
|
0, // [0:0] is the sub-list for extension extendee
|
||||||
|
0, // [0:0] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_sdk_helper_pluginutil_multiplexing_proto_init() }
|
||||||
|
func file_sdk_helper_pluginutil_multiplexing_proto_init() {
|
||||||
|
if File_sdk_helper_pluginutil_multiplexing_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !protoimpl.UnsafeEnabled {
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*MultiplexingSupportRequest); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*MultiplexingSupportResponse); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: file_sdk_helper_pluginutil_multiplexing_proto_rawDesc,
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 2,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 1,
|
||||||
|
},
|
||||||
|
GoTypes: file_sdk_helper_pluginutil_multiplexing_proto_goTypes,
|
||||||
|
DependencyIndexes: file_sdk_helper_pluginutil_multiplexing_proto_depIdxs,
|
||||||
|
MessageInfos: file_sdk_helper_pluginutil_multiplexing_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_sdk_helper_pluginutil_multiplexing_proto = out.File
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = nil
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_goTypes = nil
|
||||||
|
file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = nil
|
||||||
|
}
|
||||||
13
sdk/helper/pluginutil/multiplexing.proto
Normal file
13
sdk/helper/pluginutil/multiplexing.proto
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
package pluginutil.multiplexing;
|
||||||
|
|
||||||
|
option go_package = "github.com/hashicorp/vault/sdk/helper/pluginutil";
|
||||||
|
|
||||||
|
message MultiplexingSupportRequest {}
|
||||||
|
message MultiplexingSupportResponse {
|
||||||
|
bool supported = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
service PluginMultiplexing {
|
||||||
|
rpc MultiplexingSupport(MultiplexingSupportRequest) returns (MultiplexingSupportResponse);
|
||||||
|
}
|
||||||
101
sdk/helper/pluginutil/multiplexing_grpc.pb.go
Normal file
101
sdk/helper/pluginutil/multiplexing_grpc.pb.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
|
|
||||||
|
package pluginutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
|
const _ = grpc.SupportPackageIsVersion7
|
||||||
|
|
||||||
|
// PluginMultiplexingClient is the client API for PluginMultiplexing service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
|
type PluginMultiplexingClient interface {
|
||||||
|
MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type pluginMultiplexingClient struct {
|
||||||
|
cc grpc.ClientConnInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPluginMultiplexingClient(cc grpc.ClientConnInterface) PluginMultiplexingClient {
|
||||||
|
return &pluginMultiplexingClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pluginMultiplexingClient) MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) {
|
||||||
|
out := new(MultiplexingSupportResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PluginMultiplexingServer is the server API for PluginMultiplexing service.
|
||||||
|
// All implementations must embed UnimplementedPluginMultiplexingServer
|
||||||
|
// for forward compatibility
|
||||||
|
type PluginMultiplexingServer interface {
|
||||||
|
MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error)
|
||||||
|
mustEmbedUnimplementedPluginMultiplexingServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedPluginMultiplexingServer must be embedded to have forward compatible implementations.
|
||||||
|
type UnimplementedPluginMultiplexingServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedPluginMultiplexingServer) MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method MultiplexingSupport not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedPluginMultiplexingServer) mustEmbedUnimplementedPluginMultiplexingServer() {}
|
||||||
|
|
||||||
|
// UnsafePluginMultiplexingServer may be embedded to opt out of forward compatibility for this service.
|
||||||
|
// Use of this interface is not recommended, as added methods to PluginMultiplexingServer will
|
||||||
|
// result in compilation errors.
|
||||||
|
type UnsafePluginMultiplexingServer interface {
|
||||||
|
mustEmbedUnimplementedPluginMultiplexingServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterPluginMultiplexingServer(s grpc.ServiceRegistrar, srv PluginMultiplexingServer) {
|
||||||
|
s.RegisterService(&PluginMultiplexing_ServiceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _PluginMultiplexing_MultiplexingSupport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(MultiplexingSupportRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, req.(*MultiplexingSupportRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PluginMultiplexing_ServiceDesc is the grpc.ServiceDesc for PluginMultiplexing service.
|
||||||
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
|
// and not to be introspected or modified (even as a copy)
|
||||||
|
var PluginMultiplexing_ServiceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "pluginutil.multiplexing.PluginMultiplexing",
|
||||||
|
HandlerType: (*PluginMultiplexingServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{
|
||||||
|
{
|
||||||
|
MethodName: "MultiplexingSupport",
|
||||||
|
Handler: _PluginMultiplexing_MultiplexingSupport_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{},
|
||||||
|
Metadata: "sdk/helper/pluginutil/multiplexing.proto",
|
||||||
|
}
|
||||||
@ -9,9 +9,21 @@ import (
|
|||||||
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
log "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/go-plugin"
|
"github.com/hashicorp/go-plugin"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
"github.com/hashicorp/vault/sdk/version"
|
"github.com/hashicorp/vault/sdk/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PluginClientConfig struct {
|
||||||
|
Name string
|
||||||
|
PluginType consts.PluginType
|
||||||
|
PluginSets map[int]plugin.PluginSet
|
||||||
|
HandshakeConfig plugin.HandshakeConfig
|
||||||
|
Logger log.Logger
|
||||||
|
IsMetadataMode bool
|
||||||
|
AutoMTLS bool
|
||||||
|
MLock bool
|
||||||
|
}
|
||||||
|
|
||||||
type runConfig struct {
|
type runConfig struct {
|
||||||
// Provided by PluginRunner
|
// Provided by PluginRunner
|
||||||
command string
|
command string
|
||||||
@ -21,12 +33,9 @@ type runConfig struct {
|
|||||||
// Initialized with what's in PluginRunner.Env, but can be added to
|
// Initialized with what's in PluginRunner.Env, but can be added to
|
||||||
env []string
|
env []string
|
||||||
|
|
||||||
wrapper RunnerUtil
|
wrapper RunnerUtil
|
||||||
pluginSets map[int]plugin.PluginSet
|
|
||||||
hs plugin.HandshakeConfig
|
PluginClientConfig
|
||||||
logger log.Logger
|
|
||||||
isMetadataMode bool
|
|
||||||
autoMTLS bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) {
|
func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) {
|
||||||
@ -34,19 +43,19 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||||||
cmd.Env = append(cmd.Env, rc.env...)
|
cmd.Env = append(cmd.Env, rc.env...)
|
||||||
|
|
||||||
// Add the mlock setting to the ENV of the plugin
|
// Add the mlock setting to the ENV of the plugin
|
||||||
if rc.wrapper != nil && rc.wrapper.MlockEnabled() {
|
if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) {
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
||||||
}
|
}
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
|
||||||
|
|
||||||
if rc.isMetadataMode {
|
if rc.IsMetadataMode {
|
||||||
rc.logger = rc.logger.With("metadata", "true")
|
rc.Logger = rc.Logger.With("metadata", "true")
|
||||||
}
|
}
|
||||||
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.isMetadataMode)
|
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode)
|
||||||
cmd.Env = append(cmd.Env, metadataEnv)
|
cmd.Env = append(cmd.Env, metadataEnv)
|
||||||
|
|
||||||
var clientTLSConfig *tls.Config
|
var clientTLSConfig *tls.Config
|
||||||
if !rc.autoMTLS && !rc.isMetadataMode {
|
if !rc.AutoMTLS && !rc.IsMetadataMode {
|
||||||
// Get a CA TLS Certificate
|
// Get a CA TLS Certificate
|
||||||
certBytes, key, err := generateCert()
|
certBytes, key, err := generateCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -76,17 +85,17 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
clientConfig := &plugin.ClientConfig{
|
clientConfig := &plugin.ClientConfig{
|
||||||
HandshakeConfig: rc.hs,
|
HandshakeConfig: rc.HandshakeConfig,
|
||||||
VersionedPlugins: rc.pluginSets,
|
VersionedPlugins: rc.PluginSets,
|
||||||
Cmd: cmd,
|
Cmd: cmd,
|
||||||
SecureConfig: secureConfig,
|
SecureConfig: secureConfig,
|
||||||
TLSConfig: clientTLSConfig,
|
TLSConfig: clientTLSConfig,
|
||||||
Logger: rc.logger,
|
Logger: rc.Logger,
|
||||||
AllowedProtocols: []plugin.Protocol{
|
AllowedProtocols: []plugin.Protocol{
|
||||||
plugin.ProtocolNetRPC,
|
plugin.ProtocolNetRPC,
|
||||||
plugin.ProtocolGRPC,
|
plugin.ProtocolGRPC,
|
||||||
},
|
},
|
||||||
AutoMTLS: rc.autoMTLS,
|
AutoMTLS: rc.AutoMTLS,
|
||||||
}
|
}
|
||||||
return clientConfig, nil
|
return clientConfig, nil
|
||||||
}
|
}
|
||||||
@ -117,31 +126,37 @@ func Runner(wrapper RunnerUtil) RunOpt {
|
|||||||
|
|
||||||
func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt {
|
func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt {
|
||||||
return func(rc *runConfig) {
|
return func(rc *runConfig) {
|
||||||
rc.pluginSets = pluginSets
|
rc.PluginSets = pluginSets
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt {
|
func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt {
|
||||||
return func(rc *runConfig) {
|
return func(rc *runConfig) {
|
||||||
rc.hs = hs
|
rc.HandshakeConfig = hs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Logger(logger log.Logger) RunOpt {
|
func Logger(logger log.Logger) RunOpt {
|
||||||
return func(rc *runConfig) {
|
return func(rc *runConfig) {
|
||||||
rc.logger = logger
|
rc.Logger = logger
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MetadataMode(isMetadataMode bool) RunOpt {
|
func MetadataMode(isMetadataMode bool) RunOpt {
|
||||||
return func(rc *runConfig) {
|
return func(rc *runConfig) {
|
||||||
rc.isMetadataMode = isMetadataMode
|
rc.IsMetadataMode = isMetadataMode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutoMTLS(autoMTLS bool) RunOpt {
|
func AutoMTLS(autoMTLS bool) RunOpt {
|
||||||
return func(rc *runConfig) {
|
return func(rc *runConfig) {
|
||||||
rc.autoMTLS = autoMTLS
|
rc.AutoMTLS = autoMTLS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MLock(mlock bool) RunOpt {
|
||||||
|
return func(rc *runConfig) {
|
||||||
|
rc.MLock = mlock
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -38,19 +38,21 @@ func TestMakeConfig(t *testing.T) {
|
|||||||
args: []string{"foo", "bar"},
|
args: []string{"foo", "bar"},
|
||||||
sha256: []byte("some_sha256"),
|
sha256: []byte("some_sha256"),
|
||||||
env: []string{"initial=true"},
|
env: []string{"initial=true"},
|
||||||
pluginSets: map[int]plugin.PluginSet{
|
PluginClientConfig: PluginClientConfig{
|
||||||
1: {
|
PluginSets: map[int]plugin.PluginSet{
|
||||||
"bogus": nil,
|
1: {
|
||||||
|
"bogus": nil,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
HandshakeConfig: plugin.HandshakeConfig{
|
||||||
|
ProtocolVersion: 1,
|
||||||
|
MagicCookieKey: "magic_cookie_key",
|
||||||
|
MagicCookieValue: "magic_cookie_value",
|
||||||
|
},
|
||||||
|
Logger: hclog.NewNullLogger(),
|
||||||
|
IsMetadataMode: true,
|
||||||
|
AutoMTLS: false,
|
||||||
},
|
},
|
||||||
hs: plugin.HandshakeConfig{
|
|
||||||
ProtocolVersion: 1,
|
|
||||||
MagicCookieKey: "magic_cookie_key",
|
|
||||||
MagicCookieValue: "magic_cookie_value",
|
|
||||||
},
|
|
||||||
logger: hclog.NewNullLogger(),
|
|
||||||
isMetadataMode: true,
|
|
||||||
autoMTLS: false,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
responseWrapInfoTimes: 0,
|
responseWrapInfoTimes: 0,
|
||||||
@ -97,19 +99,21 @@ func TestMakeConfig(t *testing.T) {
|
|||||||
args: []string{"foo", "bar"},
|
args: []string{"foo", "bar"},
|
||||||
sha256: []byte("some_sha256"),
|
sha256: []byte("some_sha256"),
|
||||||
env: []string{"initial=true"},
|
env: []string{"initial=true"},
|
||||||
pluginSets: map[int]plugin.PluginSet{
|
PluginClientConfig: PluginClientConfig{
|
||||||
1: {
|
PluginSets: map[int]plugin.PluginSet{
|
||||||
"bogus": nil,
|
1: {
|
||||||
|
"bogus": nil,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
HandshakeConfig: plugin.HandshakeConfig{
|
||||||
|
ProtocolVersion: 1,
|
||||||
|
MagicCookieKey: "magic_cookie_key",
|
||||||
|
MagicCookieValue: "magic_cookie_value",
|
||||||
|
},
|
||||||
|
Logger: hclog.NewNullLogger(),
|
||||||
|
IsMetadataMode: false,
|
||||||
|
AutoMTLS: false,
|
||||||
},
|
},
|
||||||
hs: plugin.HandshakeConfig{
|
|
||||||
ProtocolVersion: 1,
|
|
||||||
MagicCookieKey: "magic_cookie_key",
|
|
||||||
MagicCookieValue: "magic_cookie_value",
|
|
||||||
},
|
|
||||||
logger: hclog.NewNullLogger(),
|
|
||||||
isMetadataMode: false,
|
|
||||||
autoMTLS: false,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
responseWrapInfo: &wrapping.ResponseWrapInfo{
|
responseWrapInfo: &wrapping.ResponseWrapInfo{
|
||||||
@ -161,19 +165,21 @@ func TestMakeConfig(t *testing.T) {
|
|||||||
args: []string{"foo", "bar"},
|
args: []string{"foo", "bar"},
|
||||||
sha256: []byte("some_sha256"),
|
sha256: []byte("some_sha256"),
|
||||||
env: []string{"initial=true"},
|
env: []string{"initial=true"},
|
||||||
pluginSets: map[int]plugin.PluginSet{
|
PluginClientConfig: PluginClientConfig{
|
||||||
1: {
|
PluginSets: map[int]plugin.PluginSet{
|
||||||
"bogus": nil,
|
1: {
|
||||||
|
"bogus": nil,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
HandshakeConfig: plugin.HandshakeConfig{
|
||||||
|
ProtocolVersion: 1,
|
||||||
|
MagicCookieKey: "magic_cookie_key",
|
||||||
|
MagicCookieValue: "magic_cookie_value",
|
||||||
|
},
|
||||||
|
Logger: hclog.NewNullLogger(),
|
||||||
|
IsMetadataMode: true,
|
||||||
|
AutoMTLS: true,
|
||||||
},
|
},
|
||||||
hs: plugin.HandshakeConfig{
|
|
||||||
ProtocolVersion: 1,
|
|
||||||
MagicCookieKey: "magic_cookie_key",
|
|
||||||
MagicCookieValue: "magic_cookie_value",
|
|
||||||
},
|
|
||||||
logger: hclog.NewNullLogger(),
|
|
||||||
isMetadataMode: true,
|
|
||||||
autoMTLS: true,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
responseWrapInfoTimes: 0,
|
responseWrapInfoTimes: 0,
|
||||||
@ -220,19 +226,21 @@ func TestMakeConfig(t *testing.T) {
|
|||||||
args: []string{"foo", "bar"},
|
args: []string{"foo", "bar"},
|
||||||
sha256: []byte("some_sha256"),
|
sha256: []byte("some_sha256"),
|
||||||
env: []string{"initial=true"},
|
env: []string{"initial=true"},
|
||||||
pluginSets: map[int]plugin.PluginSet{
|
PluginClientConfig: PluginClientConfig{
|
||||||
1: {
|
PluginSets: map[int]plugin.PluginSet{
|
||||||
"bogus": nil,
|
1: {
|
||||||
|
"bogus": nil,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
HandshakeConfig: plugin.HandshakeConfig{
|
||||||
|
ProtocolVersion: 1,
|
||||||
|
MagicCookieKey: "magic_cookie_key",
|
||||||
|
MagicCookieValue: "magic_cookie_value",
|
||||||
|
},
|
||||||
|
Logger: hclog.NewNullLogger(),
|
||||||
|
IsMetadataMode: false,
|
||||||
|
AutoMTLS: true,
|
||||||
},
|
},
|
||||||
hs: plugin.HandshakeConfig{
|
|
||||||
ProtocolVersion: 1,
|
|
||||||
MagicCookieKey: "magic_cookie_key",
|
|
||||||
MagicCookieValue: "magic_cookie_value",
|
|
||||||
},
|
|
||||||
logger: hclog.NewNullLogger(),
|
|
||||||
isMetadataMode: false,
|
|
||||||
autoMTLS: true,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
responseWrapInfoTimes: 0,
|
responseWrapInfoTimes: 0,
|
||||||
@ -329,6 +337,11 @@ type mockRunnerUtil struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) {
|
||||||
|
args := m.Called(ctx, config)
|
||||||
|
return args.Get(0).(PluginClient), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||||
args := m.Called(ctx, data, ttl, jwt)
|
args := m.Called(ctx, data, ttl, jwt)
|
||||||
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
|
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
plugin "github.com/hashicorp/go-plugin"
|
plugin "github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
||||||
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Looker defines the plugin Lookup function that looks into the plugin catalog
|
// Looker defines the plugin Lookup function that looks into the plugin catalog
|
||||||
@ -21,6 +22,7 @@ type Looker interface {
|
|||||||
// configuration and wrapping data in a response wrapped token.
|
// configuration and wrapping data in a response wrapped token.
|
||||||
// logical.SystemView implementations satisfy this interface.
|
// logical.SystemView implementations satisfy this interface.
|
||||||
type RunnerUtil interface {
|
type RunnerUtil interface {
|
||||||
|
NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error)
|
||||||
ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
||||||
MlockEnabled() bool
|
MlockEnabled() bool
|
||||||
}
|
}
|
||||||
@ -31,17 +33,25 @@ type LookRunnerUtil interface {
|
|||||||
RunnerUtil
|
RunnerUtil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PluginClient interface {
|
||||||
|
Conn() grpc.ClientConnInterface
|
||||||
|
plugin.ClientProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
const MultiplexingCtxKey string = "multiplex_id"
|
||||||
|
|
||||||
// PluginRunner defines the metadata needed to run a plugin securely with
|
// PluginRunner defines the metadata needed to run a plugin securely with
|
||||||
// go-plugin.
|
// go-plugin.
|
||||||
type PluginRunner struct {
|
type PluginRunner struct {
|
||||||
Name string `json:"name" structs:"name"`
|
Name string `json:"name" structs:"name"`
|
||||||
Type consts.PluginType `json:"type" structs:"type"`
|
Type consts.PluginType `json:"type" structs:"type"`
|
||||||
Command string `json:"command" structs:"command"`
|
Command string `json:"command" structs:"command"`
|
||||||
Args []string `json:"args" structs:"args"`
|
Args []string `json:"args" structs:"args"`
|
||||||
Env []string `json:"env" structs:"env"`
|
Env []string `json:"env" structs:"env"`
|
||||||
Sha256 []byte `json:"sha256" structs:"sha256"`
|
Sha256 []byte `json:"sha256" structs:"sha256"`
|
||||||
Builtin bool `json:"builtin" structs:"builtin"`
|
Builtin bool `json:"builtin" structs:"builtin"`
|
||||||
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
|
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
|
||||||
|
MultiplexingSupport bool `json:"multiplexing_support" structs:"multiplexing_support"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and
|
// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: sdk/logical/identity.proto
|
// source: sdk/logical/identity.proto
|
||||||
|
|
||||||
package logical
|
package logical
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: sdk/logical/plugin.proto
|
// source: sdk/logical/plugin.proto
|
||||||
|
|
||||||
package logical
|
package logical
|
||||||
|
|||||||
@ -56,6 +56,10 @@ type SystemView interface {
|
|||||||
// name. Returns a PluginRunner or an error if a plugin can not be found.
|
// name. Returns a PluginRunner or an error if a plugin can not be found.
|
||||||
LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error)
|
LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error)
|
||||||
|
|
||||||
|
// NewPluginClient returns a client for managing the lifecycle of plugin
|
||||||
|
// processes
|
||||||
|
NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error)
|
||||||
|
|
||||||
// MlockEnabled returns the configuration setting for enabling mlock on
|
// MlockEnabled returns the configuration setting for enabling mlock on
|
||||||
// plugins.
|
// plugins.
|
||||||
MlockEnabled() bool
|
MlockEnabled() bool
|
||||||
@ -152,6 +156,10 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState {
|
|||||||
return d.ReplicationStateVal
|
return d.ReplicationStateVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d StaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||||
|
return nil, errors.New("NewPluginClient is not implemented in StaticSystemView")
|
||||||
|
}
|
||||||
|
|
||||||
func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||||
return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView")
|
return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,6 +99,10 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st
|
|||||||
return info, nil
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||||
|
return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
|
func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
|
||||||
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: sdk/plugin/pb/backend.proto
|
// source: sdk/plugin/pb/backend.proto
|
||||||
|
|
||||||
package pb
|
package pb
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: vault/activity/activity_log.proto
|
// source: vault/activity/activity_log.proto
|
||||||
|
|
||||||
package activity
|
package activity
|
||||||
|
|||||||
@ -215,6 +215,22 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string
|
|||||||
return resp.WrapInfo, nil
|
return resp.WrapInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d dynamicSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||||
|
if d.core == nil {
|
||||||
|
return nil, fmt.Errorf("system view core is nil")
|
||||||
|
}
|
||||||
|
if d.core.pluginCatalog == nil {
|
||||||
|
return nil, fmt.Errorf("system view core plugin catalog is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := d.core.pluginCatalog.NewPluginClient(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
|
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
|
||||||
// returns a PluginRunner or an error if no plugin was found.
|
// returns a PluginRunner or an error if no plugin was found.
|
||||||
func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) {
|
func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) {
|
||||||
|
|||||||
@ -12,6 +12,8 @@ import (
|
|||||||
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
log "github.com/hashicorp/go-hclog"
|
||||||
multierror "github.com/hashicorp/go-multierror"
|
multierror "github.com/hashicorp/go-multierror"
|
||||||
|
plugin "github.com/hashicorp/go-plugin"
|
||||||
|
"github.com/hashicorp/go-secure-stdlib/base62"
|
||||||
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
|
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||||
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
@ -19,6 +21,8 @@ import (
|
|||||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
backendplugin "github.com/hashicorp/vault/sdk/plugin"
|
backendplugin "github.com/hashicorp/vault/sdk/plugin"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -35,21 +39,62 @@ type PluginCatalog struct {
|
|||||||
builtinRegistry BuiltinRegistry
|
builtinRegistry BuiltinRegistry
|
||||||
catalogView *BarrierView
|
catalogView *BarrierView
|
||||||
directory string
|
directory string
|
||||||
|
logger log.Logger
|
||||||
|
|
||||||
|
// externalPlugins holds plugin process connections by plugin name
|
||||||
|
//
|
||||||
|
// This allows plugins that suppport multiplexing to use a single grpc
|
||||||
|
// connection to communicate with multiple "backends". Each backend
|
||||||
|
// configuration using the same plugin will be routed to the existing
|
||||||
|
// plugin process.
|
||||||
|
externalPlugins map[string]*externalPlugin
|
||||||
|
mlockPlugins bool
|
||||||
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// externalPlugin holds client connections for multiplexed and
|
||||||
|
// non-multiplexed plugin processes
|
||||||
|
type externalPlugin struct {
|
||||||
|
// name is the plugin name
|
||||||
|
name string
|
||||||
|
|
||||||
|
// connections holds client connections by ID
|
||||||
|
connections map[string]*pluginClient
|
||||||
|
|
||||||
|
multiplexingSupport bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// pluginClient represents a connection to a plugin process
|
||||||
|
type pluginClient struct {
|
||||||
|
logger log.Logger
|
||||||
|
|
||||||
|
// id is the connection ID
|
||||||
|
id string
|
||||||
|
|
||||||
|
// client handles the lifecycle of a plugin process
|
||||||
|
// multiplexed plugins share the same client
|
||||||
|
client *plugin.Client
|
||||||
|
clientConn grpc.ClientConnInterface
|
||||||
|
cleanupFunc func() error
|
||||||
|
|
||||||
|
plugin.ClientProtocol
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Core) setupPluginCatalog(ctx context.Context) error {
|
func (c *Core) setupPluginCatalog(ctx context.Context) error {
|
||||||
c.pluginCatalog = &PluginCatalog{
|
c.pluginCatalog = &PluginCatalog{
|
||||||
builtinRegistry: c.builtinRegistry,
|
builtinRegistry: c.builtinRegistry,
|
||||||
catalogView: NewBarrierView(c.barrier, pluginCatalogPath),
|
catalogView: NewBarrierView(c.barrier, pluginCatalogPath),
|
||||||
directory: c.pluginDirectory,
|
directory: c.pluginDirectory,
|
||||||
|
logger: c.logger,
|
||||||
|
mlockPlugins: c.enableMlock,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run upgrade if untyped plugins exist
|
// Run upgrade if untyped plugins exist
|
||||||
err := c.pluginCatalog.UpgradePlugins(ctx, c.logger)
|
err := c.pluginCatalog.UpgradePlugins(ctx, c.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("error while upgrading plugin storage", "error", err)
|
c.logger.Error("error while upgrading plugin storage", "error", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.logger.IsInfo() {
|
if c.logger.IsInfo() {
|
||||||
@ -59,14 +104,205 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pluginClientConn struct {
|
||||||
|
*grpc.ClientConn
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ grpc.ClientConnInterface = &pluginClientConn{}
|
||||||
|
|
||||||
|
func (d *pluginClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
|
||||||
|
// Inject ID to the context
|
||||||
|
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
|
||||||
|
idCtx := metadata.NewOutgoingContext(ctx, md)
|
||||||
|
|
||||||
|
return d.ClientConn.Invoke(idCtx, method, args, reply, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pluginClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||||
|
// Inject ID to the context
|
||||||
|
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
|
||||||
|
idCtx := metadata.NewOutgoingContext(ctx, md)
|
||||||
|
|
||||||
|
return d.ClientConn.NewStream(idCtx, desc, method, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pluginClient) Conn() grpc.ClientConnInterface {
|
||||||
|
return p.clientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the plugin client's cleanupFunc to do any necessary cleanup on
|
||||||
|
// the plugin client and the PluginCatalog. This implements the
|
||||||
|
// plugin.ClientProtocol interface.
|
||||||
|
func (p *pluginClient) Close() error {
|
||||||
|
p.logger.Debug("cleaning up plugin client connection", "id", p.id)
|
||||||
|
return p.cleanupFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExternalPlugin will kill plugin processes and perform any necessary
|
||||||
|
// cleanup on the externalPlugins map for multiplexed and non-multiplexed
|
||||||
|
// plugins. This should be called with the write lock held.
|
||||||
|
func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
|
||||||
|
extPlugin, ok := c.externalPlugins[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("plugin client not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
pluginClient := extPlugin.connections[id]
|
||||||
|
|
||||||
|
delete(extPlugin.connections, id)
|
||||||
|
if !extPlugin.multiplexingSupport {
|
||||||
|
pluginClient.client.Kill()
|
||||||
|
|
||||||
|
if len(extPlugin.connections) == 0 {
|
||||||
|
delete(c.externalPlugins, name)
|
||||||
|
}
|
||||||
|
} else if len(extPlugin.connections) == 0 {
|
||||||
|
pluginClient.client.Kill()
|
||||||
|
delete(c.externalPlugins, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin {
|
||||||
|
if extPlugin, ok := c.externalPlugins[pluginName]; ok {
|
||||||
|
return extPlugin
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.newExternalPlugin(pluginName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin {
|
||||||
|
if c.externalPlugins == nil {
|
||||||
|
c.externalPlugins = make(map[string]*externalPlugin)
|
||||||
|
}
|
||||||
|
|
||||||
|
extPlugin := &externalPlugin{
|
||||||
|
connections: make(map[string]*pluginClient),
|
||||||
|
name: pluginName,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.externalPlugins[pluginName] = extPlugin
|
||||||
|
return extPlugin
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPluginClient returns a client for managing the lifecycle of a plugin
|
||||||
|
// process
|
||||||
|
func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
if config.Name == "" {
|
||||||
|
return nil, fmt.Errorf("no name provided for plugin")
|
||||||
|
}
|
||||||
|
if config.PluginType == consts.PluginTypeUnknown {
|
||||||
|
return nil, fmt.Errorf("no plugin type provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
pluginRunner, err := c.get(ctx, config.Name, config.PluginType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to lookup plugin: %w", err)
|
||||||
|
}
|
||||||
|
if pluginRunner == nil {
|
||||||
|
return nil, fmt.Errorf("no plugin found")
|
||||||
|
}
|
||||||
|
pc, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||||
|
return pc, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newPluginClient returns a client for managing the lifecycle of a plugin
|
||||||
|
// process. Callers should have the write lock held.
|
||||||
|
func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) {
|
||||||
|
if pluginRunner == nil {
|
||||||
|
return nil, fmt.Errorf("no plugin found")
|
||||||
|
}
|
||||||
|
|
||||||
|
extPlugin := c.getExternalPlugin(pluginRunner.Name)
|
||||||
|
id, err := base62.Random(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pc := &pluginClient{
|
||||||
|
id: id,
|
||||||
|
logger: c.logger.Named(pluginRunner.Name),
|
||||||
|
cleanupFunc: func() error {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
return c.cleanupExternalPlugin(pluginRunner.Name, id)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 {
|
||||||
|
c.logger.Debug("spawning a new plugin process", "plugin_name", pluginRunner.Name, "id", id)
|
||||||
|
client, err := pluginRunner.RunConfig(ctx,
|
||||||
|
pluginutil.PluginSets(config.PluginSets),
|
||||||
|
pluginutil.HandshakeConfig(config.HandshakeConfig),
|
||||||
|
pluginutil.Logger(config.Logger),
|
||||||
|
pluginutil.MetadataMode(config.IsMetadataMode),
|
||||||
|
pluginutil.MLock(c.mlockPlugins),
|
||||||
|
|
||||||
|
// NewPluginClient only supports AutoMTLS today
|
||||||
|
pluginutil.AutoMTLS(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pc.client = client
|
||||||
|
} else {
|
||||||
|
c.logger.Debug("returning existing plugin client for multiplexed plugin", "id", id)
|
||||||
|
|
||||||
|
// get the first client, since they are all the same
|
||||||
|
for k := range extPlugin.connections {
|
||||||
|
pc.client = extPlugin.connections[k].client
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if pc.client == nil {
|
||||||
|
return nil, fmt.Errorf("plugin client is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the protocol client for this connection.
|
||||||
|
// Subsequent calls to this will return the same client.
|
||||||
|
rpcClient, err := pc.client.Client()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn := rpcClient.(*plugin.GRPCClient).Conn
|
||||||
|
|
||||||
|
if pluginRunner.MultiplexingSupport {
|
||||||
|
// Wrap rpcClient with our implementation so that we can inject the
|
||||||
|
// ID into the context
|
||||||
|
pc.clientConn = &pluginClientConn{
|
||||||
|
ClientConn: clientConn,
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pc.clientConn = clientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
pc.ClientProtocol = rpcClient
|
||||||
|
|
||||||
|
extPlugin.connections[id] = pc
|
||||||
|
extPlugin.name = pluginRunner.Name
|
||||||
|
extPlugin.multiplexingSupport = pluginRunner.MultiplexingSupport
|
||||||
|
|
||||||
|
return extPlugin.connections[id], nil
|
||||||
|
}
|
||||||
|
|
||||||
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
|
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
|
||||||
// type. It will first attempt to run as a database plugin then a backend
|
// type and if it supports multiplexing. It will first attempt to run as a
|
||||||
// plugin. Both of these will be run in metadata mode.
|
// database plugin then a backend plugin. Both of these will be run in metadata
|
||||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
|
// mode.
|
||||||
|
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) {
|
||||||
merr := &multierror.Error{}
|
merr := &multierror.Error{}
|
||||||
err := isDatabasePlugin(ctx, plugin)
|
multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return consts.PluginTypeDatabase, nil
|
return consts.PluginTypeDatabase, multiplexingSupport, nil
|
||||||
}
|
}
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
|
|
||||||
@ -75,7 +311,7 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
err := client.Setup(ctx, &logical.BackendConfig{})
|
err := client.Setup(ctx, &logical.BackendConfig{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return consts.PluginTypeUnknown, err
|
return consts.PluginTypeUnknown, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
backendType := client.Type()
|
backendType := client.Type()
|
||||||
@ -83,9 +319,9 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||||||
|
|
||||||
switch backendType {
|
switch backendType {
|
||||||
case logical.TypeCredential:
|
case logical.TypeCredential:
|
||||||
return consts.PluginTypeCredential, nil
|
return consts.PluginTypeCredential, false, nil
|
||||||
case logical.TypeLogical:
|
case logical.TypeLogical:
|
||||||
return consts.PluginTypeSecrets, nil
|
return consts.PluginTypeSecrets, false, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
@ -102,29 +338,55 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||||||
"error", merr.Error())
|
"error", merr.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return consts.PluginTypeUnknown, nil
|
return consts.PluginTypeUnknown, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isDatabasePlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error {
|
// isDatabasePlugin returns true if the plugin supports multiplexing. An error
|
||||||
|
// is returned if the plugin is not a database plugin.
|
||||||
|
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) {
|
||||||
merr := &multierror.Error{}
|
merr := &multierror.Error{}
|
||||||
// Attempt to run as database V5 plugin
|
config := pluginutil.PluginClientConfig{
|
||||||
v5Client, err := v5.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
|
Name: pluginRunner.Name,
|
||||||
|
PluginSets: v5.PluginSets,
|
||||||
|
PluginType: consts.PluginTypeDatabase,
|
||||||
|
HandshakeConfig: v5.HandshakeConfig,
|
||||||
|
Logger: log.NewNullLogger(),
|
||||||
|
IsMetadataMode: true,
|
||||||
|
AutoMTLS: true,
|
||||||
|
}
|
||||||
|
// Attempt to run as database V5 or V6 multiplexed plugin
|
||||||
|
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
// At this point the pluginRunner does not know if multiplexing is
|
||||||
|
// supported or not. So we need to ask the plugin client itself.
|
||||||
|
multiplexingSupport, err := pluginutil.MultiplexingSupported(ctx, v5Client.clientConn)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
// Close the client and cleanup the plugin process
|
// Close the client and cleanup the plugin process
|
||||||
v5Client.Close()
|
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id)
|
||||||
return nil
|
if err != nil {
|
||||||
|
c.logger.Error("error closing plugin client", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return multiplexingSupport, nil
|
||||||
}
|
}
|
||||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
|
||||||
|
|
||||||
v4Client, err := v4.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
|
v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Close the client and cleanup the plugin process
|
// Close the client and cleanup the plugin process
|
||||||
v4Client.Close()
|
err = v4Client.Close()
|
||||||
return nil
|
if err != nil {
|
||||||
|
c.logger.Error("error closing plugin client", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
|
||||||
|
|
||||||
return merr.ErrorOrNil()
|
return false, merr.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePlugins will loop over all the plugins of unknown type and attempt to
|
// UpdatePlugins will loop over all the plugins of unknown type and attempt to
|
||||||
@ -170,7 +432,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
|
|||||||
cmdOld := plugin.Command
|
cmdOld := plugin.Command
|
||||||
plugin.Command = filepath.Join(c.directory, plugin.Command)
|
plugin.Command = filepath.Join(c.directory, plugin.Command)
|
||||||
|
|
||||||
pluginType, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
|
pluginType, multiplexingSupport, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
||||||
continue
|
continue
|
||||||
@ -181,7 +443,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upgrade the storage
|
// Upgrade the storage
|
||||||
err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
|
err = c.setInternal(ctx, pluginName, pluginType, multiplexingSupport, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
||||||
continue
|
continue
|
||||||
@ -269,10 +531,14 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.
|
|||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
return c.setInternal(ctx, name, pluginType, command, args, env, sha256)
|
// During plugin registration, we can't know if a plugin is multiplexed or
|
||||||
|
// not until we run it. So we set it to false here. Once started, we ask
|
||||||
|
// the plugin if it is multiplexed and set this value accordingly.
|
||||||
|
multiplexingSupport := false
|
||||||
|
return c.setInternal(ctx, name, pluginType, multiplexingSupport, command, args, env, sha256)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error {
|
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error {
|
||||||
// Best effort check to make sure the command isn't breaking out of the
|
// Best effort check to make sure the command isn't breaking out of the
|
||||||
// configured plugin directory.
|
// configured plugin directory.
|
||||||
commandFull := filepath.Join(c.directory, command)
|
commandFull := filepath.Join(c.directory, command)
|
||||||
@ -294,15 +560,16 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||||||
// entryTmp should only be used for the below type check, it uses the
|
// entryTmp should only be used for the below type check, it uses the
|
||||||
// full command instead of the relative command.
|
// full command instead of the relative command.
|
||||||
entryTmp := &pluginutil.PluginRunner{
|
entryTmp := &pluginutil.PluginRunner{
|
||||||
Name: name,
|
Name: name,
|
||||||
Command: commandFull,
|
Command: commandFull,
|
||||||
Args: args,
|
Args: args,
|
||||||
Env: env,
|
Env: env,
|
||||||
Sha256: sha256,
|
Sha256: sha256,
|
||||||
Builtin: false,
|
Builtin: false,
|
||||||
|
MultiplexingSupport: multiplexingSupport,
|
||||||
}
|
}
|
||||||
|
|
||||||
pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
|
pluginType, multiplexingSupport, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -312,13 +579,14 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||||||
}
|
}
|
||||||
|
|
||||||
entry := &pluginutil.PluginRunner{
|
entry := &pluginutil.PluginRunner{
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: pluginType,
|
Type: pluginType,
|
||||||
Command: command,
|
Command: command,
|
||||||
Args: args,
|
Args: args,
|
||||||
Env: env,
|
Env: env,
|
||||||
Sha256: sha256,
|
Sha256: sha256,
|
||||||
Builtin: false,
|
Builtin: false,
|
||||||
|
MultiplexingSupport: multiplexingSupport,
|
||||||
}
|
}
|
||||||
|
|
||||||
buf, err := json.Marshal(entry)
|
buf, err := json.Marshal(entry)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.27.1
|
// protoc-gen-go v1.27.1
|
||||||
// protoc v3.17.3
|
// protoc v3.19.4
|
||||||
// source: vault/request_forwarding_service.proto
|
// source: vault/request_forwarding_service.proto
|
||||||
|
|
||||||
package vault
|
package vault
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user