diff --git a/http/events.go b/http/events.go index 1958fdb6c1..fac54a6f1d 100644 --- a/http/events.go +++ b/http/events.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "path" + "slices" "strconv" "strings" "time" @@ -19,25 +20,35 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault/eventbus" + "github.com/patrickmn/go-cache" + "github.com/ryanuber/go-glob" "nhooyr.io/websocket" ) -type eventSubscribeArgs struct { +// webSocketRevalidationTime is how often we re-check access to the +// events that the websocket requested access to. +var webSocketRevalidationTime = 5 * time.Minute + +type eventSubscriber struct { ctx context.Context + clientToken string + capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error) logger hclog.Logger events *eventbus.EventBus namespacePatterns []string pattern string conn *websocket.Conn json bool + checkCache *cache.Cache + isRootToken bool } -// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason -// only if the connection closes or there was an error. -func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { - ctx := args.ctx - logger := args.logger - ch, cancel, err := args.events.SubscribeMultipleNamespaces(ctx, args.namespacePatterns, args.pattern) +// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket +// error code and reason only if the connection closes or there was an error. +func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) { + ctx := sub.ctx + logger := sub.logger + ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern) if err != nil { logger.Info("Error subscribing", "error", err) return websocket.StatusUnsupportedData, "Error subscribing", nil @@ -50,10 +61,18 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo logger.Info("Websocket context is done, closing the connection") return websocket.StatusNormalClosure, "", nil case message := <-ch: + // Perform one last check that the message is allowed to be received. + // For example, if a new namespace was created that matches the namespace patterns, + // but the token doesn't have access to it, we don't want to accidentally send it to + // the websocket. + if !sub.allowMessageCached(message.Payload.(*logical.EventReceived)) { + continue + } + logger.Debug("Sending message to websocket", "message", message.Payload) var messageBytes []byte var messageType websocket.MessageType - if args.json { + if sub.json { var ok bool messageBytes, ok = message.Format("cloudevents-json") if !ok { @@ -69,7 +88,7 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo logger.Warn("Could not serialize websocket event", "error", err) return 0, "", err } - err = args.conn.Write(ctx, messageType, messageBytes) + err = sub.conn.Write(ctx, messageType, messageBytes) if err != nil { return 0, "", err } @@ -77,6 +96,80 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo } } +// allowMessageCached checks that the message is allowed to received by the websocket. +// It caches results for specific namespaces, data paths, and event types. +func (sub *eventSubscriber) allowMessageCached(message *logical.EventReceived) bool { + if sub.isRootToken { + // fast-path root tokens + return true + } + + messageNs := strings.Trim(message.Namespace, "/") + dataPath := "" + if message.Event.Metadata != nil { + dataPathField := message.Event.Metadata.GetFields()[logical.EventMetadataDataPath] + if dataPathField != nil { + dataPath = dataPathField.GetStringValue() + } + } + if dataPath == "" { + // Only allow root tokens to subscribe to events with no data path, for now. + return false + } + cacheKey := fmt.Sprintf("%v!%v!%v", messageNs, dataPath, message.EventType) + _, ok := sub.checkCache.Get(cacheKey) + if ok { + return true + } + + // perform the actual check and cache it if true + ok = sub.allowMessage(messageNs, dataPath, message.EventType) + if ok { + err := sub.checkCache.Add(cacheKey, ok, webSocketRevalidationTime) + if err != nil { + sub.logger.Debug("Error adding to policy check cache for websocket", "error", err) + // still return the right value, but we can't guarantee it was cached + } + } + return ok +} + +// allowMessage checks that the message is allowed to received by the websocket +func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bool { + // does this even match the requested namespaces + matchedNs := false + for _, nsPattern := range sub.namespacePatterns { + if glob.Glob(nsPattern, eventNs) { + matchedNs = true + break + } + } + if !matchedNs { + return false + } + + // next check for specific access to the namespace and event types + nsDataPath := dataPath + if eventNs != "" { + nsDataPath = path.Join(eventNs, dataPath) + } + capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath) + if err != nil { + sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs) + return false + } + if !(slices.Contains(capabilities, vault.RootCapability) || slices.Contains(capabilities, vault.SubscribeCapability)) { + return false + } + for _, pattern := range allowedEventTypes { + if glob.Glob(pattern, eventType) { + return true + } + } + // no event types matched, so return false + return false +} + func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := core.Logger().Named("events-subscribe") @@ -85,7 +178,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler ctx := r.Context() // ACL check - _, _, err := core.CheckToken(ctx, req, false) + auth, entry, err := core.CheckToken(ctx, req, false) if err != nil { if errors.Is(err, logical.ErrPermissionDenied) { respondError(w, http.StatusForbidden, logical.ErrPermissionDenied) @@ -146,7 +239,25 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler } }() - closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), namespacePatterns, pattern, conn, json}) + // continually validate subscribe access while the websocket is running + ctx, cancelCtx := context.WithCancel(ctx) + defer cancelCtx() + go validateSubscribeAccessLoop(core, ctx, cancelCtx, req) + + sub := &eventSubscriber{ + ctx: ctx, + capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes, + logger: logger, + events: core.Events(), + namespacePatterns: namespacePatterns, + pattern: pattern, + conn: conn, + json: json, + checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime), + clientToken: auth.ClientToken, + isRootToken: entry.IsRoot(), + } + closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket() if err != nil { closeStatus = websocket.CloseStatus(err) if closeStatus == -1 { @@ -174,10 +285,37 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam newPatterns := make([]string, 0, len(patterns)+1) newPatterns = append(newPatterns, prepend) for _, pattern := range patterns { - if strings.Trim(strings.TrimSpace(pattern), "/") == "" { - continue + if strings.Trim(pattern, "/") != "" { + newPatterns = append(newPatterns, path.Join(prepend, pattern)) } - newPatterns = append(newPatterns, path.Join(prepend, pattern, "/")) } return newPatterns } + +// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in +// its namespace. If the access check ever fails, then the cancel function is called and the function returns. +func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) { + // if something breaks, default to canceling the websocket + defer cancel() + for { + _, _, err := core.CheckToken(ctx, req, false) + if err != nil { + core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err) + return + } + // wait a while and try again, but quit the loop if the context finishes early + finished := func() bool { + ticker := time.NewTicker(webSocketRevalidationTime) + defer ticker.Stop() + select { + case <-ctx.Done(): + return true + case <-ticker.C: + return false + } + }() + if finished { + return + } + } +} diff --git a/http/events_test.go b/http/events_test.go index 21e4f54ff2..fc36d68b1a 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -135,8 +135,7 @@ func TestEventsSubscribe(t *testing.T) { } } -// TestEventsSubscribeNamespaces tests the websocket endpoint for subscribing to events in multiple namespaces. -func TestEventsSubscribeNamespaces(t *testing.T) { +func TestNamespaceRootSubscriptions(t *testing.T) { core := vault.TestCoreWithConfig(t, &vault.CoreConfig{ Experiments: []string{experiments.VaultExperimentEventsAlpha1}, }) @@ -157,47 +156,27 @@ func TestEventsSubscribeNamespaces(t *testing.T) { const eventType = "abc" - namespaces := []string{ - "", - "ns1", - "ns2", - "ns1/ns13", - "ns1/ns13/ns134", - } - // send some events with the specified namespaces sendEvents := func() error { pluginInfo := &logical.EventPluginInfo{ MountPath: "secret", } - for _, namespacePath := range namespaces { - var ns *namespace.Namespace - if namespacePath == "" { - ns = namespace.RootNamespace - } else { - ns = &namespace.Namespace{ - ID: namespacePath, - Path: namespacePath, - CustomMetadata: nil, - } - } - id, err := uuid.GenerateUUID() - if err != nil { - core.Logger().Info("Error generating UUID, exiting sender", "error", err) - return err - } - err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{ - Id: id, - Metadata: nil, - EntityIds: nil, - Note: "testing", - }) - if err != nil { - core.Logger().Info("Error sending event, exiting sender", "error", err) - return err - } + ns := namespace.RootNamespace + id, err := uuid.GenerateUUID() + if err != nil { + core.Logger().Info("Error generating UUID, exiting sender", "error", err) + return err + } + err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{ + Id: id, + Metadata: nil, + EntityIds: nil, + Note: "testing", + }) + if err != nil { + core.Logger().Info("Error sending event, exiting sender", "error", err) + return err } - return nil } @@ -213,13 +192,14 @@ func TestEventsSubscribeNamespaces(t *testing.T) { namespaces []string expectedEvents int }{ - {"invalid", []string{"something"}, 1}, - {"simple wildcard", []string{"ns*"}, 5}, - {"two namespaces", []string{"ns1/ns13", "ns1/other"}, 2}, + // We only send events in the root namespace, but we test all the various patterns of namespace patterns. + {"single", []string{"something"}, 1}, + {"simple wildcard", []string{"ns*"}, 1}, + {"two namespaces", []string{"ns1/ns13", "ns1/other"}, 1}, {"no namespace", []string{""}, 1}, - {"all wildcard", []string{"*"}, 5}, - {"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 4}, - {"overlapping wildcard", []string{"ns*", "ns1"}, 5}, + {"all wildcard", []string{"*"}, 1}, + {"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 1}, + {"overlapping wildcard", []string{"ns*", "ns1"}, 1}, } for _, testCase := range testCases { @@ -246,8 +226,13 @@ func TestEventsSubscribeNamespaces(t *testing.T) { timeout := 10 * time.Second gotEvents := 0 for { - ctx, cancel := context.WithTimeout(ctx, timeout) - t.Cleanup(func() { defer cancel() }) + // if we got as many as we expect, shorten the test, so we don't waste time, + // but still allow time for "extra" events to come in and make us fail + if gotEvents == testCase.expectedEvents { + timeout = 100 * time.Millisecond + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) _, msg, err := conn.Read(ctx) if err != nil { @@ -263,12 +248,6 @@ func TestEventsSubscribeNamespaces(t *testing.T) { t.Log("event received", string(msg)) gotEvents += 1 - - // if we got as many as we expect, shorten the test, so we don't waste time, - // but still allow time for "extra" events to come in and make us fail - if gotEvents == testCase.expectedEvents { - timeout = 100 * time.Millisecond - } } assert.Equal(t, testCase.expectedEvents, gotEvents) diff --git a/http/http_test.go b/http/http_test.go index f2283eec98..1cd5d0d518 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" ) @@ -25,36 +25,40 @@ func testHttpGet(t *testing.T, token string, addr string) *http.Response { loggedToken = "" } t.Logf("Token is %s", loggedToken) - return testHttpData(t, "GET", token, addr, nil, false, 0) + return testHttpData(t, "GET", token, addr, "", nil, false, 0) } func testHttpDelete(t *testing.T, token string, addr string) *http.Response { - return testHttpData(t, "DELETE", token, addr, nil, false, 0) + return testHttpData(t, "DELETE", token, addr, "", nil, false, 0) } // Go 1.8+ clients redirect automatically which breaks our 307 standby testing func testHttpDeleteDisableRedirect(t *testing.T, token string, addr string) *http.Response { - return testHttpData(t, "DELETE", token, addr, nil, true, 0) + return testHttpData(t, "DELETE", token, addr, "", nil, true, 0) } func testHttpPostWrapped(t *testing.T, token string, addr string, body interface{}, wrapTTL time.Duration) *http.Response { - return testHttpData(t, "POST", token, addr, body, false, wrapTTL) + return testHttpData(t, "POST", token, addr, "", body, false, wrapTTL) } func testHttpPost(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "POST", token, addr, body, false, 0) + return testHttpData(t, "POST", token, addr, "", body, false, 0) +} + +func testHttpPostNamespace(t *testing.T, token string, addr string, namespace string, body interface{}) *http.Response { + return testHttpData(t, "POST", token, addr, namespace, body, false, 0) } func testHttpPut(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "PUT", token, addr, body, false, 0) + return testHttpData(t, "PUT", token, addr, "", body, false, 0) } // Go 1.8+ clients redirect automatically which breaks our 307 standby testing func testHttpPutDisableRedirect(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "PUT", token, addr, body, true, 0) + return testHttpData(t, "PUT", token, addr, "", body, true, 0) } -func testHttpData(t *testing.T, method string, token string, addr string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response { +func testHttpData(t *testing.T, method string, token string, addr string, namespace string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response { bodyReader := new(bytes.Buffer) if body != nil { enc := json.NewEncoder(bodyReader) @@ -78,6 +82,9 @@ func testHttpData(t *testing.T, method string, token string, addr string, body i if wrapTTL > 0 { req.Header.Set("X-Vault-Wrap-TTL", wrapTTL.String()) } + if namespace != "" { + req.Header.Set("X-Vault-Namespace", namespace) + } if len(token) != 0 { req.Header.Set(consts.AuthHeaderName, token) diff --git a/vault/acl.go b/vault/acl.go index 7706b62ac1..b3060df1df 100644 --- a/vault/acl.go +++ b/vault/acl.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "slices" "sort" "strings" @@ -52,13 +53,14 @@ type AuthResults struct { } type ACLResults struct { - Allowed bool - RootPrivs bool - IsRoot bool - MFAMethods []string - ControlGroup *ControlGroup - CapabilitiesBitmap uint32 - GrantingPolicies []logical.PolicyInfo + Allowed bool + RootPrivs bool + IsRoot bool + MFAMethods []string + ControlGroup *ControlGroup + CapabilitiesBitmap uint32 + GrantingPolicies []logical.PolicyInfo + SubscribeEventTypes []string } type SentinelResults struct { @@ -274,6 +276,12 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) { } } + if len(pc.Permissions.SubscribeEventTypes) > 0 { + if len(existingPerms.SubscribeEventTypes) > 0 { + existingPerms.SubscribeEventTypes = strutil.RemoveDuplicates(append(existingPerms.SubscribeEventTypes, pc.Permissions.SubscribeEventTypes...), false) + } + } + INSERT: switch { case pc.HasSegmentWildcards: @@ -286,7 +294,7 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) { return a, nil } -func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities []string) { +func (a *ACL) CapabilitiesAndSubscribeEventTypes(ctx context.Context, path string) (pathCapabilities []string, subscribeEventTypes []string) { req := &logical.Request{ Path: path, // doesn't matter, but use List to trigger fallback behavior so we can @@ -296,9 +304,9 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [ res := a.AllowOperation(ctx, req, true) if res.IsRoot { - return []string{RootCapability} + return []string{RootCapability}, []string{"*"} } - + subscribeEventTypes = res.SubscribeEventTypes capabilities := res.CapabilitiesBitmap if capabilities&SudoCapabilityInt > 0 { @@ -331,9 +339,15 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [ if capabilities&DenyCapabilityInt > 0 || len(pathCapabilities) == 0 { pathCapabilities = []string{DenyCapability} } + return } +func (a *ACL) Capabilities(ctx context.Context, path string) []string { + pathCapabilities, _ := a.CapabilitiesAndSubscribeEventTypes(ctx, path) + return pathCapabilities +} + // AllowOperation is used to check if the given operation is permitted. func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheckOnly bool) (ret *ACLResults) { ret = new(ACLResults) @@ -349,6 +363,7 @@ func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheck NamespacePath: "", Type: "acl", }} + ret.SubscribeEventTypes = []string{"*"} return } op := req.Operation @@ -414,6 +429,7 @@ CHECK: // rather than policy root if capCheckOnly { ret.CapabilitiesBitmap = capabilities + ret.SubscribeEventTypes = slices.Clone(permissions.SubscribeEventTypes) return ret } diff --git a/vault/capabilities.go b/vault/capabilities.go index 2d935ad42f..8310aaf8e4 100644 --- a/vault/capabilities.go +++ b/vault/capabilities.go @@ -12,30 +12,38 @@ import ( ) // Capabilities is used to fetch the capabilities of the given token on the -// given path +// given path. func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, error) { + capabilities, _, err := c.CapabilitiesAndSubscribeEventTypes(ctx, token, path) + return capabilities, err +} + +// CapabilitiesAndSubscribeEventTypes is used to fetch the capabilities and event types that are allowed to +// be subscribed to by given token on the given path. +func (c *Core) CapabilitiesAndSubscribeEventTypes(ctx context.Context, token, path string) ([]string, []string, error) { if path == "" { - return nil, &logical.StatusBadRequest{Err: "missing path"} + return nil, nil, &logical.StatusBadRequest{Err: "missing path"} } if token == "" { - return nil, &logical.StatusBadRequest{Err: "missing token"} + return nil, nil, &logical.StatusBadRequest{Err: "missing token"} } te, err := c.tokenStore.Lookup(ctx, token) if err != nil { - return nil, err + return nil, nil, err } if te == nil { - return nil, &logical.StatusBadRequest{Err: "invalid token"} + return nil, nil, &logical.StatusBadRequest{Err: "invalid token"} } - tokenNS, err := NamespaceByID(ctx, te.NamespaceID, c) + var tokenNS *namespace.Namespace + tokenNS, err = NamespaceByID(ctx, te.NamespaceID, c) if err != nil { - return nil, err + return nil, nil, err } if tokenNS == nil { - return nil, namespace.ErrNoNamespace + return nil, nil, namespace.ErrNoNamespace } var policyCount int @@ -45,15 +53,15 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, entity, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, tokenNS, te.EntityID, te.NoIdentityPolicies) if err != nil { - return nil, err + return nil, nil, err } if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") - return nil, logical.ErrPermissionDenied + return nil, nil, logical.ErrPermissionDenied } if te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") - return nil, logical.ErrPermissionDenied + return nil, nil, logical.ErrPermissionDenied } for nsID, nsPolicies := range identityPolicies { @@ -66,14 +74,14 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, if te.InlinePolicy != "" { inlinePolicy, err := ParseACLPolicy(tokenNS, te.InlinePolicy) if err != nil { - return nil, err + return nil, nil, err } policies = append(policies, inlinePolicy) policyCount++ } if policyCount == 0 { - return []string{DenyCapability}, nil + return []string{DenyCapability}, nil, nil } // Construct the corresponding ACL object. ACL construction should be @@ -81,10 +89,10 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, tokenCtx := namespace.ContextWithNamespace(ctx, tokenNS) acl, err := c.policyStore.ACL(tokenCtx, entity, policyNames, policies...) if err != nil { - return nil, err + return nil, nil, err } - capabilities := acl.Capabilities(ctx, path) + capabilities, eventTypes := acl.CapabilitiesAndSubscribeEventTypes(ctx, path) sort.Strings(capabilities) - return capabilities, nil + return capabilities, eventTypes, nil }