diff --git a/vault/barrier_view.go b/vault/barrier_view.go index 4829b7b0d1..18ebb70118 100644 --- a/vault/barrier_view.go +++ b/vault/barrier_view.go @@ -108,3 +108,18 @@ func ScanView(view *BarrierView, cb func(path string)) error { } return nil } + +// CollectKeys is used to collect all the keys in a view +func CollectKeys(view *BarrierView) ([]string, error) { + // Accumulate the keys + var existing []string + cb := func(path string) { + existing = append(existing, path) + } + + // Scan for all the keys + if err := ScanView(view, cb); err != nil { + return nil, err + } + return existing, nil +} diff --git a/vault/barrier_view_test.go b/vault/barrier_view_test.go index 71fe94e23c..f3ec596dbf 100644 --- a/vault/barrier_view_test.go +++ b/vault/barrier_view_test.go @@ -183,3 +183,37 @@ func TestBarrierView_Scan(t *testing.T) { t.Fatalf("out: %v expect: %v", out, expect) } } + +func TestBarrierView_CollectKeys(t *testing.T) { + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "view/") + + expect := []string{} + ent := []*logical.StorageEntry{ + &logical.StorageEntry{Key: "foo", Value: []byte("test")}, + &logical.StorageEntry{Key: "zip", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/zap", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar/baz", Value: []byte("test")}, + &logical.StorageEntry{Key: "foo/bar/zoo", Value: []byte("test")}, + } + + for _, e := range ent { + expect = append(expect, e.Key) + if err := view.Put(e); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Collect the keys + out, err := CollectKeys(view) + if err != nil { + t.Fatalf("err: %v", err) + } + + sort.Strings(out) + sort.Strings(expect) + if !reflect.DeepEqual(out, expect) { + t.Fatalf("out: %v expect: %v", out, expect) + } +} diff --git a/vault/expiration.go b/vault/expiration.go index 09e1bcbd9b..0f69f76c90 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -90,13 +90,8 @@ func (m *ExpirationManager) Restore() error { defer m.pendingLock.Unlock() // Accumulate existing leases - var existing []string - cb := func(path string) { - existing = append(existing, path) - } - - // Scan for all the leases - if err := ScanView(m.view, cb); err != nil { + existing, err := CollectKeys(m.view) + if err != nil { return fmt.Errorf("failed to scan for leases: %v", err) } @@ -184,14 +179,9 @@ func (m *ExpirationManager) RevokePrefix(prefix string) error { } // Accumulate existing leases - var existing []string - cb := func(path string) { - existing = append(existing, path) - } - - // Scan for all the leases in the prefix sub := m.view.SubView(prefix) - if err := ScanView(sub, cb); err != nil { + existing, err := CollectKeys(sub) + if err != nil { return fmt.Errorf("failed to scan for leases: %v", err) } diff --git a/vault/policy.go b/vault/policy.go index f0f335feb0..178623cfa1 100644 --- a/vault/policy.go +++ b/vault/policy.go @@ -27,6 +27,7 @@ var ( type Policy struct { Name string `hcl:"name"` Paths []*PathPolicy `hcl:"path,expand"` + Raw string } // PathPolicy represents a policy for a path in the namespace @@ -40,7 +41,7 @@ type PathPolicy struct { // the ACL func Parse(rules string) (*Policy, error) { // Decode the rules - p := &Policy{} + p := &Policy{Raw: rules} if err := hcl.Decode(p, rules); err != nil { return nil, fmt.Errorf("Failed to parse ACL rules: %v", err) }