mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-14 10:37:00 +02:00
The WebSocket tests have been very flaky because we weren't able to tell when a WebSocket was fully connected and subscribed to events.
We reworked the websocket subscription code to accept the websocket only after subscribing.
This should eliminate all flakiness in these tests. 🤞 (We can follow-up in an enterprise PR to simplify some of the tests after this fix is merged.)
I ran this locally a bunch of times and with data race detection enabled, and did not see any failures.
Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
349 lines
11 KiB
Go
349 lines
11 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package http
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"path"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/hashicorp/vault/vault/eventbus"
|
|
"github.com/patrickmn/go-cache"
|
|
"github.com/ryanuber/go-glob"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
// webSocketRevalidationTime is how often we re-check access to the
|
|
// events that the websocket requested access to.
|
|
var webSocketRevalidationTime = 5 * time.Minute
|
|
|
|
type eventSubscriber struct {
|
|
ctx context.Context
|
|
cancelCtx context.CancelFunc
|
|
clientToken string
|
|
logger hclog.Logger
|
|
events *eventbus.EventBus
|
|
namespacePatterns []string
|
|
pattern string
|
|
bexprFilter string
|
|
json bool
|
|
checkCache *cache.Cache
|
|
isRootToken bool
|
|
core *vault.Core
|
|
w http.ResponseWriter
|
|
r *http.Request
|
|
req *logical.Request
|
|
}
|
|
|
|
// handleEventsSubscribeWebsocket subscribes to the events, accepts the websocket connection, and then runs forever,
|
|
// serving events to the websocket connection.
|
|
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() {
|
|
ctx := sub.ctx
|
|
logger := sub.logger
|
|
// subscribe before accept to avoid race conditions
|
|
ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter)
|
|
if err != nil {
|
|
logger.Info("Error subscribing", "error", err)
|
|
sub.w.WriteHeader(400)
|
|
sub.w.Write([]byte("Error subscribing"))
|
|
return
|
|
}
|
|
defer cancel()
|
|
logger.Debug("WebSocket is subscribed to messages", "namespaces", sub.namespacePatterns, "event_types", sub.pattern, "bexpr_filter", sub.bexprFilter)
|
|
|
|
conn, err := websocket.Accept(sub.w, sub.r, nil)
|
|
if err != nil {
|
|
logger.Info("Could not accept as websocket", "error", err)
|
|
respondError(sub.w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
|
|
return
|
|
}
|
|
|
|
// continually validate subscribe access while the websocket is running
|
|
// this has to be done after accepting the websocket to avoid a race condition
|
|
go sub.validateSubscribeAccessLoop()
|
|
|
|
// make sure to close the websocket
|
|
closeStatus := websocket.StatusNormalClosure
|
|
closeReason := ""
|
|
var closeErr error = nil
|
|
|
|
defer func() {
|
|
if closeErr != nil {
|
|
closeStatus = websocket.CloseStatus(err)
|
|
if closeStatus == -1 {
|
|
closeStatus = websocket.StatusInternalError
|
|
}
|
|
closeReason = fmt.Sprintf("Internal error: %v", err)
|
|
logger.Debug("Error from websocket handler", "error", err)
|
|
}
|
|
// Close() will panic if the reason is greater than this length
|
|
if len(closeReason) > 123 {
|
|
logger.Debug("Truncated close reason", "closeReason", closeReason)
|
|
closeReason = closeReason[:123]
|
|
}
|
|
err = conn.Close(closeStatus, closeReason)
|
|
if err != nil {
|
|
logger.Debug("Error closing websocket", "error", err)
|
|
}
|
|
}()
|
|
|
|
// we don't expect any incoming messages
|
|
ctx = conn.CloseRead(ctx)
|
|
// start the pinger
|
|
go func() {
|
|
for {
|
|
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
|
|
err := conn.Ping(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Info("Websocket context is done, closing the connection")
|
|
return
|
|
case message := <-ch:
|
|
// Perform one last check that the message is allowed to be received.
|
|
// For example, if a new namespace was created that matches the namespace patterns,
|
|
// but the token doesn't have access to it, we don't want to accidentally send it to
|
|
// the websocket.
|
|
if !sub.allowMessageCached(message.Payload.(*logical.EventReceived)) {
|
|
continue
|
|
}
|
|
|
|
logger.Debug("Sending message to websocket", "message", message.Payload)
|
|
var messageBytes []byte
|
|
var messageType websocket.MessageType
|
|
if sub.json {
|
|
var ok bool
|
|
messageBytes, ok = message.Format("cloudevents-json")
|
|
if !ok {
|
|
logger.Warn("Could not get cloudevents JSON format")
|
|
closeErr = errors.New("could not get cloudevents JSON format")
|
|
return
|
|
}
|
|
messageType = websocket.MessageText
|
|
} else {
|
|
messageBytes, err = proto.Marshal(message.Payload.(*logical.EventReceived))
|
|
messageType = websocket.MessageBinary
|
|
}
|
|
if err != nil {
|
|
logger.Warn("Could not serialize websocket event", "error", err)
|
|
closeErr = err
|
|
return
|
|
}
|
|
err = conn.Write(ctx, messageType, messageBytes)
|
|
if err != nil {
|
|
closeErr = err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// allowMessageCached checks that the message is allowed to received by the websocket.
|
|
// It caches results for specific namespaces, data paths, and event types.
|
|
func (sub *eventSubscriber) allowMessageCached(message *logical.EventReceived) bool {
|
|
if sub.isRootToken {
|
|
// fast-path root tokens
|
|
return true
|
|
}
|
|
|
|
messageNs := strings.Trim(message.Namespace, "/")
|
|
dataPath := ""
|
|
if message.Event.Metadata != nil {
|
|
dataPathField := message.Event.Metadata.GetFields()[logical.EventMetadataDataPath]
|
|
if dataPathField != nil {
|
|
dataPath = dataPathField.GetStringValue()
|
|
}
|
|
}
|
|
if dataPath == "" {
|
|
// Only allow root tokens to subscribe to events with no data path, for now.
|
|
return false
|
|
}
|
|
cacheKey := fmt.Sprintf("%v!%v!%v", messageNs, dataPath, message.EventType)
|
|
_, ok := sub.checkCache.Get(cacheKey)
|
|
if ok {
|
|
return true
|
|
}
|
|
|
|
// perform the actual check and cache it if true
|
|
ok = sub.allowMessage(messageNs, dataPath, message.EventType)
|
|
if ok {
|
|
err := sub.checkCache.Add(cacheKey, ok, webSocketRevalidationTime)
|
|
if err != nil {
|
|
sub.logger.Debug("Error adding to policy check cache for websocket", "error", err)
|
|
// still return the right value, but we can't guarantee it was cached
|
|
}
|
|
}
|
|
return ok
|
|
}
|
|
|
|
// allowMessage checks that the message is allowed to received by the websocket
|
|
func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bool {
|
|
// does this even match the requested namespaces
|
|
matchedNs := false
|
|
for _, nsPattern := range sub.namespacePatterns {
|
|
if glob.Glob(nsPattern, eventNs) {
|
|
matchedNs = true
|
|
break
|
|
}
|
|
}
|
|
if !matchedNs {
|
|
return false
|
|
}
|
|
|
|
// next check for specific access to the namespace and event types
|
|
nsDataPath := dataPath
|
|
if eventNs != "" {
|
|
nsDataPath = path.Join(eventNs, dataPath)
|
|
}
|
|
capabilities, allowedEventTypes, err := sub.core.CapabilitiesAndSubscribeEventTypes(sub.ctx, sub.clientToken, nsDataPath)
|
|
if err != nil {
|
|
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
|
|
return false
|
|
}
|
|
if !(slices.Contains(capabilities, vault.RootCapability) || slices.Contains(capabilities, vault.SubscribeCapability)) {
|
|
return false
|
|
}
|
|
for _, pattern := range allowedEventTypes {
|
|
if glob.Glob(pattern, eventType) {
|
|
return true
|
|
}
|
|
}
|
|
// no event types matched, so return false
|
|
return false
|
|
}
|
|
|
|
func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
logger := core.Logger().Named("events-subscribe")
|
|
logger.Debug("Got request to", "url", r.URL, "version", r.Proto)
|
|
|
|
ctx := r.Context()
|
|
|
|
// ACL check
|
|
auth, entry, err := core.CheckToken(ctx, req, false)
|
|
if err != nil {
|
|
if errors.Is(err, logical.ErrPermissionDenied) {
|
|
respondError(w, http.StatusForbidden, logical.ErrPermissionDenied)
|
|
return
|
|
}
|
|
logger.Debug("Error validating token", "error", err)
|
|
respondError(w, http.StatusInternalServerError, fmt.Errorf("error validating token"))
|
|
return
|
|
}
|
|
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
logger.Info("Could not find namespace", "error", err)
|
|
respondError(w, http.StatusInternalServerError, fmt.Errorf("could not find namespace"))
|
|
return
|
|
}
|
|
|
|
prefix := "/v1/sys/events/subscribe/"
|
|
if ns.ID != namespace.RootNamespaceID {
|
|
prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path)
|
|
}
|
|
pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
|
|
if pattern == "" {
|
|
respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to"))
|
|
return
|
|
}
|
|
|
|
json := false
|
|
jsonRaw := r.URL.Query().Get("json")
|
|
if jsonRaw != "" {
|
|
var err error
|
|
json, err = strconv.ParseBool(jsonRaw)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, fmt.Errorf("invalid parameter for JSON: %v", jsonRaw))
|
|
return
|
|
}
|
|
}
|
|
|
|
bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter"))
|
|
namespacePatterns := r.URL.Query()["namespaces"]
|
|
namespacePatterns = prependNamespacePatterns(namespacePatterns, ns)
|
|
isRoot := entry.IsRoot()
|
|
ctx, cancelCtx := context.WithCancel(ctx)
|
|
defer cancelCtx()
|
|
|
|
sub := &eventSubscriber{
|
|
ctx: ctx,
|
|
cancelCtx: cancelCtx,
|
|
logger: logger,
|
|
events: core.Events(),
|
|
namespacePatterns: namespacePatterns,
|
|
pattern: pattern,
|
|
bexprFilter: bexprFilter,
|
|
json: json,
|
|
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
|
|
clientToken: auth.ClientToken,
|
|
isRootToken: isRoot,
|
|
core: core,
|
|
w: w,
|
|
r: r,
|
|
req: req,
|
|
}
|
|
sub.handleEventsSubscribeWebsocket()
|
|
})
|
|
}
|
|
|
|
// prependNamespacePatterns prepends the request namespace to the namespace patterns,
|
|
// and also adds the request namespace to the list.
|
|
func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Namespace) []string {
|
|
prepend := strings.Trim(requestNamespace.Path, "/")
|
|
newPatterns := make([]string, 0, len(patterns)+1)
|
|
newPatterns = append(newPatterns, prepend)
|
|
for _, pattern := range patterns {
|
|
if strings.Trim(pattern, "/") != "" {
|
|
newPatterns = append(newPatterns, path.Join(prepend, pattern))
|
|
}
|
|
}
|
|
return newPatterns
|
|
}
|
|
|
|
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
|
|
// its namespace. If the access check ever fails, then the cancel function is called and the function returns.
|
|
func (sub *eventSubscriber) validateSubscribeAccessLoop() {
|
|
// if something breaks, default to canceling the websocket
|
|
defer sub.cancelCtx()
|
|
for {
|
|
_, _, err := sub.core.CheckTokenWithLock(sub.ctx, sub.req, false)
|
|
if err != nil {
|
|
sub.core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", sub.req.Path, "error", err)
|
|
return
|
|
}
|
|
// wait a while and try again, but quit the loop if the context finishes early
|
|
finished := func() bool {
|
|
ticker := time.NewTicker(webSocketRevalidationTime)
|
|
defer ticker.Stop()
|
|
select {
|
|
case <-sub.ctx.Done():
|
|
return true
|
|
case <-ticker.C:
|
|
return false
|
|
}
|
|
}()
|
|
if finished {
|
|
return
|
|
}
|
|
}
|
|
}
|