diff --git a/http/handler.go b/http/handler.go index 2a8fc9f933..37e769aa5e 100644 --- a/http/handler.go +++ b/http/handler.go @@ -413,7 +413,36 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr r = newR case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/": - default: + // RFC 5785 + case strings.HasPrefix(r.URL.Path, "/.well-known/"): + standby, err := core.Standby() + if err != nil { + core.Logger().Warn("error resolving standby status handling .well-known path", "error", err) + } else if standby { + respondStandby(core, w, r.URL) + cancelFunc() + return + } else { + redir, err := core.GetWellKnownRedirect(r.Context(), r.URL.Path) + if err != nil { + core.Logger().Warn("error resolving potential API redirect", "error", err) + } else { + if redir != "" { + dest := url.URL{ + Path: redir, + RawQuery: r.URL.RawQuery, + } + w.Header().Set("Location", dest.String()) + if r.Method == http.MethodGet || r.Proto == "HTTP/1.0" { + w.WriteHeader(http.StatusFound) + } else { + w.WriteHeader(http.StatusTemporaryRedirect) + } + cancelFunc() + return + } + } + } respondError(nw, http.StatusNotFound, nil) cancelFunc() return diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index a4ec6483d8..510366add4 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -111,6 +111,12 @@ type ExtendedSystemView interface { // APILockShouldBlockRequest returns whether a namespace for the requested // mount is locked and should be blocked APILockShouldBlockRequest() (bool, error) + + // Register a redirect from .well-known/src to dest, where dest is a subpath of the mount. An error + // is returned if that source path is already taken + RequestWellKnownRedirect(ctx context.Context, src, dest string) error + // Deregister a specific redirect. Returns true if that redirect source was found + DeregisterWellKnownRedirect(ctx context.Context, src string) bool } type PasswordGenerator func() (password string, err error) diff --git a/vault/core.go b/vault/core.go index 33d345d902..73a29f03cc 100644 --- a/vault/core.go +++ b/vault/core.go @@ -18,6 +18,7 @@ import ( "net/http" "net/url" "os" + paths "path" "path/filepath" "runtime" "slices" @@ -132,6 +133,8 @@ const ( "disable Vault from using it. To disable Vault from using it,\n" + "set the `disable_mlock` configuration option in your configuration\n" + "file." + + WellKnownPrefix = "/.well-known/" ) var ( @@ -692,6 +695,7 @@ type Core struct { // If any role based quota (LCQ or RLQ) is enabled, don't track lease counts by role impreciseLeaseRoleTracking bool + WellKnownRedirects *wellKnownRedirectRegistry // RFC 5785 // Config value for "detect_deadlocks". detectDeadlocks []string } @@ -1039,6 +1043,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { rollbackMountPathMetrics: conf.MetricSink.TelemetryConsts.RollbackMetricsIncludeMountPoint, numRollbackWorkers: conf.NumRollbackWorkers, impreciseLeaseRoleTracking: conf.ImpreciseLeaseRoleTracking, + WellKnownRedirects: NewWellKnownRedirects(), detectDeadlocks: detectDeadlocks, } @@ -4226,6 +4231,22 @@ func (c *Core) Events() *eventbus.EventBus { return c.events } +func (c *Core) GetWellKnownRedirect(ctx context.Context, path string) (string, error) { + if c.WellKnownRedirects == nil { + return "", nil + } + path = strings.TrimPrefix(path, WellKnownPrefix) + redir, remaining := c.WellKnownRedirects.Find(path) + if redir != nil { + dest, err := redir.Destination(remaining) + if err != nil { + return "", err + } + return paths.Join("/v1", dest), nil + } + return "", nil +} + func (c *Core) DetectStateLockDeadlocks() bool { if _, ok := c.stateLock.(*locking.DeadlockRWMutex); ok { return true diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 0d3547743b..9b03d3303a 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -39,6 +39,8 @@ type extendedSystemView interface { SudoPrivilege(context.Context, string, string) bool } +var _ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil) + type extendedSystemViewImpl struct { dynamicSystemView } @@ -150,6 +152,14 @@ func (e extendedSystemViewImpl) APILockShouldBlockRequest() (bool, error) { return false, nil } +func (e extendedSystemViewImpl) RequestWellKnownRedirect(ctx context.Context, src, dest string) error { + return e.core.WellKnownRedirects.TryRegister(ctx, e.core, e.mountEntry.UUID, src, dest) +} + +func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context, src string) bool { + return e.core.WellKnownRedirects.DeregisterSource(e.mountEntry.UUID, src) +} + func (d dynamicSystemView) DefaultLeaseTTL() time.Duration { def, _ := d.fetchTTLs() return def diff --git a/vault/external_tests/router/router_ext_test.go b/vault/external_tests/router/router_ext_test.go index c4a0c269c3..b770382d7e 100644 --- a/vault/external_tests/router/router_ext_test.go +++ b/vault/external_tests/router/router_ext_test.go @@ -4,8 +4,13 @@ package router import ( + "context" + "net/http" "testing" + "github.com/hashicorp/vault/helper/testhelpers" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/testhelpers/minimal" "github.com/hashicorp/vault/sdk/logical" @@ -83,3 +88,59 @@ func TestRouter_UnmountRollbackIsntFatal(t *testing.T) { cluster.EnsureCoresSealed(t) cluster.UnsealCores(t) } + +func TestWellKnownRedirect_HA(t *testing.T) { + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + DisablePerformanceStandby: true, + LogicalBackends: map[string]logical.Factory{ + "noop": func(_ context.Context, _ *logical.BackendConfig) (logical.Backend, error) { + return &vault.NoopBackend{ + RequestHandler: func(context.Context, *logical.Request) (*logical.Response, error) { + // Return something for any request + return &logical.Response{ + Data: map[string]interface{}{ + "good": "very", + }, + }, nil + }, + }, nil + }, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + testhelpers.WaitForActiveNodeAndStandbys(t, cluster) + active := testhelpers.DeriveActiveCore(t, cluster) + standbys := testhelpers.DeriveStandbyCores(t, cluster) + standby := standbys[0].Client + + if err := active.Client.Sys().Mount("noop", &api.MountInput{ + Type: "noop", + }); err != nil { + t.Fatalf("failed to mount PKI: %v", err) + } + + resp, err := active.Client.Logical().Read("sys/mounts") + if err != nil { + t.Fatalf("failed to fetch new mount: %v", err) + } + var mountUUID string + for k, m := range resp.Data { + if k == "noop/" { + mountUUID = m.(map[string]interface{})["uuid"].(string) + break + } + } + + if err := active.Core.WellKnownRedirects.TryRegister(context.Background(), active.Core, mountUUID, "foo", "bar"); err != nil { + t.Fatal(err) + } + + standby.SetCheckRedirect(nil) + resp2, err := standby.RawRequest(standby.NewRequest(http.MethodGet, "/.well-known/foo/baz")) + if err != nil { + t.Fatal(err) + } else if resp2.StatusCode != http.StatusOK { + t.Fatal("did not get expected response from noop backend after redirect") + } +} diff --git a/vault/mount.go b/vault/mount.go index 5f1f85bcb8..b72e22a054 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -950,6 +950,8 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b } } + c.WellKnownRedirects.DeregisterMount(entry.UUID) + if c.logger.IsInfo() { c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path) } diff --git a/vault/router_test.go b/vault/router_test.go index d023b35a37..a91c4ad002 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -4,10 +4,13 @@ package vault import ( + "context" "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" @@ -631,3 +634,53 @@ func TestParseUnauthenticatedPaths_Error(t *testing.T) { } } } + +func TestWellKnownRedirectMatching(t *testing.T) { + a := assert.New(t) + // inputs + redirs := map[string]string{ + "foo": "v1/one-path", + "bar/baz": "v1/two-paths", + "baz/": "v1/trailing-slash", + } + + tests := map[string]struct { + expected string + mismatch bool + }{ + "foo": {"/v1/one-path", false}, + "foof": {"", true}, + "foo/extra": {"/v1/one-path/extra", false}, + "bar/baz": {"/v1/two-paths", false}, + "bar/baz/extra": {"/v1/two-paths/extra", false}, + "baz": {"/v1/trailing-slash", false}, + "baz/extra": {"/v1/trailing-slash/extra", false}, + } + apiRedir := NewWellKnownRedirects() + for s, d := range redirs { + if err := apiRedir.TryRegister(context.Background(), nil, "my-mount", s, d); err != nil { + t.Fatal(err) + } + } + + for k, x := range tests { + t.Run(k, func(t *testing.T) { + v, s := apiRedir.Find(k) + if x.mismatch && v != nil { + t.Fail() + } else if !x.mismatch && v == nil { + t.Fail() + } else if !x.mismatch { + d, err := v.Destination(s) + if err != nil { + t.Fatal(err) + } + a.Equal(x.expected, d) + } + }) + } + + if found := apiRedir.DeregisterSource("my-mount", "bar/baz"); !found { + t.Fail() + } +} diff --git a/vault/well_known_redirect.go b/vault/well_known_redirect.go new file mode 100644 index 0000000000..027a4c28c3 --- /dev/null +++ b/vault/well_known_redirect.go @@ -0,0 +1,139 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package vault + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "sync" + + "github.com/armon/go-radix" +) + +type wellKnownRedirect struct { + c *Core + mountUUID string + prefix string + isPrefixMatch bool +} + +type wellKnownRedirectRegistry struct { + lock sync.Mutex + paths *radix.Tree +} + +func NewWellKnownRedirects() *wellKnownRedirectRegistry { + return &wellKnownRedirectRegistry{ + paths: radix.New(), + } +} + +// Attempt to register a mapping from /.well-known/_src_ to /v1/_mount-path_/_dest_ +func (reg *wellKnownRedirectRegistry) TryRegister(ctx context.Context, core *Core, mountUUID, src, dest string) error { + if strings.HasPrefix(dest, "/") { + return errors.New("redirect targets must be relative") + } + src = strings.TrimSuffix(src, "/") + reg.lock.Lock() + defer reg.lock.Unlock() + _, _, found := reg.paths.LongestPrefix(src) + if found { + return fmt.Errorf("api redirect conflict for %s", src) + } + reg.paths.Insert(src, &wellKnownRedirect{ + c: core, + mountUUID: mountUUID, + prefix: dest, + }) + return nil +} + +// Find any relevant redirects for a given source path +func (reg *wellKnownRedirectRegistry) Find(path string) (*wellKnownRedirect, string) { + s, a, found := reg.paths.LongestPrefix(path) + if found { + remaining := strings.TrimPrefix(path, s) + if len(remaining) > 0 { + switch remaining[0] { + case '/': + remaining = remaining[1:] + case '?': + default: + // This isn't an exact path match + return nil, "" + } + } + return a.(*wellKnownRedirect), remaining + } + return nil, "" +} + +// Remove all redirects for a given mount +func (reg *wellKnownRedirectRegistry) DeregisterMount(mountUuid string) { + reg.lock.Lock() + defer reg.lock.Unlock() + + var toDelete []string + reg.paths.Walk(func(k string, v interface{}) bool { + r := v.(*wellKnownRedirect) + if r.mountUUID == mountUuid { + toDelete = append(toDelete, k) + } + return false + }) + for _, d := range toDelete { + reg.paths.Delete(d) + } +} + +// Remove a specific redirect for a mount +func (reg *wellKnownRedirectRegistry) DeregisterSource(mountUuid, src string) bool { + reg.lock.Lock() + defer reg.lock.Unlock() + var found bool + reg.paths.Walk(func(k string, v interface{}) bool { + r := v.(*wellKnownRedirect) + if r.mountUUID == mountUuid && k == src { + found = true + reg.paths.Delete(k) + return true + } + return false + }) + return found +} + +// Construct the full destination of the redirect, including any remaining path past the src +func (a *wellKnownRedirect) Destination(remaining string) (string, error) { + var destPath string + if a.c == nil { + // Just for testing + destPath = a.prefix + } else { + m := a.c.router.MatchingMountByUUID(a.mountUUID) + + if m == nil { + return "", fmt.Errorf("cannot find backend with uuid: %s", a.mountUUID) + } + var err error + destPath, err = url.JoinPath(m.Namespace().Path, m.Path, a.prefix) + if err != nil { + return "", err + } + } + + u := url.URL{ + Path: destPath + "/", + } + r, err := url.Parse(remaining) + if err != nil { + return "", err + } + dest := u.ResolveReference(r) + dest.Path = strings.TrimSuffix(dest.Path, "/") + return dest.String(), nil +}