feat: enforce configurable machine registration limit

Add `account.maxRegisteredMachines` config option to cap the number of registered machines. The provision handler atomically checks the limit under a mutex before creating new Link resources, returning ResourceExhausted when the cap is reached.

Introduce a Notification resource type (ephemeral namespace) so controllers can surface warnings to users. `omnictl` displays all active notifications on every command invocation. Frontend part of showing notifications will be implemented in a different PR.

MachineStatusMetricsController creates a warning notification when the registration limit is reached and tears it down when it's not.

Signed-off-by: Oguz Kilcan <oguz.kilcan@siderolabs.com>
This commit is contained in:
Oguz Kilcan 2026-03-16 10:01:42 +01:00
parent 711782b00e
commit cf7d752453
No known key found for this signature in database
GPG Key ID: 372F271E3AD80BFC
24 changed files with 1261 additions and 385 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1353,6 +1353,8 @@ message MachineStatusMetricsSpec {
map<string, uint32> platforms = 6;
map<string, uint32> secure_boot_status = 7;
map<string, uint32> uki_status = 8;
uint32 registered_machines_limit = 9;
bool registration_limit_reached = 10;
}
// ClusterMetricsSpec contains metrics about the clusters in the Omni instance.
@ -1618,3 +1620,17 @@ message RotateKubernetesCASpec {}
message UpgradeRolloutSpec {
map<string, int32> machine_sets_upgrade_quota = 1;
}
// NotificationSpec describes a generic notification emitted by a controller.
message NotificationSpec {
// Type describes the severity of a notification.
enum Type {
INFO = 0;
WARNING = 1;
ERROR = 2;
}
string title = 1;
string body = 2;
Type type = 3;
}

View File

