Add service discovery to init command

This commit is contained in:
vishalnayak 2016-07-20 15:38:53 -04:00
parent 95597a4a9e
commit d22204914d
2 changed files with 131 additions and 17 deletions

View File

@ -45,7 +45,7 @@ type InitStatusResponse struct {
} }
type InitResponse struct { type InitResponse struct {
Keys []string Keys []string `json:"keys"`
RecoveryKeys []string `json:"recovery_keys"` RecoveryKeys []string `json:"recovery_keys"`
RootToken string `json:"root_token"` RootToken string `json:"root_token"`
} }

View File

@ -2,8 +2,11 @@ package command
import ( import (
"fmt" "fmt"
"os"
"runtime"
"strings" "strings"
consulapi "github.com/hashicorp/consul/api"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
@ -18,6 +21,7 @@ func (c *InitCommand) Run(args []string) int {
var threshold, shares, storedShares, recoveryThreshold, recoveryShares int var threshold, shares, storedShares, recoveryThreshold, recoveryShares int
var pgpKeys, recoveryPgpKeys pgpkeys.PubKeyFilesFlag var pgpKeys, recoveryPgpKeys pgpkeys.PubKeyFilesFlag
var check bool var check bool
var auto string
flags := c.Meta.FlagSet("init", meta.FlagSetDefault) flags := c.Meta.FlagSet("init", meta.FlagSetDefault)
flags.Usage = func() { c.Ui.Error(c.Help()) } flags.Usage = func() { c.Ui.Error(c.Help()) }
flags.IntVar(&shares, "key-shares", 5, "") flags.IntVar(&shares, "key-shares", 5, "")
@ -28,10 +32,128 @@ func (c *InitCommand) Run(args []string) int {
flags.IntVar(&recoveryThreshold, "recovery-threshold", 3, "") flags.IntVar(&recoveryThreshold, "recovery-threshold", 3, "")
flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "") flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "")
flags.BoolVar(&check, "check", false, "") flags.BoolVar(&check, "check", false, "")
flags.StringVar(&auto, "auto", "", "")
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
return 1 return 1
} }
initRequest := &api.InitRequest{
SecretShares: shares,
SecretThreshold: threshold,
StoredShares: storedShares,
PGPKeys: pgpKeys,
RecoveryShares: recoveryShares,
RecoveryThreshold: recoveryThreshold,
RecoveryPGPKeys: recoveryPgpKeys,
}
// If running in 'auto' mode, run service discovery based on environment
// variables of Consul.
if auto != "" {
// Create configuration for Consul
consulConfig := consulapi.DefaultConfig()
// Create a client to communicate with Consul
consulClient, err := consulapi.NewClient(consulConfig)
if err != nil {
c.Ui.Error(fmt.Sprintf("failed to create Consul client:%v", err))
return 1
}
var uninitializedVaults []string
var initializedVault string
// Query the nodes belonging to the cluster
if services, _, err := consulClient.Catalog().Service(auto, "", &consulapi.QueryOptions{AllowStale: true}); err == nil {
Loop:
for _, service := range services {
vaultAddress := fmt.Sprintf("%s://%s:%d", consulConfig.Scheme, service.ServiceAddress, service.ServicePort)
// Set VAULT_ADDR to the discovered node
os.Setenv(api.EnvVaultAddress, vaultAddress)
// Create a client to communicate with the discovered node
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing client: %s", err))
return 1
}
// Check the initialization status of the discovered node
inited, err := client.Sys().InitStatus()
switch {
case err != nil:
c.Ui.Error(fmt.Sprintf("Error checking initialization status of discovered node: %s err:%s", vaultAddress, err))
return 1
case inited:
// One of the nodes in the cluster is initialized. Break out.
initializedVault = vaultAddress
break Loop
default:
// Vault is uninitialized.
uninitializedVaults = append(uninitializedVaults, vaultAddress)
}
}
}
export := "export"
quote := "'"
if runtime.GOOS == "windows" {
export = "set"
quote = ""
}
if initializedVault != "" {
c.Ui.Output(fmt.Sprintf("Discovered an initialized Vault node at '%s'\n", initializedVault))
c.Ui.Output("Set the following environment variable to operate on the discovered Vault:\n")
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, initializedVault, quote))
return 0
}
switch len(uninitializedVaults) {
case 0:
c.Ui.Error(fmt.Sprintf("Failed to discover Vault nodes under the service name '%s'", auto))
return 1
case 1:
// There was only one node found in the Vault cluster and it
// was uninitialized.
// Set the VAULT_ADDR to the discovered node. This will ensure
// that the client created will operate on the discovered node.
os.Setenv(api.EnvVaultAddress, uninitializedVaults[0])
// Let the client know that initialization is perfomed on the
// discovered node.
c.Ui.Output(fmt.Sprintf("Discovered Vault at '%s'\n", uninitializedVaults[0]))
// Attempt initializing it
ret := c.runInit(check, initRequest)
// Regardless of success or failure, instruct client to update VAULT_ADDR
c.Ui.Output("Set the following environment variable to operate on the discovered Vault:\n")
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, uninitializedVaults[0], quote))
return ret
default:
// If more than one Vault node were discovered, print out all of them,
// requiring the client to update VAULT_ADDR and to run init again.
c.Ui.Output(fmt.Sprintf("Discovered more than one uninitialized Vaults under the service name '%s'\n", auto))
c.Ui.Output("To initialize all Vaults, set any *one* of the following and run 'vault init':")
// Print valid commands to make setting the variables easier
for _, vaultNode := range uninitializedVaults {
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, vaultNode, quote))
}
return 0
}
}
return c.runInit(check, initRequest)
}
func (c *InitCommand) runInit(check bool, initRequest *api.InitRequest) int {
client, err := c.Client() client, err := c.Client()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
@ -43,15 +165,7 @@ func (c *InitCommand) Run(args []string) int {
return c.checkStatus(client) return c.checkStatus(client)
} }
resp, err := client.Sys().Init(&api.InitRequest{ resp, err := client.Sys().Init(initRequest)
SecretShares: shares,
SecretThreshold: threshold,
StoredShares: storedShares,
PGPKeys: pgpKeys,
RecoveryShares: recoveryShares,
RecoveryThreshold: recoveryThreshold,
RecoveryPGPKeys: recoveryPgpKeys,
})
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error initializing Vault: %s", err)) "Error initializing Vault: %s", err))
@ -67,7 +181,7 @@ func (c *InitCommand) Run(args []string) int {
c.Ui.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken)) c.Ui.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken))
if storedShares < 1 { if initRequest.StoredShares < 1 {
c.Ui.Output(fmt.Sprintf( c.Ui.Output(fmt.Sprintf(
"\n"+ "\n"+
"Vault initialized with %d keys and a key threshold of %d. Please\n"+ "Vault initialized with %d keys and a key threshold of %d. Please\n"+
@ -76,10 +190,10 @@ func (c *InitCommand) Run(args []string) int {
"to unseal it again.\n\n"+ "to unseal it again.\n\n"+
"Vault does not store the master key. Without at least %d keys,\n"+ "Vault does not store the master key. Without at least %d keys,\n"+
"your Vault will remain permanently sealed.", "your Vault will remain permanently sealed.",
shares, initRequest.SecretShares,
threshold, initRequest.SecretThreshold,
threshold, initRequest.SecretThreshold,
threshold, initRequest.SecretThreshold,
)) ))
} else { } else {
c.Ui.Output( c.Ui.Output(
@ -92,8 +206,8 @@ func (c *InitCommand) Run(args []string) int {
"\n"+ "\n"+
"Recovery key initialized with %d keys and a key threshold of %d. Please\n"+ "Recovery key initialized with %d keys and a key threshold of %d. Please\n"+
"securely distribute the above keys.", "securely distribute the above keys.",
recoveryShares, initRequest.RecoveryShares,
recoveryThreshold, initRequest.RecoveryThreshold,
)) ))
} }