VAULT-19869: Use Custom Types for Context Keys (#23649)

* create custom type for disable-replication-status-endpoints context key
make use of custom context key type in middleware function

* clean up code to remove various compiler warnings
unnecessary return statement
if condition that is always true
fix use of deprecated ioutil.NopCloser
empty if block

* remove unused unexported function

* clean up code
remove unnecessary nil check around a range expression

* clean up code
removed redundant return statement

* use http.StatusTemporaryRedirect constant instead of literal integer

* create custom type for context key for max_request_size parameter

* create custom type for context key for original request path
This commit is contained in:
Marc Boudreau 2023-10-13 14:04:26 -04:00 committed by GitHub
parent 67d743e273
commit 4e22153987
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 50 deletions

View File

@ -7,7 +7,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
@ -17,6 +17,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/hashicorp/vault/sdk/helper/compressutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
)
type bufCloser struct {
@ -64,11 +65,11 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
var reader io.Reader = req.Body
ctx := req.Context()
maxRequestSize := ctx.Value("max_request_size")
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
}
if max > 0 {
reader = io.LimitReader(req.Body, max)

View File

@ -300,9 +300,9 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
respondError(w, status, err)
return
}
if origBody != nil {
r.Body = ioutil.NopCloser(origBody)
}
r.Body = io.NopCloser(origBody)
input := &logical.LogInput{
Request: req,
}
@ -314,17 +314,16 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
cw := newCopyResponseWriter(w)
h.ServeHTTP(cw, r)
data := make(map[string]interface{})
err = jsonutil.DecodeJSON(cw.body.Bytes(), &data)
if err != nil {
// best effort, ignore
}
// Refactoring this code, since the returned error was being ignored.
jsonutil.DecodeJSON(cw.body.Bytes(), &data)
httpResp := &logical.HTTPResponse{Data: data, Headers: cw.Header()}
input.Response = logical.HTTPResponseToLogicalResponse(httpResp)
err = core.AuditLogger().AuditResponse(r.Context(), input)
if err != nil {
respondError(w, status, err)
}
return
})
}
@ -382,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
// if maxRequestSize < 0, no need to set context value
// Add a size limiter if desired
if maxRequestSize > 0 {
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize)
}
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path)
r = r.WithContext(ctx)
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
@ -465,7 +464,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
h.ServeHTTP(nw, r)
cancelFunc()
return
})
}
@ -557,25 +555,9 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle
r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port)
h.ServeHTTP(w, r)
return
})
}
// stripPrefix is a helper to strip a prefix from the path. It will
// return false from the second return value if it the prefix doesn't exist.
func stripPrefix(prefix, path string) (string, bool) {
if !strings.HasPrefix(path, prefix) {
return "", false
}
path = path[len(prefix):]
if path == "" {
return "", false
}
return path, true
}
func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
header := w.Header()
@ -585,12 +567,12 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler {
respondError(w, http.StatusInternalServerError, err)
return
}
if userHeaders != nil {
for k := range userHeaders {
v := userHeaders.Get(k)
header.Set(k, v)
}
for k := range userHeaders {
v := userHeaders.Get(k)
header.Set(k, v)
}
h.ServeHTTP(w, req)
})
}
@ -602,7 +584,6 @@ func handleUI(h http.Handler) http.Handler {
// here.
req.URL.Path = strings.TrimSuffix(req.URL.Path, "/")
h.ServeHTTP(w, req)
return
})
}
@ -680,8 +661,7 @@ func handleUIStub() http.Handler {
func handleUIRedirect() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, "/ui/", 307)
return
http.Redirect(w, req, "/ui/", http.StatusTemporaryRedirect)
})
}
@ -730,11 +710,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
// against an indefinite amount of data being read.
reader := r.Body
ctx := r.Context()
maxRequestSize := ctx.Value("max_request_size")
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
}
if max > 0 {
// MaxBytesReader won't do all the internal stuff it must unless it's
@ -769,11 +749,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
//
// A nil map will be returned if the format is empty or invalid.
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
maxRequestSize := r.Context().Value("max_request_size")
maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
}
if max > 0 {
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
@ -886,7 +866,6 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
}
forwardRequest(core, w, r)
return
})
}
@ -930,10 +909,8 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
return
}
if header != nil {
for k, v := range header {
w.Header()[k] = v
}
for k, v := range header {
w.Header()[k] = v
}
w.WriteHeader(statusCode)

View File

@ -135,7 +135,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request := r.WithContext(context.WithValue(r.Context(), "disable_replication_status_endpoints", true))
request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true))
h.ServeHTTP(w, request)
})

View File

@ -453,3 +453,33 @@ type CtxKeyRequestRole struct{}
func (c CtxKeyRequestRole) String() string {
return "request-role"
}
// CtxKeyDisableReplicationStatusEndpoints is a custom type used as a key in
// context.Context to store the value `true` when the
// disable_replication_status_endpoints configuration parameter is set to true
// for the listener through which a request was received.
type CtxKeyDisableReplicationStatusEndpoints struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyDisableReplicationStatusEndpoints) String() string {
return "disable-replication-status-endpoints"
}
// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to
// store the value of the max_request_size set for the listener through which
// a request was received.
type CtxKeyMaxRequestSize struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyMaxRequestSize) String() string {
return "max_request_size"
}
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
// to store the original request path.
type CtxKeyOriginalRequestPath struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyOriginalRequestPath) String() string {
return "original_request_path"
}

View File

@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re
t.Fatal(err)
}
req.Header.Add(consts.AuthHeaderName, rootToken)
req = req.WithContext(context.WithValue(req.Context(), "original_request_path", req.URL.Path))
req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path))
statusCode, header, respBytes, err := c.ForwardRequest(req)
if err != nil {

View File

@ -21,6 +21,7 @@ import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/cluster"
"github.com/hashicorp/vault/vault/replication"
"golang.org/x/net/http2"
@ -349,7 +350,7 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
req.URL.Path = origPath
}()
req.URL.Path = req.Context().Value("original_request_path").(string)
req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string)
freq, err := forwarding.GenerateForwardedRequest(req)
if err != nil {