mirror of
				https://github.com/traefik/traefik.git
				synced 2025-10-31 00:11:38 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			731 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			731 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package fast
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"crypto/sha1"
 | |
| 	"crypto/tls"
 | |
| 	"encoding/base64"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"net/url"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	gorillawebsocket "github.com/gorilla/websocket"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 	"github.com/traefik/traefik/v3/pkg/testhelpers"
 | |
| 	"golang.org/x/net/websocket"
 | |
| )
 | |
| 
 | |
| func TestWebSocketUpgradeCase(t *testing.T) {
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		challengeKey := r.Header.Get("Sec-Websocket-Key")
 | |
| 
 | |
| 		hijacker, ok := w.(http.Hijacker)
 | |
| 		require.True(t, ok)
 | |
| 
 | |
| 		c, _, err := hijacker.Hijack()
 | |
| 		require.NoError(t, err)
 | |
| 
 | |
| 		// Force answer with "Connection: upgrade" in lowercase.
 | |
| 		_, err = c.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Accept: " + computeAcceptKey(challengeKey) + "\r\n\n"))
 | |
| 		require.NoError(t, err)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	_, conn, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 	).open()
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	conn.Close()
 | |
| }
 | |
| 
 | |
| func TestWebSocketTCPClose(t *testing.T) {
 | |
| 	errChan := make(chan error, 1)
 | |
| 	upgrader := gorillawebsocket.Upgrader{}
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		c, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer c.Close()
 | |
| 
 | |
| 		for {
 | |
| 			_, _, err := c.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				errChan <- err
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	_, conn, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 	).open()
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	conn.Close()
 | |
| 
 | |
| 	serverErr := <-errChan
 | |
| 
 | |
| 	var wsErr *gorillawebsocket.CloseError
 | |
| 	require.ErrorAs(t, serverErr, &wsErr)
 | |
| 	assert.Equal(t, 1006, wsErr.Code)
 | |
| }
 | |
| 
 | |
| func TestWebSocketPingPong(t *testing.T) {
 | |
| 	upgrader := gorillawebsocket.Upgrader{
 | |
| 		HandshakeTimeout: 10 * time.Second,
 | |
| 		CheckOrigin: func(*http.Request) bool {
 | |
| 			return true
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
 | |
| 		ws, err := upgrader.Upgrade(writer, request, nil)
 | |
| 		require.NoError(t, err)
 | |
| 
 | |
| 		ws.SetPingHandler(func(appData string) error {
 | |
| 			err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			return nil
 | |
| 		})
 | |
| 
 | |
| 		_, _, _ = ws.ReadMessage()
 | |
| 	})
 | |
| 
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		mux.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	serverAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 	headers := http.Header{}
 | |
| 	webSocketURL := "ws://" + serverAddr + "/ws"
 | |
| 	headers.Add("Origin", webSocketURL)
 | |
| 
 | |
| 	conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
 | |
| 	require.NoError(t, err, "Error during Dial with response: %+v", resp)
 | |
| 	defer conn.Close()
 | |
| 
 | |
| 	goodErr := fmt.Errorf("signal: %s", "Good data")
 | |
| 	badErr := fmt.Errorf("signal: %s", "Bad data")
 | |
| 	conn.SetPongHandler(func(data string) error {
 | |
| 		if data == "PingPong" {
 | |
| 			return goodErr
 | |
| 		}
 | |
| 
 | |
| 		return badErr
 | |
| 	})
 | |
| 
 | |
| 	err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	_, _, err = conn.ReadMessage()
 | |
| 
 | |
| 	if !errors.Is(err, goodErr) {
 | |
| 		require.NoError(t, err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWebSocketEcho(t *testing.T) {
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
 | |
| 		msg := make([]byte, 4)
 | |
| 
 | |
| 		n, err := conn.Read(msg)
 | |
| 		require.NoError(t, err)
 | |
| 
 | |
| 		_, err = conn.Write(msg[:n])
 | |
| 		require.NoError(t, err)
 | |
| 
 | |
| 		err = conn.Close()
 | |
| 		require.NoError(t, err)
 | |
| 	}))
 | |
| 
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		mux.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	serverAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 	headers := http.Header{}
 | |
| 	webSocketURL := "ws://" + serverAddr + "/ws"
 | |
| 	headers.Add("Origin", webSocketURL)
 | |
| 
 | |
| 	conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
 | |
| 	require.NoError(t, err, "Error during Dial with response: %+v", resp)
 | |
| 
 | |
| 	err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	_, msg, err := conn.ReadMessage()
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	assert.Equal(t, "OK", string(msg))
 | |
| 
 | |
| 	err = conn.Close()
 | |
| 	require.NoError(t, err)
 | |
| }
 | |
