mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-14 18:47:01 +02:00
The WebSocket tests have been very flaky because we weren't able to tell when a WebSocket was fully connected and subscribed to events.
We reworked the websocket subscription code to accept the websocket only after subscribing.
This should eliminate all flakiness in these tests. 🤞 (We can follow-up in an enterprise PR to simplify some of the tests after this fix is merged.)
I ran this locally a bunch of times and with data race detection enabled, and did not see any failures.
Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
411 lines
12 KiB
Go
411 lines
12 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package http
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/audit"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/hashicorp/vault/vault/cluster"
|
|
"github.com/stretchr/testify/assert"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
// TestEventsSubscribe tests the websocket endpoint for subscribing to events
|
|
// by generating some events.
|
|
func TestEventsSubscribe(t *testing.T) {
|
|
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{})
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, token := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
const eventType = "abc"
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// send some events
|
|
sendEvents := func() error {
|
|
id, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pluginInfo := &logical.EventPluginInfo{
|
|
MountPath: "secret",
|
|
}
|
|
err = core.Events().SendEventInternal(namespace.RootContext(ctx), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
|
|
Id: id,
|
|
Metadata: nil,
|
|
EntityIds: nil,
|
|
Note: "testing",
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
|
|
testCases := []struct {
|
|
json bool
|
|
}{{true}, {false}}
|
|
|
|
for _, testCase := range testCases {
|
|
location := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?namespaces=ns1&namespaces=ns*&json=%v", wsAddr, eventType, testCase.json)
|
|
conn, _, err := websocket.Dial(ctx, location, &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{token}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() {
|
|
conn.Close(websocket.StatusNormalClosure, "")
|
|
})
|
|
|
|
err = sendEvents()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_, msg, err := conn.Read(ctx)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if testCase.json {
|
|
event := map[string]interface{}{}
|
|
err = json.Unmarshal(msg, &event)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Log(string(msg))
|
|
data := event["data"].(map[string]interface{})
|
|
if actualType := data["event_type"].(string); actualType != eventType {
|
|
t.Fatalf("Expeced event type %s, got %s", eventType, actualType)
|
|
}
|
|
pluginInfo, ok := data["plugin_info"].(map[string]interface{})
|
|
if !ok || pluginInfo == nil {
|
|
t.Fatalf("No plugin_info object: %v", data)
|
|
}
|
|
mountPath, ok := pluginInfo["mount_path"].(string)
|
|
if !ok || mountPath != "secret" {
|
|
t.Fatalf("Wrong mount_path: %v", data)
|
|
}
|
|
innerEvent := data["event"].(map[string]interface{})
|
|
if innerEvent["id"].(string) != event["id"].(string) {
|
|
t.Fatalf("IDs don't match, expected %s, got %s", innerEvent["id"].(string), event["id"].(string))
|
|
}
|
|
if innerEvent["note"].(string) != "testing" {
|
|
t.Fatalf("Expected 'testing', got %s", innerEvent["note"].(string))
|
|
}
|
|
|
|
checkRequiredCloudEventsFields(t, event)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestBexprFilters tests that go-bexpr filters are used to filter events.
|
|
func TestBexprFilters(t *testing.T) {
|
|
core := vault.TestCoreWithConfig(t, &vault.CoreConfig{})
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, token := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
sendEvents := func(ctx context.Context, eventTypes ...string) error {
|
|
for _, eventType := range eventTypes {
|
|
pluginInfo := &logical.EventPluginInfo{
|
|
MountPath: "secret",
|
|
}
|
|
ns := namespace.RootNamespace
|
|
id := eventType
|
|
err := core.Events().SendEventInternal(namespace.RootContext(ctx), ns, pluginInfo, logical.EventType(eventType), &logical.EventData{
|
|
Id: id,
|
|
Metadata: nil,
|
|
EntityIds: nil,
|
|
Note: "testing",
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
bexprFilter := url.QueryEscape("event_type == abc")
|
|
|
|
location := fmt.Sprintf("%s/v1/sys/events/subscribe/*?json=true&filter=%s", wsAddr, bexprFilter)
|
|
conn, _, err := websocket.Dial(ctx, location, &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{token}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
|
|
|
err = sendEvents(ctx, "abc", "def", "xyz")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// read until we time out
|
|
seen := map[string]bool{}
|
|
done := false
|
|
for !done {
|
|
done = func() bool {
|
|
readCtx, readCancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
defer readCancel()
|
|
_, msg, err := conn.Read(readCtx)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
event := map[string]interface{}{}
|
|
err = json.Unmarshal(msg, &event)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return true
|
|
}
|
|
seen[event["id"].(string)] = true
|
|
return false
|
|
}()
|
|
}
|
|
// we should only get the "abc" messages
|
|
assert.Len(t, seen, 1)
|
|
assert.Contains(t, seen, "abc")
|
|
}
|
|
|
|
func TestNamespacePrepend(t *testing.T) {
|
|
testCases := []struct {
|
|
requestNs string
|
|
patterns []string
|
|
result []string
|
|
}{
|
|
{"", []string{"ns*"}, []string{"", "ns*"}},
|
|
{"ns1", []string{"ns*"}, []string{"ns1", "ns1/ns*"}},
|
|
{"ns1", []string{"ns1*"}, []string{"ns1", "ns1/ns1*"}},
|
|
{"ns1", []string{"ns1/*"}, []string{"ns1", "ns1/ns1/*"}},
|
|
{"", []string{"ns1/ns13", "ns1/other"}, []string{"", "ns1/ns13", "ns1/other"}},
|
|
{"ns1", []string{"ns1/ns13", "ns1/other"}, []string{"ns1", "ns1/ns1/ns13", "ns1/ns1/other"}},
|
|
{"", []string{""}, []string{""}},
|
|
{"", nil, []string{""}},
|
|
{"ns1", []string{""}, []string{"ns1"}},
|
|
{"ns1", []string{"", ""}, []string{"ns1"}},
|
|
{"ns1", []string{"ns1"}, []string{"ns1", "ns1/ns1"}},
|
|
{"", []string{"*"}, []string{"", "*"}},
|
|
{"ns1", []string{"*"}, []string{"ns1", "ns1/*"}},
|
|
{"", []string{"ns1/ns13*", "ns2"}, []string{"", "ns1/ns13*", "ns2"}},
|
|
{"ns1", []string{"ns1/ns13*", "ns2"}, []string{"ns1", "ns1/ns1/ns13*", "ns1/ns2"}},
|
|
{"", []string{"ns*", "ns1"}, []string{"", "ns*", "ns1"}},
|
|
{"ns1", []string{"ns*", "ns1"}, []string{"ns1", "ns1/ns*", "ns1/ns1"}},
|
|
{"ns1", []string{"ns1*", "ns1"}, []string{"ns1", "ns1/ns1*", "ns1/ns1"}},
|
|
{"ns1", []string{"ns1/*", "ns1"}, []string{"ns1", "ns1/ns1/*", "ns1/ns1"}},
|
|
}
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.requestNs+" "+strings.Join(testCase.patterns, " "), func(t *testing.T) {
|
|
result := prependNamespacePatterns(testCase.patterns, &namespace.Namespace{ID: testCase.requestNs, Path: testCase.requestNs})
|
|
assert.Equal(t, testCase.result, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func checkRequiredCloudEventsFields(t *testing.T, event map[string]interface{}) {
|
|
t.Helper()
|
|
for _, attr := range []string{"id", "source", "specversion", "type"} {
|
|
if v, ok := event[attr]; !ok {
|
|
t.Errorf("Missing attribute %s", attr)
|
|
} else if str, ok := v.(string); !ok {
|
|
t.Errorf("Expected %s to be string but got %T", attr, v)
|
|
} else if str == "" {
|
|
t.Errorf("%s was empty string", attr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions
|
|
// fail correctly.
|
|
func TestEventsSubscribeAuth(t *testing.T) {
|
|
core := vault.TestCore(t)
|
|
ln, addr := TestServer(t, core)
|
|
defer ln.Close()
|
|
|
|
// unseal the core
|
|
keys, root := vault.TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
_, err := core.Unseal(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
var nonPrivilegedToken string
|
|
// Fetch a valid non privileged token.
|
|
{
|
|
config := api.DefaultConfig()
|
|
config.Address = addr
|
|
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(root)
|
|
|
|
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret.Auth.ClientToken == "" {
|
|
t.Fatal("Failed to fetch a non privileged token")
|
|
}
|
|
nonPrivilegedToken = secret.Auth.ClientToken
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
|
|
|
// Get a 403 with no token.
|
|
_, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil)
|
|
if err == nil {
|
|
t.Error("Expected websocket error but got none")
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected 403 but got %+v", resp)
|
|
}
|
|
|
|
// Get a 403 with a non privileged token.
|
|
_, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{
|
|
HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}},
|
|
})
|
|
if err == nil {
|
|
t.Error("Expected websocket error but got none")
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected 403 but got %+v", resp)
|
|
}
|
|
}
|
|
|
|
func TestCanForwardEventConnections(t *testing.T) {
|
|
// Run again with in-memory network
|
|
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{
|
|
Mutex: &sync.Mutex{},
|
|
Level: hclog.Trace,
|
|
Name: "inmem-cluster",
|
|
}))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster := vault.NewTestCluster(t, &vault.CoreConfig{
|
|
AuditBackends: map[string]audit.Factory{
|
|
"nop": corehelpers.NoopAuditFactory(nil),
|
|
},
|
|
}, &vault.TestClusterOptions{
|
|
ClusterLayers: inmemCluster,
|
|
})
|
|
cores := testCluster.Cores
|
|
testCluster.Start()
|
|
defer testCluster.Cleanup()
|
|
|
|
rootToken := testCluster.RootToken
|
|
|
|
// Wait for core to become active
|
|
vault.TestWaitActiveForwardingReady(t, cores[0].Core)
|
|
|
|
// Test forwarding a request. Since we're going directly from core to core
|
|
// with no fallback we know that if it worked, request handling is working
|
|
c := cores[1]
|
|
standby, err := c.Standby()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !standby {
|
|
t.Fatal("expected core to be standby")
|
|
}
|
|
|
|
// We need to call Leader as that refreshes the connection info
|
|
isLeader, _, _, err := c.Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if isLeader {
|
|
t.Fatal("core should not be leader")
|
|
}
|
|
corehelpers.RetryUntil(t, 5*time.Second, func() error {
|
|
state := c.ActiveNodeReplicationState()
|
|
if state == 0 {
|
|
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req = req.WithContext(namespace.RootContext(req.Context()))
|
|
req.Header.Add(consts.AuthHeaderName, rootToken)
|
|
|
|
resp := httptest.NewRecorder()
|
|
forwardRequest(cores[1].Core, resp, req)
|
|
|
|
header := resp.Header()
|
|
if header == nil {
|
|
t.Fatal("err: expected at least a Location header")
|
|
}
|
|
if !strings.HasPrefix(header.Get("Location"), "wss://") {
|
|
t.Fatalf("bad location: %s", header.Get("Location"))
|
|
}
|
|
|
|
// test forwarding requests to each core
|
|
handled := 0
|
|
forwarded := 0
|
|
for _, c := range cores {
|
|
resp := httptest.NewRecorder()
|
|
fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handled++
|
|
})
|
|
handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req)
|
|
header := resp.Header()
|
|
if header == nil {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(header.Get("Location"), "wss://") {
|
|
forwarded++
|
|
}
|
|
}
|
|
if handled != 1 && forwarded != 2 {
|
|
t.Fatalf("Expected 1 core to handle the request and 2 to forward")
|
|
}
|
|
}
|