mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-06 04:36:15 +02:00
ssh/tailssh: explore client connection monitoring
Run a connection monitor that pings the SSH client when session is recorded. If the pings fail consecutively, close the recording and then cancel the connection. This is one way to ensure that session records get flushed promptly when using S3 multi-part upload. Timeouts and consecutive failure threshold are hardcoded because this is just an experiment. Fixes tailscale.com/corp#33968 Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
This commit is contained in:
parent
1eba5b0cbd
commit
9f3da7ab26
@ -32,6 +32,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/sessionrecording"
|
||||
@ -76,6 +77,7 @@ type ipnLocalBackend interface {
|
||||
Dialer() *tsdial.Dialer
|
||||
TailscaleVarRoot() string
|
||||
NodeKey() key.NodePublic
|
||||
Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error)
|
||||
}
|
||||
|
||||
type server struct {
|
||||
@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) {
|
||||
}
|
||||
|
||||
var errSessionDone = errors.New("session is done")
|
||||
var errClientUnreachable = errors.New("client is unreachable")
|
||||
|
||||
// handleSSHAgentForwarding starts a Unix socket listener and in the background
|
||||
// forwards agent connections between the listener and the ssh.Session.
|
||||
@ -954,6 +957,57 @@ func (ss *sshSession) run() {
|
||||
ss.logf("startNewRecording: <nil>")
|
||||
if rec != nil {
|
||||
defer rec.Close()
|
||||
|
||||
ping := func() bool {
|
||||
clientIP := ss.conn.info.src.Addr()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := ss.conn.srv.lb.Ping(ctx, clientIP, tailcfg.PingICMP, 0)
|
||||
if err != nil {
|
||||
ss.logf("pinging SSH client %s failed: %v", clientIP, err)
|
||||
return false
|
||||
}
|
||||
|
||||
ss.logf("pinging SSH client %s successful", clientIP)
|
||||
return true
|
||||
}
|
||||
|
||||
go func() {
|
||||
ss.logf("starting connection monitor for session %s", ss.sharedID)
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
consecutiveFailures := 0
|
||||
const maxFailures = 3
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ss.ctx.Done():
|
||||
ss.logf("session terminated, closing recording: %v", context.Cause(ss.ctx))
|
||||
rec.Close()
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
pong := ping()
|
||||
if pong {
|
||||
consecutiveFailures = 0
|
||||
ss.logf("connection test passed for session %s", ss.sharedID)
|
||||
} else {
|
||||
consecutiveFailures++
|
||||
ss.logf("connection test failed (%d/%d) for session %s", consecutiveFailures, maxFailures, ss.sharedID)
|
||||
|
||||
if consecutiveFailures >= maxFailures {
|
||||
ss.logf("connection lost (connection test failed %d times), closing recording", maxFailures)
|
||||
ss.cancelCtx(errClientUnreachable)
|
||||
rec.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user