From e77f79b317d412de7b2aa8db0c6775f696fdabc7 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 17 Mar 2015 17:15:23 -0700 Subject: [PATCH] logical/framework: rollback support --- logical/framework/backend.go | 47 +++++++++++++++++++++++++++++++ logical/framework/backend_test.go | 33 ++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/logical/framework/backend.go b/logical/framework/backend.go index 7aad4eda2c..70c09f477b 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -8,6 +8,7 @@ import ( "sync" "text/template" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/logical" "github.com/mitchellh/go-wordwrap" ) @@ -30,6 +31,11 @@ type Backend struct { // For prefix match, append '*' as a suffix. PathsRoot []string + // Rollback is called when a WAL entry (see wal.go) has to be rolled + // back. It is called with the data from the entry. Boolean true should + // be returned on success. Errors should just be logged. + Rollback func(data interface{}) bool + once sync.Once pathsRe []*regexp.Regexp } @@ -82,6 +88,12 @@ type OperationFunc func(*logical.Request, *FieldData) (*logical.Response, error) // logical.Backend impl. func (b *Backend) HandleRequest(req *logical.Request) (*logical.Response, error) { + // Rollbacks are treated outside of the normal request cycle since + // the path doesn't matter for them. + if req.Operation == logical.RollbackOperation { + return b.handleRollback(req) + } + // Find the matching route path, captures := b.route(req.Path) if path == nil { @@ -167,6 +179,41 @@ func (b *Backend) route(path string) (*Path, map[string]string) { return nil, nil } +func (b *Backend) handleRollback( + req *logical.Request) (*logical.Response, error) { + if b.Rollback == nil { + return nil, logical.ErrUnsupportedOperation + } + + var merr error + keys, err := ListWAL(req.Storage) + if err != nil { + merr = multierror.Append(merr, err) + goto RESPOND_ROLLBACK + } + + for _, k := range keys { + data, err := GetWAL(req.Storage, k) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + + if b.Rollback(data) { + if err := DeleteWAL(req.Storage, k); err != nil { + merr = multierror.Append(merr, err) + } + } + } + +RESPOND_ROLLBACK: + if merr == nil { + return nil, nil + } + + return logical.ErrorResponse(merr.Error()), nil +} + func (p *Path) helpCallback(req *logical.Request, data *FieldData) (*logical.Response, error) { var tplData pathTemplateData tplData.Request = req.Path diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index c6cae67ef0..ef94121e0b 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -2,6 +2,7 @@ package framework import ( "reflect" + "sync/atomic" "testing" "github.com/hashicorp/vault/logical" @@ -135,6 +136,38 @@ func TestBackendHandleRequest_help(t *testing.T) { } } +func TestBackendHandleRequest_rollback(t *testing.T) { + var called uint32 + callback := func(data interface{}) bool { + if data == "foo" { + atomic.AddUint32(&called, 1) + } + + return true + } + + b := &Backend{ + Rollback: callback, + } + + storage := new(logical.InmemStorage) + if _, err := PutWAL(storage, "foo"); err != nil { + t.Fatalf("err: %s", err) + } + + _, err := b.HandleRequest(&logical.Request{ + Operation: logical.RollbackOperation, + Path: "", + Storage: storage, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + if v := atomic.LoadUint32(&called); v != 1 { + t.Fatalf("bad: %#v", v) + } +} + func TestBackendHandleRequest_unsupportedOperation(t *testing.T) { callback := func(req *logical.Request, data *FieldData) (*logical.Response, error) { return &logical.Response{