From 2e790526f760c890ad892fffd165ac27ab0dd9b4 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Thu, 28 Jul 2022 21:17:57 +0400 Subject: [PATCH] refactor: make apid stop gracefully and be stopped late This fixes apid and machined shutdown sequences to do graceful stop of gRPC server with timeout. Also sequences are restructured to stop apid/machined as late as possible allowing access to the node while the long sequence is running (e.g. upgrade or reset). Signed-off-by: Andrey Smirnov --- internal/app/apid/main.go | 79 ++++++++++++++----- .../runtime/v1alpha1/v1alpha1_sequencer.go | 15 ++-- .../v1alpha1/v1alpha1_sequencer_tasks.go | 6 +- .../machined/pkg/system/services/machined.go | 8 +- pkg/grpc/factory/factory.go | 20 +++++ 5 files changed, 99 insertions(+), 29 deletions(-) diff --git a/internal/app/apid/main.go b/internal/app/apid/main.go index 776580020..a3a25f9c6 100644 --- a/internal/app/apid/main.go +++ b/internal/app/apid/main.go @@ -7,8 +7,12 @@ package apid import ( "context" "flag" + "fmt" "log" + "os/signal" "regexp" + "syscall" + "time" "github.com/cosi-project/runtime/api/v1alpha1" "github.com/cosi-project/runtime/pkg/state" @@ -46,21 +50,30 @@ func runDebugServer(ctx context.Context) { // Main is the entrypoint of apid. func Main() { + if err := apidMain(); err != nil { + log.Fatal(err) + } +} + +func apidMain() error { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds | log.Ltime) rbacEnabled = flag.Bool("enable-rbac", false, "enable RBAC for Talos API") flag.Parse() - go runDebugServer(context.TODO()) + go runDebugServer(ctx) if err := startup.RandSeed(); err != nil { - log.Fatalf("failed to seed RNG: %v", err) + return fmt.Errorf("failed to seed RNG: %w", err) } runtimeConn, err := grpc.Dial("unix://"+constants.APIRuntimeSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - log.Fatalf("failed to dial runtime connection: %v", err) + return fmt.Errorf("failed to dial runtime connection: %w", err) } stateClient := v1alpha1.NewStateClient(runtimeConn) @@ -68,17 +81,17 @@ func Main() { tlsConfig, err := provider.NewTLSConfig(resources) if err != nil { - log.Fatalf("failed to create remote certificate provider: %+v", err) + return fmt.Errorf("failed to create remote certificate provider: %w", err) } serverTLSConfig, err := tlsConfig.ServerConfig() if err != nil { - log.Fatalf("failed to create OS-level TLS configuration: %v", err) + return fmt.Errorf("failed to create OS-level TLS configuration: %w", err) } clientTLSConfig, err := tlsConfig.ClientConfig() if err != nil { - log.Fatalf("failed to create client TLS config: %v", err) + return fmt.Errorf("failed to create client TLS config: %w", err) } backendFactory := apidbackend.NewAPIDFactory(clientTLSConfig) @@ -109,9 +122,22 @@ func Main() { // register future pattern: method should have suffix "Stream" router.RegisterStreamedRegex("Stream$") - var errGroup errgroup.Group + networkListener, err := factory.NewListener( + factory.Port(constants.ApidPort), + ) + if err != nil { + return fmt.Errorf("error creating listner: %w", err) + } - errGroup.Go(func() error { + socketListener, err := factory.NewListener( + factory.Network("unix"), + factory.SocketPath(constants.APISocketPath), + ) + if err != nil { + return fmt.Errorf("error creating listner: %w", err) + } + + networkServer := func() *grpc.Server { mode := authz.Disabled if *rbacEnabled { mode = authz.Enabled @@ -122,9 +148,8 @@ func Main() { Logger: log.New(log.Writer(), "apid/authz/injector/http ", log.Flags()).Printf, } - return factory.ListenAndServe( + return factory.NewServer( router, - factory.Port(constants.ApidPort), factory.WithDefaultLog(), factory.ServerOptions( grpc.Creds( @@ -140,18 +165,16 @@ func Main() { factory.WithUnaryInterceptor(injector.UnaryInterceptor()), factory.WithStreamInterceptor(injector.StreamInterceptor()), ) - }) + }() - errGroup.Go(func() error { + socketServer := func() *grpc.Server { injector := &authz.Injector{ Mode: authz.MetadataOnly, Logger: log.New(log.Writer(), "apid/authz/injector/unix ", log.Flags()).Printf, } - return factory.ListenAndServe( + return factory.NewServer( router, - factory.Network("unix"), - factory.SocketPath(constants.APISocketPath), factory.WithDefaultLog(), factory.ServerOptions( grpc.CustomCodec(proxy.Codec()), //nolint:staticcheck @@ -164,9 +187,29 @@ func Main() { factory.WithUnaryInterceptor(injector.UnaryInterceptor()), factory.WithStreamInterceptor(injector.StreamInterceptor()), ) + }() + + errGroup, ctx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + return networkServer.Serve(networkListener) }) - if err := errGroup.Wait(); err != nil { - log.Fatalf("listen: %v", err) - } + errGroup.Go(func() error { + return socketServer.Serve(socketListener) + }) + + errGroup.Go(func() error { + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + factory.ServerGracefulStop(networkServer, shutdownCtx) + factory.ServerGracefulStop(socketServer, shutdownCtx) + + return nil + }) + + return errGroup.Wait() } diff --git a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go index 1f70fe16e..b947e4789 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer.go @@ -400,7 +400,7 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru LeaveEtcd, ).Append( "stopServices", - StopServicesForUpgrade, + StopServicesEphemeral, ).Append( "unmountUser", UnmountUserDisks, @@ -421,9 +421,6 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru ).Append( "upgrade", Upgrade, - ).Append( - "stopEverything", - StopAllServices, ).Append( "mountBoot", MountBootPartition, @@ -433,6 +430,9 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru ).Append( "unmountBoot", UnmountBootPartition, + ).Append( + "stopEverything", + StopAllServices, ).Append( "reboot", Reboot, @@ -453,8 +453,8 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList { ) default: phases = phases.Append( - "stopEverything", - StopAllServices, + "stopServices", + StopServicesEphemeral, ).Append( "unmountUser", UnmountUserDisks, @@ -481,6 +481,9 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList { enableKexec, "unmountBoot", UnmountBootPartition, + ).Append( + "stopEverything", + StopAllServices, ) } diff --git a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go index 918ef1ead..6e00534a2 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_sequencer_tasks.go @@ -796,11 +796,11 @@ func StartAllServices(seq runtime.Sequence, data interface{}) (runtime.TaskExecu }, "startAllServices" } -// StopServicesForUpgrade represents the StopServicesForUpgrade task. -func StopServicesForUpgrade(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) { +// StopServicesEphemeral represents the StopServicesEphemeral task. +func StopServicesEphemeral(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) { return func(ctx context.Context, logger *log.Logger, r runtime.Runtime) (err error) { // stopping 'cri' service stops everything which depends on it (kubelet, etcd, ...) - return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd") + return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd", "trustd") }, "stopServicesForUpgrade" } diff --git a/internal/app/machined/pkg/system/services/machined.go b/internal/app/machined/pkg/system/services/machined.go index 32b0e6b68..5dda0bae0 100644 --- a/internal/app/machined/pkg/system/services/machined.go +++ b/internal/app/machined/pkg/system/services/machined.go @@ -10,6 +10,7 @@ import ( "log" "os" "path/filepath" + "time" v1alpha1server "github.com/talos-systems/talos/internal/app/machined/internal/server/v1alpha1" "github.com/talos-systems/talos/internal/app/machined/pkg/runtime" @@ -134,8 +135,6 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter return err } - defer server.Stop() - go func() { //nolint:errcheck server.Serve(listener) @@ -143,6 +142,11 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + factory.ServerGracefulStop(server, shutdownCtx) + return nil } diff --git a/pkg/grpc/factory/factory.go b/pkg/grpc/factory/factory.go index ff14b1969..1d41fe9bf 100644 --- a/pkg/grpc/factory/factory.go +++ b/pkg/grpc/factory/factory.go @@ -5,6 +5,7 @@ package factory import ( + "context" "crypto/tls" "errors" "fmt" @@ -257,3 +258,22 @@ func ListenAndServe(r Registrator, setters ...Option) (err error) { return server.Serve(listener) } + +// ServerGracefulStop the server with a timeout. +// +// Core gRPC doesn't support timeouts. +func ServerGracefulStop(server *grpc.Server, shutdownCtx context.Context) { //nolint:revive + stopped := make(chan struct{}) + + go func() { + server.GracefulStop() + close(stopped) + }() + + select { + case <-shutdownCtx.Done(): + server.Stop() + case <-stopped: + server.Stop() + } +}