diff --git a/cmd/omni/main.go b/cmd/omni/main.go index 8229eeff..a0e4c0e4 100644 --- a/cmd/omni/main.go +++ b/cmd/omni/main.go @@ -44,6 +44,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/user" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/features" "github.com/siderolabs/omni/internal/pkg/siderolink" "github.com/siderolabs/omni/internal/version" @@ -235,7 +236,7 @@ func runWithState(logger *zap.Logger) func(context.Context, state.State, *virtua return fmt.Errorf("failed to update features config resources: %w", err) } - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, authres.Enabled(authConfig)) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: authres.Enabled(authConfig)}) handler, err := backend.NewFrontendHandler(rootCmdArgs.frontendDst, logger) if err != nil { diff --git a/internal/backend/grpc/configs_test.go b/internal/backend/grpc/configs_test.go index 0087dcdb..d9ab5fa0 100644 --- a/internal/backend/grpc/configs_test.go +++ b/internal/backend/grpc/configs_test.go @@ -45,6 +45,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) //go:embed testdata/admin-kubeconfig.yaml @@ -193,11 +194,11 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string { md = metadata.New(nil) } - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, true) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: true}) msg := message.NewGRPC(md, info.FullMethod) - ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: msg}) if r := md.Get("role"); len(r) > 0 { var parsed role.Role @@ -207,7 +208,7 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string { return nil, err } - ctx = context.WithValue(ctx, auth.RoleContextKey{}, parsed) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: parsed}) } return handler(ctx, req) diff --git a/internal/backend/grpc/management.go b/internal/backend/grpc/management.go index 52ea0e54..4651b545 100644 --- a/internal/backend/grpc/management.go +++ b/internal/backend/grpc/management.go @@ -60,6 +60,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/siderolink" ) @@ -939,9 +940,10 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster return nil, err } - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } newRole, err := role.Max(userRole, clusterRole) @@ -953,7 +955,7 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster return ctx, nil } - return context.WithValue(ctx, auth.RoleContextKey{}, newRole), nil + return ctxstore.WithValue(ctx, auth.RoleContextKey{Role: newRole}), nil } func handleError(err error) error { diff --git a/internal/backend/grpc/router/talos_backend.go b/internal/backend/grpc/router/talos_backend.go index 407961ae..8d6ca089 100644 --- a/internal/backend/grpc/router/talos_backend.go +++ b/internal/backend/grpc/router/talos_backend.go @@ -22,6 +22,7 @@ import ( "github.com/siderolabs/omni/internal/backend/dns" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/grpcutil" ) @@ -78,9 +79,8 @@ func (backend *TalosBackend) GetConnection(ctx context.Context, fullMethodName s // we can't use regular gRPC server interceptors here, as proxy interface is a bit different // prepare context values for the verifier - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, backend.authEnabled) - msg := message.NewGRPC(md, fullMethodName) - ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: backend.authEnabled}) + ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, fullMethodName)}) grpcutil.SetShouldLog(ctx, "talos-backend") diff --git a/internal/backend/k8sproxy/k8sproxy.go b/internal/backend/k8sproxy/k8sproxy.go index cc3202c9..9e1900cc 100644 --- a/internal/backend/k8sproxy/k8sproxy.go +++ b/internal/backend/k8sproxy/k8sproxy.go @@ -17,7 +17,7 @@ import ( ) // clusterContextKey is a type for cluster name. -type clusterContextKey struct{} +type clusterContextKey struct{ ClusterName string } // Handler implements the HTTP reverse proxy for Kubernetes clusters. // diff --git a/internal/backend/k8sproxy/middleware.go b/internal/backend/k8sproxy/middleware.go index 5df7f3e2..6c84d7e3 100644 --- a/internal/backend/k8sproxy/middleware.go +++ b/internal/backend/k8sproxy/middleware.go @@ -18,6 +18,8 @@ import ( grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "go.uber.org/zap" "k8s.io/client-go/transport" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) const authorizationHeader = "Authorization" @@ -107,7 +109,7 @@ func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolve } // clone the request before modifying it - req = req.WithContext(context.WithValue(ctx, clusterContextKey{}, clusterName)) + req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName})) // clean all headers which are going to be overridden req.Header.Del(authorizationHeader) diff --git a/internal/backend/k8sproxy/middleware_test.go b/internal/backend/k8sproxy/middleware_test.go index 55e7f6c4..e55b664e 100644 --- a/internal/backend/k8sproxy/middleware_test.go +++ b/internal/backend/k8sproxy/middleware_test.go @@ -26,6 +26,7 @@ import ( "k8s.io/client-go/transport" "github.com/siderolabs/omni/internal/backend/k8sproxy" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) { @@ -276,9 +277,9 @@ func TestAuthorize(t *testing.T) { assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader)) assert.Nil(t, receivedReq.Header.Values("Authorization")) - v, ok := receivedReq.Context().Value(k8sproxy.ClusterContextKey{}).(string) + v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck assert.True(t, ok) - assert.Equal(t, tc.expectedCluster, v) + assert.Equal(t, tc.expectedCluster, v.ClusterName) }) } } diff --git a/internal/backend/k8sproxy/multiplex.go b/internal/backend/k8sproxy/multiplex.go index d6c1e427..12b7f7cd 100644 --- a/internal/backend/k8sproxy/multiplex.go +++ b/internal/backend/k8sproxy/multiplex.go @@ -24,6 +24,7 @@ import ( "github.com/siderolabs/omni/client/api/common" "github.com/siderolabs/omni/internal/backend/runtime" "github.com/siderolabs/omni/internal/backend/runtime/kubernetes" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // multiplexer provides an http.RoundTripper which selects the cluster based on the request context. @@ -74,12 +75,12 @@ func newMultiplexer() *multiplexer { // RoundTrip implements http.RoundTripper interface. func (m *multiplexer) RoundTrip(req *http.Request) (*http.Response, error) { - clusterName, ok := req.Context().Value(clusterContextKey{}).(string) + clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context()) if !ok { return nil, errors.New("cluster name not found in request context") } - rt, err := m.getRT(req.Context(), clusterName) + rt, err := m.getRT(req.Context(), clusterNameVal.ClusterName) if err != nil { return nil, err } diff --git a/internal/backend/k8sproxy/proxy.go b/internal/backend/k8sproxy/proxy.go index bfb52ffc..3a841820 100644 --- a/internal/backend/k8sproxy/proxy.go +++ b/internal/backend/k8sproxy/proxy.go @@ -13,6 +13,7 @@ import ( "go.uber.org/zap" "github.com/siderolabs/omni/internal/backend/logging" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // proxyHandler implements the HTTP reverse proxy. @@ -44,14 +45,14 @@ func newProxyHandler(m *multiplexer, logger *zap.Logger) *proxyHandler { // director sets the target URL for the reverse proxy. func (p *proxyHandler) director(req *http.Request) { - clusterName, ok := req.Context().Value(clusterContextKey{}).(string) + clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context()) if !ok { ctxzap.Error(req.Context(), "cluster name not found in request context") return } - connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterName) + connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterNameVal.ClusterName) if err != nil { ctxzap.Error(req.Context(), "failed to get cluster connector", zap.Error(err)) diff --git a/internal/backend/runtime/omni/omni_test.go b/internal/backend/runtime/omni/omni_test.go index 3a80db8d..5ecf083a 100644 --- a/internal/backend/runtime/omni/omni_test.go +++ b/internal/backend/runtime/omni/omni_test.go @@ -32,6 +32,7 @@ import ( "github.com/siderolabs/omni/internal/backend/workloadproxy" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // using whitelisted for external API access type. @@ -72,7 +73,7 @@ func (suite *OmniRuntimeSuite) SetupTest() { suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute) // disable auth in the context - suite.ctx = context.WithValue(suite.ctx, auth.EnabledAuthContextKey{}, false) + suite.ctx = ctxstore.WithValue(suite.ctx, auth.EnabledAuthContextKey{Enabled: false}) var err error diff --git a/internal/backend/runtime/omni/state_access.go b/internal/backend/runtime/omni/state_access.go index e8e33e2b..597edd06 100644 --- a/internal/backend/runtime/omni/state_access.go +++ b/internal/backend/runtime/omni/state_access.go @@ -28,6 +28,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var ( @@ -257,7 +258,7 @@ func checkForRole(ctx context.Context, st state.State, access state.Access, clus if clusterRole != role.None && (!requireAll || (requireAll && matchesAll)) { // override the role in the context with the computed role for this cluster - ctx = context.WithValue(ctx, auth.RoleContextKey{}, clusterRole) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: clusterRole}) } } diff --git a/internal/backend/runtime/omni/virtual/state.go b/internal/backend/runtime/omni/virtual/state.go index 65299642..c4195772 100644 --- a/internal/backend/runtime/omni/virtual/state.go +++ b/internal/backend/runtime/omni/virtual/state.go @@ -25,6 +25,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/accesspolicy" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // State is a virtual state implementation which provides virtual resources. @@ -161,16 +162,17 @@ func (v *State) validateKind(kind resource.Kind) error { } func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) { - identity, _ := ctx.Value(auth.IdentityContextKey{}).(string) //nolint:errcheck + identityVal, _ := ctxstore.Value[auth.IdentityContextKey](ctx) - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } user := virtual.NewCurrentUser() - user.TypedSpec().Value.Identity = identity + user.TypedSpec().Value.Identity = identityVal.Identity user.TypedSpec().Value.Role = string(userRole) version, err := resource.ParseVersion("1") @@ -184,9 +186,10 @@ func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) { } func (v *State) permissions(ctx context.Context) (*virtual.Permissions, error) { - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } permissions := virtual.NewPermissions() diff --git a/internal/backend/workloadproxy/accessvalidator.go b/internal/backend/workloadproxy/accessvalidator.go index 71d39a73..40eefbee 100644 --- a/internal/backend/workloadproxy/accessvalidator.go +++ b/internal/backend/workloadproxy/accessvalidator.go @@ -23,6 +23,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/accesspolicy" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // RoleProvider provides the current actor's role for a cluster. @@ -119,10 +120,10 @@ func (p *PGPAccessValidator) ValidateAccess(ctx context.Context, publicKeyID, pu return parseErr } - ctx = context.WithValue(ctx, auth.RoleContextKey{}, publicKeyRole) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: publicKeyRole}) } - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, publicKey.TypedSpec().Value.GetIdentity().GetEmail()) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: publicKey.TypedSpec().Value.GetIdentity().GetEmail()}) accessRole, err := p.roleProvider.RoleForCluster(ctx, clusterID) if err != nil { diff --git a/internal/pkg/auth/accesspolicy/cluster.go b/internal/pkg/auth/accesspolicy/cluster.go index 4ef8b1b4..4873a940 100644 --- a/internal/pkg/auth/accesspolicy/cluster.go +++ b/internal/pkg/auth/accesspolicy/cluster.go @@ -18,13 +18,15 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // RoleForCluster returns the role of the current user for the given cluster, and whether the role matches all clusters. func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.Role, bool, error) { - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } ctx = actor.MarkContextAsInternalActor(ctx) @@ -38,12 +40,12 @@ func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.R return role.None, false, err } - identityStr, identityExists := ctx.Value(auth.IdentityContextKey{}).(string) + identityVal, identityExists := ctxstore.Value[auth.IdentityContextKey](ctx) if !identityExists { return userRole, false, nil } - identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityStr).Metadata()) + identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityVal.Identity).Metadata()) if err != nil { if state.IsNotFoundError(err) { return userRole, false, nil diff --git a/internal/pkg/auth/actor/actor.go b/internal/pkg/auth/actor/actor.go index 75c777fb..adde498b 100644 --- a/internal/pkg/auth/actor/actor.go +++ b/internal/pkg/auth/actor/actor.go @@ -6,17 +6,23 @@ // Package actor implements the context marking for internal/external actors. package actor -import "context" +import ( + "context" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) // internalActorContextKey is the key for internal actor context. type internalActorContextKey struct{} // MarkContextAsInternalActor returns a new derived context from the given context, marked as an internal actor. func MarkContextAsInternalActor(ctx context.Context) context.Context { - return context.WithValue(ctx, internalActorContextKey{}, struct{}{}) + return ctxstore.WithValue(ctx, internalActorContextKey{}) } // ContextIsInternalActor returns true if the given context is marked as an internal actor. func ContextIsInternalActor(ctx context.Context) bool { - return ctx.Value(internalActorContextKey{}) != nil + _, ok := ctxstore.Value[internalActorContextKey](ctx) + + return ok } diff --git a/internal/pkg/auth/check.go b/internal/pkg/auth/check.go index 54197093..8a9f50d0 100644 --- a/internal/pkg/auth/check.go +++ b/internal/pkg/auth/check.go @@ -14,6 +14,7 @@ import ( "google.golang.org/grpc/status" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var ( @@ -81,19 +82,19 @@ func WithVerifiedEmail() CheckOption { // // The returned error can be checked against ErrUnauthenticated and ErrUnauthorized. func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { - authEnabled, ok := ctx.Value(EnabledAuthContextKey{}).(bool) + authVal, ok := ctxstore.Value[EnabledAuthContextKey](ctx) if !ok { return CheckResult{}, fmt.Errorf("%w: auth configuration not found in context", ErrUnauthenticated) } - if !authEnabled { + if !authVal.Enabled { return CheckResult{ AuthEnabled: false, }, nil } result := CheckResult{ - AuthEnabled: authEnabled, + AuthEnabled: authVal.Enabled, } opts := DefaultCheckOptions() @@ -108,22 +109,25 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { } if opts.VerifiedEmail { - email, ok := ctx.Value(VerifiedEmailContextKey{}).(string) + emailVal, ok := ctxstore.Value[VerifiedEmailContextKey](ctx) if !ok { return CheckResult{}, fmt.Errorf("%w: missing verified email", ErrUnauthenticated) } - result.VerifiedEmail = email + result.VerifiedEmail = emailVal.Email } - ctxRole, ctxRoleExists := ctx.Value(RoleContextKey{}).(role.Role) - if !ctxRoleExists { - ctxRole = role.None + ctxRole := role.None + ctxRoleExists := false + + if val, ok := ctxstore.Value[RoleContextKey](ctx); ok { + ctxRole = val.Role + ctxRoleExists = true } result.Role = ctxRole - // RoleContextKey{} is set on the context only when there is a valid signature, so we can rely on this. + // RoleContextKey is set on the context only when there is a valid signature, so we can rely on this. result.HasValidSignature = ctxRoleExists if opts.ValidSignature && !result.HasValidSignature { @@ -137,12 +141,12 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { } } - if identity, ok := ctx.Value(IdentityContextKey{}).(string); ok { - result.Identity = identity + if val, ok := ctxstore.Value[IdentityContextKey](ctx); ok { + result.Identity = val.Identity } - if userID, ok := ctx.Value(UserIDContextKey{}).(string); ok { - result.UserID = userID + if val, ok := ctxstore.Value[UserIDContextKey](ctx); ok { + result.UserID = val.UserID } return result, nil diff --git a/internal/pkg/auth/check_test.go b/internal/pkg/auth/check_test.go index 4ab4d047..93bb8eac 100644 --- a/internal/pkg/auth/check_test.go +++ b/internal/pkg/auth/check_test.go @@ -13,6 +13,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) func TestCheck(t *testing.T) { @@ -30,18 +31,20 @@ func TestCheck(t *testing.T) { }, { name: "auth disabled", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - false, + auth.EnabledAuthContextKey{ + Enabled: false, + }, ), }, { name: "not authenticated, no requirements", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), want: auth.CheckResult{ AuthEnabled: true, @@ -50,44 +53,49 @@ func TestCheck(t *testing.T) { }, { name: "not authenticated, verified email", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithVerifiedEmail()}, errorIs: auth.ErrUnauthenticated, }, { name: "not authenticated, none role", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithValidSignature(true)}, errorIs: auth.ErrUnauthenticated, }, { name: "not authenticated, operator role", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator)}, errorIs: auth.ErrUnauthenticated, }, { name: "verified email", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.VerifiedEmailContextKey{}, - "user@example.com", + auth.VerifiedEmailContextKey{ + Email: "user@example.com", + }, ), opts: []auth.CheckOption{auth.WithVerifiedEmail()}, want: auth.CheckResult{ @@ -98,14 +106,16 @@ func TestCheck(t *testing.T) { }, { name: "role okay", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator)}, want: auth.CheckResult{ @@ -116,36 +126,42 @@ func TestCheck(t *testing.T) { }, { name: "role mismatch", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Admin)}, errorIs: auth.ErrUnauthorized, }, { name: "role and verified email", - ctx: context.WithValue( - context.WithValue( - context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( + ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), - auth.VerifiedEmailContextKey{}, - "user@example.com", + auth.VerifiedEmailContextKey{ + Email: "user@example.com", + }, ), - auth.IdentityContextKey{}, - "user2@example.com", + auth.IdentityContextKey{ + Identity: "user2@example.com", + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator), auth.WithVerifiedEmail()}, want: auth.CheckResult{ @@ -158,14 +174,16 @@ func TestCheck(t *testing.T) { }, { name: "valid signature", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.None, + auth.RoleContextKey{ + Role: role.None, + }, ), opts: []auth.CheckOption{}, want: auth.CheckResult{ @@ -176,14 +194,16 @@ func TestCheck(t *testing.T) { }, { name: "missing signature", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.VerifiedEmailContextKey{}, - "me@example.com", + auth.VerifiedEmailContextKey{ + Email: "me@example.com", + }, ), opts: []auth.CheckOption{auth.WithValidSignature(true)}, errorIs: auth.ErrUnauthenticated, diff --git a/internal/pkg/auth/context.go b/internal/pkg/auth/context.go index 42eb3485..c5b57475 100644 --- a/internal/pkg/auth/context.go +++ b/internal/pkg/auth/context.go @@ -5,20 +5,26 @@ package auth -// EnabledAuthContextKey is the context key for enabled authentication. Value has the type bool. -type EnabledAuthContextKey struct{} +import ( + "github.com/siderolabs/go-api-signature/pkg/message" + + "github.com/siderolabs/omni/internal/pkg/auth/role" +) + +// EnabledAuthContextKey is the context key for enabled authentication. +type EnabledAuthContextKey struct{ Enabled bool } // GRPCMessageContextKey is the context key for the GRPC message. It is only set if authentication is enabled. -type GRPCMessageContextKey struct{} +type GRPCMessageContextKey struct{ Message *message.GRPC } -// VerifiedEmailContextKey is the context key for the verified email address. Value has the type string. -type VerifiedEmailContextKey struct{} +// VerifiedEmailContextKey is the context key for the verified email address. +type VerifiedEmailContextKey struct{ Email string } // UserIDContextKey is the context key for the user ID. Value has the type string. -type UserIDContextKey struct{} +type UserIDContextKey struct{ UserID string } // RoleContextKey is the context key for the role. Value has the type role.Role. -type RoleContextKey struct{} +type RoleContextKey struct{ Role role.Role } -// IdentityContextKey is the context key for the user identity. Value has the type string. -type IdentityContextKey struct{} +// IdentityContextKey is the context key for the user identity. +type IdentityContextKey struct{ Identity string } diff --git a/internal/pkg/auth/handler/auth_config.go b/internal/pkg/auth/handler/auth_config.go index 0969ef44..4992da5e 100644 --- a/internal/pkg/auth/handler/auth_config.go +++ b/internal/pkg/auth/handler/auth_config.go @@ -6,12 +6,12 @@ package handler import ( - "context" "net/http" "go.uber.org/zap" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // AuthConfig represents the configuration for the auth config interceptor. @@ -31,7 +31,7 @@ func NewAuthConfig(handler http.Handler, enabled bool, logger *zap.Logger) *Auth } func (c *AuthConfig) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - ctx := context.WithValue(request.Context(), auth.EnabledAuthContextKey{}, c.enabled) + ctx := ctxstore.WithValue(request.Context(), auth.EnabledAuthContextKey{Enabled: c.enabled}) request = request.WithContext(ctx) c.next.ServeHTTP(writer, request) diff --git a/internal/pkg/auth/handler/handler_test.go b/internal/pkg/auth/handler/handler_test.go index 6a308d84..cb35895f 100644 --- a/internal/pkg/auth/handler/handler_test.go +++ b/internal/pkg/auth/handler/handler_test.go @@ -20,6 +20,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/handler" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) func testHandler(t *testing.T, authEnabled bool) { @@ -130,25 +131,25 @@ func testHandler(t *testing.T, authEnabled bool) { t.Fatal("timeout") } - ctxAuthEnabled, ok := reqCtx.Value(auth.EnabledAuthContextKey{}).(bool) + ctxAuthEnabledVal, ok := ctxstore.Value[auth.EnabledAuthContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, authEnabled, ctxAuthEnabled) + assert.Equal(t, authEnabled, ctxAuthEnabledVal.Enabled) if !tc.verifyContext { return } - ctxUserID, ok := reqCtx.Value(auth.UserIDContextKey{}).(string) + ctxUserIDVal, ok := ctxstore.Value[auth.UserIDContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, "user-id", ctxUserID) + assert.Equal(t, "user-id", ctxUserIDVal.UserID) - ctxRole, ok := reqCtx.Value(auth.RoleContextKey{}).(role.Role) + ctxRoleVal, ok := ctxstore.Value[auth.RoleContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, role.Operator, ctxRole) + assert.Equal(t, role.Operator, ctxRoleVal.Role) - ctxIdentity, ok := reqCtx.Value(auth.IdentityContextKey{}).(string) + ctxIdentityVal, ok := ctxstore.Value[auth.IdentityContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, "user@example.com", ctxIdentity) + assert.Equal(t, "user@example.com", ctxIdentityVal.Identity) }) } } diff --git a/internal/pkg/auth/handler/signature.go b/internal/pkg/auth/handler/signature.go index 2850a5ab..5ede724e 100644 --- a/internal/pkg/auth/handler/signature.go +++ b/internal/pkg/auth/handler/signature.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errInvalidSignature = errors.New("invalid signature") @@ -102,9 +103,9 @@ func (s *Signature) intercept(request *http.Request) (*http.Request, error) { Set("authenticator.identity", authenticator.Identity). Set("authenticator.role", string(authenticator.Role)) - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity) - ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID) - ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity}) + ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID}) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role}) return request.WithContext(ctx), nil } diff --git a/internal/pkg/auth/interceptor/auth_config.go b/internal/pkg/auth/interceptor/auth_config.go index 6ef62bb2..8b1a91c2 100644 --- a/internal/pkg/auth/interceptor/auth_config.go +++ b/internal/pkg/auth/interceptor/auth_config.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/metadata" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // AuthConfig represents the configuration for the auth config interceptor. @@ -53,7 +54,7 @@ func (c *AuthConfig) Stream() grpc.StreamServerInterceptor { } func (c *AuthConfig) intercept(ctx context.Context, method string) context.Context { - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, c.enabled) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: c.enabled}) if !c.enabled { return ctx @@ -64,7 +65,5 @@ func (c *AuthConfig) intercept(ctx context.Context, method string) context.Conte md = metadata.New(nil) } - msg := message.NewGRPC(md, method) - - return context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + return ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, method)}) } diff --git a/internal/pkg/auth/interceptor/jwt.go b/internal/pkg/auth/interceptor/jwt.go index 7c08fff0..fbae7f84 100644 --- a/internal/pkg/auth/interceptor/jwt.go +++ b/internal/pkg/auth/interceptor/jwt.go @@ -20,6 +20,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/auth0" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidJWT = status.Error(codes.Unauthenticated, "invalid jwt") @@ -66,12 +67,12 @@ func (i *JWT) Stream() grpc.StreamServerInterceptor { } func (i *JWT) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - claims, err := msg.VerifyJWT(ctx, i.jwtVerifier) + claims, err := msgVal.Message.VerifyJWT(ctx, i.jwtVerifier) if errors.Is(err, message.ErrNotFound) { // missing jwt, pass it through return ctx, nil } @@ -90,7 +91,7 @@ func (i *JWT) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidJWT } - ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, claims.VerifiedEmail) + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: claims.VerifiedEmail}) return ctx, nil } diff --git a/internal/pkg/auth/interceptor/saml.go b/internal/pkg/auth/interceptor/saml.go index 485cc89b..3188f0fc 100644 --- a/internal/pkg/auth/interceptor/saml.go +++ b/internal/pkg/auth/interceptor/saml.go @@ -15,7 +15,6 @@ import ( "github.com/cosi-project/runtime/pkg/state" "github.com/crewjam/saml" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/siderolabs/go-api-signature/pkg/message" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -25,6 +24,7 @@ import ( authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidSAML = status.Error(codes.Unauthenticated, "invalid session") @@ -71,12 +71,12 @@ func (i *SAML) Stream() grpc.StreamServerInterceptor { } func (i *SAML) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - values := msg.Metadata.Get(auth.SamlSessionHeaderKey) + values := msgVal.Message.Metadata.Get(auth.SamlSessionHeaderKey) if len(values) == 0 { return ctx, nil } @@ -86,7 +86,7 @@ func (i *SAML) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSAML } - ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, session.TypedSpec().Value.Email) + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: session.TypedSpec().Value.Email}) return ctx, nil } diff --git a/internal/pkg/auth/interceptor/signature.go b/internal/pkg/auth/interceptor/signature.go index 2872d472..ec896e43 100644 --- a/internal/pkg/auth/interceptor/signature.go +++ b/internal/pkg/auth/interceptor/signature.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidSignature = status.Error(codes.Unauthenticated, "invalid signature") @@ -64,12 +65,12 @@ func (i *Signature) Stream() grpc.StreamServerInterceptor { } func (i *Signature) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - signature, err := msg.Signature() + signature, err := msgVal.Message.Signature() if errors.Is(err, message.ErrNotFound) { // missing signature, pass it through grpc_ctxtags.Extract(ctx). Set("authenticator.user_id", ""). @@ -100,7 +101,7 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSignature } - err = msg.VerifySignature(authenticator.Verifier) + err = msgVal.Message.VerifySignature(authenticator.Verifier) if err != nil { i.logger.Info("failed to verify message", zap.Error(err)) @@ -112,9 +113,9 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) { Set("authenticator.identity", authenticator.Identity). Set("authenticator.role", string(authenticator.Role)) - ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID) - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity) - ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role) + ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID}) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity}) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role}) return ctx, nil } diff --git a/internal/pkg/ctxstore/ctxstore.go b/internal/pkg/ctxstore/ctxstore.go new file mode 100644 index 00000000..9a24548f --- /dev/null +++ b/internal/pkg/ctxstore/ctxstore.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +// Package ctxstore provides a way to store values in the context with the key based on the type. +package ctxstore + +import "context" + +// phantomKey represent key based on type. The cool thing about this empty struct, +// is that two instances of phantomKey with the different type are different, while +// two instances of phantomKey with the same type are the same. This is useful for +// creating a unique key for each type. It also helps to avoid collision with other +// keys in the context. +// +// It also does not allocate when used as a key in the context (aka converted to any). +// Same goes for int and struct containing single int field with value below 256. +// Same goes for bool and struct containing single bool field. +type phantomKey[T any] struct{} + +// WithValue creates a new context with the value. Key is based on the type of the value. +func WithValue[T any](ctx context.Context, val T) context.Context { + return context.WithValue(ctx, phantomKey[T]{}, val) +} + +// Value returns the value from the context. Key is based on the type of the value. +func Value[T any](ctx context.Context) (T, bool) { + value := ctx.Value(phantomKey[T]{}) + if value == nil { + return *new(T), false + } + + return value.(T), true //nolint:forcetypeassert +} diff --git a/internal/pkg/ctxstore/ctxstore_test.go b/internal/pkg/ctxstore/ctxstore_test.go new file mode 100644 index 00000000..9297e300 --- /dev/null +++ b/internal/pkg/ctxstore/ctxstore_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package ctxstore_test + +import ( + "context" + "runtime" + "testing" + + "github.com/siderolabs/gen/pair" + "github.com/stretchr/testify/assert" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) + +func TestWithValue(t *testing.T) { + ctx := ctxstore.WithValue(context.Background(), "value1") + ctx = ctxstore.WithValue(ctx, 42) + ctx = ctxstore.WithValue(ctx, true) + + type ( + customString string + stringAlias = string + ) + + var cs customString + + assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[string](ctx))) + assert.Equal(t, pair.MakePair(42, true), pair.MakePair(ctxstore.Value[int](ctx))) + assert.Equal(t, pair.MakePair(true, true), pair.MakePair(ctxstore.Value[bool](ctx))) + assert.Equal(t, pair.MakePair(0.0, false), pair.MakePair(ctxstore.Value[float64](ctx))) + assert.Equal(t, pair.MakePair(cs, false), pair.MakePair(ctxstore.Value[customString](ctx))) + assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[stringAlias](ctx))) +} + +func BenchmarkWithValue(b *testing.B) { + b.ReportAllocs() + + type ( + emtpyStruct struct{} + myStruct[T any] struct{ Val T } + ) + + b.Run("empty struct", func(b *testing.B) { benchmarkFor(b, emtpyStruct{}) }) + b.Run("small int", func(b *testing.B) { benchmarkFor(b, 42) }) + b.Run("small int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 42}) }) + b.Run("normal int", func(b *testing.B) { benchmarkFor(b, 424242) }) + b.Run("normal int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 424242}) }) + b.Run("bool", func(b *testing.B) { benchmarkFor(b, true) }) + b.Run("bool inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[bool]{Val: true}) }) + b.Run("string", func(b *testing.B) { benchmarkFor(b, "value") }) + b.Run("string inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[string]{Val: "value"}) }) +} + +func benchmarkFor[T any](b *testing.B, value T) { + b.ReportAllocs() + + var ( + ok bool + result T + ) + + for range b.N { + ctx := ctxstore.WithValue(context.Background(), value) + + result, ok = ctxstore.Value[T](ctx) + if !ok { + b.Fatal("unexpected") + } + } + + runtime.KeepAlive(result) +}