mirror of
				https://github.com/traefik/traefik.git
				synced 2025-10-31 00:11:38 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			1107 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			1107 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package compress
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/andybalholm/brotli"
 | |
| 	"github.com/klauspost/compress/zstd"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	smallTestBody = []byte("aaabbc" + strings.Repeat("aaabbbccc", 9) + "aaabbbc")
 | |
| 	bigTestBody   = []byte(strings.Repeat(strings.Repeat("aaabbbccc", 66)+" ", 6) + strings.Repeat("aaabbbccc", 66))
 | |
| )
 | |
| 
 | |
| func Test_Vary(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		h              http.Handler
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "brotli",
 | |
| 			h:              newTestBrotliHandler(t, smallTestBody),
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:           "zstd",
 | |
| 			h:              newTestZstandardHandler(t, smallTestBody),
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			test.h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusAccepted, rw.Code)
 | |
| 			assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_SmallBodyNoCompression(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		h              http.Handler
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "brotli",
 | |
| 			h:              newTestBrotliHandler(t, smallTestBody),
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:           "zstd",
 | |
| 			h:              newTestZstandardHandler(t, smallTestBody),
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			test.h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			// With less than 1024 bytes the response should not be compressed.
 | |
| 			assert.Equal(t, http.StatusAccepted, rw.Code)
 | |
| 			assert.Empty(t, rw.Header().Get(contentEncoding))
 | |
| 			assert.Equal(t, smallTestBody, rw.Body.Bytes())
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_AlreadyCompressed(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		h              http.Handler
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "brotli",
 | |
| 			h:              newTestBrotliHandler(t, bigTestBody),
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:           "zstd",
 | |
| 			h:              newTestZstandardHandler(t, bigTestBody),
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			test.h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusAccepted, rw.Code)
 | |
| 			assert.Equal(t, bigTestBody, rw.Body.Bytes())
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_NoBody(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc       string
 | |
| 		statusCode int
 | |
| 		body       []byte
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:       "status no content",
 | |
| 			statusCode: http.StatusNoContent,
 | |
| 			body:       nil,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:       "status not modified",
 | |
| 			statusCode: http.StatusNotModified,
 | |
| 			body:       nil,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:       "status OK with empty body",
 | |
| 			statusCode: http.StatusOK,
 | |
| 			body:       []byte{},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:       "status OK with nil body",
 | |
| 			statusCode: http.StatusOK,
 | |
| 			body:       nil,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.WriteHeader(test.statusCode)
 | |
| 
 | |
| 				_, err := rw.Write(test.body)
 | |
| 				require.NoError(t, err)
 | |
| 			})
 | |
| 
 | |
| 			h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
 | |
| 
 | |
| 			req := httptest.NewRequest(http.MethodGet, "/", nil)
 | |
| 			req.Header.Set(acceptEncoding, "zstd")
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			body, err := io.ReadAll(rw.Body)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			assert.Empty(t, rw.Header().Get(contentEncoding))
 | |
| 			assert.Empty(t, body)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_MinSize(t *testing.T) {
 | |
| 	cfg := Config{
 | |
| 		MinSize: 128,
 | |
| 	}
 | |
| 
 | |
| 	var bodySize int
 | |
| 
 | |
| 	next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 		for range bodySize {
 | |
| 			// We make sure to Write at least once less than minSize so that both
 | |
| 			// cases below go through the same algo: i.e. they start buffering
 | |
| 			// because they haven't reached minSize.
 | |
| 			_, err := rw.Write([]byte{'x'})
 | |
| 			require.NoError(t, err)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	h := mustNewCompressionHandler(t, cfg, zstdName, next)
 | |
| 
 | |
| 	req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
 | |
| 	req.Header.Add(acceptEncoding, "zstd")
 | |
| 
 | |
| 	// Short response is not compressed
 | |
| 	bodySize = cfg.MinSize - 1
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	h.ServeHTTP(rw, req)
 | |
| 
 | |
| 	assert.Empty(t, rw.Result().Header.Get(contentEncoding))
 | |
| 
 | |
| 	// Long response is compressed
 | |
| 	bodySize = cfg.MinSize
 | |
| 	rw = httptest.NewRecorder()
 | |
| 	h.ServeHTTP(rw, req)
 | |
| 
 | |
| 	assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
 | |
| }
 | |
| 
 | |
| func Test_MultipleWriteHeader(t *testing.T) {
 | |
| 	next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 		// We ensure that the subsequent call to WriteHeader is a noop.
 | |
| 		rw.WriteHeader(http.StatusInternalServerError)
 | |
| 		rw.WriteHeader(http.StatusNotFound)
 | |
| 	})
 | |
