diff --git a/command/server.go b/command/server.go index 19f97b9a74..df05bc59a7 100644 --- a/command/server.go +++ b/command/server.go @@ -13,6 +13,7 @@ import ( "sort" "strconv" "strings" + "sync" "syscall" "time" @@ -43,6 +44,8 @@ type ServerCommand struct { ShutdownCh chan struct{} SighupCh chan struct{} + WaitGroup *sync.WaitGroup + meta.Meta logger *log.Logger @@ -308,31 +311,6 @@ func (c *ServerCommand) Run(args []string) int { } } - // If the backend supports service discovery, run service discovery - if coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() { - sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) - if ok { - activeFunc := func() bool { - if isLeader, _, err := core.Leader(); err == nil { - return isLeader - } - return false - } - - sealedFunc := func() bool { - if sealed, err := core.Sealed(); err == nil { - return sealed - } - return true - } - - if err := sd.RunServiceDiscovery(c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) - return 1 - } - } - } - // Initialize the listeners lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { @@ -392,6 +370,37 @@ func (c *ServerCommand) Run(args []string) int { return 0 } + // Perform service discovery registrations and initialization of + // HTTP server after the verifyOnly check. + + // Instantiate the wait group + c.WaitGroup = &sync.WaitGroup{} + + // If the backend supports service discovery, run service discovery + if coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() { + sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) + if ok { + activeFunc := func() bool { + if isLeader, _, err := core.Leader(); err == nil { + return isLeader + } + return false + } + + sealedFunc := func() bool { + if sealed, err := core.Sealed(); err == nil { + return sealed + } + return true + } + + if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil { + c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) + return 1 + } + } + } + // Initialize the HTTP server server := &http.Server{} server.Handler = vaulthttp.Handler(core) @@ -412,6 +421,7 @@ func (c *ServerCommand) Run(args []string) int { // Wait for shutdown shutdownTriggered := false + for !shutdownTriggered { select { case <-c.ShutdownCh: @@ -428,6 +438,8 @@ func (c *ServerCommand) Run(args []string) int { } } + // Wait for dependent goroutines to complete + c.WaitGroup.Wait() return 0 } @@ -746,10 +758,8 @@ func MakeShutdownCh() chan struct{} { shutdownCh := make(chan os.Signal, 4) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { - for { - <-shutdownCh - resultCh <- struct{}{} - } + <-shutdownCh + close(resultCh) }() return resultCh } diff --git a/physical/consul.go b/physical/consul.go index dc10e5741a..5f88a31941 100644 --- a/physical/consul.go +++ b/physical/consul.go @@ -416,17 +416,23 @@ func (c *ConsulBackend) checkDuration() time.Duration { return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) } -func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { +func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { if err := c.setAdvertiseAddr(advertiseAddr); err != nil { return err } - go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc) + // 'server' command will wait for the below goroutine to complete + waitGroup.Add(1) + + go c.runEventDemuxer(waitGroup, shutdownCh, advertiseAddr, activeFunc, sealedFunc) return nil } -func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { +func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { + // This defer statement should be executed last. So push it first. + defer waitGroup.Done() + // Fire the reconcileTimer immediately upon starting the event demuxer reconcileTimer := time.NewTimer(0) defer reconcileTimer.Stop() @@ -450,8 +456,8 @@ func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAdd var checkLock int64 var registeredServiceID string var serviceRegLock int64 -shutdown: - for { + + for !shutdown { select { case <-c.notifyActiveCh: // Run reconcile immediately upon active state change notification @@ -507,7 +513,6 @@ shutdown: case <-shutdownCh: c.logger.Printf("[INFO]: physical/consul: Shutting down consul backend") shutdown = true - break shutdown } } diff --git a/physical/consul_test.go b/physical/consul_test.go index 0687027e60..f382caaa63 100644 --- a/physical/consul_test.go +++ b/physical/consul_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "os" "reflect" + "sync" "testing" "time" @@ -195,7 +196,8 @@ func TestConsul_newConsulBackend(t *testing.T) { } var shutdownCh ShutdownChannel - if err := c.RunServiceDiscovery(shutdownCh, test.advertiseAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil { + waitGroup := &sync.WaitGroup{} + if err := c.RunServiceDiscovery(waitGroup, shutdownCh, test.advertiseAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil { t.Fatalf("bad: %v", err) } diff --git a/physical/physical.go b/physical/physical.go index ff74c9827d..9e96beb6d8 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -3,6 +3,7 @@ package physical import ( "fmt" "log" + "sync" ) const DefaultParallelOperations = 128 @@ -71,7 +72,7 @@ type ServiceDiscovery interface { // Run executes any background service discovery tasks until the // shutdown channel is closed. - RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error + RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error } type Lock interface {