diff --git a/api/ssh.go b/api/ssh.go index ee6c959551..3fa8a28c07 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -5,14 +5,17 @@ import ( "fmt" ) +// SSH is used to return a client to invoke operations on SSH backend. type SSH struct { c *Client } +// SSH is used to return the client for logical-backend API calls. func (c *Client) SSH() *SSH { return &SSH{c: c} } +// Invokes the SSH backend API to create a dynamic key func (c *SSH) KeyCreate(role string, data map[string]interface{}) (*Secret, error) { r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/ssh/creds/%s", role)) if err := r.SetJSONBody(data); err != nil { @@ -28,6 +31,7 @@ func (c *SSH) KeyCreate(role string, data map[string]interface{}) (*Secret, erro return ParseSecret(resp.Body) } +// Invokes the SSH backend API to list the roles associated with given IP address. func (c *SSH) Lookup(data map[string]interface{}) (*SSHRoles, error) { r := c.c.NewRequest("PUT", "/v1/ssh/lookup") if err := r.SetJSONBody(data); err != nil { @@ -48,6 +52,10 @@ func (c *SSH) Lookup(data map[string]interface{}) (*SSHRoles, error) { return &roles, nil } +// Structures for the requests/resposne are all down here. They aren't +// individually documentd because the map almost directly to the raw HTTP API +// documentation. Please refer to that documentation for more details. + type SSHRoles struct { Data map[string]interface{} `json:"data"` } diff --git a/builtin/logical/ssh/path_lookup.go b/builtin/logical/ssh/path_lookup.go index be49fb8c53..cca0430967 100644 --- a/builtin/logical/ssh/path_lookup.go +++ b/builtin/logical/ssh/path_lookup.go @@ -39,18 +39,11 @@ func (b *backend) pathLookupWrite(req *logical.Request, d *framework.FieldData) if err != nil { return nil, err } - if len(keys) == 0 { - return &logical.Response{ - Data: map[string]interface{}{ - "roles": nil, - }, - }, nil - } var matchingRoles []string - for _, item := range keys { - if contains, _ := containsIP(req.Storage, item, ip.String()); contains { - matchingRoles = append(matchingRoles, item) + for _, role := range keys { + if contains, _ := roleContainsIP(req.Storage, role, ip.String()); contains { + matchingRoles = append(matchingRoles, role) } } return &logical.Response{ diff --git a/builtin/logical/ssh/path_role_create.go b/builtin/logical/ssh/path_role_create.go index 5501cf33d4..d06ac1e681 100644 --- a/builtin/logical/ssh/path_role_create.go +++ b/builtin/logical/ssh/path_role_create.go @@ -2,9 +2,7 @@ package ssh import ( "fmt" - "io/ioutil" "net" - "strings" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -47,7 +45,7 @@ func (b *backend) pathRoleCreateWrite( return logical.ErrorResponse("Missing ip"), nil } - //find the role to be used for installing dynamic key + // Find the role to be used for installing dynamic key roleEntry, err := req.Storage.Get(fmt.Sprintf("policy/%s", roleName)) if err != nil { return nil, fmt.Errorf("error retrieving role: %s", err) @@ -60,33 +58,26 @@ func (b *backend) pathRoleCreateWrite( return nil, err } + // Set the default username if username == "" { username = role.DefaultUser } - //validate the IP address + // Validate the IP address ipAddr := net.ParseIP(ipRaw) if ipAddr == nil { return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ipRaw)), nil } ip := ipAddr.String() - - ipMatched := false - for _, item := range strings.Split(role.CIDR, ",") { - _, cidrIPNet, err := net.ParseCIDR(item) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil - } - ipMatched = cidrIPNet.Contains(ipAddr) - if ipMatched { - break - } + ipMatched, err := cidrContainsIP(ip, role.CIDR) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %s", err)), nil } if !ipMatched { return logical.ErrorResponse(fmt.Sprintf("IP[%s] does not belong to role[%s]", ip, roleName)), nil } - //fetch the host key to be used for installation + // Fetch the host key to be used for dynamic key installation keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName)) if err != nil { return nil, fmt.Errorf("key '%s' not found error:%s", role.KeyName, err) @@ -96,60 +87,20 @@ func (b *backend) pathRoleCreateWrite( return nil, fmt.Errorf("error reading the host key: %s", err) } - //store the host key to file. Use it as parameter for scp command - hostKeyFileName := fmt.Sprintf("./vault_ssh_%s_%s_shared.pem", username, ip) - err = ioutil.WriteFile(hostKeyFileName, []byte(hostKey.Key), 0600) - - dynamicPrivateKeyFileName := fmt.Sprintf("vault_ssh_%s_%s_otk.pem", username, ip) - dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s_otk.pem.pub", username, ip) - - //delete the temporary files if they are already present - err = removeFile(dynamicPrivateKeyFileName) - if err != nil { - return nil, fmt.Errorf("error removing dynamic private key file: '%s'", err) - } - err = removeFile(dynamicPublicKeyFileName) - if err != nil { - return nil, fmt.Errorf("error removing dynamic private key file: '%s'", err) - } - - //generate RSA key pair + // Generate RSA key pair dynamicPublicKey, dynamicPrivateKey, _ := generateRSAKeys() - //save the public key pair to a file - ioutil.WriteFile(dynamicPublicKeyFileName, []byte(dynamicPublicKey), 0644) - - //send the public key to target machine - err = uploadFileScp(dynamicPublicKeyFileName, username, ip, hostKey.Key) + // Transfer the public key to target machine + err = uploadPublicKeyScp(dynamicPublicKey, username, ip, hostKey.Key) if err != nil { return nil, err } - //connect to target machine - session, err := createSSHPublicKeysSession(username, ip, hostKey.Key) + // Add the public key to authorized_keys file in target machine + err = installPublicKeyInTarget(username, ip, hostKey.Key) if err != nil { - return nil, fmt.Errorf("unable to create SSH Session using public keys: %s", err) + return nil, fmt.Errorf("error adding public key to authorized_keys file in target") } - if session == nil { - return nil, fmt.Errorf("invalid session object") - } - - authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username) - tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) - - //commands to be run on target machine - grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) - catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) - catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName) - removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) - - targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) - - //run the commands on target machine - if err := session.Run(targetCmd); err != nil { - return nil, err - } - session.Close() result := b.Secret(SecretOneTimeKeyType).Response(map[string]interface{}{ "key": dynamicPrivateKey, @@ -159,6 +110,8 @@ func (b *backend) pathRoleCreateWrite( "host_key_name": role.KeyName, "dynamic_public_key": dynamicPublicKey, }) + + // Change the lease information to reflect user's choice lease, _ := b.Lease(req.Storage) if lease != nil { result.Secret.Lease = lease.Lease diff --git a/builtin/logical/ssh/path_roles.go b/builtin/logical/ssh/path_roles.go index f328985b86..bf10742e0b 100644 --- a/builtin/logical/ssh/path_roles.go +++ b/builtin/logical/ssh/path_roles.go @@ -53,7 +53,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (* defaultUser := d.Get("default_user").(string) cidr := d.Get("cidr").(string) - //input validations + // Input validations if roleName == "" { return logical.ErrorResponse("Missing role name"), nil } @@ -102,16 +102,23 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (* func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { roleName := d.Get("name").(string) - entry, err := req.Storage.Get(fmt.Sprintf("policy/%s", roleName)) + roleEntry, err := req.Storage.Get(fmt.Sprintf("policy/%s", roleName)) if err != nil { return nil, err } - if entry == nil { + if roleEntry == nil { return nil, nil } + var role sshRole + if err := roleEntry.DecodeJSON(&role); err != nil { + return nil, err + } return &logical.Response{ Data: map[string]interface{}{ - "policy": string(entry.Value), + "key": role.KeyName, + "admin_user": role.AdminUser, + "default_user": role.DefaultUser, + "cidr": role.CIDR, }, }, nil } diff --git a/builtin/logical/ssh/secret_ssh_key.go b/builtin/logical/ssh/secret_ssh_key.go index 4520a531aa..4f7377ce38 100644 --- a/builtin/logical/ssh/secret_ssh_key.go +++ b/builtin/logical/ssh/secret_ssh_key.go @@ -2,7 +2,6 @@ package ssh import ( "fmt" - "io/ioutil" "time" "github.com/hashicorp/vault/logical" @@ -77,7 +76,7 @@ func (b *backend) secretSSHKeyRevoke(req *logical.Request, d *framework.FieldDat return nil, fmt.Errorf("secret is missing internal data") } - //fetch the host key using the key name + // Fetch the host key using the key name hostKeyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", hostKeyName)) if err != nil { return nil, fmt.Errorf("key '%s' not found error:%s", hostKeyName, err) @@ -87,43 +86,16 @@ func (b *backend) secretSSHKeyRevoke(req *logical.Request, d *framework.FieldDat return nil, fmt.Errorf("error reading the host key: %s", err) } - //write host key to file and use it as argument to scp command - hostKeyFileName := fmt.Sprintf("./vault_ssh_%s_%s_shared.pem", username, ip) - err = ioutil.WriteFile(hostKeyFileName, []byte(hostKey.Key), 0400) - - //write dynamicPublicKey to file and use it as argument to scp command - dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s_otk.pem.pub", username, ip) - err = ioutil.WriteFile(dynamicPublicKeyFileName, []byte(dynamicPublicKey), 0400) - - //transfer the dynamic public key to target machine and use it to remove the entry from authorized_keys file - err = uploadFileScp(dynamicPublicKeyFileName, username, ip, hostKey.Key) + // Transfer the dynamic public key to target machine and use it to remove the entry from authorized_keys file + err = uploadPublicKeyScp(dynamicPublicKey, username, ip, hostKey.Key) if err != nil { return nil, fmt.Errorf("public key transfer failed: %s", err) } - //connect to target machine - session, err := createSSHPublicKeysSession(username, ip, hostKey.Key) + // Remove the public key from authorized_keys file in target machine + err = uninstallPublicKeyInTarget(username, ip, hostKey.Key) if err != nil { - return nil, fmt.Errorf("unable to create SSH Session using public keys: %s", err) + return nil, fmt.Errorf("error removing public key from authorized_keys file in target") } - if session == nil { - return nil, fmt.Errorf("invalid session object") - } - - authKeysFileName := "/home/" + username + "/.ssh/authorized_keys" - tempKeysFileName := "/home/" + username + "/temp_authorized_keys" - - //commands to be run on target machine - grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) - catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) - removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) - - remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd) - - //run the commands in target machine - if err := session.Run(remoteCmd); err != nil { - return nil, err - } - return nil, nil } diff --git a/builtin/logical/ssh/util.go b/builtin/logical/ssh/util.go index 8331ecc62c..96b94bd5d8 100644 --- a/builtin/logical/ssh/util.go +++ b/builtin/logical/ssh/util.go @@ -9,9 +9,6 @@ import ( "fmt" "io" "net" - "os" - "os/exec" - "path/filepath" "strings" "github.com/hashicorp/vault/logical" @@ -19,34 +16,11 @@ import ( "golang.org/x/crypto/ssh" ) -/* -Executes the command represented by the input. -Multiple commands can be run by concatinating strings with ';'. -Currently, it is supported only for linux platforms and user bash shell. -*/ -func exec_command(cmdString string) error { - cmd := exec.Command("/bin/bash", "-c", cmdString) - if _, err := cmd.Output(); err != nil { - return err - } - return nil -} - -/* -Transfers the file to the target machine by establishing an SSH session with the target. -Uses the public key authentication method and hence the parameter 'key' takes in the private key. -The fileName parameter takes an absolute path. -*/ -func uploadFileScp(fileName, username, ip, key string) error { - nameBase := filepath.Base(fileName) - file, err := os.Open(fileName) - if err != nil { - return err - } - stat, err := file.Stat() - if os.IsNotExist(err) { - return fmt.Errorf("file does not exist") - } +// Transfers the file to the target machine by establishing an SSH session with the target. +// Uses the public key authentication method and hence the parameter 'key' takes in the private key. +// The fileName parameter takes an absolute path. +func uploadPublicKeyScp(publicKey, username, ip, key string) error { + dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) session, err := createSSHPublicKeysSession(username, ip, key) if err != nil { return err @@ -57,21 +31,19 @@ func uploadFileScp(fileName, username, ip, key string) error { defer session.Close() go func() { w, _ := session.StdinPipe() - fmt.Fprintln(w, "C0644", stat.Size(), nameBase) - io.Copy(w, file) + fmt.Fprintln(w, "C0644", len(publicKey), dynamicPublicKeyFileName) + io.Copy(w, strings.NewReader(publicKey)) fmt.Fprint(w, "\x00") w.Close() }() - if err := session.Run(fmt.Sprintf("scp -vt %s", nameBase)); err != nil { + if err := session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)); err != nil { return err } return nil } -/* -Creates a SSH session object which can be used to run commands in the target machine. -The session will use public key authentication method with port 22. -*/ +// Creates a SSH session object which can be used to run commands in the target machine. +// The session will use public key authentication method with port 22. func createSSHPublicKeysSession(username, ipAddr, hostKey string) (*ssh.Session, error) { if username == "" { return nil, fmt.Errorf("missing username") @@ -109,33 +81,8 @@ func createSSHPublicKeysSession(username, ipAddr, hostKey string) (*ssh.Session, return session, nil } -/* -Deletes the file in the current directory. -The parameter is just the name of the file and not a path. -*/ -func removeFile(fileName string) error { - if fileName == "" { - return fmt.Errorf("missing file name") - } - wd, err := os.Getwd() - if err != nil { - return err - } - absFileName := wd + "/" + fileName - - if _, err := os.Stat(absFileName); err == nil { - err := os.Remove(absFileName) - if err != nil { - return err - } - } - return nil -} - -/* -Creates a new RSA key pair with key length of 2048. -The private key will be of pem format and the public key will be of OpenSSH format. -*/ +// Creates a new RSA key pair with key length of 2048. +// The private key will be of pem format and the public key will be of OpenSSH format. func generateRSAKeys() (publicKeyRsa string, privateKeyRsa string, err error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -155,7 +102,67 @@ func generateRSAKeys() (publicKeyRsa string, privateKeyRsa string, err error) { return } -func containsIP(s logical.Storage, roleName string, ip string) (bool, error) { +// Concatenates the public present in that target machine's home folder to ~/.ssh/authorized_keys file +func installPublicKeyInTarget(username, ip, hostKey string) error { + session, err := createSSHPublicKeysSession(username, ip, hostKey) + if err != nil { + return fmt.Errorf("unable to create SSH Session using public keys: %s", err) + } + if session == nil { + return fmt.Errorf("invalid session object") + } + defer session.Close() + + authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username) + tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) + + // Commands to be run on target machine + dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) + grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) + catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) + catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName) + removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) + + targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) + + // Run the commands on target machine + if err := session.Run(targetCmd); err != nil { + return err + } + return nil +} + +// Removes the installed public key from the authorized_keys file in target machine +func uninstallPublicKeyInTarget(username, ip, hostKey string) error { + session, err := createSSHPublicKeysSession(username, ip, hostKey) + if err != nil { + return fmt.Errorf("unable to create SSH Session using public keys: %s", err) + } + if session == nil { + return fmt.Errorf("invalid session object") + } + defer session.Close() + + authKeysFileName := "/home/" + username + "/.ssh/authorized_keys" + tempKeysFileName := "/home/" + username + "/temp_authorized_keys" + + // Commands to be run on target machine + dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) + grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) + catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) + removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) + + remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd) + + // Run the commands in target machine + if err := session.Run(remoteCmd); err != nil { + return err + } + return nil +} + +// Takes an IP address and role name and checks if the IP is part of CIDR blocks belonging to the role. +func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error) { if roleName == "" { return false, fmt.Errorf("missing role name") } @@ -173,16 +180,24 @@ func containsIP(s logical.Storage, roleName string, ip string) (bool, error) { if err := roleEntry.DecodeJSON(&role); err != nil { return false, fmt.Errorf("error decoding role '%s'", roleName) } - ipMatched := false - for _, item := range strings.Split(role.CIDR, ",") { + + if matched, err := cidrContainsIP(ip, role.CIDR); err != nil { + return false, err + } else { + return matched, nil + } +} + +// Returns true if the IP supplied by the user is part of the comma separated CIDR blocks +func cidrContainsIP(ip, cidr string) (bool, error) { + for _, item := range strings.Split(cidr, ",") { _, cidrIPNet, err := net.ParseCIDR(item) if err != nil { return false, fmt.Errorf("invalid cidr entry '%s'", item) } - ipMatched = cidrIPNet.Contains(net.ParseIP(ip)) - if ipMatched { - break + if cidrIPNet.Contains(net.ParseIP(ip)) { + return true, nil } } - return ipMatched, nil + return false, nil } diff --git a/command/ssh.go b/command/ssh.go index fc0697caca..2d2a473a7b 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -10,6 +10,7 @@ import ( "syscall" ) +// SSHCommand is a Command that establishes a SSH connection with target by generating a dynamic key type SSHCommand struct { Meta } @@ -35,7 +36,7 @@ func (c *SSHCommand) Run(args []string) int { } input := strings.Split(args[0], "@") username := input[0] - ip, err := net.ResolveIPAddr("ip4", input[1]) + ip, err := net.ResolveIPAddr("ip", input[1]) if err != nil { c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err)) return 2 @@ -109,7 +110,28 @@ func (c *SSHCommand) Synopsis() string { func (c *SSHCommand) Help() string { helpText := ` - SSHCommand Help String - ` +Usage: vault ssh [options] username@ip + + Establishes an SSH connection with the target machine. + + This command generates a dynamic key and uses it to establish an + SSH connection with the target machine. This operation requires + that SSH backend is mounted and at least one 'role' be registed + with vault at priori. + +General Options: + + ` + generalOptionsUsage() + ` + +SSH Options: + + -role Mention the role to be used to create dynamic key. + Each IP is associated with a role. To see the associated + roles with IP, use "lookup" endpoint. If you are certain that + there is only one role associated with the IP, you can + skip mentioning the role. It will be chosen by default. + If there are no roless associated with the IP, register + the CIDR block of that IP using the "roles/" endpoint. +` return strings.TrimSpace(helpText) }