From 2d0c3ff335b9bb1169a01119b2b06ad54311ed4e Mon Sep 17 00:00:00 2001 From: Chris Hoffman <99742+chrishoffman@users.noreply.github.com> Date: Thu, 28 Feb 2019 16:13:56 -0500 Subject: [PATCH] Transit Autounseal (#5995) * Adding Transit Autoseal * adding tests * adding more tests * updating seal info * send a value to test and set current key id * updating message * cleanup * Adding tls config, addressing some feedback * adding tls testing * renaming config fields for tls --- command/server/seal/server_seal.go | 3 + command/server/seal/server_seal_transit.go | 34 +++ vault/seal/seal.go | 1 + vault/seal/transit/transit.go | 256 +++++++++++++++++++++ vault/seal/transit/transit_acc_test.go | 211 +++++++++++++++++ 5 files changed, 505 insertions(+) create mode 100644 command/server/seal/server_seal_transit.go create mode 100644 vault/seal/transit/transit.go create mode 100644 vault/seal/transit/transit_acc_test.go diff --git a/command/server/seal/server_seal.go b/command/server/seal/server_seal.go index 2b42cd0874..70fca28146 100644 --- a/command/server/seal/server_seal.go +++ b/command/server/seal/server_seal.go @@ -34,6 +34,9 @@ func configureSeal(config *server.Config, infoKeys *[]string, info *map[string]s case seal.AzureKeyVault: return configureAzureKeyVaultSeal(config, infoKeys, info, logger, inseal) + case seal.Transit: + return configureTransitSeal(config, infoKeys, info, logger, inseal) + case seal.PKCS11: return nil, fmt.Errorf("Seal type 'pkcs11' requires the Vault Enterprise HSM binary") diff --git a/command/server/seal/server_seal_transit.go b/command/server/seal/server_seal_transit.go new file mode 100644 index 0000000000..29db5c12af --- /dev/null +++ b/command/server/seal/server_seal_transit.go @@ -0,0 +1,34 @@ +package seal + +import ( + "github.com/hashicorp/errwrap" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/seal/transit" +) + +func configureTransitSeal(config *server.Config, infoKeys *[]string, info *map[string]string, logger log.Logger, inseal vault.Seal) (vault.Seal, error) { + transitSeal := transit.NewSeal(logger) + sealInfo, err := transitSeal.SetConfig(config.Seal.Config) + if err != nil { + // If the error is any other than logical.KeyNotFoundError, return the error + if !errwrap.ContainsType(err, new(logical.KeyNotFoundError)) { + return nil, err + } + } + autoseal := vault.NewAutoSeal(transitSeal) + if sealInfo != nil { + *infoKeys = append(*infoKeys, "Seal Type", "Transit Address", "Transit Mount Path", "Transit Key Name") + (*info)["Seal Type"] = config.Seal.Type + (*info)["Transit Address"] = sealInfo["address"] + (*info)["Transit Mount Path"] = sealInfo["mount_path"] + (*info)["Transit Key Name"] = sealInfo["key_name"] + if namespace, ok := sealInfo["namespace"]; ok { + *infoKeys = append(*infoKeys, "Transit Namespace") + (*info)["Transit Namespace"] = namespace + } + } + return autoseal, nil +} diff --git a/vault/seal/seal.go b/vault/seal/seal.go index b80217a010..13552f99ca 100644 --- a/vault/seal/seal.go +++ b/vault/seal/seal.go @@ -13,6 +13,7 @@ const ( AWSKMS = "awskms" GCPCKMS = "gcpckms" AzureKeyVault = "azurekeyvault" + Transit = "transit" Test = "test-auto" // HSMAutoDeprecated is a deprecated seal type prior to 0.9.0. diff --git a/vault/seal/transit/transit.go b/vault/seal/transit/transit.go new file mode 100644 index 0000000000..a8a70da2bb --- /dev/null +++ b/vault/seal/transit/transit.go @@ -0,0 +1,256 @@ +package transit + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "os" + "path" + "strconv" + "strings" + "sync/atomic" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/vault/seal" +) + +// Seal is a seal that leverages Vault's Transit secret +// engine +type Seal struct { + logger log.Logger + client *api.Client + renewer *api.Renewer + + mountPath string + keyName string + + currentKeyID *atomic.Value +} + +var _ seal.Access = (*Seal)(nil) + +// NewSeal creates a new transit seal +func NewSeal(logger log.Logger) *Seal { + s := &Seal{ + logger: logger.ResetNamed("seal-transit"), + currentKeyID: new(atomic.Value), + } + s.currentKeyID.Store("") + return s +} + +// SetConfig processes the config info from the server config +func (s *Seal) SetConfig(config map[string]string) (map[string]string, error) { + if config == nil { + config = map[string]string{} + } + + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "": + s.mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") + case config["mount_path"] != "": + s.mountPath = config["mount_path"] + default: + return nil, fmt.Errorf("mount_path is required") + } + + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "": + s.keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") + case config["key_name"] != "": + s.keyName = config["key_name"] + default: + return nil, fmt.Errorf("key_name is required") + } + + var disableRenewal bool + var disableRenewalRaw string + switch { + case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "": + disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") + case config["disable_renewal"] != "": + disableRenewalRaw = config["disable_renewal"] + } + if disableRenewalRaw != "" { + var err error + disableRenewal, err = strconv.ParseBool(disableRenewalRaw) + if err != nil { + return nil, err + } + } + + var namespace string + switch { + case os.Getenv("VAULT_NAMESPACE") != "": + namespace = os.Getenv("VAULT_NAMESPACE") + case config["namespace"] != "": + namespace = config["namespace"] + } + + apiConfig := api.DefaultConfig() + if config["address"] != "" { + apiConfig.Address = config["address"] + } + if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" || + config["tls_server_name"] != "" || config["tls_skip_veriry"] != "" { + var tlsSkipVerify bool + if config["tls_skip_verify"] != "" { + var err error + tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"]) + if err != nil { + return nil, err + } + } + + tlsConfig := &api.TLSConfig{ + CACert: config["tls_ca_cert"], + CAPath: config["tls_ca_path"], + ClientCert: config["tls_client_cert"], + ClientKey: config["tls_client_key"], + TLSServerName: config["tls_server_name"], + Insecure: tlsSkipVerify, + } + if err := apiConfig.ConfigureTLS(tlsConfig); err != nil { + return nil, err + } + } + + if s.client == nil { + client, err := api.NewClient(apiConfig) + if err != nil { + return nil, err + } + if config["token"] != "" { + client.SetToken(config["token"]) + } + if namespace != "" { + client.SetNamespace(namespace) + } + if client.Token() == "" { + return nil, errors.New("missing token") + } + s.client = client + + // Send a value to test the seal and to set the current key id + if _, err := s.Encrypt(context.Background(), []byte("a")); err != nil { + return nil, err + } + + if !disableRenewal { + // Renew the token immediately to get a secret to pass to renewer + secret, err := client.Auth().Token().RenewTokenAsSelf(s.client.Token(), 0) + // If we don't get an error renewing, set up a renewer. The token may not be renewable or not have + // permission to renew-self. + if err == nil { + renewer, err := s.client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + return nil, err + } + s.renewer = renewer + + go func() { + for { + select { + case err := <-renewer.DoneCh(): + s.logger.Info("shutting down token renewal") + if err != nil { + s.logger.Error("error renewing token", "error", err) + } + return + case <-renewer.RenewCh(): + s.logger.Trace("successfully renewed token") + } + } + }() + go s.renewer.Renew() + } else { + s.logger.Info("unable to renew token, disabling renewal", "err", err) + } + } + } + + sealInfo := make(map[string]string) + sealInfo["address"] = s.client.Address() + sealInfo["mount_path"] = s.mountPath + sealInfo["key_name"] = s.keyName + if namespace != "" { + sealInfo["namespace"] = namespace + } + + return sealInfo, nil +} + +// Init is called during core.Initialize +func (s *Seal) Init(_ context.Context) error { + return nil +} + +// Finalize is called during shutdown +func (s *Seal) Finalize(_ context.Context) error { + if s.renewer != nil { + s.renewer.Stop() + } + + return nil +} + +// SealType returns the seal type for this particular seal implementation. +func (s *Seal) SealType() string { + return seal.Transit +} + +// KeyID returns the last known key id. +func (s *Seal) KeyID() string { + return s.currentKeyID.Load().(string) +} + +// Encrypt is used to encrypt using Vaults Transit engine +func (s *Seal) Encrypt(_ context.Context, plaintext []byte) (*physical.EncryptedBlobInfo, error) { + encPlaintext := base64.StdEncoding.EncodeToString(plaintext) + path := path.Join(s.mountPath, "encrypt", s.keyName) + secret, err := s.client.Logical().Write(path, map[string]interface{}{ + "plaintext": encPlaintext, + }) + if err != nil { + return nil, err + } + + ciphertext := secret.Data["ciphertext"].(string) + splitKey := strings.Split(ciphertext, ":") + if len(splitKey) != 3 { + return nil, errors.New("invalid ciphertext returned") + } + keyID := splitKey[1] + s.currentKeyID.Store(keyID) + + ret := &physical.EncryptedBlobInfo{ + Ciphertext: []byte(ciphertext), + KeyInfo: &physical.SealKeyInfo{ + KeyID: keyID, + }, + } + return ret, nil +} + +// Decrypt is used to decrypt the ciphertext +func (s *Seal) Decrypt(_ context.Context, in *physical.EncryptedBlobInfo) ([]byte, error) { + path := path.Join(s.mountPath, "decrypt", s.keyName) + secret, err := s.client.Logical().Write(path, map[string]interface{}{ + "ciphertext": string(in.Ciphertext), + }) + if err != nil { + return nil, err + } + + plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string)) + if err != nil { + return nil, err + } + + return plaintext, nil +} diff --git a/vault/seal/transit/transit_acc_test.go b/vault/seal/transit/transit_acc_test.go new file mode 100644 index 0000000000..633f0aa3d6 --- /dev/null +++ b/vault/seal/transit/transit_acc_test.go @@ -0,0 +1,211 @@ +package transit + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "path" + "reflect" + "runtime" + "testing" + "time" + + log "github.com/hashicorp/go-hclog" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/logging" + "github.com/ory/dockertest" +) + +func TestTransitSeal_Lifecycle(t *testing.T) { + cleanup, retAddress, token, mountPath, keyName, tlsConfig := prepareTestContainer(t) + defer cleanup() + + sealConfig := map[string]string{ + "address": retAddress, + "token": token, + "mount_path": mountPath, + "key_name": keyName, + "tls_ca_cert": tlsConfig.CACert, + "tls_client_cert": tlsConfig.ClientCert, + "tls_client_key": tlsConfig.ClientKey, + } + s := NewSeal(logging.NewVaultLogger(log.Trace)) + _, err := s.SetConfig(sealConfig) + if err != nil { + t.Fatalf("error setting seal config: %v", err) + } + + // Test Encrypt and Decrypt calls + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } +} + +func TestTransitSeal_TokenRenewal(t *testing.T) { + cleanup, retAddress, token, mountPath, keyName, tlsConfig := prepareTestContainer(t) + defer cleanup() + + clientConfig := &api.Config{ + Address: retAddress, + } + if err := clientConfig.ConfigureTLS(tlsConfig); err != nil { + t.Fatalf("err: %s", err) + } + + remoteClient, err := api.NewClient(clientConfig) + if err != nil { + t.Fatalf("err: %s", err) + } + remoteClient.SetToken(token) + + req := &api.TokenCreateRequest{ + Period: "5s", + } + rsp, err := remoteClient.Auth().Token().Create(req) + if err != nil { + t.Fatalf("err: %s", err) + } + + sealConfig := map[string]string{ + "address": retAddress, + "token": rsp.Auth.ClientToken, + "mount_path": mountPath, + "key_name": keyName, + "tls_ca_cert": tlsConfig.CACert, + "tls_client_cert": tlsConfig.ClientCert, + "tls_client_key": tlsConfig.ClientKey, + } + s := NewSeal(logging.NewVaultLogger(log.Trace)) + _, err = s.SetConfig(sealConfig) + if err != nil { + t.Fatalf("error setting seal config: %v", err) + } + + time.Sleep(7 * time.Second) + + // Test Encrypt and Decrypt calls + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } +} + +func prepareTestContainer(t *testing.T) (cleanup func(), retAddress, token, mountPath, keyName string, tlsConfig *api.TLSConfig) { + testToken, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("err: %s", err) + } + testMountPath, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("err: %s", err) + } + testKeyName, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("err: %s", err) + } + + var tempDir string + // Docker for Mac does not play nice with TempDir + if runtime.GOOS == "darwin" { + uniqueTempDir, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("err: %s", err) + } + tempDir = path.Join("/tmp", uniqueTempDir) + } else { + tempDir, err = ioutil.TempDir("", "transit-autoseal-test") + if err != nil { + t.Fatal(err) + } + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + dockerOptions := &dockertest.RunOptions{ + Repository: "vault", + Tag: "latest", + Cmd: []string{"server", "-log-level=trace", "-dev", "-dev-three-node", fmt.Sprintf("-dev-root-token-id=%s", testToken), + "-dev-listen-address=0.0.0.0:8200"}, + Env: []string{"VAULT_DEV_TEMP_DIR=/tmp"}, + Mounts: []string{fmt.Sprintf("%s:/tmp", tempDir)}, + } + resource, err := pool.RunWithOptions(dockerOptions) + if err != nil { + t.Fatalf("Could not start local Vault docker container: %s", err) + } + + cleanup = func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("error removing temp directory: %s", err) + } + + if err := pool.Purge(resource); err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retAddress = fmt.Sprintf("https://127.0.0.1:%s", resource.GetPort("8200/tcp")) + tlsConfig = &api.TLSConfig{ + CACert: path.Join(tempDir, "ca_cert.pem"), + ClientCert: path.Join(tempDir, "node1_port_8200_cert.pem"), + ClientKey: path.Join(tempDir, "node1_port_8200_key.pem"), + } + + // exponential backoff-retry + if err = pool.Retry(func() error { + vaultConfig := api.DefaultConfig() + vaultConfig.Address = retAddress + if err := vaultConfig.ConfigureTLS(tlsConfig); err != nil { + return err + } + vault, err := api.NewClient(vaultConfig) + if err != nil { + return err + } + vault.SetToken(testToken) + + // Set up transit + if err := vault.Sys().Mount(testMountPath, &api.MountInput{ + Type: "transit", + }); err != nil { + return err + } + + // Create default aesgcm key + if _, err := vault.Logical().Write(path.Join(testMountPath, "keys", testKeyName), map[string]interface{}{}); err != nil { + return err + } + + return nil + }); err != nil { + cleanup() + t.Fatalf("Could not connect to vault: %s", err) + } + return cleanup, retAddress, testToken, testMountPath, testKeyName, tlsConfig +}