Refactor the generate_signing_key processing (#2430)

This commit is contained in:
Vishal Nayak 2017-03-02 16:22:06 -05:00 committed by Jeff Mitchell
parent 1c821e448d
commit 93b74ebe71

View File

@ -1,15 +1,15 @@
package ssh package ssh
import ( import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt" "fmt"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"crypto/rsa"
"encoding/pem"
"crypto/rand"
"crypto/x509"
) )
func pathConfigCA(b *backend) *framework.Path { func pathConfigCA(b *backend) *framework.Path {
@ -27,7 +27,7 @@ func pathConfigCA(b *backend) *framework.Path {
"generate_signing_key": &framework.FieldSchema{ "generate_signing_key": &framework.FieldSchema{
Type: framework.TypeBool, Type: framework.TypeBool,
Description: `Generate SSH key pair internally rather than use the private_key and public_key fields.`, Description: `Generate SSH key pair internally rather than use the private_key and public_key fields.`,
Default: true, Default: true,
}, },
}, },
@ -44,22 +44,50 @@ For security reasons, the private key cannot be retrieved later.`,
} }
func (b *backend) pathCAWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathCAWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var publicKey, privateKey string
var err error var err error
if data.Get("generate_signing_key").(bool) { publicKey := data.Get("public_key").(string)
publicKey, privateKey, err = generateSSHKeyPair() privateKey := data.Get("private_key").(string)
if err != nil {
return nil, err
}
} else {
publicKey, privateKey, err = parseSSHKeyPair(data)
if err != nil { signingKeyGenerated := false
return nil, err generateSigningKeyRaw, ok := data.GetOk("generate_signing_key")
if ok {
if generateSigningKeyRaw.(bool) {
if publicKey != "" || privateKey != "" {
return logical.ErrorResponse("public_key and private_key is not required when generate_signing_key is set to true"), nil
}
publicKey, privateKey, err = generateSSHKeyPair()
if err != nil {
return nil, err
}
signingKeyGenerated = true
} }
} }
if !signingKeyGenerated {
if publicKey == "" {
return logical.ErrorResponse("missing public_key"), nil
}
if privateKey == "" {
return logical.ErrorResponse("missing private_key"), nil
}
_, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse private_key as an SSH private key: %v", err)), nil
}
_, err = parsePublicSSHKey(publicKey)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse public_key as an SSH public key: %v", err)), nil
}
}
if publicKey == "" || privateKey == "" {
return nil, fmt.Errorf("failed to generate or parse the keys")
}
err = req.Storage.Put(&logical.StorageEntry{ err = req.Storage.Put(&logical.StorageEntry{
Key: "public_key", Key: "public_key",
Value: []byte(publicKey), Value: []byte(publicKey),
@ -81,7 +109,6 @@ func (b *backend) pathCAWrite(req *logical.Request, data *framework.FieldData) (
return nil, err return nil, err
} }
func generateSSHKeyPair() (string, string, error) { func generateSSHKeyPair() (string, string, error) {
privateSeed, err := rsa.GenerateKey(rand.Reader, 4096) privateSeed, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil { if err != nil {
@ -89,9 +116,9 @@ func generateSSHKeyPair() (string, string, error) {
} }
privateBlock := &pem.Block{ privateBlock := &pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Headers: nil, Headers: nil,
Bytes: x509.MarshalPKCS1PrivateKey(privateSeed), Bytes: x509.MarshalPKCS1PrivateKey(privateSeed),
} }
public, err := ssh.NewPublicKey(&privateSeed.PublicKey) public, err := ssh.NewPublicKey(&privateSeed.PublicKey)
@ -101,28 +128,3 @@ func generateSSHKeyPair() (string, string, error) {
return string(ssh.MarshalAuthorizedKey(public)), string(pem.EncodeToMemory(privateBlock)), nil return string(ssh.MarshalAuthorizedKey(public)), string(pem.EncodeToMemory(privateBlock)), nil
} }
func parseSSHKeyPair(data *framework.FieldData) (string, string, error) {
publicKey := data.Get("public_key").(string)
if publicKey == "" {
return "", "", errutil.UserError{Err: `missing public_key`}
}
privateKey := data.Get("private_key").(string)
if privateKey == "" {
return "", "", errutil.UserError{Err: `missing public_key`}
}
_, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return "", "", errutil.UserError{Err: fmt.Sprintf(`Unable to parse "private_key" as an SSH private key: %s`, err)}
}
_, err = parsePublicSSHKey(publicKey)
if err != nil {
return "", "", errutil.UserError{Err: fmt.Sprintf(`Unable to parse "public_key" as an SSH public key: %s`, err)}
}
return publicKey, privateKey, nil
}