diff --git a/CHANGELOG.md b/CHANGELOG.md index 14b9a08b1d..0c834ec120 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ IMPROVEMENTS: * api: API client now uses a 60 second timeout instead of indefinite [GH-681] * api: Implement LookupSelf, RenewSelf, and RevokeSelf functions for auth tokens [GH-739] + * api: Standardize environment variable reading logic inside the API; the CLI + now uses this but can still override via command-line parameters [GH-618] * audit: HMAC-SHA256'd client tokens are now stored with each request entry. Previously they were only displayed at creation time; this allows much better traceability of client actions. [GH-713] diff --git a/api/client.go b/api/client.go index d5ef2e5192..16e6f62800 100644 --- a/api/client.go +++ b/api/client.go @@ -1,11 +1,17 @@ package api import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" "errors" "fmt" + "io/ioutil" "net/http" "net/url" "os" + "path/filepath" + "strconv" "strings" "sync" "time" @@ -13,6 +19,13 @@ import ( "github.com/hashicorp/go-cleanhttp" ) +const EnvVaultAddress = "VAULT_ADDR" +const EnvVaultCACert = "VAULT_CACERT" +const EnvVaultCAPath = "VAULT_CAPATH" +const EnvVaultClientCert = "VAULT_CLIENT_CERT" +const EnvVaultClientKey = "VAULT_CLIENT_KEY" +const EnvVaultInsecure = "VAULT_SKIP_VERIFY" + var ( errRedirect = errors.New("redirect") ) @@ -44,14 +57,99 @@ func DefaultConfig() *Config { HttpClient: cleanhttp.DefaultClient(), } config.HttpClient.Timeout = time.Second * 60 + transport := config.HttpClient.Transport.(*http.Transport) + transport.TLSHandshakeTimeout = 10 * time.Second + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } - if addr := os.Getenv("VAULT_ADDR"); addr != "" { - config.Address = addr + if v := os.Getenv(EnvVaultAddress); v != "" { + config.Address = v } return config } +// ReadEnvironment reads configuration information from the +// environment. If there is an error, no configuration value +// is updated. +func (c *Config) ReadEnvironment() error { + var envAddress string + var envCACert string + var envCAPath string + var envClientCert string + var envClientKey string + var envInsecure bool + var foundInsecure bool + + var newCertPool *x509.CertPool + var clientCert tls.Certificate + var foundClientCert bool + + if v := os.Getenv(EnvVaultAddress); v != "" { + envAddress = v + } + if v := os.Getenv(EnvVaultCACert); v != "" { + envCACert = v + } + if v := os.Getenv(EnvVaultCAPath); v != "" { + envCAPath = v + } + if v := os.Getenv(EnvVaultClientCert); v != "" { + envClientCert = v + } + if v := os.Getenv(EnvVaultClientKey); v != "" { + envClientKey = v + } + if v := os.Getenv(EnvVaultInsecure); v != "" { + var err error + envInsecure, err = strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("Could not parse VAULT_SKIP_VERIFY") + } + foundInsecure = true + } + // If we need custom TLS configuration, then set it + if envCACert != "" || envCAPath != "" || envClientCert != "" || envClientKey != "" || envInsecure { + var err error + if envCACert != "" { + newCertPool, err = LoadCACert(envCACert) + } else if envCAPath != "" { + newCertPool, err = LoadCAPath(envCAPath) + } + if err != nil { + return fmt.Errorf("Error setting up CA path: %s", err) + } + + if envClientCert != "" && envClientKey != "" { + clientCert, err = tls.LoadX509KeyPair(envClientCert, envClientKey) + if err != nil { + return err + } + foundClientCert = true + } else if envClientCert != "" || envClientKey != "" { + return fmt.Errorf("Both client cert and client key must be provided") + } + } + + if envAddress != "" { + c.Address = envAddress + } + + clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig + if foundInsecure { + clientTLSConfig.InsecureSkipVerify = envInsecure + } + if newCertPool != nil { + clientTLSConfig.RootCAs = newCertPool + } + if foundClientCert { + clientTLSConfig.Certificates = []tls.Certificate{clientCert} + } + + return nil +} + // Client is the client to the Vault API. Create a client with // NewClient. type Client struct { @@ -66,6 +164,7 @@ type Client struct { // automatically added to the client. Otherwise, you must manually call // `SetToken()`. func NewClient(c *Config) (*Client, error) { + u, err := url.Parse(c.Address) if err != nil { return nil, err @@ -203,3 +302,74 @@ START: return result, nil } + +// Loads the certificate from given path and creates a certificate pool from it. +func LoadCACert(path string) (*x509.CertPool, error) { + certs, err := loadCertFromPEM(path) + if err != nil { + return nil, err + } + + result := x509.NewCertPool() + for _, cert := range certs { + result.AddCert(cert) + } + + return result, nil +} + +// Loads the certificates present in the given directory and creates a +// certificate pool from it. +func LoadCAPath(path string) (*x509.CertPool, error) { + result := x509.NewCertPool() + fn := func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + certs, err := loadCertFromPEM(path) + if err != nil { + return err + } + + for _, cert := range certs { + result.AddCert(cert) + } + return nil + } + + return result, filepath.Walk(path, fn) +} + +// Creates a certificate from the given path +func loadCertFromPEM(path string) ([]*x509.Certificate, error) { + pemCerts, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + certs := make([]*x509.Certificate, 0, 5) + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + + certs = append(certs, cert) + } + + return certs, nil +} diff --git a/api/client_test.go b/api/client_test.go index d1f709e4c0..5a5b48940f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -99,3 +99,38 @@ func TestClientRedirect(t *testing.T) { t.Fatalf("Bad: %s", buf.String()) } } + +func TestClientEnvSettings(t *testing.T) { + cwd, _ := os.Getwd() + oldCACert := os.Getenv(EnvVaultCACert) + oldCAPath := os.Getenv(EnvVaultCAPath) + oldClientCert := os.Getenv(EnvVaultClientCert) + oldClientKey := os.Getenv(EnvVaultClientKey) + oldSkipVerify := os.Getenv(EnvVaultInsecure) + os.Setenv("VAULT_CACERT", cwd+"/../test/key/ourdomain.cer") + os.Setenv("VAULT_CAPATH", cwd+"/../test/key") + os.Setenv("VAULT_CLIENT_CERT", cwd+"/../test/key/ourdomain.cer") + os.Setenv("VAULT_CLIENT_KEY", cwd+"/../test/key/ourdomain.key") + os.Setenv("VAULT_SKIP_VERIFY", "true") + defer os.Setenv("VAULT_CACERT", oldCACert) + defer os.Setenv("VAULT_CAPATH", oldCAPath) + defer os.Setenv("VAULT_CLIENT_CERT", oldClientCert) + defer os.Setenv("VAULT_CLIENT_KEY", oldClientKey) + defer os.Setenv("VAULT_SKIP_VERIFY", oldSkipVerify) + + config := DefaultConfig() + if err := config.ReadEnvironment(); err != nil { + t.Fatalf("error reading environment: %v", err) + } + + tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig + if len(tlsConfig.RootCAs.Subjects()) == 0 { + t.Fatalf("bad: expected a cert pool with at least one subject") + } + if len(tlsConfig.Certificates) != 1 { + t.Fatalf("bad: expected client tls config to have a client certificate") + } + if tlsConfig.InsecureSkipVerify != true { + t.Fatalf("bad: %s", tlsConfig.InsecureSkipVerify) + } +} diff --git a/api/ssh_agent.go b/api/ssh_agent.go index dcf654bd8b..9d24043d53 100644 --- a/api/ssh_agent.go +++ b/api/ssh_agent.go @@ -3,13 +3,11 @@ package api import ( "crypto/tls" "crypto/x509" - "encoding/pem" "fmt" "io/ioutil" "net" "net/http" "os" - "path/filepath" "time" "github.com/hashicorp/hcl" @@ -95,9 +93,9 @@ func (c *SSHAgentConfig) NewClient() (*Client, error) { var certPool *x509.CertPool var err error if c.CACert != "" { - certPool, err = loadCACert(c.CACert) + certPool, err = LoadCACert(c.CACert) } else if c.CAPath != "" { - certPool, err = loadCAPath(c.CAPath) + certPool, err = LoadCAPath(c.CAPath) } if err != nil { return nil, err @@ -199,74 +197,3 @@ func (c *SSHAgent) Verify(otp string) (*SSHVerifyResponse, error) { } return &verifyResp, nil } - -// Loads the certificate from given path and creates a certificate pool from it. -func loadCACert(path string) (*x509.CertPool, error) { - certs, err := loadCertFromPEM(path) - if err != nil { - return nil, err - } - - result := x509.NewCertPool() - for _, cert := range certs { - result.AddCert(cert) - } - - return result, nil -} - -// Loads the certificates present in the given directory and creates a -// certificate pool from it. -func loadCAPath(path string) (*x509.CertPool, error) { - result := x509.NewCertPool() - fn := func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if info.IsDir() { - return nil - } - - certs, err := loadCertFromPEM(path) - if err != nil { - return err - } - - for _, cert := range certs { - result.AddCert(cert) - } - return nil - } - - return result, filepath.Walk(path, fn) -} - -// Creates a certificate from the given path -func loadCertFromPEM(path string) ([]*x509.Certificate, error) { - pemCerts, err := ioutil.ReadFile(path) - if err != nil { - return nil, err - } - - certs := make([]*x509.Certificate, 0, 5) - for len(pemCerts) > 0 { - var block *pem.Block - block, pemCerts = pem.Decode(pemCerts) - if block == nil { - break - } - if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { - continue - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, err - } - - certs = append(certs, cert) - } - - return certs, nil -} diff --git a/command/meta.go b/command/meta.go index 2537f5fa14..7087a2cc03 100644 --- a/command/meta.go +++ b/command/meta.go @@ -9,27 +9,17 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "os" "path/filepath" - "strconv" "strings" - "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/token" "github.com/mitchellh/cli" ) -// EnvVaultAddress can be used to set the address of Vault -const EnvVaultAddress = "VAULT_ADDR" -const EnvVaultCACert = "VAULT_CACERT" -const EnvVaultCAPath = "VAULT_CAPATH" -const EnvVaultClientCert = "VAULT_CLIENT_CERT" -const EnvVaultClientKey = "VAULT_CLIENT_KEY" -const EnvVaultInsecure = "VAULT_SKIP_VERIFY" - // FlagSetFlags is an enum to define what flags are present in the // default FlagSet returned by Meta.FlagSet. type FlagSetFlags uint @@ -67,51 +57,40 @@ type Meta struct { // flag settings for this command. func (m *Meta) Client() (*api.Client, error) { config := api.DefaultConfig() - if v := os.Getenv(EnvVaultAddress); v != "" { - config.Address = v + + err := config.ReadEnvironment() + if err != nil { + return nil, errwrap.Wrapf("error reading environment: {{err}}", err) } + if m.flagAddress != "" { config.Address = m.flagAddress } if m.ForceAddress != "" { config.Address = m.ForceAddress } - if v := os.Getenv(EnvVaultCACert); v != "" { - m.flagCACert = v - } - if v := os.Getenv(EnvVaultCAPath); v != "" { - m.flagCAPath = v - } - if v := os.Getenv(EnvVaultClientCert); v != "" { - m.flagClientCert = v - } - if v := os.Getenv(EnvVaultClientKey); v != "" { - m.flagClientKey = v - } - if v := os.Getenv(EnvVaultInsecure); v != "" { - var err error - m.flagInsecure, err = strconv.ParseBool(v) - if err != nil { - return nil, fmt.Errorf("Invalid value passed in for -insecure flag: %s", err) - } - } // If we need custom TLS configuration, then set it if m.flagCACert != "" || m.flagCAPath != "" || m.flagClientCert != "" || m.flagClientKey != "" || m.flagInsecure { + // We may have set items from the environment so start with the + // existing TLS config + tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig + var certPool *x509.CertPool var err error if m.flagCACert != "" { - certPool, err = m.loadCACert(m.flagCACert) + certPool, err = api.LoadCACert(m.flagCACert) } else if m.flagCAPath != "" { - certPool, err = m.loadCAPath(m.flagCAPath) + certPool, err = api.LoadCAPath(m.flagCAPath) } if err != nil { - return nil, fmt.Errorf("Error setting up CA path: %s", err) + return nil, errwrap.Wrapf("Error setting up CA path: {{err}}", err) } - tlsConfig := &tls.Config{ - InsecureSkipVerify: m.flagInsecure, - MinVersion: tls.VersionTLS12, - RootCAs: certPool, + if certPool != nil { + tlsConfig.RootCAs = certPool + } + if m.flagInsecure { + tlsConfig.InsecureSkipVerify = true } if m.flagClientCert != "" && m.flagClientKey != "" { @@ -123,20 +102,6 @@ func (m *Meta) Client() (*api.Client, error) { } else if m.flagClientCert != "" || m.flagClientKey != "" { return nil, fmt.Errorf("Both client cert and client key must be provided") } - - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSClientConfig: tlsConfig, - TLSHandshakeTimeout: 10 * time.Second, - }, - } - - config.HttpClient = client } // Build the client diff --git a/command/meta_test.go b/command/meta_test.go index ee463f32ac..c9988b2043 100644 --- a/command/meta_test.go +++ b/command/meta_test.go @@ -2,7 +2,6 @@ package command import ( "flag" - "os" "reflect" "sort" "testing" @@ -40,37 +39,3 @@ func TestFlagSet(t *testing.T) { } } } - -func TestEnvSettings(t *testing.T) { - os.Setenv("VAULT_CACERT", "/path/to/fake/cert.crt") - os.Setenv("VAULT_CAPATH", "/path/to/fake/certs") - os.Setenv("VAULT_CLIENT_CERT", "/path/to/fake/client.crt") - os.Setenv("VAULT_CLIENT_KEY", "/path/to/fake/client.key") - os.Setenv("VAULT_SKIP_VERIFY", "true") - defer os.Setenv("VAULT_CACERT", "") - defer os.Setenv("VAULT_CAPATH", "") - defer os.Setenv("VAULT_CLIENT_CERT", "") - defer os.Setenv("VAULT_CLIENT_KEY", "") - defer os.Setenv("VAULT_SKIP_VERIFY", "") - var m Meta - - // Err is ignored as it is expected that the test settings - // will cause errors; just check the flag settings - m.Client() - - if m.flagCACert != "/path/to/fake/cert.crt" { - t.Fatalf("bad: %s", m.flagAddress) - } - if m.flagCAPath != "/path/to/fake/certs" { - t.Fatalf("bad: %s", m.flagAddress) - } - if m.flagClientCert != "/path/to/fake/client.crt" { - t.Fatalf("bad: %s", m.flagAddress) - } - if m.flagClientKey != "/path/to/fake/client.key" { - t.Fatalf("bad: %s", m.flagAddress) - } - if m.flagInsecure != true { - t.Fatalf("bad: %s", m.flagAddress) - } -}