diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index cc8e339333..b827310f0e 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -80,6 +80,8 @@ func TestCopy_response(t *testing.T) { } func TestHash(t *testing.T) { + now := time.Now().UTC() + cases := []struct { Input interface{} Output interface{} @@ -116,6 +118,24 @@ func TestHash(t *testing.T) { "foo", "foo", }, + { + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + Lease: 1 * time.Hour, + LeaseIssue: now, + }, + + ClientToken: "foo", + }, + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + Lease: 1 * time.Hour, + LeaseIssue: now, + }, + + ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", + }, + }, } for _, tc := range cases { diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 14dbf8212a..2b673271e3 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -61,7 +61,7 @@ func (b *backend) pathLogin( Policies: matched.Entry.Policies, DisplayName: matched.Entry.DisplayName, Metadata: map[string]string{ - "cert_name": matched.Entry.Name, + "cert_name": matched.Entry.Name, "common_name": connState.PeerCertificates[0].Subject.CommonName, }, LeaseOptions: logical.LeaseOptions{ @@ -187,5 +187,5 @@ func (b *backend) pathLoginRenew( return nil, nil } - return framework.LeaseExtend(cert.Lease, 0)(req, d) + return framework.LeaseExtend(cert.Lease, 0, false)(req, d) } diff --git a/builtin/credential/ldap/path_login.go b/builtin/credential/ldap/path_login.go index 5b77f5772e..ad566b771a 100644 --- a/builtin/credential/ldap/path_login.go +++ b/builtin/credential/ldap/path_login.go @@ -77,7 +77,7 @@ func (b *backend) pathLoginRenew( return logical.ErrorResponse("policies have changed, revoking login"), nil } - return framework.LeaseExtend(1*time.Hour, 0)(req, d) + return framework.LeaseExtend(1*time.Hour, 0, false)(req, d) } const pathLoginSyn = ` diff --git a/builtin/credential/userpass/path_login.go b/builtin/credential/userpass/path_login.go index 54b0df1c39..7e427d825c 100644 --- a/builtin/credential/userpass/path_login.go +++ b/builtin/credential/userpass/path_login.go @@ -68,7 +68,7 @@ func (b *backend) pathLoginRenew( return nil, nil } - return framework.LeaseExtend(1*time.Hour, 0)(req, d) + return framework.LeaseExtend(1*time.Hour, 0, false)(req, d) } const pathLoginSyn = ` diff --git a/builtin/logical/aws/secret_access_keys.go b/builtin/logical/aws/secret_access_keys.go index 9ddb4d9a36..5d03b312dc 100644 --- a/builtin/logical/aws/secret_access_keys.go +++ b/builtin/logical/aws/secret_access_keys.go @@ -115,7 +115,7 @@ func (b *backend) secretAccessKeysRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) return f(req, d) } diff --git a/builtin/logical/consul/secret_token.go b/builtin/logical/consul/secret_token.go index 06679b1313..e6e83de6b0 100644 --- a/builtin/logical/consul/secret_token.go +++ b/builtin/logical/consul/secret_token.go @@ -26,7 +26,7 @@ func secretToken() *framework.Secret { DefaultDuration: DefaultLeaseDuration, DefaultGracePeriod: DefaultGracePeriod, - Renew: framework.LeaseExtend(1*time.Hour, 0), + Renew: framework.LeaseExtend(0, 0, true), Revoke: secretTokenRevoke, } } diff --git a/builtin/logical/mysql/secret_creds.go b/builtin/logical/mysql/secret_creds.go index 5bed159764..c60beb0add 100644 --- a/builtin/logical/mysql/secret_creds.go +++ b/builtin/logical/mysql/secret_creds.go @@ -44,7 +44,7 @@ func (b *backend) secretCredsRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) return f(req, d) } diff --git a/builtin/logical/postgresql/path_role_create.go b/builtin/logical/postgresql/path_role_create.go index 497b9a6c63..7d74148512 100644 --- a/builtin/logical/postgresql/path_role_create.go +++ b/builtin/logical/postgresql/path_role_create.go @@ -2,7 +2,6 @@ package postgresql import ( "fmt" - "math/rand" "time" "github.com/hashicorp/vault/logical" @@ -51,10 +50,15 @@ func (b *backend) pathRoleCreateRead( lease = &configLease{Lease: 1 * time.Hour} } - // Generate the username, password and expiration - username := fmt.Sprintf( - "vault-%s-%d-%d", - req.DisplayName, time.Now().Unix(), rand.Int31n(10000)) + // Generate the username, password and expiration. PG limits user to 63 characters + displayName := req.DisplayName + if len(displayName) > 26 { + displayName = displayName[:26] + } + username := fmt.Sprintf("%s-%s", displayName, generateUUID()) + if len(username) > 63 { + username = username[:63] + } password := generateUUID() expiration := time.Now().UTC(). Add(lease.Lease + time.Duration((float64(lease.Lease) * 0.1))). diff --git a/builtin/logical/postgresql/secret_creds.go b/builtin/logical/postgresql/secret_creds.go index 00714eb6f7..204daa073c 100644 --- a/builtin/logical/postgresql/secret_creds.go +++ b/builtin/logical/postgresql/secret_creds.go @@ -58,7 +58,7 @@ func (b *backend) secretCredsRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) resp, err := f(req, d) if err != nil { return nil, err diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index b63f903b13..d719d42c64 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -15,11 +15,13 @@ func Backend() *framework.Backend { PathsSpecial: &logical.Paths{ Root: []string{ "keys/*", + "raw/*", }, }, Paths: []*framework.Path{ pathKeys(), + pathRaw(), pathEncrypt(), pathDecrypt(), }, diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index c4c7c3d435..c03ab0e326 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -21,10 +21,27 @@ func TestBackend_basic(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test"), testAccStepReadPolicy(t, "test", false), + testAccStepReadRaw(t, "test", false), testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true), + testAccStepReadRaw(t, "test", true), + }, + }) +} + +func TestBackend_upsert(t *testing.T) { + decryptData := make(map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepReadPolicy(t, "test", true), + testAccStepEncrypt(t, "test", testPlaintext, decryptData), + testAccStepReadPolicy(t, "test", false), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeletePolicy(t, "test"), + testAccStepReadPolicy(t, "test", true), }, }) } @@ -65,6 +82,43 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicalte return err } + if d.Name != name { + return fmt.Errorf("bad: %#v", d) + } + if d.CipherMode != "aes-gcm" { + return fmt.Errorf("bad: %#v", d) + } + // Should NOT get a key back + if d.Key != nil { + return fmt.Errorf("bad: %#v", d) + } + return nil + }, + } +} + +func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "raw/" + name, + Check: func(resp *logical.Response) error { + if resp == nil && !expectNone { + return fmt.Errorf("missing response") + } else if expectNone { + if resp != nil { + return fmt.Errorf("response when expecting none") + } + return nil + } + var d struct { + Name string `mapstructure:"name"` + Key []byte `mapstructure:"key"` + CipherMode string `mapstructure:"cipher_mode"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Name != name { return fmt.Errorf("bad: %#v", d) } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index d30d72f3b4..761af7e352 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "fmt" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -56,7 +57,10 @@ func pathEncryptWrite( // Error if invalid policy if p == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + p, err = generatePolicy(req.Storage, name) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest + } } // Guard against a potentially invalid cipher-mode diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index b856964689..f83ae4037e 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -45,6 +45,41 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) { return p, nil } +// generatePolicy is used to create a new named policy with +// a randomly generated key +func generatePolicy(storage logical.Storage, name string) (*Policy, error) { + // Create the policy object + p := &Policy{ + Name: name, + CipherMode: "aes-gcm", + } + + // Generate a 256bit key + p.Key = make([]byte, 32) + _, err := rand.Read(p.Key) + if err != nil { + return nil, err + } + + // Encode the policy + buf, err := p.Serialize() + if err != nil { + return nil, err + } + + // Write the policy into storage + err = storage.Put(&logical.StorageEntry{ + Key: "policy/" + name, + Value: buf, + }) + if err != nil { + return nil, err + } + + // Return the policy + return p, nil +} + func pathKeys() *framework.Path { return &framework.Path{ Pattern: `keys/(?P\w+)`, @@ -79,34 +114,9 @@ func pathPolicyWrite( return nil, nil } - // Create the policy object - p := &Policy{ - Name: name, - CipherMode: "aes-gcm", - } - - // Generate a 256bit key - p.Key = make([]byte, 32) - _, err = rand.Read(p.Key) - if err != nil { - return nil, err - } - - // Encode the policy - buf, err := p.Serialize() - if err != nil { - return nil, err - } - - // Write the policy into storage - err = req.Storage.Put(&logical.StorageEntry{ - Key: "policy/" + name, - Value: buf, - }) - if err != nil { - return nil, err - } - return nil, nil + // Generate the policy + _, err = generatePolicy(req.Storage, name) + return nil, err } func pathPolicyRead( @@ -124,7 +134,6 @@ func pathPolicyRead( resp := &logical.Response{ Data: map[string]interface{}{ "name": p.Name, - "key": p.Key, "cipher_mode": p.CipherMode, }, } diff --git a/builtin/logical/transit/path_raw.go b/builtin/logical/transit/path_raw.go new file mode 100644 index 0000000000..ebe411a571 --- /dev/null +++ b/builtin/logical/transit/path_raw.go @@ -0,0 +1,54 @@ +package transit + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRaw() *framework.Path { + return &framework.Path{ + Pattern: `raw/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: pathRawRead, + }, + + HelpSynopsis: pathPolicyHelpSyn, + HelpDescription: pathPolicyHelpDesc, + } +} + +func pathRawRead( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if p == nil { + return nil, nil + } + + // Return the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "key": p.Key, + "cipher_mode": p.CipherMode, + }, + } + return resp, nil +} + +const pathRawHelpSyn = `Fetch raw keys for named encrption keys` + +const pathRawHelpDesc = ` +This path is used to get the underlying encryption keys used for the +named keys that are available. +` diff --git a/cli/commands.go b/cli/commands.go index d007a96ba9..d115f4a4fb 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -2,6 +2,8 @@ package cli import ( "os" + "os/signal" + "syscall" auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" @@ -70,11 +72,12 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { "transit": transit.Factory, "mysql": mysql.Factory, }, + ShutdownCh: makeShutdownCh(), }, nil }, - "help": func() (cli.Command, error) { - return &command.HelpCommand{ + "path-help": func() (cli.Command, error) { + return &command.PathHelpCommand{ Meta: meta, }, nil }, @@ -270,3 +273,20 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { }, } } + +// makeShutdownCh returns a channel that can be used for shutdown +// notifications for commands. This channel will send a message for every +// interrupt or SIGTERM received. +func makeShutdownCh() <-chan struct{} { + resultCh := make(chan struct{}) + + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + go func() { + for { + <-signalCh + resultCh <- struct{}{} + } + }() + return resultCh +} diff --git a/cli/help.go b/cli/help.go index 6c3e63d8b2..b614212c4f 100644 --- a/cli/help.go +++ b/cli/help.go @@ -13,14 +13,14 @@ import ( // HelpFunc is a cli.HelpFunc that can is used to output the help for Vault. func HelpFunc(commands map[string]cli.CommandFactory) string { commonNames := map[string]struct{}{ - "delete": struct{}{}, - "help": struct{}{}, - "read": struct{}{}, - "renew": struct{}{}, - "revoke": struct{}{}, - "write": struct{}{}, - "server": struct{}{}, - "status": struct{}{}, + "delete": struct{}{}, + "path-help": struct{}{}, + "read": struct{}{}, + "renew": struct{}{}, + "revoke": struct{}{}, + "write": struct{}{}, + "server": struct{}{}, + "status": struct{}{}, } // Determine the maximum key length, and classify based on type diff --git a/command/auth.go b/command/auth.go index 130dd7d278..f31b2a358a 100644 --- a/command/auth.go +++ b/command/auth.go @@ -114,6 +114,15 @@ func (c *AuthCommand) Run(args []string) int { return 0 } + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence + if os.Getenv("VAULT_TOKEN") != "" { + c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") + c.Ui.Output(" The environment variable takes precedence over the value") + c.Ui.Output(" set by the auth command. Either update the value of the") + c.Ui.Output(" environment variable or unset it to use the new token.\n") + } + var vars map[string]string if len(args) > 0 { builder := kvbuilder.Builder{Stdin: os.Stdin} diff --git a/command/format.go b/command/format.go index d3fad141a4..0b6aa7d968 100644 --- a/command/format.go +++ b/command/format.go @@ -53,6 +53,16 @@ func outputFormatTable(ui cli.Ui, s *api.Secret, whitespace bool) int { "lease_renewable %s %s", config.Delim, strconv.FormatBool(s.Renewable))) } + if s.Auth != nil { + input = append(input, fmt.Sprintf("token %s %s", config.Delim, s.Auth.ClientToken)) + input = append(input, fmt.Sprintf("token_duration %s %d", config.Delim, s.Auth.LeaseDuration)) + input = append(input, fmt.Sprintf("token_renewable %s %v", config.Delim, s.Auth.Renewable)) + input = append(input, fmt.Sprintf("token_policies %s %v", config.Delim, s.Auth.Policies)) + for k, v := range s.Auth.Metadata { + input = append(input, fmt.Sprintf("token_meta_%s %s %#v", k, config.Delim, v)) + } + } + for k, v := range s.Data { input = append(input, fmt.Sprintf("%s %s %v", k, config.Delim, v)) } diff --git a/command/help.go b/command/path_help.go similarity index 69% rename from command/help.go rename to command/path_help.go index f832f07bc6..792ea9e3a4 100644 --- a/command/help.go +++ b/command/path_help.go @@ -5,12 +5,12 @@ import ( "strings" ) -// HelpCommand is a Command that lists the mounts. -type HelpCommand struct { +// PathHelpCommand is a Command that lists the mounts. +type PathHelpCommand struct { Meta } -func (c *HelpCommand) Run(args []string) int { +func (c *PathHelpCommand) Run(args []string) int { flags := c.Meta.FlagSet("help", FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { @@ -35,8 +35,15 @@ func (c *HelpCommand) Run(args []string) int { help, err := client.Help(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading help: %s", err)) + if strings.Contains(err.Error(), "Vault is sealed") { + c.Ui.Error(`Error: Vault is sealed. + +The path-help command requires the Vault to be unsealed so that +mount points of secret backends are known.`) + } else { + c.Ui.Error(fmt.Sprintf( + "Error reading help: %s", err)) + } return 1 } @@ -44,13 +51,13 @@ func (c *HelpCommand) Run(args []string) int { return 0 } -func (c *HelpCommand) Synopsis() string { +func (c *PathHelpCommand) Synopsis() string { return "Look up the help for a path" } -func (c *HelpCommand) Help() string { +func (c *PathHelpCommand) Help() string { helpText := ` -Usage: vault help [options] path +Usage: vault path-help [options] path Look up the help for a path. @@ -58,6 +65,9 @@ Usage: vault help [options] path providers provide built-in help. This command looks up and outputs that help. + The command requires that the Vault be unsealed, because otherwise + the mount points of the backends are unknown. + General Options: -address=addr The address of the Vault server. diff --git a/command/help_test.go b/command/path_help_test.go similarity index 95% rename from command/help_test.go rename to command/path_help_test.go index c4facc0ca8..faec9723d9 100644 --- a/command/help_test.go +++ b/command/path_help_test.go @@ -14,7 +14,7 @@ func TestHelp(t *testing.T) { defer ln.Close() ui := new(cli.MockUi) - c := &HelpCommand{ + c := &PathHelpCommand{ Meta: Meta{ ClientToken: token, Ui: ui, diff --git a/command/read.go b/command/read.go index 8f94bcba8f..983ae56bb3 100644 --- a/command/read.go +++ b/command/read.go @@ -22,12 +22,16 @@ func (c *ReadCommand) Run(args []string) int { } args = flags.Args() - if len(args) < 1 || len(args) > 2 { - c.Ui.Error("read expects one or two arguments") + if len(args) != 1 { + c.Ui.Error("read expects one argument") flags.Usage() return 1 } + path := args[0] + if path[0] == '/' { + path = path[1:] + } client, err := c.Client() if err != nil { @@ -98,7 +102,7 @@ Read Options: -format=table The format for output. By default it is a whitespace- delimited table. This can also be json. - -field=field If included, the raw value of the specified field + -field=field If included, the raw value of the specified field will be output raw to stdout. ` diff --git a/command/server.go b/command/server.go index 2796b22cda..f3db676d76 100644 --- a/command/server.go +++ b/command/server.go @@ -32,6 +32,7 @@ type ServerCommand struct { CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory + ShutdownCh <-chan struct{} Meta } @@ -154,7 +155,7 @@ func (c *ServerCommand) Run(args []string) int { "immediately begin using the Vault CLI.\n\n"+ "The only step you need to take is to set the following\n"+ "environment variables:\n\n"+ - " export VAULT_ADDR='http://127.0.0.1:8200'\n"+ + " export VAULT_ADDR='http://127.0.0.1:8200'\n\n"+ "The unseal key and root token are reproduced below in case you\n"+ "want to seal/unseal the Vault or play with authentication.\n\n"+ "Unseal Key: %s\nRoot Token: %s\n", @@ -237,7 +238,14 @@ func (c *ServerCommand) Run(args []string) int { // Release the log gate. logGate.Flush() - <-make(chan struct{}) + // Wait for shutdown + select { + case <-c.ShutdownCh: + c.Ui.Output("==> Vault shutdown triggered") + if err := core.Shutdown(); err != nil { + c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) + } + } return 0 } @@ -407,8 +415,8 @@ General Options: specified multiple times. If it is a directory, all files with a ".hcl" or ".json" suffix will be loaded. - -dev Enables Dev mode. In this mode, Vault is completely - in-memory and unsealed. Do not run the Dev server in + -dev Enables Dev mode. In this mode, Vault is completely + in-memory and unsealed. Do not run the Dev server in production! -log-level=info Log verbosity. Defaults to "info", will be outputted diff --git a/command/token_create.go b/command/token_create.go index 85107e3fde..21d85d92b5 100644 --- a/command/token_create.go +++ b/command/token_create.go @@ -15,12 +15,14 @@ type TokenCreateCommand struct { } func (c *TokenCreateCommand) Run(args []string) int { + var format string var displayName, lease string var orphan bool var metadata map[string]string var numUses int var policies []string flags := c.Meta.FlagSet("mount", FlagSetDefault) + flags.StringVar(&format, "format", "table", "") flags.StringVar(&displayName, "display-name", "", "") flags.StringVar(&lease, "lease", "", "") flags.BoolVar(&orphan, "orphan", false, "") @@ -61,8 +63,7 @@ func (c *TokenCreateCommand) Run(args []string) int { return 2 } - c.Ui.Output(secret.Auth.ClientToken) - return 0 + return OutputSecret(c.Ui, format, secret) } func (c *TokenCreateCommand) Synopsis() string { @@ -121,6 +122,10 @@ Token Options: -use-limit=5 The number of times this token can be used until it is automatically revoked. + + -format=table The format for output. By default it is a whitespace- + delimited table. This can also be json. + ` return strings.TrimSpace(helpText) } diff --git a/command/token_create_test.go b/command/token_create_test.go index 2b659165c5..93482bad1e 100644 --- a/command/token_create_test.go +++ b/command/token_create_test.go @@ -1,6 +1,7 @@ package command import ( + "strings" "testing" "github.com/hashicorp/vault/http" @@ -27,4 +28,10 @@ func TestTokenCreate(t *testing.T) { if code := c.Run(args); code != 0 { t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) } + + // Ensure we get lease info + output := ui.OutputWriter.String() + if !strings.Contains(output, "token_duration") { + t.Fatalf("bad: %#v", output) + } } diff --git a/command/write.go b/command/write.go index fefa23e72d..ba0fdd823d 100644 --- a/command/write.go +++ b/command/write.go @@ -19,15 +19,18 @@ type WriteCommand struct { func (c *WriteCommand) Run(args []string) int { var format string + var force bool flags := c.Meta.FlagSet("write", FlagSetDefault) flags.StringVar(&format, "format", "table", "") + flags.BoolVar(&force, "force", false, "") + flags.BoolVar(&force, "f", false, "") flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { return 1 } args = flags.Args() - if len(args) < 2 { + if len(args) < 2 && !force { c.Ui.Error("write expects at least two arguments") flags.Usage() return 1 @@ -117,6 +120,12 @@ General Options: not recommended. This is especially not recommended for unsealing a vault. +Write Options: + + -f | -force Force the write to continue without any data values + specified. This allows writing to keys that do not + need or expect any fields to be specified. + ` return strings.TrimSpace(helpText) } diff --git a/command/write_test.go b/command/write_test.go index ce570f3730..51774e3c0b 100644 --- a/command/write_test.go +++ b/command/write_test.go @@ -246,3 +246,26 @@ func TestWrite_Output(t *testing.T) { t.Fatalf("bad: %s", string(ui.OutputWriter.Bytes())) } } + +func TestWrite_force(t *testing.T) { + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := http.TestServer(t, core) + defer ln.Close() + + ui := new(cli.MockUi) + c := &WriteCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + + args := []string{ + "-address", addr, + "-force", + "sys/rotate", + } + if code := c.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } +} diff --git a/http/sys_mount.go b/http/sys_mount.go index 36499cfabf..68eae02e41 100644 --- a/http/sys_mount.go +++ b/http/sys_mount.go @@ -13,7 +13,7 @@ func handleSysMounts(core *vault.Core) http.Handler { switch r.Method { case "GET": handleSysListMounts(core).ServeHTTP(w, r) - case "POST": + case "PUT", "POST": fallthrough case "DELETE": handleSysMountUnmount(core, w, r) @@ -27,8 +27,7 @@ func handleSysMounts(core *vault.Core) http.Handler { func handleSysRemount(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { - case "POST": - case "PUT": + case "PUT", "POST": default: respondError(w, http.StatusMethodNotAllowed, nil) return @@ -80,7 +79,7 @@ func handleSysListMounts(core *vault.Core) http.Handler { func handleSysMountUnmount(core *vault.Core, w http.ResponseWriter, r *http.Request) { switch r.Method { - case "POST": + case "PUT", "POST": case "DELETE": default: respondError(w, http.StatusMethodNotAllowed, nil) @@ -100,7 +99,7 @@ func handleSysMountUnmount(core *vault.Core, w http.ResponseWriter, r *http.Requ } switch r.Method { - case "POST": + case "PUT", "POST": handleSysMount(core, w, r, path) case "DELETE": handleSysUnmount(core, w, r, path) diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index c91cd66024..9e7a943405 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -76,6 +76,22 @@ func TestSysMount(t *testing.T) { } } +func TestSysMount_put(t *testing.T) { + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpPut(t, addr+"/v1/sys/mounts/foo", map[string]interface{}{ + "type": "generic", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + + // The TestSysMount test tests the thing is actually created. See that test + // for more info. +} + func TestSysRemount(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) diff --git a/logical/framework/backend.go b/logical/framework/backend.go index b53c9f21b4..042c0e70e3 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -373,6 +373,8 @@ func (t FieldType) Zero() interface{} { return false case TypeMap: return map[string]interface{}{} + case TypeDurationSecond: + return 0 default: panic("unknown type: " + t.String()) } diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index dee5d9419d..4058d194de 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -215,7 +215,7 @@ func TestBackendHandleRequest_renew(t *testing.T) { func TestBackendHandleRequest_renewExtend(t *testing.T) { secret := &Secret{ Type: "foo", - Renew: LeaseExtend(0, 0), + Renew: LeaseExtend(0, 0, false), DefaultDuration: 5 * time.Minute, } b := &Backend{ @@ -508,6 +508,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) { &FieldSchema{Type: TypeString}, "", }, + + "default duration set": { + &FieldSchema{Type: TypeDurationSecond, Default: 60}, + 60, + }, + + "default duration not set": { + &FieldSchema{Type: TypeDurationSecond}, + 0, + }, } for name, tc := range cases { diff --git a/logical/framework/field_data.go b/logical/framework/field_data.go index e8255c8c5a..40d1ac182a 100644 --- a/logical/framework/field_data.go +++ b/logical/framework/field_data.go @@ -2,6 +2,9 @@ package framework import ( "fmt" + "strconv" + "strings" + "time" "github.com/mitchellh/mapstructure" ) @@ -64,13 +67,7 @@ func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { } switch schema.Type { - case TypeBool: - fallthrough - case TypeInt: - fallthrough - case TypeMap: - fallthrough - case TypeString: + case TypeBool, TypeInt, TypeMap, TypeDurationSecond, TypeString: return d.getPrimitive(k, schema) default: return nil, false, @@ -114,6 +111,38 @@ func (d *FieldData) getPrimitive( } return result, true, nil + + case TypeDurationSecond: + var result int + switch inp := raw.(type) { + case int: + result = inp + case float32: + result = int(inp) + case float64: + result = int(inp) + case string: + // Look for a suffix otherwise its a plain second value + if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") { + dur, err := time.ParseDuration(inp) + if err != nil { + return nil, true, err + } + result = int(dur.Seconds()) + } else { + // Plain integer + val, err := strconv.ParseInt(inp, 10, 64) + if err != nil { + return nil, true, err + } + result = int(val) + } + + default: + return nil, false, fmt.Errorf("invalid input '%v'", raw) + } + return result, true, nil + default: panic(fmt.Sprintf("Unknown type: %s", schema.Type)) } diff --git a/logical/framework/field_data_test.go b/logical/framework/field_data_test.go index 000ded72a5..e6a32f8cb1 100644 --- a/logical/framework/field_data_test.go +++ b/logical/framework/field_data_test.go @@ -91,6 +91,50 @@ func TestFieldDataGet(t *testing.T) { "child": true, }, }, + + "duration type, string value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": "42", + }, + "foo", + 42, + }, + + "duration type, string duration value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": "42m", + }, + "foo", + 2520, + }, + + "duration type, int value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": 42, + }, + "foo", + 42, + }, + + "duration type, float value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": 42.0, + }, + "foo", + 42, + }, } for name, tc := range cases { diff --git a/logical/framework/field_type.go b/logical/framework/field_type.go index a02b77bcd8..d9d0ef3d24 100644 --- a/logical/framework/field_type.go +++ b/logical/framework/field_type.go @@ -9,6 +9,10 @@ const ( TypeInt TypeBool TypeMap + + // TypeDurationSecond represent as seconds, this can be either an + // integer or go duration format string (e.g. 24h) + TypeDurationSecond ) func (t FieldType) String() string { @@ -21,6 +25,8 @@ func (t FieldType) String() string { return "bool" case TypeMap: return "map" + case TypeDurationSecond: + return "duration (sec)" default: return "unknown type" } diff --git a/logical/framework/lease.go b/logical/framework/lease.go index 7203d516da..4ba250d26f 100644 --- a/logical/framework/lease.go +++ b/logical/framework/lease.go @@ -13,63 +13,61 @@ import ( // setting it to 2 hours forces a renewal within the next 2 hours again. // // maxSession is the maximum session length allowed since the original -// issue time. If this is zero, it is ignored,. -func LeaseExtend(max, maxSession time.Duration) OperationFunc { +// issue time. If this is zero, it is ignored. +// +// maxFromLease controls if the maximum renewal period comes from the existing +// lease. This means the value of `max` will be replaced with the existing +// lease duration. +func LeaseExtend(max, maxSession time.Duration, maxFromLease bool) OperationFunc { return func(req *logical.Request, data *FieldData) (*logical.Response, error) { lease := detectLease(req) if lease == nil { return nil, fmt.Errorf("no lease options for request") } + // Check if we should limit max + if maxFromLease { + max = lease.Lease + } + + // Sanity check the desired increment + switch { + // Protect against negative leases + case lease.LeaseIncrement < 0: + return logical.ErrorResponse( + "increment must be greater than 0"), logical.ErrInvalidRequest + + // If no lease increment, or too large of an increment, use the max + case max > 0 && lease.LeaseIncrement == 0, max > 0 && lease.LeaseIncrement > max: + lease.LeaseIncrement = max + } + + // Get the current time now := time.Now().UTC() // Check if we're passed the issue limit var maxSessionTime time.Time if maxSession > 0 { maxSessionTime = lease.LeaseIssue.Add(maxSession) - if maxSessionTime.Sub(now) <= 0 { + if maxSessionTime.Before(now) { return logical.ErrorResponse(fmt.Sprintf( "lease can only be renewed up to %s past original issue", maxSession)), logical.ErrInvalidRequest } } - // Protect against negative leases - if lease.LeaseIncrement < 0 { - return logical.ErrorResponse( - "increment must be greater than 0"), logical.ErrInvalidRequest - } - - // If the lease is zero, then assume max - if lease.LeaseIncrement == 0 { - lease.LeaseIncrement = max - } - - // If the increment is greater than the amount of time we have left - // on our session, set it to that. - if !maxSessionTime.IsZero() { - diff := maxSessionTime.Sub(lease.ExpirationTime()) - if diff < lease.LeaseIncrement { - lease.LeaseIncrement = diff - } + // The new lease is the minimum of the requested LeaseIncrement + // or the maxSessionTime + requestedLease := now.Add(lease.LeaseIncrement) + if !maxSessionTime.IsZero() && requestedLease.After(maxSessionTime) { + requestedLease = maxSessionTime } // Determine the requested lease - newLease := lease.IncrementedLease(lease.LeaseIncrement) - - if max > 0 { - // Determine if the requested lease is too long - maxExpiration := now.Add(max) - newExpiration := now.Add(newLease) - if newExpiration.Sub(maxExpiration) > 0 { - // The new expiration is past the max expiration. In this - // case, admit the longest lease we can. - newLease = maxExpiration.Sub(lease.ExpirationTime()) - } - } + newLeaseDuration := requestedLease.Sub(now) // Set the lease - lease.Lease = newLease + lease.Lease = newLeaseDuration return &logical.Response{Auth: req.Auth, Secret: req.Secret}, nil } } @@ -80,6 +78,5 @@ func detectLease(req *logical.Request) *logical.LeaseOptions { } else if req.Secret != nil { return &req.Secret.LeaseOptions } - return nil } diff --git a/logical/framework/lease_test.go b/logical/framework/lease_test.go index ad2fec97da..f22ce798d6 100644 --- a/logical/framework/lease_test.go +++ b/logical/framework/lease_test.go @@ -11,11 +11,12 @@ func TestLeaseExtend(t *testing.T) { now := time.Now().UTC().Round(time.Hour) cases := map[string]struct { - Max time.Duration - MaxSession time.Duration - Request time.Duration - Result time.Duration - Error bool + Max time.Duration + MaxSession time.Duration + Request time.Duration + Result time.Duration + MaxFromLease bool + Error bool }{ "valid request, good bounds": { Max: 30 * time.Hour, @@ -62,20 +63,26 @@ func TestLeaseExtend(t *testing.T) { Request: -7 * time.Hour, Error: true, }, + + "max form lease, request too large": { + Request: 10 * time.Hour, + MaxFromLease: true, + Result: time.Hour, + }, } for name, tc := range cases { req := &logical.Request{ Auth: &logical.Auth{ LeaseOptions: logical.LeaseOptions{ - Lease: 1 * time.Second, + Lease: 1 * time.Hour, LeaseIssue: now, LeaseIncrement: tc.Request, }, }, } - callback := LeaseExtend(tc.Max, tc.MaxSession) + callback := LeaseExtend(tc.Max, tc.MaxSession, tc.MaxFromLease) resp, err := callback(req, nil) if (err != nil) != tc.Error { t.Fatalf("bad: %s\nerr: %s", name, err) diff --git a/logical/lease.go b/logical/lease.go index 834242fd26..32ccb65978 100644 --- a/logical/lease.go +++ b/logical/lease.go @@ -46,22 +46,8 @@ func (l *LeaseOptions) LeaseTotal() time.Duration { // ExpirationTime computes the time until expiration including the grace period func (l *LeaseOptions) ExpirationTime() time.Time { var expireTime time.Time - if !l.LeaseIssue.IsZero() && l.Lease > 0 { - expireTime = l.LeaseIssue.UTC().Add(l.LeaseTotal()) + if l.LeaseEnabled() { + expireTime = time.Now().UTC().Add(l.LeaseTotal()) } - return expireTime } - -// IncrementedLease returns the lease duration that would need to set -// in order to increment the _current_ lease by the given duration -// if the auth were re-issued right now. -func (l *LeaseOptions) IncrementedLease(inc time.Duration) time.Duration { - var result time.Duration - expireTime := l.ExpirationTime() - if expireTime.IsZero() { - return result - } - - return expireTime.Add(inc).Sub(time.Now().UTC()) -} diff --git a/logical/lease_test.go b/logical/lease_test.go index ba74c06426..02916bc817 100644 --- a/logical/lease_test.go +++ b/logical/lease_test.go @@ -5,17 +5,6 @@ import ( "time" ) -func TestLeaseOptionsIncrementedLease(t *testing.T) { - var l LeaseOptions - l.Lease = 1 * time.Second - l.LeaseIssue = time.Now().UTC() - - actual := l.IncrementedLease(1 * time.Second) - if actual > 3*time.Second || actual < 1*time.Second { - t.Fatalf("bad: %s", actual) - } -} - func TestLeaseOptionsLeaseTotal(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour @@ -66,12 +55,11 @@ func TestLeaseOptionsLeaseTotal_negGrace(t *testing.T) { func TestLeaseOptionsExpirationTime(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour - l.LeaseIssue = time.Now().UTC() - actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease) - if !actual.Equal(expected) { - t.Fatalf("bad: %s", actual) + limit := time.Now().UTC().Add(time.Hour) + exp := l.ExpirationTime() + if exp.Before(limit) { + t.Fatalf("bad: %s", exp) } } @@ -79,11 +67,10 @@ func TestLeaseOptionsExpirationTime_grace(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour l.LeaseGracePeriod = 30 * time.Minute - l.LeaseIssue = time.Now().UTC() + limit := time.Now().UTC().Add(time.Hour + 30*time.Minute) actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease + l.LeaseGracePeriod) - if !actual.Equal(expected) { + if actual.Before(limit) { t.Fatalf("bad: %s", actual) } } @@ -92,11 +79,10 @@ func TestLeaseOptionsExpirationTime_graceNegative(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour l.LeaseGracePeriod = -1 * 30 * time.Minute - l.LeaseIssue = time.Now().UTC() + limit := time.Now().UTC().Add(time.Hour) actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease) - if !actual.Equal(expected) { + if actual.Before(limit) { t.Fatalf("bad: %s", actual) } } diff --git a/physical/mysql.go b/physical/mysql.go new file mode 100644 index 0000000000..5db23b9384 --- /dev/null +++ b/physical/mysql.go @@ -0,0 +1,175 @@ +package physical + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" + + "github.com/armon/go-metrics" + _ "github.com/go-sql-driver/mysql" +) + +// MySQLBackend is a physical backend that stores data +// within MySQL database. +type MySQLBackend struct { + dbTable string + client *sql.DB + statements map[string]*sql.Stmt +} + +// newMySQLBackend constructs a MySQL backend using the given API client and +// server address and credential for accessing mysql database. +func newMySQLBackend(conf map[string]string) (Backend, error) { + // Get the MySQL credentials to perform read/write operations. + username, ok := conf["username"] + if !ok || username == "" { + return nil, fmt.Errorf("missing username") + } + password, ok := conf["password"] + if !ok || username == "" { + return nil, fmt.Errorf("missing password") + } + + // Get or set MySQL server address. Defaults to localhost and default port(3306) + address, ok := conf["address"] + if !ok { + address = "127.0.0.1:3306" + } + + // Get the MySQL database and table details. + database, ok := conf["database"] + if !ok { + database = "vault" + } + table, ok := conf["table"] + if !ok { + table = "vault" + } + dbTable := database + "." + table + + // Create MySQL handle for the database. + dsn := username + ":" + password + "@tcp(" + address + ")/" + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, fmt.Errorf("failed to connect to mysql: %v", err) + } + + // Create the required database if it doesn't exists. + if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil { + return nil, fmt.Errorf("failed to create mysql database: %v", err) + } + + // Create the required table if it doesn't exists. + create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + if _, err := db.Exec(create_query); err != nil { + return nil, fmt.Errorf("failed to create mysql table: %v", err) + } + + // Setup the backend. + m := &MySQLBackend{ + dbTable: dbTable, + client: db, + statements: make(map[string]*sql.Stmt), + } + + // Prepare all the statements required + statements := map[string]string{ + "put": "INSERT INTO " + dbTable + + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", + "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", + "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", + "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", + } + for name, query := range statements { + if err := m.prepare(name, query); err != nil { + return nil, err + } + } + return m, nil +} + +// prepare is a helper to prepare a query for future execution +func (m *MySQLBackend) prepare(name, query string) error { + stmt, err := m.client.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare '%s': %v", name, err) + } + m.statements[name] = stmt + return nil +} + +// Put is used to insert or update an entry. +func (m *MySQLBackend) Put(entry *Entry) error { + defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) + + _, err := m.statements["put"].Exec(entry.Key, entry.Value) + if err != nil { + return err + } + return nil +} + +// Get is used to fetch and entry. +func (m *MySQLBackend) Get(key string) (*Entry, error) { + defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) + + var result []byte + err := m.statements["get"].QueryRow(key).Scan(&result) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + + ent := &Entry{ + Key: key, + Value: result, + } + return ent, nil +} + +// Delete is used to permanently delete an entry +func (m *MySQLBackend) Delete(key string) error { + defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) + + _, err := m.statements["delete"].Exec(key) + if err != nil { + return err + } + return nil +} + +// List is used to list all the keys under a given +// prefix, up to the next prefix. +func (m *MySQLBackend) List(prefix string) ([]string, error) { + defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + + // Add the % wildcard to the prefix to do the prefix search + likePrefix := prefix + "%" + rows, err := m.statements["list"].Query(likePrefix) + + var keys []string + for rows.Next() { + var key string + err = rows.Scan(&key) + if err != nil { + return nil, fmt.Errorf("failed to scan rows: %v", err) + } + + key = strings.TrimPrefix(key, prefix) + if i := strings.Index(key, "/"); i == -1 { + // Add objects only from the current 'folder' + keys = append(keys, key) + } else if i != -1 { + // Add truncated 'folder' paths + keys = appendIfMissing(keys, string(key[:i+1])) + } + } + + sort.Strings(keys) + return keys, nil +} diff --git a/physical/mysql_test.go b/physical/mysql_test.go new file mode 100644 index 0000000000..a28fb1441f --- /dev/null +++ b/physical/mysql_test.go @@ -0,0 +1,53 @@ +package physical + +import ( + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +func TestMySQLBackend(t *testing.T) { + address := os.Getenv("MYSQL_ADDR") + if address == "" { + t.SkipNow() + } + + database := os.Getenv("MYSQL_DB") + if database == "" { + database = "test" + } + + table := os.Getenv("MYSQL_TABLE") + if table == "" { + table = "test" + } + + username := os.Getenv("MYSQL_USERNAME") + password := os.Getenv("MYSQL_PASSWORD") + + // Run vault tests + b, err := NewBackend("mysql", map[string]string{ + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, + }) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + + defer func() { + mysql := b.(*MySQLBackend) + _, err := mysql.client.Exec("DROP TABLE " + mysql.dbTable) + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } + }() + + testBackend(t, b) + testBackend_ListPrefix(t, b) + +} diff --git a/physical/physical.go b/physical/physical.go index b992c10102..bd307e8dea 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -84,4 +84,5 @@ var BuiltinBackends = map[string]Factory{ "file": newFileBackend, "s3": newS3Backend, "etcd": newEtcdBackend, + "mysql": newMySQLBackend, } diff --git a/vault/core.go b/vault/core.go index 723fe23796..3c67f63fad 100644 --- a/vault/core.go +++ b/vault/core.go @@ -328,6 +328,21 @@ func NewCore(conf *CoreConfig) (*Core, error) { return c, nil } +// Shutdown is invoked when the Vault instance is about to be terminated. It +// should not be accessible as part of an API call as it will cause an availability +// problem. It is only used to gracefully quit in the case of HA so that failover +// happens as quickly as possible. +func (c *Core) Shutdown() error { + c.stateLock.Lock() + defer c.stateLock.Unlock() + if c.sealed { + return nil + } + + // Seal the Vault, causes a leader stepdown + return c.sealInternal() +} + // HandleRequest is used to handle a new incoming request func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) { c.stateLock.RLock() @@ -413,8 +428,9 @@ func (c *Core) handleRequest(req *logical.Request) (*logical.Response, error) { } // Only the token store is allowed to return an auth block, for any - // other request this is an internal error - if resp != nil && resp.Auth != nil { + // other request this is an internal error. We exclude renewal of a token, + // since it does not need to be re-registered + if resp != nil && resp.Auth != nil && !strings.HasPrefix(req.Path, "auth/token/renew/") { if !strings.HasPrefix(req.Path, "auth/token/") { c.logger.Printf( "[ERR] core: unexpected Auth response for non-token backend "+ @@ -929,6 +945,14 @@ func (c *Core) Seal(token string) error { return err } + // Seal the Vault + return c.sealInternal() +} + +// sealInternal is an internal method used to seal the vault. +// It does not do any authorization checking. The stateLock must +// be held prior to calling. +func (c *Core) sealInternal() error { // Enable that we are sealed to prevent furthur transactions c.sealed = true diff --git a/vault/core_test.go b/vault/core_test.go index c27f03460d..4b48c59acd 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -348,6 +348,17 @@ func TestCore_SealUnseal(t *testing.T) { } } +// Attempt to shutdown after unseal +func TestCore_Shutdown(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + if err := c.Shutdown(); err != nil { + t.Fatalf("err: %v", err) + } + if sealed, err := c.Sealed(); err != nil || !sealed { + t.Fatalf("err: %v", err) + } +} + // Attempt to seal bad token func TestCore_Seal_BadToken(t *testing.T) { c, _, _ := TestCoreUnsealed(t) @@ -1368,6 +1379,55 @@ func TestCore_RenewSameLease(t *testing.T) { } } +// Renew of a token should not create a new lease +func TestCore_RenewToken_SingleRegister(t *testing.T) { + c, _, root := TestCoreUnsealed(t) + + // Create a new token + req := &logical.Request{ + Operation: logical.WriteOperation, + Path: "auth/token/create", + Data: map[string]interface{}{ + "lease": "1h", + }, + ClientToken: root, + } + resp, err := c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + newClient := resp.Auth.ClientToken + + // Renew the token + req = logical.TestRequest(t, logical.WriteOperation, "auth/token/renew/"+newClient) + req.ClientToken = newClient + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Revoke using the renew prefix + req = logical.TestRequest(t, logical.WriteOperation, "sys/revoke-prefix/auth/token/renew/") + req.ClientToken = root + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Verify our token is still valid (e.g. we did not get invalided by the revoke) + req = logical.TestRequest(t, logical.ReadOperation, "auth/token/lookup/"+newClient) + req.ClientToken = newClient + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Verify the token exists + if resp.Data["id"] != newClient { + t.Fatalf("bad: %#v", resp.Data) + } +} + // Based on bug GH-203, attempt to disable a credential backend with leased secrets func TestCore_EnableDisableCred_WithLease(t *testing.T) { // Create a badass credential backend that always logs in as armon diff --git a/vault/expiration.go b/vault/expiration.go index 601722cf80..3df1738ed2 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -337,7 +337,6 @@ func (m *ExpirationManager) RenewToken(source string, token string, // Attach the ClientToken resp.Auth.ClientToken = token resp.Auth.LeaseIncrement = 0 - resp.Auth.LeaseIssue = time.Now().UTC() // Update the lease entry le.Auth = resp.Auth @@ -366,9 +365,6 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons return "", err } - // Setup some of the fields on auth - resp.Secret.LeaseIssue = time.Now().UTC() - // Create a lease entry le := leaseEntry{ LeaseID: path.Join(req.Path, generateUUID()), @@ -376,7 +372,7 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons Path: req.Path, Data: resp.Data, Secret: resp.Secret, - IssueTime: resp.Secret.LeaseIssue, + IssueTime: time.Now().UTC(), ExpireTime: resp.Secret.ExpirationTime(), } @@ -403,16 +399,13 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) error { defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now()) - // Setup some of the fields on auth - auth.LeaseIssue = time.Now().UTC() - // Create a lease entry le := leaseEntry{ LeaseID: path.Join(source, m.tokenStore.SaltID(auth.ClientToken)), ClientToken: auth.ClientToken, Auth: auth, Path: source, - IssueTime: auth.LeaseIssue, + IssueTime: time.Now().UTC(), ExpireTime: auth.ExpirationTime(), } @@ -642,7 +635,7 @@ func (l *leaseEntry) encode() ([]byte, error) { func (le *leaseEntry) renewable() error { // If there is no entry, cannot review if le == nil || le.ExpireTime.IsZero() { - return fmt.Errorf("lease not found") + return fmt.Errorf("lease not found or lease is not renewable") } // Determine if the lease is expired diff --git a/vault/expiration_test.go b/vault/expiration_test.go index f7b581be25..0fc4484bd8 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -152,7 +152,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { // Should not be able to renew, no expiration _, err = exp.RenewToken("auth/github/login", root.ID, 0) - if err.Error() != "lease not found" { + if err.Error() != "lease not found or lease is not renewable" { t.Fatalf("err: %v", err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 9769a882bd..6673226d0a 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -110,7 +110,7 @@ func NewSystemBackend(core *Core) logical.Backend { Description: strings.TrimSpace(sysHelp["lease_id"][0]), }, "increment": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeDurationSecond, Description: strings.TrimSpace(sysHelp["increment"][0]), }, }, diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index b3ba308b62..33749ae14f 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -181,7 +181,7 @@ func TestSystemBackend_renew(t *testing.T) { // Attempt renew req2 := logical.TestRequest(t, logical.WriteOperation, "renew/"+resp.Secret.LeaseID) - req2.Data["increment"] = 100 + req2.Data["increment"] = "100s" resp2, err := b.HandleRequest(req2) if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) @@ -202,7 +202,7 @@ func TestSystemBackend_renew_invalidID(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp.Data["error"] != "lease not found" { + if resp.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -250,7 +250,7 @@ func TestSystemBackend_revoke(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -312,7 +312,7 @@ func TestSystemBackend_revokePrefix(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } diff --git a/vault/token_store.go b/vault/token_store.go index b3ddc889c0..03f278c83c 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -85,7 +85,10 @@ func NewTokenStore(c *Core) (*TokenStore, error) { // Setup the framework endpoints t.Backend = &framework.Backend{ - AuthRenew: framework.LeaseExtend(0, 0), + // Allow a token lease to be extended indefinitely, but each time for only + // as much as the original lease allowed for. If the lease has a 1 hour expiration, + // it can only be extended up to another hour each time this means. + AuthRenew: framework.LeaseExtend(0, 0, true), PathsSpecial: &logical.Paths{ Root: []string{ @@ -208,7 +211,7 @@ func NewTokenStore(c *Core) (*TokenStore, error) { Description: "Token to renew", }, "increment": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeDurationSecond, Description: "The desired increment in seconds to the token expiration", }, }, diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 76b2e97b8e..2abea5cbf9 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -820,8 +820,9 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { // Get the original expire time to compare originalExpire := auth.ExpirationTime() + beforeRenew := time.Now().UTC() req := logical.TestRequest(t, logical.WriteOperation, "renew/"+root.ID) - req.Data["increment"] = "3600" + req.Data["increment"] = "3600s" resp, err := ts.HandleRequest(req) if err != nil { t.Fatalf("err: %v %v", err, resp) @@ -829,9 +830,11 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { // Get the new expire time newExpire := resp.Auth.ExpirationTime() - expireDiff := newExpire.Sub(originalExpire) - if expireDiff < 30*time.Minute || expireDiff > 3*time.Hour { - t.Fatalf("bad: %#v", expireDiff) + if newExpire.Before(originalExpire) { + t.Fatalf("should expire later: %s %s", newExpire, originalExpire) + } + if newExpire.Before(beforeRenew.Add(time.Hour)) { + t.Fatalf("should have at least an hour: %s %s", newExpire, beforeRenew) } } diff --git a/website/source/docs/config/index.html.md b/website/source/docs/config/index.html.md index 278b4d9549..8cae4df7c6 100644 --- a/website/source/docs/config/index.html.md +++ b/website/source/docs/config/index.html.md @@ -76,6 +76,8 @@ durability, etc. * `s3` - Store data within an S3 bucket [S3](http://aws.amazon.com/s3/). This backend does not support HA. + * `mysql` - Store data within MySQL. This backend does not support HA. + * `inmem` - Store data in-memory. This is only really useful for development and experimentation. Data is lost whenever Vault is restarted. @@ -143,6 +145,21 @@ For S3, the following options are supported: * `region` (optional) - The AWS region. It can be sourced from the AWS_DEFAULT_REGION environment variable and will default to "us-east-1" if not specified. +#### Backend Reference: MySQL + +The MySQL backend has the following options: + + * `username` (required) - The MySQL username to connect with. + + * `password` (required) - The MySQL password to connect with. + + * `address` (optional) - The address of the MySQL host. Defaults to + "127.0.0.1:3306. + + * `database` (optional) - The name of the database to use. Defaults to "vault". + + * `table` (optional) - The name of the table to use. Defaults to "vault". + #### Backend Reference: Inmem The in-memory backend has no configuration options. diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index 3506eca03a..743ba0662e 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -54,6 +54,15 @@ $ vault read transit/keys/foo Key Value name foo cipher_mode aes-gcm +```` + +We can read from the `raw/` endpoint to see the encryption key itself: + +``` +$ vault read transit/raw/foo +Key Value +name foo +cipher_mode aes-gcm key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8= ```` @@ -114,17 +123,7 @@ only encrypt or decrypt using the named keys they need access to.
Returns
- - ```javascript - { - "data": { - "name": "foo", - "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" - } - } - ``` - + A `204` response code.
@@ -156,7 +155,6 @@ only encrypt or decrypt using the named keys they need access to. "data": { "name": "foo", "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" } } ``` @@ -196,7 +194,9 @@ only encrypt or decrypt using the named keys they need access to.
Description
- Encrypts the provided plaintext using the named key. + Encrypts the provided plaintext using the named key. If the named key + does not already exist, it will be automatically generated for the given + name with the default parameters.
Method
@@ -269,3 +269,42 @@ only encrypt or decrypt using the named keys they need access to.
+ +### /transit/raw/ +#### GET + +
+
Description
+
+ Returns raw information about a named encryption key, + Including the underlying encryption key. This is a root protected endpoint. +
+ +
Method
+
GET
+ +
URL
+
`/transit/raw/`
+ +
Parameters
+
+ None +
+ +
Returns
+
+ + ```javascript + { + "data": { + "name": "foo", + "cipher_mode": "aes-gcm", + "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" + } + } + ``` + +
+
+ +