diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 4ad1c0b8c8..36bebe8860 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "log" - stdhttp "net/http" "os" "reflect" "sync" @@ -78,30 +77,24 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac return } -func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { +func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "database": Factory, }, } - handler1 := stdhttp.NewServeMux() - handler2 := stdhttp.NewServeMux() - handler3 := stdhttp.NewServeMux() + cluster := vault.NewTestCluster(t, coreConfig, false) + cluster.StartListeners() + cores := cluster.Cores + cores[0].Handler.Handle("/", http.Handler(cores[0].Core)) + cores[1].Handler.Handle("/", http.Handler(cores[1].Core)) + cores[2].Handler.Handle("/", http.Handler(cores[2].Core)) - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) - handler1.Handle("/", http.Handler(cores[0].Core)) - handler2.Handle("/", http.Handler(cores[1].Core)) - handler3.Handle("/", http.Handler(cores[2].Core)) + sys := vault.TestDynamicSystemView(cores[0].Core) + vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", "TestBackend_PluginMain") - core := cores[0] - - sys := vault.TestDynamicSystemView(core.Core) - vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain") - - return cores, sys + return cluster, sys } func TestBackend_PluginMain(t *testing.T) { @@ -136,10 +129,9 @@ func TestBackend_PluginMain(t *testing.T) { func TestBackend_config_connection(t *testing.T) { var resp *logical.Response var err error - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + + cluster, sys := getCluster(t) + defer cluster.CloseListeners() config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -201,10 +193,8 @@ func TestBackend_config_connection(t *testing.T) { } func TestBackend_basic(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -294,10 +284,8 @@ func TestBackend_basic(t *testing.T) { } func TestBackend_connectionCrud(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -441,10 +429,8 @@ func TestBackend_connectionCrud(t *testing.T) { } func TestBackend_roleCrud(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -553,10 +539,8 @@ func TestBackend_roleCrud(t *testing.T) { } } func TestBackend_allowedRoles(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index d0e7073d87..b6d5cfa81c 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -2,7 +2,6 @@ package dbplugin_test import ( "errors" - stdhttp "net/http" "os" "testing" "time" @@ -73,26 +72,20 @@ func (m *mockPlugin) Close() error { return nil } -func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { +func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { coreConfig := &vault.CoreConfig{} - handler1 := stdhttp.NewServeMux() - handler2 := stdhttp.NewServeMux() - handler3 := stdhttp.NewServeMux() + cluster := vault.NewTestCluster(t, coreConfig, false) + cluster.StartListeners() + cores := cluster.Cores + cores[0].Handler.Handle("/", http.Handler(cores[0].Core)) + cores[1].Handler.Handle("/", http.Handler(cores[1].Core)) + cores[2].Handler.Handle("/", http.Handler(cores[2].Core)) - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) - handler1.Handle("/", http.Handler(cores[0].Core)) - handler2.Handle("/", http.Handler(cores[1].Core)) - handler3.Handle("/", http.Handler(cores[2].Core)) + sys := vault.TestDynamicSystemView(cores[0].Core) + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") - core := cores[0] - - sys := vault.TestDynamicSystemView(core.Core) - vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main") - - return cores, sys + return cluster, sys } // This is not an actual test case, it's a helper function that will be executed @@ -116,10 +109,8 @@ func TestPlugin_Main(t *testing.T) { } func TestPlugin_Initialize(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -142,10 +133,8 @@ func TestPlugin_Initialize(t *testing.T) { } func TestPlugin_CreateUser(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -184,10 +173,8 @@ func TestPlugin_CreateUser(t *testing.T) { } func TestPlugin_RenewUser(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -220,10 +207,8 @@ func TestPlugin_RenewUser(t *testing.T) { } func TestPlugin_RevokeUser(t *testing.T) { - cores, sys := getCore(t) - for _, core := range cores { - defer core.CloseListeners() - } + cluster, sys := getCluster(t) + defer cluster.CloseListeners() db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { diff --git a/http/forwarding_test.go b/http/forwarding_test.go index 67cb65220c..02c5c079bd 100644 --- a/http/forwarding_test.go +++ b/http/forwarding_test.go @@ -26,10 +26,6 @@ import ( ) func TestHTTP_Fallback_Bad_Address(t *testing.T) { - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() - coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, @@ -37,15 +33,14 @@ func TestHTTP_Fallback_Bad_Address(t *testing.T) { ClusterAddr: "https://127.3.4.1:8382", } - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) - for _, core := range cores { - defer core.CloseListeners() - } - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cluster := vault.NewTestCluster(t, coreConfig, true) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores + + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core @@ -83,10 +78,6 @@ func TestHTTP_Fallback_Bad_Address(t *testing.T) { } func TestHTTP_Fallback_Disabled(t *testing.T) { - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() - coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, @@ -94,15 +85,14 @@ func TestHTTP_Fallback_Disabled(t *testing.T) { ClusterAddr: "empty", } - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) - for _, core := range cores { - defer core.CloseListeners() - } - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cluster := vault.NewTestCluster(t, coreConfig, true) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores + + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core @@ -150,25 +140,20 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) testPlaintext := "the quick brown fox" testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA==" - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() - coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, } - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) - for _, core := range cores { - defer core.CloseListeners() - } - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cluster := vault.NewTestCluster(t, coreConfig, true) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores + + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core @@ -463,25 +448,20 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) // This tests TLS connection state forwarding by ensuring that we can use a // client TLS to authenticate against the cert backend func TestHTTP_Forwarding_ClientTLS(t *testing.T) { - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() - coreConfig := &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{ "cert": credCert.Factory, }, } - // Chicken-and-egg: Handler needs a core. So we create handlers first, then - // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) - for _, core := range cores { - defer core.CloseListeners() - } - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cluster := vault.NewTestCluster(t, coreConfig, true) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores + + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core @@ -587,18 +567,14 @@ func TestHTTP_Forwarding_ClientTLS(t *testing.T) { } func TestHTTP_Forwarding_HelpOperation(t *testing.T) { - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() + cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, true) + defer cluster.CloseListeners() + cluster.StartListeners() + cores := cluster.Cores - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, &vault.CoreConfig{}, true) - for _, core := range cores { - defer core.CloseListeners() - } - - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) vault.TestWaitActive(t, cores[0].Core) diff --git a/http/sys_wrapping_test.go b/http/sys_wrapping_test.go index 9c27ebb81e..6102ff803e 100644 --- a/http/sys_wrapping_test.go +++ b/http/sys_wrapping_test.go @@ -2,7 +2,6 @@ package http import ( "encoding/json" - "net/http" "reflect" "testing" "time" @@ -14,21 +13,17 @@ import ( // Test wrapping functionality func TestHTTP_Wrapping(t *testing.T) { - handler1 := http.NewServeMux() - handler2 := http.NewServeMux() - handler3 := http.NewServeMux() - coreConfig := &vault.CoreConfig{} // Chicken-and-egg: Handler needs a core. So we create handlers first, then // add routes chained to a Handler-created handler. - cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) - for _, core := range cores { - defer core.CloseListeners() - } - handler1.Handle("/", Handler(cores[0].Core)) - handler2.Handle("/", Handler(cores[1].Core)) - handler3.Handle("/", Handler(cores[2].Core)) + cluster := vault.NewTestCluster(t, coreConfig, true) + defer cluster.CloseListeners() + cluster.StartListeners() + cores := cluster.Cores + cores[0].Handler.Handle("/", Handler(cores[0].Core)) + cores[1].Handler.Handle("/", Handler(cores[1].Core)) + cores[2].Handler.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core diff --git a/vault/cluster_test.go b/vault/cluster_test.go index ba99dcfa71..05bf10684e 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -85,10 +85,10 @@ func TestCluster_ListenForRequests(t *testing.T) { // Make this nicer for tests manualStepDownSleepPeriod = 5 * time.Second - cores := TestCluster(t, []http.Handler{nil, nil, nil}, nil, false) - for _, core := range cores { - defer core.CloseListeners() - } + cluster := NewTestCluster(t, nil, false) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores root := cores[0].Root @@ -198,10 +198,25 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { w.Write([]byte("core3")) }) - cores := TestCluster(t, []http.Handler{handler1, handler2, handler3}, nil, true) - for _, core := range cores { - defer core.CloseListeners() - } + cluster := NewTestCluster(t, nil, true) + cluster.StartListeners() + defer cluster.CloseListeners() + cores := cluster.Cores + cores[0].Handler.HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(201) + w.Write([]byte("core1")) + }) + cores[1].Handler.HandleFunc("/core2", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(202) + w.Write([]byte("core2")) + }) + cores[2].Handler.HandleFunc("/core3", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(203) + w.Write([]byte("core3")) + }) root := cores[0].Root diff --git a/vault/core.go b/vault/core.go index 38df589e0c..725f7f775e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -275,7 +275,7 @@ type Core struct { // reloadFuncs is a map containing reload functions reloadFuncs map[string][]ReloadFunc - // reloadFuncsLock controlls access to the funcs + // reloadFuncsLock controls access to the funcs reloadFuncsLock sync.RWMutex // wrappingJWTKey is the key used for generating JWTs containing response diff --git a/vault/testing.go b/vault/testing.go index 051238b502..9ccf7e6632 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -572,6 +572,32 @@ func TestWaitActive(t testing.TB, core *Core) { } } +type TestCluster struct { + Cores []*TestClusterCore +} + +func (t *TestCluster) StartListeners() { + for _, core := range t.Cores { + if core.Server != nil { + for _, ln := range core.Listeners { + go core.Server.Serve(ln) + } + } + } +} + +func (t *TestCluster) CloseListeners() { + for _, core := range t.Cores { + if core.Listeners != nil { + for _, ln := range core.Listeners { + ln.Close() + } + } + } + // Give time to actually shut down/clean up before the next test + time.Sleep(time.Second) +} + type TestListener struct { net.Listener Address *net.TCPAddr @@ -580,6 +606,8 @@ type TestListener struct { type TestClusterCore struct { *Core Listeners []*TestListener + Handler *http.ServeMux + Server *http.Server Root string BarrierKeys [][]byte CACertBytes []byte @@ -589,21 +617,7 @@ type TestClusterCore struct { Client *api.Client } -func (t *TestClusterCore) CloseListeners() { - if t.Listeners != nil { - for _, ln := range t.Listeners { - ln.Close() - } - } - // Give time to actually shut down/clean up before the next test - time.Sleep(time.Second) -} - -func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unsealStandbys bool) []*TestClusterCore { - if handlers == nil || len(handlers) != 3 { - t.Fatal("handlers must be size 3") - } - +func NewTestCluster(t testing.TB, base *CoreConfig, unsealStandbys bool) *TestCluster { // // TLS setup // @@ -692,15 +706,13 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal Listener: tls.NewListener(ln, tlsConfig), Address: ln.Addr().(*net.TCPAddr), }) + handler1 := http.NewServeMux() server1 := &http.Server{ - Handler: handlers[0], + Handler: handler1, } if err := http2.ConfigureServer(server1, nil); err != nil { t.Fatal(err) } - for _, ln := range c1lns { - go server1.Serve(ln) - } ln, err = net.ListenTCP("tcp", &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -714,15 +726,13 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal Address: ln.Addr().(*net.TCPAddr), }, } + handler2 := http.NewServeMux() server2 := &http.Server{ - Handler: handlers[1], + Handler: handler2, } if err := http2.ConfigureServer(server2, nil); err != nil { t.Fatal(err) } - for _, ln := range c2lns { - go server2.Serve(ln) - } ln, err = net.ListenTCP("tcp", &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -736,15 +746,13 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal Address: ln.Addr().(*net.TCPAddr), }, } + handler3 := http.NewServeMux() server3 := &http.Server{ - Handler: handlers[2], + Handler: handler3, } if err := http2.ConfigureServer(server3, nil); err != nil { t.Fatal(err) } - for _, ln := range c3lns { - go server3.Serve(ln) - } // Create three cores with the same physical and different redirect/cluster addrs // N.B.: On OSX, instead of random ports, it assigns new ports to new @@ -844,10 +852,10 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal } c2.SetClusterListenerAddrs(clusterAddrGen(c2lns)) - c2.SetClusterHandler(handlers[1]) + c2.SetClusterHandler(handler2) c3.SetClusterListenerAddrs(clusterAddrGen(c3lns)) - c3.SetClusterHandler(handlers[2]) - keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handlers[0]) + c3.SetClusterHandler(handler3) + keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1) for _, key := range keys { if _, err := c1.Unseal(TestKeyCopy(key)); err != nil { t.Fatalf("unseal err: %s", err) @@ -928,6 +936,8 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal ret = append(ret, &TestClusterCore{ Core: c1, Listeners: c1lns, + Handler: handler1, + Server: server1, Root: root, BarrierKeys: keyCopies.([][]byte), CACertBytes: caBytes, @@ -941,6 +951,8 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal ret = append(ret, &TestClusterCore{ Core: c2, Listeners: c2lns, + Handler: handler2, + Server: server2, Root: root, BarrierKeys: keyCopies.([][]byte), CACertBytes: caBytes, @@ -954,6 +966,8 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal ret = append(ret, &TestClusterCore{ Core: c3, Listeners: c3lns, + Handler: handler3, + Server: server3, Root: root, BarrierKeys: keyCopies.([][]byte), CACertBytes: caBytes, @@ -963,7 +977,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal Client: getAPIClient(c3lns[0].Address.Port), }) - return ret + return &TestCluster{Cores: ret} } const (