mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-07 21:36:26 +02:00
refactoring to unit test transit seal (#6605)
This commit is contained in:
parent
1eabcc0eb4
commit
70bbaf4115
@ -2,12 +2,7 @@ package transit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -15,7 +10,6 @@ import (
|
||||
"github.com/armon/go-metrics"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/vault/seal"
|
||||
)
|
||||
@ -23,13 +17,8 @@ import (
|
||||
// Seal is a seal that leverages Vault's Transit secret
|
||||
// engine
|
||||
type Seal struct {
|
||||
logger log.Logger
|
||||
client *api.Client
|
||||
renewer *api.Renewer
|
||||
|
||||
mountPath string
|
||||
keyName string
|
||||
|
||||
logger log.Logger
|
||||
client transitClientEncryptor
|
||||
currentKeyID *atomic.Value
|
||||
}
|
||||
|
||||
@ -47,142 +36,16 @@ func NewSeal(logger log.Logger) *Seal {
|
||||
|
||||
// SetConfig processes the config info from the server config
|
||||
func (s *Seal) SetConfig(config map[string]string) (map[string]string, error) {
|
||||
if config == nil {
|
||||
config = map[string]string{}
|
||||
client, sealInfo, err := newTransitClient(s.logger, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.client = client
|
||||
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "":
|
||||
s.mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH")
|
||||
case config["mount_path"] != "":
|
||||
s.mountPath = config["mount_path"]
|
||||
default:
|
||||
return nil, fmt.Errorf("mount_path is required")
|
||||
}
|
||||
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "":
|
||||
s.keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME")
|
||||
case config["key_name"] != "":
|
||||
s.keyName = config["key_name"]
|
||||
default:
|
||||
return nil, fmt.Errorf("key_name is required")
|
||||
}
|
||||
|
||||
var disableRenewal bool
|
||||
var disableRenewalRaw string
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "":
|
||||
disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL")
|
||||
case config["disable_renewal"] != "":
|
||||
disableRenewalRaw = config["disable_renewal"]
|
||||
}
|
||||
if disableRenewalRaw != "" {
|
||||
var err error
|
||||
disableRenewal, err = strconv.ParseBool(disableRenewalRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var namespace string
|
||||
switch {
|
||||
case os.Getenv("VAULT_NAMESPACE") != "":
|
||||
namespace = os.Getenv("VAULT_NAMESPACE")
|
||||
case config["namespace"] != "":
|
||||
namespace = config["namespace"]
|
||||
}
|
||||
|
||||
apiConfig := api.DefaultConfig()
|
||||
if config["address"] != "" {
|
||||
apiConfig.Address = config["address"]
|
||||
}
|
||||
if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" ||
|
||||
config["tls_server_name"] != "" || config["tls_skip_verify"] != "" {
|
||||
var tlsSkipVerify bool
|
||||
if config["tls_skip_verify"] != "" {
|
||||
var err error
|
||||
tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig := &api.TLSConfig{
|
||||
CACert: config["tls_ca_cert"],
|
||||
CAPath: config["tls_ca_path"],
|
||||
ClientCert: config["tls_client_cert"],
|
||||
ClientKey: config["tls_client_key"],
|
||||
TLSServerName: config["tls_server_name"],
|
||||
Insecure: tlsSkipVerify,
|
||||
}
|
||||
if err := apiConfig.ConfigureTLS(tlsConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if s.client == nil {
|
||||
client, err := api.NewClient(apiConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config["token"] != "" {
|
||||
client.SetToken(config["token"])
|
||||
}
|
||||
if namespace != "" {
|
||||
client.SetNamespace(namespace)
|
||||
}
|
||||
if client.Token() == "" {
|
||||
return nil, errors.New("missing token")
|
||||
}
|
||||
s.client = client
|
||||
|
||||
// Send a value to test the seal and to set the current key id
|
||||
if _, err := s.Encrypt(context.Background(), []byte("a")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !disableRenewal {
|
||||
// Renew the token immediately to get a secret to pass to renewer
|
||||
secret, err := client.Auth().Token().RenewTokenAsSelf(s.client.Token(), 0)
|
||||
// If we don't get an error renewing, set up a renewer. The token may not be renewable or not have
|
||||
// permission to renew-self.
|
||||
if err == nil {
|
||||
renewer, err := s.client.NewRenewer(&api.RenewerInput{
|
||||
Secret: secret,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.renewer = renewer
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case err := <-renewer.DoneCh():
|
||||
s.logger.Info("shutting down token renewal")
|
||||
if err != nil {
|
||||
s.logger.Error("error renewing token", "error", err)
|
||||
}
|
||||
return
|
||||
case <-renewer.RenewCh():
|
||||
s.logger.Trace("successfully renewed token")
|
||||
}
|
||||
}
|
||||
}()
|
||||
go s.renewer.Renew()
|
||||
} else {
|
||||
s.logger.Info("unable to renew token, disabling renewal", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealInfo := make(map[string]string)
|
||||
sealInfo["address"] = s.client.Address()
|
||||
sealInfo["mount_path"] = s.mountPath
|
||||
sealInfo["key_name"] = s.keyName
|
||||
if namespace != "" {
|
||||
sealInfo["namespace"] = namespace
|
||||
// Send a value to test the seal and to set the current key id
|
||||
if _, err := s.Encrypt(context.Background(), []byte("a")); err != nil {
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sealInfo, nil
|
||||
@ -195,10 +58,7 @@ func (s *Seal) Init(_ context.Context) error {
|
||||
|
||||
// Finalize is called during shutdown
|
||||
func (s *Seal) Finalize(_ context.Context) error {
|
||||
if s.renewer != nil {
|
||||
s.renewer.Stop()
|
||||
}
|
||||
|
||||
s.client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -227,17 +87,12 @@ func (s *Seal) Encrypt(_ context.Context, plaintext []byte) (blob *physical.Encr
|
||||
metrics.IncrCounter([]string{"seal", "encrypt"}, 1)
|
||||
metrics.IncrCounter([]string{"seal", "transit", "encrypt"}, 1)
|
||||
|
||||
encPlaintext := base64.StdEncoding.EncodeToString(plaintext)
|
||||
path := path.Join(s.mountPath, "encrypt", s.keyName)
|
||||
secret, err := s.client.Logical().Write(path, map[string]interface{}{
|
||||
"plaintext": encPlaintext,
|
||||
})
|
||||
ciphertext, err := s.client.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := secret.Data["ciphertext"].(string)
|
||||
splitKey := strings.Split(ciphertext, ":")
|
||||
splitKey := strings.Split(string(ciphertext), ":")
|
||||
if len(splitKey) != 3 {
|
||||
return nil, errors.New("invalid ciphertext returned")
|
||||
}
|
||||
@ -245,7 +100,7 @@ func (s *Seal) Encrypt(_ context.Context, plaintext []byte) (blob *physical.Encr
|
||||
s.currentKeyID.Store(keyID)
|
||||
|
||||
ret := &physical.EncryptedBlobInfo{
|
||||
Ciphertext: []byte(ciphertext),
|
||||
Ciphertext: ciphertext,
|
||||
KeyInfo: &physical.SealKeyInfo{
|
||||
KeyID: keyID,
|
||||
},
|
||||
@ -268,18 +123,9 @@ func (s *Seal) Decrypt(_ context.Context, in *physical.EncryptedBlobInfo) (pt []
|
||||
metrics.IncrCounter([]string{"seal", "decrypt"}, 1)
|
||||
metrics.IncrCounter([]string{"seal", "transit", "decrypt"}, 1)
|
||||
|
||||
path := path.Join(s.mountPath, "decrypt", s.keyName)
|
||||
secret, err := s.client.Logical().Write(path, map[string]interface{}{
|
||||
"ciphertext": string(in.Ciphertext),
|
||||
})
|
||||
plaintext, err := s.client.Decrypt(in.Ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package transit
|
||||
package transit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/vault/seal/transit"
|
||||
"github.com/ory/dockertest"
|
||||
)
|
||||
|
||||
@ -30,7 +31,7 @@ func TestTransitSeal_Lifecycle(t *testing.T) {
|
||||
"mount_path": mountPath,
|
||||
"key_name": keyName,
|
||||
}
|
||||
s := NewSeal(logging.NewVaultLogger(log.Trace))
|
||||
s := transit.NewSeal(logging.NewVaultLogger(log.Trace))
|
||||
_, err := s.SetConfig(sealConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("error setting seal config: %v", err)
|
||||
@ -87,7 +88,7 @@ func TestTransitSeal_TokenRenewal(t *testing.T) {
|
||||
"mount_path": mountPath,
|
||||
"key_name": keyName,
|
||||
}
|
||||
s := NewSeal(logging.NewVaultLogger(log.Trace))
|
||||
s := transit.NewSeal(logging.NewVaultLogger(log.Trace))
|
||||
_, err = s.SetConfig(sealConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("error setting seal config: %v", err)
|
||||
|
||||
203
vault/seal/transit/transit_client.go
Normal file
203
vault/seal/transit/transit_client.go
Normal file
@ -0,0 +1,203 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
type transitClientEncryptor interface {
|
||||
Close()
|
||||
Encrypt(plaintext []byte) (ciphertext []byte, err error)
|
||||
Decrypt(ciphertext []byte) (plaintext []byte, err error)
|
||||
}
|
||||
|
||||
type transitClient struct {
|
||||
client *api.Client
|
||||
renewer *api.Renewer
|
||||
|
||||
mountPath string
|
||||
keyName string
|
||||
}
|
||||
|
||||
func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) {
|
||||
if config == nil {
|
||||
config = map[string]string{}
|
||||
}
|
||||
|
||||
var mountPath, keyName string
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "":
|
||||
mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH")
|
||||
case config["mount_path"] != "":
|
||||
mountPath = config["mount_path"]
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("mount_path is required")
|
||||
}
|
||||
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "":
|
||||
keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME")
|
||||
case config["key_name"] != "":
|
||||
keyName = config["key_name"]
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("key_name is required")
|
||||
}
|
||||
|
||||
var disableRenewal bool
|
||||
var disableRenewalRaw string
|
||||
switch {
|
||||
case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "":
|
||||
disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL")
|
||||
case config["disable_renewal"] != "":
|
||||
disableRenewalRaw = config["disable_renewal"]
|
||||
}
|
||||
if disableRenewalRaw != "" {
|
||||
var err error
|
||||
disableRenewal, err = strconv.ParseBool(disableRenewalRaw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var namespace string
|
||||
switch {
|
||||
case os.Getenv("VAULT_NAMESPACE") != "":
|
||||
namespace = os.Getenv("VAULT_NAMESPACE")
|
||||
case config["namespace"] != "":
|
||||
namespace = config["namespace"]
|
||||
}
|
||||
|
||||
apiConfig := api.DefaultConfig()
|
||||
if config["address"] != "" {
|
||||
apiConfig.Address = config["address"]
|
||||
}
|
||||
if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" ||
|
||||
config["tls_server_name"] != "" || config["tls_skip_verify"] != "" {
|
||||
var tlsSkipVerify bool
|
||||
if config["tls_skip_verify"] != "" {
|
||||
var err error
|
||||
tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig := &api.TLSConfig{
|
||||
CACert: config["tls_ca_cert"],
|
||||
CAPath: config["tls_ca_path"],
|
||||
ClientCert: config["tls_client_cert"],
|
||||
ClientKey: config["tls_client_key"],
|
||||
TLSServerName: config["tls_server_name"],
|
||||
Insecure: tlsSkipVerify,
|
||||
}
|
||||
if err := apiConfig.ConfigureTLS(tlsConfig); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
apiClient, err := api.NewClient(apiConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if config["token"] != "" {
|
||||
apiClient.SetToken(config["token"])
|
||||
}
|
||||
if namespace != "" {
|
||||
apiClient.SetNamespace(namespace)
|
||||
}
|
||||
if apiClient.Token() == "" {
|
||||
return nil, nil, errors.New("missing token")
|
||||
}
|
||||
|
||||
client := &transitClient{
|
||||
client: apiClient,
|
||||
mountPath: mountPath,
|
||||
keyName: keyName,
|
||||
}
|
||||
|
||||
if !disableRenewal {
|
||||
// Renew the token immediately to get a secret to pass to renewer
|
||||
secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0)
|
||||
// If we don't get an error renewing, set up a renewer. The token may not be renewable or not have
|
||||
// permission to renew-self.
|
||||
if err == nil {
|
||||
renewer, err := apiClient.NewRenewer(&api.RenewerInput{
|
||||
Secret: secret,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
client.renewer = renewer
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case err := <-renewer.DoneCh():
|
||||
logger.Info("shutting down token renewal")
|
||||
if err != nil {
|
||||
logger.Error("error renewing token", "error", err)
|
||||
}
|
||||
return
|
||||
case <-renewer.RenewCh():
|
||||
logger.Trace("successfully renewed token")
|
||||
}
|
||||
}
|
||||
}()
|
||||
go renewer.Renew()
|
||||
} else {
|
||||
logger.Info("unable to renew token, disabling renewal", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
sealInfo := make(map[string]string)
|
||||
sealInfo["address"] = apiClient.Address()
|
||||
sealInfo["mount_path"] = mountPath
|
||||
sealInfo["key_name"] = keyName
|
||||
if namespace != "" {
|
||||
sealInfo["namespace"] = namespace
|
||||
}
|
||||
|
||||
return client, sealInfo, nil
|
||||
}
|
||||
|
||||
func (c *transitClient) Close() {
|
||||
if c.renewer != nil {
|
||||
c.renewer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
encPlaintext := base64.StdEncoding.EncodeToString(plaintext)
|
||||
path := path.Join(c.mountPath, "encrypt", c.keyName)
|
||||
secret, err := c.client.Logical().Write(path, map[string]interface{}{
|
||||
"plaintext": encPlaintext,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(secret.Data["ciphertext"].(string)), nil
|
||||
}
|
||||
|
||||
func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
path := path.Join(c.mountPath, "decrypt", c.keyName)
|
||||
secret, err := c.client.Logical().Write(path, map[string]interface{}{
|
||||
"ciphertext": string(ciphertext),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
82
vault/seal/transit/transit_test.go
Normal file
82
vault/seal/transit/transit_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/vault/seal"
|
||||
)
|
||||
|
||||
type testTransitClient struct {
|
||||
keyID string
|
||||
seal seal.Access
|
||||
}
|
||||
|
||||
func newTestTransitClient(keyID string) *testTransitClient {
|
||||
return &testTransitClient{
|
||||
keyID: keyID,
|
||||
seal: seal.NewTestSeal(nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testTransitClient) Close() {}
|
||||
|
||||
func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
v, err := m.seal.Encrypt(context.Background(), plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil
|
||||
}
|
||||
|
||||
func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
splitKey := strings.Split(string(ciphertext), ":")
|
||||
if len(splitKey) != 3 {
|
||||
return nil, errors.New("invalid ciphertext returned")
|
||||
}
|
||||
|
||||
data := &physical.EncryptedBlobInfo{
|
||||
Ciphertext: []byte(splitKey[2]),
|
||||
}
|
||||
v, err := m.seal.Decrypt(context.Background(), data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func TestTransitSeal_Lifecycle(t *testing.T) {
|
||||
s := NewSeal(logging.NewVaultLogger(log.Trace))
|
||||
|
||||
keyID := "test-key"
|
||||
s.client = newTestTransitClient(keyID)
|
||||
|
||||
// Test Encrypt and Decrypt calls
|
||||
input := []byte("foo")
|
||||
swi, err := s.Encrypt(context.Background(), input)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err.Error())
|
||||
}
|
||||
|
||||
pt, err := s.Decrypt(context.Background(), swi)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, pt) {
|
||||
t.Fatalf("expected %s, got %s", input, pt)
|
||||
}
|
||||
|
||||
if s.KeyID() != keyID {
|
||||
t.Fatalf("key id does not match: expected %s, got %s", keyID, s.KeyID())
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user