mirror of
https://github.com/siderolabs/omni.git
synced 2026-05-05 06:36:12 +02:00
chore: rework auth.* keys, add ctxstore package
Using so-called phantom types we can use the types themselves as keys directly without loosing performance. You no longer need to remember which type was attached to the thing you passed in context and can look up all fields access directly. Part of #37 Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
This commit is contained in:
parent
76263e12a4
commit
4cfc0e6dd0
@ -44,6 +44,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/user"
|
||||
"github.com/siderolabs/omni/internal/pkg/config"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
"github.com/siderolabs/omni/internal/pkg/features"
|
||||
"github.com/siderolabs/omni/internal/pkg/siderolink"
|
||||
"github.com/siderolabs/omni/internal/version"
|
||||
@ -235,7 +236,7 @@ func runWithState(logger *zap.Logger) func(context.Context, state.State, *virtua
|
||||
return fmt.Errorf("failed to update features config resources: %w", err)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, authres.Enabled(authConfig))
|
||||
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: authres.Enabled(authConfig)})
|
||||
|
||||
handler, err := backend.NewFrontendHandler(rootCmdArgs.frontendDst, logger)
|
||||
if err != nil {
|
||||
|
||||
@ -45,6 +45,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/config"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
//go:embed testdata/admin-kubeconfig.yaml
|
||||
@ -193,11 +194,11 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
|
||||
md = metadata.New(nil)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, true)
|
||||
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: true})
|
||||
|
||||
msg := message.NewGRPC(md, info.FullMethod)
|
||||
|
||||
ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
|
||||
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: msg})
|
||||
|
||||
if r := md.Get("role"); len(r) > 0 {
|
||||
var parsed role.Role
|
||||
@ -207,7 +208,7 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.RoleContextKey{}, parsed)
|
||||
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: parsed})
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
|
||||
@ -60,6 +60,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/config"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
"github.com/siderolabs/omni/internal/pkg/siderolink"
|
||||
)
|
||||
|
||||
@ -939,9 +940,10 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
|
||||
if !userRoleExists {
|
||||
userRole = role.None
|
||||
userRole := role.None
|
||||
|
||||
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
|
||||
userRole = val.Role
|
||||
}
|
||||
|
||||
newRole, err := role.Max(userRole, clusterRole)
|
||||
@ -953,7 +955,7 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, auth.RoleContextKey{}, newRole), nil
|
||||
return ctxstore.WithValue(ctx, auth.RoleContextKey{Role: newRole}), nil
|
||||
}
|
||||
|
||||
func handleError(err error) error {
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/backend/dns"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
"github.com/siderolabs/omni/internal/pkg/grpcutil"
|
||||
)
|
||||
|
||||
@ -78,9 +79,8 @@ func (backend *TalosBackend) GetConnection(ctx context.Context, fullMethodName s
|
||||
// we can't use regular gRPC server interceptors here, as proxy interface is a bit different
|
||||
|
||||
// prepare context values for the verifier
|
||||
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, backend.authEnabled)
|
||||
msg := message.NewGRPC(md, fullMethodName)
|
||||
ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
|
||||
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: backend.authEnabled})
|
||||
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, fullMethodName)})
|
||||
|
||||
grpcutil.SetShouldLog(ctx, "talos-backend")
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// clusterContextKey is a type for cluster name.
|
||||
type clusterContextKey struct{}
|
||||
type clusterContextKey struct{ ClusterName string }
|
||||
|
||||
// Handler implements the HTTP reverse proxy for Kubernetes clusters.
|
||||
//
|
||||
|
||||
@ -18,6 +18,8 @@ import (
|
||||
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
|
||||
"go.uber.org/zap"
|
||||
"k8s.io/client-go/transport"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
const authorizationHeader = "Authorization"
|
||||
@ -107,7 +109,7 @@ func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolve
|
||||
}
|
||||
|
||||
// clone the request before modifying it
|
||||
req = req.WithContext(context.WithValue(ctx, clusterContextKey{}, clusterName))
|
||||
req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName}))
|
||||
|
||||
// clean all headers which are going to be overridden
|
||||
req.Header.Del(authorizationHeader)
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"k8s.io/client-go/transport"
|
||||
|
||||
"github.com/siderolabs/omni/internal/backend/k8sproxy"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) {
|
||||
@ -276,9 +277,9 @@ func TestAuthorize(t *testing.T) {
|
||||
assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader))
|
||||
assert.Nil(t, receivedReq.Header.Values("Authorization"))
|
||||
|
||||
v, ok := receivedReq.Context().Value(k8sproxy.ClusterContextKey{}).(string)
|
||||
v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.expectedCluster, v)
|
||||
assert.Equal(t, tc.expectedCluster, v.ClusterName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/siderolabs/omni/client/api/common"
|
||||
"github.com/siderolabs/omni/internal/backend/runtime"
|
||||
"github.com/siderolabs/omni/internal/backend/runtime/kubernetes"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// multiplexer provides an http.RoundTripper which selects the cluster based on the request context.
|
||||
@ -74,12 +75,12 @@ func newMultiplexer() *multiplexer {
|
||||
|
||||
// RoundTrip implements http.RoundTripper interface.
|
||||
func (m *multiplexer) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
|
||||
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
|
||||
if !ok {
|
||||
return nil, errors.New("cluster name not found in request context")
|
||||
}
|
||||
|
||||
rt, err := m.getRT(req.Context(), clusterName)
|
||||
rt, err := m.getRT(req.Context(), clusterNameVal.ClusterName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/siderolabs/omni/internal/backend/logging"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// proxyHandler implements the HTTP reverse proxy.
|
||||
@ -44,14 +45,14 @@ func newProxyHandler(m *multiplexer, logger *zap.Logger) *proxyHandler {
|
||||
|
||||
// director sets the target URL for the reverse proxy.
|
||||
func (p *proxyHandler) director(req *http.Request) {
|
||||
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
|
||||
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
|
||||
if !ok {
|
||||
ctxzap.Error(req.Context(), "cluster name not found in request context")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterName)
|
||||
connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterNameVal.ClusterName)
|
||||
if err != nil {
|
||||
ctxzap.Error(req.Context(), "failed to get cluster connector", zap.Error(err))
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/backend/workloadproxy"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// using whitelisted for external API access type.
|
||||
@ -72,7 +73,7 @@ func (suite *OmniRuntimeSuite) SetupTest() {
|
||||
suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
|
||||
// disable auth in the context
|
||||
suite.ctx = context.WithValue(suite.ctx, auth.EnabledAuthContextKey{}, false)
|
||||
suite.ctx = ctxstore.WithValue(suite.ctx, auth.EnabledAuthContextKey{Enabled: false})
|
||||
|
||||
var err error
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/config"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -257,7 +258,7 @@ func checkForRole(ctx context.Context, st state.State, access state.Access, clus
|
||||
|
||||
if clusterRole != role.None && (!requireAll || (requireAll && matchesAll)) {
|
||||
// override the role in the context with the computed role for this cluster
|
||||
ctx = context.WithValue(ctx, auth.RoleContextKey{}, clusterRole)
|
||||
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: clusterRole})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// State is a virtual state implementation which provides virtual resources.
|
||||
@ -161,16 +162,17 @@ func (v *State) validateKind(kind resource.Kind) error {
|
||||
}
|
||||
|
||||
func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
|
||||
identity, _ := ctx.Value(auth.IdentityContextKey{}).(string) //nolint:errcheck
|
||||
identityVal, _ := ctxstore.Value[auth.IdentityContextKey](ctx)
|
||||
|
||||
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
|
||||
if !userRoleExists {
|
||||
userRole = role.None
|
||||
userRole := role.None
|
||||
|
||||
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
|
||||
userRole = val.Role
|
||||
}
|
||||
|
||||
user := virtual.NewCurrentUser()
|
||||
|
||||
user.TypedSpec().Value.Identity = identity
|
||||
user.TypedSpec().Value.Identity = identityVal.Identity
|
||||
user.TypedSpec().Value.Role = string(userRole)
|
||||
|
||||
version, err := resource.ParseVersion("1")
|
||||
@ -184,9 +186,10 @@ func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
|
||||
}
|
||||
|
||||
func (v *State) permissions(ctx context.Context) (*virtual.Permissions, error) {
|
||||
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
|
||||
if !userRoleExists {
|
||||
userRole = role.None
|
||||
userRole := role.None
|
||||
|
||||
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
|
||||
userRole = val.Role
|
||||
}
|
||||
|
||||
permissions := virtual.NewPermissions()
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// RoleProvider provides the current actor's role for a cluster.
|
||||
@ -119,10 +120,10 @@ func (p *PGPAccessValidator) ValidateAccess(ctx context.Context, publicKeyID, pu
|
||||
return parseErr
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.RoleContextKey{}, publicKeyRole)
|
||||
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: publicKeyRole})
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, publicKey.TypedSpec().Value.GetIdentity().GetEmail())
|
||||
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: publicKey.TypedSpec().Value.GetIdentity().GetEmail()})
|
||||
|
||||
accessRole, err := p.roleProvider.RoleForCluster(ctx, clusterID)
|
||||
if err != nil {
|
||||
|
||||
@ -18,13 +18,15 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// RoleForCluster returns the role of the current user for the given cluster, and whether the role matches all clusters.
|
||||
func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.Role, bool, error) {
|
||||
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
|
||||
if !userRoleExists {
|
||||
userRole = role.None
|
||||
userRole := role.None
|
||||
|
||||
if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
|
||||
userRole = val.Role
|
||||
}
|
||||
|
||||
ctx = actor.MarkContextAsInternalActor(ctx)
|
||||
@ -38,12 +40,12 @@ func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.R
|
||||
return role.None, false, err
|
||||
}
|
||||
|
||||
identityStr, identityExists := ctx.Value(auth.IdentityContextKey{}).(string)
|
||||
identityVal, identityExists := ctxstore.Value[auth.IdentityContextKey](ctx)
|
||||
if !identityExists {
|
||||
return userRole, false, nil
|
||||
}
|
||||
|
||||
identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityStr).Metadata())
|
||||
identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityVal.Identity).Metadata())
|
||||
if err != nil {
|
||||
if state.IsNotFoundError(err) {
|
||||
return userRole, false, nil
|
||||
|
||||
@ -6,17 +6,23 @@
|
||||
// Package actor implements the context marking for internal/external actors.
|
||||
package actor
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// internalActorContextKey is the key for internal actor context.
|
||||
type internalActorContextKey struct{}
|
||||
|
||||
// MarkContextAsInternalActor returns a new derived context from the given context, marked as an internal actor.
|
||||
func MarkContextAsInternalActor(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, internalActorContextKey{}, struct{}{})
|
||||
return ctxstore.WithValue(ctx, internalActorContextKey{})
|
||||
}
|
||||
|
||||
// ContextIsInternalActor returns true if the given context is marked as an internal actor.
|
||||
func ContextIsInternalActor(ctx context.Context) bool {
|
||||
return ctx.Value(internalActorContextKey{}) != nil
|
||||
_, ok := ctxstore.Value[internalActorContextKey](ctx)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -81,19 +82,19 @@ func WithVerifiedEmail() CheckOption {
|
||||
//
|
||||
// The returned error can be checked against ErrUnauthenticated and ErrUnauthorized.
|
||||
func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
|
||||
authEnabled, ok := ctx.Value(EnabledAuthContextKey{}).(bool)
|
||||
authVal, ok := ctxstore.Value[EnabledAuthContextKey](ctx)
|
||||
if !ok {
|
||||
return CheckResult{}, fmt.Errorf("%w: auth configuration not found in context", ErrUnauthenticated)
|
||||
}
|
||||
|
||||
if !authEnabled {
|
||||
if !authVal.Enabled {
|
||||
return CheckResult{
|
||||
AuthEnabled: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := CheckResult{
|
||||
AuthEnabled: authEnabled,
|
||||
AuthEnabled: authVal.Enabled,
|
||||
}
|
||||
|
||||
opts := DefaultCheckOptions()
|
||||
@ -108,22 +109,25 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
|
||||
}
|
||||
|
||||
if opts.VerifiedEmail {
|
||||
email, ok := ctx.Value(VerifiedEmailContextKey{}).(string)
|
||||
emailVal, ok := ctxstore.Value[VerifiedEmailContextKey](ctx)
|
||||
if !ok {
|
||||
return CheckResult{}, fmt.Errorf("%w: missing verified email", ErrUnauthenticated)
|
||||
}
|
||||
|
||||
result.VerifiedEmail = email
|
||||
result.VerifiedEmail = emailVal.Email
|
||||
}
|
||||
|
||||
ctxRole, ctxRoleExists := ctx.Value(RoleContextKey{}).(role.Role)
|
||||
if !ctxRoleExists {
|
||||
ctxRole = role.None
|
||||
ctxRole := role.None
|
||||
ctxRoleExists := false
|
||||
|
||||
if val, ok := ctxstore.Value[RoleContextKey](ctx); ok {
|
||||
ctxRole = val.Role
|
||||
ctxRoleExists = true
|
||||
}
|
||||
|
||||
result.Role = ctxRole
|
||||
|
||||
// RoleContextKey{} is set on the context only when there is a valid signature, so we can rely on this.
|
||||
// RoleContextKey is set on the context only when there is a valid signature, so we can rely on this.
|
||||
result.HasValidSignature = ctxRoleExists
|
||||
|
||||
if opts.ValidSignature && !result.HasValidSignature {
|
||||
@ -137,12 +141,12 @@ func Check(ctx context.Context, opt ...CheckOption) (CheckResult, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if identity, ok := ctx.Value(IdentityContextKey{}).(string); ok {
|
||||
result.Identity = identity
|
||||
if val, ok := ctxstore.Value[IdentityContextKey](ctx); ok {
|
||||
result.Identity = val.Identity
|
||||
}
|
||||
|
||||
if userID, ok := ctx.Value(UserIDContextKey{}).(string); ok {
|
||||
result.UserID = userID
|
||||
if val, ok := ctxstore.Value[UserIDContextKey](ctx); ok {
|
||||
result.UserID = val.UserID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
func TestCheck(t *testing.T) {
|
||||
@ -30,18 +31,20 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "auth disabled",
|
||||
ctx: context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
false,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: false,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "not authenticated, no requirements",
|
||||
ctx: context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
want: auth.CheckResult{
|
||||
AuthEnabled: true,
|
||||
@ -50,44 +53,49 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "not authenticated, verified email",
|
||||
ctx: context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithVerifiedEmail()},
|
||||
errorIs: auth.ErrUnauthenticated,
|
||||
},
|
||||
{
|
||||
name: "not authenticated, none role",
|
||||
ctx: context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithValidSignature(true)},
|
||||
errorIs: auth.ErrUnauthenticated,
|
||||
},
|
||||
{
|
||||
name: "not authenticated, operator role",
|
||||
ctx: context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithRole(role.Operator)},
|
||||
errorIs: auth.ErrUnauthenticated,
|
||||
},
|
||||
{
|
||||
name: "verified email",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.VerifiedEmailContextKey{},
|
||||
"user@example.com",
|
||||
auth.VerifiedEmailContextKey{
|
||||
Email: "user@example.com",
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithVerifiedEmail()},
|
||||
want: auth.CheckResult{
|
||||
@ -98,14 +106,16 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "role okay",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.RoleContextKey{},
|
||||
role.Operator,
|
||||
auth.RoleContextKey{
|
||||
Role: role.Operator,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithRole(role.Operator)},
|
||||
want: auth.CheckResult{
|
||||
@ -116,36 +126,42 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "role mismatch",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.RoleContextKey{},
|
||||
role.Operator,
|
||||
auth.RoleContextKey{
|
||||
Role: role.Operator,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithRole(role.Admin)},
|
||||
errorIs: auth.ErrUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "role and verified email",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.RoleContextKey{},
|
||||
role.Operator,
|
||||
auth.RoleContextKey{
|
||||
Role: role.Operator,
|
||||
},
|
||||
),
|
||||
auth.VerifiedEmailContextKey{},
|
||||
"user@example.com",
|
||||
auth.VerifiedEmailContextKey{
|
||||
Email: "user@example.com",
|
||||
},
|
||||
),
|
||||
auth.IdentityContextKey{},
|
||||
"user2@example.com",
|
||||
auth.IdentityContextKey{
|
||||
Identity: "user2@example.com",
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithRole(role.Operator), auth.WithVerifiedEmail()},
|
||||
want: auth.CheckResult{
|
||||
@ -158,14 +174,16 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid signature",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.RoleContextKey{},
|
||||
role.None,
|
||||
auth.RoleContextKey{
|
||||
Role: role.None,
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{},
|
||||
want: auth.CheckResult{
|
||||
@ -176,14 +194,16 @@ func TestCheck(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "missing signature",
|
||||
ctx: context.WithValue(
|
||||
context.WithValue(
|
||||
ctx: ctxstore.WithValue(
|
||||
ctxstore.WithValue(
|
||||
context.Background(),
|
||||
auth.EnabledAuthContextKey{},
|
||||
true,
|
||||
auth.EnabledAuthContextKey{
|
||||
Enabled: true,
|
||||
},
|
||||
),
|
||||
auth.VerifiedEmailContextKey{},
|
||||
"me@example.com",
|
||||
auth.VerifiedEmailContextKey{
|
||||
Email: "me@example.com",
|
||||
},
|
||||
),
|
||||
opts: []auth.CheckOption{auth.WithValidSignature(true)},
|
||||
errorIs: auth.ErrUnauthenticated,
|
||||
|
||||
@ -5,20 +5,26 @@
|
||||
|
||||
package auth
|
||||
|
||||
// EnabledAuthContextKey is the context key for enabled authentication. Value has the type bool.
|
||||
type EnabledAuthContextKey struct{}
|
||||
import (
|
||||
"github.com/siderolabs/go-api-signature/pkg/message"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
)
|
||||
|
||||
// EnabledAuthContextKey is the context key for enabled authentication.
|
||||
type EnabledAuthContextKey struct{ Enabled bool }
|
||||
|
||||
// GRPCMessageContextKey is the context key for the GRPC message. It is only set if authentication is enabled.
|
||||
type GRPCMessageContextKey struct{}
|
||||
type GRPCMessageContextKey struct{ Message *message.GRPC }
|
||||
|
||||
// VerifiedEmailContextKey is the context key for the verified email address. Value has the type string.
|
||||
type VerifiedEmailContextKey struct{}
|
||||
// VerifiedEmailContextKey is the context key for the verified email address.
|
||||
type VerifiedEmailContextKey struct{ Email string }
|
||||
|
||||
// UserIDContextKey is the context key for the user ID. Value has the type string.
|
||||
type UserIDContextKey struct{}
|
||||
type UserIDContextKey struct{ UserID string }
|
||||
|
||||
// RoleContextKey is the context key for the role. Value has the type role.Role.
|
||||
type RoleContextKey struct{}
|
||||
type RoleContextKey struct{ Role role.Role }
|
||||
|
||||
// IdentityContextKey is the context key for the user identity. Value has the type string.
|
||||
type IdentityContextKey struct{}
|
||||
// IdentityContextKey is the context key for the user identity.
|
||||
type IdentityContextKey struct{ Identity string }
|
||||
|
||||
@ -6,12 +6,12 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// AuthConfig represents the configuration for the auth config interceptor.
|
||||
@ -31,7 +31,7 @@ func NewAuthConfig(handler http.Handler, enabled bool, logger *zap.Logger) *Auth
|
||||
}
|
||||
|
||||
func (c *AuthConfig) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
ctx := context.WithValue(request.Context(), auth.EnabledAuthContextKey{}, c.enabled)
|
||||
ctx := ctxstore.WithValue(request.Context(), auth.EnabledAuthContextKey{Enabled: c.enabled})
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
c.next.ServeHTTP(writer, request)
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/handler"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/role"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
func testHandler(t *testing.T, authEnabled bool) {
|
||||
@ -130,25 +131,25 @@ func testHandler(t *testing.T, authEnabled bool) {
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
ctxAuthEnabled, ok := reqCtx.Value(auth.EnabledAuthContextKey{}).(bool)
|
||||
ctxAuthEnabledVal, ok := ctxstore.Value[auth.EnabledAuthContextKey](reqCtx) //nolint:contextcheck
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, authEnabled, ctxAuthEnabled)
|
||||
assert.Equal(t, authEnabled, ctxAuthEnabledVal.Enabled)
|
||||
|
||||
if !tc.verifyContext {
|
||||
return
|
||||
}
|
||||
|
||||
ctxUserID, ok := reqCtx.Value(auth.UserIDContextKey{}).(string)
|
||||
ctxUserIDVal, ok := ctxstore.Value[auth.UserIDContextKey](reqCtx) //nolint:contextcheck
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user-id", ctxUserID)
|
||||
assert.Equal(t, "user-id", ctxUserIDVal.UserID)
|
||||
|
||||
ctxRole, ok := reqCtx.Value(auth.RoleContextKey{}).(role.Role)
|
||||
ctxRoleVal, ok := ctxstore.Value[auth.RoleContextKey](reqCtx) //nolint:contextcheck
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, role.Operator, ctxRole)
|
||||
assert.Equal(t, role.Operator, ctxRoleVal.Role)
|
||||
|
||||
ctxIdentity, ok := reqCtx.Value(auth.IdentityContextKey{}).(string)
|
||||
ctxIdentityVal, ok := ctxstore.Value[auth.IdentityContextKey](reqCtx) //nolint:contextcheck
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user@example.com", ctxIdentity)
|
||||
assert.Equal(t, "user@example.com", ctxIdentityVal.Identity)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var errInvalidSignature = errors.New("invalid signature")
|
||||
@ -102,9 +103,9 @@ func (s *Signature) intercept(request *http.Request) (*http.Request, error) {
|
||||
Set("authenticator.identity", authenticator.Identity).
|
||||
Set("authenticator.role", string(authenticator.Role))
|
||||
|
||||
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity)
|
||||
ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID)
|
||||
ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role)
|
||||
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity})
|
||||
ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID})
|
||||
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role})
|
||||
|
||||
return request.WithContext(ctx), nil
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
// AuthConfig represents the configuration for the auth config interceptor.
|
||||
@ -53,7 +54,7 @@ func (c *AuthConfig) Stream() grpc.StreamServerInterceptor {
|
||||
}
|
||||
|
||||
func (c *AuthConfig) intercept(ctx context.Context, method string) context.Context {
|
||||
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, c.enabled)
|
||||
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: c.enabled})
|
||||
|
||||
if !c.enabled {
|
||||
return ctx
|
||||
@ -64,7 +65,5 @@ func (c *AuthConfig) intercept(ctx context.Context, method string) context.Conte
|
||||
md = metadata.New(nil)
|
||||
}
|
||||
|
||||
msg := message.NewGRPC(md, method)
|
||||
|
||||
return context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
|
||||
return ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, method)})
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/auth0"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var errGRPCInvalidJWT = status.Error(codes.Unauthenticated, "invalid jwt")
|
||||
@ -66,12 +67,12 @@ func (i *JWT) Stream() grpc.StreamServerInterceptor {
|
||||
}
|
||||
|
||||
func (i *JWT) intercept(ctx context.Context) (context.Context, error) {
|
||||
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
|
||||
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing or invalid message in context")
|
||||
}
|
||||
|
||||
claims, err := msg.VerifyJWT(ctx, i.jwtVerifier)
|
||||
claims, err := msgVal.Message.VerifyJWT(ctx, i.jwtVerifier)
|
||||
if errors.Is(err, message.ErrNotFound) { // missing jwt, pass it through
|
||||
return ctx, nil
|
||||
}
|
||||
@ -90,7 +91,7 @@ func (i *JWT) intercept(ctx context.Context) (context.Context, error) {
|
||||
return nil, errGRPCInvalidJWT
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, claims.VerifiedEmail)
|
||||
ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: claims.VerifiedEmail})
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@ -15,7 +15,6 @@ import (
|
||||
"github.com/cosi-project/runtime/pkg/state"
|
||||
"github.com/crewjam/saml"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"github.com/siderolabs/go-api-signature/pkg/message"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -25,6 +24,7 @@ import (
|
||||
authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/auth/actor"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var errGRPCInvalidSAML = status.Error(codes.Unauthenticated, "invalid session")
|
||||
@ -71,12 +71,12 @@ func (i *SAML) Stream() grpc.StreamServerInterceptor {
|
||||
}
|
||||
|
||||
func (i *SAML) intercept(ctx context.Context) (context.Context, error) {
|
||||
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
|
||||
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing or invalid message in context")
|
||||
}
|
||||
|
||||
values := msg.Metadata.Get(auth.SamlSessionHeaderKey)
|
||||
values := msgVal.Message.Metadata.Get(auth.SamlSessionHeaderKey)
|
||||
if len(values) == 0 {
|
||||
return ctx, nil
|
||||
}
|
||||
@ -86,7 +86,7 @@ func (i *SAML) intercept(ctx context.Context) (context.Context, error) {
|
||||
return nil, errGRPCInvalidSAML
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, auth.VerifiedEmailContextKey{}, session.TypedSpec().Value.Email)
|
||||
ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: session.TypedSpec().Value.Email})
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/auth"
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
var errGRPCInvalidSignature = status.Error(codes.Unauthenticated, "invalid signature")
|
||||
@ -64,12 +65,12 @@ func (i *Signature) Stream() grpc.StreamServerInterceptor {
|
||||
}
|
||||
|
||||
func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
|
||||
msg, ok := ctx.Value(auth.GRPCMessageContextKey{}).(*message.GRPC)
|
||||
msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing or invalid message in context")
|
||||
}
|
||||
|
||||
signature, err := msg.Signature()
|
||||
signature, err := msgVal.Message.Signature()
|
||||
if errors.Is(err, message.ErrNotFound) { // missing signature, pass it through
|
||||
grpc_ctxtags.Extract(ctx).
|
||||
Set("authenticator.user_id", "").
|
||||
@ -100,7 +101,7 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
|
||||
return nil, errGRPCInvalidSignature
|
||||
}
|
||||
|
||||
err = msg.VerifySignature(authenticator.Verifier)
|
||||
err = msgVal.Message.VerifySignature(authenticator.Verifier)
|
||||
if err != nil {
|
||||
i.logger.Info("failed to verify message", zap.Error(err))
|
||||
|
||||
@ -112,9 +113,9 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) {
|
||||
Set("authenticator.identity", authenticator.Identity).
|
||||
Set("authenticator.role", string(authenticator.Role))
|
||||
|
||||
ctx = context.WithValue(ctx, auth.UserIDContextKey{}, authenticator.UserID)
|
||||
ctx = context.WithValue(ctx, auth.IdentityContextKey{}, authenticator.Identity)
|
||||
ctx = context.WithValue(ctx, auth.RoleContextKey{}, authenticator.Role)
|
||||
ctx = ctxstore.WithValue(ctx, auth.UserIDContextKey{UserID: authenticator.UserID})
|
||||
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: authenticator.Identity})
|
||||
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: authenticator.Role})
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
35
internal/pkg/ctxstore/ctxstore.go
Normal file
35
internal/pkg/ctxstore/ctxstore.go
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2024 Sidero Labs, Inc.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the LICENSE file.
|
||||
|
||||
// Package ctxstore provides a way to store values in the context with the key based on the type.
|
||||
package ctxstore
|
||||
|
||||
import "context"
|
||||
|
||||
// phantomKey represent key based on type. The cool thing about this empty struct,
|
||||
// is that two instances of phantomKey with the different type are different, while
|
||||
// two instances of phantomKey with the same type are the same. This is useful for
|
||||
// creating a unique key for each type. It also helps to avoid collision with other
|
||||
// keys in the context.
|
||||
//
|
||||
// It also does not allocate when used as a key in the context (aka converted to any).
|
||||
// Same goes for int and struct containing single int field with value below 256.
|
||||
// Same goes for bool and struct containing single bool field.
|
||||
type phantomKey[T any] struct{}
|
||||
|
||||
// WithValue creates a new context with the value. Key is based on the type of the value.
|
||||
func WithValue[T any](ctx context.Context, val T) context.Context {
|
||||
return context.WithValue(ctx, phantomKey[T]{}, val)
|
||||
}
|
||||
|
||||
// Value returns the value from the context. Key is based on the type of the value.
|
||||
func Value[T any](ctx context.Context) (T, bool) {
|
||||
value := ctx.Value(phantomKey[T]{})
|
||||
if value == nil {
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
return value.(T), true //nolint:forcetypeassert
|
||||
}
|
||||
76
internal/pkg/ctxstore/ctxstore_test.go
Normal file
76
internal/pkg/ctxstore/ctxstore_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2024 Sidero Labs, Inc.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the LICENSE file.
|
||||
|
||||
package ctxstore_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/siderolabs/gen/pair"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
||||
)
|
||||
|
||||
func TestWithValue(t *testing.T) {
|
||||
ctx := ctxstore.WithValue(context.Background(), "value1")
|
||||
ctx = ctxstore.WithValue(ctx, 42)
|
||||
ctx = ctxstore.WithValue(ctx, true)
|
||||
|
||||
type (
|
||||
customString string
|
||||
stringAlias = string
|
||||
)
|
||||
|
||||
var cs customString
|
||||
|
||||
assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[string](ctx)))
|
||||
assert.Equal(t, pair.MakePair(42, true), pair.MakePair(ctxstore.Value[int](ctx)))
|
||||
assert.Equal(t, pair.MakePair(true, true), pair.MakePair(ctxstore.Value[bool](ctx)))
|
||||
assert.Equal(t, pair.MakePair(0.0, false), pair.MakePair(ctxstore.Value[float64](ctx)))
|
||||
assert.Equal(t, pair.MakePair(cs, false), pair.MakePair(ctxstore.Value[customString](ctx)))
|
||||
assert.Equal(t, pair.MakePair("value1", true), pair.MakePair(ctxstore.Value[stringAlias](ctx)))
|
||||
}
|
||||
|
||||
func BenchmarkWithValue(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
type (
|
||||
emtpyStruct struct{}
|
||||
myStruct[T any] struct{ Val T }
|
||||
)
|
||||
|
||||
b.Run("empty struct", func(b *testing.B) { benchmarkFor(b, emtpyStruct{}) })
|
||||
b.Run("small int", func(b *testing.B) { benchmarkFor(b, 42) })
|
||||
b.Run("small int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 42}) })
|
||||
b.Run("normal int", func(b *testing.B) { benchmarkFor(b, 424242) })
|
||||
b.Run("normal int inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[int]{Val: 424242}) })
|
||||
b.Run("bool", func(b *testing.B) { benchmarkFor(b, true) })
|
||||
b.Run("bool inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[bool]{Val: true}) })
|
||||
b.Run("string", func(b *testing.B) { benchmarkFor(b, "value") })
|
||||
b.Run("string inside struct", func(b *testing.B) { benchmarkFor(b, myStruct[string]{Val: "value"}) })
|
||||
}
|
||||
|
||||
func benchmarkFor[T any](b *testing.B, value T) {
|
||||
b.ReportAllocs()
|
||||
|
||||
var (
|
||||
ok bool
|
||||
result T
|
||||
)
|
||||
|
||||
for range b.N {
|
||||
ctx := ctxstore.WithValue(context.Background(), value)
|
||||
|
||||
result, ok = ctxstore.Value[T](ctx)
|
||||
if !ok {
|
||||
b.Fatal("unexpected")
|
||||
}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(result)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user