From c008a8d796f1c8a5c3246d264f7619348438e5e4 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Wed, 12 Aug 2015 13:09:32 -0700 Subject: [PATCH] Vault SSH: Moved agent's client creation code to Vault's source --- api/ssh_agent.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/api/ssh_agent.go b/api/ssh_agent.go index 8fd31caa89..72e82a6fc8 100644 --- a/api/ssh_agent.go +++ b/api/ssh_agent.go @@ -3,11 +3,13 @@ package api import ( "crypto/tls" "crypto/x509" + "encoding/pem" "fmt" "io/ioutil" "net" "net/http" "os" + "path/filepath" "time" "github.com/hashicorp/hcl" @@ -65,6 +67,39 @@ func (c *SSHAgentConfig) TLSClient(certPool *x509.CertPool) *http.Client { return &client } +// Returns a new client for the given configuration. This client will be used +// SSH agent to communicate with Vault server to verify the OTP entered by user. +// If the configuration supplies Vault SSL certificates, then the client will +// have tls configured in its transport. +func (c *SSHAgentConfig) NewClient() (*Client, error) { + // Creating a default client configuration for communicating with vault server. + clientConfig := DefaultConfig() + + // Pointing the client to the actual address of vault server. + clientConfig.Address = c.VaultAddr + + if c.CACert != "" || c.CAPath != "" || c.TLSSkipVerify { + var certPool *x509.CertPool + var err error + if c.CACert != "" { + certPool, err = loadCACert(c.CACert) + } else if c.CAPath != "" { + certPool, err = loadCAPath(c.CAPath) + } + if err != nil { + return nil, err + } + clientConfig.HttpClient = c.TLSClient(certPool) + } + + // Creating the client object for the given configuration + client, err := NewClient(clientConfig) + if err != nil { + return nil, err + } + return client, nil +} + // Loads agent's configuration from the file and populates the corresponding // in memory structure. func LoadSSHAgentConfig(path string) (*SSHAgentConfig, error) { @@ -135,3 +170,74 @@ func (c *SSHAgent) Verify(otp string) (*SSHVerifyResponse, error) { } return &verifyResp, nil } + +// Loads the certificate from given path and creates a certificate pool from it. +func loadCACert(path string) (*x509.CertPool, error) { + certs, err := loadCertFromPEM(path) + if err != nil { + return nil, err + } + + result := x509.NewCertPool() + for _, cert := range certs { + result.AddCert(cert) + } + + return result, nil +} + +// Loads the certificates present in the given directory and creates a +// certificate pool from it. +func loadCAPath(path string) (*x509.CertPool, error) { + result := x509.NewCertPool() + fn := func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + certs, err := loadCertFromPEM(path) + if err != nil { + return err + } + + for _, cert := range certs { + result.AddCert(cert) + } + return nil + } + + return result, filepath.Walk(path, fn) +} + +// Creates a certificate from the given path +func loadCertFromPEM(path string) ([]*x509.Certificate, error) { + pemCerts, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + certs := make([]*x509.Certificate, 0, 5) + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + + certs = append(certs, cert) + } + + return certs, nil +}