From ad62b32ff051bcee9b2d3df71c1a91d19986d4cf Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 30 Sep 2016 00:06:40 -0400 Subject: [PATCH] Rejig where the reload functions live --- cli/commands.go | 6 ++---- command/server.go | 18 ++++++++++++++---- command/server/config.go | 3 --- command/server/listener.go | 7 ++++--- command/server/listener_atlas.go | 3 ++- command/server/listener_tcp.go | 4 +++- command/server_test.go | 6 ++---- vault/core.go | 20 ++++++++++++++++++++ 8 files changed, 47 insertions(+), 20 deletions(-) diff --git a/cli/commands.go b/cli/commands.go index afd8611221..5297ea43bf 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -5,7 +5,6 @@ import ( auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" - "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/version" credAppId "github.com/hashicorp/vault/builtin/credential/app-id" @@ -87,9 +86,8 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "ssh": ssh.Factory, "rabbitmq": rabbitmq.Factory, }, - ShutdownCh: command.MakeShutdownCh(), - SighupCh: command.MakeSighupCh(), - ReloadFuncs: map[string][]server.ReloadFunc{}, + ShutdownCh: command.MakeShutdownCh(), + SighupCh: command.MakeSighupCh(), }, nil }, diff --git a/command/server.go b/command/server.go index 1274027b6b..3c35d40e6d 100644 --- a/command/server.go +++ b/command/server.go @@ -54,7 +54,8 @@ type ServerCommand struct { logger log.Logger - ReloadFuncs map[string][]server.ReloadFunc + reloadFuncsLock *sync.RWMutex + reloadFuncs *map[string][]vault.ReloadFunc } func (c *ServerCommand) Run(args []string) int { @@ -338,6 +339,10 @@ func (c *ServerCommand) Run(args []string) int { } } + // Copy the reload funcs pointers back + c.reloadFuncs = coreConfig.ReloadFuncs + c.reloadFuncsLock = coreConfig.ReloadFuncsLock + // Compile server information for output later info["backend"] = config.Backend.Type info["log level"] = logLevel @@ -374,6 +379,7 @@ func (c *ServerCommand) Run(args []string) int { clusterAddrs := []*net.TCPAddr{} // Initialize the listeners + c.reloadFuncsLock.Lock() lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { if lnConfig.Type == "atlas" { @@ -396,9 +402,9 @@ func (c *ServerCommand) Run(args []string) int { lns = append(lns, ln) if reloadFunc != nil { - relSlice := c.ReloadFuncs["listener|"+lnConfig.Type] + relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type] relSlice = append(relSlice, reloadFunc) - c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice + (*c.reloadFuncs)["listener|"+lnConfig.Type] = relSlice } if !disableClustering && lnConfig.Type == "tcp" { @@ -440,6 +446,7 @@ func (c *ServerCommand) Run(args []string) int { "%s (%s)", lnConfig.Type, strings.Join(propsList, ", ")) } + c.reloadFuncsLock.Unlock() if !disableClustering { if c.logger.IsTrace() { c.logger.Trace("cluster listener addresses synthesized", "cluster_addresses", clusterAddrs) @@ -855,11 +862,14 @@ func (c *ServerCommand) Reload(configPath []string) error { return retErr } + c.reloadFuncsLock.RLock() + defer c.reloadFuncsLock.RUnlock() + var reloadErrors *multierror.Error // Call reload on the listeners. This will call each listener with each // config block, but they verify the address. for _, lnConfig := range config.Listeners { - for _, relFunc := range c.ReloadFuncs["listener|"+lnConfig.Type] { + for _, relFunc := range (*c.reloadFuncs)["listener|"+lnConfig.Type] { if err := relFunc(lnConfig.Config); err != nil { retErr := fmt.Errorf("Error encountered reloading configuration: %s", err) reloadErrors = multierror.Append(retErr) diff --git a/command/server/config.go b/command/server/config.go index 9830128542..54321f889c 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -17,9 +17,6 @@ import ( "github.com/hashicorp/hcl/hcl/ast" ) -// ReloadFunc are functions that are called when a reload is requested. -type ReloadFunc func(map[string]string) error - // Config is the configuration for the vault server. type Config struct { Listeners []*Listener `hcl:"-"` diff --git a/command/server/listener.go b/command/server/listener.go index fd5fecbc71..2aa731d252 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -12,10 +12,11 @@ import ( "sync" "github.com/hashicorp/vault/helper/tlsutil" + "github.com/hashicorp/vault/vault" ) // ListenerFactory is the factory function to create a listener. -type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error) +type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ @@ -25,7 +26,7 @@ var BuiltinListeners = map[string]ListenerFactory{ // NewListener creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { +func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) { f, ok := BuiltinListeners[t] if !ok { return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) @@ -37,7 +38,7 @@ func NewListener(t string, config map[string]string, logger io.Writer) (net.List func listenerWrapTLS( ln net.Listener, props map[string]string, - config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { + config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) { props["tls"] = "disabled" if v, ok := config["tls_disable"]; ok { diff --git a/command/server/listener_atlas.go b/command/server/listener_atlas.go index 749cc3333b..c000474e7a 100644 --- a/command/server/listener_atlas.go +++ b/command/server/listener_atlas.go @@ -5,6 +5,7 @@ import ( "net" "github.com/hashicorp/scada-client/scada" + "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/version" ) @@ -26,7 +27,7 @@ func (s *SCADAListener) Addr() net.Addr { return s.ln.Addr() } -func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { +func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) { scadaConfig := &scada.Config{ Service: "vault", Version: version.GetVersion().VersionNumber(), diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index c351263384..9435c25103 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -4,9 +4,11 @@ import ( "io" "net" "time" + + "github.com/hashicorp/vault/vault" ) -func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) { +func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) { addr, ok := config["address"] if !ok { addr = "127.0.0.1:8200" diff --git a/command/server_test.go b/command/server_test.go index 24c8deed58..c78ee6b5b0 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/meta" "github.com/mitchellh/cli" ) @@ -183,9 +182,8 @@ func TestServer_ReloadListener(t *testing.T) { Meta: meta.Meta{ Ui: ui, }, - ShutdownCh: MakeShutdownCh(), - SighupCh: MakeSighupCh(), - ReloadFuncs: map[string][]server.ReloadFunc{}, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), } finished := false diff --git a/vault/core.go b/vault/core.go index 99c4fd5f81..21c152a3e2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -89,6 +89,9 @@ var ( manualStepDownSleepPeriod = 10 * time.Second ) +// ReloadFunc are functions that are called when a reload is requested. +type ReloadFunc func(map[string]string) error + // NonFatalError is an error that can be returned during NewCore that should be // displayed but not cause a program exit type NonFatalError struct { @@ -242,6 +245,12 @@ type Core struct { // cachingDisabled indicates whether caches are disabled cachingDisabled bool + // reloadFuncs is a map containing reload functions + reloadFuncs map[string][]ReloadFunc + + // reloadFuncsLock controlls access to the funcs + reloadFuncsLock sync.RWMutex + // // Cluster information // @@ -322,6 +331,9 @@ type CoreConfig struct { MaxLeaseTTL time.Duration `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` ClusterName string `json:"cluster_name" structs:"cluster_name" mapstructure:"cluster_name"` + + ReloadFuncs *map[string][]ReloadFunc + ReloadFuncsLock *sync.RWMutex } // NewCore is used to construct a new core @@ -415,6 +427,14 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.ha = conf.HAPhysical } + // We create the funcs here, then populate the given config with it so that + // the caller can share state + conf.ReloadFuncsLock = &c.reloadFuncsLock + c.reloadFuncsLock.Lock() + c.reloadFuncs = make(map[string][]ReloadFunc) + c.reloadFuncsLock.Unlock() + conf.ReloadFuncs = &c.reloadFuncs + // Setup the backends logicalBackends := make(map[string]logical.Factory) for k, f := range conf.LogicalBackends {