chore: rework auth.* keys, add ctxstore package

Using so-called phantom types we can use the types themselves as keys directly without loosing performance.
You no longer need to remember which type was attached to the thing you passed in context and can look up
all fields access directly.

Part of #37

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
This commit is contained in:
Dmitriy Matrenichev 2024-07-15 16:21:22 +03:00
parent 76263e12a4
commit 4cfc0e6dd0
No known key found for this signature in database
GPG Key ID: 94B473337258BFD5
27 changed files with 315 additions and 148 deletions

View File

@ -44,6 +44,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/user"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/features"
"github.com/siderolabs/omni/internal/pkg/siderolink"
"github.com/siderolabs/omni/internal/version"
@ -235,7 +236,7 @@ func runWithState(logger *zap.Logger) func(context.Context, state.State, *virtua
return fmt.Errorf("failed to update features config resources: %w", err)
}
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, authres.Enabled(authConfig))
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: authres.Enabled(authConfig)})
handler, err := backend.NewFrontendHandler(rootCmdArgs.frontendDst, logger)
if err != nil {

View File

@ -45,6 +45,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
//go:embed testdata/admin-kubeconfig.yaml
@ -193,11 +194,11 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
md = metadata.New(nil)
}
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, true)
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: true})
msg := message.NewGRPC(md, info.FullMethod)
ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: msg})
if r := md.Get("role"); len(r) > 0 {
var parsed role.Role
@ -207,7 +208,7 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
return nil, err
}
ctx = context.WithValue(ctx, auth.RoleContextKey{}, parsed)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: parsed})
}
return handler(ctx, req)

View File

@ -60,6 +60,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/siderolink"
)
@ -939,9 +940,10 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
return nil, err
}
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}
newRole, err := role.Max(userRole, clusterRole)
@ -953,7 +955,7 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
return ctx, nil
}
return context.WithValue(ctx, auth.RoleContextKey{}, newRole), nil
return ctxstore.WithValue(ctx, auth.RoleContextKey{Role: newRole}), nil
}
func handleError(err error) error {

View File

@ -22,6 +22,7 @@ import (
"github.com/siderolabs/omni/internal/backend/dns"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/grpcutil"
)
@ -78,9 +79,8 @@ func (backend *TalosBackend) GetConnection(ctx context.Context, fullMethodName s
// we can't use regular gRPC server interceptors here, as proxy interface is a bit different
// prepare context values for the verifier
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, backend.authEnabled)
msg := message.NewGRPC(md, fullMethodName)
ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: backend.authEnabled})
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, fullMethodName)})
grpcutil.SetShouldLog(ctx, "talos-backend")

View File

@ -17,7 +17,7 @@ import (
)
// clusterContextKey is a type for cluster name.
type clusterContextKey struct{}
type clusterContextKey struct{ ClusterName string }
// Handler implements the HTTP reverse proxy for Kubernetes clusters.
//

View File

@ -18,6 +18,8 @@ import (
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.uber.org/zap"
"k8s.io/client-go/transport"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
const authorizationHeader = "Authorization"
@ -107,7 +109,7 @@ func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolve
}
// clone the request before modifying it
req = req.WithContext(context.WithValue(ctx, clusterContextKey{}, clusterName))
req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName}))
// clean all headers which are going to be overridden
req.Header.Del(authorizationHeader)

View File

@ -26,6 +26,7 @@ import (
"k8s.io/client-go/transport"
"github.com/siderolabs/omni/internal/backend/k8sproxy"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) {
@ -276,9 +277,9 @@ func TestAuthorize(t *testing.T) {
assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader))
assert.Nil(t, receivedReq.Header.Values("Authorization"))
v, ok := receivedReq.Context().Value(k8sproxy.ClusterContextKey{}).(string)
v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck
assert.True(t, ok)
assert.Equal(t, tc.expectedCluster, v)
assert.Equal(t, tc.expectedCluster, v.ClusterName)
})
}
}

View File