@ -2447,6 +2447,8 @@ func (m *MachineStatusMetricsSpec) CloneVT() *MachineStatusMetricsSpec {
r.ConnectedMachinesCount = m.ConnectedMachinesCount
r.AllocatedMachinesCount = m.AllocatedMachinesCount
r.PendingMachinesCount = m.PendingMachinesCount
r.RegisteredMachinesLimit = m.RegisteredMachinesLimit
r.RegistrationLimitReached = m.RegistrationLimitReached
if rhs := m.Platforms; rhs != nil {
tmpContainer := make(map[string]uint32, len(rhs))
for k, v := range rhs {
@ -3109,6 +3111,25 @@ func (m *UpgradeRolloutSpec) CloneMessageVT() proto.Message {
return m.CloneVT()
}
func (m *NotificationSpec) CloneVT() *NotificationSpec {
if m == nil {
return (*NotificationSpec)(nil)
}
r := new(NotificationSpec)
r.Title = m.Title
r.Body = m.Body
r.Type = m.Type
if len(m.unknownFields) > 0 {
r.unknownFields = make([]byte, len(m.unknownFields))
copy(r.unknownFields, m.unknownFields)
}
return r
}
func (m *NotificationSpec) CloneMessageVT() proto.Message {
return m.CloneVT()
}
func (this *MachineSpec) EqualVT(that *MachineSpec) bool {
if this == that {
return true
@ -6500,6 +6521,12 @@ func (this *MachineStatusMetricsSpec) EqualVT(that *MachineStatusMetricsSpec) bo
return false
}
}
if this.RegisteredMachinesLimit != that.RegisteredMachinesLimit {
return false
}
if this.RegistrationLimitReached != that.RegistrationLimitReached {
return false
}
return string(this.unknownFields) == string(that.unknownFields)
}
@ -7346,6 +7373,31 @@ func (this *UpgradeRolloutSpec) EqualMessageVT(thatMsg proto.Message) bool {
}
return this.EqualVT(that)
}
func (this *NotificationSpec) EqualVT(that *NotificationSpec) bool {
if this == that {
return true
} else if this == nil || that == nil {
return false
}
if this.Title != that.Title {
return false
}
if this.Body != that.Body {
return false
}
if this.Type != that.Type {
return false
}
return string(this.unknownFields) == string(that.unknownFields)
}
func (this *NotificationSpec) EqualMessageVT(thatMsg proto.Message) bool {
that, ok := thatMsg.(*NotificationSpec)
if !ok {
return false
}
return this.EqualVT(that)
}
func (m *MachineSpec) MarshalVT() (dAtA []byte, err error) {
if m == nil {
return nil, nil
@ -13986,6 +14038,21 @@ func (m *MachineStatusMetricsSpec) MarshalToSizedBufferVT(dAtA []byte) (int, err
i -= len(m.unknownFields)
copy(dAtA[i:], m.unknownFields)
}
if m.RegistrationLimitReached {
i--
if m.RegistrationLimitReached {
dAtA[i] = 1
} else {
dAtA[i] = 0
}
i--
dAtA[i] = 0x50
}
if m.RegisteredMachinesLimit != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.RegisteredMachinesLimit))
i--
dAtA[i] = 0x48
}
if len(m.UkiStatus) > 0 {
for k := range m.UkiStatus {
v := m.UkiStatus[k]
@ -15665,6 +15732,58 @@ func (m *UpgradeRolloutSpec) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
return len(dAtA) - i, nil
}
func (m *NotificationSpec) MarshalVT() (dAtA []byte, err error) {
if m == nil {
return nil, nil
}
size := m.SizeVT()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBufferVT(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *NotificationSpec) MarshalToVT(dAtA []byte) (int, error) {
size := m.SizeVT()
return m.MarshalToSizedBufferVT(dAtA[:size])
}
func (m *NotificationSpec) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
if m == nil {
return 0, nil
}
i := len(dAtA)
_ = i
var l int
_ = l
if m.unknownFields != nil {
i -= len(m.unknownFields)
copy(dAtA[i:], m.unknownFields)
}
if m.Type != 0 {
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Type))
i--
dAtA[i] = 0x18
}
if len(m.Body) > 0 {
i -= len(m.Body)
copy(dAtA[i:], m.Body)
i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Body)))
i--
dAtA[i] = 0x12
}
if len(m.Title) > 0 {
i -= len(m.Title)
copy(dAtA[i:], m.Title)
i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Title)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *MachineSpec) SizeVT() (n int) {
if m == nil {
return 0
@ -18309,6 +18428,12 @@ func (m *MachineStatusMetricsSpec) SizeVT() (n int) {
n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize))
}
}
if m.RegisteredMachinesLimit != 0 {
n += 1 + protohelpers.SizeOfVarint(uint64(m.RegisteredMachinesLimit))
}
if m.RegistrationLimitReached {
n += 2
}
n += len(m.unknownFields)
return n
}
@ -18935,6 +19060,27 @@ func (m *UpgradeRolloutSpec) SizeVT() (n int) {
return n
}
func (m *NotificationSpec) SizeVT() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Title)
if l > 0 {
n += 1 + l + protohelpers.SizeOfVarint(uint64(l))
}
l = len(m.Body)
if l > 0 {
n += 1 + l + protohelpers.SizeOfVarint(uint64(l))
}
if m.Type != 0 {
n += 1 + protohelpers.SizeOfVarint(uint64(m.Type))
}
n += len(m.unknownFields)
return n
}
func (m *MachineSpec) UnmarshalVT(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
@ -35935,6 +36081,45 @@ func (m *MachineStatusMetricsSpec) UnmarshalVT(dAtA []byte) error {
}
m.UkiStatus[mapkey] = mapvalue
iNdEx = postIndex
case 9:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field RegisteredMachinesLimit", wireType)
}
m.RegisteredMachinesLimit = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.RegisteredMachinesLimit |= uint32(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 10:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field RegistrationLimitReached", wireType)
}
var v int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.RegistrationLimitReached = bool(v != 0)
default:
iNdEx = preIndex
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
@ -40045,3 +40230,137 @@ func (m *UpgradeRolloutSpec) UnmarshalVT(dAtA []byte) error {
}
return nil
}
func (m *NotificationSpec) UnmarshalVT(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: NotificationSpec: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: NotificationSpec: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Title", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return protohelpers.ErrInvalidLength
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return protohelpers.ErrInvalidLength
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Title = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Body", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return protohelpers.ErrInvalidLength
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return protohelpers.ErrInvalidLength
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Body = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 3:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType)
}
m.Type = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return protohelpers.ErrIntOverflow
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Type |= NotificationSpec_Type(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return protohelpers.ErrInvalidLength
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...)
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}

View File

@ -0,0 +1,52 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package omni
import (
"github.com/cosi-project/runtime/pkg/resource"
"github.com/cosi-project/runtime/pkg/resource/meta"
"github.com/cosi-project/runtime/pkg/resource/protobuf"
"github.com/cosi-project/runtime/pkg/resource/typed"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/client/pkg/omni/resources"
)
// NewNotification creates new Notification resource.
func NewNotification(id resource.ID) *Notification {
return typed.NewResource[NotificationSpec, NotificationExtension](
resource.NewMetadata(resources.EphemeralNamespace, NotificationType, id, resource.VersionUndefined),
protobuf.NewResourceSpec(&specs.NotificationSpec{}),
)
}
const (
// NotificationType is the type of the Notification resource.
// tsgen:NotificationType
NotificationType = resource.Type("Notifications.omni.sidero.dev")
// NotificationMachineRegistrationLimitID is the ID for the machine registration limit notification.
// tsgen:NotificationMachineRegistrationLimitID
NotificationMachineRegistrationLimitID = "machine-registration-limit"
)
// Notification describes a generic notification emitted by a controller.
type Notification = typed.Resource[NotificationSpec, NotificationExtension]
// NotificationSpec wraps specs.NotificationSpec.
type NotificationSpec = protobuf.ResourceSpec[specs.NotificationSpec, *specs.NotificationSpec]
// NotificationExtension provides auxiliary methods for Notification resource.
type NotificationExtension struct{}
// ResourceDefinition implements [typed.Extension] interface.
func (NotificationExtension) ResourceDefinition() meta.ResourceDefinitionSpec {
return meta.ResourceDefinitionSpec{
Type: NotificationType,
Aliases: []resource.Type{},
DefaultNamespace: resources.EphemeralNamespace,
PrintColumns: []meta.PrintColumn{},
}
}

