vault/sdk/testing/stepwise/environments/docker/environment.go

865 lines
22 KiB
Go

package docker
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"math/big"
mathrand "math/rand"
"net"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"testing"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
docker "github.com/docker/docker/client"
"github.com/docker/go-connections/nat"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-multierror"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/testing/stepwise"
"golang.org/x/net/http2"
)
var _ stepwise.Environment = (*DockerCluster)(nil)
const dockerVersion = "1.40"
// DockerCluster is used to managing the lifecycle of the test Vault cluster
type DockerCluster struct {
ID string
// PluginName is the input from the test case
PluginName string
// ClusterName is a UUID name of the cluster.
ClusterName string
// MountOptions are a set of options for registering and mounting the plugin
MountOptions stepwise.MountOptions
RaftStorage bool
ClusterNodes []*dockerClusterNode
// Certificate fields
CACert *x509.Certificate
CACertBytes []byte
CACertPEM []byte
CACertPEMFile string
CAKey *ecdsa.PrivateKey
CAKeyPEM []byte
RootCAs *x509.CertPool
// networkID tracks the network ID of the created docker network
networkID string
barrierKeys [][]byte
recoveryKeys [][]byte
tmpDir string
clientAuthRequired bool
// the mountpath of the plugin under test
mountPath string
// rootToken is the initial root token created when the Vault cluster is
// created.
rootToken string
}
// Teardown stops all the containers.
func (dc *DockerCluster) Teardown() error {
var result *multierror.Error
for _, node := range dc.ClusterNodes {
if err := node.Cleanup(); err != nil {
result = multierror.Append(result, err)
}
}
//clean up networks
if dc.networkID != "" {
cli, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithVersion(dockerVersion))
if err != nil {
return multierror.Append(result, err)
}
if err := cli.NetworkRemove(context.Background(), dc.networkID); err != nil {
return multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}
// MountPath returns the path that the plugin under test is mounted at. If a
// MountPathPrefix was given, the mount path uses the prefix with a uuid
// appended. The default is the given PluginName with a uuid suffix.
func (dc *DockerCluster) MountPath() string {
if dc.mountPath != "" {
return dc.mountPath
}
uuidStr, err := uuid.GenerateUUID()
if err != nil {
panic(err)
}
prefix := dc.PluginName
if dc.MountOptions.MountPathPrefix != "" {
prefix = dc.MountOptions.MountPathPrefix
}
dc.mountPath = fmt.Sprintf("%s_%s", prefix, uuidStr)
if dc.MountOptions.PluginType == stepwise.PluginTypeCredential {
dc.mountPath = path.Join("auth", dc.mountPath)
}
return dc.mountPath
}
// RootToken returns the root token of the cluster, if set
func (dc *DockerCluster) RootToken() string {
return dc.rootToken
}
// Name returns the name of this environment
func (dc *DockerCluster) Name() string {
return "docker"
}
// Client returns a clone of the configured Vault API client.
func (dc *DockerCluster) Client() (*api.Client, error) {
if len(dc.ClusterNodes) > 0 {
if dc.ClusterNodes[0].Client != nil {
c, err := dc.ClusterNodes[0].Client.Clone()
if err != nil {
return nil, err
}
c.SetToken(dc.ClusterNodes[0].Client.Token())
return c, nil
}
}
return nil, errors.New("no configured client found")
}
func (n *dockerClusterNode) Name() string {
return n.Cluster.ClusterName + "-" + n.NodeID
}
func (dc *DockerCluster) Initialize(ctx context.Context) error {
client, err := dc.ClusterNodes[0].NewAPIClient()
if err != nil {
return err
}
var resp *api.InitResponse
for ctx.Err() == nil {
resp, err = client.Sys().Init(&api.InitRequest{
SecretShares: 3,
SecretThreshold: 3,
})
if err == nil && resp != nil {
break
}
time.Sleep(500 * time.Millisecond)
}
if err != nil {
return err
}
if resp == nil {
return fmt.Errorf("nil response to init request")
}
for _, k := range resp.Keys {
raw, err := hex.DecodeString(k)
if err != nil {
return err
}
dc.barrierKeys = append(dc.barrierKeys, raw)
}
for _, k := range resp.RecoveryKeys {
raw, err := hex.DecodeString(k)
if err != nil {
return err
}
dc.recoveryKeys = append(dc.recoveryKeys, raw)
}
dc.rootToken = resp.RootToken
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
// Unseal
for j, node := range dc.ClusterNodes {
// copy the index value, so we're not reusing it in deeper scopes
i := j
client, err := node.NewAPIClient()
if err != nil {
return err
}
node.Client = client
if i > 0 && dc.RaftStorage {
leader := dc.ClusterNodes[0]
resp, err := client.Sys().RaftJoin(&api.RaftJoinRequest{
LeaderAPIAddr: fmt.Sprintf("https://%s:%d", dc.ClusterNodes[0].Name(), leader.Address.Port),
LeaderCACert: string(dc.CACertPEM),
LeaderClientCert: string(node.ServerCertPEM),
LeaderClientKey: string(node.ServerKeyPEM),
})
if err != nil {
return err
}
if resp == nil || !resp.Joined {
return fmt.Errorf("nil or negative response from raft join request: %v", resp)
}
}
var unsealed bool
for _, key := range dc.barrierKeys {
resp, err := client.Sys().Unseal(hex.EncodeToString(key))
if err != nil {
return err
}
unsealed = !resp.Sealed
}
if i == 0 && !unsealed {
return fmt.Errorf("could not unseal node %d", i)
}
client.SetToken(dc.rootToken)
err = ensureHealthMatches(ctx, node.Client, func(health *api.HealthResponse) error {
if health.Sealed {
return fmt.Errorf("node %d is sealed: %#v", i, health)
}
if health.ClusterID == "" {
return fmt.Errorf("node %d has no cluster ID", i)
}
dc.ID = health.ClusterID
return nil
})
if err != nil {
return err
}
if i == 0 {
err = ensureLeaderMatches(ctx, node.Client, func(leader *api.LeaderResponse) error {
if !leader.IsSelf {
return fmt.Errorf("node %d leader=%v, expected=%v", i, leader.IsSelf, true)
}
return nil
})
if err != nil {
return err
}
}
}
for i, node := range dc.ClusterNodes {
expectLeader := i == 0
err = ensureLeaderMatches(ctx, node.Client, func(leader *api.LeaderResponse) error {
if expectLeader != leader.IsSelf {
return fmt.Errorf("node %d leader=%v, expected=%v", i, leader.IsSelf, expectLeader)
}
return nil
})
if err != nil {
return err
}
}
return nil
}
func (dc *DockerCluster) setupCA(opts *DockerClusterOptions) error {
var err error
certIPs := []net.IP{
net.IPv6loopback,
net.ParseIP("127.0.0.1"),
}
var caKey *ecdsa.PrivateKey
if opts != nil && opts.CAKey != nil {
caKey = opts.CAKey
} else {
caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
}
dc.CAKey = caKey
var caBytes []byte
if opts != nil && len(opts.CACert) > 0 {
caBytes = opts.CACert
} else {
serialNumber := mathrand.New(mathrand.NewSource(time.Now().UnixNano())).Int63()
CACertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
SerialNumber: big.NewInt(serialNumber),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err = x509.CreateCertificate(rand.Reader, CACertTemplate, CACertTemplate, caKey.Public(), caKey)
if err != nil {
return err
}
}
CACert, err := x509.ParseCertificate(caBytes)
if err != nil {
return err
}
dc.CACert = CACert
dc.CACertBytes = caBytes
dc.RootCAs = x509.NewCertPool()
dc.RootCAs.AddCert(CACert)
CACertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
dc.CACertPEM = pem.EncodeToMemory(CACertPEMBlock)
dc.CACertPEMFile = filepath.Join(dc.tmpDir, "ca", "ca.pem")
err = ioutil.WriteFile(dc.CACertPEMFile, dc.CACertPEM, 0755)
if err != nil {
return err
}
marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
if err != nil {
return err
}
CAKeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledCAKey,
}
dc.CAKeyPEM = pem.EncodeToMemory(CAKeyPEMBlock)
return nil
}
// Don't call this until n.Address.IP is populated
func (n *dockerClusterNode) setupCert() error {
var err error
n.ServerKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
serialNumber := mathrand.New(mathrand.NewSource(time.Now().UnixNano())).Int63()
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: n.Name(),
},
DNSNames: []string{"localhost", n.Name()},
IPAddresses: []net.IP{net.IPv6loopback, net.ParseIP("127.0.0.1")},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(serialNumber),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
n.ServerCertBytes, err = x509.CreateCertificate(rand.Reader, certTemplate, n.Cluster.CACert, n.ServerKey.Public(), n.Cluster.CAKey)
if err != nil {
return err
}
n.ServerCert, err = x509.ParseCertificate(n.ServerCertBytes)
if err != nil {
return err
}
n.ServerCertPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: n.ServerCertBytes,
})
marshaledKey, err := x509.MarshalECPrivateKey(n.ServerKey)
if err != nil {
return err
}
n.ServerKeyPEM = pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
})
n.ServerCertPEMFile = filepath.Join(n.WorkDir, "cert.pem")
err = ioutil.WriteFile(n.ServerCertPEMFile, n.ServerCertPEM, 0755)
if err != nil {
return err
}
n.ServerKeyPEMFile = filepath.Join(n.WorkDir, "key.pem")
err = ioutil.WriteFile(n.ServerKeyPEMFile, n.ServerKeyPEM, 0755)
if err != nil {
return err
}
tlsCert, err := tls.X509KeyPair(n.ServerCertPEM, n.ServerKeyPEM)
if err != nil {
return err
}
certGetter := stepwise.NewCertificateGetter(n.ServerCertPEMFile, n.ServerKeyPEMFile, "")
if err := certGetter.Reload(); err != nil {
return err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: n.Cluster.RootCAs,
ClientCAs: n.Cluster.RootCAs,
ClientAuth: tls.RequestClientCert,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: certGetter.GetCertificate,
}
if n.Cluster.clientAuthRequired {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
n.TLSConfig = tlsConfig
return nil
}
// NewEnvironment creats a new Stepwise Environment for executing tests
func NewEnvironment(name string, options *stepwise.MountOptions) *DockerCluster {
if options == nil {
return nil
}
clusterUUID, err := uuid.GenerateUUID()
if err != nil {
panic(err)
}
return &DockerCluster{
PluginName: options.PluginName,
ClusterName: fmt.Sprintf("test-%s-%s", name, clusterUUID),
RaftStorage: true,
MountOptions: *options,
}
}
// DockerClusterNode represents a single instance of Vault in a cluster
type dockerClusterNode struct {
NodeID string
Address *net.TCPAddr
HostPort string
Client *api.Client
ServerCert *x509.Certificate
ServerCertBytes []byte
ServerCertPEM []byte
ServerCertPEMFile string
ServerKey *ecdsa.PrivateKey
ServerKeyPEM []byte
ServerKeyPEMFile string
TLSConfig *tls.Config
WorkDir string
Cluster *DockerCluster
container *types.ContainerJSON
dockerAPI *docker.Client
}
// NewAPIClient creates and configures a Vault API client to communicate with
// the running Vault Cluster for this DockerClusterNode
func (n *dockerClusterNode) NewAPIClient() (*api.Client, error) {
transport := cleanhttp.DefaultPooledTransport()
transport.TLSClientConfig = n.TLSConfig.Clone()
if err := http2.ConfigureTransport(transport); err != nil {
return nil, err
}
client := &http.Client{
Transport: transport,
CheckRedirect: func(*http.Request, []*http.Request) error {
// This can of course be overridden per-test by using its own client
return fmt.Errorf("redirects not allowed in these tests")
},
}
config := api.DefaultConfig()
if config.Error != nil {
return nil, config.Error
}
config.Address = fmt.Sprintf("https://127.0.0.1:%s", n.HostPort)
config.HttpClient = client
config.MaxRetries = 0
apiClient, err := api.NewClient(config)
if err != nil {
return nil, err
}
apiClient.SetToken(n.Cluster.RootToken())
return apiClient, nil
}
// Cleanup kills the container of the node
func (n *dockerClusterNode) Cleanup() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return n.dockerAPI.ContainerKill(ctx, n.container.ID, "KILL")
}
func (n *dockerClusterNode) start(cli *docker.Client, caDir, netName string, netCIDR *dockerClusterNode, pluginBinPath string) error {
n.dockerAPI = cli
err := n.setupCert()
if err != nil {
return err
}
vaultCfg := map[string]interface{}{
"listener": map[string]interface{}{
"tcp": map[string]interface{}{
"address": fmt.Sprintf("%s:%d", "0.0.0.0", 8200),
"tls_cert_file": "/vault/config/cert.pem",
"tls_key_file": "/vault/config/key.pem",
"telemetry": map[string]interface{}{
"unauthenticated_metrics_access": true,
},
},
},
"telemetry": map[string]interface{}{
"disable_hostname": true,
},
"storage": map[string]interface{}{
"raft": map[string]interface{}{
"path": "/vault/file",
"node_id": n.NodeID,
},
},
"cluster_name": netName,
"log_level": "TRACE",
"raw_storage_endpoint": true,
"plugin_directory": "/vault/config",
// disable_mlock is required for working in the Docker environment with
// custom plugins
"disable_mlock": true,
// These are being provided by docker-entrypoint now, since we don't know
// the address before the container starts.
//"api_addr": fmt.Sprintf("https://%s:%d", n.Address.IP, n.Address.Port),
//"cluster_addr": fmt.Sprintf("https://%s:%d", n.Address.IP, n.Address.Port+1),
}
cfgJSON, err := json.Marshal(vaultCfg)
if err != nil {
return err
}
err = ioutil.WriteFile(filepath.Join(n.WorkDir, "local.json"), cfgJSON, 0644)
if err != nil {
return err
}
// setup plugin bin copy if needed
copyFromTo := map[string]string{
n.WorkDir: "/vault/config",
caDir: "/usr/local/share/ca-certificates/",
}
if pluginBinPath != "" {
base := path.Base(pluginBinPath)
copyFromTo[pluginBinPath] = filepath.Join("/vault/config", base)
}
r := &Runner{
dockerAPI: cli,
ContainerConfig: &container.Config{
Image: "vault",
Entrypoint: []string{"/bin/sh", "-c", "update-ca-certificates && " +
"exec /usr/local/bin/docker-entrypoint.sh vault server -log-level=trace -dev-plugin-dir=./vault/config -config /vault/config/local.json"},
Env: []string{
"VAULT_CLUSTER_INTERFACE=eth0",
"VAULT_API_ADDR=https://127.0.0.1:8200",
fmt.Sprintf("VAULT_REDIRECT_ADDR=https://%s:8200", n.Name()),
},
Labels: nil,
ExposedPorts: nat.PortSet{"8200/tcp": {}, "8201/tcp": {}},
},
ContainerName: n.Name(),
NetName: netName,
CopyFromTo: copyFromTo,
}
n.container, err = r.Start(context.Background())
if err != nil {
return err
}
n.Address = &net.TCPAddr{
IP: net.ParseIP(n.container.NetworkSettings.IPAddress),
Port: 8200,
}
ports := n.container.NetworkSettings.NetworkSettingsBase.Ports[nat.Port("8200/tcp")]
if len(ports) == 0 {
n.Cleanup()
return fmt.Errorf("could not find port binding for 8200/tcp")
}
n.HostPort = ports[0].HostPort
return nil
}
// DockerClusterOptions has options for setting up the docker cluster
type DockerClusterOptions struct {
KeepStandbysSealed bool
RequireClientAuth bool
SkipInit bool
CACert []byte
NumCores int
tmpDir string
PluginTestBin string
// SetupFunc is called after the cluster is started.
SetupFunc func(t testing.T, c *DockerCluster)
CAKey *ecdsa.PrivateKey
}
//
// helper methods/functions
//
// ensureHealthMatches checks health
func ensureHealthMatches(ctx context.Context, client *api.Client, ready func(response *api.HealthResponse) error) error {
var health *api.HealthResponse
var err error
for ctx.Err() == nil {
health, err = client.Sys().Health()
switch {
case err != nil:
case health == nil:
err = fmt.Errorf("nil response to health check")
default:
err = ready(health)
if err == nil {
return nil
}
}
time.Sleep(500 * time.Millisecond)
}
return fmt.Errorf("error checking health: %v", err)
}
func ensureLeaderMatches(ctx context.Context, client *api.Client, ready func(response *api.LeaderResponse) error) error {
var leader *api.LeaderResponse
var err error
for ctx.Err() == nil {
leader, err = client.Sys().Leader()
switch {
case err != nil:
case leader == nil:
err = fmt.Errorf("nil response to leader check")
default:
err = ready(leader)
if err == nil {
return nil
}
}
time.Sleep(500 * time.Millisecond)
}
return fmt.Errorf("error checking leader: %v", err)
}
// end test helper methods
// TODO: allow number of cores/servers to be configurable
var DefaultNumCores = 1
// creates a managed docker container running Vault
func (cluster *DockerCluster) setupDockerCluster(opts *DockerClusterOptions) error {
if opts != nil && opts.tmpDir != "" {
if _, err := os.Stat(opts.tmpDir); os.IsNotExist(err) {
if err := os.MkdirAll(opts.tmpDir, 0700); err != nil {
return err
}
}
cluster.tmpDir = opts.tmpDir
} else {
tempDir, err := ioutil.TempDir("", "vault-test-cluster-")
if err != nil {
return err
}
cluster.tmpDir = tempDir
}
caDir := filepath.Join(cluster.tmpDir, "ca")
if err := os.MkdirAll(caDir, 0755); err != nil {
return err
}
var numCores int
if opts == nil || opts.NumCores == 0 {
numCores = DefaultNumCores
} else {
numCores = opts.NumCores
}
if opts != nil && opts.RequireClientAuth {
cluster.clientAuthRequired = true
}
for i := 0; i < numCores; i++ {
nodeID := fmt.Sprintf("vault-%d", i)
node := &dockerClusterNode{
NodeID: nodeID,
Cluster: cluster,
WorkDir: filepath.Join(cluster.tmpDir, nodeID),
}
cluster.ClusterNodes = append(cluster.ClusterNodes, node)
if err := os.MkdirAll(node.WorkDir, 0700); err != nil {
return err
}
}
err := cluster.setupCA(opts)
if err != nil {
return err
}
cli, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithVersion(dockerVersion))
if err != nil {
return err
}
netUUID, err := uuid.GenerateUUID()
if err != nil {
panic(err)
}
netName := fmt.Sprintf("%s-%s", "vault-test", netUUID)
netID, err := setupNetwork(cli, netName)
if err != nil {
return err
}
cluster.networkID = netID
for _, node := range cluster.ClusterNodes {
pluginBinPath := ""
if opts != nil {
pluginBinPath = opts.PluginTestBin
}
err := node.start(cli, caDir, netName, node, pluginBinPath)
if err != nil {
return err
}
}
if opts == nil || !opts.SkipInit {
if err := cluster.Initialize(context.Background()); err != nil {
return err
}
}
return nil
}
// Docker networking functions
// setupNetwork establishes networking for the Docker container
func setupNetwork(cli *docker.Client, netName string) (string, error) {
id, err := createNetwork(cli, netName)
if err != nil {
return "", fmt.Errorf("couldn't create network %s: %w", netName, err)
}
return id, nil
}
func createNetwork(cli *docker.Client, netName string) (string, error) {
resp, err := cli.NetworkCreate(context.Background(), netName, types.NetworkCreate{
CheckDuplicate: true,
Driver: "bridge",
Options: map[string]string{},
IPAM: &network.IPAM{
Driver: "default",
Options: map[string]string{},
},
})
if err != nil {
return "", err
}
return resp.ID, nil
}
// Setup creates any temp directories needed and compiles the binary for copying to Docker
func (dc *DockerCluster) Setup() error {
registryName := dc.MountOptions.RegistryName
pluginName := dc.MountOptions.PluginName
// get the working directory of the plugin being tested.
srcDir, err := os.Getwd()
if err != nil {
return err
}
// tmpDir gets cleaned up when the cluster is cleaned up
tmpDir, err := ioutil.TempDir("", "bin")
if err != nil {
return err
}
binName, binPath, sha256value, err := stepwise.CompilePlugin(registryName, pluginName, srcDir, tmpDir)
if err != nil {
return err
}
dOpts := &DockerClusterOptions{PluginTestBin: binPath}
if err := dc.setupDockerCluster(dOpts); err != nil {
return err
}
cores := dc.ClusterNodes
client := cores[0].Client
// use client to mount plugin
err = client.Sys().RegisterPlugin(&api.RegisterPluginInput{
Name: registryName,
Type: consts.PluginType(dc.MountOptions.PluginType),
Command: binName,
SHA256: sha256value,
})
if err != nil {
return err
}
switch dc.MountOptions.PluginType {
case stepwise.PluginTypeCredential:
// the mount path includes "auth/" for credential type plugins. For enabling
// auth mounts via the /sys endpoint, we need to remove that prefix
authPath := strings.TrimPrefix(dc.MountPath(), "auth/")
err = client.Sys().EnableAuthWithOptions(authPath, &api.EnableAuthOptions{
Type: registryName,
})
case stepwise.PluginTypeDatabase:
case stepwise.PluginTypeSecrets:
err = client.Sys().Mount(dc.MountPath(), &api.MountInput{
Type: registryName,
})
default:
return fmt.Errorf("unknown plugin type: %s", dc.MountOptions.PluginType.String())
}
return err
}