From 95d343876e186a7cc7fbf20a0f70a0732af1c147 Mon Sep 17 00:00:00 2001 From: kevinpollet Date: Thu, 31 Jul 2025 14:13:21 +0200 Subject: [PATCH] review --- pkg/tcp/dialer_test.go | 162 +++++++++++++++++------------------------ pkg/tcp/proxy_test.go | 63 ++++++++-------- 2 files changed, 97 insertions(+), 128 deletions(-) diff --git a/pkg/tcp/dialer_test.go b/pkg/tcp/dialer_test.go index f90ab519d..62af030db 100644 --- a/pkg/tcp/dialer_test.go +++ b/pkg/tcp/dialer_test.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "errors" "io" "math/big" "net" @@ -14,6 +15,7 @@ import ( "testing" "time" + "github.com/pires/go-proxyproto" "github.com/spiffe/go-spiffe/v2/bundle/x509bundle" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig" @@ -140,7 +142,7 @@ func TestNoTLS(t *testing.T) { require.NoError(t, err) defer backendListener.Close() - go fakeRedis(t, backendListener) + go fakeServer(t, backendListener) _, port, err := net.SplitHostPort(backendListener.Addr().String()) require.NoError(t, err) @@ -186,7 +188,7 @@ func TestTLS(t *testing.T) { tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}}) defer tlsListener.Close() - go fakeRedis(t, tlsListener) + go fakeServer(t, tlsListener) _, port, err := net.SplitHostPort(tlsListener.Addr().String()) require.NoError(t, err) @@ -236,7 +238,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) { tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}}) defer tlsListener.Close() - go fakeRedis(t, tlsListener) + go fakeServer(t, tlsListener) _, port, err := net.SplitHostPort(tlsListener.Addr().String()) require.NoError(t, err) @@ -297,7 +299,7 @@ func TestMTLS(t *testing.T) { }) defer tlsListener.Close() - go fakeRedis(t, tlsListener) + go fakeServer(t, tlsListener) _, port, err := net.SplitHostPort(tlsListener.Addr().String()) require.NoError(t, err) @@ -444,7 +446,7 @@ func TestSpiffeMTLS(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - go fakeRedis(t, tlsListener) + go fakeServer(t, tlsListener) dialerManager := NewDialerManager(test.clientSource) @@ -506,51 +508,47 @@ func TestProxyProtocol(t *testing.T) { t.Run(test.desc, func(t *testing.T) { backendListener, err := net.Listen("tcp", ":0") require.NoError(t, err) - defer backendListener.Close() - receivedData := make([]byte, 1024) + var version int + proxyBackendListener := proxyproto.Listener{ + Listener: backendListener, + ValidateHeader: func(h *proxyproto.Header) error { + version = int(h.Version) + return nil + }, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + switch test.version { + case 1, 2: + return proxyproto.USE, nil + default: + return proxyproto.REQUIRE, errors.New("unsupported version") + } + }, + } + defer proxyBackendListener.Close() - // Start a server that captures all data including proxy protocol headers - go func() { - conn, err := backendListener.Accept() - require.NoError(t, err) - defer conn.Close() - - // Read all initial data - _, err = conn.Read(receivedData) - require.NoError(t, err) - - // Check if there's ping in the data and respond - if bytes.Contains(receivedData, []byte("ping")) { - _, _ = conn.Write([]byte("PONG")) - } - }() + go fakeServer(t, &proxyBackendListener) _, port, err := net.SplitHostPort(backendListener.Addr().String()) require.NoError(t, err) dialerManager := NewDialerManager(nil) - - dynamicConf := map[string]*dynamic.TCPServersTransport{ + dialerManager.Update(map[string]*dynamic.TCPServersTransport{ "test": { ProxyProtocol: &dynamic.ProxyProtocol{ Version: test.version, }, }, - } + }) - dialerManager.Update(dynamicConf) - - dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ - ServersTransport: "test", - }, false) + dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false) require.NoError(t, err) conn, err := dialer.Dial("tcp", ":"+port) require.NoError(t, err) defer conn.Close() - _, err = conn.Write([]byte("ping\n")) + _, err = conn.Write([]byte("ping")) require.NoError(t, err) buf := make([]byte, 64) @@ -559,18 +557,7 @@ func TestProxyProtocol(t *testing.T) { assert.Equal(t, 4, n) assert.Equal(t, "PONG", string(buf[:4])) - - // Verify proxy protocol header was sent - assert.NotEmpty(t, receivedData, "Should have received data") - - if test.version == 1 { - // For v1, check for "PROXY" prefix - assert.True(t, bytes.HasPrefix(receivedData, []byte("PROXY TCP4")), "Should contain PROXY TCP4 header") - } else { - // For v2, check for binary signature - expectedSignature := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} - assert.True(t, bytes.HasPrefix(receivedData, expectedSignature), "Should contain v2 binary signature") - } + assert.Equal(t, test.version, version) }) } } @@ -581,11 +568,11 @@ func TestProxyProtocolWithTLS(t *testing.T) { version int }{ { - desc: "proxy protocol v1 with TLS", + desc: "proxy protocol v1", version: 1, }, { - desc: "proxy protocol v2 with TLS", + desc: "proxy protocol v2", version: 2, }, } @@ -597,20 +584,30 @@ func TestProxyProtocolWithTLS(t *testing.T) { backendListener, err := net.Listen("tcp", ":0") require.NoError(t, err) - defer backendListener.Close() - receivedData := make([]byte, 1024) + var version int + proxyBackendListener := proxyproto.Listener{ + Listener: backendListener, + ValidateHeader: func(h *proxyproto.Header) error { + version = int(h.Version) + return nil + }, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + switch test.version { + case 1, 2: + return proxyproto.USE, nil + default: + return proxyproto.REQUIRE, errors.New("unsupported version") + } + }, + } + defer proxyBackendListener.Close() - // Create a server that captures proxy protocol headers before TLS go func() { - conn, err := backendListener.Accept() + conn, err := proxyBackendListener.Accept() require.NoError(t, err) defer conn.Close() - // Read the proxy protocol header first (before TLS handshake) - n, err := conn.Read(receivedData) - require.NoError(t, err) - // Now wrap with TLS and perform handshake tlsConn := tls.Server(conn, &tls.Config{Certificates: []tls.Certificate{cert}}) defer tlsConn.Close() @@ -618,11 +615,11 @@ func TestProxyProtocolWithTLS(t *testing.T) { err = tlsConn.Handshake() require.NoError(t, err) - _, err = tlsConn.Read(receivedData[n:]) + buf := make([]byte, 64) + n, err := tlsConn.Read(buf) require.NoError(t, err) - // Check if there's ping in the data and respond - if bytes.Contains(receivedData, []byte("ping")) { + if bytes.Equal(buf[:n], []byte("ping")) { _, _ = tlsConn.Write([]byte("PONG")) } }() @@ -631,8 +628,7 @@ func TestProxyProtocolWithTLS(t *testing.T) { require.NoError(t, err) dialerManager := NewDialerManager(nil) - - dynamicConf := map[string]*dynamic.TCPServersTransport{ + dialerManager.Update(map[string]*dynamic.TCPServersTransport{ "test": { TLS: &dynamic.TLSClientConfig{ ServerName: "example.com", @@ -643,9 +639,7 @@ func TestProxyProtocolWithTLS(t *testing.T) { Version: test.version, }, }, - } - - dialerManager.Update(dynamicConf) + }) dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ ServersTransport: "test", @@ -656,7 +650,7 @@ func TestProxyProtocolWithTLS(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, err = conn.Write([]byte("ping\n")) + _, err = conn.Write([]byte("ping")) require.NoError(t, err) buf := make([]byte, 64) @@ -665,18 +659,7 @@ func TestProxyProtocolWithTLS(t *testing.T) { assert.Equal(t, 4, n) assert.Equal(t, "PONG", string(buf[:4])) - - // Verify proxy protocol header was sent - assert.NotEmpty(t, receivedData, "Proxy protocol header should not be empty") - - if test.version == 1 { - // For v1, check for "PROXY" prefix - assert.True(t, bytes.HasPrefix(receivedData, []byte("PROXY TCP4")), "Should contain PROXY TCP4 header") - } else { - // For v2, check for binary signature - expectedSignature := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} - assert.True(t, bytes.HasPrefix(receivedData, expectedSignature), "Should contain v2 binary signature") - } + assert.Equal(t, test.version, version) }) } } @@ -686,20 +669,16 @@ func TestProxyProtocolDisabled(t *testing.T) { require.NoError(t, err) defer backendListener.Close() - receivedData := make([]byte, 1024) - - // Start a server that captures all data go func() { conn, err := backendListener.Accept() require.NoError(t, err) defer conn.Close() - // Read first chunk of data - _, err = conn.Read(receivedData) + buf := make([]byte, 64) + n, err := conn.Read(buf) require.NoError(t, err) - // Handle ping/pong if it's in the data - if bytes.Contains(receivedData, []byte("ping")) { + if bytes.Equal(buf[:n], []byte("ping")) { _, _ = conn.Write([]byte("PONG")) } }() @@ -707,24 +686,19 @@ func TestProxyProtocolDisabled(t *testing.T) { _, port, err := net.SplitHostPort(backendListener.Addr().String()) require.NoError(t, err) + // No proxy protocol configuration. dialerManager := NewDialerManager(nil) - - // No proxy protocol configuration - dynamicConf := map[string]*dynamic.TCPServersTransport{ + dialerManager.Update(map[string]*dynamic.TCPServersTransport{ "test": {}, - } + }) - dialerManager.Update(dynamicConf) - - dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ - ServersTransport: "test", - }, false) + dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false) require.NoError(t, err) conn, err := dialer.Dial("tcp", ":"+port) require.NoError(t, err) - _, err = conn.Write([]byte("ping\n")) + _, err = conn.Write([]byte("ping")) require.NoError(t, err) buf := make([]byte, 64) @@ -733,12 +707,6 @@ func TestProxyProtocolDisabled(t *testing.T) { assert.Equal(t, 4, n) assert.Equal(t, "PONG", string(buf[:4])) - - err = conn.Close() - require.NoError(t, err) - - // Verify no proxy protocol header was sent - data should start with "ping" - assert.False(t, bytes.HasPrefix(receivedData, []byte("PROXY")), "Should not contain PROXY header") } // fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs. diff --git a/pkg/tcp/proxy_test.go b/pkg/tcp/proxy_test.go index 4b21755d7..34bd6cdfe 100644 --- a/pkg/tcp/proxy_test.go +++ b/pkg/tcp/proxy_test.go @@ -10,41 +10,11 @@ import ( "github.com/stretchr/testify/require" ) -func fakeRedis(t *testing.T, listener net.Listener) { - t.Helper() - - for { - conn, err := listener.Accept() - require.NoError(t, err) - - for { - withErr := false - buf := make([]byte, 64) - if _, err := conn.Read(buf); err != nil { - withErr = true - } - - if string(buf[:4]) == "ping" { - time.Sleep(1 * time.Millisecond) - if _, err := conn.Write([]byte("PONG")); err != nil { - _ = conn.Close() - return - } - } - - if withErr { - _ = conn.Close() - return - } - } - } -} - func TestCloseWrite(t *testing.T) { backendListener, err := net.Listen("tcp", ":0") require.NoError(t, err) - go fakeRedis(t, backendListener) + go fakeServer(t, backendListener) _, port, err := net.SplitHostPort(backendListener.Addr().String()) require.NoError(t, err) @@ -80,6 +50,37 @@ func TestCloseWrite(t *testing.T) { buffer := bytes.NewBuffer(buf) n, err := io.Copy(buffer, conn) require.NoError(t, err) + require.Equal(t, int64(4), n) require.Equal(t, "PONG", buffer.String()) } + +func fakeServer(t *testing.T, listener net.Listener) { + t.Helper() + + for { + conn, err := listener.Accept() + require.NoError(t, err) + + for { + withErr := false + buf := make([]byte, 64) + if _, err := conn.Read(buf); err != nil { + withErr = true + } + + if string(buf[:4]) == "ping" { + time.Sleep(1 * time.Millisecond) + if _, err := conn.Write([]byte("PONG")); err != nil { + _ = conn.Close() + return + } + } + + if withErr { + _ = conn.Close() + return + } + } + } +}