diff --git a/internal/app/apid/main.go b/internal/app/apid/main.go index 776580020..a3a25f9c6 100644 --- a/internal/app/apid/main.go +++ b/internal/app/apid/main.go @@ -7,8 +7,12 @@ package apid import ( "context" "flag" + "fmt" "log" + "os/signal" "regexp" + "syscall" + "time" "github.com/cosi-project/runtime/api/v1alpha1" "github.com/cosi-project/runtime/pkg/state" @@ -46,21 +50,30 @@ func runDebugServer(ctx context.Context) { // Main is the entrypoint of apid. func Main() { + if err := apidMain(); err != nil { + log.Fatal(err) + } +} + +func apidMain() error { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds | log.Ltime) rbacEnabled = flag.Bool("enable-rbac", false, "enable RBAC for Talos API") flag.Parse() - go runDebugServer(context.TODO()) + go runDebugServer(ctx) if err := startup.RandSeed(); err != nil { - log.Fatalf("failed to seed RNG: %v", err) + return fmt.Errorf("failed to seed RNG: %w", err) } runtimeConn, err := grpc.Dial("unix://"+constants.APIRuntimeSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - log.Fatalf("failed to dial runtime connection: %v", err) + return fmt.Errorf("failed to dial runtime connection: %w", err) } stateClient := v1alpha1.NewStateClient(runtimeConn) @@ -68,17 +81,17 @@ func Main() { tlsConfig, err := provider.NewTLSConfig(resources) if err != nil { - log.Fatalf("failed to create remote certificate provider: %+v", err) + return fmt.Errorf("failed to create remote certificate provider: %w", err) } serverTLSConfig, err := tlsConfig.ServerConfig() if err != nil { - log.Fatalf("failed to create OS-level TLS configuration: %v", err) + return fmt.Errorf("failed to create OS-level TLS configuration: %w", err) } clientTLSConfig, err := tlsConfig.ClientConfig() if err != nil { - log.Fatalf("failed to create client TLS config: %v", err) + return fmt.Errorf("failed to create client TLS config: %w", err) } backendFactory := apidbackend.NewAPIDFactory(clientTLSConfig) @@ -109,9 +122,22 @@ func Main() { // register future pattern: method should have suffix "Stream" router.RegisterStreamedRegex("Stream$") - var errGroup errgroup.Group + networkListener, err := factory.NewListener( + factory.Port(constants.ApidPort), + ) + if err != nil { + return fmt.Errorf("error creating listner: %w", err) + } - errGroup.Go(func() error { + socketListener, err := factory.NewListener( + factory.Network("unix"), + factory.SocketPath(constants.APISocketPath), + ) + if err != nil { + return fmt.Errorf("error creating listner: %w", err) + } + + networkServer := func() *grpc.Server { mode := authz.Disabled if *rbacEnabled { mode = authz.Enabled @@ -122,9 +148,8 @@ func Main() { Logger: log.New(log.Writer(), "apid/authz/injector/http ", log.Flags()).Printf, } - return factory.ListenAndServe( + return factory.NewServer( router, - factory.Port(constants.ApidPort), factory.WithDefaultLog(), factory.ServerOptions( grpc.Creds( @@ -140,18 +165,16 @@ func Main() { factory.WithUnaryInterceptor(injector.UnaryInterceptor()), factory.WithStreamInterceptor(injector.StreamInterceptor()), ) - }) + }() - errGroup.Go(func() error { + socketServer := func() *grpc.Server { injector := &authz.Injector{ Mode: authz.MetadataOnly, Logger: log.New(log.Writer(), "apid/authz/injector/unix ", log.Flags()).Printf, } - return factory.ListenAndServe( + return factory.NewServer( router, - factory.Network("unix"), - factory.SocketPath(constants.APISocketPath), factory.WithDefaultLog(), factory.ServerOptions( grpc.CustomCodec(proxy.Codec()), //nolint:staticcheck @@ -164,9 +187,29 @@ func Main() { factory.WithUnaryInterceptor(injector.UnaryInterceptor()), factory.WithStreamInterceptor(injector.StreamInterceptor()), ) + }() + + errGroup, ctx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + return networkServer.Serve(networkListener) }) - if err := errGroup.Wait(); err != nil { - log.Fatalf("listen: %v", err) - } + errGroup.Go(func() error { + return socketServer.Serve(socketListener) + }) + + errGroup.Go(func() error { + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + factory.ServerGracefulStop(networkServer, shutdownCtx) + factory.ServerGracefulStop(socketServer, shutdownCtx) + + return nil + }) + + return errGroup.Wait() } diff --git a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go index 1f70fe16e..b947e4789 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go @@ -400,7 +400,7 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru LeaveEtcd, ).Append( "stopServices", - StopServicesForUpgrade, + StopServicesEphemeral, ).Append( "unmountUser", UnmountUserDisks, @@ -421,9 +421,6 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru ).Append( "upgrade", Upgrade, - ).Append( - "stopEverything", - StopAllServices, ).Append( "mountBoot", MountBootPartition, @@ -433,6 +430,9 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru ).Append( "unmountBoot", UnmountBootPartition, + ).Append( + "stopEverything", + StopAllServices, ).Append( "reboot", Reboot, @@ -453,8 +453,8 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList { ) default: phases = phases.Append( - "stopEverything", - StopAllServices, + "stopServices", + StopServicesEphemeral, ).Append( "unmountUser", UnmountUserDisks, @@ -481,6 +481,9 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList { enableKexec, "unmountBoot", UnmountBootPartition, + ).Append( + "stopEverything", + StopAllServices, ) } diff --git a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go index 918ef1ead..6e00534a2 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go @@ -796,11 +796,11 @@ func StartAllServices(seq runtime.Sequence, data interface{}) (runtime.TaskExecu }, "startAllServices" } -// StopServicesForUpgrade represents the StopServicesForUpgrade task. -func StopServicesForUpgrade(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) { +// StopServicesEphemeral represents the StopServicesEphemeral task. +func StopServicesEphemeral(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) { return func(ctx context.Context, logger *log.Logger, r runtime.Runtime) (err error) { // stopping 'cri' service stops everything which depends on it (kubelet, etcd, ...) - return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd") + return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd", "trustd") }, "stopServicesForUpgrade" } diff --git a/internal/app/machined/pkg/system/services/machined.go b/internal/app/machined/pkg/system/services/machined.go index 32b0e6b68..5dda0bae0 100644 --- a/internal/app/machined/pkg/system/services/machined.go +++ b/internal/app/machined/pkg/system/services/machined.go @@ -10,6 +10,7 @@ import ( "log" "os" "path/filepath" + "time" v1alpha1server "github.com/talos-systems/talos/internal/app/machined/internal/server/v1alpha1" "github.com/talos-systems/talos/internal/app/machined/pkg/runtime" @@ -134,8 +135,6 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter return err } - defer server.Stop() - go func() { //nolint:errcheck server.Serve(listener) @@ -143,6 +142,11 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + factory.ServerGracefulStop(server, shutdownCtx) + return nil } diff --git a/pkg/grpc/factory/factory.go b/pkg/grpc/factory/factory.go index ff14b1969..1d41fe9bf 100644 --- a/pkg/grpc/factory/factory.go +++ b/pkg/grpc/factory/factory.go @@ -5,6 +5,7 @@ package factory import ( + "context" "crypto/tls" "errors" "fmt" @@ -257,3 +258,22 @@ func ListenAndServe(r Registrator, setters ...Option) (err error) { return server.Serve(listener) } + +// ServerGracefulStop the server with a timeout. +// +// Core gRPC doesn't support timeouts. +func ServerGracefulStop(server *grpc.Server, shutdownCtx context.Context) { //nolint:revive + stopped := make(chan struct{}) + + go func() { + server.GracefulStop() + close(stopped) + }() + + select { + case <-shutdownCtx.Done(): + server.Stop() + case <-stopped: + server.Stop() + } +}