@ -24,6 +24,7 @@ import (
"github.com/siderolabs/omni/client/api/common"
"github.com/siderolabs/omni/internal/backend/runtime"
"github.com/siderolabs/omni/internal/backend/runtime/kubernetes"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// multiplexer provides an http.RoundTripper which selects the cluster based on the request context.
@ -74,12 +75,12 @@ func newMultiplexer() *multiplexer {
// RoundTrip implements http.RoundTripper interface.
func (m *multiplexer) RoundTrip(req *http.Request) (*http.Response, error) {
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
if !ok {
return nil, errors.New("cluster name not found in request context")
}
rt, err := m.getRT(req.Context(), clusterName)
rt, err := m.getRT(req.Context(), clusterNameVal.ClusterName)
if err != nil {
return nil, err
}

View File

@ -13,6 +13,7 @@ import (
"go.uber.org/zap"
"github.com/siderolabs/omni/internal/backend/logging"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// proxyHandler implements the HTTP reverse proxy.
@ -44,14 +45,14 @@ func newProxyHandler(m *multiplexer, logger *zap.Logger) *proxyHandler {
// director sets the target URL for the reverse proxy.
func (p *proxyHandler) director(req *http.Request) {
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
if !ok {
ctxzap.Error(req.Context(), "cluster name not found in request context")
return
}
connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterName)
connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterNameVal.ClusterName)
if err != nil {
ctxzap.Error(req.Context(), "failed to get cluster connector", zap.Error(err))

View File

@ -32,6 +32,7 @@ import (
"github.com/siderolabs/omni/internal/backend/workloadproxy"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// using whitelisted for external API access type.
@ -72,7 +73,7 @@ func (suite *OmniRuntimeSuite) SetupTest() {
suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute)
// disable auth in the context
suite.ctx = context.WithValue(suite.ctx, auth.EnabledAuthContextKey{}, false)
suite.ctx = ctxstore.WithValue(suite.ctx, auth.EnabledAuthContextKey{Enabled: false})
var err error

View File

@ -28,6 +28,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var (
@ -257,7 +258,7 @@ func checkForRole(ctx context.Context, st state.State, access state.Access, clus
if clusterRole != role.None && (!requireAll || (requireAll && matchesAll)) {
// override the role in the context with the computed role for this cluster
ctx = context.WithValue(ctx, auth.RoleContextKey{}, clusterRole)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: clusterRole})
}
}

View File

@ -25,6 +25,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// State is a virtual state implementation which provides virtual resources.
@ -161,16 +162,17 @@ func (v *State) validateKind(kind resource.Kind) error {
}
func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
identity, _ := ctx.Value(auth.IdentityContextKey{}).(string) //nolint:errcheck
identityVal, _ := ctxstore.Value[auth.IdentityContextKey](ctx)
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}
user := virtual.NewCurrentUser()
user.TypedSpec().Value.Identity = identity
user.TypedSpec().Value.Identity = identityVal.Identity
user.TypedSpec().Value.Role = string(userRole)
version, err := resource.ParseVersion("1")
@ -184,9 +186,10 @@ func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
}
func (v *State) permissions(ctx context.Context) (*virtual.Permissions, error) {
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}
permissions := virtual.NewPermissions()

View File

@ -23,6 +23,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// RoleProvider provides the current actor's role for a cluster.
@ -119,10 +120,10 @@ func (p *PGPAccessValidator) ValidateAccess(ctx context.Context, publicKeyID, pu
return parseErr
}
ctx = context.WithValue(ctx, auth.RoleContextKey{}, publicKeyRole)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: publicKeyRole})
}
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, publicKey.TypedSpec().Value.GetIdentity().GetEmail())
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: publicKey.TypedSpec().Value.GetIdentity().GetEmail()})
accessRole, err := p.roleProvider.RoleForCluster(ctx, clusterID)
if err != nil {

View File

@ -18,13 +18,15 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// RoleForCluster returns the role of the current user for the given cluster, and whether the role matches all clusters.
func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.Role, bool, error) {
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}
ctx = actor.MarkContextAsInternalActor(ctx)
@ -38,12 +40,12 @@ func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.R
return role.None, false, err
}
identityStr, identityExists := ctx.Value(auth.IdentityContextKey{}).(string)
identityVal, identityExists := ctxstore.Value[auth.IdentityContextKey](ctx)
if !identityExists {
return userRole, false, nil
}
identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityStr).Metadata())
identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityVal.Identity).Metadata())
if err != nil {
if state.IsNotFoundError(err) {
return userRole, false, nil

View File

@ -6,17 +6,23 @@
// Package actor implements the context marking for internal/external actors.
package actor
import "context"
import (
"context"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// internalActorContextKey is the key for internal actor context.
type internalActorContextKey struct{}
// MarkContextAsInternalActor returns a new derived context from the given context, marked as an internal actor.
func MarkContextAsInternalActor(ctx context.Context) context.Context {
return context.WithValue(ctx, internalActorContextKey{}, struct{}{})
return ctxstore.WithValue(ctx, internalActorContextKey{})
}
// ContextIsInternalActor returns true if the given context is marked as an internal actor.
func ContextIsInternalActor(ctx context.Context) bool {
return ctx.Value(internalActorContextKey{}) != nil
_, ok := ctxstore.Value[internalActorContextKey](ctx)
return ok
}

View File

@ -14,6 +14,7 @@ import (
"google.golang.org/grpc/status"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var (
@ -81,19 +82,19 @@ func WithVerifiedEmail() CheckOption {
//
// The returned error can be checked against ErrUnauthenticated and ErrUnauthorized.
func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
authEnabled, ok := ctx.Value(EnabledAuthContextKey{}).(bool)
authVal, ok := ctxstore.Value[EnabledAuthContextKey](ctx)
if !ok {
return CheckResult{}, fmt.Errorf("%w: auth configuration not found in context", ErrUnauthenticated)
}
if !authEnabled {
if !authVal.Enabled {
return CheckResult{
AuthEnabled: false,
}, nil
}
result := CheckResult{
AuthEnabled: authEnabled,
AuthEnabled: authVal.Enabled,
}
opts := DefaultCheckOptions()
@ -108,22 +109,25 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
}
if opts.VerifiedEmail {
email, ok := ctx.Value(VerifiedEmailContextKey{}).(string)
emailVal, ok := ctxstore.Value[VerifiedEmailContextKey](ctx)
if !ok {
return CheckResult{}, fmt.Errorf("%w: missing verified email", ErrUnauthenticated)
}
result.VerifiedEmail = email
result.VerifiedEmail = emailVal.Email
}
ctxRole, ctxRoleExists := ctx.Value(RoleContextKey{}).(role.Role)
if !ctxRoleExists {
ctxRole = role.None
ctxRole := role.None
ctxRoleExists := false
if val, ok := ctxstore.Value[RoleContextKey](ctx); ok {
ctxRole = val.Role
ctxRoleExists = true
}
result.Role = ctxRole
// RoleContextKey{} is set on the context only when there is a valid signature, so we can rely on this.
// RoleContextKey is set on the context only when there is a valid signature, so we can rely on this.
result.HasValidSignature = ctxRoleExists
if opts.ValidSignature && !result.HasValidSignature {
@ -137,12 +141,12 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
}
}
if identity, ok := ctx.Value(IdentityContextKey{}).(string); ok {
result.Identity = identity
if val, ok := ctxstore.Value[IdentityContextKey](ctx); ok {
result.Identity = val.Identity
}
if userID, ok := ctx.Value(UserIDContextKey{}).(string); ok {
result.UserID = userID
if val, ok := ctxstore.Value[UserIDContextKey](ctx); ok {
result.UserID = val.UserID
}
return result, nil

View File

@ -13,6 +13,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
func TestCheck(t *testing.T) {
@ -30,18 +31,20 @@ func TestCheck(t *testing.T) {
},
{
name: "auth disabled",
ctx: context.WithValue(
ctx: ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
false,
auth.EnabledAuthContextKey{
Enabled: false,
},
),
},
{
name: "not authenticated, no requirements",
ctx: context.WithValue(
ctx: ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
want: auth.CheckResult{
AuthEnabled: true,
@ -50,44 +53,49 @@ func TestCheck(t *testing.T) {
},
{
name: "not authenticated, verified email",
ctx: context.WithValue(
ctx: ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
opts: []auth.CheckOption{auth.WithVerifiedEmail()},
errorIs: auth.ErrUnauthenticated,
},
{
name: "not authenticated, none role",
ctx: context.WithValue(
ctx: ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
opts: []auth.CheckOption{auth.WithValidSignature(true)},
errorIs: auth.ErrUnauthenticated,
},
{
name: "not authenticated, operator role",
ctx: context.WithValue(
ctx: ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
opts: []auth.CheckOption{auth.WithRole(role.Operator)},
errorIs: auth.ErrUnauthenticated,
},
{
name: "verified email",
ctx: context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.VerifiedEmailContextKey{},
"user@example.com",
auth.VerifiedEmailContextKey{
Email: "user@example.com",
},
),
opts: []auth.CheckOption{auth.WithVerifiedEmail()},
want: auth.CheckResult{
@ -98,14 +106,16 @@ func TestCheck(t *testing.T) {
},
{
name: "role okay",
ctx: context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.RoleContextKey{},
role.Operator,
auth.RoleContextKey{
Role: role.Operator,
},
),
opts: []auth.CheckOption{auth.WithRole(role.Operator)},
want: auth.CheckResult{
@ -116,36 +126,42 @@ func TestCheck(t *testing.T) {
},
{
name: "role mismatch",
ctx: context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.RoleContextKey{},
role.Operator,
auth.RoleContextKey{
Role: role.Operator,
},
),
opts: []auth.CheckOption{auth.WithRole(role.Admin)},
errorIs: auth.ErrUnauthorized,
},
{
name: "role and verified email",
ctx: context.WithValue(
context.WithValue(
context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.RoleContextKey{},
role.Operator,
auth.RoleContextKey{
Role: role.Operator,
},
),
auth.VerifiedEmailContextKey{},
"user@example.com",
auth.VerifiedEmailContextKey{
Email: "user@example.com",
},
),
auth.IdentityContextKey{},
"user2@example.com",
auth.IdentityContextKey{
Identity: "user2@example.com",
},
),
opts: []auth.CheckOption{auth.WithRole(role.Operator), auth.WithVerifiedEmail()},
want: auth.CheckResult{
@ -158,14 +174,16 @@ func TestCheck(t *testing.T) {
},
{
name: "valid signature",
ctx: context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.RoleContextKey{},
role.None,
auth.RoleContextKey{
Role: role.None,
},
),
opts: []auth.CheckOption{},
want: auth.CheckResult{
@ -176,14 +194,16 @@ func TestCheck(t *testing.T) {
},
{
name: "missing signature",
ctx: context.WithValue(
context.WithValue(
ctx: ctxstore.WithValue(
ctxstore.WithValue(
context.Background(),
auth.EnabledAuthContextKey{},
true,
auth.EnabledAuthContextKey{
Enabled: true,
},
),
auth.VerifiedEmailContextKey{},
"me@example.com",
auth.VerifiedEmailContextKey{
Email: "me@example.com",
},
),
opts: []auth.CheckOption{auth.WithValidSignature(true)},
errorIs: auth.ErrUnauthenticated,

View File

@ -5,20 +5,26 @@
package auth
// EnabledAuthContextKey is the context key for enabled authentication. Value has the type bool.
type EnabledAuthContextKey struct{}
import (
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/siderolabs/omni/internal/pkg/auth/role"
)
// EnabledAuthContextKey is the context key for enabled authentication.
type EnabledAuthContextKey struct{ Enabled bool }
// GRPCMessageContextKey is the context key for the GRPC message. It is only set if authentication is enabled.
type GRPCMessageContextKey struct{}
type GRPCMessageContextKey struct{ Message *message.GRPC }
// VerifiedEmailContextKey is the context key for the verified email address. Value has the type string.
type VerifiedEmailContextKey struct{}
// VerifiedEmailContextKey is the context key for the verified email address.
type VerifiedEmailContextKey struct{ Email string }
// UserIDContextKey is the context key for the user ID. Value has the type string.
type UserIDContextKey struct{}
type UserIDContextKey struct{ UserID string }
// RoleContextKey is the context key for the role. Value has the type role.Role.
type RoleContextKey struct{}
type RoleContextKey struct{ Role role.Role }
// IdentityContextKey is the context key for the user identity. Value has the type string.
type IdentityContextKey struct{}
// IdentityContextKey is the context key for the user identity.
type IdentityContextKey struct{ Identity string }

View File

@ -6,12 +6,12 @@
package handler
import (
"context"
"net/http"
"go.uber.org/zap"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// AuthConfig represents the configuration for the auth config interceptor.
@ -31,7 +31,7 @@ func NewAuthConfig(handler http.Handler, enabled bool, logger *zap.Logger) *Auth
}
func (c *AuthConfig) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
ctx := context.WithValue(request.Context(), auth.EnabledAuthContextKey{}, c.enabled)
ctx := ctxstore.WithValue(request.Context(), auth.EnabledAuthContextKey{Enabled: c.enabled})
request = request.WithContext(ctx)
c.next.ServeHTTP(writer, request)

View File

@ -20,6 +20,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/handler"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
func testHandler(t *testing.T, authEnabled bool) {
@ -130,25 +131,25 @@ func testHandler(t *testing.T, authEnabled bool) {
t.Fatal("timeout")
}
ctxAuthEnabled, ok := reqCtx.Value(auth.EnabledAuthContextKey{}).(bool)
ctxAuthEnabledVal, ok := ctxstore.Value[auth.EnabledAuthContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, authEnabled, ctxAuthEnabled)
assert.Equal(t, authEnabled, ctxAuthEnabledVal.Enabled)
if !tc.verifyContext {
return
}
ctxUserID, ok := reqCtx.Value(auth.UserIDContextKey{}).(string)
ctxUserIDVal, ok := ctxstore.Value[auth.UserIDContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, "user-id", ctxUserID)
assert.Equal(t, "user-id", ctxUserIDVal.UserID)
ctxRole, ok := reqCtx.Value(auth.RoleContextKey{}).(role.Role)
ctxRoleVal, ok := ctxstore.Value[auth.RoleContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, role.Operator, ctxRole)
assert.Equal(t, role.Operator, ctxRoleVal.Role)
ctxIdentity, ok := reqCtx.Value(auth.IdentityContextKey{}).(string)
ctxIdentityVal, ok := ctxstore.Value[auth.IdentityContextKey](reqCtx) //nolint:contextcheck
require.True(t, ok)
assert.Equal(t, "user@example.com", ctxIdentity)
assert.Equal(t, "user@example.com", ctxIdentityVal.Identity)
})
}
}

View File

@ -15,6 +15,7 @@ import (
"go.uber.org/zap"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var errInvalidSignature = errors.New("invalid signature")
@ -102,9 +103,9 @@ func (s *Signature) intercept(request *http.Request) (*http.Request, error) {
Set("authenticator.identity", authenticator.Identity).
Set("authenticator.role", string(authenticator.Role))
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity)
ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID)
ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role)
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity})
ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID})
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role})
return request.WithContext(ctx), nil
}

View File

@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
// AuthConfig represents the configuration for the auth config interceptor.
@ -53,7 +54,7 @@ func (c *AuthConfig) Stream() grpc.StreamServerInterceptor {
}
func (c *AuthConfig) intercept(ctx context.Context, method string) context.Context {
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, c.enabled)
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: c.enabled})
if !c.enabled {
return ctx
@ -64,7 +65,5 @@ func (c *AuthConfig) intercept(ctx context.Context, method string) context.Conte
md = metadata.New(nil)
}
msg := message.NewGRPC(md, method)
return context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
return ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, method)})
}

View File

@ -20,6 +20,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/auth0"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var errGRPCInvalidJWT = status.Error(codes.Unauthenticated, "invalid jwt")
@ -66,12 +67,12 @@ func (i *JWT) Stream() grpc.StreamServerInterceptor {
}
func (i *JWT) intercept(ctx context.Context) (context.Context, error) {
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
if !ok {
return nil, status.Error(codes.Internal, "missing or invalid message in context")
}
claims, err := msg.VerifyJWT(ctx, i.jwtVerifier)
claims, err := msgVal.Message.VerifyJWT(ctx, i.jwtVerifier)
if errors.Is(err, message.ErrNotFound) { // missing jwt, pass it through
return ctx, nil
}
@ -90,7 +91,7 @@ func (i *JWT) intercept(ctx context.Context) (context.Context, error) {
return nil, errGRPCInvalidJWT
}
ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, claims.VerifiedEmail)
ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: claims.VerifiedEmail})
return ctx, nil
}

View File

@ -15,7 +15,6 @@ import (
"github.com/cosi-project/runtime/pkg/state"
"github.com/crewjam/saml"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/siderolabs/go-api-signature/pkg/message"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -25,6 +24,7 @@ import (
authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var errGRPCInvalidSAML = status.Error(codes.Unauthenticated, "invalid session")
@ -71,12 +71,12 @@ func (i *SAML) Stream() grpc.StreamServerInterceptor {
}
func (i *SAML) intercept(ctx context.Context) (context.Context, error) {
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
if !ok {
return nil, status.Error(codes.Internal, "missing or invalid message in context")
}
values := msg.Metadata.Get(auth.SamlSessionHeaderKey)
values := msgVal.Message.Metadata.Get(auth.SamlSessionHeaderKey)
if len(values) == 0 {
return ctx, nil
}
@ -86,7 +86,7 @@ func (i *SAML) intercept(ctx context.Context) (context.Context, error) {
return nil, errGRPCInvalidSAML
}
ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, session.TypedSpec().Value.Email)
ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: session.TypedSpec().Value.Email})
return ctx, nil
}