| 
 | |
| 	h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
 | |
| 
 | |
| 	req := httptest.NewRequest(http.MethodGet, "/", nil)
 | |
| 	req.Header.Set(acceptEncoding, "zstd")
 | |
| 
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	h.ServeHTTP(rw, req)
 | |
| 
 | |
| 	assert.Equal(t, http.StatusInternalServerError, rw.Code)
 | |
| }
 | |
| 
 | |
| func Test_FlushBeforeWrite(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		cfg            Config
 | |
| 		algo           string
 | |
| 		readerBuilder  func(io.Reader) (io.Reader, error)
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "brotli",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: brotliName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return brotli.NewReader(reader), nil
 | |
| 			},
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "zstd",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: zstdName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return zstd.NewReader(reader)
 | |
| 			},
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.WriteHeader(http.StatusOK)
 | |
| 				rw.(http.Flusher).Flush()
 | |
| 
 | |
| 				_, err := rw.Write(bigTestBody)
 | |
| 				require.NoError(t, err)
 | |
| 			})
 | |
| 
 | |
| 			srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
 | |
| 			defer srv.Close()
 | |
| 
 | |
| 			req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			res, err := http.DefaultClient.Do(req)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			defer res.Body.Close()
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, res.StatusCode)
 | |
| 			assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
 | |
| 
 | |
| 			reader, err := test.readerBuilder(res.Body)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			got, err := io.ReadAll(reader)
 | |
| 			require.NoError(t, err)
 | |
| 			assert.Equal(t, bigTestBody, got)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_FlushAfterWrite(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		cfg            Config
 | |
| 		algo           string
 | |
| 		readerBuilder  func(io.Reader) (io.Reader, error)
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "brotli",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: brotliName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return brotli.NewReader(reader), nil
 | |
| 			},
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "zstd",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: zstdName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return zstd.NewReader(reader)
 | |
| 			},
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.WriteHeader(http.StatusOK)
 | |
| 
 | |
| 				_, err := rw.Write(bigTestBody[0:1])
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				rw.(http.Flusher).Flush()
 | |
| 				for _, b := range bigTestBody[1:] {
 | |
| 					_, err := rw.Write([]byte{b})
 | |
| 					require.NoError(t, err)
 | |
| 				}
 | |
| 			})
 | |
| 
 | |
| 			srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
 | |
| 			defer srv.Close()
 | |
| 
 | |
| 			req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			res, err := http.DefaultClient.Do(req)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			defer res.Body.Close()
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, res.StatusCode)
 | |
| 			assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
 | |
| 
 | |
| 			reader, err := test.readerBuilder(res.Body)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			got, err := io.ReadAll(reader)
 | |
| 			require.NoError(t, err)
 | |
| 			assert.Equal(t, bigTestBody, got)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_FlushAfterWriteNil(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		cfg            Config
 | |
| 		algo           string
 | |
| 		readerBuilder  func(io.Reader) (io.Reader, error)
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "brotli",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: brotliName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return brotli.NewReader(reader), nil
 | |
| 			},
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "zstd",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: zstdName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return zstd.NewReader(reader)
 | |
| 			},
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.WriteHeader(http.StatusOK)
 | |
| 
 | |
| 				_, err := rw.Write(nil)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				rw.(http.Flusher).Flush()
 | |
| 			})
 | |
| 
 | |
| 			srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
 | |
| 			defer srv.Close()
 | |
| 
 | |
| 			req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			res, err := http.DefaultClient.Do(req)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			defer res.Body.Close()
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, res.StatusCode)
 | |
| 			assert.Empty(t, res.Header.Get(contentEncoding))
 | |
| 
 | |
| 			reader, err := test.readerBuilder(res.Body)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			got, err := io.ReadAll(reader)
 | |
| 			require.NoError(t, err)
 | |
| 			assert.Empty(t, got)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_FlushAfterAllWrites(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		cfg            Config
 | |
| 		algo           string
 | |
| 		readerBuilder  func(io.Reader) (io.Reader, error)
 | |
| 		acceptEncoding string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "brotli",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: brotliName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return brotli.NewReader(reader), nil
 | |
| 			},
 | |
| 			acceptEncoding: "br",
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "zstd",
 | |
| 			cfg:  Config{MinSize: 1024, MiddlewareName: "Test"},
 | |
| 			algo: zstdName,
 | |
| 			readerBuilder: func(reader io.Reader) (io.Reader, error) {
 | |
| 				return zstd.NewReader(reader)
 | |
| 			},
 | |
