mirror of
https://github.com/traefik/traefik.git
synced 2025-10-28 23:11:39 +01:00
1094 lines
38 KiB
Go
1094 lines
38 KiB
Go
package leasttime
|
||
|
||
import (
|
||
"context"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/http/httptrace"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||
)
|
||
|
||
type key string
|
||
|
||
const serviceName key = "serviceName"
|
||
|
||
func pointer[T any](v T) *T { return &v }
|
||
|
||
// responseRecorder tracks which servers handled requests.
|
||
type responseRecorder struct {
|
||
*httptest.ResponseRecorder
|
||
save map[string]int
|
||
}
|
||
|
||
func (r *responseRecorder) WriteHeader(statusCode int) {
|
||
server := r.Header().Get("server")
|
||
if server != "" {
|
||
r.save[server]++
|
||
}
|
||
r.ResponseRecorder.WriteHeader(statusCode)
|
||
}
|
||
|
||
// TestBalancer tests basic server addition and least-time selection.
|
||
func TestBalancer(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "second")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 10 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// With least-time and equal response times, both servers should get some traffic.
|
||
assert.Positive(t, recorder.save["first"])
|
||
assert.Positive(t, recorder.save["second"])
|
||
assert.Equal(t, 10, recorder.save["first"]+recorder.save["second"])
|
||
}
|
||
|
||
// TestBalancerNoService tests behavior when no servers are configured.
|
||
func TestBalancerNoService(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
|
||
}
|
||
|
||
// TestBalancerNoServiceUp tests behavior when all servers are marked down.
|
||
func TestBalancerNoServiceUp(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.WriteHeader(http.StatusInternalServerError)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.WriteHeader(http.StatusInternalServerError)
|
||
}), pointer(1), false)
|
||
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "first", false)
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
|
||
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
|
||
}
|
||
|
||
// TestBalancerOneServerDown tests that down servers are excluded from selection.
|
||
func TestBalancerOneServerDown(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.WriteHeader(http.StatusInternalServerError)
|
||
}), pointer(1), false)
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 3 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
assert.Equal(t, 3, recorder.save["first"])
|
||
assert.Equal(t, 0, recorder.save["second"])
|
||
}
|
||
|
||
// TestBalancerOneServerDownThenUp tests server status transitions.
|
||
func TestBalancerOneServerDownThenUp(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "second")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 3 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
assert.Equal(t, 3, recorder.save["first"])
|
||
assert.Equal(t, 0, recorder.save["second"])
|
||
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", true)
|
||
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 20 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
// Both servers should get some traffic.
|
||
assert.Positive(t, recorder.save["first"])
|
||
assert.Positive(t, recorder.save["second"])
|
||
assert.Equal(t, 20, recorder.save["first"]+recorder.save["second"])
|
||
}
|
||
|
||
// TestBalancerAllServersZeroWeight tests that all zero-weight servers result in no available server.
|
||
func TestBalancerAllServersZeroWeight(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
|
||
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
|
||
}
|
||
|
||
// TestBalancerOneServerZeroWeight tests that zero-weight servers are ignored.
|
||
func TestBalancerOneServerZeroWeight(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 3 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Only first server should receive traffic.
|
||
assert.Equal(t, 3, recorder.save["first"])
|
||
assert.Equal(t, 0, recorder.save["second"])
|
||
}
|
||
|
||
// TestBalancerPropagate tests status propagation to parent balancers.
|
||
func TestBalancerPropagate(t *testing.T) {
|
||
balancer1 := New(nil, true)
|
||
|
||
balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "second")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer2 := New(nil, true)
|
||
balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "third")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "fourth")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
topBalancer := New(nil, true)
|
||
topBalancer.Add("balancer1", balancer1, pointer(1), false)
|
||
topBalancer.Add("balancer2", balancer2, pointer(1), false)
|
||
err := balancer1.RegisterStatusUpdater(func(up bool) {
|
||
topBalancer.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "balancer1", up)
|
||
})
|
||
assert.NoError(t, err)
|
||
err = balancer2.RegisterStatusUpdater(func(up bool) {
|
||
topBalancer.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "balancer2", up)
|
||
})
|
||
assert.NoError(t, err)
|
||
|
||
// Set all children of balancer1 to down, should propagate to top.
|
||
balancer1.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "first", false)
|
||
balancer1.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "second", false)
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 4 {
|
||
topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Only balancer2 should receive traffic.
|
||
assert.Equal(t, 0, recorder.save["first"])
|
||
assert.Equal(t, 0, recorder.save["second"])
|
||
assert.Equal(t, 4, recorder.save["third"]+recorder.save["fourth"])
|
||
}
|
||
|
||
// TestBalancerOneServerFenced tests that fenced servers are excluded from selection.
|
||
func TestBalancerOneServerFenced(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "second")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), true) // fenced
|
||
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 3 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Only first server should receive traffic.
|
||
assert.Equal(t, 3, recorder.save["first"])
|
||
assert.Equal(t, 0, recorder.save["second"])
|
||
}
|
||
|
||
// TestBalancerAllFencedServers tests that all fenced servers result in no available server.
|
||
func TestBalancerAllFencedServers(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(1), true)
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(1), true)
|
||
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
|
||
}
|
||
|
||
// TestBalancerRegisterStatusUpdaterWithoutHealthCheck tests error when registering updater without health check.
|
||
func TestBalancerRegisterStatusUpdaterWithoutHealthCheck(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
err := balancer.RegisterStatusUpdater(func(up bool) {})
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "healthCheck not enabled")
|
||
}
|
||
|
||
// TestBalancerSticky tests sticky session support.
|
||
func TestBalancerSticky(t *testing.T) {
|
||
balancer := New(&dynamic.Sticky{
|
||
Cookie: &dynamic.Cookie{
|
||
Name: "test",
|
||
},
|
||
}, false)
|
||
|
||
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "first")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "second")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// First request should set cookie.
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
firstServer := recorder.Header().Get("server")
|
||
assert.NotEmpty(t, firstServer)
|
||
|
||
// Extract cookie from first response.
|
||
cookies := recorder.Result().Cookies()
|
||
assert.NotEmpty(t, cookies)
|
||
|
||
// Second request with cookie should hit same server.
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
for _, cookie := range cookies {
|
||
req.AddCookie(cookie)
|
||
}
|
||
|
||
recorder2 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder2, req)
|
||
secondServer := recorder2.Header().Get("server")
|
||
|
||
assert.Equal(t, firstServer, secondServer)
|
||
}
|
||
|
||
// TestBalancerStickyFallback tests that sticky sessions fallback to least-time when sticky server is down.
|
||
func TestBalancerStickyFallback(t *testing.T) {
|
||
balancer := New(&dynamic.Sticky{
|
||
Cookie: &dynamic.Cookie{
|
||
Name: "test",
|
||
},
|
||
}, false)
|
||
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(50 * time.Millisecond)
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(50 * time.Millisecond)
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// Make initial request to establish sticky session with server1.
|
||
recorder1 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder1, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
firstServer := recorder1.Header().Get("server")
|
||
assert.NotEmpty(t, firstServer)
|
||
|
||
// Extract cookie from first response.
|
||
cookies := recorder1.Result().Cookies()
|
||
assert.NotEmpty(t, cookies)
|
||
|
||
// Mark the sticky server as DOWN
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "test"), firstServer, false)
|
||
|
||
// Request with sticky cookie should fallback to the other server
|
||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
for _, cookie := range cookies {
|
||
req2.AddCookie(cookie)
|
||
}
|
||
recorder2 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder2, req2)
|
||
fallbackServer := recorder2.Header().Get("server")
|
||
assert.NotEqual(t, firstServer, fallbackServer)
|
||
assert.NotEmpty(t, fallbackServer)
|
||
|
||
// New sticky cookie should be written for the fallback server
|
||
newCookies := recorder2.Result().Cookies()
|
||
assert.NotEmpty(t, newCookies)
|
||
|
||
// Verify sticky session persists with the fallback server
|
||
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
for _, cookie := range newCookies {
|
||
req3.AddCookie(cookie)
|
||
}
|
||
recorder3 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder3, req3)
|
||
assert.Equal(t, fallbackServer, recorder3.Header().Get("server"))
|
||
|
||
// Bring original server back UP
|
||
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "test"), firstServer, true)
|
||
|
||
// Request with fallback server cookie should still stick to fallback server
|
||
req4 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
for _, cookie := range newCookies {
|
||
req4.AddCookie(cookie)
|
||
}
|
||
recorder4 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder4, req4)
|
||
assert.Equal(t, fallbackServer, recorder4.Header().Get("server"))
|
||
}
|
||
|
||
// TestBalancerStickyFenced tests that sticky sessions persist to fenced servers (graceful shutdown)
|
||
// Fencing enables zero-downtime deployments: fenced servers reject NEW connections
|
||
// but continue serving EXISTING sticky sessions until they complete.
|
||
func TestBalancerStickyFenced(t *testing.T) {
|
||
balancer := New(&dynamic.Sticky{
|
||
Cookie: &dynamic.Cookie{
|
||
Name: "test",
|
||
},
|
||
}, false)
|
||
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// Establish sticky session with any server.
|
||
recorder1 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder1, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
stickyServer := recorder1.Header().Get("server")
|
||
assert.NotEmpty(t, stickyServer)
|
||
|
||
cookies := recorder1.Result().Cookies()
|
||
assert.NotEmpty(t, cookies)
|
||
|
||
// Fence the sticky server (simulate graceful shutdown).
|
||
balancer.handlersMu.Lock()
|
||
balancer.fenced[stickyServer] = struct{}{}
|
||
balancer.handlersMu.Unlock()
|
||
|
||
// Existing sticky session should STILL work (graceful draining).
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
for _, cookie := range cookies {
|
||
req.AddCookie(cookie)
|
||
}
|
||
recorder2 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder2, req)
|
||
assert.Equal(t, stickyServer, recorder2.Header().Get("server"))
|
||
|
||
// But NEW requests should NOT go to the fenced server.
|
||
recorder3 := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder3, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
newServer := recorder3.Header().Get("server")
|
||
assert.NotEqual(t, stickyServer, newServer)
|
||
assert.NotEmpty(t, newServer)
|
||
}
|
||
|
||
// TestRingBufferBasic tests basic ring buffer functionality with few samples.
|
||
func TestRingBufferBasic(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
// Test cold start - no samples.
|
||
avg := handler.getAvgResponseTime()
|
||
assert.InDelta(t, 0.0, avg, 0)
|
||
|
||
// Add one sample.
|
||
handler.updateResponseTime(10 * time.Millisecond)
|
||
avg = handler.getAvgResponseTime()
|
||
assert.InDelta(t, 10.0, avg, 0)
|
||
|
||
// Add more samples.
|
||
handler.updateResponseTime(20 * time.Millisecond)
|
||
handler.updateResponseTime(30 * time.Millisecond)
|
||
avg = handler.getAvgResponseTime()
|
||
assert.InDelta(t, 20.0, avg, 0) // (10 + 20 + 30) / 3 = 20
|
||
}
|
||
|
||
// TestRingBufferWraparound tests ring buffer behavior when it wraps around
|
||
func TestRingBufferWraparound(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
// Fill the buffer with 100 samples of 10ms each.
|
||
for range sampleSize {
|
||
handler.updateResponseTime(10 * time.Millisecond)
|
||
}
|
||
avg := handler.getAvgResponseTime()
|
||
assert.InDelta(t, 10.0, avg, 0)
|
||
|
||
// Add one more sample (should replace oldest).
|
||
handler.updateResponseTime(20 * time.Millisecond)
|
||
avg = handler.getAvgResponseTime()
|
||
// Sum: 99*10 + 1*20 = 1010, avg = 1010/100 = 10.1
|
||
assert.InDelta(t, 10.1, avg, 0)
|
||
|
||
// Add 10 more samples of 30ms.
|
||
for range 10 {
|
||
handler.updateResponseTime(30 * time.Millisecond)
|
||
}
|
||
avg = handler.getAvgResponseTime()
|
||
// Sum: 89*10 + 1*20 + 10*30 = 890 + 20 + 300 = 1210, avg = 1210/100 = 12.1
|
||
assert.InDelta(t, 12.1, avg, 0)
|
||
}
|
||
|
||
// TestRingBufferLarge tests ring buffer with many samples (> 100).
|
||
func TestRingBufferLarge(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
// Add 150 samples.
|
||
for i := range 150 {
|
||
handler.updateResponseTime(time.Duration(i+1) * time.Millisecond)
|
||
}
|
||
|
||
// Should only track last 100 samples: 51, 52, ..., 150
|
||
// Sum = (51 + 150) * 100 / 2 = 10050
|
||
// Avg = 10050 / 100 = 100.5
|
||
avg := handler.getAvgResponseTime()
|
||
assert.InDelta(t, 100.5, avg, 0)
|
||
}
|
||
|
||
// TestInflightCounter tests inflight request tracking.
|
||
func TestInflightCounter(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
var inflightAtRequest atomic.Int64
|
||
|
||
balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
inflightAtRequest.Store(balancer.handlers[0].inflightCount.Load())
|
||
rw.Header().Set("server", "test")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// Check that inflight count is 0 initially.
|
||
balancer.handlersMu.RLock()
|
||
handler := balancer.handlers[0]
|
||
balancer.handlersMu.RUnlock()
|
||
assert.Equal(t, int64(0), handler.inflightCount.Load())
|
||
|
||
// Make a request.
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
||
// During request, inflight should have been 1.
|
||
assert.Equal(t, int64(1), inflightAtRequest.Load())
|
||
|
||
// After request completes, inflight should be back to 0.
|
||
assert.Equal(t, int64(0), handler.inflightCount.Load())
|
||
}
|
||
|
||
// TestConcurrentResponseTimeUpdates tests thread safety of response time updates.
|
||
func TestConcurrentResponseTimeUpdates(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
// Concurrently update response times.
|
||
var wg sync.WaitGroup
|
||
numGoroutines := 10
|
||
updatesPerGoroutine := 20
|
||
|
||
for i := range numGoroutines {
|
||
wg.Add(1)
|
||
go func(id int) {
|
||
defer wg.Done()
|
||
for range updatesPerGoroutine {
|
||
handler.updateResponseTime(time.Duration(id+1) * time.Millisecond)
|
||
}
|
||
}(i)
|
||
}
|
||
|
||
wg.Wait()
|
||
|
||
// Should have exactly 100 samples (buffer size).
|
||
assert.Equal(t, sampleSize, handler.sampleCount)
|
||
}
|
||
|
||
// TestConcurrentInflightTracking tests thread safety of inflight counter.
|
||
func TestConcurrentInflightTracking(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(10 * time.Millisecond)
|
||
rw.WriteHeader(http.StatusOK)
|
||
}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
var maxInflight atomic.Int64
|
||
|
||
var wg sync.WaitGroup
|
||
numRequests := 50
|
||
|
||
for range numRequests {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
handler.inflightCount.Add(1)
|
||
defer handler.inflightCount.Add(-1)
|
||
|
||
// Track maximum inflight count.
|
||
current := handler.inflightCount.Load()
|
||
for {
|
||
maxLoad := maxInflight.Load()
|
||
if current <= maxLoad || maxInflight.CompareAndSwap(maxLoad, current) {
|
||
break
|
||
}
|
||
}
|
||
|
||
time.Sleep(1 * time.Millisecond)
|
||
}()
|
||
}
|
||
|
||
wg.Wait()
|
||
|
||
// All requests completed, inflight should be 0.
|
||
assert.Equal(t, int64(0), handler.inflightCount.Load())
|
||
// Max inflight should be > 1 (concurrent requests).
|
||
assert.Greater(t, maxInflight.Load(), int64(1))
|
||
}
|
||
|
||
// TestConcurrentRequestsRespectInflight tests that the load balancer dynamically
|
||
// adapts to inflight request counts during concurrent request processing.
|
||
func TestConcurrentRequestsRespectInflight(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Use a channel to control when handlers start sleeping.
|
||
// This ensures we can fill one server with inflight requests before routing new ones.
|
||
blockChan := make(chan struct{})
|
||
|
||
// Add two servers with equal response times and weights.
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
<-blockChan // Wait for signal to proceed.
|
||
time.Sleep(10 * time.Millisecond)
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
<-blockChan // Wait for signal to proceed.
|
||
time.Sleep(10 * time.Millisecond)
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// Pre-warm both servers to establish equal average response times.
|
||
for i := range sampleSize {
|
||
balancer.handlers[0].responseTimes[i] = 10.0
|
||
}
|
||
balancer.handlers[0].responseTimeSum = 10.0 * sampleSize
|
||
balancer.handlers[0].sampleCount = sampleSize
|
||
|
||
for i := range sampleSize {
|
||
balancer.handlers[1].responseTimes[i] = 10.0
|
||
}
|
||
balancer.handlers[1].responseTimeSum = 10.0 * sampleSize
|
||
balancer.handlers[1].sampleCount = sampleSize
|
||
|
||
// Phase 1: Launch concurrent requests to server1 that will block.
|
||
var wg sync.WaitGroup
|
||
inflightRequests := 5
|
||
|
||
for range inflightRequests {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}()
|
||
}
|
||
|
||
// Wait for goroutines to start and increment inflight counters.
|
||
// They will block on the channel, keeping inflight count high.
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// Verify inflight counts before making new requests.
|
||
server1Inflight := balancer.handlers[0].inflightCount.Load()
|
||
server2Inflight := balancer.handlers[1].inflightCount.Load()
|
||
assert.Equal(t, int64(5), server1Inflight+server2Inflight)
|
||
|
||
// Phase 2: Make new requests while the initial requests are blocked.
|
||
// These should see the high inflight counts and route to the less-loaded server.
|
||
var saveMu sync.Mutex
|
||
save := map[string]int{}
|
||
newRequests := 50
|
||
|
||
// Launch new requests in background so they don't block.
|
||
var newWg sync.WaitGroup
|
||
for range newRequests {
|
||
newWg.Add(1)
|
||
go func() {
|
||
defer newWg.Done()
|
||
rec := httptest.NewRecorder()
|
||
balancer.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
server := rec.Header().Get("server")
|
||
if server != "" {
|
||
saveMu.Lock()
|
||
save[server]++
|
||
saveMu.Unlock()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// Wait for new requests to start and see the inflight counts.
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
close(blockChan)
|
||
|
||
wg.Wait()
|
||
newWg.Wait()
|
||
|
||
saveMu.Lock()
|
||
total := save["server1"] + save["server2"]
|
||
server1Count := save["server1"]
|
||
server2Count := save["server2"]
|
||
saveMu.Unlock()
|
||
|
||
assert.Equal(t, newRequests, total)
|
||
|
||
// With inflight tracking, load should naturally balance toward equal distribution.
|
||
// We allow variance due to concurrent execution and race windows in server selection.
|
||
assert.InDelta(t, 25.0, float64(server1Count), 5.0) // 20-30 requests
|
||
assert.InDelta(t, 25.0, float64(server2Count), 5.0) // 20-30 requests
|
||
}
|
||
|
||
// TestTTFBMeasurement tests TTFB measurement accuracy.
|
||
func TestTTFBMeasurement(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add server with known delay.
|
||
delay := 50 * time.Millisecond
|
||
balancer.Add("slow", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(delay)
|
||
rw.Header().Set("server", "slow")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// Make multiple requests to build average.
|
||
for range 5 {
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Check that average response time is approximately the delay.
|
||
avg := balancer.handlers[0].getAvgResponseTime()
|
||
|
||
// Allow 5ms tolerance for Go timing jitter and test environment variations.
|
||
assert.InDelta(t, float64(delay.Milliseconds()), avg, 5.0)
|
||
}
|
||
|
||
// TestZeroSamplesReturnsZero tests that getAvgResponseTime returns 0 when no samples.
|
||
func TestZeroSamplesReturnsZero(t *testing.T) {
|
||
handler := &namedHandler{
|
||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
|
||
name: "test",
|
||
weight: 1,
|
||
}
|
||
|
||
avg := handler.getAvgResponseTime()
|
||
assert.InDelta(t, 0.0, avg, 0)
|
||
}
|
||
|
||
// TestScoreCalculationWithWeights tests that weights are properly considered in score calculation.
|
||
func TestScoreCalculationWithWeights(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add two servers with same response time but different weights.
|
||
// Server with higher weight should be preferred.
|
||
balancer.Add("weighted", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(50 * time.Millisecond)
|
||
rw.Header().Set("server", "weighted")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(3), false) // Weight 3
|
||
|
||
balancer.Add("normal", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(50 * time.Millisecond)
|
||
rw.Header().Set("server", "normal")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false) // Weight 1
|
||
|
||
// Make requests to build up response time averages.
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 2 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Score for weighted: (50 × (1 + 0)) / 3 = 16.67
|
||
// Score for normal: (50 × (1 + 0)) / 1 = 50
|
||
// After warmup, weighted server has 3x better score (16.67 vs 50) and should receive nearly all requests.
|
||
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 10 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
assert.Equal(t, 10, recorder.save["weighted"])
|
||
assert.Zero(t, recorder.save["normal"])
|
||
}
|
||
|
||
// TestScoreCalculationWithInflight tests that inflight requests are considered in score calculation.
|
||
func TestScoreCalculationWithInflight(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// We'll manually control the inflight counters to test the score calculation.
|
||
// Add two servers with same response time.
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(10 * time.Millisecond)
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(10 * time.Millisecond)
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// Build up response time averages for both servers.
|
||
for range 2 {
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Now manually set server1 to have high inflight count.
|
||
balancer.handlers[0].inflightCount.Store(5)
|
||
|
||
// Make requests - they should prefer server2 because:
|
||
// Score for server1: (10 × (1 + 5)) / 1 = 60
|
||
// Score for server2: (10 × (1 + 0)) / 1 = 10
|
||
recorder := &responseRecorder{save: map[string]int{}}
|
||
for range 5 {
|
||
// Manually increment to simulate the ServeHTTP behavior.
|
||
server, _ := balancer.nextServer()
|
||
server.inflightCount.Add(1)
|
||
|
||
if server.name == "server1" {
|
||
recorder.save["server1"]++
|
||
} else {
|
||
recorder.save["server2"]++
|
||
}
|
||
}
|
||
|
||
// Server2 should get all requests
|
||
assert.Equal(t, 5, recorder.save["server2"])
|
||
assert.Zero(t, recorder.save["server1"])
|
||
}
|
||
|
||
// TestScoreCalculationColdStart tests that new servers (0ms avg) get fair selection
|
||
func TestScoreCalculationColdStart(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add a warm server with established response time
|
||
balancer.Add("warm", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(50 * time.Millisecond)
|
||
rw.Header().Set("server", "warm")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// Warm up the first server
|
||
for range 10 {
|
||
recorder := httptest.NewRecorder()
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Now add a cold server (new, no response time data)
|
||
balancer.Add("cold", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(10 * time.Millisecond) // Actually faster
|
||
rw.Header().Set("server", "cold")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// Cold server should get selected because:
|
||
// Score for warm: (50 × (1 + 0)) / 1 = 50
|
||
// Score for cold: (0 × (1 + 0)) / 1 = 0
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 20 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// Cold server should get all or most requests initially due to 0ms average
|
||
assert.Greater(t, recorder.save["cold"], 10)
|
||
|
||
// After cold server builds up its average, it should continue to get more traffic
|
||
// because it's actually faster (10ms vs 50ms)
|
||
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 20 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
assert.Greater(t, recorder.save["cold"], recorder.save["warm"])
|
||
}
|
||
|
||
// TestFastServerGetsMoreTraffic verifies that servers with lower response times
|
||
// receive proportionally more traffic in steady state (after cold start).
|
||
// This tests the core selection bias of the least-time algorithm.
|
||
func TestFastServerGetsMoreTraffic(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add two servers with different static response times.
|
||
balancer.Add("fast", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(20 * time.Millisecond)
|
||
rw.Header().Set("server", "fast")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("slow", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(100 * time.Millisecond)
|
||
rw.Header().Set("server", "slow")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// After just 1 request to each server, the algorithm identifies the fastest server
|
||
// and routes nearly all subsequent traffic there (converges in ~2 requests).
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 50 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
assert.Greater(t, recorder.save["fast"], recorder.save["slow"])
|
||
assert.Greater(t, recorder.save["fast"], 48) // Expect ~96-98% to fast server (48-49/50).
|
||
}
|
||
|
||
// TestTrafficShiftsWhenPerformanceDegrades verifies that the load balancer
|
||
// adapts to changing server performance by shifting traffic away from degraded servers.
|
||
// This tests the adaptive behavior - the core value proposition of least-time load balancing.
|
||
func TestTrafficShiftsWhenPerformanceDegrades(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Use atomic to dynamically control server1's response time.
|
||
server1Delay := atomic.Int64{}
|
||
server1Delay.Store(5) // Start with 5ms
|
||
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(time.Duration(server1Delay.Load()) * time.Millisecond)
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond) // Static 5ms
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
|
||
}), pointer(1), false)
|
||
|
||
// Pre-fill ring buffers to eliminate cold start effects and ensure deterministic equal performance state.
|
||
for _, h := range balancer.handlers {
|
||
for i := range sampleSize {
|
||
h.responseTimes[i] = 5.0
|
||
}
|
||
h.responseTimeSum = 5.0 * sampleSize
|
||
h.sampleCount = sampleSize
|
||
}
|
||
|
||
// Phase 1: Both servers perform equally (5ms each).
|
||
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 50 {
|
||
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// With equal performance and pre-filled buffers, distribution should be balanced via WRR tie-breaking.
|
||
total := recorder.save["server1"] + recorder.save["server2"]
|
||
assert.Equal(t, 50, total)
|
||
assert.InDelta(t, 25, recorder.save["server1"], 10) // 25 ± 10 requests
|
||
assert.InDelta(t, 25, recorder.save["server2"], 10) // 25 ± 10 requests
|
||
|
||
// Phase 2: server1 degrades (simulating GC pause, CPU spike, or network latency).
|
||
server1Delay.Store(15) // Now 15ms (3x slower)
|
||
|
||
// Make more requests to shift the moving average.
|
||
// Ring buffer has 100 samples, need significant new samples to shift average.
|
||
// server1's average will climb from ~5ms toward 15ms.
|
||
recorder2 := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
|
||
for range 60 {
|
||
balancer.ServeHTTP(recorder2, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
}
|
||
|
||
// server2 should get significantly more traffic (>75%)
|
||
// Score for server1: (~10-15ms × 1) / 1 = 10-15 (as average climbs)
|
||
// Score for server2: (5ms × 1) / 1 = 5
|
||
total2 := recorder2.save["server1"] + recorder2.save["server2"]
|
||
assert.Equal(t, 60, total2)
|
||
assert.Greater(t, recorder2.save["server2"], 45) // At least 75% (45/60)
|
||
assert.Less(t, recorder2.save["server1"], 15) // At most 25% (15/60)
|
||
}
|
||
|
||
// TestMultipleServersWithSameScore tests WRR tie-breaking when multiple servers have identical scores.
|
||
// Uses nextServer() directly to avoid timing variations in the test.
|
||
func TestMultipleServersWithSameScore(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add three servers with identical response times and weights.
|
||
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "server1")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "server2")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
balancer.Add("server3", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "server3")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false)
|
||
|
||
// Set all servers to identical response times to trigger tie-breaking.
|
||
for _, h := range balancer.handlers {
|
||
for i := range sampleSize {
|
||
h.responseTimes[i] = 5.0
|
||
}
|
||
h.responseTimeSum = 5.0 * sampleSize
|
||
h.sampleCount = sampleSize
|
||
}
|
||
|
||
// With all servers having identical scores, WRR tie-breaking should distribute fairly.
|
||
// Test the selection logic directly without actual HTTP requests to avoid timing variations.
|
||
counts := map[string]int{"server1": 0, "server2": 0, "server3": 0}
|
||
for range 90 {
|
||
server, err := balancer.nextServer()
|
||
assert.NoError(t, err)
|
||
counts[server.name]++
|
||
}
|
||
|
||
total := counts["server1"] + counts["server2"] + counts["server3"]
|
||
assert.Equal(t, 90, total)
|
||
|
||
// With WRR and 90 requests, each server should get ~30 requests (±1 due to initialization).
|
||
assert.InDelta(t, 30, counts["server1"], 1)
|
||
assert.InDelta(t, 30, counts["server2"], 1)
|
||
assert.InDelta(t, 30, counts["server3"], 1)
|
||
}
|
||
|
||
// TestWRRTieBreakingWeightedDistribution tests weighted distribution among tied servers.
|
||
// Uses nextServer() directly to avoid timing variations in the test.
|
||
func TestWRRTieBreakingWeightedDistribution(t *testing.T) {
|
||
balancer := New(nil, false)
|
||
|
||
// Add two servers with different weights.
|
||
// To create equal scores, response times must be proportional to weights.
|
||
balancer.Add("weighted", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(15 * time.Millisecond) // 3x longer due to 3x weight
|
||
rw.Header().Set("server", "weighted")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(3), false) // Weight 3
|
||
|
||
balancer.Add("normal", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
time.Sleep(5 * time.Millisecond)
|
||
rw.Header().Set("server", "normal")
|
||
rw.WriteHeader(http.StatusOK)
|
||
}), pointer(1), false) // Weight 1
|
||
|
||
// Since response times is proportional to weights, both scores are equal, so WRR tie-breaking will apply.
|
||
// weighted: score = (15 * 1) / 3 = 5
|
||
// normal: score = (5 * 1) / 1 = 5
|
||
for i := range sampleSize {
|
||
balancer.handlers[0].responseTimes[i] = 15.0
|
||
}
|
||
balancer.handlers[0].responseTimeSum = 15.0 * sampleSize
|
||
balancer.handlers[0].sampleCount = sampleSize
|
||
|
||
for i := range sampleSize {
|
||
balancer.handlers[1].responseTimes[i] = 5.0
|
||
}
|
||
balancer.handlers[1].responseTimeSum = 5.0 * sampleSize
|
||
balancer.handlers[1].sampleCount = sampleSize
|
||
|
||
// Test the selection logic directly without actual HTTP requests to avoid timing variations.
|
||
counts := map[string]int{"weighted": 0, "normal": 0}
|
||
for range 80 {
|
||
server, err := balancer.nextServer()
|
||
assert.NoError(t, err)
|
||
counts[server.name]++
|
||
}
|
||
|
||
total := counts["weighted"] + counts["normal"]
|
||
assert.Equal(t, 80, total)
|
||
|
||
// With 3:1 weight ratio, distribution should be ~75%/25% (60/80 and 20/80), ±1 due to initialization.
|
||
assert.InDelta(t, 60, counts["weighted"], 1)
|
||
assert.InDelta(t, 20, counts["normal"], 1)
|
||
}
|