mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-15 11:07:00 +02:00
Previously, when a user initiated a websocket subscription, the access to the `sys/events/subscribe` endpoint was checked then, and only once. Now, perform continuous policy checks: * We check access to the `sys/events/subscribe` endpoint every five minutes. If this check fails, then the websocket is terminated. * Upon receiving any message, we verify that the `subscribe` capability is present for that namespace, data path, and event type. If it is not, then the message is not delivered. If the message is allowed, we cache that result for five minutes. Tests for this are in a separate enterprise PR. Documentation will be updated in another PR. Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com> Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
460 lines
13 KiB
Go
460 lines
13 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package http
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/audit"
|
|
"github.com/hashicorp/vault/helper/experiments"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/hashicorp/vault/vault/cluster"
|
|
"github.com/stretchr/testify/assert"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
// TestEventsSubscribe tests the websocket endpoint for subscribing to events
|
|
// by generating some events.
|
|
func TestEventsSubscribe(t *testing.T) {
|
|
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{
|
|
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
|
|
})
|
|
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, token := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
stop := atomic.Bool{}
|
|
|
|
const eventType = "abc"
|
|
|
|
// send some events
|
|
go func() {
|
|
for !stop.Load() {
|
|
id, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
|
|
}
|
|
pluginInfo := &logical.EventPluginInfo{
|
|
MountPath: "secret",
|
|
}
|
|
err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
|
|
Id: id,
|
|
Metadata: nil,
|
|
EntityIds: nil,
|
|
Note: "testing",
|
|
})
|
|
if err != nil {
|
|
core.Logger().Info("Error sending event, exiting sender", "error", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
stop.Store(true)
|
|
})
|
|
|
|
ctx := context.Background()
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
|
|
testCases := []struct {
|
|
json bool
|
|
}{{true}, {false}}
|
|
|
|
for _, testCase := range testCases {
|
|
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=%v", wsAddr, eventType, testCase.json)
|
|
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{token}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() {
|
|
conn.Close(websocket.StatusNormalClosure, "")
|
|
})
|
|
|
|
_, msg, err := conn.Read(ctx)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if testCase.json {
|
|
event := map[string]interface{}{}
|
|
err = json.Unmarshal(msg, &event)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Log(string(msg))
|
|
data := event["data"].(map[string]interface{})
|
|
if actualType := data["event_type"].(string); actualType != eventType {
|
|
t.Fatalf("Expeced event type %s, got %s", eventType, actualType)
|
|
}
|
|
pluginInfo, ok := data["plugin_info"].(map[string]interface{})
|
|
if !ok || pluginInfo == nil {
|
|
t.Fatalf("No plugin_info object: %v", data)
|
|
}
|
|
mountPath, ok := pluginInfo["mount_path"].(string)
|
|
if !ok || mountPath != "secret" {
|
|
t.Fatalf("Wrong mount_path: %v", data)
|
|
}
|
|
innerEvent := data["event"].(map[string]interface{})
|
|
if innerEvent["id"].(string) != event["id"].(string) {
|
|
t.Fatalf("IDs don't match, expected %s, got %s", innerEvent["id"].(string), event["id"].(string))
|
|
}
|
|
if innerEvent["note"].(string) != "testing" {
|
|
t.Fatalf("Expected 'testing', got %s", innerEvent["note"].(string))
|
|
}
|
|
|
|
checkRequiredCloudEventsFields(t, event)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNamespaceRootSubscriptions(t *testing.T) {
|
|
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{
|
|
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
|
|
})
|
|
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, token := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
stop := atomic.Bool{}
|
|
|
|
const eventType = "abc"
|
|
|
|
// send some events with the specified namespaces
|
|
sendEvents := func() error {
|
|
pluginInfo := &logical.EventPluginInfo{
|
|
MountPath: "secret",
|
|
}
|
|
ns := namespace.RootNamespace
|
|
id, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
|
|
return err
|
|
}
|
|
err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{
|
|
Id: id,
|
|
Metadata: nil,
|
|
EntityIds: nil,
|
|
Note: "testing",
|
|
})
|
|
if err != nil {
|
|
core.Logger().Info("Error sending event, exiting sender", "error", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
stop.Store(true)
|
|
})
|
|
|
|
ctx := context.Background()
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
namespaces []string
|
|
expectedEvents int
|
|
}{
|
|
// We only send events in the root namespace, but we test all the various patterns of namespace patterns.
|
|
{"single", []string{"something"}, 1},
|
|
{"simple wildcard", []string{"ns*"}, 1},
|
|
{"two namespaces", []string{"ns1/ns13", "ns1/other"}, 1},
|
|
{"no namespace", []string{""}, 1},
|
|
{"all wildcard", []string{"*"}, 1},
|
|
{"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 1},
|
|
{"overlapping wildcard", []string{"ns*", "ns1"}, 1},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
extra := ""
|
|
for _, ns := range testCase.namespaces {
|
|
extra += "&namespaces=" + ns
|
|
}
|
|
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=true%v", wsAddr, eventType, extra)
|
|
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{token}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() {
|
|
conn.Close(websocket.StatusNormalClosure, "")
|
|
})
|
|
err = sendEvents()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// CI is sometimes slow, so this timeout is high initially
|
|
timeout := 10 * time.Second
|
|
gotEvents := 0
|
|
for {
|
|
// if we got as many as we expect, shorten the test, so we don't waste time,
|
|
// but still allow time for "extra" events to come in and make us fail
|
|
if gotEvents == testCase.expectedEvents {
|
|
timeout = 100 * time.Millisecond
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
t.Cleanup(cancel)
|
|
|
|
_, msg, err := conn.Read(ctx)
|
|
if err != nil {
|
|
t.Log("error reading from connection", err)
|
|
break
|
|
}
|
|
|
|
event := map[string]interface{}{}
|
|
err = json.Unmarshal(msg, &event)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("event received", string(msg))
|
|
gotEvents += 1
|
|
}
|
|
|
|
assert.Equal(t, testCase.expectedEvents, gotEvents)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNamespacePrepend(t *testing.T) {
|
|
testCases := []struct {
|
|
requestNs string
|
|
patterns []string
|
|
result []string
|
|
}{
|
|
{"", []string{"ns*"}, []string{"", "ns*"}},
|
|
{"ns1", []string{"ns*"}, []string{"ns1", "ns1/ns*"}},
|
|
{"ns1", []string{"ns1*"}, []string{"ns1", "ns1/ns1*"}},
|
|
{"ns1", []string{"ns1/*"}, []string{"ns1", "ns1/ns1/*"}},
|
|
{"", []string{"ns1/ns13", "ns1/other"}, []string{"", "ns1/ns13", "ns1/other"}},
|
|
{"ns1", []string{"ns1/ns13", "ns1/other"}, []string{"ns1", "ns1/ns1/ns13", "ns1/ns1/other"}},
|
|
{"", []string{""}, []string{""}},
|
|
{"", nil, []string{""}},
|
|
{"ns1", []string{""}, []string{"ns1"}},
|
|
{"ns1", []string{"", ""}, []string{"ns1"}},
|
|
{"ns1", []string{"ns1"}, []string{"ns1", "ns1/ns1"}},
|
|
{"", []string{"*"}, []string{"", "*"}},
|
|
{"ns1", []string{"*"}, []string{"ns1", "ns1/*"}},
|
|
{"", []string{"ns1/ns13*", "ns2"}, []string{"", "ns1/ns13*", "ns2"}},
|
|
{"ns1", []string{"ns1/ns13*", "ns2"}, []string{"ns1", "ns1/ns1/ns13*", "ns1/ns2"}},
|
|
{"", []string{"ns*", "ns1"}, []string{"", "ns*", "ns1"}},
|
|
{"ns1", []string{"ns*", "ns1"}, []string{"ns1", "ns1/ns*", "ns1/ns1"}},
|
|
{"ns1", []string{"ns1*", "ns1"}, []string{"ns1", "ns1/ns1*", "ns1/ns1"}},
|
|
{"ns1", []string{"ns1/*", "ns1"}, []string{"ns1", "ns1/ns1/*", "ns1/ns1"}},
|
|
}
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.requestNs+" "+strings.Join(testCase.patterns, " "), func(t *testing.T) {
|
|
result := prependNamespacePatterns(testCase.patterns, &namespace.Namespace{ID: testCase.requestNs, Path: testCase.requestNs})
|
|
assert.Equal(t, testCase.result, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func checkRequiredCloudEventsFields(t *testing.T, event map[string]interface{}) {
|
|
t.Helper()
|
|
for _, attr := range []string{"id", "source", "specversion", "type"} {
|
|
if v, ok := event[attr]; !ok {
|
|
t.Errorf("Missing attribute %s", attr)
|
|
} else if str, ok := v.(string); !ok {
|
|
t.Errorf("Expected %s to be string but got %T", attr, v)
|
|
} else if str == "" {
|
|
t.Errorf("%s was empty string", attr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions
|
|
// fail correctly.
|
|
func TestEventsSubscribeAuth(t *testing.T) {
|
|
core := vault.TestCore(t)
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, root := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
var nonPrivilegedToken string
|
|
// Fetch a valid non privileged token.
|
|
{
|
|
config := api.DefaultConfig()
|
|
config.Address = addr
|
|
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(root)
|
|
|
|
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret.Auth.ClientToken == "" {
|
|
t.Fatal("Failed to fetch a non privileged token")
|
|
}
|
|
nonPrivilegedToken = secret.Auth.ClientToken
|
|
}
|
|
|
|
ctx := context.Background()
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
|
|
// Get a 403 with no token.
|
|
_, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil)
|
|
if err == nil {
|
|
t.Error("Expected websocket error but got none")
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected 403 but got %+v", resp)
|
|
}
|
|
|
|
// Get a 403 with a non privileged token.
|
|
_, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}},
|
|
})
|
|
if err == nil {
|
|
t.Error("Expected websocket error but got none")
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected 403 but got %+v", resp)
|
|
}
|
|
}
|
|
|
|
func TestCanForwardEventConnections(t *testing.T) {
|
|
// Run again with in-memory network
|
|
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{
|
|
Mutex: &sync.Mutex{},
|
|
Level: hclog.Trace,
|
|
Name: "inmem-cluster",
|
|
}))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster := vault.NewTestCluster(t, &vault.CoreConfig{
|
|
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
|
|
AuditBackends: map[string]audit.Factory{
|
|
"nop": corehelpers.NoopAuditFactory(nil),
|
|
},
|
|
}, &vault.TestClusterOptions{
|
|
ClusterLayers: inmemCluster,
|
|
})
|
|
cores := testCluster.Cores
|
|
testCluster.Start()
|
|
defer testCluster.Cleanup()
|
|
|
|
rootToken := testCluster.RootToken
|
|
|
|
// Wait for core to become active
|
|
vault.TestWaitActiveForwardingReady(t, cores[0].Core)
|
|
|
|
// Test forwarding a request. Since we're going directly from core to core
|
|
// with no fallback we know that if it worked, request handling is working
|
|
c := cores[1]
|
|
standby, err := c.Standby()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !standby {
|
|
t.Fatal("expected core to be standby")
|
|
}
|
|
|
|
// We need to call Leader as that refreshes the connection info
|
|
isLeader, _, _, err := c.Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if isLeader {
|
|
t.Fatal("core should not be leader")
|
|
}
|
|
corehelpers.RetryUntil(t, 5*time.Second, func() error {
|
|
state := c.ActiveNodeReplicationState()
|
|
if state == 0 {
|
|
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req = req.WithContext(namespace.RootContext(req.Context()))
|
|
req.Header.Add(consts.AuthHeaderName, rootToken)
|
|
|
|
resp := httptest.NewRecorder()
|
|
forwardRequest(cores[1].Core, resp, req)
|
|
|
|
header := resp.Header()
|
|
if header == nil {
|
|
t.Fatal("err: expected at least a Location header")
|
|
}
|
|
if !strings.HasPrefix(header.Get("Location"), "wss://") {
|
|
t.Fatalf("bad location: %s", header.Get("Location"))
|
|
}
|
|
|
|
// test forwarding requests to each core
|
|
handled := 0
|
|
forwarded := 0
|
|
for _, c := range cores {
|
|
resp := httptest.NewRecorder()
|
|
fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handled++
|
|
})
|
|
handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req)
|
|
header := resp.Header()
|
|
if header == nil {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(header.Get("Location"), "wss://") {
|
|
forwarded++
|
|
}
|
|
}
|
|
if handled != 1 && forwarded != 2 {
|
|
t.Fatalf("Expected 1 core to handle the request and 2 to forward")
|
|
}
|
|
}
|