vault/sdk/database/dbplugin/v5/middleware.go
Michael Golowka a69ee0f65a
DBPW - Copy newdbplugin package to dbplugin/v5 (#10151)
This is part 1 of 4 for renaming the `newdbplugin` package. This copies the existing package to the new location but keeps the current one in place so we can migrate the existing references over more easily.
2020-10-15 13:20:12 -06:00

275 lines
8.4 KiB
Go

package dbplugin
import (
"context"
"errors"
"net/url"
"strings"
"time"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"google.golang.org/grpc/status"
)
// ///////////////////////////////////////////////////
// Tracing Middleware
// ///////////////////////////////////////////////////
var _ Database = databaseTracingMiddleware{}
// databaseTracingMiddleware wraps a implementation of Database and executes
// trace logging on function call.
type databaseTracingMiddleware struct {
next Database
logger log.Logger
}
func (mw databaseTracingMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("initialize",
"status", "finished",
"verify", req.VerifyConnection,
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("initialize", "status", "started")
return mw.next.Initialize(ctx, req)
}
func (mw databaseTracingMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("create user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("create user",
"status", "started")
return mw.next.NewUser(ctx, req)
}
func (mw databaseTracingMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("update user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("update user", "status", "started")
return mw.next.UpdateUser(ctx, req)
}
func (mw databaseTracingMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("delete user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("delete user",
"status", "started")
return mw.next.DeleteUser(ctx, req)
}
func (mw databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("close",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("close",
"status", "started")
return mw.next.Close()
}
// ///////////////////////////////////////////////////
// Metrics Middleware Domain
// ///////////////////////////////////////////////////
var _ Database = databaseMetricsMiddleware{}
// databaseMetricsMiddleware wraps an implementation of Databases and on
// function call logs metrics about this instance.
type databaseMetricsMiddleware struct {
next Database
typeStr string
}
func (mw databaseMetricsMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(ctx, req)
}
func (mw databaseMetricsMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
defer func(start time.Time) {
metrics.MeasureSince([]string{"database", "NewUser"}, start)
metrics.MeasureSince([]string{"database", mw.typeStr, "NewUser"}, start)
if err != nil {
metrics.IncrCounter([]string{"database", "NewUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "NewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser"}, 1)
return mw.next.NewUser(ctx, req)
}
func (mw databaseMetricsMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "UpdateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "UpdateUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "UpdateUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "UpdateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser"}, 1)
return mw.next.UpdateUser(ctx, req)
}
func (mw databaseMetricsMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "DeleteUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "DeleteUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "DeleteUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "DeleteUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser"}, 1)
return mw.next.DeleteUser(ctx, req)
}
func (mw databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw databaseMetricsMiddleware) Close() (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Close"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Close"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close()
}
// ///////////////////////////////////////////////////
// Error Sanitizer Middleware Domain
// ///////////////////////////////////////////////////
var _ Database = DatabaseErrorSanitizerMiddleware{}
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type DatabaseErrorSanitizerMiddleware struct {
next Database
secretsFn secretsFn
}
type secretsFn func() map[string]string
func NewDatabaseErrorSanitizerMiddleware(next Database, secrets secretsFn) DatabaseErrorSanitizerMiddleware {
return DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secrets,
}
}
func (mw DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
resp, err = mw.next.Initialize(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
resp, err = mw.next.NewUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
resp, err := mw.next.UpdateUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
resp, err := mw.next.DeleteUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
// sanitize errors by removing any sensitive strings within their messages. This uses
// the secretsFn to determine what fields should be sanitized.
func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn == nil {
return err
}
for find, replace := range mw.secretsFn() {
if find == "" {
continue
}
// Attempt to keep the status code attached to the
// error while changing the actual error message
s, ok := status.FromError(err)
if ok {
err = status.Error(s.Code(), strings.Replace(s.Message(), find, replace, -1))
continue
}
err = errors.New(strings.Replace(err.Error(), find, replace, -1))
}
return err
}