View File

@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/status"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
var errGRPCInvalidSignature = status.Error(codes.Unauthenticated, "invalid signature")
@ -64,12 +65,12 @@ func (i *Signature) Stream() grpc.StreamServerInterceptor {
}
func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
if !ok {
return nil, status.Error(codes.Internal, "missing or invalid message in context")
}
signature, err := msg.Signature()
signature, err := msgVal.Message.Signature()
if errors.Is(err, message.ErrNotFound) { // missing signature, pass it through
grpc_ctxtags.Extract(ctx).
Set("authenticator.user_id", "").
@ -100,7 +101,7 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
return nil, errGRPCInvalidSignature
}
err = msg.VerifySignature(authenticator.Verifier)
err = msgVal.Message.VerifySignature(authenticator.Verifier)
if err != nil {
i.logger.Info("failed to verify message", zap.Error(err))
@ -112,9 +113,9 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
Set("authenticator.identity", authenticator.Identity).
Set("authenticator.role", string(authenticator.Role))
ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID)
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity)
ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role)
ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID})
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity})
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role})
return ctx, nil
}

View File

@ -0,0 +1,35 @@
// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package ctxstore provides a way to store values in the context with the key based on the type.
package ctxstore
import "context"
// phantomKey represent key based on type. The cool thing about this empty struct,
// is that two instances of phantomKey with the different type are different, while
// two instances of phantomKey with the same type are the same. This is useful for
// creating a unique key for each type. It also helps to avoid collision with other
// keys in the context.
//
// It also does not allocate when used as a key in the context (aka converted to any).
// Same goes for int and struct containing single int field with value below 256.
// Same goes for bool and struct containing single bool field.
type phantomKey[T any] struct{}
// WithValue creates a new context with the value. Key is based on the type of the value.
func WithValue[T any](ctx context.Context, val T) context.Context {
return context.WithValue(ctx, phantomKey[T]{}, val)
}
// Value returns the value from the context. Key is based on the type of the value.
func Value[T any](ctx context.Context) (T, bool) {
value := ctx.Value(phantomKey[T]{})
if value == nil {
return *new(T), false
}
return value.(T), true //nolint:forcetypeassert
}