| 
 | |
| func TestWebSocketPassHost(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc     string
 | |
| 		passHost bool
 | |
| 		expected string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:     "PassHost false",
 | |
| 			passHost: false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:     "PassHost true",
 | |
| 			passHost: true,
 | |
| 			expected: "example.com",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, test := range testCases {
 | |
| 		t.Run(test.desc, func(t *testing.T) {
 | |
| 			mux := http.NewServeMux()
 | |
| 			mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
 | |
| 				req := conn.Request()
 | |
| 
 | |
| 				if test.passHost {
 | |
| 					require.Equal(t, test.expected, req.Host)
 | |
| 				} else {
 | |
| 					require.NotEqual(t, test.expected, req.Host)
 | |
| 				}
 | |
| 
 | |
| 				msg := make([]byte, 4)
 | |
| 
 | |
| 				n, err := conn.Read(msg)
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				_, err = conn.Write(msg[:n])
 | |
| 				require.NoError(t, err)
 | |
| 
 | |
| 				err = conn.Close()
 | |
| 				require.NoError(t, err)
 | |
| 			}))
 | |
| 
 | |
| 			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 				mux.ServeHTTP(w, req)
 | |
| 			}))
 | |
| 			defer srv.Close()
 | |
| 
 | |
| 			proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 
 | |
| 			serverAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 			headers := http.Header{}
 | |
| 			webSocketURL := "ws://" + serverAddr + "/ws"
 | |
| 			headers.Add("Origin", webSocketURL)
 | |
| 			headers.Add("Host", "example.com")
 | |
| 
 | |
| 			conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
 | |
| 			require.NoError(t, err, "Error during Dial with response: %+v", resp)
 | |
| 
 | |
| 			err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			_, msg, err := conn.ReadMessage()
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			assert.Equal(t, "OK", string(msg))
 | |
| 
 | |
| 			err = conn.Close()
 | |
| 			require.NoError(t, err)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
 | |
| 	upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
 | |
| 		return true
 | |
| 	}}
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		c, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer c.Close()
 | |
| 		for {
 | |
| 			mt, message, err := c.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 			err = c.WriteMessage(mt, message)
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("ok"),
 | |
| 		withOrigin("http://127.0.0.2"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| func TestWebSocketRequestWithOrigin(t *testing.T) {
 | |
| 	upgrader := gorillawebsocket.Upgrader{}
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		c, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer c.Close()
 | |
| 		for {
 | |
| 			mt, message, err := c.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 			err = c.WriteMessage(mt, message)
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	_, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("echo"),
 | |
| 		withOrigin("http://127.0.0.2"),
 | |
| 	).send()
 | |
| 	require.EqualError(t, err, "bad status")
 | |
| 
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("ok"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| func TestWebSocketRequestWithQueryParams(t *testing.T) {
 | |
| 	upgrader := gorillawebsocket.Upgrader{}
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		conn, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 
 | |
| 		assert.Equal(t, "test", r.URL.Query().Get("query"))
 | |
| 		for {
 | |
| 			mt, message, err := conn.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 
 | |
| 			err = conn.WriteMessage(mt, message)
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws?query=test"),
 | |
| 		withData("ok"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
 | |
| 		_ = conn.Close()
 | |
| 	}))
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		mux.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	u := parseURI(t, srv.URL)
 | |
| 
 | |
| 	f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
 | |
| 		return net.Dial("tcp", u.Host)
 | |
| 	}))
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		req.URL = parseURI(t, srv.URL)
 | |
| 		w.Header().Set("HEADER-KEY", "HEADER-VALUE")
 | |
| 		f.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	serverAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 	headers := http.Header{}
 | |
| 	webSocketURL := "ws://" + serverAddr + "/ws"
 | |
| 	headers.Add("Origin", webSocketURL)
 | |
| 	conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
 | |
| 	require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
 | |
| 	defer conn.Close()
 | |
| 
 | |
| 	assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
 | |
| }
 | |
