From 70bbaf4115a28b31bd06cb718dbb6ff9192573ac Mon Sep 17 00:00:00 2001 From: Chris Hoffman <99742+chrishoffman@users.noreply.github.com> Date: Tue, 23 Apr 2019 15:13:56 -0400 Subject: [PATCH] refactoring to unit test transit seal (#6605) --- vault/seal/transit/transit.go | 184 ++-------------------- vault/seal/transit/transit_acc_test.go | 7 +- vault/seal/transit/transit_client.go | 203 +++++++++++++++++++++++++ vault/seal/transit/transit_test.go | 82 ++++++++++ 4 files changed, 304 insertions(+), 172 deletions(-) create mode 100644 vault/seal/transit/transit_client.go create mode 100644 vault/seal/transit/transit_test.go diff --git a/vault/seal/transit/transit.go b/vault/seal/transit/transit.go index bcb14ec21b..eface946a3 100644 --- a/vault/seal/transit/transit.go +++ b/vault/seal/transit/transit.go @@ -2,12 +2,7 @@ package transit import ( "context" - "encoding/base64" "errors" - "fmt" - "os" - "path" - "strconv" "strings" "sync/atomic" "time" @@ -15,7 +10,6 @@ import ( "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" - "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/vault/seal" ) @@ -23,13 +17,8 @@ import ( // Seal is a seal that leverages Vault's Transit secret // engine type Seal struct { - logger log.Logger - client *api.Client - renewer *api.Renewer - - mountPath string - keyName string - + logger log.Logger + client transitClientEncryptor currentKeyID *atomic.Value } @@ -47,142 +36,16 @@ func NewSeal(logger log.Logger) *Seal { // SetConfig processes the config info from the server config func (s *Seal) SetConfig(config map[string]string) (map[string]string, error) { - if config == nil { - config = map[string]string{} + client, sealInfo, err := newTransitClient(s.logger, config) + if err != nil { + return nil, err } + s.client = client - switch { - case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "": - s.mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") - case config["mount_path"] != "": - s.mountPath = config["mount_path"] - default: - return nil, fmt.Errorf("mount_path is required") - } - - switch { - case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "": - s.keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") - case config["key_name"] != "": - s.keyName = config["key_name"] - default: - return nil, fmt.Errorf("key_name is required") - } - - var disableRenewal bool - var disableRenewalRaw string - switch { - case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "": - disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") - case config["disable_renewal"] != "": - disableRenewalRaw = config["disable_renewal"] - } - if disableRenewalRaw != "" { - var err error - disableRenewal, err = strconv.ParseBool(disableRenewalRaw) - if err != nil { - return nil, err - } - } - - var namespace string - switch { - case os.Getenv("VAULT_NAMESPACE") != "": - namespace = os.Getenv("VAULT_NAMESPACE") - case config["namespace"] != "": - namespace = config["namespace"] - } - - apiConfig := api.DefaultConfig() - if config["address"] != "" { - apiConfig.Address = config["address"] - } - if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" || - config["tls_server_name"] != "" || config["tls_skip_verify"] != "" { - var tlsSkipVerify bool - if config["tls_skip_verify"] != "" { - var err error - tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"]) - if err != nil { - return nil, err - } - } - - tlsConfig := &api.TLSConfig{ - CACert: config["tls_ca_cert"], - CAPath: config["tls_ca_path"], - ClientCert: config["tls_client_cert"], - ClientKey: config["tls_client_key"], - TLSServerName: config["tls_server_name"], - Insecure: tlsSkipVerify, - } - if err := apiConfig.ConfigureTLS(tlsConfig); err != nil { - return nil, err - } - } - - if s.client == nil { - client, err := api.NewClient(apiConfig) - if err != nil { - return nil, err - } - if config["token"] != "" { - client.SetToken(config["token"]) - } - if namespace != "" { - client.SetNamespace(namespace) - } - if client.Token() == "" { - return nil, errors.New("missing token") - } - s.client = client - - // Send a value to test the seal and to set the current key id - if _, err := s.Encrypt(context.Background(), []byte("a")); err != nil { - return nil, err - } - - if !disableRenewal { - // Renew the token immediately to get a secret to pass to renewer - secret, err := client.Auth().Token().RenewTokenAsSelf(s.client.Token(), 0) - // If we don't get an error renewing, set up a renewer. The token may not be renewable or not have - // permission to renew-self. - if err == nil { - renewer, err := s.client.NewRenewer(&api.RenewerInput{ - Secret: secret, - }) - if err != nil { - return nil, err - } - s.renewer = renewer - - go func() { - for { - select { - case err := <-renewer.DoneCh(): - s.logger.Info("shutting down token renewal") - if err != nil { - s.logger.Error("error renewing token", "error", err) - } - return - case <-renewer.RenewCh(): - s.logger.Trace("successfully renewed token") - } - } - }() - go s.renewer.Renew() - } else { - s.logger.Info("unable to renew token, disabling renewal", "err", err) - } - } - } - - sealInfo := make(map[string]string) - sealInfo["address"] = s.client.Address() - sealInfo["mount_path"] = s.mountPath - sealInfo["key_name"] = s.keyName - if namespace != "" { - sealInfo["namespace"] = namespace + // Send a value to test the seal and to set the current key id + if _, err := s.Encrypt(context.Background(), []byte("a")); err != nil { + client.Close() + return nil, err } return sealInfo, nil @@ -195,10 +58,7 @@ func (s *Seal) Init(_ context.Context) error { // Finalize is called during shutdown func (s *Seal) Finalize(_ context.Context) error { - if s.renewer != nil { - s.renewer.Stop() - } - + s.client.Close() return nil } @@ -227,17 +87,12 @@ func (s *Seal) Encrypt(_ context.Context, plaintext []byte) (blob *physical.Encr metrics.IncrCounter([]string{"seal", "encrypt"}, 1) metrics.IncrCounter([]string{"seal", "transit", "encrypt"}, 1) - encPlaintext := base64.StdEncoding.EncodeToString(plaintext) - path := path.Join(s.mountPath, "encrypt", s.keyName) - secret, err := s.client.Logical().Write(path, map[string]interface{}{ - "plaintext": encPlaintext, - }) + ciphertext, err := s.client.Encrypt(plaintext) if err != nil { return nil, err } - ciphertext := secret.Data["ciphertext"].(string) - splitKey := strings.Split(ciphertext, ":") + splitKey := strings.Split(string(ciphertext), ":") if len(splitKey) != 3 { return nil, errors.New("invalid ciphertext returned") } @@ -245,7 +100,7 @@ func (s *Seal) Encrypt(_ context.Context, plaintext []byte) (blob *physical.Encr s.currentKeyID.Store(keyID) ret := &physical.EncryptedBlobInfo{ - Ciphertext: []byte(ciphertext), + Ciphertext: ciphertext, KeyInfo: &physical.SealKeyInfo{ KeyID: keyID, }, @@ -268,18 +123,9 @@ func (s *Seal) Decrypt(_ context.Context, in *physical.EncryptedBlobInfo) (pt [] metrics.IncrCounter([]string{"seal", "decrypt"}, 1) metrics.IncrCounter([]string{"seal", "transit", "decrypt"}, 1) - path := path.Join(s.mountPath, "decrypt", s.keyName) - secret, err := s.client.Logical().Write(path, map[string]interface{}{ - "ciphertext": string(in.Ciphertext), - }) + plaintext, err := s.client.Decrypt(in.Ciphertext) if err != nil { return nil, err } - - plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string)) - if err != nil { - return nil, err - } - return plaintext, nil } diff --git a/vault/seal/transit/transit_acc_test.go b/vault/seal/transit/transit_acc_test.go index 405cbd339c..11896b76b8 100644 --- a/vault/seal/transit/transit_acc_test.go +++ b/vault/seal/transit/transit_acc_test.go @@ -1,4 +1,4 @@ -package transit +package transit_test import ( "context" @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/testhelpers/docker" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/vault/seal/transit" "github.com/ory/dockertest" ) @@ -30,7 +31,7 @@ func TestTransitSeal_Lifecycle(t *testing.T) { "mount_path": mountPath, "key_name": keyName, } - s := NewSeal(logging.NewVaultLogger(log.Trace)) + s := transit.NewSeal(logging.NewVaultLogger(log.Trace)) _, err := s.SetConfig(sealConfig) if err != nil { t.Fatalf("error setting seal config: %v", err) @@ -87,7 +88,7 @@ func TestTransitSeal_TokenRenewal(t *testing.T) { "mount_path": mountPath, "key_name": keyName, } - s := NewSeal(logging.NewVaultLogger(log.Trace)) + s := transit.NewSeal(logging.NewVaultLogger(log.Trace)) _, err = s.SetConfig(sealConfig) if err != nil { t.Fatalf("error setting seal config: %v", err) diff --git a/vault/seal/transit/transit_client.go b/vault/seal/transit/transit_client.go new file mode 100644 index 0000000000..b9d6c1f12d --- /dev/null +++ b/vault/seal/transit/transit_client.go @@ -0,0 +1,203 @@ +package transit + +import ( + "encoding/base64" + "errors" + "fmt" + "os" + "path" + "strconv" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" +) + +type transitClientEncryptor interface { + Close() + Encrypt(plaintext []byte) (ciphertext []byte, err error) + Decrypt(ciphertext []byte) (plaintext []byte, err error) +} + +type transitClient struct { + client *api.Client + renewer *api.Renewer + + mountPath string + keyName string +} + +func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) { + if config == nil { + config = map[string]string{} + } + + var mountPath, keyName string + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "": + mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") + case config["mount_path"] != "": + mountPath = config["mount_path"] + default: + return nil, nil, fmt.Errorf("mount_path is required") + } + + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "": + keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") + case config["key_name"] != "": + keyName = config["key_name"] + default: + return nil, nil, fmt.Errorf("key_name is required") + } + + var disableRenewal bool + var disableRenewalRaw string + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "": + disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") + case config["disable_renewal"] != "": + disableRenewalRaw = config["disable_renewal"] + } + if disableRenewalRaw != "" { + var err error + disableRenewal, err = strconv.ParseBool(disableRenewalRaw) + if err != nil { + return nil, nil, err + } + } + + var namespace string + switch { + case os.Getenv("VAULT_NAMESPACE") != "": + namespace = os.Getenv("VAULT_NAMESPACE") + case config["namespace"] != "": + namespace = config["namespace"] + } + + apiConfig := api.DefaultConfig() + if config["address"] != "" { + apiConfig.Address = config["address"] + } + if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" || + config["tls_server_name"] != "" || config["tls_skip_verify"] != "" { + var tlsSkipVerify bool + if config["tls_skip_verify"] != "" { + var err error + tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"]) + if err != nil { + return nil, nil, err + } + } + + tlsConfig := &api.TLSConfig{ + CACert: config["tls_ca_cert"], + CAPath: config["tls_ca_path"], + ClientCert: config["tls_client_cert"], + ClientKey: config["tls_client_key"], + TLSServerName: config["tls_server_name"], + Insecure: tlsSkipVerify, + } + if err := apiConfig.ConfigureTLS(tlsConfig); err != nil { + return nil, nil, err + } + } + + apiClient, err := api.NewClient(apiConfig) + if err != nil { + return nil, nil, err + } + if config["token"] != "" { + apiClient.SetToken(config["token"]) + } + if namespace != "" { + apiClient.SetNamespace(namespace) + } + if apiClient.Token() == "" { + return nil, nil, errors.New("missing token") + } + + client := &transitClient{ + client: apiClient, + mountPath: mountPath, + keyName: keyName, + } + + if !disableRenewal { + // Renew the token immediately to get a secret to pass to renewer + secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0) + // If we don't get an error renewing, set up a renewer. The token may not be renewable or not have + // permission to renew-self. + if err == nil { + renewer, err := apiClient.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + return nil, nil, err + } + client.renewer = renewer + + go func() { + for { + select { + case err := <-renewer.DoneCh(): + logger.Info("shutting down token renewal") + if err != nil { + logger.Error("error renewing token", "error", err) + } + return + case <-renewer.RenewCh(): + logger.Trace("successfully renewed token") + } + } + }() + go renewer.Renew() + } else { + logger.Info("unable to renew token, disabling renewal", "err", err) + } + } + + sealInfo := make(map[string]string) + sealInfo["address"] = apiClient.Address() + sealInfo["mount_path"] = mountPath + sealInfo["key_name"] = keyName + if namespace != "" { + sealInfo["namespace"] = namespace + } + + return client, sealInfo, nil +} + +func (c *transitClient) Close() { + if c.renewer != nil { + c.renewer.Stop() + } +} + +func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) { + encPlaintext := base64.StdEncoding.EncodeToString(plaintext) + path := path.Join(c.mountPath, "encrypt", c.keyName) + secret, err := c.client.Logical().Write(path, map[string]interface{}{ + "plaintext": encPlaintext, + }) + if err != nil { + return nil, err + } + + return []byte(secret.Data["ciphertext"].(string)), nil +} + +func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) { + path := path.Join(c.mountPath, "decrypt", c.keyName) + secret, err := c.client.Logical().Write(path, map[string]interface{}{ + "ciphertext": string(ciphertext), + }) + if err != nil { + return nil, err + } + + plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string)) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/vault/seal/transit/transit_test.go b/vault/seal/transit/transit_test.go new file mode 100644 index 0000000000..e0c7a38708 --- /dev/null +++ b/vault/seal/transit/transit_test.go @@ -0,0 +1,82 @@ +package transit + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/physical" + "github.com/hashicorp/vault/vault/seal" +) + +type testTransitClient struct { + keyID string + seal seal.Access +} + +func newTestTransitClient(keyID string) *testTransitClient { + return &testTransitClient{ + keyID: keyID, + seal: seal.NewTestSeal(nil), + } +} + +func (m *testTransitClient) Close() {} + +func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) { + v, err := m.seal.Encrypt(context.Background(), plaintext) + if err != nil { + return nil, err + } + + return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil +} + +func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) { + splitKey := strings.Split(string(ciphertext), ":") + if len(splitKey) != 3 { + return nil, errors.New("invalid ciphertext returned") + } + + data := &physical.EncryptedBlobInfo{ + Ciphertext: []byte(splitKey[2]), + } + v, err := m.seal.Decrypt(context.Background(), data) + if err != nil { + return nil, err + } + + return v, nil +} + +func TestTransitSeal_Lifecycle(t *testing.T) { + s := NewSeal(logging.NewVaultLogger(log.Trace)) + + keyID := "test-key" + s.client = newTestTransitClient(keyID) + + // Test Encrypt and Decrypt calls + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } + + if s.KeyID() != keyID { + t.Fatalf("key id does not match: expected %s, got %s", keyID, s.KeyID()) + } +}