| 			acceptEncoding: "zstd",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				for i := range bigTestBody {
 | |
| 					_, err := rw.Write(bigTestBody[i : i+1])
 | |
| 					require.NoError(t, err)
 | |
| 				}
 | |
| 				rw.(http.Flusher).Flush()
 | |
| 			})
 | |
| 
 | |
| 			srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
 | |
| 			defer srv.Close()
 | |
| 
 | |
| 			req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			req.Header.Set(acceptEncoding, test.acceptEncoding)
 | |
| 
 | |
| 			res, err := http.DefaultClient.Do(req)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			defer res.Body.Close()
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, res.StatusCode)
 | |
| 			assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
 | |
| 
 | |
| 			reader, err := test.readerBuilder(res.Body)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			got, err := io.ReadAll(reader)
 | |
| 			require.NoError(t, err)
 | |
| 			assert.Equal(t, bigTestBody, got)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_ExcludedContentTypes(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                 string
 | |
| 		contentType          string
 | |
| 		excludedContentTypes []string
 | |
| 		expCompression       bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "Always compress when content types are empty",
 | |
| 			contentType:    "",
 | |
| 			expCompression: true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match",
 | |
| 			contentType:          "application/json",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME no match",
 | |
| 			contentType:          "text/xml",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with no other directive ignores non-MIME directives",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, different charset",
 | |
| 			contentType:          "application/json; charset=ascii",
 | |
| 			excludedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, same charset",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, missing charset",
 | |
| 			contentType:          "application/json",
 | |
| 			excludedContentTypes: []string{"application/json; charset=ascii"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match case insensitive",
 | |
| 			contentType:          "Application/Json",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match ignore whitespace",
 | |
| 			contentType:          "application/json;charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json;            charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			cfg := Config{
 | |
| 				MinSize:              1024,
 | |
| 				ExcludedContentTypes: test.excludedContentTypes,
 | |
| 			}
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.Header().Set(contentType, test.contentType)
 | |
| 
 | |
| 				rw.WriteHeader(http.StatusAccepted)
 | |
| 
 | |
| 				_, err := rw.Write(bigTestBody)
 | |
| 				require.NoError(t, err)
 | |
| 			})
 | |
| 
 | |
| 			h := mustNewCompressionHandler(t, cfg, zstdName, next)
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, zstdName)
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusAccepted, rw.Code)
 | |
| 
 | |
| 			if test.expCompression {
 | |
| 				assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				reader, err := zstd.NewReader(rw.Body)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				got, err := io.ReadAll(reader)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			} else {
 | |
| 				assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
 | |
| 
 | |
| 				got, err := io.ReadAll(rw.Body)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_IncludedContentTypes(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                 string
 | |
| 		contentType          string
 | |
| 		includedContentTypes []string
 | |
| 		expCompression       bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "Always compress when content types are empty",
 | |
| 			contentType:    "",
 | |
| 			expCompression: true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match",
 | |
| 			contentType:          "application/json",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME no match",
 | |
| 			contentType:          "text/xml",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with no other directive ignores non-MIME directives",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, different charset",
 | |
| 			contentType:          "application/json; charset=ascii",
 | |
| 			includedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, same charset",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, missing charset",
 | |
| 			contentType:          "application/json",
 | |
| 			includedContentTypes: []string{"application/json; charset=ascii"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match case insensitive",
 | |
| 			contentType:          "Application/Json",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match ignore whitespace",
 | |
| 			contentType:          "application/json;charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json;            charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			cfg := Config{
 | |
| 				MinSize:              1024,
 | |
| 				IncludedContentTypes: test.includedContentTypes,
 | |
| 			}
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.Header().Set(contentType, test.contentType)
 | |
| 
 | |
| 				rw.WriteHeader(http.StatusAccepted)
 | |
| 
 | |
| 				_, err := rw.Write(bigTestBody)
 | |
| 				require.NoError(t, err)
 | |
| 			})
 | |
| 
 | |
| 			h := mustNewCompressionHandler(t, cfg, zstdName, next)
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, zstdName)
 | |
| 
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusAccepted, rw.Code)
 | |
| 
 | |
