traefik/pkg/tcp/wrr_load_balancer_test.go
Douglas De Toni Machado 8392503df7
Add TCP Healthcheck
2025-10-22 11:42:05 +02:00

295 lines
6.8 KiB
Go

package tcp
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWRRLoadBalancer_LoadBalancing(t *testing.T) {
testCases := []struct {
desc string
serversWeight map[string]int
totalCall int
expectedWrite map[string]int
expectedClose int
}{
{
desc: "RoundRobin",
serversWeight: map[string]int{
"h1": 1,
"h2": 1,
},
totalCall: 4,
expectedWrite: map[string]int{
"h1": 2,
"h2": 2,
},
},
{
desc: "WeighedRoundRobin",
serversWeight: map[string]int{
"h1": 3,
"h2": 1,
},
totalCall: 4,
expectedWrite: map[string]int{
"h1": 3,
"h2": 1,
},
},
{
desc: "WeighedRoundRobin with more call",
serversWeight: map[string]int{
"h1": 3,
"h2": 1,
},
totalCall: 16,
expectedWrite: map[string]int{
"h1": 12,
"h2": 4,
},
},
{
desc: "WeighedRoundRobin with one 0 weight server",
serversWeight: map[string]int{
"h1": 3,
"h2": 0,
},
totalCall: 16,
expectedWrite: map[string]int{
"h1": 16,
},
},
{
desc: "WeighedRoundRobin with all servers with 0 weight",
serversWeight: map[string]int{
"h1": 0,
"h2": 0,
"h3": 0,
},
totalCall: 10,
expectedWrite: map[string]int{},
expectedClose: 10,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
balancer := NewWRRLoadBalancer(false)
for server, weight := range test.serversWeight {
balancer.Add(server, HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte(server))
require.NoError(t, err)
}), &weight)
}
conn := &fakeConn{writeCall: make(map[string]int)}
for range test.totalCall {
balancer.ServeTCP(conn)
}
assert.Equal(t, test.expectedWrite, conn.writeCall)
assert.Equal(t, test.expectedClose, conn.closeCall)
})
}
}
func TestWRRLoadBalancer_NoServiceUp(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "first", false)
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
balancer.ServeTCP(conn)
assert.Empty(t, conn.writeCall)
assert.Equal(t, 1, conn.closeCall)
}
func TestWRRLoadBalancer_OneServerDown(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
for range 3 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 3, conn.writeCall["first"])
}
func TestWRRLoadBalancer_DownThenUp(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
for range 3 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 3, conn.writeCall["first"])
balancer.SetStatus(t.Context(), "second", true)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 2 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 1, conn.writeCall["first"])
assert.Equal(t, 1, conn.writeCall["second"])
}
func TestWRRLoadBalancer_Propagate(t *testing.T) {
balancer1 := NewWRRLoadBalancer(true)
balancer1.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer1.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer2 := NewWRRLoadBalancer(true)
balancer2.Add("third", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("third"))
require.NoError(t, err)
}), pointer(1))
balancer2.Add("fourth", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("fourth"))
require.NoError(t, err)
}), pointer(1))
topBalancer := NewWRRLoadBalancer(true)
topBalancer.Add("balancer1", balancer1, pointer(1))
_ = balancer1.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(t.Context(), "balancer1", up)
})
topBalancer.Add("balancer2", balancer2, pointer(1))
_ = balancer2.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(t.Context(), "balancer2", up)
})
conn := &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 2, conn.writeCall["first"])
assert.Equal(t, 2, conn.writeCall["second"])
assert.Equal(t, 2, conn.writeCall["third"])
assert.Equal(t, 2, conn.writeCall["fourth"])
// fourth gets downed, but balancer2 still up since third is still up.
balancer2.SetStatus(t.Context(), "fourth", false)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 2, conn.writeCall["first"])
assert.Equal(t, 2, conn.writeCall["second"])
assert.Equal(t, 4, conn.writeCall["third"])
assert.Equal(t, 0, conn.writeCall["fourth"])
// third gets downed, and the propagation triggers balancer2 to be marked as
// down as well for topBalancer.
balancer2.SetStatus(t.Context(), "third", false)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 4, conn.writeCall["first"])
assert.Equal(t, 4, conn.writeCall["second"])
assert.Equal(t, 0, conn.writeCall["third"])
assert.Equal(t, 0, conn.writeCall["fourth"])
}
func pointer[T any](v T) *T { return &v }
type fakeConn struct {
writeCall map[string]int
closeCall int
}
func (f *fakeConn) Read(b []byte) (n int, err error) {
panic("implement me")
}
func (f *fakeConn) Write(b []byte) (n int, err error) {
f.writeCall[string(b)]++
return len(b), nil
}
func (f *fakeConn) Close() error {
f.closeCall++
return nil
}
func (f *fakeConn) LocalAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) RemoteAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) CloseWrite() error {
panic("implement me")
}