From 9d31716ab978ff75d7ab591ea0d4d4f29d75cd6a Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Wed, 12 Feb 2020 14:20:22 -0800 Subject: [PATCH] Support processing parameters sent as a URL-encoded form (#8325) --- http/handler.go | 75 ++++++++++++++++++++++++++++++++++++++- http/handler_test.go | 64 +++++++++++++++++++++++++++++++++ http/logical.go | 58 +++++++++++++++++++++++++----- http/logical_test.go | 25 +++++++++++++ http/sys_generate_root.go | 4 +-- http/sys_init.go | 2 +- http/sys_raft.go | 2 +- http/sys_rekey.go | 6 ++-- http/sys_seal.go | 2 +- 9 files changed, 220 insertions(+), 18 deletions(-) diff --git a/http/handler.go b/http/handler.go index 8ae1b225ed..d07f4169ea 100644 --- a/http/handler.go +++ b/http/handler.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/ioutil" + "mime" "net" "net/http" "net/textproto" @@ -566,7 +567,7 @@ func parseQuery(values url.Values) map[string]interface{} { return nil } -func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) { +func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) { // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. reader := r.Body @@ -598,6 +599,44 @@ func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out return nil, err } +// parseFormRequest parses values from a form POST. +// +// A nil map will be returned if the format is empty or invalid. +func parseFormRequest(r *http.Request) (map[string]interface{}, error) { + maxRequestSize := r.Context().Value("max_request_size") + if maxRequestSize != nil { + max, ok := maxRequestSize.(int64) + if !ok { + return nil, errors.New("could not parse max_request_size from request context") + } + if max > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max)) + } + } + if err := r.ParseForm(); err != nil { + return nil, err + } + + var data map[string]interface{} + + if len(r.PostForm) != 0 { + data = make(map[string]interface{}, len(r.PostForm)) + for k, v := range r.PostForm { + switch len(v) { + case 0: + case 1: + data[k] = v[0] + default: + // Almost anywhere taking in a string list can take in comma + // separated values, and really this is super niche anyways + data[k] = strings.Join(v, ",") + } + } + } + + return data, nil +} + // handleRequestForwarding determines whether to forward a request or not, // falling back on the older behavior of redirecting the client func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler { @@ -960,6 +999,40 @@ func parseMFAHeader(req *logical.Request) error { return nil } +// isForm tries to determine whether the request should be +// processed as a form or as JSON. +// +// Virtually all existing use cases have assumed processing as JSON, +// and there has not been a Content-Type requirement in the API. In order to +// maintain backwards compatibility, this will err on the side of JSON. +// The request will be considered a form only if: +// +// 1. The content type is "application/x-www-form-urlencoded" +// 2. The start of the request doesn't look like JSON. For this test we +// we expect the body to begin with { or [, ignoring leading whitespace. +func isForm(head []byte, contentType string) bool { + contentType, _, err := mime.ParseMediaType(contentType) + + if err != nil || contentType != "application/x-www-form-urlencoded" { + return false + } + + // Look for the start of JSON or not-JSON, skipping any insignificant + // whitespace (per https://tools.ietf.org/html/rfc7159#section-2). + for _, c := range head { + switch c { + case ' ', '\t', '\n', '\r': + continue + case '[', '{': // JSON + return false + default: // not JSON + return true + } + } + + return true +} + func respondError(w http.ResponseWriter, status int, err error) { logical.RespondError(w, status, err) } diff --git a/http/handler_test.go b/http/handler_test.go index 31079ef9cf..96cee3b387 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -2,11 +2,14 @@ package http import ( "context" + "crypto/tls" "encoding/json" "errors" + "io/ioutil" "net/http" "net/http/httptest" "net/textproto" + "net/url" "reflect" "strings" "testing" @@ -676,3 +679,64 @@ func testNonPrintable(t *testing.T, disable bool) { testResponseStatus(t, resp, 400) } } + +func TestHandler_Parse_Form(t *testing.T) { + cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + vault.TestWaitActive(t, core) + + c := cleanhttp.DefaultClient() + c.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: cluster.RootCAs, + }, + } + + values := url.Values{ + "zip": []string{"zap"}, + "abc": []string{"xyz"}, + "multi": []string{"first", "second"}, + "empty": []string{}, + } + req, err := http.NewRequest("POST", cores[0].Client.Address()+"/v1/secret/foo", nil) + if err != nil { + t.Fatal(err) + } + req.Body = ioutil.NopCloser(strings.NewReader(values.Encode())) + req.Header.Set("x-vault-token", cluster.RootToken) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 204 { + t.Fatalf("bad response: %#v\nrequest was: %#v\nurl was: %#v", *resp, *req, req.URL) + } + + client := cores[0].Client + client.SetToken(cluster.RootToken) + + apiResp, err := client.Logical().Read("secret/foo") + if err != nil { + t.Fatal(err) + } + if apiResp == nil { + t.Fatal("api resp is nil") + } + expected := map[string]interface{}{ + "zip": "zap", + "abc": "xyz", + "multi": "first,second", + } + if diff := deep.Equal(expected, apiResp.Data); diff != nil { + t.Fatal(diff) + } +} diff --git a/http/logical.go b/http/logical.go index f198ff403e..25572a0189 100644 --- a/http/logical.go +++ b/http/logical.go @@ -1,6 +1,7 @@ package http import ( + "bufio" "encoding/base64" "encoding/json" "fmt" @@ -20,6 +21,24 @@ import ( "go.uber.org/atomic" ) +// bufferedReader can be used to replace a request body with a buffered +// version. The Close method invokes the original Closer. +type bufferedReader struct { + *bufio.Reader + rOrig io.ReadCloser +} + +func newBufferedReader(r io.ReadCloser) *bufferedReader { + return &bufferedReader{ + Reader: bufio.NewReader(r), + rOrig: r, + } +} + +func (b *bufferedReader) Close() error { + return b.rOrig.Close() +} + func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { ns, err := namespace.FromContext(r.Context()) if err != nil { @@ -71,16 +90,37 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. case "POST", "PUT": op = logical.UpdateOperation - // Parse the request if we can - if op == logical.UpdateOperation { - // If we are uploading a snapshot we don't want to parse it. Instead - // we will simply add the HTTP request to the logical request object - // for later consumption. - if path == "sys/storage/raft/snapshot" || path == "sys/storage/raft/snapshot-force" { - passHTTPReq = true - origBody = r.Body + + // Buffer the request body in order to allow us to peek at the beginning + // without consuming it. This approach involves no copying. + bufferedBody := newBufferedReader(r.Body) + r.Body = bufferedBody + + // If we are uploading a snapshot we don't want to parse it. Instead + // we will simply add the HTTP request to the logical request object + // for later consumption. + if path == "sys/storage/raft/snapshot" || path == "sys/storage/raft/snapshot-force" { + passHTTPReq = true + origBody = r.Body + } else { + // Sample the first bytes to determine whether this should be parsed as + // a form or as JSON. The amount to look ahead (512 bytes) is arbitrary + // but extremely tolerant (i.e. allowing 511 bytes of leading whitespace + // and an incorrect content-type). + head, err := bufferedBody.Peek(512) + if err != nil && err != bufio.ErrBufferFull && err != io.EOF { + return nil, nil, http.StatusBadRequest, err + } + + if isForm(head, r.Header.Get("Content-Type")) { + formData, err := parseFormRequest(r) + if err != nil { + return nil, nil, http.StatusBadRequest, fmt.Errorf("error parsing form data: %w", err) + } + + data = formData } else { - origBody, err = parseRequest(perfStandby, r, w, &data) + origBody, err = parseJSONRequest(perfStandby, r, w, &data) if err == io.EOF { data = nil err = nil diff --git a/http/logical_test.go b/http/logical_test.go index df385bd05e..8f836dc49b 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -437,3 +437,28 @@ func TestLogical_Audit_invalidWrappingToken(t *testing.T) { } } } + +func TestLogical_ShouldParseForm(t *testing.T) { + const formCT = "application/x-www-form-urlencoded" + + tests := map[string]struct { + prefix string + contentType string + isForm bool + }{ + "JSON": {`{"a":42}`, formCT, false}, + "JSON 2": {`[42]`, formCT, false}, + "JSON w/leading space": {" \n\n\r\t [42] ", formCT, false}, + "Form": {"a=42&b=dog", formCT, true}, + "Form w/wrong CT": {"a=42&b=dog", "application/json", false}, + } + + for name, test := range tests { + isForm := isForm([]byte(test.prefix), test.contentType) + + if isForm != test.isForm { + t.Fatalf("%s fail: expected isForm %t, got %t", name, test.isForm, isForm) + } + } + +} diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index dae751a469..12d829d78f 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { // Parse the request var req GenerateRootInitRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { respondError(w, http.StatusBadRequest, err) return } @@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Parse the request var req GenerateRootUpdateRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_init.go b/http/sys_init.go index ca77ada911..b21e5363ea 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -39,7 +39,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) // Parse the request var req InitRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_raft.go b/http/sys_raft.go index d72517eb69..c36f87310d 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -26,7 +26,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) { // Parse the request var req JoinRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_rekey.go b/http/sys_rekey.go index eb8760f927..d1cec653a6 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } @@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Parse the request var req RekeyUpdateRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } @@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyVerificationUpdateRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_seal.go b/http/sys_seal.go index 1cf520c098..a13573addd 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { // Parse the request var req UnsealRequest - if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { + if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return }