mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-10 16:47:01 +02:00
* feat: DB plugin multiplexing (#13734) * WIP: start from main and get a plugin runner from core * move MultiplexedClient map to plugin catalog - call sys.NewPluginClient from PluginFactory - updates to getPluginClient - thread through isMetadataMode * use go-plugin ClientProtocol interface - call sys.NewPluginClient from dbplugin.NewPluginClient * move PluginSets to dbplugin package - export dbplugin HandshakeConfig - small refactor of PluginCatalog.getPluginClient * add removeMultiplexedClient; clean up on Close() - call client.Kill from plugin catalog - set rpcClient when muxed client exists * add ID to dbplugin.DatabasePluginClient struct * only create one plugin process per plugin type * update NewPluginClient to return connection ID to sdk - wrap grpc.ClientConn so we can inject the ID into context - get ID from context on grpc server * add v6 multiplexing protocol version * WIP: backwards compat for db plugins * Ensure locking on plugin catalog access - Create public GetPluginClient method for plugin catalog - rename postgres db plugin * use the New constructor for db plugins * grpc server: use write lock for Close and rlock for CRUD * cleanup MultiplexedClients on Close * remove TODO * fix multiplexing regression with grpc server connection * cleanup grpc server instances on close * embed ClientProtocol in Multiplexer interface * use PluginClientConfig arg to make NewPluginClient plugin type agnostic * create a new plugin process for non-muxed plugins * feat: plugin multiplexing: handle plugin client cleanup (#13896) * use closure for plugin client cleanup * log and return errors; add comments * move rpcClient wrapping to core for ID injection * refactor core plugin client and sdk * remove unused ID method * refactor and only wrap clientConn on multiplexed plugins * rename structs and do not export types * Slight refactor of system view interface * Revert "Slight refactor of system view interface" This reverts commit73d420e5cd
. * Revert "Revert "Slight refactor of system view interface"" This reverts commitf75527008a
. * only provide pluginRunner arg to the internal newPluginClient method * embed ClientProtocol in pluginClient and name logger * Add back MLock support * remove enableMlock arg from setupPluginCatalog * rename plugin util interface to PluginClient Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> * feature: multiplexing: fix unit tests (#14007) * fix grpc_server tests and add coverage * update run_config tests * add happy path test case for grpc_server ID from context * update test helpers * feat: multiplexing: handle v5 plugin compiled with new sdk * add mux supported flag and increase test coverage * set multiplexingSupport field in plugin server * remove multiplexingSupport field in sdk * revert postgres to non-multiplexed * add comments on grpc server fields * use pointer receiver on grpc server methods * add changelog * use pointer for grpcserver instance * Use a gRPC server to determine if a plugin should be multiplexed * Apply suggestions from code review Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com> * add lock to removePluginClient * add multiplexingSupport field to externalPlugin struct * do not send nil to grpc MultiplexingSupport * check err before logging * handle locking scenario for cleanupFunc * allow ServeConfigMultiplex to dispense v5 plugin * reposition structs, add err check and comments * add comment on locking for cleanupExternalPlugin Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
801 lines
20 KiB
Go
801 lines
20 KiB
Go
package dbplugin
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
|
|
"github.com/golang/protobuf/ptypes"
|
|
"github.com/golang/protobuf/ptypes/timestamp"
|
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// Before minValidSeconds in ptypes package
|
|
var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
|
|
func TestGRPCServer_Initialize(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.InitializeRequest
|
|
expectedResp *proto.InitializeResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database errored": {
|
|
db: fakeDatabase{
|
|
initErr: errors.New("initialization error"),
|
|
},
|
|
req: &proto.InitializeRequest{},
|
|
expectedResp: &proto.InitializeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
grpcSetupFunc: testGrpcServer,
|
|
},
|
|
"newConfig can't marshal to JSON": {
|
|
db: fakeDatabase{
|
|
initResp: InitializeResponse{
|
|
Config: map[string]interface{}{
|
|
"bad-data": badJSONValue{},
|
|
},
|
|
},
|
|
},
|
|
req: &proto.InitializeRequest{},
|
|
expectedResp: &proto.InitializeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
grpcSetupFunc: testGrpcServer,
|
|
},
|
|
"happy path with config data for multiplexed plugin": {
|
|
db: fakeDatabase{
|
|
initResp: InitializeResponse{
|
|
Config: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
},
|
|
req: &proto.InitializeRequest{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectedResp: &proto.InitializeResponse{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
grpcSetupFunc: testGrpcServer,
|
|
},
|
|
"happy path with config data for non-multiplexed plugin": {
|
|
db: fakeDatabase{
|
|
initResp: InitializeResponse{
|
|
Config: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
},
|
|
req: &proto.InitializeRequest{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectedResp: &proto.InitializeResponse{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
grpcSetupFunc: testGrpcServerSingleImpl,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := test.grpcSetupFunc(t, test.db)
|
|
resp, err := g.Initialize(idCtx, test.req)
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCoerceFloatsToInt(t *testing.T) {
|
|
type testCase struct {
|
|
input map[string]interface{}
|
|
expected map[string]interface{}
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"no numbers": {
|
|
input: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
"raw integers": {
|
|
input: map[string]interface{}{
|
|
"foo": 42,
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": 42,
|
|
},
|
|
},
|
|
"floats ": {
|
|
input: map[string]interface{}{
|
|
"foo": 42.2,
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": 42.2,
|
|
},
|
|
},
|
|
"floats coerced to ints": {
|
|
input: map[string]interface{}{
|
|
"foo": float64(42),
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": int64(42),
|
|
},
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
actual := copyMap(test.input)
|
|
coerceFloatsToInt(actual)
|
|
if !reflect.DeepEqual(actual, test.expected) {
|
|
t.Fatalf("Actual: %#v\nExpected: %#v", actual, test.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func copyMap(m map[string]interface{}) map[string]interface{} {
|
|
newMap := map[string]interface{}{}
|
|
for k, v := range m {
|
|
newMap[k] = v
|
|
}
|
|
return newMap
|
|
}
|
|
|
|
func TestGRPCServer_NewUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.NewUserRequest
|
|
expectedResp *proto.NewUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username config": {
|
|
db: fakeDatabase{},
|
|
req: &proto.NewUserRequest{},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"bad expiration": {
|
|
db: fakeDatabase{},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ×tamp.Timestamp{
|
|
Seconds: invalidExpiration.Unix(),
|
|
},
|
|
},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
newUserErr: errors.New("new user error"),
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ptypes.TimestampNow(),
|
|
},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path with expiration": {
|
|
db: fakeDatabase{
|
|
newUserResp: NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ptypes.TimestampNow(),
|
|
},
|
|
expectedResp: &proto.NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
"happy path without expiration": {
|
|
db: fakeDatabase{
|
|
newUserResp: NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
},
|
|
expectedResp: &proto.NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := testGrpcServer(t, test.db)
|
|
resp, err := g.NewUser(idCtx, test.req)
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_UpdateUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.UpdateUserRequest
|
|
expectedResp *proto.UpdateUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"missing changes": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
updateUserErr: errors.New("update user error"),
|
|
},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Password: &proto.ChangePassword{
|
|
NewPassword: "90ughaino",
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"bad expiration date": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Expiration: &proto.ChangeExpiration{
|
|
NewExpiration: ×tamp.Timestamp{
|
|
// Before minValidSeconds in ptypes package
|
|
Seconds: invalidExpiration.Unix(),
|
|
},
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"change password happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Password: &proto.ChangePassword{
|
|
NewPassword: "90ughaino",
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
"change expiration happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Expiration: &proto.ChangeExpiration{
|
|
NewExpiration: ptypes.TimestampNow(),
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := testGrpcServer(t, test.db)
|
|
resp, err := g.UpdateUser(idCtx, test.req)
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_DeleteUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.DeleteUserRequest
|
|
expectedResp *proto.DeleteUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username": {
|
|
db: fakeDatabase{},
|
|
req: &proto.DeleteUserRequest{},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
deleteUserErr: errors.New("delete user error"),
|
|
},
|
|
req: &proto.DeleteUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.DeleteUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := testGrpcServer(t, test.db)
|
|
resp, err := g.DeleteUser(idCtx, test.req)
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_Type(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
expectedResp *proto.TypeResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
typeErr: errors.New("type error"),
|
|
},
|
|
expectedResp: &proto.TypeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path": {
|
|
db: fakeDatabase{
|
|
typeResp: "fake database",
|
|
},
|
|
expectedResp: &proto.TypeResponse{
|
|
Type: "fake database",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := testGrpcServer(t, test.db)
|
|
resp, err := g.Type(idCtx, &proto.Empty{})
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_Close(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
|
assertFunc func(t *testing.T, g gRPCServer)
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
closeErr: errors.New("close error"),
|
|
},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
grpcSetupFunc: testGrpcServer,
|
|
assertFunc: nil,
|
|
},
|
|
"happy path for multiplexed plugin": {
|
|
db: fakeDatabase{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
grpcSetupFunc: testGrpcServer,
|
|
assertFunc: func(t *testing.T, g gRPCServer) {
|
|
if len(g.instances) != 0 {
|
|
t.Fatalf("err expected instances map to be empty")
|
|
}
|
|
},
|
|
},
|
|
"happy path for non-multiplexed plugin": {
|
|
db: fakeDatabase{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
grpcSetupFunc: testGrpcServerSingleImpl,
|
|
assertFunc: nil,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
idCtx, g := test.grpcSetupFunc(t, test.db)
|
|
_, err := g.Close(idCtx, &proto.Empty{})
|
|
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if test.assertFunc != nil {
|
|
test.assertFunc(t, g)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetMultiplexIDFromContext(t *testing.T) {
|
|
type testCase struct {
|
|
ctx context.Context
|
|
expectedResp string
|
|
expectedErr error
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing plugin multiplexing metadata": {
|
|
ctx: context.Background(),
|
|
expectedResp: "",
|
|
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
|
|
},
|
|
"unexpected number of IDs in metadata": {
|
|
ctx: idCtx(t, "12345", "67891"),
|
|
expectedResp: "",
|
|
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
|
|
},
|
|
"empty multiplex ID in metadata": {
|
|
ctx: idCtx(t, ""),
|
|
expectedResp: "",
|
|
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
|
|
},
|
|
"happy path, id is returned from metadata": {
|
|
ctx: idCtx(t, "12345"),
|
|
expectedResp: "12345",
|
|
expectedErr: nil,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
resp, err := getMultiplexIDFromContext(test.ctx)
|
|
|
|
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
} else if !reflect.DeepEqual(err, test.expectedErr) {
|
|
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
|
|
}
|
|
|
|
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// testGrpcServer is a test helper that returns a context with an ID set in its
|
|
// metadata and a gRPCServer instance for a multiplexed plugin
|
|
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
|
|
t.Helper()
|
|
g := gRPCServer{
|
|
factoryFunc: func() (interface{}, error) {
|
|
return db, nil
|
|
},
|
|
instances: make(map[string]Database),
|
|
}
|
|
|
|
id := "12345"
|
|
idCtx := idCtx(t, id)
|
|
g.instances[id] = db
|
|
|
|
return idCtx, g
|
|
}
|
|
|
|
// testGrpcServerSingleImpl is a test helper that returns a context and a
|
|
// gRPCServer instance for a non-multiplexed plugin
|
|
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
|
|
t.Helper()
|
|
return context.Background(), gRPCServer{
|
|
singleImpl: db,
|
|
}
|
|
}
|
|
|
|
// idCtx is a test helper that will return a context with the IDs set in its
|
|
// metadata
|
|
func idCtx(t *testing.T, ids ...string) context.Context {
|
|
t.Helper()
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
md := metadata.MD{}
|
|
for _, id := range ids {
|
|
md.Append(pluginutil.MultiplexingCtxKey, id)
|
|
}
|
|
return metadata.NewIncomingContext(ctx, md)
|
|
}
|
|
|
|
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
|
t.Helper()
|
|
|
|
strct, err := mapToStruct(m)
|
|
if err != nil {
|
|
t.Fatalf("unable to marshal to protobuf: %s", err)
|
|
}
|
|
return strct
|
|
}
|
|
|
|
type badJSONValue struct{}
|
|
|
|
func (badJSONValue) MarshalJSON() ([]byte, error) {
|
|
return nil, fmt.Errorf("this cannot be marshalled to JSON")
|
|
}
|
|
|
|
func (badJSONValue) UnmarshalJSON([]byte) error {
|
|
return fmt.Errorf("this cannot be unmarshalled from JSON")
|
|
}
|
|
|
|
var _ Database = fakeDatabase{}
|
|
|
|
type fakeDatabase struct {
|
|
initResp InitializeResponse
|
|
initErr error
|
|
|
|
newUserResp NewUserResponse
|
|
newUserErr error
|
|
|
|
updateUserResp UpdateUserResponse
|
|
updateUserErr error
|
|
|
|
deleteUserResp DeleteUserResponse
|
|
deleteUserErr error
|
|
|
|
typeResp string
|
|
typeErr error
|
|
|
|
closeErr error
|
|
}
|
|
|
|
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
|
|
return e.initResp, e.initErr
|
|
}
|
|
|
|
func (e fakeDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
|
|
return e.newUserResp, e.newUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
|
|
return e.updateUserResp, e.updateUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
|
|
return e.deleteUserResp, e.deleteUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) Type() (string, error) {
|
|
return e.typeResp, e.typeErr
|
|
}
|
|
|
|
func (e fakeDatabase) Close() error {
|
|
return e.closeErr
|
|
}
|
|
|
|
var _ Database = &recordingDatabase{}
|
|
|
|
type recordingDatabase struct {
|
|
initializeCalls int
|
|
newUserCalls int
|
|
updateUserCalls int
|
|
deleteUserCalls int
|
|
typeCalls int
|
|
closeCalls int
|
|
|
|
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
|
|
next Database
|
|
}
|
|
|
|
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
|
|
f.initializeCalls++
|
|
if f.next == nil {
|
|
return InitializeResponse{}, nil
|
|
}
|
|
return f.next.Initialize(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
|
|
f.newUserCalls++
|
|
if f.next == nil {
|
|
return NewUserResponse{}, nil
|
|
}
|
|
return f.next.NewUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
|
|
f.updateUserCalls++
|
|
if f.next == nil {
|
|
return UpdateUserResponse{}, nil
|
|
}
|
|
return f.next.UpdateUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
|
|
f.deleteUserCalls++
|
|
if f.next == nil {
|
|
return DeleteUserResponse{}, nil
|
|
}
|
|
return f.next.DeleteUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) Type() (string, error) {
|
|
f.typeCalls++
|
|
if f.next == nil {
|
|
return "recordingDatabase", nil
|
|
}
|
|
return f.next.Type()
|
|
}
|
|
|
|
func (f *recordingDatabase) Close() error {
|
|
f.closeCalls++
|
|
if f.next == nil {
|
|
return nil
|
|
}
|
|
return f.next.Close()
|
|
}
|