omni/internal/pkg/auth/handler/handler_test.go
Edward Sammut Alessi d3ae77c0cc
chore: bump copyright to 2026
Bump copyright for conformance to 2026

Signed-off-by: Edward Sammut Alessi <edward.sammutalessi@siderolabs.com>
2026-01-21 15:30:49 +01:00

174 lines
4.1 KiB
Go

// Copyright (c) 2026 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package handler_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/handler"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
func testHandler(t *testing.T, authEnabled bool) {
ctxCh := make(chan context.Context, 1)
coreHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
ctxCh <- r.Context()
})
logger := zaptest.NewLogger(t)
authenticatorFunc := func(context.Context, string) (*auth.Authenticator, error) { //nolint:unparam
return &auth.Authenticator{
Verifier: mockSignerVerifier{},
Identity: "user@example.com",
UserID: "user-id",
Role: role.Operator,
}, nil
}
testServer := func(signatureRequired message.SignatureRequiredCheckFunc) *httptest.Server {
wrapWithAuth := func(h http.Handler) http.Handler {
signatureHandler := handler.NewSignature(h, authenticatorFunc, logger, message.WithSignatureRequiredCheck(signatureRequired))
return handler.NewAuthConfig(signatureHandler, authEnabled, logger)
}
return httptest.NewServer(wrapWithAuth(coreHandler))
}
ctx := t.Context()
type testCase struct { //nolint:govet
name string
uri string
signRequest bool
appendSignature []byte
verifyContext bool
expectedCode int
public bool
}
var testCases []testCase
if authEnabled {
testCases = []testCase{
{
name: "no signature",
expectedCode: http.StatusOK,
uri: "/ok",
public: true,
},
{
name: "correct signature",
expectedCode: http.StatusOK,
signRequest: true,
verifyContext: true,
uri: "/ok",
},
{
name: "broken signature",
expectedCode: http.StatusUnauthorized,
signRequest: true,
appendSignature: []byte("broken"),
uri: "/fail",
},
}
} else {
testCases = []testCase{
{
name: "no signature",
expectedCode: http.StatusOK,
uri: "/",
},
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := testServer(func() (bool, error) {
return authEnabled && !tc.public, nil
})
defer ts.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+tc.uri, nil)
require.NoError(t, err)
if tc.signRequest {
var msg *message.HTTP
msg, err = message.NewHTTP(req)
require.NoError(t, err)
require.NoError(t, msg.Sign("foo@example.com", mockSignerVerifier{tc.appendSignature}))
}
resp, err := http.DefaultClient.Do(req) //nolint:bodyclose
require.NoError(t, err)
require.Equal(t, tc.expectedCode, resp.StatusCode)
if tc.expectedCode != http.StatusOK {
return
}
var reqCtx context.Context
select {
case reqCtx = <-ctxCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
ctxAuthEnabledVal, ok := ctxstore.Value[auth.EnabledAuthContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, authEnabled, ctxAuthEnabledVal.Enabled)
if !tc.verifyContext {
return
}
ctxUserIDVal, ok := ctxstore.Value[auth.UserIDContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, "user-id", ctxUserIDVal.UserID)
ctxRoleVal, ok := ctxstore.Value[auth.RoleContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, role.Operator, ctxRoleVal.Role)
ctxIdentityVal, ok := ctxstore.Value[auth.IdentityContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, "user@example.com", ctxIdentityVal.Identity)
})
}
}
func TestHandler(t *testing.T) {
t.Run("AuthEnabled", func(t *testing.T) {
testHandler(t, true)
})
t.Run("AuthDisabled", func(t *testing.T) {
testHandler(t, false)
})
}