diff --git a/client/api/omni/specs/omni.pb.go b/client/api/omni/specs/omni.pb.go index 374d6cc6..a48213e7 100644 --- a/client/api/omni/specs/omni.pb.go +++ b/client/api/omni/specs/omni.pb.go @@ -7242,6 +7242,7 @@ type ClusterMachineRequestStatusSpec struct { MachineUuid string `protobuf:"bytes,2,opt,name=machine_uuid,json=machineUuid,proto3" json:"machine_uuid,omitempty"` ProviderId string `protobuf:"bytes,3,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"` Stage ClusterMachineRequestStatusSpec_Stage `protobuf:"varint,4,opt,name=stage,proto3,enum=specs.ClusterMachineRequestStatusSpec_Stage" json:"stage,omitempty"` + Error string `protobuf:"bytes,5,opt,name=error,proto3" json:"error,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -7304,6 +7305,13 @@ func (x *ClusterMachineRequestStatusSpec) GetStage() ClusterMachineRequestStatus return ClusterMachineRequestStatusSpec_UNKNOWN } +func (x *ClusterMachineRequestStatusSpec) GetError() string { + if x != nil { + return x.Error + } + return "" +} + // InfraMachineConfigSpec is the spec of the user-owned InfraMachineConfig resource. type InfraMachineConfigSpec struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -11755,13 +11763,14 @@ const file_omni_specs_omni_proto_rawDesc = "" + "\x02id\x18\x01 \x01(\tR\x02id\x12'\n" + "\x0fnum_diagnostics\x18\x02 \x01(\rR\x0enumDiagnostics\"L\n" + "\x1dMachineRequestSetPressureSpec\x12+\n" + - "\x11required_machines\x18\x01 \x01(\rR\x10requiredMachines\"\xa7\x02\n" + + "\x11required_machines\x18\x01 \x01(\rR\x10requiredMachines\"\xbd\x02\n" + "\x1fClusterMachineRequestStatusSpec\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x12!\n" + "\fmachine_uuid\x18\x02 \x01(\tR\vmachineUuid\x12\x1f\n" + "\vprovider_id\x18\x03 \x01(\tR\n" + "providerId\x12B\n" + - "\x05stage\x18\x04 \x01(\x0e2,.specs.ClusterMachineRequestStatusSpec.StageR\x05stage\"d\n" + + "\x05stage\x18\x04 \x01(\x0e2,.specs.ClusterMachineRequestStatusSpec.StageR\x05stage\x12\x14\n" + + "\x05error\x18\x05 \x01(\tR\x05error\"d\n" + "\x05Stage\x12\v\n" + "\aUNKNOWN\x10\x00\x12\v\n" + "\aPENDING\x10\x01\x12\x10\n" + diff --git a/client/api/omni/specs/omni.proto b/client/api/omni/specs/omni.proto index 1a1f1fe0..13bfe7ef 100644 --- a/client/api/omni/specs/omni.proto +++ b/client/api/omni/specs/omni.proto @@ -1446,6 +1446,7 @@ message ClusterMachineRequestStatusSpec { string machine_uuid = 2; string provider_id = 3; Stage stage = 4; + string error = 5; } // InfraMachineConfigSpec is the spec of the user-owned InfraMachineConfig resource. diff --git a/client/api/omni/specs/omni_vtproto.pb.go b/client/api/omni/specs/omni_vtproto.pb.go index f57022be..6b68db1f 100644 --- a/client/api/omni/specs/omni_vtproto.pb.go +++ b/client/api/omni/specs/omni_vtproto.pb.go @@ -2694,6 +2694,7 @@ func (m *ClusterMachineRequestStatusSpec) CloneVT() *ClusterMachineRequestStatus r.MachineUuid = m.MachineUuid r.ProviderId = m.ProviderId r.Stage = m.Stage + r.Error = m.Error if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -6928,6 +6929,9 @@ func (this *ClusterMachineRequestStatusSpec) EqualVT(that *ClusterMachineRequest if this.Stage != that.Stage { return false } + if this.Error != that.Error { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -14868,6 +14872,13 @@ func (m *ClusterMachineRequestStatusSpec) MarshalToSizedBufferVT(dAtA []byte) (i i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.Error) > 0 { + i -= len(m.Error) + copy(dAtA[i:], m.Error) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Error))) + i-- + dAtA[i] = 0x2a + } if m.Stage != 0 { i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Stage)) i-- @@ -19172,6 +19183,10 @@ func (m *ClusterMachineRequestStatusSpec) SizeVT() (n int) { if m.Stage != 0 { n += 1 + protohelpers.SizeOfVarint(uint64(m.Stage)) } + l = len(m.Error) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -38089,6 +38104,38 @@ func (m *ClusterMachineRequestStatusSpec) UnmarshalVT(dAtA []byte) error { break } } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/client/pkg/infra/controllers/provision.go b/client/pkg/infra/controllers/provision.go index f824c864..f391806f 100644 --- a/client/pkg/infra/controllers/provision.go +++ b/client/pkg/infra/controllers/provision.go @@ -6,6 +6,7 @@ package controllers import ( "context" + "errors" "fmt" "slices" "strings" @@ -18,7 +19,6 @@ import ( "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" "github.com/siderolabs/gen/optional" - "github.com/siderolabs/gen/xerrors" "go.uber.org/zap" "github.com/siderolabs/omni/client/api/omni/specs" @@ -316,10 +316,13 @@ func (ctrl *ProvisionController[T]) reconcileRunning(ctx context.Context, r cont machineRequestStatus.TypedSpec().Value.Error = "" machineRequestStatus.TypedSpec().Value.Stage = specs.MachineRequestStatusSpec_PROVISIONING + // Pass a copy to the step so mutations beyond Id and LabelMachineInfraID don't leak into our state. + mrsCopy := machineRequestStatus.DeepCopy().(*infra.MachineRequestStatus) //nolint:forcetypeassert,errcheck + if err = safe.WriterModify(ctx, r, res.(T), func(st T) error { //nolint:forcetypeassert,errcheck err = step.Run(ctx, logger, provision.NewContext( machineRequest, - machineRequestStatus, + mrsCopy, st, connectionParams, ctrl.imageFactory, @@ -329,7 +332,7 @@ func (ctrl *ProvisionController[T]) reconcileRunning(ctx context.Context, r cont st.Metadata().Annotations().Set(currentStepAnnotation, step.Name()) if err != nil { - if !xerrors.TypeIs[*controller.RequeueError](err) { + if _, ok := errors.AsType[*controller.RequeueError](err); !ok { //nolint:errcheck return err } @@ -340,8 +343,16 @@ func (ctrl *ProvisionController[T]) reconcileRunning(ctx context.Context, r cont }); err != nil { logger.Error("machine provision failed", zap.Error(err), zap.String("step", step.Name())) - machineRequestStatus.TypedSpec().Value.Error = err.Error() - machineRequestStatus.TypedSpec().Value.Stage = specs.MachineRequestStatusSpec_FAILED + if writeErr := safe.WriterModify(ctx, r, machineRequestStatus, func(res *infra.MachineRequestStatus) error { + applyStepMutations(res, mrsCopy) + + res.TypedSpec().Value.Error = err.Error() + res.TypedSpec().Value.Stage = specs.MachineRequestStatusSpec_FAILED + + return nil + }); writeErr != nil { + return writeErr + } return controller.NewRequeueError(err, time.Minute) } @@ -351,6 +362,12 @@ func (ctrl *ProvisionController[T]) reconcileRunning(ctx context.Context, r cont ) error { res.TypedSpec().Value = machineRequestStatus.TypedSpec().Value + applyStepMutations(res, mrsCopy) + + if reqErr, ok := errors.AsType[*controller.RequeueError](requeueError); ok && reqErr.Err() != nil { + res.TypedSpec().Value.Error = reqErr.Err().Error() + } + return nil }); err != nil { return err @@ -377,6 +394,15 @@ func (ctrl *ProvisionController[T]) reconcileRunning(ctx context.Context, r cont return nil } +// applyStepMutations copies the fields a provision step is allowed to set on MachineRequestStatus from src to dst. +func applyStepMutations(dst, src *infra.MachineRequestStatus) { + dst.TypedSpec().Value.Id = src.TypedSpec().Value.Id + + if infraID, ok := src.Metadata().Labels().Get(omni.LabelMachineInfraID); ok { + dst.Metadata().Labels().Set(omni.LabelMachineInfraID, infraID) + } +} + func (ctrl *ProvisionController[T]) removePatches(ctx context.Context, r controller.QRuntime, requestID string) (bool, error) { patches, err := safe.ReaderListAll[*infra.ConfigPatchRequest](ctx, r, state.WithLabelQuery( resource.LabelEqual(omni.LabelInfraProviderID, ctrl.providerID), diff --git a/client/pkg/infra/infra_test.go b/client/pkg/infra/infra_test.go index ea78700a..d7031111 100644 --- a/client/pkg/infra/infra_test.go +++ b/client/pkg/infra/infra_test.go @@ -13,6 +13,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "testing" "time" @@ -295,21 +296,9 @@ func TestInfra(t *testing.T) { ch: provisionChannel, } - state := setupInfra(ctx, t, p, tt.options...) + st := setupInfra(ctx, t, p, tt.options...) - providerJoinConfig := siderolink.NewProviderJoinConfig(providerID) - providerJoinConfig.TypedSpec().Value.JoinToken = "abcd" - - providerJoinConfig.Metadata().Labels().Set(omni.LabelInfraProviderID, providerID) - - require.NoError(t, state.Create(ctx, providerJoinConfig)) - - siderolinkAPIConfig := siderolink.NewAPIConfig() - siderolinkAPIConfig.TypedSpec().Value.MachineApiAdvertisedUrl = "http://127.0.0.1:8099" - siderolinkAPIConfig.TypedSpec().Value.LogsPort = 8092 - siderolinkAPIConfig.TypedSpec().Value.EventsPort = 8091 - - require.NoError(t, state.Create(ctx, siderolinkAPIConfig)) + createSiderolinkConfigs(ctx, t, st) customLabel := "custom" customValue := "hello" @@ -320,9 +309,9 @@ func TestInfra(t *testing.T) { patchID := machineRequest.Metadata().ID() - require.NoError(t, state.Create(ctx, machineRequest)) + require.NoError(t, st.Create(ctx, machineRequest)) - rtestutils.AssertResources(ctx, t, state, []string{machineRequest.Metadata().ID()}, func(machineRequestStatus *infrares.MachineRequestStatus, assert *assert.Assertions) { + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(machineRequestStatus *infrares.MachineRequestStatus, assert *assert.Assertions) { val, ok := machineRequestStatus.Metadata().Labels().Get(omni.LabelInfraProviderID) assert.True(ok) @@ -337,36 +326,36 @@ func TestInfra(t *testing.T) { require.True(t, channel.SendWithContext(ctx, provisionChannel, struct{}{})) - rtestutils.AssertResources(ctx, t, state, []string{machineRequest.Metadata().ID()}, func(machineRequestStatus *infrares.MachineRequestStatus, assert *assert.Assertions) { + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(machineRequestStatus *infrares.MachineRequestStatus, assert *assert.Assertions) { assert.Equal(specs.MachineRequestStatusSpec_PROVISIONED, machineRequestStatus.TypedSpec().Value.Stage) }) - rtestutils.AssertResources(ctx, t, state, []string{patchID}, func(r *infrares.ConfigPatchRequest, assert *assert.Assertions) { + rtestutils.AssertResources(ctx, t, st, []string{patchID}, func(r *infrares.ConfigPatchRequest, assert *assert.Assertions) { data, err := r.TypedSpec().Value.GetUncompressedData() assert.NoError(err) assert.EqualValues([]byte("machine: {}"), data.Data()) }) - rtestutils.AssertResources(ctx, t, state, []string{machineRequest.Metadata().ID()}, func(testResource *TestResource, assert *assert.Assertions) { + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(testResource *TestResource, assert *assert.Assertions) { assert.True(testResource.TypedSpec().Value.Connected) }) require.NotNil(t, p.getMachine(machineRequest.Metadata().ID())) - rtestutils.Destroy[*infrares.MachineRequest](ctx, t, state, []string{machineRequest.Metadata().ID()}) + rtestutils.Destroy[*infrares.MachineRequest](ctx, t, st, []string{machineRequest.Metadata().ID()}) - rtestutils.AssertNoResource[*infrares.MachineRequestStatus](ctx, t, state, machineRequest.Metadata().ID()) - rtestutils.AssertNoResource[*TestResource](ctx, t, state, machineRequest.Metadata().ID()) + rtestutils.AssertNoResource[*infrares.MachineRequestStatus](ctx, t, st, machineRequest.Metadata().ID()) + rtestutils.AssertNoResource[*TestResource](ctx, t, st, machineRequest.Metadata().ID()) require.Nil(t, p.getMachine(machineRequest.Metadata().ID())) - rtestutils.AssertNoResource[*infrares.ConfigPatchRequest](ctx, t, state, patchID) + rtestutils.AssertNoResource[*infrares.ConfigPatchRequest](ctx, t, st, patchID) }) } } -func setupInfra(ctx context.Context, t *testing.T, p *provisioner, opts ...infra.Option) state.State { +func setupInfra(ctx context.Context, t *testing.T, p provision.Provisioner[*TestResource], opts ...infra.Option) state.State { state := state.WrapCore(namespaced.NewState(inmem.Build)) resourceRegistry := registry.NewResourceRegistry(state) @@ -414,3 +403,163 @@ func setupInfra(ctx context.Context, t *testing.T, p *provisioner, opts ...infra return state } + +func createSiderolinkConfigs(ctx context.Context, t *testing.T, st state.State) { + providerJoinConfig := siderolink.NewProviderJoinConfig(providerID) + providerJoinConfig.TypedSpec().Value.JoinToken = "abcd" + providerJoinConfig.Metadata().Labels().Set(omni.LabelInfraProviderID, providerID) + + require.NoError(t, st.Create(ctx, providerJoinConfig)) + + siderolinkAPIConfig := siderolink.NewAPIConfig() + siderolinkAPIConfig.TypedSpec().Value.MachineApiAdvertisedUrl = "http://127.0.0.1:8099" + siderolinkAPIConfig.TypedSpec().Value.LogsPort = 8092 + siderolinkAPIConfig.TypedSpec().Value.EventsPort = 8091 + + require.NoError(t, st.Create(ctx, siderolinkAPIConfig)) +} + +// stepProvisioner is a Provisioner whose ProvisionSteps are configurable per test. +type stepProvisioner struct { + steps []provision.Step[*TestResource] +} + +func (p *stepProvisioner) ProvisionSteps() []provision.Step[*TestResource] { + return p.steps +} + +func (p *stepProvisioner) Deprovision(context.Context, *zap.Logger, *TestResource, *infrares.MachineRequest) error { + return nil +} + +func TestProvisionStepFailurePersistsError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + t.Cleanup(cancel) + + p := &stepProvisioner{ + steps: []provision.Step[*TestResource]{ + provision.NewStep("fail", func(context.Context, *zap.Logger, provision.Context[*TestResource]) error { + return errors.New("permanent failure") + }), + }, + } + + st := setupInfra(ctx, t, p) + createSiderolinkConfigs(ctx, t, st) + + machineRequest := infrares.NewMachineRequest("fail-test") + machineRequest.Metadata().Labels().Set(omni.LabelInfraProviderID, providerID) + + require.NoError(t, st.Create(ctx, machineRequest)) + + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(mrs *infrares.MachineRequestStatus, assert *assert.Assertions) { + assert.Equal(specs.MachineRequestStatusSpec_FAILED, mrs.TypedSpec().Value.Stage) + assert.Equal("permanent failure", mrs.TypedSpec().Value.Error) + }) +} + +func TestProvisionStepRequeuePersistsError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + t.Cleanup(cancel) + + var allowSuccess atomic.Bool + + p := &stepProvisioner{ + steps: []provision.Step[*TestResource]{ + provision.NewStep("retry-then-succeed", func(context.Context, *zap.Logger, provision.Context[*TestResource]) error { + if !allowSuccess.Load() { + return provision.NewRetryErrorf(500*time.Millisecond, "transient failure") + } + + return nil + }), + }, + } + + st := setupInfra(ctx, t, p) + createSiderolinkConfigs(ctx, t, st) + + machineRequest := infrares.NewMachineRequest("requeue-test") + machineRequest.Metadata().Labels().Set(omni.LabelInfraProviderID, providerID) + + require.NoError(t, st.Create(ctx, machineRequest)) + + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(mrs *infrares.MachineRequestStatus, assert *assert.Assertions) { + assert.Equal(specs.MachineRequestStatusSpec_PROVISIONING, mrs.TypedSpec().Value.Stage) + assert.Equal("transient failure", mrs.TypedSpec().Value.Error) + }) + + allowSuccess.Store(true) + + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(mrs *infrares.MachineRequestStatus, assert *assert.Assertions) { + assert.Equal(specs.MachineRequestStatusSpec_PROVISIONED, mrs.TypedSpec().Value.Stage) + assert.Empty(mrs.TypedSpec().Value.Error) + }) +} + +func TestProvisionStepMutationsRestricted(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + t.Cleanup(cancel) + + const ( + allowedUUID = "good-uuid" + allowedInfraID = "good-infra-id" + forbiddenLabel = "evil-label" + ) + + block := make(chan struct{}) + + t.Cleanup(func() { close(block) }) + + p := &stepProvisioner{ + steps: []provision.Step[*TestResource]{ + provision.NewStep("mutate", func(_ context.Context, _ *zap.Logger, pctx provision.Context[*TestResource]) error { + pctx.SetMachineUUID(allowedUUID) + pctx.SetMachineInfraID(allowedInfraID) + + // Direct mutations beyond the two helper methods must NOT propagate. + pctx.MachineRequestStatus.TypedSpec().Value.Status = "evil status" + pctx.MachineRequestStatus.TypedSpec().Value.Stage = specs.MachineRequestStatusSpec_FAILED + pctx.MachineRequestStatus.Metadata().Labels().Set(forbiddenLabel, "yes") + + return nil + }), + provision.NewStep("block", func(ctx context.Context, _ *zap.Logger, _ provision.Context[*TestResource]) error { + select { + case <-block: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }), + }, + } + + st := setupInfra(ctx, t, p) + createSiderolinkConfigs(ctx, t, st) + + machineRequest := infrares.NewMachineRequest("mutate-test") + machineRequest.Metadata().Labels().Set(omni.LabelInfraProviderID, providerID) + + require.NoError(t, st.Create(ctx, machineRequest)) + + rtestutils.AssertResources(ctx, t, st, []string{machineRequest.Metadata().ID()}, func(mrs *infrares.MachineRequestStatus, assert *assert.Assertions) { + assert.Equal(allowedUUID, mrs.TypedSpec().Value.Id) + + infraID, ok := mrs.Metadata().Labels().Get(omni.LabelMachineInfraID) + assert.True(ok) + assert.Equal(allowedInfraID, infraID) + + assert.NotEqual("evil status", mrs.TypedSpec().Value.Status) + assert.Equal(specs.MachineRequestStatusSpec_PROVISIONING, mrs.TypedSpec().Value.Stage) + + _, hasForbidden := mrs.Metadata().Labels().Get(forbiddenLabel) + assert.False(hasForbidden) + }) +} diff --git a/frontend/src/api/omni/specs/omni.pb.ts b/frontend/src/api/omni/specs/omni.pb.ts index dd974d17..0cb81264 100644 --- a/frontend/src/api/omni/specs/omni.pb.ts +++ b/frontend/src/api/omni/specs/omni.pb.ts @@ -981,6 +981,7 @@ export type ClusterMachineRequestStatusSpec = { machine_uuid?: string provider_id?: string stage?: ClusterMachineRequestStatusSpecStage + error?: string } export type InfraMachineConfigSpec = { diff --git a/internal/backend/runtime/omni/controllers/omni/cluster_machine_request_status.go b/internal/backend/runtime/omni/controllers/omni/cluster_machine_request_status.go index f954b483..8c0f1e7a 100644 --- a/internal/backend/runtime/omni/controllers/omni/cluster_machine_request_status.go +++ b/internal/backend/runtime/omni/controllers/omni/cluster_machine_request_status.go @@ -76,13 +76,13 @@ func NewClusterMachineRequestStatusController() *ClusterMachineRequestStatusCont clusterMachineRequestStatus.TypedSpec().Value.MachineUuid = machineRequestStatus.TypedSpec().Value.Id clusterMachineRequestStatus.TypedSpec().Value.Status = machineRequestStatus.TypedSpec().Value.Status + clusterMachineRequestStatus.TypedSpec().Value.Error = machineRequestStatus.TypedSpec().Value.Error switch machineRequestStatus.TypedSpec().Value.Stage { case specs.MachineRequestStatusSpec_UNKNOWN: clusterMachineRequestStatus.TypedSpec().Value.Stage = specs.ClusterMachineRequestStatusSpec_PENDING case specs.MachineRequestStatusSpec_PROVISIONING: clusterMachineRequestStatus.TypedSpec().Value.Stage = specs.ClusterMachineRequestStatusSpec_PROVISIONING - case specs.MachineRequestStatusSpec_PROVISIONED: clusterMachineRequestStatus.TypedSpec().Value.Stage = specs.ClusterMachineRequestStatusSpec_PROVISIONED