events: Continuously verify policies (#22705)

Previously, when a user initiated a websocket subscription,
the access to the `sys/events/subscribe` endpoint was checked then,
and only once.

Now, perform continuous policy checks:

* We check access to the `sys/events/subscribe` endpoint every five
  minutes. If this check fails, then the websocket is terminated.
* Upon receiving any message, we verify that the `subscribe`
  capability is present for that namespace, data path, and event type.
  If it is not, then the message is not delivered. If the message is
  allowed, we cache that result for five minutes.

Tests for this are in a separate enterprise PR.

Documentation will be updated in another PR.

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
Christopher Swenson 2023-09-05 16:28:09 -07:00 committed by GitHub
parent 545b6e4eae
commit f0a23e117f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 248 additions and 100 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"path" "path"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -19,25 +20,35 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/eventbus" "github.com/hashicorp/vault/vault/eventbus"
"github.com/patrickmn/go-cache"
"github.com/ryanuber/go-glob"
"nhooyr.io/websocket" "nhooyr.io/websocket"
) )
type eventSubscribeArgs struct { // webSocketRevalidationTime is how often we re-check access to the
// events that the websocket requested access to.
var webSocketRevalidationTime = 5 * time.Minute
type eventSubscriber struct {
ctx context.Context ctx context.Context
clientToken string
capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error)
logger hclog.Logger logger hclog.Logger
events *eventbus.EventBus events *eventbus.EventBus
namespacePatterns []string namespacePatterns []string
pattern string pattern string
conn *websocket.Conn conn *websocket.Conn
json bool json bool
checkCache *cache.Cache
isRootToken bool
} }
// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason // handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket
// only if the connection closes or there was an error. // error code and reason only if the connection closes or there was an error.
func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) {
ctx := args.ctx ctx := sub.ctx
logger := args.logger logger := sub.logger
ch, cancel, err := args.events.SubscribeMultipleNamespaces(ctx, args.namespacePatterns, args.pattern) ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern)
if err != nil { if err != nil {
logger.Info("Error subscribing", "error", err) logger.Info("Error subscribing", "error", err)
return websocket.StatusUnsupportedData, "Error subscribing", nil return websocket.StatusUnsupportedData, "Error subscribing", nil
@ -50,10 +61,18 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo
logger.Info("Websocket context is done, closing the connection") logger.Info("Websocket context is done, closing the connection")
return websocket.StatusNormalClosure, "", nil return websocket.StatusNormalClosure, "", nil
case message := <-ch: case message := <-ch:
// Perform one last check that the message is allowed to be received.
// For example, if a new namespace was created that matches the namespace patterns,
// but the token doesn't have access to it, we don't want to accidentally send it to
// the websocket.
if !sub.allowMessageCached(message.Payload.(*logical.EventReceived)) {
continue
}
logger.Debug("Sending message to websocket", "message", message.Payload) logger.Debug("Sending message to websocket", "message", message.Payload)
var messageBytes []byte var messageBytes []byte
var messageType websocket.MessageType var messageType websocket.MessageType
if args.json { if sub.json {
var ok bool var ok bool
messageBytes, ok = message.Format("cloudevents-json") messageBytes, ok = message.Format("cloudevents-json")
if !ok { if !ok {
@ -69,7 +88,7 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo
logger.Warn("Could not serialize websocket event", "error", err) logger.Warn("Could not serialize websocket event", "error", err)
return 0, "", err return 0, "", err
} }
err = args.conn.Write(ctx, messageType, messageBytes) err = sub.conn.Write(ctx, messageType, messageBytes)
if err != nil { if err != nil {
return 0, "", err return 0, "", err
} }
@ -77,6 +96,80 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo
} }
} }
// allowMessageCached checks that the message is allowed to received by the websocket.
// It caches results for specific namespaces, data paths, and event types.
func (sub *eventSubscriber) allowMessageCached(message *logical.EventReceived) bool {
if sub.isRootToken {
// fast-path root tokens
return true
}
messageNs := strings.Trim(message.Namespace, "/")
dataPath := ""
if message.Event.Metadata != nil {
dataPathField := message.Event.Metadata.GetFields()[logical.EventMetadataDataPath]
if dataPathField != nil {
dataPath = dataPathField.GetStringValue()
}
}
if dataPath == "" {
// Only allow root tokens to subscribe to events with no data path, for now.
return false
}
cacheKey := fmt.Sprintf("%v!%v!%v", messageNs, dataPath, message.EventType)
_, ok := sub.checkCache.Get(cacheKey)
if ok {
return true
}
// perform the actual check and cache it if true
ok = sub.allowMessage(messageNs, dataPath, message.EventType)
if ok {
err := sub.checkCache.Add(cacheKey, ok, webSocketRevalidationTime)
if err != nil {
sub.logger.Debug("Error adding to policy check cache for websocket", "error", err)
// still return the right value, but we can't guarantee it was cached
}
}
return ok
}
// allowMessage checks that the message is allowed to received by the websocket
func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bool {
// does this even match the requested namespaces
matchedNs := false
for _, nsPattern := range sub.namespacePatterns {
if glob.Glob(nsPattern, eventNs) {
matchedNs = true
break
}
}
if !matchedNs {
return false
}
// next check for specific access to the namespace and event types
nsDataPath := dataPath
if eventNs != "" {
nsDataPath = path.Join(eventNs, dataPath)
}
capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath)
if err != nil {
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
return false
}
if !(slices.Contains(capabilities, vault.RootCapability) || slices.Contains(capabilities, vault.SubscribeCapability)) {
return false
}
for _, pattern := range allowedEventTypes {
if glob.Glob(pattern, eventType) {
return true
}
}
// no event types matched, so return false
return false
}
func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler { func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := core.Logger().Named("events-subscribe") logger := core.Logger().Named("events-subscribe")
@ -85,7 +178,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
ctx := r.Context() ctx := r.Context()
// ACL check // ACL check
_, _, err := core.CheckToken(ctx, req, false) auth, entry, err := core.CheckToken(ctx, req, false)
if err != nil { if err != nil {
if errors.Is(err, logical.ErrPermissionDenied) { if errors.Is(err, logical.ErrPermissionDenied) {
respondError(w, http.StatusForbidden, logical.ErrPermissionDenied) respondError(w, http.StatusForbidden, logical.ErrPermissionDenied)
@ -146,7 +239,25 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
} }
}() }()
closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), namespacePatterns, pattern, conn, json}) // continually validate subscribe access while the websocket is running
ctx, cancelCtx := context.WithCancel(ctx)
defer cancelCtx()
go validateSubscribeAccessLoop(core, ctx, cancelCtx, req)
sub := &eventSubscriber{
ctx: ctx,
capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes,
logger: logger,
events: core.Events(),
namespacePatterns: namespacePatterns,
pattern: pattern,
conn: conn,
json: json,
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
clientToken: auth.ClientToken,
isRootToken: entry.IsRoot(),
}
closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket()
if err != nil { if err != nil {
closeStatus = websocket.CloseStatus(err) closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 { if closeStatus == -1 {
@ -174,10 +285,37 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam
newPatterns := make([]string, 0, len(patterns)+1) newPatterns := make([]string, 0, len(patterns)+1)
newPatterns = append(newPatterns, prepend) newPatterns = append(newPatterns, prepend)
for _, pattern := range patterns { for _, pattern := range patterns {
if strings.Trim(strings.TrimSpace(pattern), "/") == "" { if strings.Trim(pattern, "/") != "" {
continue newPatterns = append(newPatterns, path.Join(prepend, pattern))
} }
newPatterns = append(newPatterns, path.Join(prepend, pattern, "/"))
} }
return newPatterns return newPatterns
} }
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
// its namespace. If the access check ever fails, then the cancel function is called and the function returns.
func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) {
// if something breaks, default to canceling the websocket
defer cancel()
for {
_, _, err := core.CheckToken(ctx, req, false)
if err != nil {
core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err)
return
}
// wait a while and try again, but quit the loop if the context finishes early
finished := func() bool {
ticker := time.NewTicker(webSocketRevalidationTime)
defer ticker.Stop()
select {
case <-ctx.Done():
return true
case <-ticker.C:
return false
}
}()
if finished {
return
}
}
}