| 
 | |
| func TestWebSocketRequestWithEncodedChar(t *testing.T) {
 | |
| 	upgrader := gorillawebsocket.Upgrader{}
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		conn, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 		assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
 | |
| 		for {
 | |
| 			mt, message, err := conn.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 			err = conn.WriteMessage(mt, message)
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/%3A%2F%2F"),
 | |
| 		withData("ok"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| func TestWebSocketUpgradeFailed(t *testing.T) {
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
 | |
| 		w.WriteHeader(http.StatusBadRequest)
 | |
| 	})
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		mux.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	u := parseURI(t, srv.URL)
 | |
| 	f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
 | |
| 		return net.Dial("tcp", u.Host)
 | |
| 	}))
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		path := req.URL.Path // keep the original path
 | |
| 
 | |
| 		if path != "/ws" {
 | |
| 			w.WriteHeader(http.StatusOK)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// Set new backend URL
 | |
| 		req.URL = parseURI(t, srv.URL)
 | |
| 		req.URL.Path = path
 | |
| 		f.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	defer conn.Close()
 | |
| 
 | |
| 	req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	req.Header.Add("upgrade", "websocket")
 | |
| 	req.Header.Add("Connection", "upgrade")
 | |
| 
 | |
| 	err = req.Write(conn)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	// First request works with 400
 | |
| 	br := bufio.NewReader(conn)
 | |
| 	resp, err := http.ReadResponse(br, req)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	assert.Equal(t, 400, resp.StatusCode)
 | |
| }
 | |
| 
 | |
| func TestForwardsWebsocketTraffic(t *testing.T) {
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
 | |
| 		_, err := conn.Write([]byte("ok"))
 | |
| 		require.NoError(t, err)
 | |
| 
 | |
| 		err = conn.Close()
 | |
| 		require.NoError(t, err)
 | |
| 	}))
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		mux.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxy.Close()
 | |
| 
 | |
| 	proxyAddr := proxy.Listener.Addr().String()
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("echo"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| func createTLSWebsocketServer() *httptest.Server {
 | |
| 	upgrader := gorillawebsocket.Upgrader{}
 | |
| 	srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		conn, err := upgrader.Upgrade(w, r, nil)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 		for {
 | |
| 			mt, message, err := conn.ReadMessage()
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 
 | |
| 			err = conn.WriteMessage(mt, message)
 | |
| 			if err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}))
 | |
| 	return srv
 | |
| }
 | |
| 
 | |
| func TestWebSocketTransferTLSConfig(t *testing.T) {
 | |
| 	srv := createTLSWebsocketServer()
 | |
| 	defer srv.Close()
 | |
| 
 | |
| 	proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
 | |
| 	defer proxyWithoutTLSConfig.Close()
 | |
| 
 | |
| 	proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
 | |
| 
 | |
| 	_, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("ok"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.EqualError(t, err, "bad status")
 | |
| 
 | |
| 	pool := createConnectionPool(srv.URL, &tls.Config{InsecureSkipVerify: true})
 | |
| 
 | |
| 	proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, pool)
 | |
| 	defer proxyWithTLSConfig.Close()
 | |
| 
 | |
| 	proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
 | |
| 
 | |
| 	resp, err := newWebsocketRequest(
 | |
| 		withServer(proxyAddr),
 | |
| 		withPath("/ws"),
 | |
| 		withData("ok"),
 | |
| 	).send()
 | |
| 
 | |
| 	require.NoError(t, err)
 | |
| 	assert.Equal(t, "ok", resp)
 | |
| }
 | |
