From be7c31f695923b4d6f7804acb413e848ea0da3aa Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Thu, 21 Sep 2017 20:51:12 -0400 Subject: [PATCH] Fix bad rebase Apparently I can't git... --- command/format.go | 2 +- command/operator_unseal_test.go | 20 +- command/path_help_test.go | 4 +- command/server.go | 391 ++++++++++++++------------------ command/server_ha_test.go | 106 --------- command/server_test.go | 204 ++++++++++++----- command/status_test.go | 4 +- 7 files changed, 328 insertions(+), 403 deletions(-) delete mode 100644 command/server_ha_test.go diff --git a/command/format.go b/command/format.go index b8669303b8..a23268e46f 100644 --- a/command/format.go +++ b/command/format.go @@ -250,6 +250,6 @@ func OutputSealStatus(ui cli.Ui, client *api.Client, status *api.SealStatusRespo } } - ui.Output(columnOutput(out, nil)) + ui.Output(tableOutput(out, nil)) return 0 } diff --git a/command/operator_unseal_test.go b/command/operator_unseal_test.go index 9091ae40e2..e2222fc73b 100644 --- a/command/operator_unseal_test.go +++ b/command/operator_unseal_test.go @@ -1,7 +1,6 @@ package command import ( - "fmt" "io/ioutil" "strings" "testing" @@ -68,7 +67,7 @@ func TestOperatorUnsealCommand_Run(t *testing.T) { if exp := 0; code != exp { t.Errorf("expected %d to be %d", code, exp) } - expected := "Unseal Progress: 0" + expected := "0/3" combined := ui.OutputWriter.String() + ui.ErrorWriter.String() if !strings.Contains(combined, expected) { t.Errorf("expected %q to contain %q", combined, expected) @@ -86,7 +85,7 @@ func TestOperatorUnsealCommand_Run(t *testing.T) { t.Fatal(err) } - for i, key := range keys { + for _, key := range keys { ui, cmd := testOperatorUnsealCommand(t) cmd.client = client cmd.testOutput = ioutil.Discard @@ -96,14 +95,17 @@ func TestOperatorUnsealCommand_Run(t *testing.T) { key, }) if exp := 0; code != exp { - t.Errorf("expected %d to be %d", code, exp) - } - expected := fmt.Sprintf("Unseal Progress: %d", (i+1)%3) // 1, 2, 0 - combined := ui.OutputWriter.String() + ui.ErrorWriter.String() - if !strings.Contains(combined, expected) { - t.Errorf("expected %q to contain %q", combined, expected) + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) } } + + status, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if status.Sealed { + t.Error("expected unsealed") + } }) t.Run("communication_failure", func(t *testing.T) { diff --git a/command/path_help_test.go b/command/path_help_test.go index 4e788df130..688bcf09ce 100644 --- a/command/path_help_test.go +++ b/command/path_help_test.go @@ -46,9 +46,9 @@ func TestPathHelpCommand_Run(t *testing.T) { 2, }, { - "generic", + "kv", []string{"secret/"}, - "The generic backend", + "The kv backend", 0, }, { diff --git a/command/server.go b/command/server.go index 180a2fae66..13c444dd73 100644 --- a/command/server.go +++ b/command/server.go @@ -33,7 +33,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" - "github.com/hashicorp/vault/helper/flag-slice" "github.com/hashicorp/vault/helper/gated-writer" "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/mlock" @@ -41,7 +40,6 @@ import ( "github.com/hashicorp/vault/helper/reload" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/version" @@ -51,6 +49,8 @@ var _ cli.Command = (*ServerCommand)(nil) var _ cli.CommandAutocomplete = (*ServerCommand)(nil) type ServerCommand struct { + *BaseCommand + AuditBackends map[string]audit.Factory CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory @@ -61,8 +61,6 @@ type ServerCommand struct { WaitGroup *sync.WaitGroup - meta.Meta - logGate *gatedwriter.Writer logger log.Logger @@ -84,9 +82,10 @@ type ServerCommand struct { flagDevHA bool flagDevLatency int flagDevLatencyJitter int - flagDevTransactional bool flagDevLeasedKV bool + flagDevSkipInit bool flagDevThreeNode bool + flagDevTransactional bool flagTestVerifyOnly bool } @@ -223,6 +222,13 @@ func (c *ServerCommand) Flags() *FlagSets { Hidden: true, }) + f.BoolVar(&BoolVar{ + Name: "dev-skip-init", + Target: &c.flagDevSkipInit, + Default: false, + Hidden: true, + }) + f.BoolVar(&BoolVar{ Name: "dev-three-node", Target: &c.flagDevThreeNode, @@ -252,27 +258,10 @@ func (c *ServerCommand) AutocompleteFlags() complete.Flags { } func (c *ServerCommand) Run(args []string) int { - var dev, verifyOnly, devHA, devTransactional, devLeasedKV, devThreeNode, devSkipInit bool - var configPath []string - var logLevel, devRootTokenID, devListenAddress, devPluginDir string - var devLatency, devLatencyJitter int - flags := c.Meta.FlagSet("server", meta.FlagSetDefault) - flags.BoolVar(&dev, "dev", false, "") - flags.StringVar(&devRootTokenID, "dev-root-token-id", "", "") - flags.StringVar(&devListenAddress, "dev-listen-address", "", "") - flags.StringVar(&devPluginDir, "dev-plugin-dir", "", "") - flags.StringVar(&logLevel, "log-level", "info", "") - flags.IntVar(&devLatency, "dev-latency", 0, "") - flags.IntVar(&devLatencyJitter, "dev-latency-jitter", 20, "") - flags.BoolVar(&verifyOnly, "verify-only", false, "") - flags.BoolVar(&devHA, "dev-ha", false, "") - flags.BoolVar(&devTransactional, "dev-transactional", false, "") - flags.BoolVar(&devLeasedKV, "dev-leased-kv", false, "") - flags.BoolVar(&devThreeNode, "dev-three-node", false, "") - flags.BoolVar(&devSkipInit, "dev-skip-init", false, "") - flags.Usage = func() { c.Ui.Output(c.Help()) } - flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config") - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } @@ -280,8 +269,8 @@ func (c *ServerCommand) Run(args []string) int { // start logging too early. c.logGate = &gatedwriter.Writer{Writer: colorable.NewColorable(os.Stderr)} var level int - logLevel = strings.ToLower(strings.TrimSpace(logLevel)) - switch logLevel { + c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel)) + switch c.flagLogLevel { case "trace": level = log.LevelTrace case "debug": @@ -295,7 +284,7 @@ func (c *ServerCommand) Run(args []string) int { case "err": level = log.LevelError default: - c.Ui.Output(fmt.Sprintf("Unknown log level %s", logLevel)) + c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel)) return 1 } @@ -315,24 +304,16 @@ func (c *ServerCommand) Run(args []string) int { log: os.Getenv("VAULT_GRPC_LOGGING") != "", }) - if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" && devRootTokenID == "" { - devRootTokenID = os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") - } - - if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" && devListenAddress == "" { - devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS") - } - - if devHA || devTransactional || devLeasedKV || devThreeNode { - dev = true + // Automatically enable dev mode if other dev flags are provided. + if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode { + c.flagDev = true } // Validation - if !dev { + if !c.flagDev { switch { - case len(configPath) == 0: - c.Ui.Output("At least one config path must be specified with -config") - flags.Usage() + case len(c.flagConfigs) == 0: + c.UI.Error("Must specify at least one config path using -config") return 1 case c.flagDevRootTokenID != "": c.UI.Warn(wrapAtLength( @@ -344,17 +325,16 @@ func (c *ServerCommand) Run(args []string) int { // Load the configuration var config *server.Config - if dev { - config = server.DevConfig(devHA, devTransactional) - if devListenAddress != "" { - config.Listeners[0].Config["address"] = devListenAddress + if c.flagDev { + config = server.DevConfig(c.flagDevHA, c.flagDevTransactional) + if c.flagDevListenAddr != "" { + config.Listeners[0].Config["address"] = c.flagDevListenAddr } } - for _, path := range configPath { + for _, path := range c.flagConfigs { current, err := server.LoadConfig(path, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error loading configuration from %s: %s", path, err)) + c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err)) return 1 } @@ -367,43 +347,45 @@ func (c *ServerCommand) Run(args []string) int { // Ensure at least one config was found. if config == nil { - c.Ui.Output("No configuration files found.") + c.UI.Output(wrapAtLength( + "No configuration files found. Please provide configurations with the " + + "-config flag. If you are supply the path to a directory, please " + + "ensure the directory contains files with the .hcl or .json " + + "extension.")) return 1 } // Ensure that a backend is provided if config.Storage == nil { - c.Ui.Output("A storage backend must be specified") + c.UI.Output("A storage backend must be specified") return 1 } // If mlockall(2) isn't supported, show a warning. We disable this // in dev because it is quite scary to see when first using Vault. - if !dev && !mlock.Supported() { - c.Ui.Output("==> WARNING: mlock not supported on this system!\n") - c.Ui.Output(" An `mlockall(2)`-like syscall to prevent memory from being") - c.Ui.Output(" swapped to disk is not supported on this system. Running") - c.Ui.Output(" Vault on an mlockall(2) enabled system is much more secure.\n") + if !c.flagDev && !mlock.Supported() { + c.UI.Warn(wrapAtLength( + "WARNING! mlock is not supported on this system! An mlockall(2)-like " + + "syscall to prevent memory from being swapped to disk is not " + + "supported on this system. For better security, only run Vault on " + + "systems where this call is supported. If you are running Vault " + + "in a Docker container, provide the IPC_LOCK cap to the container.")) } if err := c.setupTelemetry(config); err != nil { - c.Ui.Output(fmt.Sprintf("Error initializing telemetry: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing telemetry: %s", err)) return 1 } // Initialize the backend factory, exists := c.PhysicalBackends[config.Storage.Type] if !exists { - c.Ui.Output(fmt.Sprintf( - "Unknown storage type %s", - config.Storage.Type)) + c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type)) return 1 } backend, err := factory(config.Storage.Config, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing storage of type %s: %s", - config.Storage.Type, err)) + c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err)) return 1 } @@ -417,13 +399,13 @@ func (c *ServerCommand) Run(args []string) int { if seal != nil { err = seal.Finalize() if err != nil { - c.Ui.Error(fmt.Sprintf("Error finalizing seals: %v", err)) + c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err)) } } }() if seal == nil { - c.Ui.Error(fmt.Sprintf("Could not create seal; most likely proper Seal configuration information was not set, but no error was generated.")) + c.UI.Error(fmt.Sprintf("Could not create seal! Most likely proper Seal configuration information was not set, but no error was generated.")) return 1 } @@ -445,14 +427,13 @@ func (c *ServerCommand) Run(args []string) int { PluginDirectory: config.PluginDirectory, EnableRaw: config.EnableRawEndpoint, } - - if dev { - coreConfig.DevToken = devRootTokenID - if devLeasedKV { + if c.flagDev { + coreConfig.DevToken = c.flagDevRootTokenID + if c.flagDevLeasedKV { coreConfig.LogicalBackends["kv"] = vault.LeasedPassthroughBackendFactory } - if devPluginDir != "" { - coreConfig.PluginDirectory = devPluginDir + if c.flagDevPluginDir != "" { + coreConfig.PluginDirectory = c.flagDevPluginDir } if c.flagDevLatency > 0 { injectLatency := time.Duration(c.flagDevLatency) * time.Millisecond @@ -464,8 +445,8 @@ func (c *ServerCommand) Run(args []string) int { } } - if devThreeNode { - return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys, devListenAddress) + if c.flagDevThreeNode { + return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys) } var disableClustering bool @@ -475,26 +456,25 @@ func (c *ServerCommand) Run(args []string) int { if config.HAStorage != nil { factory, exists := c.PhysicalBackends[config.HAStorage.Type] if !exists { - c.Ui.Output(fmt.Sprintf( - "Unknown HA storage type %s", - config.HAStorage.Type)) + c.UI.Error(fmt.Sprintf("Unknown HA storage type %s", config.HAStorage.Type)) return 1 + } habackend, err := factory(config.HAStorage.Config, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing HA storage of type %s: %s", - config.HAStorage.Type, err)) + c.UI.Error(fmt.Sprintf( + "Error initializing HA storage of type %s: %s", config.HAStorage.Type, err)) return 1 + } if coreConfig.HAPhysical, ok = habackend.(physical.HABackend); !ok { - c.Ui.Output("Specified HA storage does not support HA") + c.UI.Error("Specified HA storage does not support HA") return 1 } if !coreConfig.HAPhysical.HAEnabled() { - c.Ui.Output("Specified HA storage has HA support disabled; please consult documentation") + c.UI.Error("Specified HA storage has HA support disabled; please consult documentation") return 1 } @@ -529,14 +509,14 @@ func (c *ServerCommand) Run(args []string) int { if ok && coreConfig.RedirectAddr == "" { redirect, err := c.detectRedirect(detect, config) if err != nil { - c.Ui.Output(fmt.Sprintf("Error detecting redirect address: %s", err)) + c.UI.Error(fmt.Sprintf("Error detecting redirect address: %s", err)) } else if redirect == "" { - c.Ui.Output("Failed to detect redirect address.") + c.UI.Error("Failed to detect redirect address.") } else { coreConfig.RedirectAddr = redirect } } - if coreConfig.RedirectAddr == "" && dev { + if coreConfig.RedirectAddr == "" && c.flagDev { coreConfig.RedirectAddr = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) } @@ -551,14 +531,15 @@ func (c *ServerCommand) Run(args []string) int { switch { case coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "": addrToUse = coreConfig.RedirectAddr - case dev: + case c.flagDev: addrToUse = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) default: goto CLUSTER_SYNTHESIS_COMPLETE } u, err := url.ParseRequestURI(addrToUse) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing synthesized cluster address %s: %v", addrToUse, err)) + c.UI.Error(fmt.Sprintf( + "Error parsing synthesized cluster address %s: %v", addrToUse, err)) return 1 } host, port, err := net.SplitHostPort(u.Host) @@ -568,13 +549,14 @@ func (c *ServerCommand) Run(args []string) int { host = u.Host port = "443" } else { - c.Ui.Output(fmt.Sprintf("Error parsing redirect address: %v", err)) + c.UI.Error(fmt.Sprintf("Error parsing redirect address: %v", err)) return 1 } } nPort, err := strconv.Atoi(port) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err)) + c.UI.Error(fmt.Sprintf( + "Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err)) return 1 } u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1)) @@ -589,8 +571,8 @@ CLUSTER_SYNTHESIS_COMPLETE: // Force https as we'll always be TLS-secured u, err := url.ParseRequestURI(coreConfig.ClusterAddr) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) - return 1 + c.UI.Error(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) + return 11 } u.Scheme = "https" coreConfig.ClusterAddr = u.String() @@ -600,7 +582,7 @@ CLUSTER_SYNTHESIS_COMPLETE: core, newCoreError := vault.NewCore(coreConfig) if newCoreError != nil { if !errwrap.ContainsType(newCoreError, new(vault.NonFatalError)) { - c.Ui.Output(fmt.Sprintf("Error initializing core: %s", newCoreError)) + c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError)) return 1 } } @@ -611,7 +593,7 @@ CLUSTER_SYNTHESIS_COMPLETE: // Compile server information for output later info["storage"] = config.Storage.Type - info["log level"] = logLevel + info["log level"] = c.flagLogLevel info["mlock"] = fmt.Sprintf( "supported: %v, enabled: %v", mlock.Supported(), !config.DisableMlock && mlock.Supported()) @@ -648,9 +630,7 @@ CLUSTER_SYNTHESIS_COMPLETE: for i, lnConfig := range config.Listeners { ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing listener of type %s: %s", - lnConfig.Type, err)) + c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err)) return 1 } @@ -670,16 +650,14 @@ CLUSTER_SYNTHESIS_COMPLETE: addr = addrRaw.(string) tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error resolving cluster_address: %s", - err)) + c.UI.Error(fmt.Sprintf("Error resolving cluster_address: %s", err)) return 1 } clusterAddrs = append(clusterAddrs, tcpAddr) } else { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { - c.Ui.Output("Failed to parse tcp listener") + c.UI.Error("Failed to parse tcp listener") return 1 } clusterAddr := &net.TCPAddr{ @@ -737,17 +715,19 @@ CLUSTER_SYNTHESIS_COMPLETE: // Server configuration output padding := 24 sort.Strings(infoKeys) - c.Ui.Output("==> Vault server configuration:\n") + c.UI.Output("==> Vault server configuration:\n") for _, k := range infoKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "%s%s: %s", strings.Repeat(" ", padding-len(k)), strings.Title(k), info[k])) } - c.Ui.Output("") + c.UI.Output("") - if verifyOnly { + // Tests might not want to start a vault server and just want to verify + // the configuration. + if c.flagTestVerifyOnly { return 0 } @@ -761,7 +741,7 @@ CLUSTER_SYNTHESIS_COMPLETE: err = core.UnsealWithStoredKeys() if err != nil { if !errwrap.ContainsType(err, new(vault.NonFatalError)) { - c.Ui.Output(fmt.Sprintf("Error initializing core: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing core: %s", err)) return 1 } } @@ -791,18 +771,17 @@ CLUSTER_SYNTHESIS_COMPLETE: } if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil { - c.Ui.Output(fmt.Sprintf("Error initializing service discovery: %v", err)) + c.UI.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) return 1 } } } // If we're in Dev mode, then initialize the core - if dev && !devSkipInit { + if c.flagDev && !c.flagDevSkipInit { init, err := c.enableDev(core, coreConfig) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing Dev mode: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing Dev mode: %s", err)) return 1 } @@ -813,44 +792,49 @@ CLUSTER_SYNTHESIS_COMPLETE: quote = "" } - c.Ui.Output(fmt.Sprint( - "==> WARNING: Dev mode is enabled!\n\n" + - "In this mode, Vault is completely in-memory and unsealed.\n" + - "Vault is configured to only have a single unseal key. The root\n" + - "token has already been authenticated with the CLI, so you can\n" + - "immediately begin using the Vault CLI.\n\n" + - "The only step you need to take is to set the following\n" + - "environment variables:\n\n" + - " " + export + " VAULT_ADDR=" + quote + "http://" + config.Listeners[0].Config["address"].(string) + quote + "\n\n" + - "The unseal key and root token are reproduced below in case you\n" + - "want to seal/unseal the Vault or play with authentication.\n", - )) + // Print the big dev mode warning! + c.UI.Warn(wrapAtLength( + "WARNING! dev mode is enabled! In this mode, Vault runs entirely " + + "in-memory and starts unsealed with a single unseal key. The root " + + "token is already authenticated to the CLI, so you can immediately " + + "begin using Vault.")) + c.UI.Warn("") + c.UI.Warn("You may need to set the following environment variable:") + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", + export, quote, "http://"+config.Listeners[0].Config["address"].(string), quote)) // Unseal key is not returned if stored shares is supported if len(init.SecretShares) > 0 { - c.Ui.Output(fmt.Sprintf( - "Unseal Key: %s", - base64.StdEncoding.EncodeToString(init.SecretShares[0]), - )) + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "The unseal key and root token are displayed below in case you want " + + "to seal/unseal the Vault or re-authenticate.")) + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.SecretShares[0]))) } if len(init.RecoveryShares) > 0 { - c.Ui.Output(fmt.Sprintf( - "Recovery Key: %s", - base64.StdEncoding.EncodeToString(init.RecoveryShares[0]), - )) + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "The recovery key and root token are displayed below in case you want " + + "to seal/unseal the Vault or re-authenticate.")) + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.RecoveryShares[0]))) } - c.Ui.Output(fmt.Sprintf( - "Root Token: %s\n", - init.RootToken, - )) + c.UI.Warn(fmt.Sprintf("Root Token: %s", init.RootToken)) + + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "Development mode should NOT be used in production installations!")) + c.UI.Warn("") } // Initialize the HTTP server server := &http.Server{} if err := http2.ConfigureServer(server, nil); err != nil { - c.Ui.Output(fmt.Sprintf("Error configuring server for HTTP/2: %s", err)) + c.UI.Error(fmt.Sprintf("Error configuring server for HTTP/2: %s", err)) return 1 } server.Handler = handler @@ -859,12 +843,20 @@ CLUSTER_SYNTHESIS_COMPLETE: } if newCoreError != nil { - c.Ui.Output("==> Warning:\n\nNon-fatal error during initialization; check the logs for more information.") - c.Ui.Output("") + c.UI.Warn(wrapAtLength( + "WARNING! A non-fatal error occurred during initialization. Please " + + "check the logs for more information.")) + c.UI.Warn("") } // Output the header that the server has started - c.Ui.Output("==> Vault server started! Log data will stream in below:\n") + c.UI.Output("==> Vault server started! Log data will stream in below:\n") + + // Inform any tests that the server is ready + select { + case c.startedCh <- struct{}{}: + default: + } // Release the log gate. c.logGate.Flush() @@ -887,7 +879,7 @@ CLUSTER_SYNTHESIS_COMPLETE: for !shutdownTriggered { select { case <-c.ShutdownCh: - c.Ui.Output("==> Vault shutdown triggered") + c.UI.Output("==> Vault shutdown triggered") // Stop the listners so that we don't process further client requests. c.cleanupGuard.Do(listenerCloseFunc) @@ -896,15 +888,15 @@ CLUSTER_SYNTHESIS_COMPLETE: // request forwarding listeners will also be closed (and also // waited for). if err := core.Shutdown(); err != nil { - c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) + c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err)) } shutdownTriggered = true case <-c.SighupCh: - c.Ui.Output("==> Vault reload triggered") - if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, configPath); err != nil { - c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) + c.UI.Output("==> Vault reload triggered") + if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, c.flagConfigs); err != nil { + c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) } } } @@ -1031,10 +1023,10 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig return init, nil } -func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info map[string]string, infoKeys []string, devListenAddress string) int { +func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info map[string]string, infoKeys []string) int { testCluster := vault.NewTestCluster(&testing.RuntimeT{}, base, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, - BaseListenAddress: devListenAddress, + BaseListenAddress: c.flagDevListenAddr, }) defer c.cleanupGuard.Do(testCluster.Cleanup) @@ -1063,15 +1055,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // Server configuration output padding := 24 sort.Strings(infoKeys) - c.Ui.Output("==> Vault server configuration:\n") + c.UI.Output("==> Vault server configuration:\n") for _, k := range infoKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "%s%s: %s", strings.Repeat(" ", padding-len(k)), strings.Title(k), info[k])) } - c.Ui.Output("") + c.UI.Output("") for _, core := range testCluster.Cores { core.Server.Handler = vaulthttp.Handler(core.Core) @@ -1095,15 +1087,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m } resp, err := testCluster.Cores[0].HandleRequest(req) if err != nil { - c.Ui.Output(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) + c.UI.Error(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) return 1 } if resp == nil { - c.Ui.Output(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken)) + c.UI.Error(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken)) return 1 } if resp.Auth == nil { - c.Ui.Output(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken)) + c.UI.Error(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken)) return 1 } @@ -1114,7 +1106,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m req.Data = nil resp, err = testCluster.Cores[0].HandleRequest(req) if err != nil { - c.Ui.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) + c.UI.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) return 1 } } @@ -1122,37 +1114,37 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // Set the token tokenHelper, err := c.TokenHelper() if err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error getting token helper: %s", err)) return 1 } if err := tokenHelper.Store(testCluster.RootToken); err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error storing in token helper: %s", err)) return 1 } if err := ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(testCluster.RootToken), 0755); err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error writing token to tempfile: %s", err)) return 1 } - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "==> Three node dev mode is enabled\n\n" + "The unseal key and root token are reproduced below in case you\n" + "want to seal/unseal the Vault or play with authentication.\n", )) for i, key := range testCluster.BarrierKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "Unseal Key %d: %s", i+1, base64.StdEncoding.EncodeToString(key), )) } - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "\nRoot Token: %s\n", testCluster.RootToken, )) - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "\nUseful env vars:\n"+ "VAULT_TOKEN=%s\n"+ "VAULT_ADDR=%s\n"+ @@ -1163,7 +1155,13 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m )) // Output the header that the server has started - c.Ui.Output("==> Vault server started! Log data will stream in below:\n") + c.UI.Output("==> Vault server started! Log data will stream in below:\n") + + // Inform any tests that the server is ready + select { + case c.startedCh <- struct{}{}: + default: + } // Release the log gate. c.logGate.Flush() @@ -1174,7 +1172,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m for !shutdownTriggered { select { case <-c.ShutdownCh: - c.Ui.Output("==> Vault shutdown triggered") + c.UI.Output("==> Vault shutdown triggered") // Stop the listners so that we don't process further client requests. c.cleanupGuard.Do(testCluster.Cleanup) @@ -1184,17 +1182,17 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // waited for). for _, core := range testCluster.Cores { if err := core.Shutdown(); err != nil { - c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) + c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err)) } } shutdownTriggered = true case <-c.SighupCh: - c.Ui.Output("==> Vault reload triggered") + c.UI.Output("==> Vault reload triggered") for _, core := range testCluster.Cores { if err := c.Reload(core.ReloadFuncsLock, core.ReloadFuncs, nil); err != nil { - c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) + c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) } } } @@ -1405,68 +1403,11 @@ func (c *ServerCommand) Reload(lock *sync.RWMutex, reloadFuncs *map[string][]rel } } - return reloadErrors.ErrorOrNil() -} - -func (c *ServerCommand) Synopsis() string { - return "Start a Vault server" -} - -func (c *ServerCommand) Help() string { - helpText := ` -Usage: vault server [options] - - Start a Vault server. - - This command starts a Vault server that responds to API requests. - Vault will start in a "sealed" state. The Vault must be unsealed - with "vault unseal" or the API before this server can respond to requests. - This must be done for every server. - - If the server is being started against a storage backend that is - brand new (no existing Vault data in it), it must be initialized with - "vault init" or the API first. - - -General Options: - - -config= Path to the configuration file or directory. This can - be specified multiple times. If it is a directory, - all files with a ".hcl" or ".json" suffix will be - loaded. - - -dev Enables Dev mode. In this mode, Vault is completely - in-memory and unsealed. Do not run the Dev server in - production! - - -dev-root-token-id="" If set, the root token returned in Dev mode will have - the given ID. This *only* has an effect when running - in Dev mode. Can also be specified with the - VAULT_DEV_ROOT_TOKEN_ID environment variable. - - -dev-listen-address="" If set, this overrides the normal Dev mode listen - address of "127.0.0.1:8200". Can also be specified - with the VAULT_DEV_LISTEN_ADDRESS environment - variable. - - -log-level=info Log verbosity. Defaults to "info", will be output to - stderr. Supported values: "trace", "debug", "info", - "warn", "err" -` - return strings.TrimSpace(helpText) -} - -func (c *ServerCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *ServerCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-config": complete.PredictOr(complete.PredictFiles("*.hcl"), complete.PredictFiles("*.json")), - "-dev": complete.PredictNothing, - "-dev-root-token-id": complete.PredictNothing, - "-dev-listen-address": complete.PredictNothing, - "-log-level": complete.PredictSet("trace", "debug", "info", "warn", "err"), + // Send a message that we reloaded. This prevents "guessing" sleep times + // in tests. + select { + case c.reloadedCh <- struct{}{}: + default: } return reloadErrors.ErrorOrNil() diff --git a/command/server_ha_test.go b/command/server_ha_test.go deleted file mode 100644 index a9b1188126..0000000000 --- a/command/server_ha_test.go +++ /dev/null @@ -1,106 +0,0 @@ -// +build !race - -package command - -import ( - "io/ioutil" - "os" - "strings" - "testing" - - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/physical" - "github.com/mitchellh/cli" - - physConsul "github.com/hashicorp/vault/physical/consul" -) - -// The following tests have a go-metrics/exp manager race condition -func TestServer_CommonHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "(HA available)") { - t.Fatalf("did not find HA available: %s", ui.OutputWriter.String()) - } -} - -func TestServer_GoodSeparateHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl + haconsulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "HA Storage:") { - t.Fatalf("did not find HA Storage: %s", ui.OutputWriter.String()) - } -} - -func TestServer_BadSeparateHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl + badhaconsulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name()} - - if code := c.Run(args); code == 0 { - t.Fatalf("bad: should have gotten an error on a bad HA config") - } -} diff --git a/command/server_test.go b/command/server_test.go index 9a90239011..c15cf0596f 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -1,4 +1,5 @@ // +build !race +// The server tests have a go-metrics/exp manager race condition :(. package command @@ -7,72 +8,112 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "math/rand" + "net" "os" "strings" "sync" "testing" "time" - "github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/physical" "github.com/mitchellh/cli" + physConsul "github.com/hashicorp/vault/physical/consul" physFile "github.com/hashicorp/vault/physical/file" ) -var ( - basehcl = ` -disable_mlock = true +func testRandomPort(tb testing.TB) int { + tb.Helper() -listener "tcp" { - address = "127.0.0.1:8200" - tls_disable = "true" + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + if err != nil { + tb.Fatal(err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + tb.Fatal(err) + } + defer l.Close() + + return l.Addr().(*net.TCPAddr).Port } -` - consulhcl = ` +func testBaseHCL(tb testing.TB) string { + tb.Helper() + + return strings.TrimSpace(fmt.Sprintf(` + disable_mlock = true + listener "tcp" { + address = "127.0.0.1:%d" + tls_disable = "true" + } + `, testRandomPort(tb))) +} + +const ( + consulHCL = ` backend "consul" { - prefix = "foo/" - advertise_addr = "http://127.0.0.1:8200" - disable_registration = "true" + prefix = "foo/" + advertise_addr = "http://127.0.0.1:8200" + disable_registration = "true" } ` - haconsulhcl = ` + haConsulHCL = ` ha_backend "consul" { - prefix = "bar/" - redirect_addr = "http://127.0.0.1:8200" - disable_registration = "true" + prefix = "bar/" + redirect_addr = "http://127.0.0.1:8200" + disable_registration = "true" } ` - badhaconsulhcl = ` + badHAConsulHCL = ` ha_backend "file" { - path = "/dev/null" + path = "/dev/null" } ` - reloadhcl = ` + reloadHCL = ` backend "file" { - path = "/dev/null" + path = "/dev/null" } - disable_mlock = true - listener "tcp" { - address = "127.0.0.1:8203" - tls_cert_file = "TMPDIR/reload_cert.pem" - tls_key_file = "TMPDIR/reload_key.pem" + address = "127.0.0.1:8203" + tls_cert_file = "TMPDIR/reload_cert.pem" + tls_key_file = "TMPDIR/reload_key.pem" } ` ) -// The following tests have a go-metrics/exp manager race condition +func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &ServerCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), + PhysicalBackends: map[string]physical.Factory{ + "file": physFile.NewFileBackend, + "consul": physConsul.NewConsulBackend, + }, + + // These prevent us from random sleep guessing... + startedCh: make(chan struct{}, 5), + reloadedCh: make(chan struct{}, 5), + } +} + func TestServer_ReloadListener(t *testing.T) { + t.Parallel() + wd, _ := os.Getwd() wd += "/server/test-fixtures/reload/" - td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63)) + td, err := ioutil.TempDir("", "vault-test-") if err != nil { t.Fatal(err) } @@ -86,7 +127,7 @@ func TestServer_ReloadListener(t *testing.T) { inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key") ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) - relhcl := strings.Replace(reloadhcl, "TMPDIR", td, -1) + relhcl := strings.Replace(reloadHCL, "TMPDIR", td, -1) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem") @@ -96,17 +137,8 @@ func TestServer_ReloadListener(t *testing.T) { t.Fatal("not ok when appending CA cert") } - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - ShutdownCh: MakeShutdownCh(), - SighupCh: MakeSighupCh(), - PhysicalBackends: map[string]physical.Factory{ - "file": physFile.NewFileBackend, - }, - } + ui, cmd := testServerCommand(t) + _ = ui finished := false finishedMutex := sync.Mutex{} @@ -114,7 +146,7 @@ func TestServer_ReloadListener(t *testing.T) { wg.Add(1) args := []string{"-config", td + "/reload.hcl"} go func() { - if code := c.Run(args); code != 0 { + if code := cmd.Run(args); code != 0 { t.Error("got a non-zero exit status") } finishedMutex.Lock() @@ -123,14 +155,6 @@ func TestServer_ReloadListener(t *testing.T) { wg.Done() }() - checkFinished := func() { - finishedMutex.Lock() - if finished { - t.Fatalf(fmt.Sprintf("finished early; relhcl was\n%s\nstdout was\n%s\nstderr was\n%s\n", relhcl, ui.OutputWriter.String(), ui.ErrorWriter.String())) - } - finishedMutex.Unlock() - } - testCertificateName := func(cn string) error { conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{ RootCAs: certPool, @@ -149,31 +173,95 @@ func TestServer_ReloadListener(t *testing.T) { return nil } - checkFinished() - time.Sleep(5 * time.Second) - checkFinished() + select { + case <-cmd.startedCh: + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } if err := testCertificateName("foo.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } - relhcl = strings.Replace(reloadhcl, "TMPDIR", td, -1) + relhcl = strings.Replace(reloadHCL, "TMPDIR", td, -1) inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem") ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0777) inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key") ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) - c.SighupCh <- struct{}{} - checkFinished() - time.Sleep(2 * time.Second) - checkFinished() + cmd.SighupCh <- struct{}{} + select { + case <-cmd.reloadedCh: + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } if err := testCertificateName("bar.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } - c.ShutdownCh <- struct{}{} + cmd.ShutdownCh <- struct{}{} wg.Wait() } + +func TestServer(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + contents string + exp string + code int + }{ + { + "common_ha", + testBaseHCL(t) + consulHCL, + "(HA available)", + 0, + }, + { + "separate_ha", + testBaseHCL(t) + consulHCL + haConsulHCL, + "HA Storage:", + 0, + }, + { + "bad_separate_ha", + testBaseHCL(t) + consulHCL + badHAConsulHCL, + "Specified HA storage does not support HA", + 1, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testServerCommand(t) + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("error creating temp dir: %v", err) + } + f.WriteString(tc.contents) + f.Close() + defer os.Remove(f.Name()) + + code := cmd.Run([]string{ + "-config", f.Name(), + "-test-verify-only", + }) + output := ui.ErrorWriter.String() + ui.OutputWriter.String() + if code != tc.code { + t.Errorf("expected %d to be %d: %s", code, tc.code, output) + } + + if !strings.Contains(output, tc.exp) { + t.Fatalf("expected %q to contain %q", output, tc.exp) + } + }) + } +} diff --git a/command/status_test.go b/command/status_test.go index 717d4afbc8..e34a72c578 100644 --- a/command/status_test.go +++ b/command/status_test.go @@ -32,14 +32,14 @@ func TestStatusCommand_Run(t *testing.T) { "unsealed", nil, false, - "Sealed: false", + "Sealed false", 0, }, { "sealed", nil, true, - "Sealed: true", + "Sealed true", 2, }, {