This commit is contained in:
kevinpollet 2025-07-31 14:13:21 +02:00
parent be8851c7f3
commit 95d343876e
No known key found for this signature in database
GPG Key ID: 0C9A5DDD1B292453
2 changed files with 97 additions and 128 deletions

View File

@ -7,6 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"errors"
"io" "io"
"math/big" "math/big"
"net" "net"
@ -14,6 +15,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/pires/go-proxyproto"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle" "github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig" "github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
@ -140,7 +142,7 @@ func TestNoTLS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer backendListener.Close() defer backendListener.Close()
go fakeRedis(t, backendListener) go fakeServer(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String()) _, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
@ -186,7 +188,7 @@ func TestTLS(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}}) tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close() defer tlsListener.Close()
go fakeRedis(t, tlsListener) go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String()) _, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
@ -236,7 +238,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}}) tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close() defer tlsListener.Close()
go fakeRedis(t, tlsListener) go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String()) _, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
@ -297,7 +299,7 @@ func TestMTLS(t *testing.T) {
}) })
defer tlsListener.Close() defer tlsListener.Close()
go fakeRedis(t, tlsListener) go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String()) _, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
@ -444,7 +446,7 @@ func TestSpiffeMTLS(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
go fakeRedis(t, tlsListener) go fakeServer(t, tlsListener)
dialerManager := NewDialerManager(test.clientSource) dialerManager := NewDialerManager(test.clientSource)
@ -506,51 +508,47 @@ func TestProxyProtocol(t *testing.T) {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0") backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err) require.NoError(t, err)
defer backendListener.Close()
receivedData := make([]byte, 1024) var version int
proxyBackendListener := proxyproto.Listener{
// Start a server that captures all data including proxy protocol headers Listener: backendListener,
go func() { ValidateHeader: func(h *proxyproto.Header) error {
conn, err := backendListener.Accept() version = int(h.Version)
require.NoError(t, err) return nil
defer conn.Close() },
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
// Read all initial data switch test.version {
_, err = conn.Read(receivedData) case 1, 2:
require.NoError(t, err) return proxyproto.USE, nil
default:
// Check if there's ping in the data and respond return proxyproto.REQUIRE, errors.New("unsupported version")
if bytes.Contains(receivedData, []byte("ping")) {
_, _ = conn.Write([]byte("PONG"))
} }
}() },
}
defer proxyBackendListener.Close()
go fakeServer(t, &proxyBackendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String()) _, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
dialerManager := NewDialerManager(nil) dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
dynamicConf := map[string]*dynamic.TCPServersTransport{
"test": { "test": {
ProxyProtocol: &dynamic.ProxyProtocol{ ProxyProtocol: &dynamic.ProxyProtocol{
Version: test.version, Version: test.version,
}, },
}, },
} })
dialerManager.Update(dynamicConf) dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{
ServersTransport: "test",
}, false)
require.NoError(t, err) require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port) conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
_, err = conn.Write([]byte("ping\n")) _, err = conn.Write([]byte("ping"))
require.NoError(t, err) require.NoError(t, err)
buf := make([]byte, 64) buf := make([]byte, 64)
@ -559,18 +557,7 @@ func TestProxyProtocol(t *testing.T) {
assert.Equal(t, 4, n) assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4])) assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
// Verify proxy protocol header was sent
assert.NotEmpty(t, receivedData, "Should have received data")
if test.version == 1 {
// For v1, check for "PROXY" prefix
assert.True(t, bytes.HasPrefix(receivedData, []byte("PROXY TCP4")), "Should contain PROXY TCP4 header")
} else {
// For v2, check for binary signature
expectedSignature := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
assert.True(t, bytes.HasPrefix(receivedData, expectedSignature), "Should contain v2 binary signature")
}
}) })
} }
} }
@ -581,11 +568,11 @@ func TestProxyProtocolWithTLS(t *testing.T) {
version int version int
}{ }{
{ {
desc: "proxy protocol v1 with TLS", desc: "proxy protocol v1",
version: 1, version: 1,
}, },
{ {
desc: "proxy protocol v2 with TLS", desc: "proxy protocol v2",
version: 2, version: 2,
}, },
} }
@ -597,20 +584,30 @@ func TestProxyProtocolWithTLS(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0") backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err) require.NoError(t, err)
defer backendListener.Close()
receivedData := make([]byte, 1024) var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
// Create a server that captures proxy protocol headers before TLS
go func() { go func() {
conn, err := backendListener.Accept() conn, err := proxyBackendListener.Accept()
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
// Read the proxy protocol header first (before TLS handshake)
n, err := conn.Read(receivedData)
require.NoError(t, err)
// Now wrap with TLS and perform handshake // Now wrap with TLS and perform handshake
tlsConn := tls.Server(conn, &tls.Config{Certificates: []tls.Certificate{cert}}) tlsConn := tls.Server(conn, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsConn.Close() defer tlsConn.Close()
@ -618,11 +615,11 @@ func TestProxyProtocolWithTLS(t *testing.T) {
err = tlsConn.Handshake() err = tlsConn.Handshake()
require.NoError(t, err) require.NoError(t, err)
_, err = tlsConn.Read(receivedData[n:]) buf := make([]byte, 64)
n, err := tlsConn.Read(buf)
require.NoError(t, err) require.NoError(t, err)
// Check if there's ping in the data and respond if bytes.Equal(buf[:n], []byte("ping")) {
if bytes.Contains(receivedData, []byte("ping")) {
_, _ = tlsConn.Write([]byte("PONG")) _, _ = tlsConn.Write([]byte("PONG"))
} }
}() }()
@ -631,8 +628,7 @@ func TestProxyProtocolWithTLS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
dialerManager := NewDialerManager(nil) dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
dynamicConf := map[string]*dynamic.TCPServersTransport{
"test": { "test": {
TLS: &dynamic.TLSClientConfig{ TLS: &dynamic.TLSClientConfig{
ServerName: "example.com", ServerName: "example.com",
@ -643,9 +639,7 @@ func TestProxyProtocolWithTLS(t *testing.T) {
Version: test.version, Version: test.version,
}, },
}, },
} })
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{
ServersTransport: "test", ServersTransport: "test",
@ -656,7 +650,7 @@ func TestProxyProtocolWithTLS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
_, err = conn.Write([]byte("ping\n")) _, err = conn.Write([]byte("ping"))
require.NoError(t, err) require.NoError(t, err)
buf := make([]byte, 64) buf := make([]byte, 64)
@ -665,18 +659,7 @@ func TestProxyProtocolWithTLS(t *testing.T) {
assert.Equal(t, 4, n) assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4])) assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
// Verify proxy protocol header was sent
assert.NotEmpty(t, receivedData, "Proxy protocol header should not be empty")
if test.version == 1 {
// For v1, check for "PROXY" prefix
assert.True(t, bytes.HasPrefix(receivedData, []byte("PROXY TCP4")), "Should contain PROXY TCP4 header")
} else {
// For v2, check for binary signature
expectedSignature := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
assert.True(t, bytes.HasPrefix(receivedData, expectedSignature), "Should contain v2 binary signature")
}
}) })
} }
} }
@ -686,20 +669,16 @@ func TestProxyProtocolDisabled(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer backendListener.Close() defer backendListener.Close()
receivedData := make([]byte, 1024)
// Start a server that captures all data
go func() { go func() {
conn, err := backendListener.Accept() conn, err := backendListener.Accept()
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
// Read first chunk of data buf := make([]byte, 64)
_, err = conn.Read(receivedData) n, err := conn.Read(buf)
require.NoError(t, err) require.NoError(t, err)
// Handle ping/pong if it's in the data if bytes.Equal(buf[:n], []byte("ping")) {
if bytes.Contains(receivedData, []byte("ping")) {
_, _ = conn.Write([]byte("PONG")) _, _ = conn.Write([]byte("PONG"))
} }
}() }()
@ -707,24 +686,19 @@ func TestProxyProtocolDisabled(t *testing.T) {
_, port, err := net.SplitHostPort(backendListener.Addr().String()) _, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
// No proxy protocol configuration.
dialerManager := NewDialerManager(nil) dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
// No proxy protocol configuration
dynamicConf := map[string]*dynamic.TCPServersTransport{
"test": {}, "test": {},
} })
dialerManager.Update(dynamicConf) dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{
ServersTransport: "test",
}, false)
require.NoError(t, err) require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port) conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err) require.NoError(t, err)
_, err = conn.Write([]byte("ping\n")) _, err = conn.Write([]byte("ping"))
require.NoError(t, err) require.NoError(t, err)
buf := make([]byte, 64) buf := make([]byte, 64)
@ -733,12 +707,6 @@ func TestProxyProtocolDisabled(t *testing.T) {
assert.Equal(t, 4, n) assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4])) assert.Equal(t, "PONG", string(buf[:4]))
err = conn.Close()
require.NoError(t, err)
// Verify no proxy protocol header was sent - data should start with "ping"
assert.False(t, bytes.HasPrefix(receivedData, []byte("PROXY")), "Should not contain PROXY header")
} }
// fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs. // fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs.

