vault/http/events_test.go
Christopher Swenson f0a23e117f
events: Continuously verify policies (#22705)
Previously, when a user initiated a websocket subscription,
the access to the `sys/events/subscribe` endpoint was checked then,
and only once.

Now, perform continuous policy checks:

* We check access to the `sys/events/subscribe` endpoint every five
  minutes. If this check fails, then the websocket is terminated.
* Upon receiving any message, we verify that the `subscribe`
  capability is present for that namespace, data path, and event type.
  If it is not, then the message is not delivered. If the message is
  allowed, we cache that result for five minutes.

Tests for this are in a separate enterprise PR.

Documentation will be updated in another PR.

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
2023-09-05 16:28:09 -07:00

460 lines
13 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package http
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/experiments"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/cluster"
"github.com/stretchr/testify/assert"
"nhooyr.io/websocket"
)
// TestEventsSubscribe tests the websocket endpoint for subscribing to events
// by generating some events.
func TestEventsSubscribe(t *testing.T) {
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
})
ln, addr := TestServer(t, core)
defer ln.Close()
// unseal the core
keys, token := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}
stop := atomic.Bool{}
const eventType = "abc"
// send some events
go func() {
for !stop.Load() {
id, err := uuid.GenerateUUID()
if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
}
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
core.Logger().Info("Error sending event, exiting sender", "error", err)
}
time.Sleep(100 * time.Millisecond)
}
}()
t.Cleanup(func() {
stop.Store(true)
})
ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)
testCases := []struct {
json bool
}{{true}, {false}}
for _, testCase := range testCases {
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=%v", wsAddr, eventType, testCase.json)
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{token}},
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
conn.Close(websocket.StatusNormalClosure, "")
})
_, msg, err := conn.Read(ctx)
if err != nil {
t.Fatal(err)
}
if testCase.json {
event := map[string]interface{}{}
err = json.Unmarshal(msg, &event)
if err != nil {
t.Fatal(err)
}
t.Log(string(msg))
data := event["data"].(map[string]interface{})
if actualType := data["event_type"].(string); actualType != eventType {
t.Fatalf("Expeced event type %s, got %s", eventType, actualType)
}
pluginInfo, ok := data["plugin_info"].(map[string]interface{})
if !ok || pluginInfo == nil {
t.Fatalf("No plugin_info object: %v", data)
}
mountPath, ok := pluginInfo["mount_path"].(string)
if !ok || mountPath != "secret" {
t.Fatalf("Wrong mount_path: %v", data)
}
innerEvent := data["event"].(map[string]interface{})
if innerEvent["id"].(string) != event["id"].(string) {
t.Fatalf("IDs don't match, expected %s, got %s", innerEvent["id"].(string), event["id"].(string))
}
if innerEvent["note"].(string) != "testing" {
t.Fatalf("Expected 'testing', got %s", innerEvent["note"].(string))
}
checkRequiredCloudEventsFields(t, event)
}
}
}
func TestNamespaceRootSubscriptions(t *testing.T) {
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
})
ln, addr := TestServer(t, core)
defer ln.Close()
// unseal the core
keys, token := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}
stop := atomic.Bool{}
const eventType = "abc"
// send some events with the specified namespaces
sendEvents := func() error {
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
ns := namespace.RootNamespace
id, err := uuid.GenerateUUID()
if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
return err
}
err = core.Events().SendEventInternal(namespace.RootContext(context.Background()), ns, pluginInfo, eventType, &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
core.Logger().Info("Error sending event, exiting sender", "error", err)
return err
}
return nil
}
t.Cleanup(func() {
stop.Store(true)
})
ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)
testCases := []struct {
name string
namespaces []string
expectedEvents int
}{
// We only send events in the root namespace, but we test all the various patterns of namespace patterns.
{"single", []string{"something"}, 1},
{"simple wildcard", []string{"ns*"}, 1},
{"two namespaces", []string{"ns1/ns13", "ns1/other"}, 1},
{"no namespace", []string{""}, 1},
{"all wildcard", []string{"*"}, 1},
{"mixed wildcard", []string{"ns1/ns13*", "ns2"}, 1},
{"overlapping wildcard", []string{"ns*", "ns1"}, 1},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
extra := ""
for _, ns := range testCase.namespaces {
extra += "&namespaces=" + ns
}
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=true%v", wsAddr, eventType, extra)
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{token}},
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
conn.Close(websocket.StatusNormalClosure, "")
})
err = sendEvents()
if err != nil {
t.Fatal(err)
}
// CI is sometimes slow, so this timeout is high initially
timeout := 10 * time.Second
gotEvents := 0
for {
// if we got as many as we expect, shorten the test, so we don't waste time,
// but still allow time for "extra" events to come in and make us fail
if gotEvents == testCase.expectedEvents {
timeout = 100 * time.Millisecond
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)
_, msg, err := conn.Read(ctx)
if err != nil {
t.Log("error reading from connection", err)
break
}
event := map[string]interface{}{}
err = json.Unmarshal(msg, &event)
if err != nil {
t.Fatal(err)
}
t.Log("event received", string(msg))
gotEvents += 1
}
assert.Equal(t, testCase.expectedEvents, gotEvents)
})
}
}
func TestNamespacePrepend(t *testing.T) {
testCases := []struct {
requestNs string
patterns []string
result []string
}{
{"", []string{"ns*"}, []string{"", "ns*"}},
{"ns1", []string{"ns*"}, []string{"ns1", "ns1/ns*"}},
{"ns1", []string{"ns1*"}, []string{"ns1", "ns1/ns1*"}},
{"ns1", []string{"ns1/*"}, []string{"ns1", "ns1/ns1/*"}},
{"", []string{"ns1/ns13", "ns1/other"}, []string{"", "ns1/ns13", "ns1/other"}},
{"ns1", []string{"ns1/ns13", "ns1/other"}, []string{"ns1", "ns1/ns1/ns13", "ns1/ns1/other"}},
{"", []string{""}, []string{""}},
{"", nil, []string{""}},
{"ns1", []string{""}, []string{"ns1"}},
{"ns1", []string{"", ""}, []string{"ns1"}},
{"ns1", []string{"ns1"}, []string{"ns1", "ns1/ns1"}},
{"", []string{"*"}, []string{"", "*"}},
{"ns1", []string{"*"}, []string{"ns1", "ns1/*"}},
{"", []string{"ns1/ns13*", "ns2"}, []string{"", "ns1/ns13*", "ns2"}},
{"ns1", []string{"ns1/ns13*", "ns2"}, []string{"ns1", "ns1/ns1/ns13*", "ns1/ns2"}},
{"", []string{"ns*", "ns1"}, []string{"", "ns*", "ns1"}},
{"ns1", []string{"ns*", "ns1"}, []string{"ns1", "ns1/ns*", "ns1/ns1"}},
{"ns1", []string{"ns1*", "ns1"}, []string{"ns1", "ns1/ns1*", "ns1/ns1"}},
{"ns1", []string{"ns1/*", "ns1"}, []string{"ns1", "ns1/ns1/*", "ns1/ns1"}},
}
for _, testCase := range testCases {
t.Run(testCase.requestNs+" "+strings.Join(testCase.patterns, " "), func(t *testing.T) {
result := prependNamespacePatterns(testCase.patterns, &namespace.Namespace{ID: testCase.requestNs, Path: testCase.requestNs})
assert.Equal(t, testCase.result, result)
})
}
}
func checkRequiredCloudEventsFields(t *testing.T, event map[string]interface{}) {
t.Helper()
for _, attr := range []string{"id", "source", "specversion", "type"} {
if v, ok := event[attr]; !ok {
t.Errorf("Missing attribute %s", attr)
} else if str, ok := v.(string); !ok {
t.Errorf("Expected %s to be string but got %T", attr, v)
} else if str == "" {
t.Errorf("%s was empty string", attr)
}
}
}
// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions
// fail correctly.
func TestEventsSubscribeAuth(t *testing.T) {
core := vault.TestCore(t)
ln, addr := TestServer(t, core)
defer ln.Close()
// unseal the core
keys, root := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}
var nonPrivilegedToken string
// Fetch a valid non privileged token.
{
config := api.DefaultConfig()
config.Address = addr
client, err := api.NewClient(config)
if err != nil {
t.Fatal(err)
}
client.SetToken(root)
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatal("Failed to fetch a non privileged token")
}
nonPrivilegedToken = secret.Auth.ClientToken
}
ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)
// Get a 403 with no token.
_, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil)
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}
// Get a 403 with a non privileged token.
_, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}},
})
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}
}
func TestCanForwardEventConnections(t *testing.T) {
// Run again with in-memory network
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{
Mutex: &sync.Mutex{},
Level: hclog.Trace,
Name: "inmem-cluster",
}))
if err != nil {
t.Fatal(err)
}
testCluster := vault.NewTestCluster(t, &vault.CoreConfig{
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
AuditBackends: map[string]audit.Factory{
"nop": corehelpers.NoopAuditFactory(nil),
},
}, &vault.TestClusterOptions{
ClusterLayers: inmemCluster,
})
cores := testCluster.Cores
testCluster.Start()
defer testCluster.Cleanup()
rootToken := testCluster.RootToken
// Wait for core to become active
vault.TestWaitActiveForwardingReady(t, cores[0].Core)
// Test forwarding a request. Since we're going directly from core to core
// with no fallback we know that if it worked, request handling is working
c := cores[1]
standby, err := c.Standby()
if err != nil {
t.Fatal(err)
}
if !standby {
t.Fatal("expected core to be standby")
}
// We need to call Leader as that refreshes the connection info
isLeader, _, _, err := c.Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatal("core should not be leader")
}
corehelpers.RetryUntil(t, 5*time.Second, func() error {
state := c.ActiveNodeReplicationState()
if state == 0 {
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
}
return nil
})
req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil)
if err != nil {
t.Fatal(err)
}
req = req.WithContext(namespace.RootContext(req.Context()))
req.Header.Add(consts.AuthHeaderName, rootToken)
resp := httptest.NewRecorder()
forwardRequest(cores[1].Core, resp, req)
header := resp.Header()
if header == nil {
t.Fatal("err: expected at least a Location header")
}
if !strings.HasPrefix(header.Get("Location"), "wss://") {
t.Fatalf("bad location: %s", header.Get("Location"))
}
// test forwarding requests to each core
handled := 0
forwarded := 0
for _, c := range cores {
resp := httptest.NewRecorder()
fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handled++
})
handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req)
header := resp.Header()
if header == nil {
continue
}
if strings.HasPrefix(header.Get("Location"), "wss://") {
forwarded++
}
}
if handled != 1 && forwarded != 2 {
t.Fatalf("Expected 1 core to handle the request and 2 to forward")
}
}