vault/sdk/database/dbplugin/v5/grpc_server_test.go
Christopher Swenson 70278c2787
Add plugin version to GRPC interface (#17088)
Add plugin version to GRPC interface

Added a version interface in the sdk/logical so that it can be shared between all plugin types, and then wired it up to RunningVersion in the mounts, auth list, and database systems.

I've tested that this works with auth, database, and secrets plugin types, with the following logic to populate RunningVersion:

If a plugin has a PluginVersion() method implemented, then that is used
If not, and the plugin is built into the Vault binary, then the go.mod version is used
Otherwise, the it will be the empty string.
My apologies for the length of this PR.

* Placeholder backend should be external

We use a placeholder backend (previously a framework.Backend) before a
GRPC plugin is lazy-loaded. This makes us later think the plugin is a
builtin plugin.

So we added a `placeholderBackend` type that overrides the
`IsExternal()` method so that later we know that the plugin is external,
and don't give it a default builtin version.
2022-09-15 16:37:59 -07:00

837 lines
20 KiB
Go

package dbplugin
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/protobuf/types/known/structpb"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// Before minValidSeconds in ptypes package
var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
func TestGRPCServer_Initialize(t *testing.T) {
type testCase struct {
db Database
req *proto.InitializeRequest
expectedResp *proto.InitializeResponse
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
}
tests := map[string]testCase{
"database errored": {
db: fakeDatabase{
initErr: errors.New("initialization error"),
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"newConfig can't marshal to JSON": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"bad-data": badJSONValue{},
},
},
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data for multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
},
req: &proto.InitializeRequest{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectedResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data for non-multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
},
req: &proto.InitializeRequest{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectedResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := test.grpcSetupFunc(t, test.db)
resp, err := g.Initialize(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestCoerceFloatsToInt(t *testing.T) {
type testCase struct {
input map[string]interface{}
expected map[string]interface{}
}
tests := map[string]testCase{
"no numbers": {
input: map[string]interface{}{
"foo": "bar",
},
expected: map[string]interface{}{
"foo": "bar",
},
},
"raw integers": {
input: map[string]interface{}{
"foo": 42,
},
expected: map[string]interface{}{
"foo": 42,
},
},
"floats ": {
input: map[string]interface{}{
"foo": 42.2,
},
expected: map[string]interface{}{
"foo": 42.2,
},
},
"floats coerced to ints": {
input: map[string]interface{}{
"foo": float64(42),
},
expected: map[string]interface{}{
"foo": int64(42),
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
actual := copyMap(test.input)
coerceFloatsToInt(actual)
if !reflect.DeepEqual(actual, test.expected) {
t.Fatalf("Actual: %#v\nExpected: %#v", actual, test.expected)
}
})
}
}
func copyMap(m map[string]interface{}) map[string]interface{} {
newMap := map[string]interface{}{}
for k, v := range m {
newMap[k] = v
}
return newMap
}
func TestGRPCServer_NewUser(t *testing.T) {
type testCase struct {
db Database
req *proto.NewUserRequest
expectedResp *proto.NewUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username config": {
db: fakeDatabase{},
req: &proto.NewUserRequest{},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"bad expiration": {
db: fakeDatabase{},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: &timestamp.Timestamp{
Seconds: invalidExpiration.Unix(),
},
},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
newUserErr: errors.New("new user error"),
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: ptypes.TimestampNow(),
},
expectedResp: &proto.NewUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path with expiration": {
db: fakeDatabase{
newUserResp: NewUserResponse{
Username: "someuser_foo",
},
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
Expiration: ptypes.TimestampNow(),
},
expectedResp: &proto.NewUserResponse{
Username: "someuser_foo",
},
expectErr: false,
expectCode: codes.OK,
},
"happy path without expiration": {
db: fakeDatabase{
newUserResp: NewUserResponse{
Username: "someuser_foo",
},
},
req: &proto.NewUserRequest{
UsernameConfig: &proto.UsernameConfig{
DisplayName: "dispname",
RoleName: "rolename",
},
},
expectedResp: &proto.NewUserResponse{
Username: "someuser_foo",
},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.NewUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_UpdateUser(t *testing.T) {
type testCase struct {
db Database
req *proto.UpdateUserRequest
expectedResp *proto.UpdateUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"missing changes": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
updateUserErr: errors.New("update user error"),
},
req: &proto.UpdateUserRequest{
Username: "someuser",
Password: &proto.ChangePassword{
NewPassword: "90ughaino",
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"bad expiration date": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Expiration: &proto.ChangeExpiration{
NewExpiration: &timestamp.Timestamp{
// Before minValidSeconds in ptypes package
Seconds: invalidExpiration.Unix(),
},
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"change password happy path": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Password: &proto.ChangePassword{
NewPassword: "90ughaino",
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
"change expiration happy path": {
db: fakeDatabase{},
req: &proto.UpdateUserRequest{
Username: "someuser",
Expiration: &proto.ChangeExpiration{
NewExpiration: ptypes.TimestampNow(),
},
},
expectedResp: &proto.UpdateUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.UpdateUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_DeleteUser(t *testing.T) {
type testCase struct {
db Database
req *proto.DeleteUserRequest
expectedResp *proto.DeleteUserResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"missing username": {
db: fakeDatabase{},
req: &proto.DeleteUserRequest{},
expectedResp: &proto.DeleteUserResponse{},
expectErr: true,
expectCode: codes.InvalidArgument,
},
"database error": {
db: fakeDatabase{
deleteUserErr: errors.New("delete user error"),
},
req: &proto.DeleteUserRequest{
Username: "someuser",
},
expectedResp: &proto.DeleteUserResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path": {
db: fakeDatabase{},
req: &proto.DeleteUserRequest{
Username: "someuser",
},
expectedResp: &proto.DeleteUserResponse{},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.DeleteUser(idCtx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_Type(t *testing.T) {
type testCase struct {
db Database
expectedResp *proto.TypeResponse
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"database error": {
db: fakeDatabase{
typeErr: errors.New("type error"),
},
expectedResp: &proto.TypeResponse{},
expectErr: true,
expectCode: codes.Internal,
},
"happy path": {
db: fakeDatabase{
typeResp: "fake database",
},
expectedResp: &proto.TypeResponse{
Type: "fake database",
},
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.Type(idCtx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCServer_Close(t *testing.T) {
type testCase struct {
db Database
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
assertFunc func(t *testing.T, g gRPCServer)
}
tests := map[string]testCase{
"database error": {
db: fakeDatabase{
closeErr: errors.New("close error"),
},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
assertFunc: nil,
},
"happy path for multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
assertFunc: func(t *testing.T, g gRPCServer) {
if len(g.instances) != 0 {
t.Fatalf("err expected instances map to be empty")
}
},
},
"happy path for non-multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
assertFunc: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := test.grpcSetupFunc(t, test.db)
_, err := g.Close(idCtx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if test.assertFunc != nil {
test.assertFunc(t, g)
}
})
}
}
func TestGRPCServer_Version(t *testing.T) {
type testCase struct {
db Database
expectedResp string
expectErr bool
expectCode codes.Code
}
tests := map[string]testCase{
"backend that does not implement version": {
db: fakeDatabase{},
expectedResp: "",
expectErr: false,
expectCode: codes.OK,
},
"backend with version": {
db: fakeDatabaseWithVersion{
version: "v123",
},
expectedResp: "v123",
expectErr: false,
expectCode: codes.OK,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.Version(idCtx, &logical.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
actualCode := status.Code(err)
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if !reflect.DeepEqual(resp.PluginVersion, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
// testGrpcServer is a test helper that returns a context with an ID set in its
// metadata and a gRPCServer instance for a multiplexed plugin
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
g := gRPCServer{
factoryFunc: func() (interface{}, error) {
return db, nil
},
instances: make(map[string]Database),
}
id := "12345"
idCtx := idCtx(t, id)
g.instances[id] = db
return idCtx, g
}
// testGrpcServerSingleImpl is a test helper that returns a context and a
// gRPCServer instance for a non-multiplexed plugin
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
return context.Background(), gRPCServer{
singleImpl: db,
}
}
// idCtx is a test helper that will return a context with the IDs set in its
// metadata
func idCtx(t *testing.T, ids ...string) context.Context {
t.Helper()
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
md := metadata.MD{}
for _, id := range ids {
md.Append(pluginutil.MultiplexingCtxKey, id)
}
return metadata.NewIncomingContext(ctx, md)
}
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
t.Helper()
strct, err := mapToStruct(m)
if err != nil {
t.Fatalf("unable to marshal to protobuf: %s", err)
}
return strct
}
type badJSONValue struct{}
func (badJSONValue) MarshalJSON() ([]byte, error) {
return nil, fmt.Errorf("this cannot be marshalled to JSON")
}
func (badJSONValue) UnmarshalJSON([]byte) error {
return fmt.Errorf("this cannot be unmarshalled from JSON")
}
var _ Database = fakeDatabase{}
type fakeDatabase struct {
initResp InitializeResponse
initErr error
newUserResp NewUserResponse
newUserErr error
updateUserResp UpdateUserResponse
updateUserErr error
deleteUserResp DeleteUserResponse
deleteUserErr error
typeResp string
typeErr error
closeErr error
}
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
return e.initResp, e.initErr
}
func (e fakeDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
return e.newUserResp, e.newUserErr
}
func (e fakeDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
return e.updateUserResp, e.updateUserErr
}
func (e fakeDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
return e.deleteUserResp, e.deleteUserErr
}
func (e fakeDatabase) Type() (string, error) {
return e.typeResp, e.typeErr
}
func (e fakeDatabase) Close() error {
return e.closeErr
}
var _ Database = &recordingDatabase{}
type recordingDatabase struct {
initializeCalls int
newUserCalls int
updateUserCalls int
deleteUserCalls int
typeCalls int
closeCalls int
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
next Database
}
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
f.initializeCalls++
if f.next == nil {
return InitializeResponse{}, nil
}
return f.next.Initialize(ctx, req)
}
func (f *recordingDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
f.newUserCalls++
if f.next == nil {
return NewUserResponse{}, nil
}
return f.next.NewUser(ctx, req)
}
func (f *recordingDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
f.updateUserCalls++
if f.next == nil {
return UpdateUserResponse{}, nil
}
return f.next.UpdateUser(ctx, req)
}
func (f *recordingDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
f.deleteUserCalls++
if f.next == nil {
return DeleteUserResponse{}, nil
}
return f.next.DeleteUser(ctx, req)
}
func (f *recordingDatabase) Type() (string, error) {
f.typeCalls++
if f.next == nil {
return "recordingDatabase", nil
}
return f.next.Type()
}
func (f *recordingDatabase) Close() error {
f.closeCalls++
if f.next == nil {
return nil
}
return f.next.Close()
}
type fakeDatabaseWithVersion struct {
version string
}
func (e fakeDatabaseWithVersion) PluginVersion() logical.PluginVersion {
return logical.PluginVersion{Version: e.version}
}
func (e fakeDatabaseWithVersion) Initialize(_ context.Context, _ InitializeRequest) (InitializeResponse, error) {
return InitializeResponse{}, nil
}
func (e fakeDatabaseWithVersion) NewUser(_ context.Context, _ NewUserRequest) (NewUserResponse, error) {
return NewUserResponse{}, nil
}
func (e fakeDatabaseWithVersion) UpdateUser(_ context.Context, _ UpdateUserRequest) (UpdateUserResponse, error) {
return UpdateUserResponse{}, nil
}
func (e fakeDatabaseWithVersion) DeleteUser(_ context.Context, _ DeleteUserRequest) (DeleteUserResponse, error) {
return DeleteUserResponse{}, nil
}
func (e fakeDatabaseWithVersion) Type() (string, error) {
return "", nil
}
func (e fakeDatabaseWithVersion) Close() error {
return nil
}
var (
_ Database = (*fakeDatabaseWithVersion)(nil)
_ logical.PluginVersioner = (*fakeDatabaseWithVersion)(nil)
)