diff --git a/http/handler.go b/http/handler.go index efaf3375a2..e6ec8b9735 100644 --- a/http/handler.go +++ b/http/handler.go @@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) { status = http.StatusServiceUnavailable } + // Allow HTTPCoded error passthrough to specify a code + if t, ok := err.(logical.HTTPCodedError); ok { + status = t.Code() + } + w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) diff --git a/http/handler_test.go b/http/handler_test.go index a38771c31d..266c09a0af 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -1,10 +1,13 @@ package http import ( + "errors" "net/http" + "net/http/httptest" "reflect" "testing" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" ) @@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) { } testResponseStatus(t, resp, 503) } + +func TestHandler_error(t *testing.T) { + w := httptest.NewRecorder() + + respondError(w, 500, errors.New("Test Error")) + + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } + + // The code inside of the error should override + // the argument to respondError + w2 := httptest.NewRecorder() + e := logical.CodedError(403, "error text") + + respondError(w2, 500, e) + + if w2.Code != 403 { + t.Fatalf("expected 403, got %d", w2.Code) + } + + // vault.ErrSealed is a special case + w3 := httptest.NewRecorder() + + respondError(w3, 400, vault.ErrSealed) + + if w3.Code != 503 { + t.Fatalf("expected 503, got %d", w3.Code) + } + +} diff --git a/http/sys_mount.go b/http/sys_mount.go index 61a53f0738..1d3264b1e9 100644 --- a/http/sys_mount.go +++ b/http/sys_mount.go @@ -131,6 +131,7 @@ func handleSysMount( "description": req.Description, }, })) + if err != nil { respondError(w, http.StatusInternalServerError, err) return @@ -149,6 +150,7 @@ func handleSysUnmount( Path: "sys/mounts/" + path, Connection: getConnection(r), })) + if err != nil { respondError(w, http.StatusInternalServerError, err) return diff --git a/logical/error.go b/logical/error.go new file mode 100644 index 0000000000..e32ef007be --- /dev/null +++ b/logical/error.go @@ -0,0 +1,24 @@ +package logical + +type HTTPCodedError interface { + Error() string + Code() int +} + +func CodedError(c int, s string) HTTPCodedError { + return &codedError{s,c} +} + +type codedError struct { + s string + code int +} + +func (e *codedError) Error() string { + return e.s +} + +func (e *codedError) Code() int { + return e.code +} + diff --git a/logical/request.go b/logical/request.go index 92a69903bf..0b123c8119 100644 --- a/logical/request.go +++ b/logical/request.go @@ -145,6 +145,6 @@ var ( // ErrInvalidRequest is returned if the request is invalid ErrInvalidRequest = errors.New("invalid request") - // ErrPermissionDeneid is returned if the client is not authorized + // ErrPermissionDenied is returned if the client is not authorized ErrPermissionDenied = errors.New("permission denied") ) diff --git a/logical/response.go b/logical/response.go index d87a8f7c96..8af0b3de51 100644 --- a/logical/response.go +++ b/logical/response.go @@ -13,7 +13,7 @@ const ( // avoided like the HTTPContentType. The value must be a byte slice. HTTPRawBody = "http_raw_body" - // HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType. + // HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType. // This can only be specified for non-secrets, and should should be similarly // avoided like the HTTPContentType. The value must be an integer. HTTPStatusCode = "http_status_code" diff --git a/vault/logical_system.go b/vault/logical_system.go index 12d521b4a6..a428f10598 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -371,9 +371,21 @@ func (b *SystemBackend) handleMount( // Attempt mount if err := b.Core.mount(me); err != nil { b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err) + return handleError(err) + } + + return nil, nil +} + +// used to intercept an HTTPCodedError so it goes back to callee +func handleError( + err error) (*logical.Response, error) { + switch err.(type) { + case logical.HTTPCodedError: + return logical.ErrorResponse(err.Error()), err + default: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest } - return nil, nil } // handleUnmount is used to unmount a path @@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount( // Attempt unmount if err := b.Core.unmount(suffix); err != nil { b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil @@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount( // Attempt remount if err := b.Core.remount(fromPath, toPath); err != nil { b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil @@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew( resp, err := b.Core.expiration.Renew(leaseID, increment) if err != nil { b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return resp, err } @@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke( // Invoke the expiration manager directly if err := b.Core.expiration.Revoke(leaseID); err != nil { b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix( // Invoke the expiration manager directly if err := b.Core.expiration.RevokePrefix(prefix); err != nil { b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth( // Attempt enabling if err := b.Core.enableCredential(me); err != nil { b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth( // Attempt disable if err := b.Core.disableCredential(suffix); err != nil { b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead( policy, err := b.Core.policy.GetPolicy(name) if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } if policy == nil { @@ -567,7 +579,7 @@ func (b *SystemBackend) handlePolicySet( // Validate the rules parse parse, err := Parse(rules) if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } // Override the name @@ -575,7 +587,7 @@ func (b *SystemBackend) handlePolicySet( // Update the policy if err := b.Core.policy.SetPolicy(parse); err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if err := b.Core.policy.DeletePolicy(name); err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit( // Attempt enabling if err := b.Core.enableAudit(me); err != nil { b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit( // Attempt disable if err := b.Core.disableAudit(path); err != nil { b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead( entry, err := b.Core.barrier.Get(path) if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } if entry == nil { return nil, nil @@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete( } if err := b.Core.barrier.Delete(path); err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } return nil, nil } @@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate( newTerm, err := b.Core.barrier.Rotate() if err != nil { b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err) - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + return handleError(err) } b.Backend.Logger().Printf("[INFO] sys: installed new encryption key") diff --git a/vault/mount.go b/vault/mount.go index 3a85845252..b5c768b580 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -22,7 +22,7 @@ const ( // barrier view for the backends. backendBarrierPrefix = "logical/" - // systemBarrierPrefix is sthe prefix used for the + // systemBarrierPrefix is the prefix used for the // system logical backend. systemBarrierPrefix = "sys/" ) @@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error { me.Path += "/" } - // Prevent protected paths from being unmounted + // Prevent protected paths from being mounted for _, p := range protectedMounts { if strings.HasPrefix(me.Path, p) { - return fmt.Errorf("cannot mount '%s'", me.Path) + return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path)) } } // Verify there is no conflicting mount if match := c.router.MatchingMount(me.Path); match != "" { - return fmt.Errorf("existing mount at '%s'", match) + return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match)) } // Generate a new UUID and view