mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 12:26:34 +02:00
VAULT-19869: Use Custom Types for Context Keys (#23649)
* create custom type for disable-replication-status-endpoints context key make use of custom context key type in middleware function * clean up code to remove various compiler warnings unnecessary return statement if condition that is always true fix use of deprecated ioutil.NopCloser empty if block * remove unused unexported function * clean up code remove unnecessary nil check around a range expression * clean up code removed redundant return statement * use http.StatusTemporaryRedirect constant instead of literal integer * create custom type for context key for max_request_size parameter * create custom type for context key for original request path
This commit is contained in:
parent
67d743e273
commit
4e22153987
@ -7,7 +7,7 @@ import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
type bufCloser struct {
|
||||
@ -64,11 +65,11 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
|
||||
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
|
||||
var reader io.Reader = req.Body
|
||||
ctx := req.Context()
|
||||
maxRequestSize := ctx.Value("max_request_size")
|
||||
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse max_request_size from request context")
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
}
|
||||
if max > 0 {
|
||||
reader = io.LimitReader(req.Body, max)
|
||||
|
||||
@ -300,9 +300,9 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
|
||||
respondError(w, status, err)
|
||||
return
|
||||
}
|
||||
if origBody != nil {
|
||||
r.Body = ioutil.NopCloser(origBody)
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(origBody)
|
||||
|
||||
input := &logical.LogInput{
|
||||
Request: req,
|
||||
}
|
||||
@ -314,17 +314,16 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
|
||||
cw := newCopyResponseWriter(w)
|
||||
h.ServeHTTP(cw, r)
|
||||
data := make(map[string]interface{})
|
||||
err = jsonutil.DecodeJSON(cw.body.Bytes(), &data)
|
||||
if err != nil {
|
||||
// best effort, ignore
|
||||
}
|
||||
|
||||
// Refactoring this code, since the returned error was being ignored.
|
||||
jsonutil.DecodeJSON(cw.body.Bytes(), &data)
|
||||
|
||||
httpResp := &logical.HTTPResponse{Data: data, Headers: cw.Header()}
|
||||
input.Response = logical.HTTPResponseToLogicalResponse(httpResp)
|
||||
err = core.AuditLogger().AuditResponse(r.Context(), input)
|
||||
if err != nil {
|
||||
respondError(w, status, err)
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
@ -382,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
||||
// if maxRequestSize < 0, no need to set context value
|
||||
// Add a size limiter if desired
|
||||
if maxRequestSize > 0 {
|
||||
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize)
|
||||
}
|
||||
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path)
|
||||
r = r.WithContext(ctx)
|
||||
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
|
||||
|
||||
@ -465,7 +464,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
||||
h.ServeHTTP(nw, r)
|
||||
|
||||
cancelFunc()
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
@ -557,25 +555,9 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle
|
||||
|
||||
r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port)
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// stripPrefix is a helper to strip a prefix from the path. It will
|
||||
// return false from the second return value if it the prefix doesn't exist.
|
||||
func stripPrefix(prefix, path string) (string, bool) {
|
||||
if !strings.HasPrefix(path, prefix) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
path = path[len(prefix):]
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return path, true
|
||||
}
|
||||
|
||||
func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
header := w.Header()
|
||||
@ -585,12 +567,12 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
if userHeaders != nil {
|
||||
for k := range userHeaders {
|
||||
v := userHeaders.Get(k)
|
||||
header.Set(k, v)
|
||||
}
|
||||
|
||||
for k := range userHeaders {
|
||||
v := userHeaders.Get(k)
|
||||
header.Set(k, v)
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
@ -602,7 +584,6 @@ func handleUI(h http.Handler) http.Handler {
|
||||
// here.
|
||||
req.URL.Path = strings.TrimSuffix(req.URL.Path, "/")
|
||||
h.ServeHTTP(w, req)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
@ -680,8 +661,7 @@ func handleUIStub() http.Handler {
|
||||
|
||||
func handleUIRedirect() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
http.Redirect(w, req, "/ui/", 307)
|
||||
return
|
||||
http.Redirect(w, req, "/ui/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
}
|
||||
|
||||
@ -730,11 +710,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
||||
// against an indefinite amount of data being read.
|
||||
reader := r.Body
|
||||
ctx := r.Context()
|
||||
maxRequestSize := ctx.Value("max_request_size")
|
||||
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse max_request_size from request context")
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
}
|
||||
if max > 0 {
|
||||
// MaxBytesReader won't do all the internal stuff it must unless it's
|
||||
@ -769,11 +749,11 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
||||
//
|
||||
// A nil map will be returned if the format is empty or invalid.
|
||||
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
|
||||
maxRequestSize := r.Context().Value("max_request_size")
|
||||
maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse max_request_size from request context")
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
}
|
||||
if max > 0 {
|
||||
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
|
||||
@ -886,7 +866,6 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
|
||||
}
|
||||
|
||||
forwardRequest(core, w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
@ -930,10 +909,8 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if header != nil {
|
||||
for k, v := range header {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
for k, v := range header {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
@ -135,7 +135,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
||||
|
||||
func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
request := r.WithContext(context.WithValue(r.Context(), "disable_replication_status_endpoints", true))
|
||||
request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true))
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
})
|
||||
|
||||
@ -453,3 +453,33 @@ type CtxKeyRequestRole struct{}
|
||||
func (c CtxKeyRequestRole) String() string {
|
||||
return "request-role"
|
||||
}
|
||||
|
||||
// CtxKeyDisableReplicationStatusEndpoints is a custom type used as a key in
|
||||
// context.Context to store the value `true` when the
|
||||
// disable_replication_status_endpoints configuration parameter is set to true
|
||||
// for the listener through which a request was received.
|
||||
type CtxKeyDisableReplicationStatusEndpoints struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyDisableReplicationStatusEndpoints) String() string {
|
||||
return "disable-replication-status-endpoints"
|
||||
}
|
||||
|
||||
// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to
|
||||
// store the value of the max_request_size set for the listener through which
|
||||
// a request was received.
|
||||
type CtxKeyMaxRequestSize struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyMaxRequestSize) String() string {
|
||||
return "max_request_size"
|
||||
}
|
||||
|
||||
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
|
||||
// to store the original request path.
|
||||
type CtxKeyOriginalRequestPath struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyOriginalRequestPath) String() string {
|
||||
return "original_request_path"
|
||||
}
|
||||
|
||||
@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
req = req.WithContext(context.WithValue(req.Context(), "original_request_path", req.URL.Path))
|
||||
req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path))
|
||||
|
||||
statusCode, header, respBytes, err := c.ForwardRequest(req)
|
||||
if err != nil {
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/forwarding"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
"golang.org/x/net/http2"
|
||||
@ -349,7 +350,7 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
|
||||
req.URL.Path = origPath
|
||||
}()
|
||||
|
||||
req.URL.Path = req.Context().Value("original_request_path").(string)
|
||||
req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string)
|
||||
|
||||
freq, err := forwarding.GenerateForwardedRequest(req)
|
||||
if err != nil {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user