diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index f1d0585f5..8273097c3 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -11,7 +11,6 @@ import ( "io" "net" "net/http" - "net/http/httptest" "os" "strings" "sync" @@ -20,6 +19,7 @@ import ( "time" "github.com/go-json-experiment/json/jsontext" + "tailscale.com/net/memnet" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/util/eventbus/eventbustest" @@ -30,6 +30,7 @@ import ( // test in this package. Config.BaseURL defaults to https://log.tailscale.com // and Config.HTTPC defaults to http.DefaultClient, so a test that forgets to // override either can otherwise silently hit the real logtail server. +// Tests that need an HTTP server should use memnet (see newTestLogtailServer). func TestMain(m *testing.M) { tr := http.DefaultTransport.(*http.Transport) orig := tr.DialContext @@ -38,25 +39,19 @@ func TestMain(m *testing.M) { if err == nil && (host == "127.0.0.1" || host == "::1" || host == "localhost") { return orig(ctx, network, addr) } - return nil, fmt.Errorf("logtail tests: refusing to dial non-localhost address %q; use httptest.Server or a custom Config.HTTPC", addr) + return nil, fmt.Errorf("logtail tests: refusing to dial non-localhost address %q; use memnet or a custom Config.HTTPC", addr) } os.Exit(m.Run()) } -func TestFastShutdown(t *testing.T) { +func TestFastShutdown(t *testing.T) { synctest.Test(t, synctestFastShutdown) } + +func synctestFastShutdown(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - testServ := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {})) - defer testServ.Close() - - logger := NewLogger(Config{ - BaseURL: testServ.URL, - Bus: eventbustest.NewBus(t), - }, t.Logf) - err := logger.Shutdown(ctx) - if err != nil { + _, logger := newTestLogtailServer(t) + if err := logger.Shutdown(ctx); err != nil { t.Error(err) } } @@ -65,49 +60,60 @@ func TestFastShutdown(t *testing.T) { const logLines = 3 type LogtailTestServer struct { - srv *httptest.Server // Log server uploaded chan []byte } -func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { - ts := LogtailTestServer{} +// newTestLogtailServer wires up an in-memory HTTP server (via memnet) and a +// *Logger whose HTTPC dials it. Lives inside the caller's synctest bubble so +// the default FlushDelay and any other fake timers advance automatically. +func newTestLogtailServer(t *testing.T) (*LogtailTestServer, *Logger) { + ts := &LogtailTestServer{ + // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" + uploaded: make(chan []byte, 2+logLines), + } - // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" - ts.uploaded = make(chan []byte, 2+logLines) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error("failed to read HTTP request") + } + ts.uploaded <- body + }) - ts.srv = httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - t.Error("failed to read HTTP request") - } - ts.uploaded <- body - })) - - t.Cleanup(ts.srv.Close) + ln := memnet.Listen("logtail-test:0") + httpsrv := &http.Server{Handler: handler} + go httpsrv.Serve(ln) + t.Cleanup(func() { + httpsrv.Close() + ln.Close() + }) logger := NewLogger(Config{ - BaseURL: ts.srv.URL, + BaseURL: "http://" + ln.Addr().String(), Bus: eventbustest.NewBus(t), + HTTPC: &http.Client{ + Transport: &http.Transport{DialContext: ln.Dial}, + }, }, t.Logf) - // There is always an initial "logtail started" message + // There is always an initial "logtail started" message. body := <-ts.uploaded if !strings.Contains(string(body), "started") { t.Errorf("unknown start logging statement: %q", string(body)) } - - return &ts, logger + return ts, logger } -func TestDrainPendingMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestDrainPendingMessages(t *testing.T) { synctest.Test(t, synctestDrainPendingMessages) } + +func synctestDrainPendingMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) for range logLines { logger.Write([]byte("log line")) } - // all of the "log line" messages usually arrive at once, but poll if needed. + // All the "log line" messages usually arrive at once, but poll if needed. var body strings.Builder for i := 0; i <= logLines; i++ { body.WriteString(string(<-ts.uploaded)) @@ -115,17 +121,17 @@ func TestDrainPendingMessages(t *testing.T) { if count == logLines { break } - // if we never find count == logLines, the test will eventually time out. } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } } -func TestEncodeAndUploadMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestEncodeAndUploadMessages(t *testing.T) { synctest.Test(t, synctestEncodeAndUploadMessages) } + +func synctestEncodeAndUploadMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) tests := []struct { name string @@ -166,8 +172,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } }