diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 1e46e2292..3fc558bd4 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -18,9 +18,12 @@ package http import ( "fmt" + "io" "net/http" + "net/url" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -57,3 +60,98 @@ func TestNewInstrumentedClient(t *testing.T) { _, ok = result2.Transport.(*CustomRoundTripper) require.True(t, ok) } + +func TestCancelRequest(t *testing.T) { + for _, tt := range []struct { + title string + customRoundTripper CustomRoundTripper + request *http.Request + }{ + { + title: "CancelRequest does nothing", + customRoundTripper: CustomRoundTripper{}, + request: &http.Request{}, + }, + } { + t.Run(tt.title, func(t *testing.T) { + tt.customRoundTripper.CancelRequest(tt.request) + }) + } +} + +type mockRoundTripper struct { + response *http.Response + error error +} + +func (mrt mockRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return mrt.response, mrt.error +} + +func TestRoundTrip(t *testing.T) { + for _, tt := range []struct { + title string + nextRoundTripper mockRoundTripper + request *http.Request + method string + url string + body io.Reader + + expectError bool + expectedResponse *http.Response + }{ + { + title: "RoundTrip returns no error", + nextRoundTripper: mockRoundTripper{}, + request: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Scheme: "HTTPS", + Host: "test.local", + Path: "/path", + }, + Body: nil, + }, + expectError: false, + expectedResponse: nil, + }, + { + title: "RoundTrip extracts status from request", + nextRoundTripper: mockRoundTripper{ + response: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + request: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Scheme: "HTTPS", + Host: "test.local", + Path: "/path", + }, + Body: nil, + }, + expectError: false, + expectedResponse: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + } { + t.Run(tt.title, func(t *testing.T) { + req, err := http.NewRequest(tt.method, tt.url, tt.body) + customRoundTripper := CustomRoundTripper{ + next: tt.nextRoundTripper, + } + + resp, err := customRoundTripper.RoundTrip(req) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tt.expectedResponse, resp) + }) + } +}