View File

@ -135,8 +135,7 @@ func TestEventsSubscribe(t *testing.T) {
} }
} }
// TestEventsSubscribeNamespaces tests the websocket endpoint for subscribing to events in multiple namespaces. func TestNamespaceRootSubscriptions(t *testing.T) {
func TestEventsSubscribeNamespaces(t *testing.T) {
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{ core := vault.TestCoreWithConfig(t, &vault.CoreConfig{
Experiments: []string{experiments.VaultExperimentEventsAlpha1}, Experiments: []string{experiments.VaultExperimentEventsAlpha1},
}) })
@ -157,47 +156,27 @@ func TestEventsSubscribeNamespaces(t *testing.T) {
const eventType = "abc" const eventType = "abc"
namespaces := []string{
"",
"ns1",
"ns2",
"ns1/ns13",
"ns1/ns13/ns134",
}
// send some events with the specified namespaces // send some events with the specified namespaces
sendEvents := func() error { sendEvents := func() error {
pluginInfo := &logical.EventPluginInfo{ pluginInfo := &logical.EventPluginInfo{
MountPath: "secret", MountPath: "secret",
} }
for _, namespacePath := range namespaces { ns := namespace.RootNamespace
var ns *namespace.Namespace id, err := uuid.GenerateUUID()
if namespacePath == "" { if err != nil {
ns = namespace.RootNamespace core.Logger().Info("Error generating UUID, exiting sender", "error", err)
} else { return err
ns = &namespace.Namespace{ }
ID: namespacePath, err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{
Path: namespacePath, Id: id,
CustomMetadata: nil, Metadata: nil,
} EntityIds: nil,
} Note: "testing",
id, err := uuid.GenerateUUID() })
if err != nil { if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err) core.Logger().Info("Error sending event, exiting sender", "error", err)
return err return err
}
err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
core.Logger().Info("Error sending event, exiting sender", "error", err)
return err
}
} }
return nil return nil
} }
@ -213,13 +192,14 @@ func TestEventsSubscribeNamespaces(t *testing.T) {
namespaces []string namespaces []string
expectedEvents int expectedEvents int
}{ }{
{"invalid", []string{"something"}, 1}, // We only send events in the root namespace, but we test all the various patterns of namespace patterns.
{"simple wildcard", []string{"ns*"}, 5}, {"single", []string{"something"}, 1},
{"two namespaces", []string{"ns1/ns13", "ns1/other"}, 2}, {"simple wildcard", []string{"ns*"}, 1},
{"two namespaces", []string{"ns1/ns13", "ns1/other"}, 1},
{"no namespace", []string{""}, 1}, {"no namespace", []string{""}, 1},
{"all wildcard", []string{"*"}, 5}, {"all wildcard", []string{"*"}, 1},
{"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 4}, {"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 1},
{"overlapping wildcard", []string{"ns*", "ns1"}, 5}, {"overlapping wildcard", []string{"ns*", "ns1"}, 1},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
@ -246,8 +226,13 @@ func TestEventsSubscribeNamespaces(t *testing.T) {
timeout := 10 * time.Second timeout := 10 * time.Second
gotEvents := 0 gotEvents := 0
for { for {
ctx, cancel := context.WithTimeout(ctx, timeout) // if we got as many as we expect, shorten the test, so we don't waste time,
t.Cleanup(func() { defer cancel() }) // but still allow time for "extra" events to come in and make us fail
if gotEvents == testCase.expectedEvents {
timeout = 100 * time.Millisecond
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)
_, msg, err := conn.Read(ctx) _, msg, err := conn.Read(ctx)
if err != nil { if err != nil {
@ -263,12 +248,6 @@ func TestEventsSubscribeNamespaces(t *testing.T) {
t.Log("event received", string(msg)) t.Log("event received", string(msg))
gotEvents += 1 gotEvents += 1
// if we got as many as we expect, shorten the test, so we don't waste time,
// but still allow time for "extra" events to come in and make us fail
if gotEvents == testCase.expectedEvents {
timeout = 100 * time.Millisecond
}
} }
assert.Equal(t, testCase.expectedEvents, gotEvents) assert.Equal(t, testCase.expectedEvents, gotEvents)

View File

@ -14,7 +14,7 @@ import (
"testing" "testing"
"time" "time"
cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/jsonutil"
) )
@ -25,36 +25,40 @@ func testHttpGet(t *testing.T, token string, addr string) *http.Response {
loggedToken = "<empty>" loggedToken = "<empty>"
} }
t.Logf("Token is %s", loggedToken) t.Logf("Token is %s", loggedToken)
return testHttpData(t, "GET", token, addr, nil, false, 0) return testHttpData(t, "GET", token, addr, "", nil, false, 0)
} }
func testHttpDelete(t *testing.T, token string, addr string) *http.Response { func testHttpDelete(t *testing.T, token string, addr string) *http.Response {
return testHttpData(t, "DELETE", token, addr, nil, false, 0) return testHttpData(t, "DELETE", token, addr, "", nil, false, 0)
} }
// Go 1.8+ clients redirect automatically which breaks our 307 standby testing // Go 1.8+ clients redirect automatically which breaks our 307 standby testing
func testHttpDeleteDisableRedirect(t *testing.T, token string, addr string) *http.Response { func testHttpDeleteDisableRedirect(t *testing.T, token string, addr string) *http.Response {
return testHttpData(t, "DELETE", token, addr, nil, true, 0) return testHttpData(t, "DELETE", token, addr, "", nil, true, 0)
} }
func testHttpPostWrapped(t *testing.T, token string, addr string, body interface{}, wrapTTL time.Duration) *http.Response { func testHttpPostWrapped(t *testing.T, token string, addr string, body interface{}, wrapTTL time.Duration) *http.Response {
return testHttpData(t, "POST", token, addr, body, false, wrapTTL) return testHttpData(t, "POST", token, addr, "", body, false, wrapTTL)
} }
func testHttpPost(t *testing.T, token string, addr string, body interface{}) *http.Response { func testHttpPost(t *testing.T, token string, addr string, body interface{}) *http.Response {
return testHttpData(t, "POST", token, addr, body, false, 0) return testHttpData(t, "POST", token, addr, "", body, false, 0)
}
func testHttpPostNamespace(t *testing.T, token string, addr string, namespace string, body interface{}) *http.Response {
return testHttpData(t, "POST", token, addr, namespace, body, false, 0)
} }
func testHttpPut(t *testing.T, token string, addr string, body interface{}) *http.Response { func testHttpPut(t *testing.T, token string, addr string, body interface{}) *http.Response {
return testHttpData(t, "PUT", token, addr, body, false, 0) return testHttpData(t, "PUT", token, addr, "", body, false, 0)
} }
// Go 1.8+ clients redirect automatically which breaks our 307 standby testing // Go 1.8+ clients redirect automatically which breaks our 307 standby testing
func testHttpPutDisableRedirect(t *testing.T, token string, addr string, body interface{}) *http.Response { func testHttpPutDisableRedirect(t *testing.T, token string, addr string, body interface{}) *http.Response {
return testHttpData(t, "PUT", token, addr, body, true, 0) return testHttpData(t, "PUT", token, addr, "", body, true, 0)
} }
func testHttpData(t *testing.T, method string, token string, addr string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response { func testHttpData(t *testing.T, method string, token string, addr string, namespace string, body interface{}, disableRedirect bool, wrapTTL time.Duration) *http.Response {
bodyReader := new(bytes.Buffer) bodyReader := new(bytes.Buffer)
if body != nil { if body != nil {
enc := json.NewEncoder(bodyReader) enc := json.NewEncoder(bodyReader)
@ -78,6 +82,9 @@ func testHttpData(t *testing.T, method string, token string, addr string, body i
if wrapTTL > 0 { if wrapTTL > 0 {
req.Header.Set("X-Vault-Wrap-TTL", wrapTTL.String()) req.Header.Set("X-Vault-Wrap-TTL", wrapTTL.String())
} }
if namespace != "" {
req.Header.Set("X-Vault-Namespace", namespace)
}
if len(token) != 0 { if len(token) != 0 {
req.Header.Set(consts.AuthHeaderName, token) req.Header.Set(consts.AuthHeaderName, token)

View File

@ -7,6 +7,7 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"slices"
"sort" "sort"
"strings" "strings"
@ -52,13 +53,14 @@ type AuthResults struct {
} }
type ACLResults struct { type ACLResults struct {
Allowed bool Allowed bool
RootPrivs bool RootPrivs bool
IsRoot bool IsRoot bool
MFAMethods []string MFAMethods []string
ControlGroup *ControlGroup ControlGroup *ControlGroup
CapabilitiesBitmap uint32 CapabilitiesBitmap uint32
GrantingPolicies []logical.PolicyInfo GrantingPolicies []logical.PolicyInfo
SubscribeEventTypes []string
} }
type SentinelResults struct { type SentinelResults struct {
@ -274,6 +276,12 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) {
} }
} }
if len(pc.Permissions.SubscribeEventTypes) > 0 {
if len(existingPerms.SubscribeEventTypes) > 0 {
existingPerms.SubscribeEventTypes = strutil.RemoveDuplicates(append(existingPerms.SubscribeEventTypes, pc.Permissions.SubscribeEventTypes...), false)
}
}
INSERT: INSERT:
switch { switch {
case pc.HasSegmentWildcards: case pc.HasSegmentWildcards:
@ -286,7 +294,7 @@ func NewACL(ctx context.Context, policies []*Policy) (*ACL, error) {
return a, nil return a, nil
} }
func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities []string) { func (a *ACL) CapabilitiesAndSubscribeEventTypes(ctx context.Context, path string) (pathCapabilities []string, subscribeEventTypes []string) {
req := &logical.Request{ req := &logical.Request{
Path: path, Path: path,
// doesn't matter, but use List to trigger fallback behavior so we can // doesn't matter, but use List to trigger fallback behavior so we can
@ -296,9 +304,9 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [
res := a.AllowOperation(ctx, req, true) res := a.AllowOperation(ctx, req, true)
if res.IsRoot { if res.IsRoot {
return []string{RootCapability} return []string{RootCapability}, []string{"*"}
} }
subscribeEventTypes = res.SubscribeEventTypes
capabilities := res.CapabilitiesBitmap capabilities := res.CapabilitiesBitmap
if capabilities&SudoCapabilityInt > 0 { if capabilities&SudoCapabilityInt > 0 {
@ -331,9 +339,15 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [
if capabilities&DenyCapabilityInt > 0 || len(pathCapabilities) == 0 { if capabilities&DenyCapabilityInt > 0 || len(pathCapabilities) == 0 {
pathCapabilities = []string{DenyCapability} pathCapabilities = []string{DenyCapability}
} }
return return
} }
func (a *ACL) Capabilities(ctx context.Context, path string) []string {
pathCapabilities, _ := a.CapabilitiesAndSubscribeEventTypes(ctx, path)
return pathCapabilities
}
// AllowOperation is used to check if the given operation is permitted. // AllowOperation is used to check if the given operation is permitted.
func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheckOnly bool) (ret *ACLResults) { func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheckOnly bool) (ret *ACLResults) {
ret = new(ACLResults) ret = new(ACLResults)
@ -349,6 +363,7 @@ func (a *ACL) AllowOperation(ctx context.Context, req *logical.Request, capCheck
NamespacePath: "", NamespacePath: "",
Type: "acl", Type: "acl",
}} }}
ret.SubscribeEventTypes = []string{"*"}
return return
} }
op := req.Operation op := req.Operation
@ -414,6 +429,7 @@ CHECK:
// rather than policy root // rather than policy root
if capCheckOnly { if capCheckOnly {
ret.CapabilitiesBitmap = capabilities ret.CapabilitiesBitmap = capabilities
ret.SubscribeEventTypes = slices.Clone(permissions.SubscribeEventTypes)
return ret return ret
} }

View File

@ -12,30 +12,38 @@ import (
) )
// Capabilities is used to fetch the capabilities of the given token on the // Capabilities is used to fetch the capabilities of the given token on the
// given path // given path.
func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, error) { func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, error) {
capabilities, _, err := c.CapabilitiesAndSubscribeEventTypes(ctx, token, path)
return capabilities, err
}
// CapabilitiesAndSubscribeEventTypes is used to fetch the capabilities and event types that are allowed to
// be subscribed to by given token on the given path.
func (c *Core) CapabilitiesAndSubscribeEventTypes(ctx context.Context, token, path string) ([]string, []string, error) {
if path == "" { if path == "" {
return nil, &logical.StatusBadRequest{Err: "missing path"} return nil, nil, &logical.StatusBadRequest{Err: "missing path"}
} }
if token == "" { if token == "" {
return nil, &logical.StatusBadRequest{Err: "missing token"} return nil, nil, &logical.StatusBadRequest{Err: "missing token"}
} }
te, err := c.tokenStore.Lookup(ctx, token) te, err := c.tokenStore.Lookup(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if te == nil { if te == nil {
return nil, &logical.StatusBadRequest{Err: "invalid token"} return nil, nil, &logical.StatusBadRequest{Err: "invalid token"}
} }
tokenNS, err := NamespaceByID(ctx, te.NamespaceID, c) var tokenNS *namespace.Namespace
tokenNS, err = NamespaceByID(ctx, te.NamespaceID, c)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if tokenNS == nil { if tokenNS == nil {
return nil, namespace.ErrNoNamespace return nil, nil, namespace.ErrNoNamespace
} }
var policyCount int var policyCount int
@ -45,15 +53,15 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string,
entity, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, tokenNS, te.EntityID, te.NoIdentityPolicies) entity, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, tokenNS, te.EntityID, te.NoIdentityPolicies)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if entity != nil && entity.Disabled { if entity != nil && entity.Disabled {
c.logger.Warn("permission denied as the entity on the token is disabled") c.logger.Warn("permission denied as the entity on the token is disabled")
return nil, logical.ErrPermissionDenied return nil, nil, logical.ErrPermissionDenied
} }
if te.EntityID != "" && entity == nil { if te.EntityID != "" && entity == nil {
c.logger.Warn("permission denied as the entity on the token is invalid") c.logger.Warn("permission denied as the entity on the token is invalid")
return nil, logical.ErrPermissionDenied return nil, nil, logical.ErrPermissionDenied
} }
for nsID, nsPolicies := range identityPolicies { for nsID, nsPolicies := range identityPolicies {
@ -66,14 +74,14 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string,
if te.InlinePolicy != "" { if te.InlinePolicy != "" {
inlinePolicy, err := ParseACLPolicy(tokenNS, te.InlinePolicy) inlinePolicy, err := ParseACLPolicy(tokenNS, te.InlinePolicy)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
policies = append(policies, inlinePolicy) policies = append(policies, inlinePolicy)
policyCount++ policyCount++
} }
if policyCount == 0 { if policyCount == 0 {
return []string{DenyCapability}, nil return []string{DenyCapability}, nil, nil
} }
// Construct the corresponding ACL object. ACL construction should be // Construct the corresponding ACL object. ACL construction should be
@ -81,10 +89,10 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string,
tokenCtx := namespace.ContextWithNamespace(ctx, tokenNS) tokenCtx := namespace.ContextWithNamespace(ctx, tokenNS)
acl, err := c.policyStore.ACL(tokenCtx, entity, policyNames, policies...) acl, err := c.policyStore.ACL(tokenCtx, entity, policyNames, policies...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
capabilities := acl.Capabilities(ctx, path) capabilities, eventTypes := acl.CapabilitiesAndSubscribeEventTypes(ctx, path)
sort.Strings(capabilities) sort.Strings(capabilities)
return capabilities, nil return capabilities, eventTypes, nil
} }