diff --git a/vault/core.go b/vault/core.go index 755b368303..8cbfdfdf53 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1193,9 +1193,15 @@ func (c *Core) Seal(token string) error { func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr error) { defer metrics.MeasureSince([]string{"core", "seal-internal"}, time.Now()) + var unlocked bool + defer func() { + if !unlocked { + c.stateLock.RUnlock() + } + }() + if req == nil { retErr = multierror.Append(retErr, errors.New("nil request to seal")) - c.stateLock.RUnlock() return retErr } @@ -1207,14 +1213,12 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if c.standby { c.logger.Error("vault cannot seal when in standby mode; please restart instead") retErr = multierror.Append(retErr, errors.New("vault cannot seal when in standby mode; please restart instead")) - c.stateLock.RUnlock() return retErr } acl, te, entity, identityPolicies, err := c.fetchACLTokenEntryAndEntity(ctx, req) if err != nil { retErr = multierror.Append(retErr, err) - c.stateLock.RUnlock() return retErr } @@ -1242,20 +1246,17 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { c.logger.Error("failed to audit request", "request_path", req.Path, "error", err) retErr = multierror.Append(retErr, errors.New("failed to audit request, cannot continue")) - c.stateLock.RUnlock() return retErr } if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } if te != nil && te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } @@ -1266,13 +1267,11 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if err != nil { c.logger.Error("failed to use token", "error", err) retErr = multierror.Append(retErr, ErrInternalError) - c.stateLock.RUnlock() return retErr } if te == nil { // Token is no longer valid retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } } @@ -1282,7 +1281,6 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr RootPrivsRequired: true, }) if !authResults.Allowed { - c.stateLock.RUnlock() retErr = multierror.Append(retErr, authResults.Error) if authResults.Error.ErrorOrNil() == nil || authResults.DeniedError { retErr = multierror.Append(retErr, logical.ErrPermissionDenied) @@ -1304,6 +1302,7 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr } // Unlock; sealing will grab the lock when needed + unlocked = true c.stateLock.RUnlock() sealErr := c.sealInternal() diff --git a/vault/external_tests/misc/recover_from_panic_test.go b/vault/external_tests/misc/recover_from_panic_test.go new file mode 100644 index 0000000000..157a4e9694 --- /dev/null +++ b/vault/external_tests/misc/recover_from_panic_test.go @@ -0,0 +1,49 @@ +package token + +import ( + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" +) + +// Tests the regression in +// https://github.com/hashicorp/vault/pull/6920 +func TestRecoverFromPanic(t *testing.T) { + logger := hclog.New(nil) + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "noop": vault.NoopBackendFactory, + }, + EnableRaw: true, + Logger: logger, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0] + vault.TestWaitActive(t, core.Core) + client := core.Client + + err := client.Sys().Mount("noop", &api.MountInput{ + Type: "noop", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Read("noop/panic") + if err == nil { + t.Fatal("expected error") + } + + // This will deadlock the test if we hit the condition + cluster.EnsureCoresSealed(t) +} diff --git a/vault/ha.go b/vault/ha.go index 39fb74d48f..fd2ca1b486 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -207,6 +207,7 @@ func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr e c.stateLock.RLock() defer c.stateLock.RUnlock() + if c.Sealed() { return nil } @@ -261,14 +262,12 @@ func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr e if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } if te != nil && te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } diff --git a/vault/request_handling.go b/vault/request_handling.go index 41aac1a309..6d45d8e163 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -385,18 +385,12 @@ func (c *Core) HandleRequest(httpCtx context.Context, req *logical.Request) (res func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.Request, doLocking bool) (resp *logical.Response, err error) { if doLocking { c.stateLock.RLock() - } - unlockFunc := func() { - if doLocking { - c.stateLock.RUnlock() - } + defer c.stateLock.RUnlock() } if c.Sealed() { - unlockFunc() return nil, consts.ErrSealed } if c.standby && !c.perfStandby { - unlockFunc() return nil, consts.ErrStandby } @@ -412,7 +406,6 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R ns, err := namespace.FromContext(httpCtx) if err != nil { cancel() - unlockFunc() return nil, errwrap.Wrapf("could not parse namespace from http context: {{err}}", err) } ctx = namespace.ContextWithNamespace(ctx, ns) @@ -421,7 +414,6 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R req.SetTokenEntry(nil) cancel() - unlockFunc() return resp, err } diff --git a/vault/router_test.go b/vault/router_test.go index 29f9ae0a7f..b20b69894f 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -1,116 +1,15 @@ package vault import ( - "context" - "fmt" "reflect" "strings" - "sync" "testing" - "time" - log "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) -type HandlerFunc func(context.Context, *logical.Request) (*logical.Response, error) - -type NoopBackend struct { - sync.Mutex - - Root []string - Login []string - Paths []string - Requests []*logical.Request - Response *logical.Response - RequestHandler HandlerFunc - Invalidations []string - DefaultLeaseTTL time.Duration - MaxLeaseTTL time.Duration - BackendType logical.BackendType -} - -func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { - if req.TokenEntry() != nil { - panic("got a non-nil TokenEntry") - } - - var err error - resp := n.Response - if n.RequestHandler != nil { - resp, err = n.RequestHandler(ctx, req) - } - - n.Lock() - defer n.Unlock() - - requestCopy := *req - n.Paths = append(n.Paths, req.Path) - n.Requests = append(n.Requests, &requestCopy) - if req.Storage == nil { - return nil, fmt.Errorf("missing view") - } - - return resp, err -} - -func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { - return false, false, nil -} - -func (n *NoopBackend) SpecialPaths() *logical.Paths { - return &logical.Paths{ - Root: n.Root, - Unauthenticated: n.Login, - } -} - -func (n *NoopBackend) System() logical.SystemView { - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - if n.DefaultLeaseTTL > 0 { - defaultLeaseTTLVal = n.DefaultLeaseTTL - } - - if n.MaxLeaseTTL > 0 { - maxLeaseTTLVal = n.MaxLeaseTTL - } - - return logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - } -} - -func (n *NoopBackend) Cleanup(ctx context.Context) { - // noop -} - -func (n *NoopBackend) InvalidateKey(ctx context.Context, k string) { - n.Invalidations = append(n.Invalidations, k) -} - -func (n *NoopBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { - return nil -} - -func (n *NoopBackend) Logger() log.Logger { - return log.NewNullLogger() -} - -func (n *NoopBackend) Initialize(ctx context.Context) error { - return nil -} - -func (n *NoopBackend) Type() logical.BackendType { - if n.BackendType == logical.TypeUnknown { - return logical.TypeLogical - } - return n.BackendType -} - func TestRouter_Mount(t *testing.T) { r := NewRouter() _, barrier, _ := mockBarrier(t) diff --git a/vault/router_testing.go b/vault/router_testing.go new file mode 100644 index 0000000000..bc287806ed --- /dev/null +++ b/vault/router_testing.go @@ -0,0 +1,115 @@ +package vault + +import ( + "context" + "fmt" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" +) + +type RouterTestHandlerFunc func(context.Context, *logical.Request) (*logical.Response, error) + +type NoopBackend struct { + sync.Mutex + + Root []string + Login []string + Paths []string + Requests []*logical.Request + Response *logical.Response + RequestHandler RouterTestHandlerFunc + Invalidations []string + DefaultLeaseTTL time.Duration + MaxLeaseTTL time.Duration + BackendType logical.BackendType +} + +func NoopBackendFactory(_ context.Context, _ *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{}, nil +} + +func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { + if req.TokenEntry() != nil { + panic("got a non-nil TokenEntry") + } + + var err error + resp := n.Response + if n.RequestHandler != nil { + resp, err = n.RequestHandler(ctx, req) + } + + n.Lock() + defer n.Unlock() + + requestCopy := *req + n.Paths = append(n.Paths, req.Path) + n.Requests = append(n.Requests, &requestCopy) + if req.Storage == nil { + return nil, fmt.Errorf("missing view") + } + + if req.Path == "panic" { + panic("as you command") + } + + return resp, err +} + +func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { + return false, false, nil +} + +func (n *NoopBackend) SpecialPaths() *logical.Paths { + return &logical.Paths{ + Root: n.Root, + Unauthenticated: n.Login, + } +} + +func (n *NoopBackend) System() logical.SystemView { + defaultLeaseTTLVal := time.Hour * 24 + maxLeaseTTLVal := time.Hour * 24 * 32 + if n.DefaultLeaseTTL > 0 { + defaultLeaseTTLVal = n.DefaultLeaseTTL + } + + if n.MaxLeaseTTL > 0 { + maxLeaseTTLVal = n.MaxLeaseTTL + } + + return logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLVal, + MaxLeaseTTLVal: maxLeaseTTLVal, + } +} + +func (n *NoopBackend) Cleanup(ctx context.Context) { + // noop +} + +func (n *NoopBackend) InvalidateKey(ctx context.Context, k string) { + n.Invalidations = append(n.Invalidations, k) +} + +func (n *NoopBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { + return nil +} + +func (n *NoopBackend) Logger() log.Logger { + return log.NewNullLogger() +} + +func (n *NoopBackend) Initialize(ctx context.Context) error { + return nil +} + +func (n *NoopBackend) Type() logical.BackendType { + if n.BackendType == logical.TypeUnknown { + return logical.TypeLogical + } + return n.BackendType +}