From 4e22153987cf2f9064a1d307d46f440595afc975 Mon Sep 17 00:00:00 2001 From: Marc Boudreau Date: Fri, 13 Oct 2023 14:04:26 -0400 Subject: [PATCH] VAULT-19869: Use Custom Types for Context Keys (#23649) * create custom type for disable-replication-status-endpoints context key make use of custom context key type in middleware function * clean up code to remove various compiler warnings unnecessary return statement if condition that is always true fix use of deprecated ioutil.NopCloser empty if block * remove unused unexported function * clean up code remove unnecessary nil check around a range expression * clean up code removed redundant return statement * use http.StatusTemporaryRedirect constant instead of literal integer * create custom type for context key for max_request_size parameter * create custom type for context key for original request path --- helper/forwarding/util.go | 7 ++-- http/handler.go | 65 ++++++++++++------------------------- http/util.go | 2 +- sdk/logical/request.go | 30 +++++++++++++++++ vault/cluster_test.go | 2 +- vault/request_forwarding.go | 3 +- 6 files changed, 59 insertions(+), 50 deletions(-) diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index 74ea5abe1d..a712c11885 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -7,7 +7,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" - "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -17,6 +17,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/hashicorp/vault/sdk/helper/compressutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" + "github.com/hashicorp/vault/sdk/logical" ) type bufCloser struct { @@ -64,11 +65,11 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request func GenerateForwardedRequest(req *http.Request) (*Request, error) { var reader io.Reader = req.Body ctx := req.Context() - maxRequestSize := ctx.Value("max_request_size") + maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{}) if maxRequestSize != nil { max, ok := maxRequestSize.(int64) if !ok { - return nil, errors.New("could not parse max_request_size from request context") + return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) } if max > 0 { reader = io.LimitReader(req.Body, max) diff --git a/http/handler.go b/http/handler.go index 731b3ba804..43a92aa9e5 100644 --- a/http/handler.go +++ b/http/handler.go @@ -300,9 +300,9 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { respondError(w, status, err) return } - if origBody != nil { - r.Body = ioutil.NopCloser(origBody) - } + + r.Body = io.NopCloser(origBody) + input := &logical.LogInput{ Request: req, } @@ -314,17 +314,16 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { cw := newCopyResponseWriter(w) h.ServeHTTP(cw, r) data := make(map[string]interface{}) - err = jsonutil.DecodeJSON(cw.body.Bytes(), &data) - if err != nil { - // best effort, ignore - } + + // Refactoring this code, since the returned error was being ignored. + jsonutil.DecodeJSON(cw.body.Bytes(), &data) + httpResp := &logical.HTTPResponse{Data: data, Headers: cw.Header()} input.Response = logical.HTTPResponseToLogicalResponse(httpResp) err = core.AuditLogger().AuditResponse(r.Context(), input) if err != nil { respondError(w, status, err) } - return }) } @@ -382,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr // if maxRequestSize < 0, no need to set context value // Add a size limiter if desired if maxRequestSize > 0 { - ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) + ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize) } - ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) + ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path) r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -465,7 +464,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr h.ServeHTTP(nw, r) cancelFunc() - return }) } @@ -557,25 +555,9 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port) h.ServeHTTP(w, r) - return }) } -// stripPrefix is a helper to strip a prefix from the path. It will -// return false from the second return value if it the prefix doesn't exist. -func stripPrefix(prefix, path string) (string, bool) { - if !strings.HasPrefix(path, prefix) { - return "", false - } - - path = path[len(prefix):] - if path == "" { - return "", false - } - - return path, true -} - func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { header := w.Header() @@ -585,12 +567,12 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { respondError(w, http.StatusInternalServerError, err) return } - if userHeaders != nil { - for k := range userHeaders { - v := userHeaders.Get(k) - header.Set(k, v) - } + + for k := range userHeaders { + v := userHeaders.Get(k) + header.Set(k, v) } + h.ServeHTTP(w, req) }) } @@ -602,7 +584,6 @@ func handleUI(h http.Handler) http.Handler { // here. req.URL.Path = strings.TrimSuffix(req.URL.Path, "/") h.ServeHTTP(w, req) - return }) } @@ -680,8 +661,7 @@ func handleUIStub() http.Handler { func handleUIRedirect() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - http.Redirect(w, req, "/ui/", 307) - return + http.Redirect(w, req, "/ui/", http.StatusTemporaryRedirect) }) } @@ -730,11 +710,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // against an indefinite amount of data being read. reader := r.Body ctx := r.Context() - maxRequestSize := ctx.Value("max_request_size") + maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{}) if maxRequestSize != nil { max, ok := maxRequestSize.(int64) if !ok { - return nil, errors.New("could not parse max_request_size from request context") + return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) } if max > 0 { // MaxBytesReader won't do all the internal stuff it must unless it's @@ -769,11 +749,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // // A nil map will be returned if the format is empty or invalid. func parseFormRequest(r *http.Request) (map[string]interface{}, error) { - maxRequestSize := r.Context().Value("max_request_size") + maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{}) if maxRequestSize != nil { max, ok := maxRequestSize.(int64) if !ok { - return nil, errors.New("could not parse max_request_size from request context") + return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) } if max > 0 { r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max)) @@ -886,7 +866,6 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle } forwardRequest(core, w, r) - return }) } @@ -930,10 +909,8 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { return } - if header != nil { - for k, v := range header { - w.Header()[k] = v - } + for k, v := range header { + w.Header()[k] = v } w.WriteHeader(statusCode) diff --git a/http/util.go b/http/util.go index bbf49951dc..2b0d7cc0d2 100644 --- a/http/util.go +++ b/http/util.go @@ -135,7 +135,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - request := r.WithContext(context.WithValue(r.Context(), "disable_replication_status_endpoints", true)) + request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true)) h.ServeHTTP(w, request) }) diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 39d5bbe262..bb42db7dd3 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -453,3 +453,33 @@ type CtxKeyRequestRole struct{} func (c CtxKeyRequestRole) String() string { return "request-role" } + +// CtxKeyDisableReplicationStatusEndpoints is a custom type used as a key in +// context.Context to store the value `true` when the +// disable_replication_status_endpoints configuration parameter is set to true +// for the listener through which a request was received. +type CtxKeyDisableReplicationStatusEndpoints struct{} + +// String returns a string representation of the receiver type. +func (c CtxKeyDisableReplicationStatusEndpoints) String() string { + return "disable-replication-status-endpoints" +} + +// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to +// store the value of the max_request_size set for the listener through which +// a request was received. +type CtxKeyMaxRequestSize struct{} + +// String returns a string representation of the receiver type. +func (c CtxKeyMaxRequestSize) String() string { + return "max_request_size" +} + +// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context +// to store the original request path. +type CtxKeyOriginalRequestPath struct{} + +// String returns a string representation of the receiver type. +func (c CtxKeyOriginalRequestPath) String() string { + return "original_request_path" +} diff --git a/vault/cluster_test.go b/vault/cluster_test.go index 8e56909af8..c890582bea 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re t.Fatal(err) } req.Header.Add(consts.AuthHeaderName, rootToken) - req = req.WithContext(context.WithValue(req.Context(), "original_request_path", req.URL.Path)) + req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path)) statusCode, header, respBytes, err := c.ForwardRequest(req) if err != nil { diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 1532ebe31e..440de62b34 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -21,6 +21,7 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/forwarding" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/replication" "golang.org/x/net/http2" @@ -349,7 +350,7 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro req.URL.Path = origPath }() - req.URL.Path = req.Context().Value("original_request_path").(string) + req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string) freq, err := forwarding.GenerateForwardedRequest(req) if err != nil {