diff --git a/command/command_test.go b/command/command_test.go index f41813c28d..0ff084f45b 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -1,64 +1,212 @@ package command import ( + "context" + "encoding/base64" + "net" + "net/http" + "strings" "testing" + "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/builtin/logical/ssh" "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/physical/inmem" "github.com/hashicorp/vault/vault" + "github.com/mitchellh/cli" + auditFile "github.com/hashicorp/vault/builtin/audit/file" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" vaulthttp "github.com/hashicorp/vault/http" logxi "github.com/mgutz/logxi/v1" ) -var testVaultServerDefaultBackends = map[string]logical.Factory{ - "transit": transit.Factory, - "pki": pki.Factory, -} +var ( + defaultVaultLogger = logxi.NullLog -func testVaultServer(t testing.TB) (*api.Client, func()) { - return testVaultServerBackends(t, testVaultServerDefaultBackends) -} - -func testVaultServerBackends(t testing.TB, backends map[string]logical.Factory) (*api.Client, func()) { - coreConfig := &vault.CoreConfig{ - DisableMlock: true, - DisableCache: true, - Logger: logxi.NullLog, - LogicalBackends: backends, + defaultVaultCredentialBackends = map[string]logical.Factory{ + "userpass": credUserpass.Factory, } - cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + defaultVaultAuditBackends = map[string]audit.Factory{ + "file": auditFile.Factory, + } + + defaultVaultLogicalBackends = map[string]logical.Factory{ + "generic-leased": vault.LeasedPassthroughBackendFactory, + "pki": pki.Factory, + "ssh": ssh.Factory, + "transit": transit.Factory, + } +) + +// assertNoTabs asserts the CLI help has no tab characters. +func assertNoTabs(tb testing.TB, c cli.Command) { + tb.Helper() + + if strings.ContainsRune(c.Help(), '\t') { + tb.Errorf("%#v help output contains tabs", c) + } +} + +// testVaultServer creates a test vault cluster and returns a configured API +// client and closer function. +func testVaultServer(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + client, _, closer := testVaultServerUnseal(tb) + return client, closer +} + +// testVaultServerUnseal creates a test vault cluster and returns a configured +// API client, list of unseal keys (as strings), and a closer function. +func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { + tb.Helper() + + return testVaultServerCoreConfig(tb, &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: defaultVaultLogger, + CredentialBackends: defaultVaultCredentialBackends, + AuditBackends: defaultVaultAuditBackends, + LogicalBackends: defaultVaultLogicalBackends, + }) +} + +// testVaultServerCoreConfig creates a new vault cluster with the given core +// configuration. This is a lower-level test helper. +func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) { + tb.Helper() + + cluster := vault.NewTestCluster(tb, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, + NumCores: 1, // Default is 3, but we don't need that many }) cluster.Start() - // make it easy to get access to the active + // Make it easy to get access to the active core := cluster.Cores[0].Core - vault.TestWaitActive(t, core) + vault.TestWaitActive(tb, core) + // Get the client already setup for us! client := cluster.Cores[0].Client client.SetToken(cluster.RootToken) - // Sanity check - secret, err := client.Auth().Token().LookupSelf() - if err != nil { - t.Fatal(err) + // Convert the unseal keys to base64 encoded, since these are how the user + // will get them. + unsealKeys := make([]string, len(cluster.BarrierKeys)) + for i := range unsealKeys { + unsealKeys[i] = base64.StdEncoding.EncodeToString(cluster.BarrierKeys[i]) } - if secret == nil || secret.Data["id"].(string) != cluster.RootToken { - t.Fatalf("token mismatch: %#v vs %q", secret, cluster.RootToken) - } - return client, func() { defer cluster.Cleanup() } + + return client, unsealKeys, func() { defer cluster.Cleanup() } } -func testClient(t *testing.T, addr string, token string) *api.Client { +// testVaultServerUninit creates an uninitialized server. +func testVaultServerUninit(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + inm, err := inmem.NewInmem(nil, defaultVaultLogger) + if err != nil { + tb.Fatal(err) + } + + core, err := vault.NewCore(&vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: defaultVaultLogger, + Physical: inm, + CredentialBackends: defaultVaultCredentialBackends, + AuditBackends: defaultVaultAuditBackends, + LogicalBackends: defaultVaultLogicalBackends, + }) + if err != nil { + tb.Fatal(err) + } + + ln, addr := vaulthttp.TestServer(tb, core) + + client, err := api.NewClient(&api.Config{ + Address: addr, + }) + if err != nil { + tb.Fatal(err) + } + + return client, func() { ln.Close() } +} + +// testVaultServerBad creates an http server that returns a 500 on each request +// to simulate failures. +func testVaultServerBad(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + tb.Fatal(err) + } + + server := &http.Server{ + Addr: "127.0.0.1:0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "500 internal server error", http.StatusInternalServerError) + }), + ReadTimeout: 1 * time.Second, + ReadHeaderTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + IdleTimeout: 1 * time.Second, + } + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + tb.Fatal(err) + } + }() + + client, err := api.NewClient(&api.Config{ + Address: "http://" + listener.Addr().String(), + }) + if err != nil { + tb.Fatal(err) + } + + return client, func() { + ctx, done := context.WithTimeout(context.Background(), 5*time.Second) + defer done() + + server.Shutdown(ctx) + } +} + +// testTokenAndAccessor creates a new authentication token capable of being renewed with +// the default policy attached. It returns the token and it's accessor. +func testTokenAndAccessor(tb testing.TB, client *api.Client) (string, string) { + tb.Helper() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + }) + if err != nil { + tb.Fatal(err) + } + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + tb.Fatalf("missing auth data: %#v", secret) + } + return secret.Auth.ClientToken, secret.Auth.Accessor +} + +func testClient(tb testing.TB, addr string, token string) *api.Client { + tb.Helper() config := api.DefaultConfig() config.Address = addr client, err := api.NewClient(config) if err != nil { - t.Fatalf("err: %s", err) + tb.Fatal(err) } client.SetToken(token)