VAULT-28255: Fix namespaced redirects (#27660)

* handle namespaced events redirects

* full test:

* changelog

* lint
This commit is contained in:
miagilepner 2024-07-03 10:08:39 +02:00 committed by GitHub
parent fc19a9ce9c
commit 9e299c2896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 55 additions and 22 deletions

3
changelog/27660.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
core (enterprise): Fix HTTP redirects in namespaces to use the correct path and (in the case of event subscriptions) the correct URI scheme.
```

View File

@ -116,7 +116,7 @@ var (
"/v1/sys/wrapping/wrap", "/v1/sys/wrapping/wrap",
} }
websocketRawPaths = []string{ websocketRawPaths = []string{
"/v1/sys/events/subscribe", "sys/events/subscribe",
} }
oidcProtectedPathRegex = regexp.MustCompile(`^identity/oidc/provider/\w(([\w-.]+)?\w)?/userinfo$`) oidcProtectedPathRegex = regexp.MustCompile(`^identity/oidc/provider/\w(([\w-.]+)?\w)?/userinfo$`)
) )
@ -128,9 +128,7 @@ func init() {
"!sys/storage/raft/snapshot-auto/config", "!sys/storage/raft/snapshot-auto/config",
}) })
websocketPaths.AddPaths(websocketRawPaths) websocketPaths.AddPaths(websocketRawPaths)
for _, path := range websocketRawPaths { alwaysRedirectPaths.AddPaths(websocketRawPaths)
alwaysRedirectPaths.AddPaths([]string{strings.TrimPrefix(path, "/v1/")})
}
} }
type HandlerAnchor struct{} type HandlerAnchor struct{}
@ -434,7 +432,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
} else if standby && !perfStandby { } else if standby && !perfStandby {
// Standby nodes, not performance standbys, don't start plugins // Standby nodes, not performance standbys, don't start plugins
// so registration can not happen, instead redirect to active // so registration can not happen, instead redirect to active
respondStandby(core, w, r.URL) respondStandby(core, w, r)
cancelFunc() cancelFunc()
return return
} else { } else {
@ -909,7 +907,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
respondError(w, http.StatusBadRequest, err) respondError(w, http.StatusBadRequest, err)
return return
} }
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) path := trimPath(ns, r.URL.Path)
if !perfStandbyAlwaysForwardPaths.HasPath(path) && !alwaysRedirectPaths.HasPath(path) { if !perfStandbyAlwaysForwardPaths.HasPath(path) && !alwaysRedirectPaths.HasPath(path) {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
return return
@ -946,14 +944,14 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { if r.Header.Get(vault.IntNoForwardingHeaderName) != "" {
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
if r.Header.Get(NoRequestForwardingHeaderName) != "" { if r.Header.Get(NoRequestForwardingHeaderName) != "" {
// Forwarding explicitly disabled, fall back to previous behavior // Forwarding explicitly disabled, fall back to previous behavior
core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request") core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request")
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
@ -962,10 +960,25 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
respondError(w, http.StatusBadRequest, err) respondError(w, http.StatusBadRequest, err)
return return
} }
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) path := trimPath(ns, r.URL.Path)
if alwaysRedirectPaths.HasPath(path) { redirect := alwaysRedirectPaths.HasPath(path)
// websocket paths are special, because they can contain a namespace
// in front of them. This isn't an issue on perf standbys where the
// namespace manager will know all the namespaces, so we will have
// already extracted it from the path. But regular standbys don't have
// knowledge of the namespaces, so we need
// to add an extra check
if !redirect && !core.PerfStandby() {
for _, websocketPath := range websocketRawPaths {
if strings.Contains(path, websocketPath) {
redirect = true
break
}
}
}
if redirect {
core.Logger().Trace("cannot forward request (path included in always redirect paths), falling back to redirection to standby") core.Logger().Trace("cannot forward request (path included in always redirect paths), falling back to redirection to standby")
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
@ -981,7 +994,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
} }
// Fall back to redirection // Fall back to redirection
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
@ -1045,7 +1058,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
return resp, false, false return resp, false, false
} }
if errwrap.Contains(err, consts.ErrStandby.Error()) { if errwrap.Contains(err, consts.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL) respondStandby(core, w, rawReq)
return resp, false, false return resp, false, false
} }
if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) { if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) {
@ -1094,7 +1107,8 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
} }
// respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby
func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { func respondStandby(core *vault.Core, w http.ResponseWriter, r *http.Request) {
reqURL := r.URL
// Request the leader address // Request the leader address
_, redirectAddr, _, err := core.Leader() _, redirectAddr, _, err := core.Leader()
if err != nil { if err != nil {
@ -1131,8 +1145,13 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
RawQuery: reqURL.RawQuery, RawQuery: reqURL.RawQuery,
} }
ctx := r.Context()
ns, err := namespace.FromContext(ctx)
if err != nil {
respondError(w, http.StatusBadRequest, err)
}
// WebSockets schemas are ws or wss // WebSockets schemas are ws or wss
if websocketPaths.HasPath(reqURL.Path) { if websocketPaths.HasPath(trimPath(ns, reqURL.Path)) {
if finalURL.Scheme == "http" { if finalURL.Scheme == "http" {
finalURL.Scheme = "ws" finalURL.Scheme = "ws"
} else { } else {
@ -1140,6 +1159,11 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
} }
} }
originalPath, ok := logical.ContextOriginalRequestPathValue(ctx)
if ok {
finalURL.Path = originalPath
}
// Ensure there is a scheme, default to https // Ensure there is a scheme, default to https
if finalURL.Scheme == "" { if finalURL.Scheme == "" {
finalURL.Scheme = "https" finalURL.Scheme = "https"
@ -1391,3 +1415,8 @@ func respondOIDCPermissionDenied(w http.ResponseWriter) {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
enc.Encode(oidcResponse) enc.Encode(oidcResponse)
} }
// trimPath removes the /v1/ prefix and the namespace from the path
func trimPath(ns *namespace.Namespace, path string) string {
return ns.TrimmedPath(path[len("/v1/"):])
}

View File

@ -40,7 +40,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) {
respondError(w, http.StatusNotFound, errors.New("Missing /v1/ prefix in path. Use vault path-help command to retrieve API help for paths")) respondError(w, http.StatusNotFound, errors.New("Missing /v1/ prefix in path. Use vault path-help command to retrieve API help for paths"))
return return
} }
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) path := trimPath(ns, r.URL.Path)
req := &logical.Request{ req := &logical.Request{
Operation: logical.HelpOperation, Operation: logical.HelpOperation,

View File

@ -50,8 +50,7 @@ func buildLogicalRequestNoAuth(perfStandby bool, ra *vault.RouterAccess, w http.
if err != nil { if err != nil {
return nil, nil, http.StatusBadRequest, nil return nil, nil, http.StatusBadRequest, nil
} }
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) path := trimPath(ns, r.URL.Path)
var data map[string]interface{} var data map[string]interface{}
var origBody io.ReadCloser var origBody io.ReadCloser
var passHTTPReq bool var passHTTPReq bool
@ -361,11 +360,13 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
respondError(w, http.StatusInternalServerError, err) respondError(w, http.StatusInternalServerError, err)
return return
} }
trimmedPath := trimPath(ns, r.URL.Path)
nsPath := ns.Path nsPath := ns.Path
if ns.ID == namespace.RootNamespaceID { if ns.ID == namespace.RootNamespaceID {
nsPath = "" nsPath = ""
} }
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("/v1/%ssys/events/subscribe/", nsPath)) { if websocketPaths.HasPath(trimmedPath) {
handler := entHandleEventsSubscribe(core, req) handler := entHandleEventsSubscribe(core, req)
if handler != nil { if handler != nil {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)

View File

@ -20,7 +20,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
standby, _ := core.Standby() standby, _ := core.Standby()
if standby { if standby {
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
@ -155,7 +155,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
standby, _ := core.Standby() standby, _ := core.Standby()
if standby { if standby {
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }
@ -228,7 +228,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
standby, _ := core.Standby() standby, _ := core.Standby()
if standby { if standby {
respondStandby(core, w, r.URL) respondStandby(core, w, r)
return return
} }