// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package pluginutil import ( "context" "fmt" "reflect" "testing" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) func TestMultiplexingSupported(t *testing.T) { type args struct { ctx context.Context cc grpc.ClientConnInterface name string } type testCase struct { name string args args env string want bool wantErr bool } tests := []testCase{ { name: "multiplexing is supported if plugin is not opted out", args: args{ ctx: context.Background(), cc: &MockClientConnInterfaceNoop{}, name: "plugin", }, env: "", want: true, }, { name: "multiplexing is not supported if plugin is opted out", args: args{ ctx: context.Background(), cc: &MockClientConnInterfaceNoop{}, name: "optedOutPlugin", }, env: "optedOutPlugin", want: false, }, { name: "multiplexing is not supported if plugin among one of the opted out", args: args{ ctx: context.Background(), cc: &MockClientConnInterfaceNoop{}, name: "optedOutPlugin", }, env: "firstPlugin,optedOutPlugin,otherPlugin", want: false, }, { name: "multiplexing is supported if different plugin is opted out", args: args{ ctx: context.Background(), cc: &MockClientConnInterfaceNoop{}, name: "plugin", }, env: "optedOutPlugin", want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Setenv(PluginMultiplexingOptOut, tt.env) got, err := MultiplexingSupported(tt.args.ctx, tt.args.cc, tt.args.name) if (err != nil) != tt.wantErr { t.Errorf("MultiplexingSupported() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("MultiplexingSupported() got = %v, want %v", got, tt.want) } }) } } func TestGetMultiplexIDFromContext(t *testing.T) { type testCase struct { ctx context.Context expectedResp string expectedErr error } tests := map[string]testCase{ "missing plugin multiplexing metadata": { ctx: context.Background(), expectedResp: "", expectedErr: fmt.Errorf("missing plugin multiplexing metadata"), }, "unexpected number of IDs in metadata": { ctx: idCtx(t, "12345", "67891"), expectedResp: "", expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"), }, "empty multiplex ID in metadata": { ctx: idCtx(t, ""), expectedResp: "", expectedErr: fmt.Errorf("empty multiplex ID in metadata"), }, "happy path, id is returned from metadata": { ctx: idCtx(t, "12345"), expectedResp: "12345", expectedErr: nil, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { resp, err := GetMultiplexIDFromContext(test.ctx) if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil { t.Fatalf("err expected, got nil") } else if !reflect.DeepEqual(err, test.expectedErr) { t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr) } if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil { t.Fatalf("no error expected, got: %s", err) } if !reflect.DeepEqual(resp, test.expectedResp) { t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) } }) } } // idCtx is a test helper that will return a context with the IDs set in its // metadata func idCtx(t *testing.T, ids ...string) context.Context { // Context doesn't need to timeout since this is just passed through ctx := context.Background() md := metadata.MD{} for _, id := range ids { md.Append(MultiplexingCtxKey, id) } return metadata.NewIncomingContext(ctx, md) } type MockClientConnInterfaceNoop struct{} func (m *MockClientConnInterfaceNoop) Invoke(_ context.Context, _ string, _ interface{}, reply interface{}, _ ...grpc.CallOption) error { reply.(*MultiplexingSupportResponse).Supported = true return nil } func (m *MockClientConnInterfaceNoop) NewStream(_ context.Context, _ *grpc.StreamDesc, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) { return nil, nil }