| 
 | |
| const dialTimeout = time.Second
 | |
| 
 | |
| type websocketRequestOpt func(w *websocketRequest)
 | |
| 
 | |
| func withServer(server string) websocketRequestOpt {
 | |
| 	return func(w *websocketRequest) {
 | |
| 		w.ServerAddr = server
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func withPath(path string) websocketRequestOpt {
 | |
| 	return func(w *websocketRequest) {
 | |
| 		w.Path = path
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func withData(data string) websocketRequestOpt {
 | |
| 	return func(w *websocketRequest) {
 | |
| 		w.Data = data
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func withOrigin(origin string) websocketRequestOpt {
 | |
| 	return func(w *websocketRequest) {
 | |
| 		w.Origin = origin
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
 | |
| 	wsrequest := &websocketRequest{}
 | |
| 	for _, opt := range opts {
 | |
| 		opt(wsrequest)
 | |
| 	}
 | |
| 
 | |
| 	if wsrequest.Origin == "" {
 | |
| 		wsrequest.Origin = "http://" + wsrequest.ServerAddr
 | |
| 	}
 | |
| 
 | |
| 	if wsrequest.Config == nil {
 | |
| 		wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
 | |
| 	}
 | |
| 
 | |
| 	return wsrequest
 | |
| }
 | |
| 
 | |
| type websocketRequest struct {
 | |
| 	ServerAddr string
 | |
| 	Path       string
 | |
| 	Data       string
 | |
| 	Origin     string
 | |
| 	Config     *websocket.Config
 | |
| }
 | |
| 
 | |
| func (w *websocketRequest) send() (string, error) {
 | |
| 	conn, _, err := w.open()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	defer conn.Close()
 | |
| 
 | |
| 	if _, err := conn.Write([]byte(w.Data)); err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	msg := make([]byte, 512)
 | |
| 
 | |
| 	var n int
 | |
| 	n, err = conn.Read(msg)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	received := string(msg[:n])
 | |
| 	return received, nil
 | |
| }
 | |
| 
 | |
| func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
 | |
| 	client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	conn, err := websocket.NewClient(w.Config, client)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	return conn, client, err
 | |
| }
 | |
| 
 | |
| func parseURI(t *testing.T, uri string) *url.URL {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	out, err := url.ParseRequestURI(uri)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	return out
 | |
| }
 | |
| 
 | |
| func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
 | |
| 	u := testhelpers.MustParseURL(target)
 | |
| 	return newConnPool(200, 0, 0, func() (net.Conn, error) {
 | |
| 		if tlsConfig != nil {
 | |
| 			return tls.Dial("tcp", u.Host, tlsConfig)
 | |
| 		}
 | |
| 
 | |
| 		return net.Dial("tcp", u.Host)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptest.Server {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	u := parseURI(t, uri)
 | |
| 	proxy, err := NewReverseProxy(u, nil, false, true, false, pool)
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | |
| 		path := req.URL.Path // keep the original path
 | |
| 		// Set new backend URL
 | |
| 		req.URL = u
 | |
| 		req.URL.Path = path
 | |
| 
 | |
| 		proxy.ServeHTTP(w, req)
 | |
| 	}))
 | |
| 	t.Cleanup(srv.Close)
 | |
| 
 | |
| 	return srv
 | |
| }
 | |
| 
 | |
| func computeAcceptKey(challengeKey string) string {
 | |
| 	h := sha1.New() // #nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54
 | |
| 	h.Write([]byte(challengeKey))
 | |
| 	h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
 | |
| 	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 | |
| }
 |