From 36174bc913bd03b9448cd0ae60bc06a4eb069033 Mon Sep 17 00:00:00 2001 From: Austin Gebauer <34121980+austingebauer@users.noreply.github.com> Date: Mon, 28 Aug 2023 09:17:33 -0700 Subject: [PATCH] Fixes events subscribe for non-root namespaces (#22580) * Fixes events subscribe for non-root namespaces * Adds a test --- vault/eventbus/bus.go | 2 +- vault/eventbus/bus_test.go | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 87fb3a0589..4e54400f7f 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -202,7 +202,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { } func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *eventlogger.Event, context.CancelFunc, error) { - return bus.SubscribeMultipleNamespaces(ctx, []string{ns.Path}, pattern) + return bus.SubscribeMultipleNamespaces(ctx, []string{strings.Trim(ns.Path, "/")}, pattern) } func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespacePathPatterns []string, pattern string) (<-chan *eventlogger.Event, context.CancelFunc, error) { diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 8546fb565d..54d1009c07 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -73,6 +73,50 @@ func TestBusBasics(t *testing.T) { } } +// TestSubscribeNonRootNamespace verifies that events for non-root namespaces +// aren't filtered out by the bus. +func TestSubscribeNonRootNamespace(t *testing.T) { + bus, err := NewEventBus(nil) + if err != nil { + t.Fatal(err) + } + bus.Start() + ctx := context.Background() + + eventType := logical.EventType("someType") + + ns := &namespace.Namespace{ + ID: "abc", + Path: "abc/", + } + + ch, cancel, err := bus.Subscribe(ctx, ns, string(eventType)) + if err != nil { + t.Fatal(err) + } + defer cancel() + + event, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + + err = bus.SendEventInternal(ctx, ns, nil, eventType, event) + if err != nil { + t.Error(err) + } + + timeout := time.After(1 * time.Second) + select { + case message := <-ch: + if message.Payload.(*logical.EventReceived).Event.Id != event.Id { + t.Errorf("Got unexpected message: %+v", message) + } + case <-timeout: + t.Error("Timeout waiting for message") + } +} + // TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus. func TestNamespaceFiltering(t *testing.T) { bus, err := NewEventBus(nil)