mirror of
https://github.com/siderolabs/omni.git
synced 2026-03-31 13:41:04 +02:00
feat: enforce configurable machine registration limit
Add `account.maxRegisteredMachines` config option to cap the number of registered machines. The provision handler atomically checks the limit under a mutex before creating new Link resources, returning ResourceExhausted when the cap is reached. Introduce a Notification resource type (ephemeral namespace) so controllers can surface warnings to users. `omnictl` displays all active notifications on every command invocation. Frontend part of showing notifications will be implemented in a different PR. MachineStatusMetricsController creates a warning notification when the registration limit is reached and tears it down when it's not. Signed-off-by: Oguz Kilcan <oguz.kilcan@siderolabs.com>
This commit is contained in:
parent
711782b00e
commit
cf7d752453
File diff suppressed because it is too large
Load Diff
@ -1353,6 +1353,8 @@ message MachineStatusMetricsSpec {
|
||||
map<string, uint32> platforms = 6;
|
||||
map<string, uint32> secure_boot_status = 7;
|
||||
map<string, uint32> uki_status = 8;
|
||||
uint32 registered_machines_limit = 9;
|
||||
bool registration_limit_reached = 10;
|
||||
}
|
||||
|
||||
// ClusterMetricsSpec contains metrics about the clusters in the Omni instance.
|
||||
@ -1618,3 +1620,17 @@ message RotateKubernetesCASpec {}
|
||||
message UpgradeRolloutSpec {
|
||||
map<string, int32> machine_sets_upgrade_quota = 1;
|
||||
}
|
||||
|
||||
// NotificationSpec describes a generic notification emitted by a controller.
|
||||
message NotificationSpec {
|
||||
// Type describes the severity of a notification.
|
||||
enum Type {
|
||||
INFO = 0;
|
||||
WARNING = 1;
|
||||
ERROR = 2;
|
||||
}
|
||||
|
||||
string title = 1;
|
||||
string body = 2;
|
||||
Type type = 3;
|
||||
}
|
||||
|
||||
@ -2447,6 +2447,8 @@ func (m *MachineStatusMetricsSpec) CloneVT() *MachineStatusMetricsSpec {
|
||||
r.ConnectedMachinesCount = m.ConnectedMachinesCount
|
||||
r.AllocatedMachinesCount = m.AllocatedMachinesCount
|
||||
r.PendingMachinesCount = m.PendingMachinesCount
|
||||
r.RegisteredMachinesLimit = m.RegisteredMachinesLimit
|
||||
r.RegistrationLimitReached = m.RegistrationLimitReached
|
||||
if rhs := m.Platforms; rhs != nil {
|
||||
tmpContainer := make(map[string]uint32, len(rhs))
|
||||
for k, v := range rhs {
|
||||
@ -3109,6 +3111,25 @@ func (m *UpgradeRolloutSpec) CloneMessageVT() proto.Message {
|
||||
return m.CloneVT()
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) CloneVT() *NotificationSpec {
|
||||
if m == nil {
|
||||
return (*NotificationSpec)(nil)
|
||||
}
|
||||
r := new(NotificationSpec)
|
||||
r.Title = m.Title
|
||||
r.Body = m.Body
|
||||
r.Type = m.Type
|
||||
if len(m.unknownFields) > 0 {
|
||||
r.unknownFields = make([]byte, len(m.unknownFields))
|
||||
copy(r.unknownFields, m.unknownFields)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) CloneMessageVT() proto.Message {
|
||||
return m.CloneVT()
|
||||
}
|
||||
|
||||
func (this *MachineSpec) EqualVT(that *MachineSpec) bool {
|
||||
if this == that {
|
||||
return true
|
||||
@ -6500,6 +6521,12 @@ func (this *MachineStatusMetricsSpec) EqualVT(that *MachineStatusMetricsSpec) bo
|
||||
return false
|
||||
}
|
||||
}
|
||||
if this.RegisteredMachinesLimit != that.RegisteredMachinesLimit {
|
||||
return false
|
||||
}
|
||||
if this.RegistrationLimitReached != that.RegistrationLimitReached {
|
||||
return false
|
||||
}
|
||||
return string(this.unknownFields) == string(that.unknownFields)
|
||||
}
|
||||
|
||||
@ -7346,6 +7373,31 @@ func (this *UpgradeRolloutSpec) EqualMessageVT(thatMsg proto.Message) bool {
|
||||
}
|
||||
return this.EqualVT(that)
|
||||
}
|
||||
func (this *NotificationSpec) EqualVT(that *NotificationSpec) bool {
|
||||
if this == that {
|
||||
return true
|
||||
} else if this == nil || that == nil {
|
||||
return false
|
||||
}
|
||||
if this.Title != that.Title {
|
||||
return false
|
||||
}
|
||||
if this.Body != that.Body {
|
||||
return false
|
||||
}
|
||||
if this.Type != that.Type {
|
||||
return false
|
||||
}
|
||||
return string(this.unknownFields) == string(that.unknownFields)
|
||||
}
|
||||
|
||||
func (this *NotificationSpec) EqualMessageVT(thatMsg proto.Message) bool {
|
||||
that, ok := thatMsg.(*NotificationSpec)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return this.EqualVT(that)
|
||||
}
|
||||
func (m *MachineSpec) MarshalVT() (dAtA []byte, err error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
@ -13986,6 +14038,21 @@ func (m *MachineStatusMetricsSpec) MarshalToSizedBufferVT(dAtA []byte) (int, err
|
||||
i -= len(m.unknownFields)
|
||||
copy(dAtA[i:], m.unknownFields)
|
||||
}
|
||||
if m.RegistrationLimitReached {
|
||||
i--
|
||||
if m.RegistrationLimitReached {
|
||||
dAtA[i] = 1
|
||||
} else {
|
||||
dAtA[i] = 0
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0x50
|
||||
}
|
||||
if m.RegisteredMachinesLimit != 0 {
|
||||
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.RegisteredMachinesLimit))
|
||||
i--
|
||||
dAtA[i] = 0x48
|
||||
}
|
||||
if len(m.UkiStatus) > 0 {
|
||||
for k := range m.UkiStatus {
|
||||
v := m.UkiStatus[k]
|
||||
@ -15665,6 +15732,58 @@ func (m *UpgradeRolloutSpec) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) MarshalVT() (dAtA []byte, err error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
size := m.SizeVT()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalToSizedBufferVT(dAtA[:size])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) MarshalToVT(dAtA []byte) (int, error) {
|
||||
size := m.SizeVT()
|
||||
return m.MarshalToSizedBufferVT(dAtA[:size])
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) MarshalToSizedBufferVT(dAtA []byte) (int, error) {
|
||||
if m == nil {
|
||||
return 0, nil
|
||||
}
|
||||
i := len(dAtA)
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if m.unknownFields != nil {
|
||||
i -= len(m.unknownFields)
|
||||
copy(dAtA[i:], m.unknownFields)
|
||||
}
|
||||
if m.Type != 0 {
|
||||
i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Type))
|
||||
i--
|
||||
dAtA[i] = 0x18
|
||||
}
|
||||
if len(m.Body) > 0 {
|
||||
i -= len(m.Body)
|
||||
copy(dAtA[i:], m.Body)
|
||||
i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Body)))
|
||||
i--
|
||||
dAtA[i] = 0x12
|
||||
}
|
||||
if len(m.Title) > 0 {
|
||||
i -= len(m.Title)
|
||||
copy(dAtA[i:], m.Title)
|
||||
i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Title)))
|
||||
i--
|
||||
dAtA[i] = 0xa
|
||||
}
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *MachineSpec) SizeVT() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
@ -18309,6 +18428,12 @@ func (m *MachineStatusMetricsSpec) SizeVT() (n int) {
|
||||
n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize))
|
||||
}
|
||||
}
|
||||
if m.RegisteredMachinesLimit != 0 {
|
||||
n += 1 + protohelpers.SizeOfVarint(uint64(m.RegisteredMachinesLimit))
|
||||
}
|
||||
if m.RegistrationLimitReached {
|
||||
n += 2
|
||||
}
|
||||
n += len(m.unknownFields)
|
||||
return n
|
||||
}
|
||||
@ -18935,6 +19060,27 @@ func (m *UpgradeRolloutSpec) SizeVT() (n int) {
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *NotificationSpec) SizeVT() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var l int
|
||||
_ = l
|
||||
l = len(m.Title)
|
||||
if l > 0 {
|
||||
n += 1 + l + protohelpers.SizeOfVarint(uint64(l))
|
||||
}
|
||||
l = len(m.Body)
|
||||
if l > 0 {
|
||||
n += 1 + l + protohelpers.SizeOfVarint(uint64(l))
|
||||
}
|
||||
if m.Type != 0 {
|
||||
n += 1 + protohelpers.SizeOfVarint(uint64(m.Type))
|
||||
}
|
||||
n += len(m.unknownFields)
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *MachineSpec) UnmarshalVT(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
@ -35935,6 +36081,45 @@ func (m *MachineStatusMetricsSpec) UnmarshalVT(dAtA []byte) error {
|
||||
}
|
||||
m.UkiStatus[mapkey] = mapvalue
|
||||
iNdEx = postIndex
|
||||
case 9:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field RegisteredMachinesLimit", wireType)
|
||||
}
|
||||
m.RegisteredMachinesLimit = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.RegisteredMachinesLimit |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 10:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field RegistrationLimitReached", wireType)
|
||||
}
|
||||
var v int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
v |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
m.RegistrationLimitReached = bool(v != 0)
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
|
||||
@ -40045,3 +40230,137 @@ func (m *UpgradeRolloutSpec) UnmarshalVT(dAtA []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *NotificationSpec) UnmarshalVT(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: NotificationSpec: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: NotificationSpec: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Title", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return protohelpers.ErrInvalidLength
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex < 0 {
|
||||
return protohelpers.ErrInvalidLength
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Title = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Body", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return protohelpers.ErrInvalidLength
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex < 0 {
|
||||
return protohelpers.ErrInvalidLength
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Body = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
case 3:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType)
|
||||
}
|
||||
m.Type = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return protohelpers.ErrIntOverflow
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.Type |= NotificationSpec_Type(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := protohelpers.Skip(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (skippy < 0) || (iNdEx+skippy) < 0 {
|
||||
return protohelpers.ErrInvalidLength
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...)
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
52
client/pkg/omni/resources/omni/notification.go
Normal file
52
client/pkg/omni/resources/omni/notification.go
Normal file
@ -0,0 +1,52 @@
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package omni
|
||||
|
||||
import (
|
||||
"github.com/cosi-project/runtime/pkg/resource"
|
||||
"github.com/cosi-project/runtime/pkg/resource/meta"
|
||||
"github.com/cosi-project/runtime/pkg/resource/protobuf"
|
||||
"github.com/cosi-project/runtime/pkg/resource/typed"
|
||||
|
||||
"github.com/siderolabs/omni/client/api/omni/specs"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources"
|
||||
)
|
||||
|
||||
// NewNotification creates new Notification resource.
|
||||
func NewNotification(id resource.ID) *Notification {
|
||||
return typed.NewResource[NotificationSpec, NotificationExtension](
|
||||
resource.NewMetadata(resources.EphemeralNamespace, NotificationType, id, resource.VersionUndefined),
|
||||
protobuf.NewResourceSpec(&specs.NotificationSpec{}),
|
||||
)
|
||||
}
|
||||
|
||||
const (
|
||||
// NotificationType is the type of the Notification resource.
|
||||
// tsgen:NotificationType
|
||||
NotificationType = resource.Type("Notifications.omni.sidero.dev")
|
||||
|
||||
// NotificationMachineRegistrationLimitID is the ID for the machine registration limit notification.
|
||||
// tsgen:NotificationMachineRegistrationLimitID
|
||||
NotificationMachineRegistrationLimitID = "machine-registration-limit"
|
||||
)
|
||||
|
||||
// Notification describes a generic notification emitted by a controller.
|
||||
type Notification = typed.Resource[NotificationSpec, NotificationExtension]
|
||||
|
||||
// NotificationSpec wraps specs.NotificationSpec.
|
||||
type NotificationSpec = protobuf.ResourceSpec[specs.NotificationSpec, *specs.NotificationSpec]
|
||||
|
||||
// NotificationExtension provides auxiliary methods for Notification resource.
|
||||
type NotificationExtension struct{}
|
||||
|
||||
// ResourceDefinition implements [typed.Extension] interface.
|
||||
func (NotificationExtension) ResourceDefinition() meta.ResourceDefinitionSpec {
|
||||
return meta.ResourceDefinitionSpec{
|
||||
Type: NotificationType,
|
||||
Aliases: []resource.Type{},
|
||||
DefaultNamespace: resources.EphemeralNamespace,
|
||||
PrintColumns: []meta.PrintColumn{},
|
||||
}
|
||||
}
|
||||
@ -79,6 +79,7 @@ func init() {
|
||||
registry.MustRegisterResource(MachineStatusSnapshotType, &MachineStatusSnapshot{})
|
||||
registry.MustRegisterResource(MachineStatusLinkType, &MachineStatusLink{})
|
||||
registry.MustRegisterResource(MachineStatusMetricsType, &MachineStatusMetrics{})
|
||||
registry.MustRegisterResource(NotificationType, &Notification{})
|
||||
registry.MustRegisterResource(MaintenanceConfigStatusType, &MaintenanceConfigStatus{})
|
||||
registry.MustRegisterResource(LoadBalancerConfigType, &LoadBalancerConfig{})
|
||||
registry.MustRegisterResource(LoadBalancerStatusType, &LoadBalancerStatus{})
|
||||
|
||||
@ -16,9 +16,11 @@ import (
|
||||
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/siderolabs/omni/client/api/omni/specs"
|
||||
"github.com/siderolabs/omni/client/pkg/client"
|
||||
"github.com/siderolabs/omni/client/pkg/client/omni"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources"
|
||||
omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources/system"
|
||||
"github.com/siderolabs/omni/client/pkg/omnictl/config"
|
||||
"github.com/siderolabs/omni/client/pkg/version"
|
||||
@ -147,6 +149,10 @@ func WithClient(f func(ctx context.Context, client *client.Client) error, client
|
||||
return err
|
||||
}
|
||||
|
||||
if err = checkNotifications(ctx, client.Omni().State()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return f(ctx, client)
|
||||
})
|
||||
}
|
||||
@ -180,6 +186,34 @@ If you want to enable the version validation and disable this warning, set githu
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkNotifications(ctx context.Context, st state.State) error {
|
||||
notifications, err := safe.StateListAll[*omnires.Notification](ctx, st)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list notifications: %w", err)
|
||||
}
|
||||
|
||||
for n := range notifications.All() {
|
||||
spec := n.TypedSpec().Value
|
||||
|
||||
var prefix string
|
||||
|
||||
switch spec.Type {
|
||||
case specs.NotificationSpec_ERROR:
|
||||
prefix = "[ERROR]"
|
||||
case specs.NotificationSpec_WARNING:
|
||||
prefix = "[WARN]"
|
||||
case specs.NotificationSpec_INFO:
|
||||
prefix = "[INFO]"
|
||||
default:
|
||||
prefix = "[UNKNOWN]"
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s %s: %s\n", prefix, spec.Title, spec.Body) //nolint:errcheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkVersionWarning(sysVersion *system.SysVersion) {
|
||||
backendVersion, err := semver.ParseTolerant(sysVersion.TypedSpec().Value.BackendVersion)
|
||||
if err != nil {
|
||||
|
||||
@ -154,6 +154,7 @@ func buildRootCommand() (*cobra.Command, error) {
|
||||
rootCmdFlagBinder.StringVar("account-id", flagDescription("account.id", configSchema), &flagConfig.Account.Id)
|
||||
rootCmdFlagBinder.StringVar("name", flagDescription("account.name", configSchema), &flagConfig.Account.Name)
|
||||
rootCmdFlagBinder.StringVar("user-pilot-app-token", flagDescription("account.userPilot.appToken", configSchema), &flagConfig.Account.UserPilot.AppToken)
|
||||
rootCmdFlagBinder.Uint32Var("account-max-registered-machines", flagDescription("account.maxRegisteredMachines", configSchema), &flagConfig.Account.MaxRegisteredMachines)
|
||||
|
||||
if err := defineServiceFlags(rootCmd, rootCmdFlagBinder, flagConfig, configSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to define service flags: %w", err)
|
||||
|
||||
@ -194,6 +194,12 @@ export enum SecretRotationSpecComponent {
|
||||
KUBERNETES_CA = 2,
|
||||
}
|
||||
|
||||
export enum NotificationSpecType {
|
||||
INFO = 0,
|
||||
WARNING = 1,
|
||||
ERROR = 2,
|
||||
}
|
||||
|
||||
export type MachineSpec = {
|
||||
management_address?: string
|
||||
connected?: boolean
|
||||
@ -898,6 +904,8 @@ export type MachineStatusMetricsSpec = {
|
||||
platforms?: {[key: string]: number}
|
||||
secure_boot_status?: {[key: string]: number}
|
||||
uki_status?: {[key: string]: number}
|
||||
registered_machines_limit?: number
|
||||
registration_limit_reached?: boolean
|
||||
}
|
||||
|
||||
export type ClusterMetricsSpec = {
|
||||
@ -1066,4 +1074,10 @@ export type RotateKubernetesCASpec = {
|
||||
|
||||
export type UpgradeRolloutSpec = {
|
||||
machine_sets_upgrade_quota?: {[key: string]: number}
|
||||
}
|
||||
|
||||
export type NotificationSpec = {
|
||||
title?: string
|
||||
body?: string
|
||||
type?: NotificationSpecType
|
||||
}
|
||||
@ -207,6 +207,8 @@ export const MachineStatusSnapshotType = "MachineStatusSnapshots.omni.sidero.dev
|
||||
export const MachineUpgradeStatusType = "MachineUpgradeStatuses.omni.sidero.dev";
|
||||
export const MaintenanceConfigStatusType = "MaintenanceConfigStatuses.omni.sidero.dev";
|
||||
export const NodeForceDestroyRequestType = "NodeForceDestroyRequests.omni.sidero.dev";
|
||||
export const NotificationType = "Notifications.omni.sidero.dev";
|
||||
export const NotificationMachineRegistrationLimitID = "machine-registration-limit";
|
||||
export const OngoingTaskType = "OngoingTasks.omni.sidero.dev";
|
||||
export const RedactedClusterMachineConfigType = "RedactedClusterMachineConfigs.omni.sidero.dev";
|
||||
export const RotateKubernetesCAType = "RotateKubernetesCAs.omni.sidero.dev";
|
||||
|
||||
@ -47,6 +47,7 @@ export AUTH0_DOMAIN="${AUTH0_DOMAIN}"
|
||||
export OMNI_CONFIG="${TEST_OUTPUTS_DIR}/config.yaml"
|
||||
export MAX_USERS="${MAX_USERS:-0}"
|
||||
export MAX_SERVICE_ACCOUNTS="${MAX_SERVICE_ACCOUNTS:-0}"
|
||||
export MAX_REGISTERED_MACHINES="${MAX_REGISTERED_MACHINES:-0}"
|
||||
export REGISTRY_MIRROR_FLAGS=()
|
||||
export REGISTRY_MIRROR_CONFIG=""
|
||||
export IMPORTED_CLUSTER_ARGS=()
|
||||
|
||||
@ -61,6 +61,9 @@ prepare_vault
|
||||
# Start MinIO server.
|
||||
prepare_minio access_key="access" secret_key="secret123"
|
||||
|
||||
# Set the registration limit to the total number of machines so that the registration limit is actively enforced during the test.
|
||||
export MAX_REGISTERED_MACHINES="${TOTAL_MACHINES}"
|
||||
|
||||
# Prepare omni config.
|
||||
prepare_omni_config
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
account:
|
||||
maxRegisteredMachines: ${MAX_REGISTERED_MACHINES}
|
||||
services:
|
||||
api:
|
||||
endpoint: 0.0.0.0:8099
|
||||
|
||||
@ -7,6 +7,7 @@ package omni
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -24,6 +25,7 @@ import (
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources/infra"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
|
||||
"github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/helpers"
|
||||
)
|
||||
|
||||
type nodeInfo struct {
|
||||
@ -32,6 +34,13 @@ type nodeInfo struct {
|
||||
connected bool
|
||||
}
|
||||
|
||||
// NewMachineStatusMetricsController creates a new MachineStatusMetricsController.
|
||||
func NewMachineStatusMetricsController(maxRegisteredMachines uint32) *MachineStatusMetricsController {
|
||||
return &MachineStatusMetricsController{
|
||||
maxRegisteredMachines: maxRegisteredMachines,
|
||||
}
|
||||
}
|
||||
|
||||
// MachineStatusMetricsController provides metrics based on ClusterStatus.
|
||||
//
|
||||
//nolint:govet
|
||||
@ -41,6 +50,8 @@ type MachineStatusMetricsController struct {
|
||||
|
||||
metricsOnce sync.Once
|
||||
|
||||
maxRegisteredMachines uint32
|
||||
|
||||
platformNames []string
|
||||
|
||||
metricNumMachines prometheus.Gauge
|
||||
@ -79,6 +90,10 @@ func (ctrl *MachineStatusMetricsController) Outputs() []controller.Output {
|
||||
Type: omni.MachineStatusMetricsType,
|
||||
Kind: controller.OutputExclusive,
|
||||
},
|
||||
{
|
||||
Type: omni.NotificationType,
|
||||
Kind: controller.OutputShared,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,6 +176,12 @@ func (ctrl *MachineStatusMetricsController) Run(ctx context.Context, r controlle
|
||||
return err
|
||||
}
|
||||
|
||||
if ctrl.maxRegisteredMachines > 0 {
|
||||
if err = ctrl.reconcileRegistrationLimitNotification(ctx, r, metricsSpec); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
@ -169,6 +190,26 @@ func (ctrl *MachineStatusMetricsController) Run(ctx context.Context, r controlle
|
||||
}
|
||||
}
|
||||
|
||||
func (ctrl *MachineStatusMetricsController) reconcileRegistrationLimitNotification(ctx context.Context, r controller.Runtime, metricsSpec *specs.MachineStatusMetricsSpec) error {
|
||||
if metricsSpec.RegistrationLimitReached {
|
||||
return safe.WriterModify(ctx, r, omni.NewNotification(omni.NotificationMachineRegistrationLimitID),
|
||||
func(res *omni.Notification) error {
|
||||
res.TypedSpec().Value.Title = "Machine Registration Limit Reached"
|
||||
res.TypedSpec().Value.Body = fmt.Sprintf(
|
||||
"%d/%d machines registered. New machines will be rejected.",
|
||||
metricsSpec.RegisteredMachinesCount, metricsSpec.RegisteredMachinesLimit)
|
||||
res.TypedSpec().Value.Type = specs.NotificationSpec_WARNING
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
_, err := helpers.TeardownAndDestroy(ctx, r, omni.NewNotification(omni.NotificationMachineRegistrationLimitID).Metadata())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ctrl *MachineStatusMetricsController) gatherMetrics(statuses iter.Seq[*omni.MachineStatus], numPendingMachines int) *specs.MachineStatusMetricsSpec {
|
||||
platformMetrics := make(map[string]uint32, len(ctrl.platformNames))
|
||||
for _, p := range ctrl.platformNames {
|
||||
@ -245,13 +286,15 @@ func (ctrl *MachineStatusMetricsController) gatherMetrics(statuses iter.Seq[*omn
|
||||
}
|
||||
|
||||
return &specs.MachineStatusMetricsSpec{
|
||||
ConnectedMachinesCount: uint32(connectedMachines),
|
||||
RegisteredMachinesCount: uint32(machines),
|
||||
AllocatedMachinesCount: uint32(allocatedMachines),
|
||||
PendingMachinesCount: uint32(numPendingMachines),
|
||||
Platforms: platformMetrics,
|
||||
SecureBootStatus: secureBootStatusMetrics,
|
||||
UkiStatus: ukiMetrics,
|
||||
ConnectedMachinesCount: uint32(connectedMachines),
|
||||
RegisteredMachinesCount: uint32(machines),
|
||||
AllocatedMachinesCount: uint32(allocatedMachines),
|
||||
PendingMachinesCount: uint32(numPendingMachines),
|
||||
Platforms: platformMetrics,
|
||||
SecureBootStatus: secureBootStatusMetrics,
|
||||
UkiStatus: ukiMetrics,
|
||||
RegisteredMachinesLimit: ctrl.maxRegisteredMachines,
|
||||
RegistrationLimitReached: ctrl.maxRegisteredMachines > 0 && uint32(machines) >= ctrl.maxRegisteredMachines,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
// Copyright (c) 2026 Sidero Labs, Inc.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the LICENSE file.
|
||||
|
||||
package omni_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cosi-project/runtime/pkg/resource/rtestutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/siderolabs/omni/client/api/omni/specs"
|
||||
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
|
||||
omnictrl "github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/omni"
|
||||
"github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/testutils"
|
||||
)
|
||||
|
||||
func TestMachineStatusMetricsController_RegistrationLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
machineIDs []string
|
||||
maxRegistered uint32
|
||||
expectCount uint32
|
||||
expectLimit uint32
|
||||
expectLimitReached bool
|
||||
}{
|
||||
{
|
||||
name: "limit not reached",
|
||||
machineIDs: []string{"m1", "m2"},
|
||||
maxRegistered: 5,
|
||||
expectCount: 2,
|
||||
expectLimit: 5,
|
||||
expectLimitReached: false,
|
||||
},
|
||||
{
|
||||
name: "limit reached",
|
||||
machineIDs: []string{"m1", "m2"},
|
||||
maxRegistered: 2,
|
||||
expectCount: 2,
|
||||
expectLimit: 2,
|
||||
expectLimitReached: true,
|
||||
},
|
||||
{
|
||||
name: "limit exceeded",
|
||||
machineIDs: []string{"m1", "m2", "m3"},
|
||||
maxRegistered: 1,
|
||||
expectCount: 3,
|
||||
expectLimit: 1,
|
||||
expectLimitReached: true,
|
||||
},
|
||||
{
|
||||
name: "unlimited when zero",
|
||||
machineIDs: []string{"m1", "m2"},
|
||||
maxRegistered: 0,
|
||||
expectCount: 2,
|
||||
expectLimit: 0,
|
||||
expectLimitReached: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
testutils.WithRuntime(ctx, t, testutils.TestOptions{},
|
||||
func(_ context.Context, tc testutils.TestContext) {
|
||||
require.NoError(t, tc.Runtime.RegisterController(omnictrl.NewMachineStatusMetricsController(tt.maxRegistered)))
|
||||
},
|
||||
func(ctx context.Context, tc testutils.TestContext) {
|
||||
for _, id := range tt.machineIDs {
|
||||
require.NoError(t, tc.State.Create(ctx, omni.NewMachineStatus(id)))
|
||||
}
|
||||
|
||||
rtestutils.AssertResource(ctx, t, tc.State, omni.MachineStatusMetricsID, func(res *omni.MachineStatusMetrics, a *assert.Assertions) {
|
||||
a.EqualValues(tt.expectCount, res.TypedSpec().Value.RegisteredMachinesCount)
|
||||
a.EqualValues(tt.expectLimit, res.TypedSpec().Value.RegisteredMachinesLimit)
|
||||
a.Equal(tt.expectLimitReached, res.TypedSpec().Value.RegistrationLimitReached)
|
||||
})
|
||||
|
||||
if tt.expectLimitReached {
|
||||
rtestutils.AssertResource(ctx, t, tc.State, omni.NotificationMachineRegistrationLimitID, func(res *omni.Notification, a *assert.Assertions) {
|
||||
a.Equal("Machine Registration Limit Reached", res.TypedSpec().Value.Title)
|
||||
a.Contains(res.TypedSpec().Value.Body, "machines registered")
|
||||
a.Equal(specs.NotificationSpec_WARNING, res.TypedSpec().Value.Type)
|
||||
})
|
||||
} else {
|
||||
// Notification should not exist when limit is not reached.
|
||||
// Sleep briefly since there is no state change to poll on.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
rtestutils.AssertNoResource[*omni.Notification](ctx, t, tc.State, omni.NotificationMachineRegistrationLimitID)
|
||||
}
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -164,7 +164,7 @@ func NewRuntime(cfg *config.Params, talosClientFactory *talos.ClientFactory, dns
|
||||
},
|
||||
omnictrl.NewMachineCleanupController(),
|
||||
omnictrl.NewMachineStatusLinkController(linkCounterDeltaCh),
|
||||
&omnictrl.MachineStatusMetricsController{},
|
||||
omnictrl.NewMachineStatusMetricsController(cfg.Account.GetMaxRegisteredMachines()),
|
||||
omnictrl.NewVersionsController(cfg.Registries.GetImageFactoryBaseURL(), cfg.Features.GetEnableTalosPreReleaseVersions(), cfg.Registries.GetKubernetes()),
|
||||
omnictrl.NewClusterLoadBalancerController(
|
||||
cfg.Services.LoadBalancer.GetMinPort(),
|
||||
|
||||
@ -468,6 +468,7 @@ func filterAccess(ctx context.Context, access state.Access) error {
|
||||
omni.InstallationMediaType,
|
||||
omni.OngoingTaskType,
|
||||
omni.MachineStatusMetricsType,
|
||||
omni.NotificationType,
|
||||
omni.ClusterMetricsType,
|
||||
omni.ClusterStatusMetricsType,
|
||||
system.SysVersionType,
|
||||
@ -595,6 +596,7 @@ func filterAccessByType(access state.Access) error {
|
||||
omni.MachineExtensionsStatusType,
|
||||
omni.MachineExtensionsType,
|
||||
omni.MachineStatusMetricsType,
|
||||
omni.NotificationType,
|
||||
omni.UpgradeRolloutType,
|
||||
authres.AuthConfigType,
|
||||
authres.IdentityLastActiveType,
|
||||
|
||||
@ -600,15 +600,16 @@ func (s *Server) runMachineAPI(ctx context.Context) error {
|
||||
wgAddress := s.cfg.Services.Siderolink.WireGuard.GetEndpoint()
|
||||
|
||||
params := siderolink.Params{
|
||||
WireguardEndpoint: wgAddress,
|
||||
AdvertisedEndpoint: s.cfg.Services.Siderolink.WireGuard.GetAdvertisedEndpoint(),
|
||||
MachineAPIEndpoint: s.cfg.Services.MachineAPI.GetEndpoint(),
|
||||
MachineAPITLSCert: s.cfg.Services.MachineAPI.GetCertFile(),
|
||||
MachineAPITLSKey: s.cfg.Services.MachineAPI.GetKeyFile(),
|
||||
EventSinkPort: strconv.Itoa(s.cfg.Services.Siderolink.GetEventSinkPort()),
|
||||
JoinTokensMode: s.cfg.Services.Siderolink.GetJoinTokensMode(),
|
||||
UseGRPCTunnel: s.cfg.Services.Siderolink.GetUseGRPCTunnel(),
|
||||
DisableLastEndpoint: s.cfg.Services.Siderolink.GetDisableLastEndpoint(),
|
||||
WireguardEndpoint: wgAddress,
|
||||
AdvertisedEndpoint: s.cfg.Services.Siderolink.WireGuard.GetAdvertisedEndpoint(),
|
||||
MachineAPIEndpoint: s.cfg.Services.MachineAPI.GetEndpoint(),
|
||||
MachineAPITLSCert: s.cfg.Services.MachineAPI.GetCertFile(),
|
||||
MachineAPITLSKey: s.cfg.Services.MachineAPI.GetKeyFile(),
|
||||
EventSinkPort: strconv.Itoa(s.cfg.Services.Siderolink.GetEventSinkPort()),
|
||||
JoinTokensMode: s.cfg.Services.Siderolink.GetJoinTokensMode(),
|
||||
UseGRPCTunnel: s.cfg.Services.Siderolink.GetUseGRPCTunnel(),
|
||||
DisableLastEndpoint: s.cfg.Services.Siderolink.GetDisableLastEndpoint(),
|
||||
MaxRegisteredMachines: s.cfg.Account.GetMaxRegisteredMachines(),
|
||||
}
|
||||
|
||||
omniState := s.state.Default()
|
||||
|
||||
@ -1251,6 +1251,11 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client
|
||||
allowedVerbSet: readOnlyVerbSet,
|
||||
isSignatureSufficient: true,
|
||||
},
|
||||
{
|
||||
resource: omni.NewNotification(uuid.New().String()),
|
||||
allowedVerbSet: readOnlyVerbSet,
|
||||
isSignatureSufficient: true,
|
||||
},
|
||||
{
|
||||
resource: omni.NewClusterMetrics(uuid.New().String()),
|
||||
allowedVerbSet: readOnlyVerbSet,
|
||||
|
||||
@ -17,6 +17,17 @@ func (s *Account) SetId(v string) {
|
||||
s.Id = &v
|
||||
}
|
||||
|
||||
func (s *Account) GetMaxRegisteredMachines() uint32 {
|
||||
if s == nil || s.MaxRegisteredMachines == nil {
|
||||
return *new(uint32)
|
||||
}
|
||||
return *s.MaxRegisteredMachines
|
||||
}
|
||||
|
||||
func (s *Account) SetMaxRegisteredMachines(v uint32) {
|
||||
s.MaxRegisteredMachines = &v
|
||||
}
|
||||
|
||||
func (s *Account) GetName() string {
|
||||
if s == nil || s.Name == nil {
|
||||
return *new(string)
|
||||
|
||||
@ -85,6 +85,14 @@
|
||||
"userPilot": {
|
||||
"description": "UserPilot contains UserPilot-related configuration.",
|
||||
"$ref": "#/definitions/UserPilot"
|
||||
},
|
||||
"maxRegisteredMachines": {
|
||||
"description": "MaxRegisteredMachines is the maximum number of registered machines allowed. 0 means unlimited.",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"goJSONSchema": {
|
||||
"type": "uint32"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@ -10,6 +10,10 @@ type Account struct {
|
||||
// initial setup.
|
||||
Id *string `json:"id" yaml:"id"`
|
||||
|
||||
// MaxRegisteredMachines is the maximum number of registered machines allowed. 0
|
||||
// means unlimited.
|
||||
MaxRegisteredMachines *uint32 `json:"maxRegisteredMachines,omitempty" yaml:"maxRegisteredMachines,omitempty"`
|
||||
|
||||
// Name is the human-readable name of the account.
|
||||
Name *string `json:"name" yaml:"name"`
|
||||
|
||||
|
||||
@ -113,6 +113,7 @@ func NewManager(
|
||||
state,
|
||||
params.JoinTokensMode,
|
||||
params.UseGRPCTunnel,
|
||||
params.MaxRegisteredMachines,
|
||||
),
|
||||
}
|
||||
|
||||
@ -179,15 +180,16 @@ func getJoinToken(logger *zap.Logger) (string, error) {
|
||||
|
||||
// Params are the parameters for the Manager.
|
||||
type Params struct {
|
||||
WireguardEndpoint string
|
||||
AdvertisedEndpoint string
|
||||
MachineAPIEndpoint string
|
||||
MachineAPITLSCert string
|
||||
MachineAPITLSKey string
|
||||
EventSinkPort string
|
||||
JoinTokensMode config.SiderolinkServiceJoinTokensMode
|
||||
UseGRPCTunnel bool
|
||||
DisableLastEndpoint bool
|
||||
WireguardEndpoint string
|
||||
AdvertisedEndpoint string
|
||||
MachineAPIEndpoint string
|
||||
MachineAPITLSCert string
|
||||
MachineAPITLSKey string
|
||||
EventSinkPort string
|
||||
JoinTokensMode config.SiderolinkServiceJoinTokensMode
|
||||
MaxRegisteredMachines uint32
|
||||
UseGRPCTunnel bool
|
||||
DisableLastEndpoint bool
|
||||
}
|
||||
|
||||
// NewListener creates a new listener.
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cosi-project/runtime/pkg/controller/generic"
|
||||
@ -72,12 +73,13 @@ func (pc *provisionContext) isAuthorizedSecureFlow() bool {
|
||||
}
|
||||
|
||||
// NewProvisionHandler creates a new ProvisionHandler.
|
||||
func NewProvisionHandler(logger *zap.Logger, state state.State, joinTokenMode config.SiderolinkServiceJoinTokensMode, forceWireguardOverGRPC bool) *ProvisionHandler {
|
||||
func NewProvisionHandler(logger *zap.Logger, state state.State, joinTokenMode config.SiderolinkServiceJoinTokensMode, forceWireguardOverGRPC bool, maxRegisteredMachines uint32) *ProvisionHandler {
|
||||
return &ProvisionHandler{
|
||||
logger: logger,
|
||||
state: state,
|
||||
joinTokenMode: joinTokenMode,
|
||||
forceWireguardOverGRPC: forceWireguardOverGRPC,
|
||||
maxRegisteredMachines: maxRegisteredMachines,
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,6 +90,8 @@ type ProvisionHandler struct {
|
||||
logger *zap.Logger
|
||||
state state.State
|
||||
joinTokenMode config.SiderolinkServiceJoinTokensMode
|
||||
registrationMu sync.Mutex
|
||||
maxRegisteredMachines uint32
|
||||
forceWireguardOverGRPC bool
|
||||
}
|
||||
|
||||
@ -303,7 +307,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
|
||||
return nil, status.Error(codes.PermissionDenied, "unauthorized")
|
||||
}
|
||||
|
||||
return establishLink[*siderolinkres.Link](ctx, h.logger, h.state, provisionContext, nil, nil)
|
||||
return establishLink[*siderolinkres.Link](ctx, h, provisionContext, nil, nil)
|
||||
}
|
||||
|
||||
if !provisionContext.isAuthorizedSecureFlow() {
|
||||
@ -316,7 +320,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
|
||||
// put the machine into the limbo state by creating the pending machine resource
|
||||
// the controller will then pick it up and create a wireguard peer for it
|
||||
if provisionContext.requestNodeUniqueToken == nil {
|
||||
return establishLink[*siderolinkres.PendingMachine](ctx, h.logger, h.state, provisionContext, nil, nil)
|
||||
return establishLink[*siderolinkres.PendingMachine](ctx, h, provisionContext, nil, nil)
|
||||
}
|
||||
|
||||
annotationsToAdd := []string{}
|
||||
@ -332,7 +336,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
|
||||
annotationsToAdd = append(annotationsToAdd, siderolinkres.ForceValidNodeUniqueToken)
|
||||
}
|
||||
|
||||
response, err := establishLink[*siderolinkres.Link](ctx, h.logger, h.state, provisionContext, annotationsToAdd, annotationsToRemove)
|
||||
response, err := establishLink[*siderolinkres.Link](ctx, h, provisionContext, annotationsToAdd, annotationsToRemove)
|
||||
if err != nil {
|
||||
if errors.Is(err, errUUIDConflict) {
|
||||
logger.Info("detected UUID conflict", zap.String("peer", provisionContext.request.NodePublicKey))
|
||||
@ -340,7 +344,7 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
|
||||
// link is there, but the token doesn't match and the fingerprint differs, keep the machine in the limbo state
|
||||
// mark pending machine as having the UUID conflict, PendingMachineStatus controller should inject the new UUID
|
||||
// and the machine will re-join
|
||||
return establishLink[*siderolinkres.PendingMachine](ctx, h.logger, h.state, provisionContext, []string{siderolinkres.PendingMachineUUIDConflict}, nil)
|
||||
return establishLink[*siderolinkres.PendingMachine](ctx, h, provisionContext, []string{siderolinkres.PendingMachineUUIDConflict}, nil)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
@ -349,6 +353,27 @@ func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *prov
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// createWithRegistrationLimit creates the resource, checking the registration limit for new Link resources under the registration mutex to prevent concurrent provisions from exceeding the limit.
|
||||
func (h *ProvisionHandler) createWithRegistrationLimit(ctx context.Context, r resource.Resource, provisionContext *provisionContext) error {
|
||||
isNewLink := provisionContext.link == nil && r.Metadata().Type() == siderolinkres.LinkType
|
||||
|
||||
if isNewLink && h.maxRegisteredMachines > 0 {
|
||||
h.registrationMu.Lock()
|
||||
defer h.registrationMu.Unlock()
|
||||
|
||||
links, err := safe.ReaderListAll[*siderolinkres.Link](ctx, h.state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if uint32(links.Len()) >= h.maxRegisteredMachines {
|
||||
return status.Errorf(codes.ResourceExhausted, "machine registration limit reached: %d/%d machines registered", links.Len(), h.maxRegisteredMachines)
|
||||
}
|
||||
}
|
||||
|
||||
return h.state.Create(ctx, r)
|
||||
}
|
||||
|
||||
func (h *ProvisionHandler) removePendingMachine(ctx context.Context, pendingMachine *siderolinkres.PendingMachine) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
@ -380,9 +405,11 @@ func (h *ProvisionHandler) removePendingMachine(ctx context.Context, pendingMach
|
||||
return nil
|
||||
}
|
||||
|
||||
func establishLink[T res](ctx context.Context, logger *zap.Logger, st state.State, provisionContext *provisionContext,
|
||||
annotationsToAdd []string, annotationsToRemove []string,
|
||||
func establishLink[T res](ctx context.Context, h *ProvisionHandler, provisionContext *provisionContext, annotationsToAdd []string, annotationsToRemove []string,
|
||||
) (*pb.ProvisionResponse, error) {
|
||||
logger := h.logger
|
||||
st := h.state
|
||||
|
||||
link, err := newLink[T](provisionContext, annotationsToAdd, annotationsToRemove)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -415,7 +442,7 @@ func establishLink[T res](ctx context.Context, logger *zap.Logger, st state.Stat
|
||||
}
|
||||
}
|
||||
|
||||
if err = st.Create(ctx, link); err != nil {
|
||||
if err = h.createWithRegistrationLimit(ctx, link, provisionContext); err != nil {
|
||||
if !state.IsConflictError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -122,7 +122,7 @@ func TestProvision(t *testing.T) {
|
||||
require.NoError(t, eg.Wait())
|
||||
})
|
||||
|
||||
provisionHandler := siderolink.NewProvisionHandler(logger, state, mode, false)
|
||||
provisionHandler := siderolink.NewProvisionHandler(logger, state, mode, false, 0)
|
||||
|
||||
config := siderolinkres.NewConfig()
|
||||
config.TypedSpec().Value.ServerAddress = "127.0.0.1"
|
||||
@ -745,4 +745,87 @@ func TestProvision(t *testing.T) {
|
||||
_, err = provisionHandler.Provision(ctx, request)
|
||||
require.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
})
|
||||
|
||||
t.Run("registration limit blocks new machines", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
|
||||
|
||||
// Override the handler with one that has a limit of 2.
|
||||
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 2)
|
||||
|
||||
// Create 2 links to reach the limit.
|
||||
for _, id := range []string{"m1", "m2"} {
|
||||
require.NoError(t, st.Create(ctx, siderolinkres.NewLink(id, &specs.SiderolinkSpec{})))
|
||||
}
|
||||
|
||||
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
|
||||
require.NoError(t, tokenErr)
|
||||
|
||||
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
|
||||
NodeUuid: "m3",
|
||||
NodePublicKey: genKey(),
|
||||
TalosVersion: new("v1.9.0"),
|
||||
JoinToken: new(validToken),
|
||||
NodeUniqueToken: new(uniqueToken),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
require.Contains(t, err.Error(), "2/2 machines registered")
|
||||
})
|
||||
|
||||
t.Run("registration limit allows when under", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
|
||||
|
||||
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 5)
|
||||
|
||||
require.NoError(t, st.Create(ctx, siderolinkres.NewLink("m1", &specs.SiderolinkSpec{})))
|
||||
|
||||
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
|
||||
require.NoError(t, tokenErr)
|
||||
|
||||
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
|
||||
NodeUuid: "m2",
|
||||
NodePublicKey: genKey(),
|
||||
TalosVersion: new("v1.9.0"),
|
||||
JoinToken: new(validToken),
|
||||
NodeUniqueToken: new(uniqueToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("registration limit unlimited when zero", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*5)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
st, _ := setup(ctx, t, config.SiderolinkServiceJoinTokensModeStrict)
|
||||
|
||||
provisionHandler := siderolink.NewProvisionHandler(zaptest.NewLogger(t), st, config.SiderolinkServiceJoinTokensModeStrict, false, 0)
|
||||
|
||||
for _, id := range []string{"m1", "m2", "m3"} {
|
||||
require.NoError(t, st.Create(ctx, siderolinkres.NewLink(id, &specs.SiderolinkSpec{})))
|
||||
}
|
||||
|
||||
uniqueToken, tokenErr := jointoken.NewNodeUniqueToken(uuid.NewString(), uuid.NewString()).Encode()
|
||||
require.NoError(t, tokenErr)
|
||||
|
||||
_, err := provisionHandler.Provision(ctx, &pb.ProvisionRequest{
|
||||
NodeUuid: "m4",
|
||||
NodePublicKey: genKey(),
|
||||
TalosVersion: new("v1.9.0"),
|
||||
JoinToken: new(validToken),
|
||||
NodeUniqueToken: new(uniqueToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user