vault/sdk/database/dbplugin/v5/grpc_server_test.go
John-Michael Faircloth 3565c90cf8
feature: multiplexing support for database plugins (#14033)
* feat: DB plugin multiplexing (#13734)

* WIP: start from main and get a plugin runner from core

* move MultiplexedClient map to plugin catalog
- call sys.NewPluginClient from PluginFactory
- updates to getPluginClient
- thread through isMetadataMode

* use go-plugin ClientProtocol interface
- call sys.NewPluginClient from dbplugin.NewPluginClient

* move PluginSets to dbplugin package
- export dbplugin HandshakeConfig
- small refactor of PluginCatalog.getPluginClient

* add removeMultiplexedClient; clean up on Close()
- call client.Kill from plugin catalog
- set rpcClient when muxed client exists

* add ID to dbplugin.DatabasePluginClient struct

* only create one plugin process per plugin type

* update NewPluginClient to return connection ID to sdk
- wrap grpc.ClientConn so we can inject the ID into context
- get ID from context on grpc server

* add v6 multiplexing  protocol version

* WIP: backwards compat for db plugins

* Ensure locking on plugin catalog access

- Create public GetPluginClient method for plugin catalog
- rename postgres db plugin

* use the New constructor for db plugins

* grpc server: use write lock for Close and rlock for CRUD

* cleanup MultiplexedClients on Close

* remove TODO

* fix multiplexing regression with grpc server connection

* cleanup grpc server instances on close

* embed ClientProtocol in Multiplexer interface

* use PluginClientConfig arg to make NewPluginClient plugin type agnostic

* create a new plugin process for non-muxed plugins

* feat: plugin multiplexing: handle plugin client cleanup (#13896)

* use closure for plugin client cleanup

* log and return errors; add comments

* move rpcClient wrapping to core for ID injection

* refactor core plugin client and sdk

* remove unused ID method

* refactor and only wrap clientConn on multiplexed plugins

* rename structs and do not export types

* Slight refactor of system view interface

* Revert "Slight refactor of system view interface"

This reverts commit 73d420e5cd.

* Revert "Revert "Slight refactor of system view interface""

This reverts commit f75527008a.

* only provide pluginRunner arg to the internal newPluginClient method

* embed ClientProtocol in pluginClient and name logger

* Add back MLock support

* remove enableMlock arg from setupPluginCatalog

* rename plugin util interface to PluginClient

Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>

* feature: multiplexing: fix unit tests (#14007)

* fix grpc_server tests and add coverage

* update run_config tests

* add happy path test case for grpc_server ID from context

* update test helpers

* feat: multiplexing: handle v5 plugin compiled with new sdk

* add mux supported flag and increase test coverage

* set multiplexingSupport field in plugin server

* remove multiplexingSupport field in sdk

* revert postgres to non-multiplexed

* add comments on grpc server fields

* use pointer receiver on grpc server methods

* add changelog

* use pointer for grpcserver instance

* Use a gRPC server to determine if a plugin should be multiplexed

* Apply suggestions from code review

Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>

* add lock to removePluginClient

* add multiplexingSupport field to externalPlugin struct

* do not send nil to grpc MultiplexingSupport

* check err before logging

* handle locking scenario for cleanupFunc

* allow ServeConfigMultiplex to dispense v5 plugin

* reposition structs, add err check and comments

* add comment on locking for cleanupExternalPlugin

Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
2022-02-17 08:50:33 -06:00

801 lines
20 KiB
Go

package dbplugin
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"time"
"google.golang.org/protobuf/types/known/structpb"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// Before minValidSeconds in ptypes package
var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
func TestGRPCServer_Initialize(t *testing.T) {
type testCase struct {
db Database
req *proto.InitializeRequest
expectedResp *proto.InitializeResponse
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
}
tests := map[string]testCase{
"database errored": {
db: fakeDatabase{
initErr: errors.New("initialization error"),
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"newConfig can't marshal to JSON": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"bad-data": badJSONValue{},
},
},
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data for multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
},
req: &proto.InitializeRequest{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectedResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data for non-multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
},
req: &proto.InitializeRequest{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectedResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := test.grpcSetupFunc(t, test.db)
resp, err := g.Initialize(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestCoerceFloatsToInt(t *testing.T) {
type testCase struct {
input map[string]interface{}
expected map[string]interface{}
}
tests := map[string]testCase{
"no numbers": {
input: map[string]interface{}{
"foo": "bar",
},
expected: map[string]interface{}{
"foo": "bar",
},
},
"raw integers": {
input: map[string]interface{}{
"foo": 42,
},
expected: map[string]interface{}{
"foo": 42,
},
},
"floats ": {
input: map[string]interface{}{
"foo": 42.2,
},
expected: map[string]interface{}{
"foo": 42.2,
},
},
"floats coerced to ints": {
input: map[string]interface{}{
"foo": float64(42),
},
expected: map[string]interface{}{
"foo": int64(42),
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
actual := copyMap(test.input)
coerceFloatsToInt(actual)
if !reflect.DeepEqual(actual, test.expected) {
t.Fatalf("Actual: %#v\nExpected: %#v", actual, test.expected)
}
})
}
}
func copyMap(m map[string]interface{}) map[string]interface{} {
newMap := map[string]interface{}{}
for k, v := range m {
newMap[k] = v
}
return newMap
}
func TestGRPCServer_NewUser(t *testing.T) {
type testCase struct {
db Database
req *proto.NewUserRequest
expectedResp *proto.NewUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username config": {
db: fakeDatabase{},
req: &proto.NewUserRequest{},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"bad expiration": {
db: fakeDatabase{},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: &timestamp.Timestamp{
Seconds: invalidExpiration.Unix(),
},
},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
newUserErr: errors.New("new user error"),
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: ptypes.TimestampNow(),
},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path with expiration": {
db: fakeDatabase{
newUserResp: NewUserResponse{
Username: "someuser_foo",
},
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: ptypes.TimestampNow(),
},
expectedResp: &proto.NewUserResponse{
Username: "someuser_foo",
},
expectErr: false,
expectCode: codes.OK,
},
"happy path without expiration": {
db: fakeDatabase{
newUserResp: NewUserResponse{
Username: "someuser_foo",
},
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
},
expectedResp: &proto.NewUserResponse{
Username: "someuser_foo",
},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.NewUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_UpdateUser(t *testing.T) {
type testCase struct {
db Database
req *proto.UpdateUserRequest
expectedResp *proto.UpdateUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"missing changes": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
updateUserErr: errors.New("update user error"),
},
req: &proto.UpdateUserRequest{
Username: "someuser",
Password: &proto.ChangePassword{
NewPassword: "90ughaino",
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"bad expiration date": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Expiration: &proto.ChangeExpiration{
NewExpiration: &timestamp.Timestamp{
// Before minValidSeconds in ptypes package
Seconds: invalidExpiration.Unix(),
},
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"change password happy path": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Password: &proto.ChangePassword{
NewPassword: "90ughaino",
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
"change expiration happy path": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Expiration: &proto.ChangeExpiration{
NewExpiration: ptypes.TimestampNow(),
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.UpdateUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_DeleteUser(t *testing.T) {
type testCase struct {
db Database
req *proto.DeleteUserRequest
expectedResp *proto.DeleteUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username": {
db: fakeDatabase{},
req: &proto.DeleteUserRequest{},
expectedResp: &proto.DeleteUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
deleteUserErr: errors.New("delete user error"),
},
req: &proto.DeleteUserRequest{
Username: "someuser",
},
expectedResp: &proto.DeleteUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path": {
db: fakeDatabase{},
req: &proto.DeleteUserRequest{
Username: "someuser",
},
expectedResp: &proto.DeleteUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.DeleteUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_Type(t *testing.T) {
type testCase struct {
db Database
expectedResp *proto.TypeResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"database error": {
db: fakeDatabase{
typeErr: errors.New("type error"),
},
expectedResp: &proto.TypeResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path": {
db: fakeDatabase{
typeResp: "fake database",
},
expectedResp: &proto.TypeResponse{
Type: "fake database",
},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.Type(idCtx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_Close(t *testing.T) {
type testCase struct {
db Database
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
assertFunc func(t *testing.T, g gRPCServer)
}
tests := map[string]testCase{
"database error": {
db: fakeDatabase{
closeErr: errors.New("close error"),
},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
assertFunc: nil,
},
"happy path for multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
assertFunc: func(t *testing.T, g gRPCServer) {
if len(g.instances) != 0 {
t.Fatalf("err expected instances map to be empty")
}
},
},
"happy path for non-multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
assertFunc: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := test.grpcSetupFunc(t, test.db)
_, err := g.Close(idCtx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if test.assertFunc != nil {
test.assertFunc(t, g)
}
})
}
}
func TestGetMultiplexIDFromContext(t *testing.T) {
type testCase struct {
ctx context.Context
expectedResp string
expectedErr error
}
tests := map[string]testCase{
"missing plugin multiplexing metadata": {
ctx: context.Background(),
expectedResp: "",
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
},
"unexpected number of IDs in metadata": {
ctx: idCtx(t, "12345", "67891"),
expectedResp: "",
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
},
"empty multiplex ID in metadata": {
ctx: idCtx(t, ""),
expectedResp: "",
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
},
"happy path, id is returned from metadata": {
ctx: idCtx(t, "12345"),
expectedResp: "12345",
expectedErr: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
resp, err := getMultiplexIDFromContext(test.ctx)
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
t.Fatalf("err expected, got nil")
} else if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
}
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
// testGrpcServer is a test helper that returns a context with an ID set in its
// metadata and a gRPCServer instance for a multiplexed plugin
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
g := gRPCServer{
factoryFunc: func() (interface{}, error) {
return db, nil
},
instances: make(map[string]Database),
}
id := "12345"
idCtx := idCtx(t, id)
g.instances[id] = db
return idCtx, g
}
// testGrpcServerSingleImpl is a test helper that returns a context and a
// gRPCServer instance for a non-multiplexed plugin
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
return context.Background(), gRPCServer{
singleImpl: db,
}
}
// idCtx is a test helper that will return a context with the IDs set in its
// metadata
func idCtx(t *testing.T, ids ...string) context.Context {
t.Helper()
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
md := metadata.MD{}
for _, id := range ids {
md.Append(pluginutil.MultiplexingCtxKey, id)
}
return metadata.NewIncomingContext(ctx, md)
}
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
t.Helper()
strct, err := mapToStruct(m)
if err != nil {
t.Fatalf("unable to marshal to protobuf: %s", err)
}
return strct
}
type badJSONValue struct{}
func (badJSONValue) MarshalJSON() ([]byte, error) {
return nil, fmt.Errorf("this cannot be marshalled to JSON")
}
func (badJSONValue) UnmarshalJSON([]byte) error {
return fmt.Errorf("this cannot be unmarshalled from JSON")
}
var _ Database = fakeDatabase{}
type fakeDatabase struct {
initResp InitializeResponse
initErr error
newUserResp NewUserResponse
newUserErr error
updateUserResp UpdateUserResponse
updateUserErr error
deleteUserResp DeleteUserResponse
deleteUserErr error
typeResp string
typeErr error
closeErr error
}
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
return e.initResp, e.initErr
}
func (e fakeDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
return e.newUserResp, e.newUserErr
}
func (e fakeDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
return e.updateUserResp, e.updateUserErr
}
func (e fakeDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
return e.deleteUserResp, e.deleteUserErr
}
func (e fakeDatabase) Type() (string, error) {
return e.typeResp, e.typeErr
}
func (e fakeDatabase) Close() error {
return e.closeErr
}
var _ Database = &recordingDatabase{}
type recordingDatabase struct {
initializeCalls int
newUserCalls int
updateUserCalls int
deleteUserCalls int
typeCalls int
closeCalls int
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
next Database
}
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
f.initializeCalls++
if f.next == nil {
return InitializeResponse{}, nil
}
return f.next.Initialize(ctx, req)
}
func (f *recordingDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
f.newUserCalls++
if f.next == nil {
return NewUserResponse{}, nil
}
return f.next.NewUser(ctx, req)
}
func (f *recordingDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
f.updateUserCalls++
if f.next == nil {
return UpdateUserResponse{}, nil
}
return f.next.UpdateUser(ctx, req)
}
func (f *recordingDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
f.deleteUserCalls++
if f.next == nil {
return DeleteUserResponse{}, nil
}
return f.next.DeleteUser(ctx, req)
}
func (f *recordingDatabase) Type() (string, error) {
f.typeCalls++
if f.next == nil {
return "recordingDatabase", nil
}
return f.next.Type()
}
func (f *recordingDatabase) Close() error {
f.closeCalls++
if f.next == nil {
return nil
}
return f.next.Close()
}