| 			if test.expCompression {
 | |
| 				assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				reader, err := zstd.NewReader(rw.Body)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				got, err := io.ReadAll(reader)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			} else {
 | |
| 				assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
 | |
| 
 | |
| 				got, err := io.ReadAll(rw.Body)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_FlushExcludedContentTypes(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                 string
 | |
| 		contentType          string
 | |
| 		excludedContentTypes []string
 | |
| 		expCompression       bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "Always compress when content types are empty",
 | |
| 			contentType:    "",
 | |
| 			expCompression: true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match",
 | |
| 			contentType:          "application/json",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME no match",
 | |
| 			contentType:          "text/xml",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with no other directive ignores non-MIME directives",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, different charset",
 | |
| 			contentType:          "application/json; charset=ascii",
 | |
| 			excludedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, same charset",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, missing charset",
 | |
| 			contentType:          "application/json",
 | |
| 			excludedContentTypes: []string{"application/json; charset=ascii"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match case insensitive",
 | |
| 			contentType:          "Application/Json",
 | |
| 			excludedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match ignore whitespace",
 | |
| 			contentType:          "application/json;charset=utf-8",
 | |
| 			excludedContentTypes: []string{"application/json;            charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			cfg := Config{
 | |
| 				MinSize:              1024,
 | |
| 				ExcludedContentTypes: test.excludedContentTypes,
 | |
| 			}
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.Header().Set(contentType, test.contentType)
 | |
| 				rw.WriteHeader(http.StatusOK)
 | |
| 
 | |
| 				tb := bigTestBody
 | |
| 				for len(tb) > 0 {
 | |
| 					// Write 100 bytes per run
 | |
| 					// Detection should not be affected (we send 100 bytes)
 | |
| 					toWrite := 100
 | |
| 					if toWrite > len(tb) {
 | |
| 						toWrite = len(tb)
 | |
| 					}
 | |
| 
 | |
| 					_, err := rw.Write(tb[:toWrite])
 | |
| 					require.NoError(t, err)
 | |
| 
 | |
| 					// Flush between each write
 | |
| 					rw.(http.Flusher).Flush()
 | |
| 					tb = tb[toWrite:]
 | |
| 				}
 | |
| 			})
 | |
| 
 | |
| 			h := mustNewCompressionHandler(t, cfg, zstdName, next)
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, zstdName)
 | |
| 
 | |
| 			// This doesn't allow checking flushes, but we validate if content is correct.
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, rw.Code)
 | |
| 
 | |
| 			if test.expCompression {
 | |
| 				assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				reader, err := zstd.NewReader(rw.Body)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				got, err := io.ReadAll(reader)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			} else {
 | |
| 				assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				got, err := io.ReadAll(rw.Body)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_FlushIncludedContentTypes(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                 string
 | |
| 		contentType          string
 | |
| 		includedContentTypes []string
 | |
| 		expCompression       bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:           "Always compress when content types are empty",
 | |
| 			contentType:    "",
 | |
| 			expCompression: true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match",
 | |
| 			contentType:          "application/json",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME no match",
 | |
| 			contentType:          "text/xml",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with no other directive ignores non-MIME directives",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, different charset",
 | |
| 			contentType:          "application/json; charset=ascii",
 | |
| 			includedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, same charset",
 | |
| 			contentType:          "application/json; charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json; charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match with other directives requires all directives be equal, missing charset",
 | |
| 			contentType:          "application/json",
 | |
| 			includedContentTypes: []string{"application/json; charset=ascii"},
 | |
| 			expCompression:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match case insensitive",
 | |
| 			contentType:          "Application/Json",
 | |
| 			includedContentTypes: []string{"application/json"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                 "MIME match ignore whitespace",
 | |
| 			contentType:          "application/json;charset=utf-8",
 | |
| 			includedContentTypes: []string{"application/json;            charset=utf-8"},
 | |
| 			expCompression:       true,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			cfg := Config{
 | |
| 				MinSize:              1024,
 | |
| 				IncludedContentTypes: test.includedContentTypes,
 | |
| 			}
 | |
| 
 | |
| 			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 				rw.Header().Set(contentType, test.contentType)
 | |
| 				rw.WriteHeader(http.StatusOK)
 | |
| 
 | |
| 				tb := bigTestBody
 | |
| 				for len(tb) > 0 {
 | |
| 					// Write 100 bytes per run
 | |
| 					// Detection should not be affected (we send 100 bytes)
 | |
| 					toWrite := 100
 | |
| 					if toWrite > len(tb) {
 | |
| 						toWrite = len(tb)
 | |
| 					}
 | |
| 
 | |
| 					_, err := rw.Write(tb[:toWrite])
 | |
| 					require.NoError(t, err)
 | |
| 
 | |
| 					// Flush between each write
 | |
| 					rw.(http.Flusher).Flush()
 | |
| 					tb = tb[toWrite:]
 | |
| 				}
 | |
| 			})
 | |
| 
 | |
| 			h := mustNewCompressionHandler(t, cfg, zstdName, next)
 | |
| 
 | |
| 			req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
 | |
| 			req.Header.Set(acceptEncoding, zstdName)
 | |
| 
 | |
| 			// This doesn't allow checking flushes, but we validate if content is correct.
 | |
| 			rw := httptest.NewRecorder()
 | |
| 			h.ServeHTTP(rw, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, rw.Code)
 | |
| 
 | |
| 			if test.expCompression {
 | |
| 				assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				reader, err := zstd.NewReader(rw.Body)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				got, err := io.ReadAll(reader)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			} else {
 | |
| 				assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
 | |
| 
 | |
| 				got, err := io.ReadAll(rw.Body)
 | |
| 				assert.NoError(t, err)
 | |
| 				assert.Equal(t, bigTestBody, got)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	var writer NewCompressionWriter
 | |
| 	switch algo {
 | |
| 	case zstdName:
 | |
| 		writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
 | |
| 			writer, err := zstd.NewWriter(rw)
 | |
| 			require.NoError(t, err)
 | |
| 			return writer, zstdName, nil
 | |
| 		}
 | |
| 	case brotliName:
 | |
| 		writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
 | |
| 			return brotli.NewWriter(rw), brotliName, nil
 | |
| 		}
 | |
| 	default:
 | |
| 		assert.Failf(t, "unknown compression algorithm: %s", algo)
 | |
| 	}
 | |
