mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-09 08:07:01 +02:00
This is part 1 of 4 for renaming the `newdbplugin` package. This copies the existing package to the new location but keeps the current one in place so we can migrate the existing references over more easily.
693 lines
16 KiB
Go
693 lines
16 KiB
Go
package dbplugin
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
|
|
"github.com/golang/protobuf/ptypes"
|
|
"github.com/golang/protobuf/ptypes/timestamp"
|
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var (
|
|
// Before minValidSeconds in ptypes package
|
|
invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
)
|
|
|
|
func TestGRPCServer_Initialize(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.InitializeRequest
|
|
expectedResp *proto.InitializeResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database errored": {
|
|
db: fakeDatabase{
|
|
initErr: errors.New("initialization error"),
|
|
},
|
|
req: &proto.InitializeRequest{},
|
|
expectedResp: &proto.InitializeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"newConfig can't marshal to JSON": {
|
|
db: fakeDatabase{
|
|
initResp: InitializeResponse{
|
|
Config: map[string]interface{}{
|
|
"bad-data": badJSONValue{},
|
|
},
|
|
},
|
|
},
|
|
req: &proto.InitializeRequest{},
|
|
expectedResp: &proto.InitializeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path with config data": {
|
|
db: fakeDatabase{
|
|
initResp: InitializeResponse{
|
|
Config: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
},
|
|
req: &proto.InitializeRequest{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectedResp: &proto.InitializeResponse{
|
|
ConfigData: marshal(t, map[string]interface{}{
|
|
"foo": "bar",
|
|
}),
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
resp, err := g.Initialize(ctx, test.req)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCoerceFloatsToInt(t *testing.T) {
|
|
type testCase struct {
|
|
input map[string]interface{}
|
|
expected map[string]interface{}
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"no numbers": {
|
|
input: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
"raw integers": {
|
|
input: map[string]interface{}{
|
|
"foo": 42,
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": 42,
|
|
},
|
|
},
|
|
"floats ": {
|
|
input: map[string]interface{}{
|
|
"foo": 42.2,
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": 42.2,
|
|
},
|
|
},
|
|
"floats coerced to ints": {
|
|
input: map[string]interface{}{
|
|
"foo": float64(42),
|
|
},
|
|
expected: map[string]interface{}{
|
|
"foo": int64(42),
|
|
},
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
actual := copyMap(test.input)
|
|
coerceFloatsToInt(actual)
|
|
if !reflect.DeepEqual(actual, test.expected) {
|
|
t.Fatalf("Actual: %#v\nExpected: %#v", actual, test.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func copyMap(m map[string]interface{}) map[string]interface{} {
|
|
newMap := map[string]interface{}{}
|
|
for k, v := range m {
|
|
newMap[k] = v
|
|
}
|
|
return newMap
|
|
}
|
|
|
|
func TestGRPCServer_NewUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.NewUserRequest
|
|
expectedResp *proto.NewUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username config": {
|
|
db: fakeDatabase{},
|
|
req: &proto.NewUserRequest{},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"bad expiration": {
|
|
db: fakeDatabase{},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ×tamp.Timestamp{
|
|
Seconds: invalidExpiration.Unix(),
|
|
},
|
|
},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
newUserErr: errors.New("new user error"),
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ptypes.TimestampNow(),
|
|
},
|
|
expectedResp: &proto.NewUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path with expiration": {
|
|
db: fakeDatabase{
|
|
newUserResp: NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
Expiration: ptypes.TimestampNow(),
|
|
},
|
|
expectedResp: &proto.NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
"happy path without expiration": {
|
|
db: fakeDatabase{
|
|
newUserResp: NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
},
|
|
req: &proto.NewUserRequest{
|
|
UsernameConfig: &proto.UsernameConfig{
|
|
DisplayName: "dispname",
|
|
RoleName: "rolename",
|
|
},
|
|
},
|
|
expectedResp: &proto.NewUserResponse{
|
|
Username: "someuser_foo",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
resp, err := g.NewUser(ctx, test.req)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_UpdateUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.UpdateUserRequest
|
|
expectedResp *proto.UpdateUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"missing changes": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
updateUserErr: errors.New("update user error"),
|
|
},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Password: &proto.ChangePassword{
|
|
NewPassword: "90ughaino",
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"bad expiration date": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Expiration: &proto.ChangeExpiration{
|
|
NewExpiration: ×tamp.Timestamp{
|
|
// Before minValidSeconds in ptypes package
|
|
Seconds: invalidExpiration.Unix(),
|
|
},
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"change password happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Password: &proto.ChangePassword{
|
|
NewPassword: "90ughaino",
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
"change expiration happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.UpdateUserRequest{
|
|
Username: "someuser",
|
|
Expiration: &proto.ChangeExpiration{
|
|
NewExpiration: ptypes.TimestampNow(),
|
|
},
|
|
},
|
|
expectedResp: &proto.UpdateUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
resp, err := g.UpdateUser(ctx, test.req)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_DeleteUser(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
req *proto.DeleteUserRequest
|
|
expectedResp *proto.DeleteUserResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"missing username": {
|
|
db: fakeDatabase{},
|
|
req: &proto.DeleteUserRequest{},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.InvalidArgument,
|
|
},
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
deleteUserErr: errors.New("delete user error"),
|
|
},
|
|
req: &proto.DeleteUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path": {
|
|
db: fakeDatabase{},
|
|
req: &proto.DeleteUserRequest{
|
|
Username: "someuser",
|
|
},
|
|
expectedResp: &proto.DeleteUserResponse{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
resp, err := g.DeleteUser(ctx, test.req)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_Type(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
expectedResp *proto.TypeResponse
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
typeErr: errors.New("type error"),
|
|
},
|
|
expectedResp: &proto.TypeResponse{},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path": {
|
|
db: fakeDatabase{
|
|
typeResp: "fake database",
|
|
},
|
|
expectedResp: &proto.TypeResponse{
|
|
Type: "fake database",
|
|
},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
resp, err := g.Type(ctx, &proto.Empty{})
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
|
|
if !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGRPCServer_Close(t *testing.T) {
|
|
type testCase struct {
|
|
db Database
|
|
expectErr bool
|
|
expectCode codes.Code
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"database error": {
|
|
db: fakeDatabase{
|
|
closeErr: errors.New("close error"),
|
|
},
|
|
expectErr: true,
|
|
expectCode: codes.Internal,
|
|
},
|
|
"happy path": {
|
|
db: fakeDatabase{},
|
|
expectErr: false,
|
|
expectCode: codes.OK,
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
g := gRPCServer{
|
|
impl: test.db,
|
|
}
|
|
|
|
// Context doesn't need to timeout since this is just passed through
|
|
ctx := context.Background()
|
|
|
|
_, err := g.Close(ctx, &proto.Empty{})
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
actualCode := status.Code(err)
|
|
if actualCode != test.expectCode {
|
|
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
|
t.Helper()
|
|
|
|
strct, err := mapToStruct(m)
|
|
if err != nil {
|
|
t.Fatalf("unable to marshal to protobuf: %s", err)
|
|
}
|
|
return strct
|
|
}
|
|
|
|
type badJSONValue struct{}
|
|
|
|
func (badJSONValue) MarshalJSON() ([]byte, error) {
|
|
return nil, fmt.Errorf("this cannot be marshalled to JSON")
|
|
}
|
|
|
|
func (badJSONValue) UnmarshalJSON([]byte) error {
|
|
return fmt.Errorf("this cannot be unmarshalled from JSON")
|
|
}
|
|
|
|
var _ Database = fakeDatabase{}
|
|
|
|
type fakeDatabase struct {
|
|
initResp InitializeResponse
|
|
initErr error
|
|
|
|
newUserResp NewUserResponse
|
|
newUserErr error
|
|
|
|
updateUserResp UpdateUserResponse
|
|
updateUserErr error
|
|
|
|
deleteUserResp DeleteUserResponse
|
|
deleteUserErr error
|
|
|
|
typeResp string
|
|
typeErr error
|
|
|
|
closeErr error
|
|
}
|
|
|
|
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
|
|
return e.initResp, e.initErr
|
|
}
|
|
|
|
func (e fakeDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
|
|
return e.newUserResp, e.newUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
|
|
return e.updateUserResp, e.updateUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
|
|
return e.deleteUserResp, e.deleteUserErr
|
|
}
|
|
|
|
func (e fakeDatabase) Type() (string, error) {
|
|
return e.typeResp, e.typeErr
|
|
}
|
|
|
|
func (e fakeDatabase) Close() error {
|
|
return e.closeErr
|
|
}
|
|
|
|
var _ Database = &recordingDatabase{}
|
|
|
|
type recordingDatabase struct {
|
|
initializeCalls int
|
|
newUserCalls int
|
|
updateUserCalls int
|
|
deleteUserCalls int
|
|
typeCalls int
|
|
closeCalls int
|
|
|
|
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
|
|
next Database
|
|
}
|
|
|
|
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
|
|
f.initializeCalls++
|
|
if f.next == nil {
|
|
return InitializeResponse{}, nil
|
|
}
|
|
return f.next.Initialize(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
|
|
f.newUserCalls++
|
|
if f.next == nil {
|
|
return NewUserResponse{}, nil
|
|
}
|
|
return f.next.NewUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
|
|
f.updateUserCalls++
|
|
if f.next == nil {
|
|
return UpdateUserResponse{}, nil
|
|
}
|
|
return f.next.UpdateUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
|
|
f.deleteUserCalls++
|
|
if f.next == nil {
|
|
return DeleteUserResponse{}, nil
|
|
}
|
|
return f.next.DeleteUser(ctx, req)
|
|
}
|
|
|
|
func (f *recordingDatabase) Type() (string, error) {
|
|
f.typeCalls++
|
|
if f.next == nil {
|
|
return "recordingDatabase", nil
|
|
}
|
|
return f.next.Type()
|
|
}
|
|
|
|
func (f *recordingDatabase) Close() error {
|
|
f.closeCalls++
|
|
if f.next == nil {
|
|
return nil
|
|
}
|
|
return f.next.Close()
|
|
}
|