diff --git a/builtin/logical/ssh/path_config_add_host_key.go b/builtin/logical/ssh/path_config_add_host_key.go index dddb59094c..bb43aa588c 100644 --- a/builtin/logical/ssh/path_config_add_host_key.go +++ b/builtin/logical/ssh/path_config_add_host_key.go @@ -1,8 +1,6 @@ package ssh import ( - "bytes" - "fmt" "log" "github.com/hashicorp/vault/logical" @@ -39,30 +37,30 @@ func (b *backend) pathAddHostKeyWrite(req *logical.Request, d *framework.FieldDa log.Printf("Vishal: ssh.pathAddHostKeyWrite\n") username := d.Get("username").(string) ip := d.Get("ip").(string) + //TODO: parse ip into ipv4 address and validate it key := d.Get("key").(string) log.Printf("Vishal: ssh.pathAddHostKeyWrite username:%#v ip:%#v key:%#v\n", username, ip, key) - localCmdString := ` - rm -f vault_ssh_otk.pem vault_ssh_otk.pem.pub; - ssh-keygen -f vault_ssh_otk.pem -t rsa -N ''; - chmod 400 vault_ssh_otk.pem; - scp -i vault_ssh_shared.pem vault_ssh_otk.pem.pub vishal@localhost:/home/vishal - echo done! - ` - err := exec_command(localCmdString) + + entry, err := logical.StorageEntryJSON("hosts/"+ip+"/"+username, &sshAddHostKey{ + Username: username, + IP: ip, + Key: key, + }) if err != nil { - fmt.Errorf("Running command failed " + err.Error()) + return nil, err } - session := createSSHPublicKeysSession("vishal", "127.0.0.1") - var buf bytes.Buffer - session.Stdout = &buf - if err := installSshOtkInTarget(session); err != nil { - fmt.Errorf("Failed to install one-time-key at target machine: " + err.Error()) + if err := req.Storage.Put(entry); err != nil { + return nil, err } - session.Close() - fmt.Println(buf.String()) return nil, nil } +type sshAddHostKey struct { + Username string + IP string + Key string +} + const pathConfigAddHostKeySyn = ` pathConfigAddHostKeySyn ` diff --git a/builtin/logical/ssh/path_config_remove_host_key.go b/builtin/logical/ssh/path_config_remove_host_key.go index 4acc330713..e48f0a5720 100644 --- a/builtin/logical/ssh/path_config_remove_host_key.go +++ b/builtin/logical/ssh/path_config_remove_host_key.go @@ -20,10 +20,6 @@ func pathConfigRemoveHostKey(b *backend) *framework.Path { Type: framework.TypeString, Description: "IP address of host.", }, - "key": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "SSH private key for host.", - }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.WriteOperation: b.pathRemoveHostKeyWrite, @@ -33,8 +29,16 @@ func pathConfigRemoveHostKey(b *backend) *framework.Path { } } -func (b *backend) pathRemoveHostKeyWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { +func (b *backend) pathRemoveHostKeyWrite(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { log.Printf("Vishal: ssh.pathRemoveHostKeyWrite\n") + username := d.Get("username").(string) + ip := d.Get("ip").(string) + //TODO: parse ip into ipv4 address and validate it + log.Printf("Vishal: ssh.pathRemoveHostKeyWrite username:%#v ip:%#v\n", username, ip) + err := req.Storage.Delete("hosts/" + ip + "/" + username) + if err != nil { + return nil, err + } return nil, nil } diff --git a/builtin/logical/ssh/secret_one_time_key.go b/builtin/logical/ssh/secret_one_time_key.go index 04a9941a62..96e73f3780 100644 --- a/builtin/logical/ssh/secret_one_time_key.go +++ b/builtin/logical/ssh/secret_one_time_key.go @@ -30,7 +30,7 @@ func secretOneTimeKey(b *backend) *framework.Secret { }, DefaultDuration: 1 * time.Hour, DefaultGracePeriod: 10 * time.Minute, - Renew: framework.LeaseExtend(1*time.Hour, 0), + Renew: framework.LeaseExtend(1*time.Hour, 0, false), Revoke: b.secretPrivateKeyRevoke, } } diff --git a/command/ssh.go b/command/ssh.go index 8621cae66f..a7623766e2 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -48,6 +48,7 @@ func (c *SshCommand) Run(args []string) int { sshEnv := os.Environ() sshCmdArgs := []string{"ssh", "-i", "vault_ssh_otk_" + args[0] + ".pem", "vishal@localhost"} + defer os.Remove("vault_ssh_otk_" + args[0] + ".pem") if err := syscall.Exec(sshBinary, sshCmdArgs, sshEnv); err != nil { log.Printf("Execution failed: sshCommand: " + err.Error())