| 
 | |
| 	w, err := NewCompressionHandler(cfg, writer, next)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	return w
 | |
| }
 | |
| 
 | |
| func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 		if req.URL.Path == "/compressed" {
 | |
| 			rw.Header().Set("Content-Encoding", brotliName)
 | |
| 		}
 | |
| 
 | |
| 		rw.WriteHeader(http.StatusAccepted)
 | |
| 		_, err := rw.Write(body)
 | |
| 		require.NoError(t, err)
 | |
| 	})
 | |
| 
 | |
| 	return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next)
 | |
| }
 | |
| 
 | |
| func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 		if req.URL.Path == "/compressed" {
 | |
| 			rw.Header().Set("Content-Encoding", zstdName)
 | |
| 		}
 | |
| 
 | |
| 		rw.WriteHeader(http.StatusAccepted)
 | |
| 		_, err := rw.Write(body)
 | |
| 		require.NoError(t, err)
 | |
| 	})
 | |
| 
 | |
| 	return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next)
 | |
| }
 | |
| 
 | |
| func Test_ParseContentType_equals(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc      string
 | |
| 		pct       parsedContentType
 | |
| 		mediaType string
 | |
| 		params    map[string]string
 | |
| 		expect    assert.BoolAssertionFunc
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:   "empty parsed content type",
 | |
| 			expect: assert.True,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "simple content type",
 | |
| 			pct: parsedContentType{
 | |
| 				mediaType: "plain/text",
 | |
| 			},
 | |
| 			mediaType: "plain/text",
 | |
| 			expect:    assert.True,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "content type with params",
 | |
| 			pct: parsedContentType{
 | |
| 				mediaType: "plain/text",
 | |
| 				params: map[string]string{
 | |
| 					"charset": "utf8",
 | |
| 				},
 | |
| 			},
 | |
| 			mediaType: "plain/text",
 | |
| 			params: map[string]string{
 | |
| 				"charset": "utf8",
 | |
| 			},
 | |
| 			expect: assert.True,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "different content type",
 | |
| 			pct: parsedContentType{
 | |
| 				mediaType: "plain/text",
 | |
| 			},
 | |
| 			mediaType: "application/json",
 | |
| 			expect:    assert.False,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "content type with params",
 | |
| 			pct: parsedContentType{
 | |
| 				mediaType: "plain/text",
 | |
| 				params: map[string]string{
 | |
| 					"charset": "utf8",
 | |
| 				},
 | |
| 			},
 | |
| 			mediaType: "plain/text",
 | |
| 			params: map[string]string{
 | |
| 				"charset": "latin-1",
 | |
| 			},
 | |
| 			expect: assert.False,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "different number of parameters",
 | |
| 			pct: parsedContentType{
 | |
| 				mediaType: "plain/text",
 | |
| 				params: map[string]string{
 | |
| 					"charset": "utf8",
 | |
| 				},
 | |
| 			},
 | |
| 			mediaType: "plain/text",
 | |
| 			params: map[string]string{
 | |
| 				"charset": "utf8",
 | |
| 				"q":       "0.8",
 | |
| 			},
 | |
| 			expect: assert.False,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			t.Parallel()
 | |
| 
 | |
| 			test.expect(t, test.pct.equals(test.mediaType, test.params))
 | |
| 		})
 | |
| 	}
 | |
| }
 |