package dbplugin import ( "fmt" "reflect" "strings" "testing" "time" "unicode" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" ) func TestConversionsHaveAllFields(t *testing.T) { t.Run("initReqToProto", func(t *testing.T) { req := InitializeRequest{ Config: map[string]interface{}{ "foo": map[string]interface{}{ "bar": "baz", }, }, VerifyConnection: true, } protoReq, err := initReqToProto(req) if err != nil { t.Fatalf("Failed to convert request to proto request: %s", err) } values := getAllGetterValues(protoReq) if len(values) == 0 { // Probably a test failure - the protos used in these tests should have Get functions on them t.Fatalf("No values found from Get functions!") } for _, gtr := range values { err := assertAllFieldsSet(fmt.Sprintf("InitializeRequest.%s", gtr.name), gtr.value) if err != nil { t.Fatalf("%s", err) } } }) t.Run("newUserReqToProto", func(t *testing.T) { req := NewUserRequest{ UsernameConfig: UsernameMetadata{ DisplayName: "dispName", RoleName: "roleName", }, Statements: Statements{ Commands: []string{ "statement", }, }, RollbackStatements: Statements{ Commands: []string{ "rollback_statement", }, }, Password: "password", Expiration: time.Now(), } protoReq, err := newUserReqToProto(req) if err != nil { t.Fatalf("Failed to convert request to proto request: %s", err) } values := getAllGetterValues(protoReq) if len(values) == 0 { // Probably a test failure - the protos used in these tests should have Get functions on them t.Fatalf("No values found from Get functions!") } for _, gtr := range values { err := assertAllFieldsSet(fmt.Sprintf("NewUserRequest.%s", gtr.name), gtr.value) if err != nil { t.Fatalf("%s", err) } } }) t.Run("updateUserReqToProto", func(t *testing.T) { req := UpdateUserRequest{ Username: "username", Password: &ChangePassword{ NewPassword: "newpassword", Statements: Statements{ Commands: []string{ "statement", }, }, }, Expiration: &ChangeExpiration{ NewExpiration: time.Now(), Statements: Statements{ Commands: []string{ "statement", }, }, }, } protoReq, err := updateUserReqToProto(req) if err != nil { t.Fatalf("Failed to convert request to proto request: %s", err) } values := getAllGetterValues(protoReq) if len(values) == 0 { // Probably a test failure - the protos used in these tests should have Get functions on them t.Fatalf("No values found from Get functions!") } for _, gtr := range values { err := assertAllFieldsSet(fmt.Sprintf("UpdateUserRequest.%s", gtr.name), gtr.value) if err != nil { t.Fatalf("%s", err) } } }) t.Run("deleteUserReqToProto", func(t *testing.T) { req := DeleteUserRequest{ Username: "username", Statements: Statements{ Commands: []string{ "statement", }, }, } protoReq, err := deleteUserReqToProto(req) if err != nil { t.Fatalf("Failed to convert request to proto request: %s", err) } values := getAllGetterValues(protoReq) if len(values) == 0 { // Probably a test failure - the protos used in these tests should have Get functions on them t.Fatalf("No values found from Get functions!") } for _, gtr := range values { err := assertAllFieldsSet(fmt.Sprintf("DeleteUserRequest.%s", gtr.name), gtr.value) if err != nil { t.Fatalf("%s", err) } } }) t.Run("getUpdateUserRequest", func(t *testing.T) { req := &proto.UpdateUserRequest{ Username: "username", Password: &proto.ChangePassword{ NewPassword: "newpass", Statements: &proto.Statements{ Commands: []string{ "statement", }, }, }, Expiration: &proto.ChangeExpiration{ NewExpiration: timestamppb.Now(), Statements: &proto.Statements{ Commands: []string{ "statement", }, }, }, } protoReq, err := getUpdateUserRequest(req) if err != nil { t.Fatalf("Failed to convert request to proto request: %s", err) } err = assertAllFieldsSet("proto.UpdateUserRequest", protoReq) if err != nil { t.Fatalf("%s", err) } }) } type getter struct { name string value interface{} } func getAllGetterValues(value interface{}) (values []getter) { typ := reflect.TypeOf(value) val := reflect.ValueOf(value) for i := 0; i < typ.NumMethod(); i++ { method := typ.Method(i) if !strings.HasPrefix(method.Name, "Get") { continue } valMethod := val.Method(i) resp := valMethod.Call(nil) getVal := resp[0].Interface() gtr := getter{ name: strings.TrimPrefix(method.Name, "Get"), value: getVal, } values = append(values, gtr) } return values } // Ensures the assertion works properly func TestAssertAllFieldsSet(t *testing.T) { type testCase struct { value interface{} expectErr bool } tests := map[string]testCase{ "zero int": { value: 0, expectErr: true, }, "non-zero int": { value: 1, expectErr: false, }, "zero float64": { value: 0.0, expectErr: true, }, "non-zero float64": { value: 1.0, expectErr: false, }, "empty string": { value: "", expectErr: true, }, "true boolean": { value: true, expectErr: false, }, "false boolean": { // False is an exception to the "is zero" rule value: false, expectErr: false, }, "blank struct": { value: struct{}{}, expectErr: true, }, "non-blank but empty struct": { value: struct { str string }{ str: "", }, expectErr: true, }, "non-empty string": { value: "foo", expectErr: false, }, "non-empty struct": { value: struct { str string }{ str: "foo", }, expectErr: false, }, "empty nested struct": { value: struct { Str string Substruct struct { Substr string } }{ Str: "foo", Substruct: struct { Substr string }{}, // Empty sub-field }, expectErr: true, }, "filled nested struct": { value: struct { str string substruct struct { substr string } }{ str: "foo", substruct: struct { substr string }{ substr: "sub-foo", }, }, expectErr: false, }, "nil map": { value: map[string]string(nil), expectErr: true, }, "empty map": { value: map[string]string{}, expectErr: true, }, "filled map": { value: map[string]string{ "foo": "bar", "int": "42", }, expectErr: false, }, "map with empty string value": { value: map[string]string{ "foo": "", }, expectErr: true, }, "nested map with empty string value": { value: map[string]interface{}{ "bar": "baz", "foo": map[string]interface{}{ "subfoo": "", }, }, expectErr: true, }, "nil slice": { value: []string(nil), expectErr: true, }, "empty slice": { value: []string{}, expectErr: true, }, "filled slice": { value: []string{ "foo", }, expectErr: false, }, "slice with empty string value": { value: []string{ "", }, expectErr: true, }, "empty structpb": { value: newStructPb(t, map[string]interface{}{}), expectErr: true, }, "filled structpb": { value: newStructPb(t, map[string]interface{}{ "foo": "bar", "int": 42, }), expectErr: false, }, "pointer to zero int": { value: intPtr(0), expectErr: true, }, "pointer to non-zero int": { value: intPtr(1), expectErr: false, }, "pointer to zero float64": { value: float64Ptr(0.0), expectErr: true, }, "pointer to non-zero float64": { value: float64Ptr(1.0), expectErr: false, }, "pointer to nil string": { value: new(string), expectErr: true, }, "pointer to non-nil string": { value: strPtr("foo"), expectErr: false, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { err := assertAllFieldsSet("", test.value) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } if !test.expectErr && err != nil { t.Fatalf("no error expected, got: %s", err) } }) } } func assertAllFieldsSet(name string, val interface{}) error { if val == nil { return fmt.Errorf("value is nil") } rVal := reflect.ValueOf(val) return assertAllFieldsSetValue(name, rVal) } func assertAllFieldsSetValue(name string, rVal reflect.Value) error { // All booleans are allowed - we don't have a way of differentiating between // and intentional false and a missing false if rVal.Kind() == reflect.Bool { return nil } // Primitives fall through here if rVal.IsZero() { return fmt.Errorf("%s is zero", name) } switch rVal.Kind() { case reflect.Ptr, reflect.Interface: return assertAllFieldsSetValue(name, rVal.Elem()) case reflect.Struct: return assertAllFieldsSetStruct(name, rVal) case reflect.Map: if rVal.Len() == 0 { return fmt.Errorf("%s (map type) is empty", name) } iter := rVal.MapRange() for iter.Next() { k := iter.Key() v := iter.Value() err := assertAllFieldsSetValue(fmt.Sprintf("%s[%s]", name, k), v) if err != nil { return err } } case reflect.Slice: if rVal.Len() == 0 { return fmt.Errorf("%s (slice type) is empty", name) } for i := 0; i < rVal.Len(); i++ { sliceVal := rVal.Index(i) err := assertAllFieldsSetValue(fmt.Sprintf("%s[%d]", name, i), sliceVal) if err != nil { return err } } } return nil } func assertAllFieldsSetStruct(name string, rVal reflect.Value) error { switch rVal.Type() { case reflect.TypeOf(timestamppb.Timestamp{}): ts := rVal.Interface().(timestamppb.Timestamp) if ts.AsTime().IsZero() { return fmt.Errorf("%s is zero", name) } return nil default: for i := 0; i < rVal.NumField(); i++ { field := rVal.Field(i) fieldName := rVal.Type().Field(i) // Skip fields that aren't exported if unicode.IsLower([]rune(fieldName.Name)[0]) { continue } err := assertAllFieldsSetValue(fmt.Sprintf("%s.%s", name, fieldName.Name), field) if err != nil { return err } } return nil } } func intPtr(i int) *int { return &i } func float64Ptr(f float64) *float64 { return &f } func strPtr(str string) *string { return &str } func newStructPb(t *testing.T, m map[string]interface{}) *structpb.Struct { t.Helper() s, err := structpb.NewStruct(m) if err != nil { t.Fatalf("Failed to convert map to struct: %s", err) } return s }