From f0a23e117f8b10cb34a051a00be5c924b042be1c Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Tue, 5 Sep 2023 16:28:09 -0700 Subject: [PATCH] events: Continuously verify policies (#22705) Previously, when a user initiated a websocket subscription, the access to the `sys/events/subscribe` endpoint was checked then, and only once. Now, perform continuous policy checks: * We check access to the `sys/events/subscribe` endpoint every five minutes. If this check fails, then the websocket is terminated. * Upon receiving any message, we verify that the `subscribe` capability is present for that namespace, data path, and event type. If it is not, then the message is not delivered. If the message is allowed, we cache that result for five minutes. Tests for this are in a separate enterprise PR. Documentation will be updated in another PR. Co-authored-by: Tom Proctor Co-authored-by: Nick Cabatoff --- http/events.go | 166 ++++++++++++++++++++++++++++++++++++++---- http/events_test.go | 81 ++++++++------------- http/http_test.go | 25 ++++--- vault/acl.go | 36 ++++++--- vault/capabilities.go | 40 ++++++---- 5 files changed, 248 insertions(+), 100 deletions(-) diff --git a/http/events.go b/http/events.go index 1958fdb6c1..fac54a6f1d 100644 --- a/http/events.go +++ b/http/events.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "path" + "slices" "strconv" "strings" "time" @@ -19,25 +20,35 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault/eventbus" + "github.com/patrickmn/go-cache" + "github.com/ryanuber/go-glob" "nhooyr.io/websocket" ) -type eventSubscribeArgs struct { +// webSocketRevalidationTime is how often we re-check access to the +// events that the websocket requested access to. +var webSocketRevalidationTime = 5 * time.Minute + +type eventSubscriber struct { ctx context.Context + clientToken string + capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error) logger hclog.Logger events *eventbus.EventBus namespacePatterns []string pattern string conn *websocket.Conn json bool + checkCache *cache.Cache + isRootToken bool } -// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason -// only if the connection closes or there was an error. -func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { - ctx := args.ctx - logger := args.logger - ch, cancel, err := args.events.SubscribeMultipleNamespaces(ctx, args.namespacePatterns, args.pattern) +// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket +// error code and reason only if the connection closes or there was an error. +func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) { + ctx := sub.ctx + logger := sub.logger + ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern) if err != nil { logger.Info("Error subscribing", "error", err) return websocket.StatusUnsupportedData, "Error subscribing", nil @@ -50,10 +61,18 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo logger.Info("Websocket context is done, closing the connection") return websocket.StatusNormalClosure, "", nil case message := <-ch: + // Perform one last check that the message is allowed to be received. + // For example, if a new namespace was created that matches the namespace patterns, + // but the token doesn't have access to it, we don't want to accidentally send it to + // the websocket. + if !sub.allowMessageCached(message.Payload.(*logical.EventReceived)) { + continue + } + logger.Debug("Sending message to websocket", "message", message.Payload) var messageBytes []byte var messageType websocket.MessageType - if args.json { + if sub.json { var ok bool messageBytes, ok = message.Format("cloudevents-json") if !ok { @@ -69,7 +88,7 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo logger.Warn("Could not serialize websocket event", "error", err) return 0, "", err } - err = args.conn.Write(ctx, messageType, messageBytes) + err = sub.conn.Write(ctx, messageType, messageBytes) if err != nil { return 0, "", err } @@ -77,6 +96,80 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo } } +// allowMessageCached checks that the message is allowed to received by the websocket. +// It caches results for specific namespaces, data paths, and event types. +func (sub *eventSubscriber) allowMessageCached(message *logical.EventReceived) bool { + if sub.isRootToken { + // fast-path root tokens + return true + } + + messageNs := strings.Trim(message.Namespace, "/") + dataPath := "" + if message.Event.Metadata != nil { + dataPathField := message.Event.Metadata.GetFields()[logical.EventMetadataDataPath] + if dataPathField != nil { + dataPath = dataPathField.GetStringValue() + } + } + if dataPath == "" { + // Only allow root tokens to subscribe to events with no data path, for now. + return false + } + cacheKey := fmt.Sprintf("%v!%v!%v", messageNs, dataPath, message.EventType) + _, ok := sub.checkCache.Get(cacheKey) + if ok { + return true + } + + // perform the actual check and cache it if true + ok = sub.allowMessage(messageNs, dataPath, message.EventType) + if ok { + err := sub.checkCache.Add(cacheKey, ok, webSocketRevalidationTime) + if err != nil { + sub.logger.Debug("Error adding to policy check cache for websocket", "error", err) + // still return the right value, but we can't guarantee it was cached + } + } + return ok +} + +// allowMessage checks that the message is allowed to received by the websocket +func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bool { + // does this even match the requested namespaces + matchedNs := false + for _, nsPattern := range sub.namespacePatterns { + if glob.Glob(nsPattern, eventNs) { + matchedNs = true + break + } + } + if !matchedNs { + return false + } + + // next check for specific access to the namespace and event types + nsDataPath := dataPath + if eventNs != "" { + nsDataPath = path.Join(eventNs, dataPath) + } + capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath) + if err != nil { + sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs) + return false + } + if !(slices.Contains(capabilities, vault.RootCapability) || slices.Contains(capabilities, vault.SubscribeCapability)) { + return false + } + for _, pattern := range allowedEventTypes { + if glob.Glob(pattern, eventType) { + return true + } + } + // no event types matched, so return false + return false +} + func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := core.Logger().Named("events-subscribe") @@ -85,7 +178,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler ctx := r.Context() // ACL check - _, _, err := core.CheckToken(ctx, req, false) + auth, entry, err := core.CheckToken(ctx, req, false) if err != nil { if errors.Is(err, logical.ErrPermissionDenied) { respondError(w, http.StatusForbidden, logical.ErrPermissionDenied) @@ -146,7 +239,25 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler } }() - closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), namespacePatterns, pattern, conn, json}) + // continually validate subscribe access while the websocket is running + ctx, cancelCtx := context.WithCancel(ctx) + defer cancelCtx() + go validateSubscribeAccessLoop(core, ctx, cancelCtx, req) + + sub := &eventSubscriber{ + ctx: ctx, + capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes, + logger: logger, + events: core.Events(), + namespacePatterns: namespacePatterns, + pattern: pattern, + conn: conn, + json: json, + checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime), + clientToken: auth.ClientToken, + isRootToken: entry.IsRoot(), + } + closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket() if err != nil { closeStatus = websocket.CloseStatus(err) if closeStatus == -1 { @@ -174,10 +285,37 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam newPatterns := make([]string, 0, len(patterns)+1) newPatterns = append(newPatterns, prepend) for _, pattern := range patterns { - if strings.Trim(strings.TrimSpace(pattern), "/") == "" { - continue + if strings.Trim(pattern, "/") != "" { + newPatterns = append(newPatterns, path.Join(prepend, pattern)) } - newPatterns = append(newPatterns, path.Join(prepend, pattern, "/")) } return newPatterns } + +// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in +// its namespace. If the access check ever fails, then the cancel function is called and the function returns. +func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) { + // if something breaks, default to canceling the websocket + defer cancel() + for { + _, _, err := core.CheckToken(ctx, req, false) + if err != nil { + core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err) + return + } + // wait a while and try again, but quit the loop if the context finishes early + finished := func() bool { + ticker := time.NewTicker(webSocketRevalidationTime) + defer ticker.Stop() + select { + case <-ctx.Done(): + return true + case <-ticker.C: + return false + } + }() + if finished { + return + } + } +} diff --git a/http/events_test.go b/http/events_test.go index 21e4f54ff2..fc36d68b1a 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -135,8 +135,7 @@ func TestEventsSubscribe(t *testing.T) { } } -// TestEventsSubscribeNamespaces tests the websocket endpoint for subscribing to events in multiple namespaces. -func TestEventsSubscribeNamespaces(t *testing.T) { +func TestNamespaceRootSubscriptions(t *testing.T) { core := vault.TestCoreWithConfig(t, &vault.CoreConfig{ Experiments: []string{experiments.VaultExperimentEventsAlpha1}, }) @@ -157,47 +156,27 @@ func TestEventsSubscribeNamespaces(t *testing.T) { const eventType = "abc" - namespaces := []string{ - "", - "ns1", - "ns2", - "ns1/ns13", - "ns1/ns13/ns134", - } - // send some events with the specified namespaces sendEvents := func() error { pluginInfo := &logical.EventPluginInfo{ MountPath: "secret", } - for _, namespacePath := range namespaces { - var ns *namespace.Namespace - if namespacePath == "" { - ns = namespace.RootNamespace - } else { - ns = &namespace.Namespace{ - ID: namespacePath, - Path: namespacePath, - CustomMetadata: nil, - } - } - id, err := uuid.GenerateUUID() - if err != nil { - core.Logger().Info("Error generating UUID, exiting sender", "error", err) - return err - } - err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{ - Id: id, - Metadata: nil, - EntityIds: nil, - Note: "testing", - }) - if err != nil { - core.Logger().Info("Error sending event, exiting sender", "error", err) - return err - } + ns := namespace.RootNamespace + id, err := uuid.GenerateUUID() + if err != nil { + core.Logger().Info("Error generating UUID, exiting sender", "error", err) + return err + } + err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{ + Id: id, + Metadata: nil, + EntityIds: nil, + Note: "testing", + }) + if err != nil { + core.Logger().Info("Error sending event, exiting sender", "error", err) + return err } - return nil } @@ -213,13 +192,14 @@ func TestEventsSubscribeNamespaces(t *testing.T) { namespaces []string expectedEvents int }{ - {"invalid", []string{"something"}, 1}, - {"simple wildcard", []string{"ns*"}, 5}, - {"two namespaces", []string{"ns1/ns13", "ns1/other"}, 2}, + // We only send events in the root namespace, but we test all the various patterns of namespace patterns. + {"single", []string{"something"}, 1}, + {"simple wildcard", []string{"ns*"}, 1}, + {"two namespaces", []string{"ns1/ns13", "ns1/other"}, 1}, {"no namespace", []string{""}, 1}, - {"all wildcard", []string{"*"}, 5}, - {"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 4}, - {"overlapping wildcard", []string{"ns*", "ns1"}, 5}, + {"all wildcard", []string{"*"}, 1}, + {"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 1}, + {"overlapping wildcard", []string{"ns*", "ns1"}, 1}, } for _, testCase := range testCases { @@ -246,8 +226,13 @@ func TestEventsSubscribeNamespaces(t *testing.T) { timeout := 10 * time.Second gotEvents := 0 for { - ctx, cancel := context.WithTimeout(ctx, timeout) - t.Cleanup(func() { defer cancel() }) + // if we got as many as we expect, shorten the test, so we don't waste time, + // but still allow time for "extra" events to come in and make us fail + if gotEvents == testCase.expectedEvents { + timeout = 100 * time.Millisecond + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) _, msg, err := conn.Read(ctx) if err != nil { @@ -263,12 +248,6 @@ func TestEventsSubscribeNamespaces(t *testing.T) { t.Log("event received", string(msg)) gotEvents += 1 - - // if we got as many as we expect, shorten the test, so we don't waste time, - // but still allow time for "extra" events to come in and make us fail - if gotEvents == testCase.expectedEvents { - timeout = 100 * time.Millisecond - } } assert.Equal(t, testCase.expectedEvents, gotEvents) diff --git a/http/http_test.go b/http/http_test.go index f2283eec98..1cd5d0d518 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" ) @@ -25,36 +25,40 @@ func testHttpGet(t *testing.T, token string, addr string) *http.Response { loggedToken = "" } t.Logf("Token is %s", loggedToken) - return testHttpData(t, "GET", token, addr, nil, false, 0) + return testHttpData(t, "GET", token, addr, "", nil, false, 0) } func testHttpDelete(t *testing.T, token string, addr string) *http.Response { - return testHttpData(t, "DELETE", token, addr, nil, false, 0) + return testHttpData(t, "DELETE", token, addr, "", nil, false, 0) } // Go 1.8+ clients redirect automatically which breaks our 307 standby testing func testHttpDeleteDisableRedirect(t *testing.T, token string, addr string) *http.Response { - return testHttpData(t, "DELETE", token, addr, nil, true, 0) + return testHttpData(t, "DELETE", token, addr, "", nil, true, 0) } func testHttpPostWrapped(t *testing.T, token string, addr string, body interface{}, wrapTTL time.Duration) *http.Response { - return testHttpData(t, "POST", token, addr, body, false, wrapTTL) + return testHttpData(t, "POST", token, addr, "", body, false, wrapTTL) } func testHttpPost(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "POST", token, addr, body, false, 0) + return testHttpData(t, "POST", token, addr, "", body, false, 0) +} + +func testHttpPostNamespace(t *testing.T, token string, addr string, namespace string, body interface{}) *http.Response { + return testHttpData(t, "POST", token, addr, namespace, body, false, 0) } func testHttpPut(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "PUT", token, addr, body, false, 0) + return testHttpData(t, "PUT", token, addr, "", body, false, 0) } // Go 1.8+ clients redirect automatically which breaks our 307 standby testing func testHttpPutDisableRedirect(t *testing.T, token string, addr string, body interface{}) *http.Response { - return testHttpData(t, "PUT", token, addr, body, true, 0) + return testHttpData(t, "PUT", token, addr, "", body, true, 0) } -func testHttpData(t *testing.T, method string, token string, addr string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response { +func testHttpData(t *testing.T, method string, token string, addr string, namespace string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response { bodyReader := new(bytes.Buffer) if body != nil { enc := json.NewEncoder(bodyReader) @@ -78,6 +82,9 @@ func testHttpData(t *testing.T, method string, token string, addr string, body i if wrapTTL > 0 { req.Header.Set("X-Vault-Wrap-TTL", wrapTTL.String()) } + if namespace != "" { + req.Header.Set("X-Vault-Namespace", namespace) + } if len(token) != 0 { req.Header.Set(consts.AuthHeaderName, token) diff --git a/vault/acl.go b/vault/acl.go index 7706b62ac1..b3060df1df 100644 --- a/vault/acl.go +++ b/vault/acl.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "slices" "sort" "strings" @@ -52,13 +53,14 @@ type AuthResults struct { } type ACLResults struct { - Allowed bool - RootPrivs bool - IsRoot bool - MFAMethods []string - ControlGroup *ControlGroup - CapabilitiesBitmap uint32 - GrantingPolicies []logical.PolicyInfo + Allowed bool + RootPrivs bool + IsRoot bool + MFAMethods []string + ControlGroup *ControlGroup + CapabilitiesBitmap uint32 + GrantingPolicies []logical.PolicyInfo + SubscribeEventTypes []string } type SentinelResults struct { @@ -274,6 +276,12 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) { } } + if len(pc.Permissions.SubscribeEventTypes) > 0 { + if len(existingPerms.SubscribeEventTypes) > 0 { + existingPerms.SubscribeEventTypes = strutil.RemoveDuplicates(append(existingPerms.SubscribeEventTypes, pc.Permissions.SubscribeEventTypes...), false) + } + } + INSERT: switch { case pc.HasSegmentWildcards: @@ -286,7 +294,7 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) { return a, nil } -func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities []string) { +func (a *ACL) CapabilitiesAndSubscribeEventTypes(ctx context.Context, path string) (pathCapabilities []string, subscribeEventTypes []string) { req := &logical.Request{ Path: path, // doesn't matter, but use List to trigger fallback behavior so we can @@ -296,9 +304,9 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [ res := a.AllowOperation(ctx, req, true) if res.IsRoot { - return []string{RootCapability} + return []string{RootCapability}, []string{"*"} } - + subscribeEventTypes = res.SubscribeEventTypes capabilities := res.CapabilitiesBitmap if capabilities&SudoCapabilityInt > 0 { @@ -331,9 +339,15 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [ if capabilities&DenyCapabilityInt > 0 || len(pathCapabilities) == 0 { pathCapabilities = []string{DenyCapability} } + return } +func (a *ACL) Capabilities(ctx context.Context, path string) []string { + pathCapabilities, _ := a.CapabilitiesAndSubscribeEventTypes(ctx, path) + return pathCapabilities +} + // AllowOperation is used to check if the given operation is permitted. func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheckOnly bool) (ret *ACLResults) { ret = new(ACLResults) @@ -349,6 +363,7 @@ func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheck NamespacePath: "", Type: "acl", }} + ret.SubscribeEventTypes = []string{"*"} return } op := req.Operation @@ -414,6 +429,7 @@ CHECK: // rather than policy root if capCheckOnly { ret.CapabilitiesBitmap = capabilities + ret.SubscribeEventTypes = slices.Clone(permissions.SubscribeEventTypes) return ret } diff --git a/vault/capabilities.go b/vault/capabilities.go index 2d935ad42f..8310aaf8e4 100644 --- a/vault/capabilities.go +++ b/vault/capabilities.go @@ -12,30 +12,38 @@ import ( ) // Capabilities is used to fetch the capabilities of the given token on the -// given path +// given path. func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, error) { + capabilities, _, err := c.CapabilitiesAndSubscribeEventTypes(ctx, token, path) + return capabilities, err +} + +// CapabilitiesAndSubscribeEventTypes is used to fetch the capabilities and event types that are allowed to +// be subscribed to by given token on the given path. +func (c *Core) CapabilitiesAndSubscribeEventTypes(ctx context.Context, token, path string) ([]string, []string, error) { if path == "" { - return nil, &logical.StatusBadRequest{Err: "missing path"} + return nil, nil, &logical.StatusBadRequest{Err: "missing path"} } if token == "" { - return nil, &logical.StatusBadRequest{Err: "missing token"} + return nil, nil, &logical.StatusBadRequest{Err: "missing token"} } te, err := c.tokenStore.Lookup(ctx, token) if err != nil { - return nil, err + return nil, nil, err } if te == nil { - return nil, &logical.StatusBadRequest{Err: "invalid token"} + return nil, nil, &logical.StatusBadRequest{Err: "invalid token"} } - tokenNS, err := NamespaceByID(ctx, te.NamespaceID, c) + var tokenNS *namespace.Namespace + tokenNS, err = NamespaceByID(ctx, te.NamespaceID, c) if err != nil { - return nil, err + return nil, nil, err } if tokenNS == nil { - return nil, namespace.ErrNoNamespace + return nil, nil, namespace.ErrNoNamespace } var policyCount int @@ -45,15 +53,15 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, entity, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, tokenNS, te.EntityID, te.NoIdentityPolicies) if err != nil { - return nil, err + return nil, nil, err } if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") - return nil, logical.ErrPermissionDenied + return nil, nil, logical.ErrPermissionDenied } if te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") - return nil, logical.ErrPermissionDenied + return nil, nil, logical.ErrPermissionDenied } for nsID, nsPolicies := range identityPolicies { @@ -66,14 +74,14 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, if te.InlinePolicy != "" { inlinePolicy, err := ParseACLPolicy(tokenNS, te.InlinePolicy) if err != nil { - return nil, err + return nil, nil, err } policies = append(policies, inlinePolicy) policyCount++ } if policyCount == 0 { - return []string{DenyCapability}, nil + return []string{DenyCapability}, nil, nil } // Construct the corresponding ACL object. ACL construction should be @@ -81,10 +89,10 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, tokenCtx := namespace.ContextWithNamespace(ctx, tokenNS) acl, err := c.policyStore.ACL(tokenCtx, entity, policyNames, policies...) if err != nil { - return nil, err + return nil, nil, err } - capabilities := acl.Capabilities(ctx, path) + capabilities, eventTypes := acl.CapabilitiesAndSubscribeEventTypes(ctx, path) sort.Strings(capabilities) - return capabilities, nil + return capabilities, eventTypes, nil }