From 4cfc0e6dd0bf45767bcbd17eb813544153d0beed Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Mon, 15 Jul 2024 16:21:22 +0300 Subject: [PATCH] chore: rework auth.* keys, add `ctxstore` package Using so-called phantom types we can use the types themselves as keys directly without loosing performance. You no longer need to remember which type was attached to the thing you passed in context and can look up all fields access directly. Part of #37 Signed-off-by: Dmitriy Matrenichev --- cmd/omni/main.go | 3 +- internal/backend/grpc/configs_test.go | 7 +- internal/backend/grpc/management.go | 10 +- internal/backend/grpc/router/talos_backend.go | 6 +- internal/backend/k8sproxy/k8sproxy.go | 2 +- internal/backend/k8sproxy/middleware.go | 4 +- internal/backend/k8sproxy/middleware_test.go | 5 +- internal/backend/k8sproxy/multiplex.go | 5 +- internal/backend/k8sproxy/proxy.go | 5 +- internal/backend/runtime/omni/omni_test.go | 3 +- internal/backend/runtime/omni/state_access.go | 3 +- .../backend/runtime/omni/virtual/state.go | 19 +-- .../backend/workloadproxy/accessvalidator.go | 5 +- internal/pkg/auth/accesspolicy/cluster.go | 12 +- internal/pkg/auth/actor/actor.go | 12 +- internal/pkg/auth/check.go | 30 ++-- internal/pkg/auth/check_test.go | 134 ++++++++++-------- internal/pkg/auth/context.go | 24 ++-- internal/pkg/auth/handler/auth_config.go | 4 +- internal/pkg/auth/handler/handler_test.go | 17 +-- internal/pkg/auth/handler/signature.go | 7 +- internal/pkg/auth/interceptor/auth_config.go | 7 +- internal/pkg/auth/interceptor/jwt.go | 7 +- internal/pkg/auth/interceptor/saml.go | 8 +- internal/pkg/auth/interceptor/signature.go | 13 +- internal/pkg/ctxstore/ctxstore.go | 35 +++++ internal/pkg/ctxstore/ctxstore_test.go | 76 ++++++++++ 27 files changed, 315 insertions(+), 148 deletions(-) create mode 100644 internal/pkg/ctxstore/ctxstore.go create mode 100644 internal/pkg/ctxstore/ctxstore_test.go diff --git a/cmd/omni/main.go b/cmd/omni/main.go index 8229eeff..a0e4c0e4 100644 --- a/cmd/omni/main.go +++ b/cmd/omni/main.go @@ -44,6 +44,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/user" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/features" "github.com/siderolabs/omni/internal/pkg/siderolink" "github.com/siderolabs/omni/internal/version" @@ -235,7 +236,7 @@ func runWithState(logger *zap.Logger) func(context.Context, state.State, *virtua return fmt.Errorf("failed to update features config resources: %w", err) } - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, authres.Enabled(authConfig)) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: authres.Enabled(authConfig)}) handler, err := backend.NewFrontendHandler(rootCmdArgs.frontendDst, logger) if err != nil { diff --git a/internal/backend/grpc/configs_test.go b/internal/backend/grpc/configs_test.go index 0087dcdb..d9ab5fa0 100644 --- a/internal/backend/grpc/configs_test.go +++ b/internal/backend/grpc/configs_test.go @@ -45,6 +45,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) //go:embed testdata/admin-kubeconfig.yaml @@ -193,11 +194,11 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string { md = metadata.New(nil) } - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, true) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: true}) msg := message.NewGRPC(md, info.FullMethod) - ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: msg}) if r := md.Get("role"); len(r) > 0 { var parsed role.Role @@ -207,7 +208,7 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string { return nil, err } - ctx = context.WithValue(ctx, auth.RoleContextKey{}, parsed) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: parsed}) } return handler(ctx, req) diff --git a/internal/backend/grpc/management.go b/internal/backend/grpc/management.go index 52ea0e54..4651b545 100644 --- a/internal/backend/grpc/management.go +++ b/internal/backend/grpc/management.go @@ -60,6 +60,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/siderolink" ) @@ -939,9 +940,10 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster return nil, err } - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } newRole, err := role.Max(userRole, clusterRole) @@ -953,7 +955,7 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster return ctx, nil } - return context.WithValue(ctx, auth.RoleContextKey{}, newRole), nil + return ctxstore.WithValue(ctx, auth.RoleContextKey{Role: newRole}), nil } func handleError(err error) error { diff --git a/internal/backend/grpc/router/talos_backend.go b/internal/backend/grpc/router/talos_backend.go index 407961ae..8d6ca089 100644 --- a/internal/backend/grpc/router/talos_backend.go +++ b/internal/backend/grpc/router/talos_backend.go @@ -22,6 +22,7 @@ import ( "github.com/siderolabs/omni/internal/backend/dns" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" "github.com/siderolabs/omni/internal/pkg/grpcutil" ) @@ -78,9 +79,8 @@ func (backend *TalosBackend) GetConnection(ctx context.Context, fullMethodName s // we can't use regular gRPC server interceptors here, as proxy interface is a bit different // prepare context values for the verifier - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, backend.authEnabled) - msg := message.NewGRPC(md, fullMethodName) - ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: backend.authEnabled}) + ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, fullMethodName)}) grpcutil.SetShouldLog(ctx, "talos-backend") diff --git a/internal/backend/k8sproxy/k8sproxy.go b/internal/backend/k8sproxy/k8sproxy.go index cc3202c9..9e1900cc 100644 --- a/internal/backend/k8sproxy/k8sproxy.go +++ b/internal/backend/k8sproxy/k8sproxy.go @@ -17,7 +17,7 @@ import ( ) // clusterContextKey is a type for cluster name. -type clusterContextKey struct{} +type clusterContextKey struct{ ClusterName string } // Handler implements the HTTP reverse proxy for Kubernetes clusters. // diff --git a/internal/backend/k8sproxy/middleware.go b/internal/backend/k8sproxy/middleware.go index 5df7f3e2..6c84d7e3 100644 --- a/internal/backend/k8sproxy/middleware.go +++ b/internal/backend/k8sproxy/middleware.go @@ -18,6 +18,8 @@ import ( grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "go.uber.org/zap" "k8s.io/client-go/transport" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) const authorizationHeader = "Authorization" @@ -107,7 +109,7 @@ func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolve } // clone the request before modifying it - req = req.WithContext(context.WithValue(ctx, clusterContextKey{}, clusterName)) + req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName})) // clean all headers which are going to be overridden req.Header.Del(authorizationHeader) diff --git a/internal/backend/k8sproxy/middleware_test.go b/internal/backend/k8sproxy/middleware_test.go index 55e7f6c4..e55b664e 100644 --- a/internal/backend/k8sproxy/middleware_test.go +++ b/internal/backend/k8sproxy/middleware_test.go @@ -26,6 +26,7 @@ import ( "k8s.io/client-go/transport" "github.com/siderolabs/omni/internal/backend/k8sproxy" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) { @@ -276,9 +277,9 @@ func TestAuthorize(t *testing.T) { assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader)) assert.Nil(t, receivedReq.Header.Values("Authorization")) - v, ok := receivedReq.Context().Value(k8sproxy.ClusterContextKey{}).(string) + v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck assert.True(t, ok) - assert.Equal(t, tc.expectedCluster, v) + assert.Equal(t, tc.expectedCluster, v.ClusterName) }) } } diff --git a/internal/backend/k8sproxy/multiplex.go b/internal/backend/k8sproxy/multiplex.go index d6c1e427..12b7f7cd 100644 --- a/internal/backend/k8sproxy/multiplex.go +++ b/internal/backend/k8sproxy/multiplex.go @@ -24,6 +24,7 @@ import ( "github.com/siderolabs/omni/client/api/common" "github.com/siderolabs/omni/internal/backend/runtime" "github.com/siderolabs/omni/internal/backend/runtime/kubernetes" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // multiplexer provides an http.RoundTripper which selects the cluster based on the request context. @@ -74,12 +75,12 @@ func newMultiplexer() *multiplexer { // RoundTrip implements http.RoundTripper interface. func (m *multiplexer) RoundTrip(req *http.Request) (*http.Response, error) { - clusterName, ok := req.Context().Value(clusterContextKey{}).(string) + clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context()) if !ok { return nil, errors.New("cluster name not found in request context") } - rt, err := m.getRT(req.Context(), clusterName) + rt, err := m.getRT(req.Context(), clusterNameVal.ClusterName) if err != nil { return nil, err } diff --git a/internal/backend/k8sproxy/proxy.go b/internal/backend/k8sproxy/proxy.go index bfb52ffc..3a841820 100644 --- a/internal/backend/k8sproxy/proxy.go +++ b/internal/backend/k8sproxy/proxy.go @@ -13,6 +13,7 @@ import ( "go.uber.org/zap" "github.com/siderolabs/omni/internal/backend/logging" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // proxyHandler implements the HTTP reverse proxy. @@ -44,14 +45,14 @@ func newProxyHandler(m *multiplexer, logger *zap.Logger) *proxyHandler { // director sets the target URL for the reverse proxy. func (p *proxyHandler) director(req *http.Request) { - clusterName, ok := req.Context().Value(clusterContextKey{}).(string) + clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context()) if !ok { ctxzap.Error(req.Context(), "cluster name not found in request context") return } - connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterName) + connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterNameVal.ClusterName) if err != nil { ctxzap.Error(req.Context(), "failed to get cluster connector", zap.Error(err)) diff --git a/internal/backend/runtime/omni/omni_test.go b/internal/backend/runtime/omni/omni_test.go index 3a80db8d..5ecf083a 100644 --- a/internal/backend/runtime/omni/omni_test.go +++ b/internal/backend/runtime/omni/omni_test.go @@ -32,6 +32,7 @@ import ( "github.com/siderolabs/omni/internal/backend/workloadproxy" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // using whitelisted for external API access type. @@ -72,7 +73,7 @@ func (suite *OmniRuntimeSuite) SetupTest() { suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute) // disable auth in the context - suite.ctx = context.WithValue(suite.ctx, auth.EnabledAuthContextKey{}, false) + suite.ctx = ctxstore.WithValue(suite.ctx, auth.EnabledAuthContextKey{Enabled: false}) var err error diff --git a/internal/backend/runtime/omni/state_access.go b/internal/backend/runtime/omni/state_access.go index e8e33e2b..597edd06 100644 --- a/internal/backend/runtime/omni/state_access.go +++ b/internal/backend/runtime/omni/state_access.go @@ -28,6 +28,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var ( @@ -257,7 +258,7 @@ func checkForRole(ctx context.Context, st state.State, access state.Access, clus if clusterRole != role.None && (!requireAll || (requireAll && matchesAll)) { // override the role in the context with the computed role for this cluster - ctx = context.WithValue(ctx, auth.RoleContextKey{}, clusterRole) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: clusterRole}) } } diff --git a/internal/backend/runtime/omni/virtual/state.go b/internal/backend/runtime/omni/virtual/state.go index 65299642..c4195772 100644 --- a/internal/backend/runtime/omni/virtual/state.go +++ b/internal/backend/runtime/omni/virtual/state.go @@ -25,6 +25,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/accesspolicy" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // State is a virtual state implementation which provides virtual resources. @@ -161,16 +162,17 @@ func (v *State) validateKind(kind resource.Kind) error { } func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) { - identity, _ := ctx.Value(auth.IdentityContextKey{}).(string) //nolint:errcheck + identityVal, _ := ctxstore.Value[auth.IdentityContextKey](ctx) - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } user := virtual.NewCurrentUser() - user.TypedSpec().Value.Identity = identity + user.TypedSpec().Value.Identity = identityVal.Identity user.TypedSpec().Value.Role = string(userRole) version, err := resource.ParseVersion("1") @@ -184,9 +186,10 @@ func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) { } func (v *State) permissions(ctx context.Context) (*virtual.Permissions, error) { - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } permissions := virtual.NewPermissions() diff --git a/internal/backend/workloadproxy/accessvalidator.go b/internal/backend/workloadproxy/accessvalidator.go index 71d39a73..40eefbee 100644 --- a/internal/backend/workloadproxy/accessvalidator.go +++ b/internal/backend/workloadproxy/accessvalidator.go @@ -23,6 +23,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/accesspolicy" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // RoleProvider provides the current actor's role for a cluster. @@ -119,10 +120,10 @@ func (p *PGPAccessValidator) ValidateAccess(ctx context.Context, publicKeyID, pu return parseErr } - ctx = context.WithValue(ctx, auth.RoleContextKey{}, publicKeyRole) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: publicKeyRole}) } - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, publicKey.TypedSpec().Value.GetIdentity().GetEmail()) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: publicKey.TypedSpec().Value.GetIdentity().GetEmail()}) accessRole, err := p.roleProvider.RoleForCluster(ctx, clusterID) if err != nil { diff --git a/internal/pkg/auth/accesspolicy/cluster.go b/internal/pkg/auth/accesspolicy/cluster.go index 4ef8b1b4..4873a940 100644 --- a/internal/pkg/auth/accesspolicy/cluster.go +++ b/internal/pkg/auth/accesspolicy/cluster.go @@ -18,13 +18,15 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // RoleForCluster returns the role of the current user for the given cluster, and whether the role matches all clusters. func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.Role, bool, error) { - userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role) - if !userRoleExists { - userRole = role.None + userRole := role.None + + if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok { + userRole = val.Role } ctx = actor.MarkContextAsInternalActor(ctx) @@ -38,12 +40,12 @@ func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.R return role.None, false, err } - identityStr, identityExists := ctx.Value(auth.IdentityContextKey{}).(string) + identityVal, identityExists := ctxstore.Value[auth.IdentityContextKey](ctx) if !identityExists { return userRole, false, nil } - identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityStr).Metadata()) + identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityVal.Identity).Metadata()) if err != nil { if state.IsNotFoundError(err) { return userRole, false, nil diff --git a/internal/pkg/auth/actor/actor.go b/internal/pkg/auth/actor/actor.go index 75c777fb..adde498b 100644 --- a/internal/pkg/auth/actor/actor.go +++ b/internal/pkg/auth/actor/actor.go @@ -6,17 +6,23 @@ // Package actor implements the context marking for internal/external actors. package actor -import "context" +import ( + "context" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) // internalActorContextKey is the key for internal actor context. type internalActorContextKey struct{} // MarkContextAsInternalActor returns a new derived context from the given context, marked as an internal actor. func MarkContextAsInternalActor(ctx context.Context) context.Context { - return context.WithValue(ctx, internalActorContextKey{}, struct{}{}) + return ctxstore.WithValue(ctx, internalActorContextKey{}) } // ContextIsInternalActor returns true if the given context is marked as an internal actor. func ContextIsInternalActor(ctx context.Context) bool { - return ctx.Value(internalActorContextKey{}) != nil + _, ok := ctxstore.Value[internalActorContextKey](ctx) + + return ok } diff --git a/internal/pkg/auth/check.go b/internal/pkg/auth/check.go index 54197093..8a9f50d0 100644 --- a/internal/pkg/auth/check.go +++ b/internal/pkg/auth/check.go @@ -14,6 +14,7 @@ import ( "google.golang.org/grpc/status" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var ( @@ -81,19 +82,19 @@ func WithVerifiedEmail() CheckOption { // // The returned error can be checked against ErrUnauthenticated and ErrUnauthorized. func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { - authEnabled, ok := ctx.Value(EnabledAuthContextKey{}).(bool) + authVal, ok := ctxstore.Value[EnabledAuthContextKey](ctx) if !ok { return CheckResult{}, fmt.Errorf("%w: auth configuration not found in context", ErrUnauthenticated) } - if !authEnabled { + if !authVal.Enabled { return CheckResult{ AuthEnabled: false, }, nil } result := CheckResult{ - AuthEnabled: authEnabled, + AuthEnabled: authVal.Enabled, } opts := DefaultCheckOptions() @@ -108,22 +109,25 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { } if opts.VerifiedEmail { - email, ok := ctx.Value(VerifiedEmailContextKey{}).(string) + emailVal, ok := ctxstore.Value[VerifiedEmailContextKey](ctx) if !ok { return CheckResult{}, fmt.Errorf("%w: missing verified email", ErrUnauthenticated) } - result.VerifiedEmail = email + result.VerifiedEmail = emailVal.Email } - ctxRole, ctxRoleExists := ctx.Value(RoleContextKey{}).(role.Role) - if !ctxRoleExists { - ctxRole = role.None + ctxRole := role.None + ctxRoleExists := false + + if val, ok := ctxstore.Value[RoleContextKey](ctx); ok { + ctxRole = val.Role + ctxRoleExists = true } result.Role = ctxRole - // RoleContextKey{} is set on the context only when there is a valid signature, so we can rely on this. + // RoleContextKey is set on the context only when there is a valid signature, so we can rely on this. result.HasValidSignature = ctxRoleExists if opts.ValidSignature && !result.HasValidSignature { @@ -137,12 +141,12 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) { } } - if identity, ok := ctx.Value(IdentityContextKey{}).(string); ok { - result.Identity = identity + if val, ok := ctxstore.Value[IdentityContextKey](ctx); ok { + result.Identity = val.Identity } - if userID, ok := ctx.Value(UserIDContextKey{}).(string); ok { - result.UserID = userID + if val, ok := ctxstore.Value[UserIDContextKey](ctx); ok { + result.UserID = val.UserID } return result, nil diff --git a/internal/pkg/auth/check_test.go b/internal/pkg/auth/check_test.go index 4ab4d047..93bb8eac 100644 --- a/internal/pkg/auth/check_test.go +++ b/internal/pkg/auth/check_test.go @@ -13,6 +13,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) func TestCheck(t *testing.T) { @@ -30,18 +31,20 @@ func TestCheck(t *testing.T) { }, { name: "auth disabled", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - false, + auth.EnabledAuthContextKey{ + Enabled: false, + }, ), }, { name: "not authenticated, no requirements", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), want: auth.CheckResult{ AuthEnabled: true, @@ -50,44 +53,49 @@ func TestCheck(t *testing.T) { }, { name: "not authenticated, verified email", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithVerifiedEmail()}, errorIs: auth.ErrUnauthenticated, }, { name: "not authenticated, none role", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithValidSignature(true)}, errorIs: auth.ErrUnauthenticated, }, { name: "not authenticated, operator role", - ctx: context.WithValue( + ctx: ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator)}, errorIs: auth.ErrUnauthenticated, }, { name: "verified email", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.VerifiedEmailContextKey{}, - "user@example.com", + auth.VerifiedEmailContextKey{ + Email: "user@example.com", + }, ), opts: []auth.CheckOption{auth.WithVerifiedEmail()}, want: auth.CheckResult{ @@ -98,14 +106,16 @@ func TestCheck(t *testing.T) { }, { name: "role okay", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator)}, want: auth.CheckResult{ @@ -116,36 +126,42 @@ func TestCheck(t *testing.T) { }, { name: "role mismatch", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), opts: []auth.CheckOption{auth.WithRole(role.Admin)}, errorIs: auth.ErrUnauthorized, }, { name: "role and verified email", - ctx: context.WithValue( - context.WithValue( - context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( + ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.Operator, + auth.RoleContextKey{ + Role: role.Operator, + }, ), - auth.VerifiedEmailContextKey{}, - "user@example.com", + auth.VerifiedEmailContextKey{ + Email: "user@example.com", + }, ), - auth.IdentityContextKey{}, - "user2@example.com", + auth.IdentityContextKey{ + Identity: "user2@example.com", + }, ), opts: []auth.CheckOption{auth.WithRole(role.Operator), auth.WithVerifiedEmail()}, want: auth.CheckResult{ @@ -158,14 +174,16 @@ func TestCheck(t *testing.T) { }, { name: "valid signature", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.RoleContextKey{}, - role.None, + auth.RoleContextKey{ + Role: role.None, + }, ), opts: []auth.CheckOption{}, want: auth.CheckResult{ @@ -176,14 +194,16 @@ func TestCheck(t *testing.T) { }, { name: "missing signature", - ctx: context.WithValue( - context.WithValue( + ctx: ctxstore.WithValue( + ctxstore.WithValue( context.Background(), - auth.EnabledAuthContextKey{}, - true, + auth.EnabledAuthContextKey{ + Enabled: true, + }, ), - auth.VerifiedEmailContextKey{}, - "me@example.com", + auth.VerifiedEmailContextKey{ + Email: "me@example.com", + }, ), opts: []auth.CheckOption{auth.WithValidSignature(true)}, errorIs: auth.ErrUnauthenticated, diff --git a/internal/pkg/auth/context.go b/internal/pkg/auth/context.go index 42eb3485..c5b57475 100644 --- a/internal/pkg/auth/context.go +++ b/internal/pkg/auth/context.go @@ -5,20 +5,26 @@ package auth -// EnabledAuthContextKey is the context key for enabled authentication. Value has the type bool. -type EnabledAuthContextKey struct{} +import ( + "github.com/siderolabs/go-api-signature/pkg/message" + + "github.com/siderolabs/omni/internal/pkg/auth/role" +) + +// EnabledAuthContextKey is the context key for enabled authentication. +type EnabledAuthContextKey struct{ Enabled bool } // GRPCMessageContextKey is the context key for the GRPC message. It is only set if authentication is enabled. -type GRPCMessageContextKey struct{} +type GRPCMessageContextKey struct{ Message *message.GRPC } -// VerifiedEmailContextKey is the context key for the verified email address. Value has the type string. -type VerifiedEmailContextKey struct{} +// VerifiedEmailContextKey is the context key for the verified email address. +type VerifiedEmailContextKey struct{ Email string } // UserIDContextKey is the context key for the user ID. Value has the type string. -type UserIDContextKey struct{} +type UserIDContextKey struct{ UserID string } // RoleContextKey is the context key for the role. Value has the type role.Role. -type RoleContextKey struct{} +type RoleContextKey struct{ Role role.Role } -// IdentityContextKey is the context key for the user identity. Value has the type string. -type IdentityContextKey struct{} +// IdentityContextKey is the context key for the user identity. +type IdentityContextKey struct{ Identity string } diff --git a/internal/pkg/auth/handler/auth_config.go b/internal/pkg/auth/handler/auth_config.go index 0969ef44..4992da5e 100644 --- a/internal/pkg/auth/handler/auth_config.go +++ b/internal/pkg/auth/handler/auth_config.go @@ -6,12 +6,12 @@ package handler import ( - "context" "net/http" "go.uber.org/zap" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // AuthConfig represents the configuration for the auth config interceptor. @@ -31,7 +31,7 @@ func NewAuthConfig(handler http.Handler, enabled bool, logger *zap.Logger) *Auth } func (c *AuthConfig) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - ctx := context.WithValue(request.Context(), auth.EnabledAuthContextKey{}, c.enabled) + ctx := ctxstore.WithValue(request.Context(), auth.EnabledAuthContextKey{Enabled: c.enabled}) request = request.WithContext(ctx) c.next.ServeHTTP(writer, request) diff --git a/internal/pkg/auth/handler/handler_test.go b/internal/pkg/auth/handler/handler_test.go index 6a308d84..cb35895f 100644 --- a/internal/pkg/auth/handler/handler_test.go +++ b/internal/pkg/auth/handler/handler_test.go @@ -20,6 +20,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/handler" "github.com/siderolabs/omni/internal/pkg/auth/role" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) func testHandler(t *testing.T, authEnabled bool) { @@ -130,25 +131,25 @@ func testHandler(t *testing.T, authEnabled bool) { t.Fatal("timeout") } - ctxAuthEnabled, ok := reqCtx.Value(auth.EnabledAuthContextKey{}).(bool) + ctxAuthEnabledVal, ok := ctxstore.Value[auth.EnabledAuthContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, authEnabled, ctxAuthEnabled) + assert.Equal(t, authEnabled, ctxAuthEnabledVal.Enabled) if !tc.verifyContext { return } - ctxUserID, ok := reqCtx.Value(auth.UserIDContextKey{}).(string) + ctxUserIDVal, ok := ctxstore.Value[auth.UserIDContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, "user-id", ctxUserID) + assert.Equal(t, "user-id", ctxUserIDVal.UserID) - ctxRole, ok := reqCtx.Value(auth.RoleContextKey{}).(role.Role) + ctxRoleVal, ok := ctxstore.Value[auth.RoleContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, role.Operator, ctxRole) + assert.Equal(t, role.Operator, ctxRoleVal.Role) - ctxIdentity, ok := reqCtx.Value(auth.IdentityContextKey{}).(string) + ctxIdentityVal, ok := ctxstore.Value[auth.IdentityContextKey](reqCtx) //nolint:contextcheck require.True(t, ok) - assert.Equal(t, "user@example.com", ctxIdentity) + assert.Equal(t, "user@example.com", ctxIdentityVal.Identity) }) } } diff --git a/internal/pkg/auth/handler/signature.go b/internal/pkg/auth/handler/signature.go index 2850a5ab..5ede724e 100644 --- a/internal/pkg/auth/handler/signature.go +++ b/internal/pkg/auth/handler/signature.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errInvalidSignature = errors.New("invalid signature") @@ -102,9 +103,9 @@ func (s *Signature) intercept(request *http.Request) (*http.Request, error) { Set("authenticator.identity", authenticator.Identity). Set("authenticator.role", string(authenticator.Role)) - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity) - ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID) - ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity}) + ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID}) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role}) return request.WithContext(ctx), nil } diff --git a/internal/pkg/auth/interceptor/auth_config.go b/internal/pkg/auth/interceptor/auth_config.go index 6ef62bb2..8b1a91c2 100644 --- a/internal/pkg/auth/interceptor/auth_config.go +++ b/internal/pkg/auth/interceptor/auth_config.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/metadata" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) // AuthConfig represents the configuration for the auth config interceptor. @@ -53,7 +54,7 @@ func (c *AuthConfig) Stream() grpc.StreamServerInterceptor { } func (c *AuthConfig) intercept(ctx context.Context, method string) context.Context { - ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, c.enabled) + ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: c.enabled}) if !c.enabled { return ctx @@ -64,7 +65,5 @@ func (c *AuthConfig) intercept(ctx context.Context, method string) context.Conte md = metadata.New(nil) } - msg := message.NewGRPC(md, method) - - return context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg) + return ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, method)}) } diff --git a/internal/pkg/auth/interceptor/jwt.go b/internal/pkg/auth/interceptor/jwt.go index 7c08fff0..fbae7f84 100644 --- a/internal/pkg/auth/interceptor/jwt.go +++ b/internal/pkg/auth/interceptor/jwt.go @@ -20,6 +20,7 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/auth0" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidJWT = status.Error(codes.Unauthenticated, "invalid jwt") @@ -66,12 +67,12 @@ func (i *JWT) Stream() grpc.StreamServerInterceptor { } func (i *JWT) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - claims, err := msg.VerifyJWT(ctx, i.jwtVerifier) + claims, err := msgVal.Message.VerifyJWT(ctx, i.jwtVerifier) if errors.Is(err, message.ErrNotFound) { // missing jwt, pass it through return ctx, nil } @@ -90,7 +91,7 @@ func (i *JWT) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidJWT } - ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, claims.VerifiedEmail) + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: claims.VerifiedEmail}) return ctx, nil } diff --git a/internal/pkg/auth/interceptor/saml.go b/internal/pkg/auth/interceptor/saml.go index 485cc89b..3188f0fc 100644 --- a/internal/pkg/auth/interceptor/saml.go +++ b/internal/pkg/auth/interceptor/saml.go @@ -15,7 +15,6 @@ import ( "github.com/cosi-project/runtime/pkg/state" "github.com/crewjam/saml" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/siderolabs/go-api-signature/pkg/message" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -25,6 +24,7 @@ import ( authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidSAML = status.Error(codes.Unauthenticated, "invalid session") @@ -71,12 +71,12 @@ func (i *SAML) Stream() grpc.StreamServerInterceptor { } func (i *SAML) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - values := msg.Metadata.Get(auth.SamlSessionHeaderKey) + values := msgVal.Message.Metadata.Get(auth.SamlSessionHeaderKey) if len(values) == 0 { return ctx, nil } @@ -86,7 +86,7 @@ func (i *SAML) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSAML } - ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, session.TypedSpec().Value.Email) + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: session.TypedSpec().Value.Email}) return ctx, nil } diff --git a/internal/pkg/auth/interceptor/signature.go b/internal/pkg/auth/interceptor/signature.go index 2872d472..ec896e43 100644 --- a/internal/pkg/auth/interceptor/signature.go +++ b/internal/pkg/auth/interceptor/signature.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/siderolabs/omni/internal/pkg/auth" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) var errGRPCInvalidSignature = status.Error(codes.Unauthenticated, "invalid signature") @@ -64,12 +65,12 @@ func (i *Signature) Stream() grpc.StreamServerInterceptor { } func (i *Signature) intercept(ctx context.Context) (context.Context, error) { - msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC) + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") } - signature, err := msg.Signature() + signature, err := msgVal.Message.Signature() if errors.Is(err, message.ErrNotFound) { // missing signature, pass it through grpc_ctxtags.Extract(ctx). Set("authenticator.user_id", ""). @@ -100,7 +101,7 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSignature } - err = msg.VerifySignature(authenticator.Verifier) + err = msgVal.Message.VerifySignature(authenticator.Verifier) if err != nil { i.logger.Info("failed to verify message", zap.Error(err)) @@ -112,9 +113,9 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) { Set("authenticator.identity", authenticator.Identity). Set("authenticator.role", string(authenticator.Role)) - ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID) - ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity) - ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role) + ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID}) + ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity}) + ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role}) return ctx, nil } diff --git a/internal/pkg/ctxstore/ctxstore.go b/internal/pkg/ctxstore/ctxstore.go new file mode 100644 index 00000000..9a24548f --- /dev/null +++ b/internal/pkg/ctxstore/ctxstore.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +// Package ctxstore provides a way to store values in the context with the key based on the type. +package ctxstore + +import "context" + +// phantomKey represent key based on type. The cool thing about this empty struct, +// is that two instances of phantomKey with the different type are different, while +// two instances of phantomKey with the same type are the same. This is useful for +// creating a unique key for each type. It also helps to avoid collision with other +// keys in the context. +// +// It also does not allocate when used as a key in the context (aka converted to any). +// Same goes for int and struct containing single int field with value below 256. +// Same goes for bool and struct containing single bool field. +type phantomKey[T any] struct{} + +// WithValue creates a new context with the value. Key is based on the type of the value. +func WithValue[T any](ctx context.Context, val T) context.Context { + return context.WithValue(ctx, phantomKey[T]{}, val) +} + +// Value returns the value from the context. Key is based on the type of the value. +func Value[T any](ctx context.Context) (T, bool) { + value := ctx.Value(phantomKey[T]{}) + if value == nil { + return *new(T), false + } + + return value.(T), true //nolint:forcetypeassert +} diff --git a/internal/pkg/ctxstore/ctxstore_test.go b/internal/pkg/ctxstore/ctxstore_test.go new file mode 100644 index 00000000..9297e300 --- /dev/null +++ b/internal/pkg/ctxstore/ctxstore_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package ctxstore_test + +import ( + "context" + "runtime" + "testing" + + "github.com/siderolabs/gen/pair" + "github.com/stretchr/testify/assert" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) + +func TestWithValue(t *testing.T) { + ctx := ctxstore.WithValue(context.Background(), "value1") + ctx = ctxstore.WithValue(ctx, 42) + ctx = ctxstore.WithValue(ctx, true) + + type ( + customString string + stringAlias = string + ) + + var cs customString + + assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[string](ctx))) + assert.Equal(t, pair.MakePair(42, true), pair.MakePair(ctxstore.Value[int](ctx))) + assert.Equal(t, pair.MakePair(true, true), pair.MakePair(ctxstore.Value[bool](ctx))) + assert.Equal(t, pair.MakePair(0.0, false), pair.MakePair(ctxstore.Value[float64](ctx))) + assert.Equal(t, pair.MakePair(cs, false), pair.MakePair(ctxstore.Value[customString](ctx))) + assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[stringAlias](ctx))) +} + +func BenchmarkWithValue(b *testing.B) { + b.ReportAllocs() + + type ( + emtpyStruct struct{} + myStruct[T any] struct{ Val T } + ) + + b.Run("empty struct", func(b *testing.B) { benchmarkFor(b, emtpyStruct{}) }) + b.Run("small int", func(b *testing.B) { benchmarkFor(b, 42) }) + b.Run("small int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 42}) }) + b.Run("normal int", func(b *testing.B) { benchmarkFor(b, 424242) }) + b.Run("normal int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 424242}) }) + b.Run("bool", func(b *testing.B) { benchmarkFor(b, true) }) + b.Run("bool inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[bool]{Val: true}) }) + b.Run("string", func(b *testing.B) { benchmarkFor(b, "value") }) + b.Run("string inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[string]{Val: "value"}) }) +} + +func benchmarkFor[T any](b *testing.B, value T) { + b.ReportAllocs() + + var ( + ok bool + result T + ) + + for range b.N { + ctx := ctxstore.WithValue(context.Background(), value) + + result, ok = ctxstore.Value[T](ctx) + if !ok { + b.Fatal("unexpected") + } + } + + runtime.KeepAlive(result) +}