mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-27 06:11:43 +01:00
The hijacker on k8s-proxy's reverse proxy is used to stream recordings to tsrecorder as they pass through the proxy to the kubernetes api server. The connection to the recorder was using the client's (e.g., kubectl) context, rather than a dedicated one. This was causing the recording stream to get cut off in scenarios where the client cancelled the context before streaming could be completed. By using a dedicated context, we can continue streaming even if the client cancels the context (for example if the client request completes). Fixes #17404 Signed-off-by: chaosinthecrd <tom@tmlabs.co.uk>
127 lines
3.6 KiB
Go
127 lines
3.6 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
//go:build !plan9
|
|
|
|
package sessionrecording
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
"tailscale.com/client/tailscale/apitype"
|
|
"tailscale.com/k8s-operator/sessionrecording/fakes"
|
|
"tailscale.com/net/netx"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/tsnet"
|
|
"tailscale.com/tstest"
|
|
)
|
|
|
|
func Test_Hijacker(t *testing.T) {
|
|
zl, err := zap.NewDevelopment()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
failOpen bool
|
|
failRecorderConnect bool // fail initial connect to the recorder
|
|
failRecorderConnPostConnect bool // send error down the error channel
|
|
wantsConnClosed bool
|
|
wantsSetupErr bool
|
|
proto Protocol
|
|
}{
|
|
{
|
|
name: "setup_succeeds_conn_stays_open",
|
|
proto: SPDYProtocol,
|
|
},
|
|
{
|
|
name: "setup_succeeds_conn_stays_open_ws",
|
|
proto: WSProtocol,
|
|
},
|
|
{
|
|
name: "setup_fails_policy_is_to_fail_open_conn_stays_open",
|
|
failOpen: true,
|
|
failRecorderConnect: true,
|
|
proto: SPDYProtocol,
|
|
},
|
|
{
|
|
name: "setup_fails_policy_is_to_fail_closed_conn_is_closed",
|
|
failRecorderConnect: true,
|
|
wantsSetupErr: true,
|
|
wantsConnClosed: true,
|
|
proto: SPDYProtocol,
|
|
},
|
|
{
|
|
name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open",
|
|
failRecorderConnPostConnect: true,
|
|
failOpen: true,
|
|
proto: SPDYProtocol,
|
|
},
|
|
{
|
|
name: "connection_fails_post-initial_connect,_policy_is_to_fail_closed_conn_is_closed",
|
|
failRecorderConnPostConnect: true,
|
|
wantsConnClosed: true,
|
|
proto: SPDYProtocol,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tc := &fakes.TestConn{}
|
|
ch := make(chan error)
|
|
h := &Hijacker{
|
|
connectToRecorder: func(context.Context,
|
|
[]netip.AddrPort,
|
|
netx.DialFunc,
|
|
) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) {
|
|
if tt.failRecorderConnect {
|
|
err = errors.New("test")
|
|
}
|
|
return wc, rec, ch, err
|
|
},
|
|
failOpen: tt.failOpen,
|
|
who: &apitype.WhoIsResponse{Node: &tailcfg.Node{}, UserProfile: &tailcfg.UserProfile{}},
|
|
log: zl.Sugar(),
|
|
ts: &tsnet.Server{},
|
|
req: &http.Request{URL: &url.URL{RawQuery: "tty=true"}},
|
|
proto: tt.proto,
|
|
}
|
|
ctx := context.Background()
|
|
_, err := h.setUpRecording(tc)
|
|
if (err != nil) != tt.wantsSetupErr {
|
|
t.Errorf("spdyHijacker.setupRecording() error = %v, wantErr %v", err, tt.wantsSetupErr)
|
|
return
|
|
}
|
|
if tt.failRecorderConnPostConnect {
|
|
select {
|
|
case ch <- errors.New("err"):
|
|
case <-time.After(time.Second * 15):
|
|
t.Errorf("error from recorder conn was not read within 15 seconds")
|
|
}
|
|
}
|
|
timeout := time.Second * 20
|
|
// TODO (irbekrm): cover case where an error is received
|
|
// over channel and the failure policy is to fail open
|
|
// (test that connection remains open over some period
|
|
// of time).
|
|
if err := tstest.WaitFor(timeout, func() (err error) {
|
|
if tt.wantsConnClosed != tc.IsClosed() {
|
|
return fmt.Errorf("got connection state: %t, wants connection state: %t", tc.IsClosed(), tt.wantsConnClosed)
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
t.Errorf("connection did not reach the desired state within %s", timeout.String())
|
|
}
|
|
ctx.Done()
|
|
})
|
|
}
|
|
}
|