package dbplugin import ( "context" "errors" "net/url" "strings" "time" metrics "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "google.golang.org/grpc/status" ) // /////////////////////////////////////////////////// // Tracing Middleware // /////////////////////////////////////////////////// var _ Database = databaseTracingMiddleware{} // databaseTracingMiddleware wraps a implementation of Database and executes // trace logging on function call. type databaseTracingMiddleware struct { next Database logger log.Logger } func (mw databaseTracingMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { defer func(then time.Time) { mw.logger.Trace("initialize", "status", "finished", "verify", req.VerifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("initialize", "status", "started") return mw.next.Initialize(ctx, req) } func (mw databaseTracingMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { defer func(then time.Time) { mw.logger.Trace("create user", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("create user", "status", "started") return mw.next.NewUser(ctx, req) } func (mw databaseTracingMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) { defer func(then time.Time) { mw.logger.Trace("update user", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("update user", "status", "started") return mw.next.UpdateUser(ctx, req) } func (mw databaseTracingMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) { defer func(then time.Time) { mw.logger.Trace("delete user", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("delete user", "status", "started") return mw.next.DeleteUser(ctx, req) } func (mw databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } func (mw databaseTracingMiddleware) Close() (err error) { defer func(then time.Time) { mw.logger.Trace("close", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("close", "status", "started") return mw.next.Close() } // /////////////////////////////////////////////////// // Metrics Middleware Domain // /////////////////////////////////////////////////// var _ Database = databaseMetricsMiddleware{} // databaseMetricsMiddleware wraps an implementation of Databases and on // function call logs metrics about this instance. type databaseMetricsMiddleware struct { next Database typeStr string } func (mw databaseMetricsMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) if err != nil { metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1) } }(time.Now()) metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) return mw.next.Initialize(ctx, req) } func (mw databaseMetricsMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { defer func(start time.Time) { metrics.MeasureSince([]string{"database", "NewUser"}, start) metrics.MeasureSince([]string{"database", mw.typeStr, "NewUser"}, start) if err != nil { metrics.IncrCounter([]string{"database", "NewUser", "error"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser", "error"}, 1) } }(time.Now()) metrics.IncrCounter([]string{"database", "NewUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser"}, 1) return mw.next.NewUser(ctx, req) } func (mw databaseMetricsMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "UpdateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "UpdateUser"}, now) if err != nil { metrics.IncrCounter([]string{"database", "UpdateUser", "error"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser", "error"}, 1) } }(time.Now()) metrics.IncrCounter([]string{"database", "UpdateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser"}, 1) return mw.next.UpdateUser(ctx, req) } func (mw databaseMetricsMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "DeleteUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "DeleteUser"}, now) if err != nil { metrics.IncrCounter([]string{"database", "DeleteUser", "error"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser", "error"}, 1) } }(time.Now()) metrics.IncrCounter([]string{"database", "DeleteUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser"}, 1) return mw.next.DeleteUser(ctx, req) } func (mw databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } func (mw databaseMetricsMiddleware) Close() (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Close"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now) if err != nil { metrics.IncrCounter([]string{"database", "Close", "error"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1) } }(time.Now()) metrics.IncrCounter([]string{"database", "Close"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) return mw.next.Close() } // /////////////////////////////////////////////////// // Error Sanitizer Middleware Domain // /////////////////////////////////////////////////// var _ Database = DatabaseErrorSanitizerMiddleware{} // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and // sanitizes returned error messages type DatabaseErrorSanitizerMiddleware struct { next Database secretsFn secretsFn } type secretsFn func() map[string]string func NewDatabaseErrorSanitizerMiddleware(next Database, secrets secretsFn) DatabaseErrorSanitizerMiddleware { return DatabaseErrorSanitizerMiddleware{ next: next, secretsFn: secrets, } } func (mw DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { resp, err = mw.next.Initialize(ctx, req) return resp, mw.sanitize(err) } func (mw DatabaseErrorSanitizerMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { resp, err = mw.next.NewUser(ctx, req) return resp, mw.sanitize(err) } func (mw DatabaseErrorSanitizerMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) { resp, err := mw.next.UpdateUser(ctx, req) return resp, mw.sanitize(err) } func (mw DatabaseErrorSanitizerMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) { resp, err := mw.next.DeleteUser(ctx, req) return resp, mw.sanitize(err) } func (mw DatabaseErrorSanitizerMiddleware) Type() (string, error) { dbType, err := mw.next.Type() return dbType, mw.sanitize(err) } func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) { return mw.sanitize(mw.next.Close()) } // sanitize errors by removing any sensitive strings within their messages. This uses // the secretsFn to determine what fields should be sanitized. func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error { if err == nil { return nil } if errwrap.ContainsType(err, new(url.Error)) { return errors.New("unable to parse connection url") } if mw.secretsFn == nil { return err } for find, replace := range mw.secretsFn() { if find == "" { continue } // Attempt to keep the status code attached to the // error while changing the actual error message s, ok := status.FromError(err) if ok { err = status.Error(s.Code(), strings.Replace(s.Message(), find, replace, -1)) continue } err = errors.New(strings.Replace(err.Error(), find, replace, -1)) } return err }