Events API uses consistent error codes (#19246)

This commit is contained in:
Tom Proctor 2023-02-20 16:24:27 +00:00 committed by GitHub
parent e416190238
commit 4c11d090cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 14 deletions

View File

@ -84,7 +84,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
_, _, err := core.CheckToken(ctx, req, false)
if err != nil {
if errors.Is(err, logical.ErrPermissionDenied) {
respondError(w, http.StatusUnauthorized, logical.ErrPermissionDenied)
respondError(w, http.StatusForbidden, logical.ErrPermissionDenied)
return
}
logger.Debug("Error validating token", "error", err)

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
@ -34,7 +35,7 @@ func TestEventsSubscribe(t *testing.T) {
stop := atomic.Bool{}
eventType := "abc"
const eventType = "abc"
// send some events
go func() {
@ -43,7 +44,10 @@ func TestEventsSubscribe(t *testing.T) {
if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
}
err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, nil, logical.EventType(eventType), &logical.EventData{
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
@ -60,9 +64,7 @@ func TestEventsSubscribe(t *testing.T) {
stop.Store(true)
})
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancelFunc)
ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)
testCases := []struct {
@ -71,14 +73,6 @@ func TestEventsSubscribe(t *testing.T) {
for _, testCase := range testCases {
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=%v", wsAddr, eventType, testCase.json)
// check that the connection fails if we don't have a token
_, _, err := websocket.Dial(ctx, url, nil)
if err == nil {
t.Error("Expected websocket error but got none")
} else if !strings.HasSuffix(err.Error(), "401") {
t.Errorf("Expected 401 websocket but got %v", err)
}
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{token}},
})
@ -101,3 +95,65 @@ func TestEventsSubscribe(t *testing.T) {
}
}
}
// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions
// fail correctly.
func TestEventsSubscribeAuth(t *testing.T) {
core := vault.TestCore(t)
ln, addr := TestServer(t, core)
defer ln.Close()
// unseal the core
keys, root := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}
var nonPrivilegedToken string
// Fetch a valid non privileged token.
{
config := api.DefaultConfig()
config.Address = addr
client, err := api.NewClient(config)
if err != nil {
t.Fatal(err)
}
client.SetToken(root)
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatal("Failed to fetch a non privileged token")
}
nonPrivilegedToken = secret.Auth.ClientToken
}
ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)
// Get a 403 with no token.
_, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil)
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}
// Get a 403 with a non privileged token.
_, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}},
})
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}
}