diff --git a/command/server/listener_tcp_test.go b/command/server/listener_tcp_test.go index 42da6c0d21..dde2251340 100644 --- a/command/server/listener_tcp_test.go +++ b/command/server/listener_tcp_test.go @@ -6,9 +6,7 @@ package server import ( "crypto/tls" "crypto/x509" - "fmt" "io/ioutil" - "math/rand" "net" "os" "testing" @@ -18,6 +16,7 @@ import ( "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/internalshared/configutil" "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/require" ) func TestTCPListener(t *testing.T) { @@ -25,9 +24,7 @@ func TestTCPListener(t *testing.T) { Address: "127.0.0.1:0", TLSDisable: true, }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) connFn := func(lnReal net.Listener) (net.Conn, error) { return net.Dial("tcp", ln.Addr().String()) @@ -41,19 +38,13 @@ func TestTCPListener_tls(t *testing.T) { wd, _ := os.Getwd() wd += "/test-fixtures/reload/" - td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63())) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(td) - // Setup initial certs - inBytes, _ := ioutil.ReadFile(wd + "reload_ca.pem") + inBytes, err := os.ReadFile(wd + "reload_ca.pem") + require.NoError(t, err) + certPool := x509.NewCertPool() ok := certPool.AppendCertsFromPEM(inBytes) - if !ok { - t.Fatal("not ok when appending CA cert") - } + require.True(t, ok, "not ok when appending CA cert") ln, _, _, err := tcpListenerFactory(&configutil.Listener{ Address: "127.0.0.1:0", @@ -62,9 +53,8 @@ func TestTCPListener_tls(t *testing.T) { TLSRequireAndVerifyClientCert: true, TLSClientCAFile: wd + "reload_ca.pem", }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) + cwd, _ := os.Getwd() clientCert, _ := tls.LoadX509KeyPair( @@ -100,9 +90,7 @@ func TestTCPListener_tls(t *testing.T) { TLSDisableClientCerts: true, TLSClientCAFile: wd + "reload_ca.pem", }, nil, cli.NewMockUi()) - if err == nil { - t.Fatal("expected error due to mutually exclusive client cert options") - } + require.Error(t, err, "expected error due to mutually exclusive client cert options") ln, _, _, err = tcpListenerFactory(&configutil.Listener{ Address: "127.0.0.1:0", @@ -111,9 +99,7 @@ func TestTCPListener_tls(t *testing.T) { TLSDisableClientCerts: true, TLSClientCAFile: wd + "reload_ca.pem", }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) testListenerImpl(t, ln, connFn(false), "foo.example.com", 0, "127.0.0.1", false) } @@ -122,19 +108,11 @@ func TestTCPListener_tls13(t *testing.T) { wd, _ := os.Getwd() wd += "/test-fixtures/reload/" - td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63())) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(td) - // Setup initial certs inBytes, _ := ioutil.ReadFile(wd + "reload_ca.pem") certPool := x509.NewCertPool() ok := certPool.AppendCertsFromPEM(inBytes) - if !ok { - t.Fatal("not ok when appending CA cert") - } + require.True(t, ok, "not ok when appending CA cert") ln, _, _, err := tcpListenerFactory(&configutil.Listener{ Address: "127.0.0.1:0", @@ -144,9 +122,8 @@ func TestTCPListener_tls13(t *testing.T) { TLSClientCAFile: wd + "reload_ca.pem", TLSMinVersion: "tls13", }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) + cwd, _ := os.Getwd() clientCert, _ := tls.LoadX509KeyPair( @@ -183,9 +160,7 @@ func TestTCPListener_tls13(t *testing.T) { TLSClientCAFile: wd + "reload_ca.pem", TLSMinVersion: "tls13", }, nil, cli.NewMockUi()) - if err == nil { - t.Fatal("expected error due to mutually exclusive client cert options") - } + require.Error(t, err, "expected error due to mutually exclusive client cert options") ln, _, _, err = tcpListenerFactory(&configutil.Listener{ Address: "127.0.0.1:0", @@ -195,9 +170,7 @@ func TestTCPListener_tls13(t *testing.T) { TLSClientCAFile: wd + "reload_ca.pem", TLSMinVersion: "tls13", }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) testListenerImpl(t, ln, connFn(false), "foo.example.com", tls.VersionTLS13, "127.0.0.1", false) @@ -209,9 +182,7 @@ func TestTCPListener_tls13(t *testing.T) { TLSClientCAFile: wd + "reload_ca.pem", TLSMaxVersion: "tls12", }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) testListenerImpl(t, ln, connFn(false), "foo.example.com", tls.VersionTLS12, "127.0.0.1", false) } @@ -429,9 +400,7 @@ func TestTCPListener_proxyProtocol(t *testing.T) { proxyProtocolAuthorizedAddrs := []*sockaddr.SockAddrMarshaler{} if tc.AuthorizedAddr != "" { sockAddr, err := sockaddr.NewSockAddr(tc.AuthorizedAddr) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) proxyProtocolAuthorizedAddrs = append( proxyProtocolAuthorizedAddrs, &sockaddr.SockAddrMarshaler{SockAddr: sockAddr}, @@ -444,12 +413,11 @@ func TestTCPListener_proxyProtocol(t *testing.T) { ProxyProtocolBehavior: tc.Behavior, ProxyProtocolAuthorizedAddrs: proxyProtocolAuthorizedAddrs, }, nil, cli.NewMockUi()) - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) connFn := func(lnReal net.Listener) (net.Conn, error) { - conn, err := net.Dial("tcp", ln.Addr().String()) + d := net.Dialer{Timeout: 3 * time.Second} + conn, err := d.Dial("tcp", lnReal.Addr().String()) if err != nil { return nil, err }