Attempt to fix TestProxyFromEnvironment test

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Romain 2025-06-02 10:46:04 +02:00 committed by GitHub
parent 0b4058dde0
commit 2fdee25bb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,7 +20,6 @@ import (
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/static"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/traefik/traefik/v3/pkg/tls/generate"
)
const (
@ -125,9 +124,17 @@ func TestProxyFromEnvironment(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
backendURL, backendCert := newBackendServer(t, test.tls, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
var backendServer *httptest.Server
if test.tls {
backendServer = httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("backendTLS"))
}))
} else {
backendServer = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("backend"))
}))
}
t.Cleanup(backendServer.Close)
var proxyCalled bool
proxyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -155,8 +162,21 @@ func TestProxyFromEnvironment(t *testing.T) {
connHj, _, err := hj.Hijack()
require.NoError(t, err)
go func() { _, _ = io.Copy(connHj, conn) }()
_, _ = io.Copy(conn, connHj)
defer func() {
_ = connHj.Close()
_ = conn.Close()
}()
errCh := make(chan error, 1)
go func() {
_, err = io.Copy(connHj, conn)
errCh <- err
}()
go func() {
_, err = io.Copy(conn, connHj)
errCh <- err
}()
<-errCh // Wait for one of the copy operations to finish
})
var proxyURL string
@ -198,7 +218,7 @@ func TestProxyFromEnvironment(t *testing.T) {
proxyURL = proxyServer.URL
case proxyHTTPS:
proxyServer := httptest.NewServer(proxyHandler)
proxyServer := httptest.NewTLSServer(proxyHandler)
t.Cleanup(proxyServer.Close)
proxyURL = proxyServer.URL
@ -209,11 +229,8 @@ func TestProxyFromEnvironment(t *testing.T) {
if proxyCert != nil {
certPool.AddCert(proxyCert)
}
if backendCert != nil {
cert, err := x509.ParseCertificate(backendCert.Certificate[0])
require.NoError(t, err)
certPool.AddCert(cert)
if backendServer.Certificate() != nil {
certPool.AddCert(backendServer.Certificate())
}
builder := NewProxyBuilder(&transportManagerMock{tlsConfig: &tls.Config{RootCAs: certPool}}, static.FastProxyConfig{})
@ -230,7 +247,7 @@ func TestProxyFromEnvironment(t *testing.T) {
return u, nil
}
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false, false)
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendServer.URL), false, false)
require.NoError(t, err)
reverseProxyServer := httptest.NewServer(reverseProxy)
@ -246,7 +263,11 @@ func TestProxyFromEnvironment(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
if test.tls {
assert.Equal(t, "backendTLS", string(body))
} else {
assert.Equal(t, "backend", string(body))
}
assert.True(t, proxyCalled)
})
}
@ -385,52 +406,6 @@ func TestTransferEncodingChunked(t *testing.T) {
assert.Equal(t, "chunk 0\nchunk 1\nchunk 2\n", string(body))
}
func newCertificate(t *testing.T, domain string) *tls.Certificate {
t.Helper()
certPEM, keyPEM, err := generate.KeyPair(domain, time.Time{})
require.NoError(t, err)
certificate, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
return &certificate
}
func newBackendServer(t *testing.T, isTLS bool, handler http.Handler) (string, *tls.Certificate) {
t.Helper()
var ln net.Listener
var err error
var cert *tls.Certificate
scheme := "http"
domain := "backend.localhost"
if isTLS {
scheme = "https"
cert = newCertificate(t, domain)
ln, err = tls.Listen("tcp", ":0", &tls.Config{Certificates: []tls.Certificate{*cert}})
require.NoError(t, err)
} else {
ln, err = net.Listen("tcp", ":0")
require.NoError(t, err)
}
srv := &http.Server{Handler: handler}
go func() { _ = srv.Serve(ln) }()
t.Cleanup(func() { _ = srv.Close() })
_, port, err := net.SplitHostPort(ln.Addr().String())
require.NoError(t, err)
backendURL := fmt.Sprintf("%s://%s:%s", scheme, domain, port)
return backendURL, cert
}
type transportManagerMock struct {
tlsConfig *tls.Config
}