From 4d217becab9eda2ba51a8af50f5883e09bbe9c41 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Tue, 8 Mar 2022 10:33:24 -0600 Subject: [PATCH] plugin multiplexing: add catalog test coverage (#14398) * plugin client and plugin catalog tests * add v5 plugin cases and more checks * improve err msg * refactor tests; fix test err msg --- .../dbplugin/v5/plugin_client_test.go | 146 ++++++++++++++++++ vault/plugin_catalog_test.go | 98 ++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 sdk/database/dbplugin/v5/plugin_client_test.go diff --git a/sdk/database/dbplugin/v5/plugin_client_test.go b/sdk/database/dbplugin/v5/plugin_client_test.go new file mode 100644 index 0000000000..0ff8309f10 --- /dev/null +++ b/sdk/database/dbplugin/v5/plugin_client_test.go @@ -0,0 +1,146 @@ +package dbplugin + +import ( + "context" + "errors" + "reflect" + "testing" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/helper/wrapping" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" +) + +func TestNewPluginClient(t *testing.T) { + type testCase struct { + config pluginutil.PluginClientConfig + pluginClient pluginutil.PluginClient + expectedResp *DatabasePluginClient + expectedErr error + } + + tests := map[string]testCase{ + "happy path": { + config: testPluginClientConfig(), + pluginClient: &fakePluginClient{ + connResp: nil, + dispenseResp: gRPCClient{client: fakeClient{}}, + dispenseErr: nil, + }, + expectedResp: &DatabasePluginClient{ + client: &fakePluginClient{ + connResp: nil, + dispenseResp: gRPCClient{client: fakeClient{}}, + dispenseErr: nil, + }, + Database: gRPCClient{proto.NewDatabaseClient(nil), context.Context(nil)}, + }, + expectedErr: nil, + }, + "dispense error": { + config: testPluginClientConfig(), + pluginClient: &fakePluginClient{ + connResp: nil, + dispenseResp: gRPCClient{}, + dispenseErr: errors.New("dispense error"), + }, + expectedResp: nil, + expectedErr: errors.New("dispense error"), + }, + "error unsupported client type": { + config: testPluginClientConfig(), + pluginClient: &fakePluginClient{ + connResp: nil, + dispenseResp: nil, + dispenseErr: nil, + }, + expectedResp: nil, + expectedErr: errors.New("unsupported client type"), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + + mockWrapper := new(mockRunnerUtil) + mockWrapper.On("NewPluginClient", ctx, mock.Anything). + Return(test.pluginClient, nil) + defer mockWrapper.AssertNumberOfCalls(t, "NewPluginClient", 1) + + resp, err := NewPluginClient(ctx, mockWrapper, test.config) + if test.expectedErr != nil && err == nil { + t.Fatalf("err expected, got nil") + } + if test.expectedErr == nil && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if test.expectedErr == nil && !reflect.DeepEqual(resp, test.expectedResp) { + t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) + } + }) + } +} + +func testPluginClientConfig() pluginutil.PluginClientConfig { + return pluginutil.PluginClientConfig{ + Name: "test-plugin", + PluginSets: PluginSets, + PluginType: consts.PluginTypeDatabase, + HandshakeConfig: HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: true, + } +} + +var _ pluginutil.PluginClient = &fakePluginClient{} + +type fakePluginClient struct { + connResp grpc.ClientConnInterface + + dispenseResp interface{} + dispenseErr error +} + +func (f *fakePluginClient) Conn() grpc.ClientConnInterface { + return nil +} + +func (f *fakePluginClient) Dispense(name string) (interface{}, error) { + return f.dispenseResp, f.dispenseErr +} + +func (f *fakePluginClient) Ping() error { + return nil +} + +func (f *fakePluginClient) Close() error { + return nil +} + +var _ pluginutil.RunnerUtil = &mockRunnerUtil{} + +type mockRunnerUtil struct { + mock.Mock +} + +func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + args := m.Called(ctx, config) + return args.Get(0).(pluginutil.PluginClient), args.Error(1) +} + +func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { + args := m.Called(ctx, data, ttl, jwt) + return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1) +} + +func (m *mockRunnerUtil) MlockEnabled() bool { + args := m.Called() + return args.Bool(0) +} diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 2fb138c2f0..00925a87c6 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -10,6 +10,9 @@ import ( "sort" "testing" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/plugins/database/postgresql" + v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" @@ -180,3 +183,98 @@ func TestPluginCatalog_List(t *testing.T) { } } } + +func TestPluginCatalog_NewPluginClient(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + core.pluginCatalog.directory = sym + + if extPlugins := len(core.pluginCatalog.externalPlugins); extPlugins != 0 { + t.Fatalf("expected externalPlugins map to be of len 0 but got %d", extPlugins) + } + + // register plugins + TestAddTestPlugin(t, core, "mux-postgres", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_PostgresMultiplexed", []string{}, "") + TestAddTestPlugin(t, core, "single-postgres-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "") + TestAddTestPlugin(t, core, "single-postgres-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "") + + // run plugins + if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil { + t.Fatal(err) + } + if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil { + t.Fatal(err) + } + if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-1")); err != nil { + t.Fatal(err) + } + if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-2")); err != nil { + t.Fatal(err) + } + + externalPlugins := core.pluginCatalog.externalPlugins + if len(externalPlugins) != 3 { + t.Fatalf("expected externalPlugins map to be of len 3 but got %d", len(externalPlugins)) + } + + // check connections map + expectedLen := 2 + if len(externalPlugins["mux-postgres"].connections) != expectedLen { + t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) + } + expectedLen = 1 + if len(externalPlugins["single-postgres-1"].connections) != expectedLen { + t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) + } + if len(externalPlugins["single-postgres-2"].connections) != expectedLen { + t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) + } + + // check multiplexing support + if !externalPlugins["mux-postgres"].multiplexingSupport { + t.Fatalf("expected external plugin to be multiplexed") + } + if externalPlugins["single-postgres-1"].multiplexingSupport { + t.Fatalf("expected external plugin to be non-multiplexed") + } + if externalPlugins["single-postgres-2"].multiplexingSupport { + t.Fatalf("expected external plugin to be non-multiplexed") + } +} + +func TestPluginCatalog_PluginMain_Postgres(t *testing.T) { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + dbType, err := postgresql.New() + if err != nil { + t.Fatalf("Failed to initialize postgres: %s", err) + } + + v5.Serve(dbType.(v5.Database)) +} + +func TestPluginCatalog_PluginMain_PostgresMultiplexed(t *testing.T) { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + v5.ServeMultiplex(postgresql.New) +} + +func testPluginClientConfig(pluginName string) pluginutil.PluginClientConfig { + return pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: consts.PluginTypeDatabase, + PluginSets: v5.PluginSets, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + } +}