diff --git a/http/logical.go b/http/logical.go index 33cdb6538e..5340cc6ecc 100644 --- a/http/logical.go +++ b/http/logical.go @@ -155,7 +155,7 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl return } - if resp.WrapInfo.Token != "" { + if resp.WrapInfo != nil && resp.WrapInfo.Token != "" { httpResp = logical.HTTPResponse{ WrapInfo: &logical.HTTPWrapInfo{ Token: resp.WrapInfo.Token, diff --git a/logical/response.go b/logical/response.go index 6c310d991c..88dd9934aa 100644 --- a/logical/response.go +++ b/logical/response.go @@ -66,7 +66,7 @@ type Response struct { warnings []string // Information for wrapping the response in a cubbyhole - WrapInfo WrapInfo + WrapInfo *WrapInfo } func init() { diff --git a/vault/request_handling.go b/vault/request_handling.go index 5ae6366564..bc4d9bd790 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -78,7 +78,7 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err // We are wrapping if there is anything to wrap (not a nil response) and a // TTL was specified for the token, plus if cubbyhole is mounted (which // will be the case normally) - wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo.TTL != 0 + wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0 // If we are wrapping, the first part happens before auditing so that // resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the diff --git a/vault/router.go b/vault/router.go index e0ed222b3f..70dd94533a 100644 --- a/vault/router.go +++ b/vault/router.go @@ -256,15 +256,25 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l // If either of the request or response requested wrapping, ensure that // the lowest value is what ends up in the response. switch { - case req.WrapTTL == 0 && resp.WrapInfo.TTL == 0: - case req.WrapTTL != 0 && resp.WrapInfo.TTL != 0: + case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0): + // Neither defines it, so do nothing + + case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0): + // Both define, so use the lowest if req.WrapTTL < resp.WrapInfo.TTL { resp.WrapInfo.TTL = req.WrapTTL } + case req.WrapTTL != 0: - resp.WrapInfo.TTL = req.WrapTTL - // Only case left is that only resp defines it, which doesn't need to - // be explicitly handled + // Response wrap info doesn't exist, or its TTL is zero, so set + // it to the request TTL + resp.WrapInfo = &logical.WrapInfo{ + TTL: req.WrapTTL, + } + + default: + // Only case left is that only resp defines it, which doesn't + // need to be explicitly handled } } diff --git a/vault/router_test.go b/vault/router_test.go index 6f73eee02a..feaa4648b8 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -39,7 +39,7 @@ func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, er } if n.WrapTTL != 0 { - n.Response.WrapInfo.TTL = n.WrapTTL + n.Response.WrapInfo = &logical.WrapInfo{TTL: n.WrapTTL} } return n.Response, nil @@ -432,7 +432,7 @@ func TestRouter_Wrapping(t *testing.T) { if resp == nil { t.Fatalf("bad: %v", resp) } - if resp.WrapInfo.TTL != time.Duration(15*time.Second) { + if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) { t.Fatalf("bad: %#v", resp) } @@ -450,7 +450,7 @@ func TestRouter_Wrapping(t *testing.T) { if resp == nil { t.Fatalf("bad: %v", resp) } - if resp.WrapInfo.TTL != time.Duration(15*time.Second) { + if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) { t.Fatalf("bad: %#v", resp) } @@ -469,7 +469,7 @@ func TestRouter_Wrapping(t *testing.T) { if resp == nil { t.Fatalf("bad: %v", resp) } - if resp.WrapInfo.TTL != time.Duration(10*time.Second) { + if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) { t.Fatalf("bad: %#v", resp) } @@ -488,7 +488,7 @@ func TestRouter_Wrapping(t *testing.T) { if resp == nil { t.Fatalf("bad: %v", resp) } - if resp.WrapInfo.TTL != time.Duration(10*time.Second) { + if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) { t.Fatalf("bad: %#v", resp) } }