View File

@ -10,41 +10,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func fakeRedis(t *testing.T, listener net.Listener) {
t.Helper()
for {
conn, err := listener.Accept()
require.NoError(t, err)
for {
withErr := false
buf := make([]byte, 64)
if _, err := conn.Read(buf); err != nil {
withErr = true
}
if string(buf[:4]) == "ping" {
time.Sleep(1 * time.Millisecond)
if _, err := conn.Write([]byte("PONG")); err != nil {
_ = conn.Close()
return
}
}
if withErr {
_ = conn.Close()
return
}
}
}
}
func TestCloseWrite(t *testing.T) { func TestCloseWrite(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0") backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err) require.NoError(t, err)
go fakeRedis(t, backendListener) go fakeServer(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String()) _, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err) require.NoError(t, err)
@ -80,6 +50,37 @@ func TestCloseWrite(t *testing.T) {
buffer := bytes.NewBuffer(buf) buffer := bytes.NewBuffer(buf)
n, err := io.Copy(buffer, conn) n, err := io.Copy(buffer, conn)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(4), n) require.Equal(t, int64(4), n)
require.Equal(t, "PONG", buffer.String()) require.Equal(t, "PONG", buffer.String())
} }
func fakeServer(t *testing.T, listener net.Listener) {
t.Helper()
for {
conn, err := listener.Accept()
require.NoError(t, err)
for {
withErr := false
buf := make([]byte, 64)
if _, err := conn.Read(buf); err != nil {
withErr = true
}
if string(buf[:4]) == "ping" {
time.Sleep(1 * time.Millisecond)
if _, err := conn.Write([]byte("PONG")); err != nil {
_ = conn.Close()
return
}
}
if withErr {
_ = conn.Close()
return
}
}
}
}