View File

@ -79,6 +79,7 @@ func init() {
registry.MustRegisterResource(MachineStatusSnapshotType, &MachineStatusSnapshot{})
registry.MustRegisterResource(MachineStatusLinkType, &MachineStatusLink{})
registry.MustRegisterResource(MachineStatusMetricsType, &MachineStatusMetrics{})
registry.MustRegisterResource(NotificationType, &Notification{})
registry.MustRegisterResource(MaintenanceConfigStatusType, &MaintenanceConfigStatus{})
registry.MustRegisterResource(LoadBalancerConfigType, &LoadBalancerConfig{})
registry.MustRegisterResource(LoadBalancerStatusType, &LoadBalancerStatus{})

View File

@ -16,9 +16,11 @@ import (
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"go.uber.org/zap"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/client/pkg/client"
"github.com/siderolabs/omni/client/pkg/client/omni"
"github.com/siderolabs/omni/client/pkg/omni/resources"
omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/client/pkg/omni/resources/system"
"github.com/siderolabs/omni/client/pkg/omnictl/config"
"github.com/siderolabs/omni/client/pkg/version"
@ -147,6 +149,10 @@ func WithClient(f func(ctx context.Context, client *client.Client) error, client
return err
}
if err = checkNotifications(ctx, client.Omni().State()); err != nil {
return err
}
return f(ctx, client)
})
}
@ -180,6 +186,34 @@ If you want to enable the version validation and disable this warning, set githu
return nil
}
func checkNotifications(ctx context.Context, st state.State) error {
notifications, err := safe.StateListAll[*omnires.Notification](ctx, st)
if err != nil {
return fmt.Errorf("failed to list notifications: %w", err)
}
for n := range notifications.All() {
spec := n.TypedSpec().Value
var prefix string
switch spec.Type {
case specs.NotificationSpec_ERROR:
prefix = "[ERROR]"
case specs.NotificationSpec_WARNING:
prefix = "[WARN]"
case specs.NotificationSpec_INFO:
prefix = "[INFO]"
default:
prefix = "[UNKNOWN]"
}
fmt.Fprintf(os.Stderr, "%s %s: %s\n", prefix, spec.Title, spec.Body) //nolint:errcheck
}
return nil
}
func checkVersionWarning(sysVersion *system.SysVersion) {
backendVersion, err := semver.ParseTolerant(sysVersion.TypedSpec().Value.BackendVersion)
if err != nil {

View File

@ -154,6 +154,7 @@ func buildRootCommand() (*cobra.Command, error) {
rootCmdFlagBinder.StringVar("account-id", flagDescription("account.id", configSchema), &flagConfig.Account.Id)
rootCmdFlagBinder.StringVar("name", flagDescription("account.name", configSchema), &flagConfig.Account.Name)
rootCmdFlagBinder.StringVar("user-pilot-app-token", flagDescription("account.userPilot.appToken", configSchema), &flagConfig.Account.UserPilot.AppToken)
rootCmdFlagBinder.Uint32Var("account-max-registered-machines", flagDescription("account.maxRegisteredMachines", configSchema), &flagConfig.Account.MaxRegisteredMachines)
if err := defineServiceFlags(rootCmd, rootCmdFlagBinder, flagConfig, configSchema); err != nil {
return nil, fmt.Errorf("failed to define service flags: %w", err)

View File

@ -194,6 +194,12 @@ export enum SecretRotationSpecComponent {
KUBERNETES_CA = 2,
}
export enum NotificationSpecType {
INFO = 0,
WARNING = 1,
ERROR = 2,
}
export type MachineSpec = {
management_address?: string
connected?: boolean
@ -898,6 +904,8 @@ export type MachineStatusMetricsSpec = {
platforms?: {[key: string]: number}
secure_boot_status?: {[key: string]: number}
uki_status?: {[key: string]: number}
registered_machines_limit?: number
registration_limit_reached?: boolean
}
export type ClusterMetricsSpec = {
@ -1066,4 +1074,10 @@ export type RotateKubernetesCASpec = {
export type UpgradeRolloutSpec = {
machine_sets_upgrade_quota?: {[key: string]: number}
}
export type NotificationSpec = {
title?: string
body?: string
type?: NotificationSpecType
}

View File

@ -207,6 +207,8 @@ export const MachineStatusSnapshotType = "MachineStatusSnapshots.omni.sidero.dev
export const MachineUpgradeStatusType = "MachineUpgradeStatuses.omni.sidero.dev";
export const MaintenanceConfigStatusType = "MaintenanceConfigStatuses.omni.sidero.dev";
export const NodeForceDestroyRequestType = "NodeForceDestroyRequests.omni.sidero.dev";
export const NotificationType = "Notifications.omni.sidero.dev";
export const NotificationMachineRegistrationLimitID = "machine-registration-limit";
export const OngoingTaskType = "OngoingTasks.omni.sidero.dev";
export const RedactedClusterMachineConfigType = "RedactedClusterMachineConfigs.omni.sidero.dev";
export const RotateKubernetesCAType = "RotateKubernetesCAs.omni.sidero.dev";

View File

@ -47,6 +47,7 @@ export AUTH0_DOMAIN="${AUTH0_DOMAIN}"
export OMNI_CONFIG="${TEST_OUTPUTS_DIR}/config.yaml"
export MAX_USERS="${MAX_USERS:-0}"
export MAX_SERVICE_ACCOUNTS="${MAX_SERVICE_ACCOUNTS:-0}"
export MAX_REGISTERED_MACHINES="${MAX_REGISTERED_MACHINES:-0}"
export REGISTRY_MIRROR_FLAGS=()
export REGISTRY_MIRROR_CONFIG=""
export IMPORTED_CLUSTER_ARGS=()

View File

@ -61,6 +61,9 @@ prepare_vault
# Start MinIO server.
prepare_minio access_key="access" secret_key="secret123"
# Set the registration limit to the total number of machines so that the registration limit is actively enforced during the test.
export MAX_REGISTERED_MACHINES="${TOTAL_MACHINES}"
# Prepare omni config.
prepare_omni_config

View File

@ -1,3 +1,5 @@
account:
maxRegisteredMachines: ${MAX_REGISTERED_MACHINES}
services:
api:
endpoint: 0.0.0.0:8099

View File

@ -7,6 +7,7 @@ package omni
import (
"context"
"fmt"
"iter"
"strconv"
"sync"
@ -24,6 +25,7 @@ import (
"github.com/siderolabs/omni/client/pkg/omni/resources"
"github.com/siderolabs/omni/client/pkg/omni/resources/infra"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/helpers"
)
type nodeInfo struct {
@ -32,6 +34,13 @@ type nodeInfo struct {
connected bool
}
// NewMachineStatusMetricsController creates a new MachineStatusMetricsController.
func NewMachineStatusMetricsController(maxRegisteredMachines uint32) *MachineStatusMetricsController {
return &MachineStatusMetricsController{
maxRegisteredMachines: maxRegisteredMachines,
}
}
// MachineStatusMetricsController provides metrics based on ClusterStatus.
//
//nolint:govet
@ -41,6 +50,8 @@ type MachineStatusMetricsController struct {
metricsOnce sync.Once
maxRegisteredMachines uint32
platformNames []string
metricNumMachines prometheus.Gauge
@ -79,6 +90,10 @@ func (ctrl *MachineStatusMetricsController) Outputs() []controller.Output {
Type: omni.MachineStatusMetricsType,
Kind: controller.OutputExclusive,
},
{
Type: omni.NotificationType,
Kind: controller.OutputShared,
},
}
}
@ -161,6 +176,12 @@ func (ctrl *MachineStatusMetricsController) Run(ctx context.Context, r controlle
return err
}
if ctrl.maxRegisteredMachines > 0 {
if err = ctrl.reconcileRegistrationLimitNotification(ctx, r, metricsSpec); err != nil {
return err
}
}
select {
case <-ctx.Done():
return nil
@ -169,6 +190,26 @@ func (ctrl *MachineStatusMetricsController) Run(ctx context.Context, r controlle
}
}
func (ctrl *MachineStatusMetricsController) reconcileRegistrationLimitNotification(ctx context.Context, r controller.Runtime, metricsSpec *specs.MachineStatusMetricsSpec) error {
if metricsSpec.RegistrationLimitReached {
return safe.WriterModify(ctx, r, omni.NewNotification(omni.NotificationMachineRegistrationLimitID),
func(res *omni.Notification) error {
res.TypedSpec().Value.Title = "Machine Registration Limit Reached"
res.TypedSpec().Value.Body = fmt.Sprintf(
"%d/%d machines registered. New machines will be rejected.",
metricsSpec.RegisteredMachinesCount, metricsSpec.RegisteredMachinesLimit)
res.TypedSpec().Value.Type = specs.NotificationSpec_WARNING
return nil
},
)
}
_, err := helpers.TeardownAndDestroy(ctx, r, omni.NewNotification(omni.NotificationMachineRegistrationLimitID).Metadata())
return err
}
func (ctrl *MachineStatusMetricsController) gatherMetrics(statuses iter.Seq[*omni.MachineStatus], numPendingMachines int) *specs.MachineStatusMetricsSpec {
platformMetrics := make(map[string]uint32, len(ctrl.platformNames))
for _, p := range ctrl.platformNames {
@ -245,13 +286,15 @@ func (ctrl *MachineStatusMetricsController) gatherMetrics(statuses iter.Seq[*omn
}
return &specs.MachineStatusMetricsSpec{
ConnectedMachinesCount: uint32(connectedMachines),
RegisteredMachinesCount: uint32(machines),
AllocatedMachinesCount: uint32(allocatedMachines),
PendingMachinesCount: uint32(numPendingMachines),
Platforms: platformMetrics,
SecureBootStatus: secureBootStatusMetrics,
UkiStatus: ukiMetrics,
ConnectedMachinesCount: uint32(connectedMachines),
RegisteredMachinesCount: uint32(machines),
AllocatedMachinesCount: uint32(allocatedMachines),
PendingMachinesCount: uint32(numPendingMachines),
Platforms: platformMetrics,
SecureBootStatus: secureBootStatusMetrics,
UkiStatus: ukiMetrics,
RegisteredMachinesLimit: ctrl.maxRegisteredMachines,
RegistrationLimitReached: ctrl.maxRegisteredMachines > 0 && uint32(machines) >= ctrl.maxRegisteredMachines,
}
}

View File

@ -0,0 +1,104 @@
// Copyright (c) 2026 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package omni_test
import (
"context"
"testing"
"time"
"github.com/cosi-project/runtime/pkg/resource/rtestutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
omnictrl "github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/omni"
"github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/testutils"
)
func TestMachineStatusMetricsController_RegistrationLimit(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
machineIDs []string
maxRegistered uint32
expectCount uint32
expectLimit uint32
expectLimitReached bool
}{
{
name: "limit not reached",
machineIDs: []string{"m1", "m2"},
maxRegistered: 5,
expectCount: 2,
expectLimit: 5,
expectLimitReached: false,
},
{
name: "limit reached",
machineIDs: []string{"m1", "m2"},
maxRegistered: 2,
expectCount: 2,
expectLimit: 2,
expectLimitReached: true,
},
{
name: "limit exceeded",
machineIDs: []string{"m1", "m2", "m3"},
maxRegistered: 1,
expectCount: 3,
expectLimit: 1,
expectLimitReached: true,
},
{
name: "unlimited when zero",
machineIDs: []string{"m1", "m2"},
maxRegistered: 0,
expectCount: 2,
expectLimit: 0,
expectLimitReached: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
t.Cleanup(cancel)
testutils.WithRuntime(ctx, t, testutils.TestOptions{},
func(_ context.Context, tc testutils.TestContext) {
require.NoError(t, tc.Runtime.RegisterController(omnictrl.NewMachineStatusMetricsController(tt.maxRegistered)))
},
func(ctx context.Context, tc testutils.TestContext) {
for _, id := range tt.machineIDs {
require.NoError(t, tc.State.Create(ctx, omni.NewMachineStatus(id)))
}
rtestutils.AssertResource(ctx, t, tc.State, omni.MachineStatusMetricsID, func(res *omni.MachineStatusMetrics, a *assert.Assertions) {
a.EqualValues(tt.expectCount, res.TypedSpec().Value.RegisteredMachinesCount)
a.EqualValues(tt.expectLimit, res.TypedSpec().Value.RegisteredMachinesLimit)
a.Equal(tt.expectLimitReached, res.TypedSpec().Value.RegistrationLimitReached)
})
if tt.expectLimitReached {
rtestutils.AssertResource(ctx, t, tc.State, omni.NotificationMachineRegistrationLimitID, func(res *omni.Notification, a *assert.Assertions) {
a.Equal("Machine Registration Limit Reached", res.TypedSpec().Value.Title)
a.Contains(res.TypedSpec().Value.Body, "machines registered")
a.Equal(specs.NotificationSpec_WARNING, res.TypedSpec().Value.Type)
})
} else {
// Notification should not exist when limit is not reached.
// Sleep briefly since there is no state change to poll on.
time.Sleep(500 * time.Millisecond)
rtestutils.AssertNoResource[*omni.Notification](ctx, t, tc.State, omni.NotificationMachineRegistrationLimitID)
}
},
)
})
}
}

View File

@ -164,7 +164,7 @@ func NewRuntime(cfg *config.Params, talosClientFactory *talos.ClientFactory, dns
},
omnictrl.NewMachineCleanupController(),
omnictrl.NewMachineStatusLinkController(linkCounterDeltaCh),
&omnictrl.MachineStatusMetricsController{},
omnictrl.NewMachineStatusMetricsController(cfg.Account.GetMaxRegisteredMachines()),
omnictrl.NewVersionsController(cfg.Registries.GetImageFactoryBaseURL(), cfg.Features.GetEnableTalosPreReleaseVersions(), cfg.Registries.GetKubernetes()),
omnictrl.NewClusterLoadBalancerController(
cfg.Services.LoadBalancer.GetMinPort(),

View File

@ -468,6 +468,7 @@ func filterAccess(ctx context.Context, access state.Access) error {
omni.InstallationMediaType,
omni.OngoingTaskType,
omni.MachineStatusMetricsType,
omni.NotificationType,
omni.ClusterMetricsType,
omni.ClusterStatusMetricsType,
system.SysVersionType,
@ -595,6 +596,7 @@ func filterAccessByType(access state.Access) error {
omni.MachineExtensionsStatusType,
omni.MachineExtensionsType,
omni.MachineStatusMetricsType,
omni.NotificationType,
omni.UpgradeRolloutType,
authres.AuthConfigType,
authres.IdentityLastActiveType,

View File

@ -600,15 +600,16 @@ func (s *Server) runMachineAPI(ctx context.Context) error {
wgAddress := s.cfg.Services.Siderolink.WireGuard.GetEndpoint()
params := siderolink.Params{
WireguardEndpoint: wgAddress,
AdvertisedEndpoint: s.cfg.Services.Siderolink.WireGuard.GetAdvertisedEndpoint(),
MachineAPIEndpoint: s.cfg.Services.MachineAPI.GetEndpoint(),
MachineAPITLSCert: s.cfg.Services.MachineAPI.GetCertFile(),
MachineAPITLSKey: s.cfg.Services.MachineAPI.GetKeyFile(),
EventSinkPort: strconv.Itoa(s.cfg.Services.Siderolink.GetEventSinkPort()),
JoinTokensMode: s.cfg.Services.Siderolink.GetJoinTokensMode(),
UseGRPCTunnel: s.cfg.Services.Siderolink.GetUseGRPCTunnel(),
DisableLastEndpoint: s.cfg.Services.Siderolink.GetDisableLastEndpoint(),
WireguardEndpoint: wgAddress,
AdvertisedEndpoint: s.cfg.Services.Siderolink.WireGuard.GetAdvertisedEndpoint(),
MachineAPIEndpoint: s.cfg.Services.MachineAPI.GetEndpoint(),
MachineAPITLSCert: s.cfg.Services.MachineAPI.GetCertFile(),
MachineAPITLSKey: s.cfg.Services.MachineAPI.GetKeyFile(),
EventSinkPort: strconv.Itoa(s.cfg.Services.Siderolink.GetEventSinkPort()),
JoinTokensMode: s.cfg.Services.Siderolink.GetJoinTokensMode(),
UseGRPCTunnel: s.cfg.Services.Siderolink.GetUseGRPCTunnel(),
DisableLastEndpoint: s.cfg.Services.Siderolink.GetDisableLastEndpoint(),
MaxRegisteredMachines: s.cfg.Account.GetMaxRegisteredMachines(),
}
omniState := s.state.Default()

View File

@ -1251,6 +1251,11 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client
allowedVerbSet: readOnlyVerbSet,
isSignatureSufficient: true,
},
{
resource: omni.NewNotification(uuid.New().String()),
allowedVerbSet: readOnlyVerbSet,
isSignatureSufficient: true,
},
{
resource: omni.NewClusterMetrics(uuid.New().String()),
allowedVerbSet: readOnlyVerbSet,

View File

@ -17,6 +17,17 @@ func (s *Account) SetId(v string) {
s.Id = &v
}
func (s *Account) GetMaxRegisteredMachines() uint32 {
if s == nil || s.MaxRegisteredMachines == nil {
return *new(uint32)
}
return *s.MaxRegisteredMachines
}
func (s *Account) SetMaxRegisteredMachines(v uint32) {
s.MaxRegisteredMachines = &v
}
func (s *Account) GetName() string {
if s == nil || s.Name == nil {
return *new(string)

View File

@ -85,6 +85,14 @@
"userPilot": {
"description": "UserPilot contains UserPilot-related configuration.",
"$ref": "#/definitions/UserPilot"
},
"maxRegisteredMachines": {
"description": "MaxRegisteredMachines is the maximum number of registered machines allowed. 0 means unlimited.",
"type": "integer",
"minimum": 0,
"goJSONSchema": {
"type": "uint32"
}
}
}
},

View File

@ -10,6 +10,10 @@ type Account struct {
// initial setup.
Id *string `json:"id" yaml:"id"`
// MaxRegisteredMachines is the maximum number of registered machines allowed. 0
// means unlimited.
MaxRegisteredMachines *uint32 `json:"maxRegisteredMachines,omitempty" yaml:"maxRegisteredMachines,omitempty"`
// Name is the human-readable name of the account.
Name *string `json:"name" yaml:"name"`

View File

@ -113,6 +113,7 @@ func NewManager(
state,
params.JoinTokensMode,
params.UseGRPCTunnel,
params.MaxRegisteredMachines,
),
}
@ -179,15 +180,16 @@ func getJoinToken(logger *zap.Logger) (string, error) {
// Params are the parameters for the Manager.
type Params struct {
WireguardEndpoint string
AdvertisedEndpoint string
MachineAPIEndpoint string
MachineAPITLSCert string
MachineAPITLSKey string
EventSinkPort string
JoinTokensMode config.SiderolinkServiceJoinTokensMode
UseGRPCTunnel bool
DisableLastEndpoint bool
WireguardEndpoint string
AdvertisedEndpoint string
MachineAPIEndpoint string
MachineAPITLSCert string
MachineAPITLSKey string
EventSinkPort string
JoinTokensMode config.SiderolinkServiceJoinTokensMode
MaxRegisteredMachines uint32
UseGRPCTunnel bool
DisableLastEndpoint bool
}
// NewListener creates a new listener.

View File

@ -12,6 +12,7 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/cosi-project/runtime/pkg/controller/generic"
@ -72,12 +73,13 @@ func (pc *provisionContext) isAuthorizedSecureFlow() bool {
}
// NewProvisionHandler creates a new ProvisionHandler.
func NewProvisionHandler(logger *zap.Logger, state state.State, joinTokenMode config.SiderolinkServiceJoinTokensMode, forceWireguardOverGRPC bool) *ProvisionHandler {
func NewProvisionHandler(logger *zap.Logger, state state.State, joinTokenMode config.SiderolinkServiceJoinTokensMode, forceWireguardOverGRPC bool, maxRegisteredMachines uint32) *ProvisionHandler {
return &ProvisionHandler{
logger: logger,
state: state,
joinTokenMode: joinTokenMode,
forceWireguardOverGRPC: forceWireguardOverGRPC,
maxRegisteredMachines: maxRegisteredMachines,
}
}
@ -88,6 +90,8 @@ type ProvisionHandler struct {
logger *zap.Logger
state state.State
joinTokenMode config.SiderolinkServiceJoinTokensMode
registrationMu sync.Mutex
maxRegisteredMachines uint32
forceWireguardOverGRPC bool
}
@ -303,7 +307,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
return nil, status.Error(codes.PermissionDenied, "unauthorized")
}
return establishLink[*siderolinkres.Link](ctx, h.logger, h.state, provisionContext, nil, nil)
return establishLink[*siderolinkres.Link](ctx, h, provisionContext, nil, nil)
}
if !provisionContext.isAuthorizedSecureFlow() {
@ -316,7 +320,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
// put the machine into the limbo state by creating the pending machine resource
// the controller will then pick it up and create a wireguard peer for it
if provisionContext.requestNodeUniqueToken == nil {
return establishLink[*siderolinkres.PendingMachine](ctx, h.logger, h.state, provisionContext, nil, nil)
return establishLink[*siderolinkres.PendingMachine](ctx, h, provisionContext, nil, nil)
}
annotationsToAdd := []string{}
@ -332,7 +336,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
annotationsToAdd = append(annotationsToAdd, siderolinkres.ForceValidNodeUniqueToken)
}
response, err := establishLink[*siderolinkres.Link](ctx, h.logger, h.state, provisionContext, annotationsToAdd, annotationsToRemove)
response, err := establishLink[*siderolinkres.Link](ctx, h, provisionContext, annotationsToAdd, annotationsToRemove)
if err != nil {
if errors.Is(err, errUUIDConflict) {
logger.Info("detected UUID conflict", zap.String("peer", provisionContext.request.NodePublicKey))
@ -340,7 +344,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
// link is there, but the token doesn't match and the fingerprint differs, keep the machine in the limbo state
// mark pending machine as having the UUID conflict, PendingMachineStatus controller should inject the new UUID
// and the machine will re-join
return establishLink[*siderolinkres.PendingMachine](ctx, h.logger, h.state, provisionContext, []string{siderolinkres.PendingMachineUUIDConflict}, nil)
return establishLink[*siderolinkres.PendingMachine](ctx, h, provisionContext, []string{siderolinkres.PendingMachineUUIDConflict}, nil)
}
return nil, err
@ -349,6 +353,27 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
return response, nil
}
// createWithRegistrationLimit creates the resource, checking the registration limit for new Link resources under the registration mutex to prevent concurrent provisions from exceeding the limit.
func (h *ProvisionHandler) createWithRegistrationLimit(ctx context.Context, r resource.Resource, provisionContext *provisionContext) error {
isNewLink := provisionContext.link == nil && r.Metadata().Type() == siderolinkres.LinkType
if isNewLink && h.maxRegisteredMachines > 0 {
h.registrationMu.Lock()
defer h.registrationMu.Unlock()
links, err := safe.ReaderListAll[*siderolinkres.Link](ctx, h.state)
if err != nil {
return err
}
if uint32(links.Len()) >= h.maxRegisteredMachines {
return status.Errorf(codes.ResourceExhausted, "machine registration limit reached: %d/%d machines registered", links.Len(), h.maxRegisteredMachines)
}
}
return h.state.Create(ctx, r)
}
func (h *ProvisionHandler) removePendingMachine(ctx context.Context, pendingMachine *siderolinkres.PendingMachine) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
@ -380,9 +405,11 @@ func (h *ProvisionHandler) removePendingMachine(ctx context.Context, pendingMach
return nil
}
func establishLink[T res](ctx context.Context, logger *zap.Logger, st state.State, provisionContext *provisionContext,
annotationsToAdd []string, annotationsToRemove []string,
func establishLink[T res](ctx context.Context, h *ProvisionHandler, provisionContext *provisionContext, annotationsToAdd []string, annotationsToRemove []string,
) (*pb.ProvisionResponse, error) {
logger := h.logger
st := h.state
link, err := newLink[T](provisionContext, annotationsToAdd, annotationsToRemove)
if err != nil {
return nil, err
@ -415,7 +442,7 @@ func establishLink[T res](ctx context.Context, logger *zap.Logger, st state.Stat
}
}
if err = st.Create(ctx, link); err != nil {
if err = h.createWithRegistrationLimit(ctx, link, provisionContext); err != nil {
if !state.IsConflictError(err) {
return nil, err
}

View File

@ -122,7 +122,7 @@ func TestProvision(t *testing.T) {
require.NoError(t, eg.Wait())
})
provisionHandler := siderolink.NewProvisionHandler(logger, state, mode, false)
provisionHandler := siderolink.NewProvisionHandler(logger, state, mode, false, 0)
config := siderolinkres.NewConfig()
config.TypedSpec().Value.ServerAddress = "127.0.0.1"
@ -745,4 +745,87 @@ func TestProvision(t *testing.T) {
_, err = provisionHandler.Provision(ctx, request)
require.Equal(t, codes.PermissionDenied, status.Code(err))
})
t.Run("registration limit blocks new machines", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
t.Cleanup(cancel)
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
// Override the handler with one that has a limit of 2.
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 2)
// Create 2 links to reach the limit.
for _, id := range []string{"m1", "m2"} {
require.NoError(t, st.Create(ctx, siderolinkres.NewLink(id, &specs.SiderolinkSpec{})))
}
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
require.NoError(t, tokenErr)
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
NodeUuid: "m3",
NodePublicKey: genKey(),
TalosVersion: new("v1.9.0"),
JoinToken: new(validToken),
NodeUniqueToken: new(uniqueToken),
})
require.Error(t, err)
require.Equal(t, codes.ResourceExhausted, status.Code(err))
require.Contains(t, err.Error(), "2/2 machines registered")
})
t.Run("registration limit allows when under", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
t.Cleanup(cancel)
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 5)
require.NoError(t, st.Create(ctx, siderolinkres.NewLink("m1", &specs.SiderolinkSpec{})))
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
require.NoError(t, tokenErr)
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
NodeUuid: "m2",
NodePublicKey: genKey(),
TalosVersion: new("v1.9.0"),
JoinToken: new(validToken),
NodeUniqueToken: new(uniqueToken),
})
require.NoError(t, err)
})
t.Run("registration limit unlimited when zero", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
t.Cleanup(cancel)
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 0)
for _, id := range []string{"m1", "m2", "m3"} {
require.NoError(t, st.Create(ctx, siderolinkres.NewLink(id, &specs.SiderolinkSpec{})))
}
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
require.NoError(t, tokenErr)
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
NodeUuid: "m4",
NodePublicKey: genKey(),
TalosVersion: new("v1.9.0"),
JoinToken: new(validToken),
NodeUniqueToken: new(uniqueToken),
})
require.NoError(t, err)
})
}