View File

@ -0,0 +1,76 @@
// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package ctxstore_test
import (
"context"
"runtime"
"testing"
"github.com/siderolabs/gen/pair"
"github.com/stretchr/testify/assert"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
func TestWithValue(t *testing.T) {
ctx := ctxstore.WithValue(context.Background(), "value1")
ctx = ctxstore.WithValue(ctx, 42)
ctx = ctxstore.WithValue(ctx, true)
type (
customString string
stringAlias = string
)
var cs customString
assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[string](ctx)))
assert.Equal(t, pair.MakePair(42, true), pair.MakePair(ctxstore.Value[int](ctx)))
assert.Equal(t, pair.MakePair(true, true), pair.MakePair(ctxstore.Value[bool](ctx)))
assert.Equal(t, pair.MakePair(0.0, false), pair.MakePair(ctxstore.Value[float64](ctx)))
assert.Equal(t, pair.MakePair(cs, false), pair.MakePair(ctxstore.Value[customString](ctx)))
assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[stringAlias](ctx)))
}
func BenchmarkWithValue(b *testing.B) {
b.ReportAllocs()
type (
emtpyStruct struct{}
myStruct[T any] struct{ Val T }
)
b.Run("empty struct", func(b *testing.B) { benchmarkFor(b, emtpyStruct{}) })
b.Run("small int", func(b *testing.B) { benchmarkFor(b, 42) })
b.Run("small int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 42}) })
b.Run("normal int", func(b *testing.B) { benchmarkFor(b, 424242) })
b.Run("normal int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 424242}) })
b.Run("bool", func(b *testing.B) { benchmarkFor(b, true) })
b.Run("bool inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[bool]{Val: true}) })
b.Run("string", func(b *testing.B) { benchmarkFor(b, "value") })
b.Run("string inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[string]{Val: "value"}) })
}
func benchmarkFor[T any](b *testing.B, value T) {
b.ReportAllocs()
var (
ok bool
result T
)
for range b.N {
ctx := ctxstore.WithValue(context.Background(), value)
result, ok = ctxstore.Value[T](ctx)
if !ok {
b.Fatal("unexpected")
}
}
runtime.KeepAlive(result)
}