2025-10-23 16:16:05 +02:00

1094 lines
38 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package leasttime
import (
"context"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
)
type key string
const serviceName key = "serviceName"
func pointer[T any](v T) *T { return &v }
// responseRecorder tracks which servers handled requests.
type responseRecorder struct {
*httptest.ResponseRecorder
save map[string]int
}
func (r *responseRecorder) WriteHeader(statusCode int) {
server := r.Header().Get("server")
if server != "" {
r.save[server]++
}
r.ResponseRecorder.WriteHeader(statusCode)
}
// TestBalancer tests basic server addition and least-time selection.
func TestBalancer(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 10 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// With least-time and equal response times, both servers should get some traffic.
assert.Positive(t, recorder.save["first"])
assert.Positive(t, recorder.save["second"])
assert.Equal(t, 10, recorder.save["first"]+recorder.save["second"])
}
// TestBalancerNoService tests behavior when no servers are configured.
func TestBalancerNoService(t *testing.T) {
balancer := New(nil, false)
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
// TestBalancerNoServiceUp tests behavior when all servers are marked down.
func TestBalancerNoServiceUp(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), pointer(1), false)
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "first", false)
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
// TestBalancerOneServerDown tests that down servers are excluded from selection.
func TestBalancerOneServerDown(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), pointer(1), false)
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 3 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
assert.Equal(t, 3, recorder.save["first"])
assert.Equal(t, 0, recorder.save["second"])
}
// TestBalancerOneServerDownThenUp tests server status transitions.
func TestBalancerOneServerDownThenUp(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", false)
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 3 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
assert.Equal(t, 3, recorder.save["first"])
assert.Equal(t, 0, recorder.save["second"])
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "parent"), "second", true)
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 20 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Both servers should get some traffic.
assert.Positive(t, recorder.save["first"])
assert.Positive(t, recorder.save["second"])
assert.Equal(t, 20, recorder.save["first"]+recorder.save["second"])
}
// TestBalancerAllServersZeroWeight tests that all zero-weight servers result in no available server.
func TestBalancerAllServersZeroWeight(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
// TestBalancerOneServerZeroWeight tests that zero-weight servers are ignored.
func TestBalancerOneServerZeroWeight(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(0), false)
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 3 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Only first server should receive traffic.
assert.Equal(t, 3, recorder.save["first"])
assert.Equal(t, 0, recorder.save["second"])
}
// TestBalancerPropagate tests status propagation to parent balancers.
func TestBalancerPropagate(t *testing.T) {
balancer1 := New(nil, true)
balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer2 := New(nil, true)
balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "third")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "fourth")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
topBalancer := New(nil, true)
topBalancer.Add("balancer1", balancer1, pointer(1), false)
topBalancer.Add("balancer2", balancer2, pointer(1), false)
err := balancer1.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "balancer1", up)
})
assert.NoError(t, err)
err = balancer2.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "balancer2", up)
})
assert.NoError(t, err)
// Set all children of balancer1 to down, should propagate to top.
balancer1.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "first", false)
balancer1.SetStatus(context.WithValue(t.Context(), serviceName, "top"), "second", false)
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 4 {
topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Only balancer2 should receive traffic.
assert.Equal(t, 0, recorder.save["first"])
assert.Equal(t, 0, recorder.save["second"])
assert.Equal(t, 4, recorder.save["third"]+recorder.save["fourth"])
}
// TestBalancerOneServerFenced tests that fenced servers are excluded from selection.
func TestBalancerOneServerFenced(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), pointer(1), true) // fenced
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 3 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Only first server should receive traffic.
assert.Equal(t, 3, recorder.save["first"])
assert.Equal(t, 0, recorder.save["second"])
}
// TestBalancerAllFencedServers tests that all fenced servers result in no available server.
func TestBalancerAllFencedServers(t *testing.T) {
balancer := New(nil, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(1), true)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), pointer(1), true)
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
// TestBalancerRegisterStatusUpdaterWithoutHealthCheck tests error when registering updater without health check.
func TestBalancerRegisterStatusUpdaterWithoutHealthCheck(t *testing.T) {
balancer := New(nil, false)
err := balancer.RegisterStatusUpdater(func(up bool) {})
assert.Error(t, err)
assert.Contains(t, err.Error(), "healthCheck not enabled")
}
// TestBalancerSticky tests sticky session support.
func TestBalancerSticky(t *testing.T) {
balancer := New(&dynamic.Sticky{
Cookie: &dynamic.Cookie{
Name: "test",
},
}, false)
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// First request should set cookie.
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
firstServer := recorder.Header().Get("server")
assert.NotEmpty(t, firstServer)
// Extract cookie from first response.
cookies := recorder.Result().Cookies()
assert.NotEmpty(t, cookies)
// Second request with cookie should hit same server.
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
recorder2 := httptest.NewRecorder()
balancer.ServeHTTP(recorder2, req)
secondServer := recorder2.Header().Get("server")
assert.Equal(t, firstServer, secondServer)
}
// TestBalancerStickyFallback tests that sticky sessions fallback to least-time when sticky server is down.
func TestBalancerStickyFallback(t *testing.T) {
balancer := New(&dynamic.Sticky{
Cookie: &dynamic.Cookie{
Name: "test",
},
}, false)
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(50 * time.Millisecond)
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(50 * time.Millisecond)
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// Make initial request to establish sticky session with server1.
recorder1 := httptest.NewRecorder()
balancer.ServeHTTP(recorder1, httptest.NewRequest(http.MethodGet, "/", nil))
firstServer := recorder1.Header().Get("server")
assert.NotEmpty(t, firstServer)
// Extract cookie from first response.
cookies := recorder1.Result().Cookies()
assert.NotEmpty(t, cookies)
// Mark the sticky server as DOWN
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "test"), firstServer, false)
// Request with sticky cookie should fallback to the other server
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
for _, cookie := range cookies {
req2.AddCookie(cookie)
}
recorder2 := httptest.NewRecorder()
balancer.ServeHTTP(recorder2, req2)
fallbackServer := recorder2.Header().Get("server")
assert.NotEqual(t, firstServer, fallbackServer)
assert.NotEmpty(t, fallbackServer)
// New sticky cookie should be written for the fallback server
newCookies := recorder2.Result().Cookies()
assert.NotEmpty(t, newCookies)
// Verify sticky session persists with the fallback server
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
for _, cookie := range newCookies {
req3.AddCookie(cookie)
}
recorder3 := httptest.NewRecorder()
balancer.ServeHTTP(recorder3, req3)
assert.Equal(t, fallbackServer, recorder3.Header().Get("server"))
// Bring original server back UP
balancer.SetStatus(context.WithValue(t.Context(), serviceName, "test"), firstServer, true)
// Request with fallback server cookie should still stick to fallback server
req4 := httptest.NewRequest(http.MethodGet, "/", nil)
for _, cookie := range newCookies {
req4.AddCookie(cookie)
}
recorder4 := httptest.NewRecorder()
balancer.ServeHTTP(recorder4, req4)
assert.Equal(t, fallbackServer, recorder4.Header().Get("server"))
}
// TestBalancerStickyFenced tests that sticky sessions persist to fenced servers (graceful shutdown)
// Fencing enables zero-downtime deployments: fenced servers reject NEW connections
// but continue serving EXISTING sticky sessions until they complete.
func TestBalancerStickyFenced(t *testing.T) {
balancer := New(&dynamic.Sticky{
Cookie: &dynamic.Cookie{
Name: "test",
},
}, false)
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// Establish sticky session with any server.
recorder1 := httptest.NewRecorder()
balancer.ServeHTTP(recorder1, httptest.NewRequest(http.MethodGet, "/", nil))
stickyServer := recorder1.Header().Get("server")
assert.NotEmpty(t, stickyServer)
cookies := recorder1.Result().Cookies()
assert.NotEmpty(t, cookies)
// Fence the sticky server (simulate graceful shutdown).
balancer.handlersMu.Lock()
balancer.fenced[stickyServer] = struct{}{}
balancer.handlersMu.Unlock()
// Existing sticky session should STILL work (graceful draining).
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
recorder2 := httptest.NewRecorder()
balancer.ServeHTTP(recorder2, req)
assert.Equal(t, stickyServer, recorder2.Header().Get("server"))
// But NEW requests should NOT go to the fenced server.
recorder3 := httptest.NewRecorder()
balancer.ServeHTTP(recorder3, httptest.NewRequest(http.MethodGet, "/", nil))
newServer := recorder3.Header().Get("server")
assert.NotEqual(t, stickyServer, newServer)
assert.NotEmpty(t, newServer)
}
// TestRingBufferBasic tests basic ring buffer functionality with few samples.
func TestRingBufferBasic(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
name: "test",
weight: 1,
}
// Test cold start - no samples.
avg := handler.getAvgResponseTime()
assert.InDelta(t, 0.0, avg, 0)
// Add one sample.
handler.updateResponseTime(10 * time.Millisecond)
avg = handler.getAvgResponseTime()
assert.InDelta(t, 10.0, avg, 0)
// Add more samples.
handler.updateResponseTime(20 * time.Millisecond)
handler.updateResponseTime(30 * time.Millisecond)
avg = handler.getAvgResponseTime()
assert.InDelta(t, 20.0, avg, 0) // (10 + 20 + 30) / 3 = 20
}
// TestRingBufferWraparound tests ring buffer behavior when it wraps around
func TestRingBufferWraparound(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
name: "test",
weight: 1,
}
// Fill the buffer with 100 samples of 10ms each.
for range sampleSize {
handler.updateResponseTime(10 * time.Millisecond)
}
avg := handler.getAvgResponseTime()
assert.InDelta(t, 10.0, avg, 0)
// Add one more sample (should replace oldest).
handler.updateResponseTime(20 * time.Millisecond)
avg = handler.getAvgResponseTime()
// Sum: 99*10 + 1*20 = 1010, avg = 1010/100 = 10.1
assert.InDelta(t, 10.1, avg, 0)
// Add 10 more samples of 30ms.
for range 10 {
handler.updateResponseTime(30 * time.Millisecond)
}
avg = handler.getAvgResponseTime()
// Sum: 89*10 + 1*20 + 10*30 = 890 + 20 + 300 = 1210, avg = 1210/100 = 12.1
assert.InDelta(t, 12.1, avg, 0)
}
// TestRingBufferLarge tests ring buffer with many samples (> 100).
func TestRingBufferLarge(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
name: "test",
weight: 1,
}
// Add 150 samples.
for i := range 150 {
handler.updateResponseTime(time.Duration(i+1) * time.Millisecond)
}
// Should only track last 100 samples: 51, 52, ..., 150
// Sum = (51 + 150) * 100 / 2 = 10050
// Avg = 10050 / 100 = 100.5
avg := handler.getAvgResponseTime()
assert.InDelta(t, 100.5, avg, 0)
}
// TestInflightCounter tests inflight request tracking.
func TestInflightCounter(t *testing.T) {
balancer := New(nil, false)
var inflightAtRequest atomic.Int64
balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
inflightAtRequest.Store(balancer.handlers[0].inflightCount.Load())
rw.Header().Set("server", "test")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// Check that inflight count is 0 initially.
balancer.handlersMu.RLock()
handler := balancer.handlers[0]
balancer.handlersMu.RUnlock()
assert.Equal(t, int64(0), handler.inflightCount.Load())
// Make a request.
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
// During request, inflight should have been 1.
assert.Equal(t, int64(1), inflightAtRequest.Load())
// After request completes, inflight should be back to 0.
assert.Equal(t, int64(0), handler.inflightCount.Load())
}
// TestConcurrentResponseTimeUpdates tests thread safety of response time updates.
func TestConcurrentResponseTimeUpdates(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
name: "test",
weight: 1,
}
// Concurrently update response times.
var wg sync.WaitGroup
numGoroutines := 10
updatesPerGoroutine := 20
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for range updatesPerGoroutine {
handler.updateResponseTime(time.Duration(id+1) * time.Millisecond)
}
}(i)
}
wg.Wait()
// Should have exactly 100 samples (buffer size).
assert.Equal(t, sampleSize, handler.sampleCount)
}
// TestConcurrentInflightTracking tests thread safety of inflight counter.
func TestConcurrentInflightTracking(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(10 * time.Millisecond)
rw.WriteHeader(http.StatusOK)
}),
name: "test",
weight: 1,
}
var maxInflight atomic.Int64
var wg sync.WaitGroup
numRequests := 50
for range numRequests {
wg.Add(1)
go func() {
defer wg.Done()
handler.inflightCount.Add(1)
defer handler.inflightCount.Add(-1)
// Track maximum inflight count.
current := handler.inflightCount.Load()
for {
maxLoad := maxInflight.Load()
if current <= maxLoad || maxInflight.CompareAndSwap(maxLoad, current) {
break
}
}
time.Sleep(1 * time.Millisecond)
}()
}
wg.Wait()
// All requests completed, inflight should be 0.
assert.Equal(t, int64(0), handler.inflightCount.Load())
// Max inflight should be > 1 (concurrent requests).
assert.Greater(t, maxInflight.Load(), int64(1))
}
// TestConcurrentRequestsRespectInflight tests that the load balancer dynamically
// adapts to inflight request counts during concurrent request processing.
func TestConcurrentRequestsRespectInflight(t *testing.T) {
balancer := New(nil, false)
// Use a channel to control when handlers start sleeping.
// This ensures we can fill one server with inflight requests before routing new ones.
blockChan := make(chan struct{})
// Add two servers with equal response times and weights.
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
<-blockChan // Wait for signal to proceed.
time.Sleep(10 * time.Millisecond)
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
<-blockChan // Wait for signal to proceed.
time.Sleep(10 * time.Millisecond)
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// Pre-warm both servers to establish equal average response times.
for i := range sampleSize {
balancer.handlers[0].responseTimes[i] = 10.0
}
balancer.handlers[0].responseTimeSum = 10.0 * sampleSize
balancer.handlers[0].sampleCount = sampleSize
for i := range sampleSize {
balancer.handlers[1].responseTimes[i] = 10.0
}
balancer.handlers[1].responseTimeSum = 10.0 * sampleSize
balancer.handlers[1].sampleCount = sampleSize
// Phase 1: Launch concurrent requests to server1 that will block.
var wg sync.WaitGroup
inflightRequests := 5
for range inflightRequests {
wg.Add(1)
go func() {
defer wg.Done()
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}()
}
// Wait for goroutines to start and increment inflight counters.
// They will block on the channel, keeping inflight count high.
time.Sleep(50 * time.Millisecond)
// Verify inflight counts before making new requests.
server1Inflight := balancer.handlers[0].inflightCount.Load()
server2Inflight := balancer.handlers[1].inflightCount.Load()
assert.Equal(t, int64(5), server1Inflight+server2Inflight)
// Phase 2: Make new requests while the initial requests are blocked.
// These should see the high inflight counts and route to the less-loaded server.
var saveMu sync.Mutex
save := map[string]int{}
newRequests := 50
// Launch new requests in background so they don't block.
var newWg sync.WaitGroup
for range newRequests {
newWg.Add(1)
go func() {
defer newWg.Done()
rec := httptest.NewRecorder()
balancer.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
server := rec.Header().Get("server")
if server != "" {
saveMu.Lock()
save[server]++
saveMu.Unlock()
}
}()
}
// Wait for new requests to start and see the inflight counts.
time.Sleep(50 * time.Millisecond)
close(blockChan)
wg.Wait()
newWg.Wait()
saveMu.Lock()
total := save["server1"] + save["server2"]
server1Count := save["server1"]
server2Count := save["server2"]
saveMu.Unlock()
assert.Equal(t, newRequests, total)
// With inflight tracking, load should naturally balance toward equal distribution.
// We allow variance due to concurrent execution and race windows in server selection.
assert.InDelta(t, 25.0, float64(server1Count), 5.0) // 20-30 requests
assert.InDelta(t, 25.0, float64(server2Count), 5.0) // 20-30 requests
}
// TestTTFBMeasurement tests TTFB measurement accuracy.
func TestTTFBMeasurement(t *testing.T) {
balancer := New(nil, false)
// Add server with known delay.
delay := 50 * time.Millisecond
balancer.Add("slow", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(delay)
rw.Header().Set("server", "slow")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// Make multiple requests to build average.
for range 5 {
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Check that average response time is approximately the delay.
avg := balancer.handlers[0].getAvgResponseTime()
// Allow 5ms tolerance for Go timing jitter and test environment variations.
assert.InDelta(t, float64(delay.Milliseconds()), avg, 5.0)
}
// TestZeroSamplesReturnsZero tests that getAvgResponseTime returns 0 when no samples.
func TestZeroSamplesReturnsZero(t *testing.T) {
handler := &namedHandler{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}),
name: "test",
weight: 1,
}
avg := handler.getAvgResponseTime()
assert.InDelta(t, 0.0, avg, 0)
}
// TestScoreCalculationWithWeights tests that weights are properly considered in score calculation.
func TestScoreCalculationWithWeights(t *testing.T) {
balancer := New(nil, false)
// Add two servers with same response time but different weights.
// Server with higher weight should be preferred.
balancer.Add("weighted", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(50 * time.Millisecond)
rw.Header().Set("server", "weighted")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(3), false) // Weight 3
balancer.Add("normal", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(50 * time.Millisecond)
rw.Header().Set("server", "normal")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false) // Weight 1
// Make requests to build up response time averages.
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 2 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Score for weighted: (50 × (1 + 0)) / 3 = 16.67
// Score for normal: (50 × (1 + 0)) / 1 = 50
// After warmup, weighted server has 3x better score (16.67 vs 50) and should receive nearly all requests.
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 10 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
assert.Equal(t, 10, recorder.save["weighted"])
assert.Zero(t, recorder.save["normal"])
}
// TestScoreCalculationWithInflight tests that inflight requests are considered in score calculation.
func TestScoreCalculationWithInflight(t *testing.T) {
balancer := New(nil, false)
// We'll manually control the inflight counters to test the score calculation.
// Add two servers with same response time.
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(10 * time.Millisecond)
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(10 * time.Millisecond)
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// Build up response time averages for both servers.
for range 2 {
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Now manually set server1 to have high inflight count.
balancer.handlers[0].inflightCount.Store(5)
// Make requests - they should prefer server2 because:
// Score for server1: (10 × (1 + 5)) / 1 = 60
// Score for server2: (10 × (1 + 0)) / 1 = 10
recorder := &responseRecorder{save: map[string]int{}}
for range 5 {
// Manually increment to simulate the ServeHTTP behavior.
server, _ := balancer.nextServer()
server.inflightCount.Add(1)
if server.name == "server1" {
recorder.save["server1"]++
} else {
recorder.save["server2"]++
}
}
// Server2 should get all requests
assert.Equal(t, 5, recorder.save["server2"])
assert.Zero(t, recorder.save["server1"])
}
// TestScoreCalculationColdStart tests that new servers (0ms avg) get fair selection
func TestScoreCalculationColdStart(t *testing.T) {
balancer := New(nil, false)
// Add a warm server with established response time
balancer.Add("warm", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(50 * time.Millisecond)
rw.Header().Set("server", "warm")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// Warm up the first server
for range 10 {
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Now add a cold server (new, no response time data)
balancer.Add("cold", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(10 * time.Millisecond) // Actually faster
rw.Header().Set("server", "cold")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// Cold server should get selected because:
// Score for warm: (50 × (1 + 0)) / 1 = 50
// Score for cold: (0 × (1 + 0)) / 1 = 0
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 20 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// Cold server should get all or most requests initially due to 0ms average
assert.Greater(t, recorder.save["cold"], 10)
// After cold server builds up its average, it should continue to get more traffic
// because it's actually faster (10ms vs 50ms)
recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 20 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
assert.Greater(t, recorder.save["cold"], recorder.save["warm"])
}
// TestFastServerGetsMoreTraffic verifies that servers with lower response times
// receive proportionally more traffic in steady state (after cold start).
// This tests the core selection bias of the least-time algorithm.
func TestFastServerGetsMoreTraffic(t *testing.T) {
balancer := New(nil, false)
// Add two servers with different static response times.
balancer.Add("fast", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(20 * time.Millisecond)
rw.Header().Set("server", "fast")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
balancer.Add("slow", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(100 * time.Millisecond)
rw.Header().Set("server", "slow")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// After just 1 request to each server, the algorithm identifies the fastest server
// and routes nearly all subsequent traffic there (converges in ~2 requests).
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 50 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
assert.Greater(t, recorder.save["fast"], recorder.save["slow"])
assert.Greater(t, recorder.save["fast"], 48) // Expect ~96-98% to fast server (48-49/50).
}
// TestTrafficShiftsWhenPerformanceDegrades verifies that the load balancer
// adapts to changing server performance by shifting traffic away from degraded servers.
// This tests the adaptive behavior - the core value proposition of least-time load balancing.
func TestTrafficShiftsWhenPerformanceDegrades(t *testing.T) {
balancer := New(nil, false)
// Use atomic to dynamically control server1's response time.
server1Delay := atomic.Int64{}
server1Delay.Store(5) // Start with 5ms
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(time.Duration(server1Delay.Load()) * time.Millisecond)
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond) // Static 5ms
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
httptrace.ContextClientTrace(req.Context()).GotFirstResponseByte()
}), pointer(1), false)
// Pre-fill ring buffers to eliminate cold start effects and ensure deterministic equal performance state.
for _, h := range balancer.handlers {
for i := range sampleSize {
h.responseTimes[i] = 5.0
}
h.responseTimeSum = 5.0 * sampleSize
h.sampleCount = sampleSize
}
// Phase 1: Both servers perform equally (5ms each).
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 50 {
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
}
// With equal performance and pre-filled buffers, distribution should be balanced via WRR tie-breaking.
total := recorder.save["server1"] + recorder.save["server2"]
assert.Equal(t, 50, total)
assert.InDelta(t, 25, recorder.save["server1"], 10) // 25 ± 10 requests
assert.InDelta(t, 25, recorder.save["server2"], 10) // 25 ± 10 requests
// Phase 2: server1 degrades (simulating GC pause, CPU spike, or network latency).
server1Delay.Store(15) // Now 15ms (3x slower)
// Make more requests to shift the moving average.
// Ring buffer has 100 samples, need significant new samples to shift average.
// server1's average will climb from ~5ms toward 15ms.
recorder2 := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for range 60 {
balancer.ServeHTTP(recorder2, httptest.NewRequest(http.MethodGet, "/", nil))
}
// server2 should get significantly more traffic (>75%)
// Score for server1: (~10-15ms × 1) / 1 = 10-15 (as average climbs)
// Score for server2: (5ms × 1) / 1 = 5
total2 := recorder2.save["server1"] + recorder2.save["server2"]
assert.Equal(t, 60, total2)
assert.Greater(t, recorder2.save["server2"], 45) // At least 75% (45/60)
assert.Less(t, recorder2.save["server1"], 15) // At most 25% (15/60)
}
// TestMultipleServersWithSameScore tests WRR tie-breaking when multiple servers have identical scores.
// Uses nextServer() directly to avoid timing variations in the test.
func TestMultipleServersWithSameScore(t *testing.T) {
balancer := New(nil, false)
// Add three servers with identical response times and weights.
balancer.Add("server1", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "server1")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("server2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "server2")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
balancer.Add("server3", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "server3")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false)
// Set all servers to identical response times to trigger tie-breaking.
for _, h := range balancer.handlers {
for i := range sampleSize {
h.responseTimes[i] = 5.0
}
h.responseTimeSum = 5.0 * sampleSize
h.sampleCount = sampleSize
}
// With all servers having identical scores, WRR tie-breaking should distribute fairly.
// Test the selection logic directly without actual HTTP requests to avoid timing variations.
counts := map[string]int{"server1": 0, "server2": 0, "server3": 0}
for range 90 {
server, err := balancer.nextServer()
assert.NoError(t, err)
counts[server.name]++
}
total := counts["server1"] + counts["server2"] + counts["server3"]
assert.Equal(t, 90, total)
// With WRR and 90 requests, each server should get ~30 requests (±1 due to initialization).
assert.InDelta(t, 30, counts["server1"], 1)
assert.InDelta(t, 30, counts["server2"], 1)
assert.InDelta(t, 30, counts["server3"], 1)
}
// TestWRRTieBreakingWeightedDistribution tests weighted distribution among tied servers.
// Uses nextServer() directly to avoid timing variations in the test.
func TestWRRTieBreakingWeightedDistribution(t *testing.T) {
balancer := New(nil, false)
// Add two servers with different weights.
// To create equal scores, response times must be proportional to weights.
balancer.Add("weighted", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(15 * time.Millisecond) // 3x longer due to 3x weight
rw.Header().Set("server", "weighted")
rw.WriteHeader(http.StatusOK)
}), pointer(3), false) // Weight 3
balancer.Add("normal", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(5 * time.Millisecond)
rw.Header().Set("server", "normal")
rw.WriteHeader(http.StatusOK)
}), pointer(1), false) // Weight 1
// Since response times is proportional to weights, both scores are equal, so WRR tie-breaking will apply.
// weighted: score = (15 * 1) / 3 = 5
// normal: score = (5 * 1) / 1 = 5
for i := range sampleSize {
balancer.handlers[0].responseTimes[i] = 15.0
}
balancer.handlers[0].responseTimeSum = 15.0 * sampleSize
balancer.handlers[0].sampleCount = sampleSize
for i := range sampleSize {
balancer.handlers[1].responseTimes[i] = 5.0
}
balancer.handlers[1].responseTimeSum = 5.0 * sampleSize
balancer.handlers[1].sampleCount = sampleSize
// Test the selection logic directly without actual HTTP requests to avoid timing variations.
counts := map[string]int{"weighted": 0, "normal": 0}
for range 80 {
server, err := balancer.nextServer()
assert.NoError(t, err)
counts[server.name]++
}
total := counts["weighted"] + counts["normal"]
assert.Equal(t, 80, total)
// With 3:1 weight ratio, distribution should be ~75%/25% (60/80 and 20/80), ±1 due to initialization.
assert.InDelta(t, 60, counts["weighted"], 1)
assert.InDelta(t, 20, counts["normal"], 1)
}