Update Type() to return an error

This commit is contained in:
Brian Kassouf 2017-04-12 16:41:06 -07:00
parent f2401c0128
commit 03e2bcbc79
6 changed files with 26 additions and 19 deletions

View File

@ -162,5 +162,5 @@ as secret backends, including but not limited to:
cassandra, msslq, mysql, postgres cassandra, msslq, mysql, postgres
After mounting this backend, configure it using the endpoints within After mounting this backend, configure it using the endpoints within
the "database/dbs/" path. the "database/config/" path.
` `

View File

@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn
return nil, err return nil, err
} }
// We should have a Greeter now! This feels like a normal interface // We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection. // implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient) databaseRPC := raw.(*databasePluginRPCClient)
// Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{ return &DatabasePluginClient{
client: client, client: client,
databasePluginRPCClient: databaseRPC, databasePluginRPCClient: databaseRPC,
@ -70,12 +71,11 @@ type databasePluginRPCClient struct {
client *rpc.Client client *rpc.Client
} }
func (dr *databasePluginRPCClient) Type() string { func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string var dbType string
//TODO: catch error err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType) return fmt.Sprintf("plugin-%s", dbType), err
} }
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {

View File

@ -18,7 +18,7 @@ type databaseTracingMiddleware struct {
typeStr string typeStr string
} }
func (mw *databaseTracingMiddleware) Type() string { func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type() return mw.next.Type()
} }
@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct {
typeStr string typeStr string
} }
func (mw *databaseMetricsMiddleware) Type() string { func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type() return mw.next.Type()
} }

View File

@ -2,6 +2,7 @@ package dbplugin
import ( import (
"errors" "errors"
"fmt"
"net/rpc" "net/rpc"
"time" "time"
@ -16,7 +17,7 @@ var (
// DatabaseType is the interface that all database objects must implement. // DatabaseType is the interface that all database objects must implement.
type DatabaseType interface { type DatabaseType interface {
Type() string Type() (string, error)
CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error RevokeUser(statements Statements, username string) error
@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log
return nil, err return nil, err
} }
typeStr, err := db.Type()
if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err)
}
// Wrap with metrics middleware // Wrap with metrics middleware
db = &databaseMetricsMiddleware{ db = &databaseMetricsMiddleware{
next: db, next: db,
typeStr: db.Type(), typeStr: typeStr,
} }
// Wrap with tracing middleware // Wrap with tracing middleware
db = &databaseTracingMiddleware{ db = &databaseTracingMiddleware{
next: db, next: db,
typeStr: db.Type(), typeStr: typeStr,
logger: logger, logger: logger,
} }

View File

@ -19,7 +19,7 @@ type mockPlugin struct {
users map[string][]string users map[string][]string
} }
func (m *mockPlugin) Type() string { return "mock" } func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
err = errors.New("err") err = errors.New("err")
if usernamePrefix == "" || expiration.IsZero() { if usernamePrefix == "" || expiration.IsZero() {
@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username) delete(m.users, username)
return nil return nil
} }
func (m *mockPlugin) Initialize(conf map[string]interface{}) error { func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
err := errors.New("err") err := errors.New("err")
if len(conf) != 1 { if len(conf) != 1 {
return err return err
@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1, "test": 1,
} }
err = dbRaw.Initialize(connectionDetails) err = dbRaw.Initialize(connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails) err = db.Initialize(connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails) err = db.Initialize(connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails) err = db.Initialize(connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@ -42,8 +42,9 @@ type databasePluginRPCServer struct {
} }
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
*resp = ds.impl.Type() var err error
return nil *resp, err = ds.impl.Type()
return err
} }
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {