diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index 7b70a6d51..0d9760db7 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -951,3 +951,193 @@ func (conn *addressFakingConn) RemoteAddr() net.Addr { Port: 10002, } } + +// TestIntegrationExitCodes verifies that SSH exit codes are correctly +// delivered to the client through the full server stack. +func TestIntegrationExitCodes(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { debugTest.Store(false) }) + + tests := []struct { + name string + cmd string + wantCode int + }{ + { + name: "success", + cmd: "true", + wantCode: 0, + }, + { + name: "exit_code_passthrough", + cmd: "exit 42", + wantCode: 42, + }, + { + // Exit code 127 for command not found, per POSIX shell convention. + // https://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html#tag_18_08_02 + name: "command_not_found", + cmd: "/nonexistent/binary", + wantCode: 127, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := testSession(t, false, false, nil) + err := s.Run(tt.cmd) + if tt.wantCode == 0 { + if err != nil { + t.Fatalf("expected exit code 0, got error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected exit code %d, got nil error", tt.wantCode) + } + var exitErr *ssh.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *ssh.ExitError, got %T: %v", err, err) + } + if exitErr.ExitStatus() != tt.wantCode { + t.Errorf("exit code = %d, want %d", exitErr.ExitStatus(), tt.wantCode) + } + }) + } +} + +// TestLocalUnixForwardingHalfClose verifies that the bidirectional copy +// in Unix socket forwarding uses half-close correctly: when one direction +// finishes, the other direction's data is not lost. This tests the bicopy +// fix where the old cancel-on-first-direction-complete approach would +// prematurely teardown the slower direction. +func TestLocalUnixForwardingHalfClose(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { debugTest.Store(false) }) + + socketDir, err := os.MkdirTemp("", "tailssh-test-") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(socketDir) }) + socketPath := filepath.Join(socketDir, "halfclose.sock") + + ul, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ul.Close() }) + + // Service that reads all input, then sends a delayed response. + // With the old bicopy (cancel on first direction complete), + // the response would be lost because the channel would be torn + // down when the client's write side finished. + const response = "delayed-response-after-client-closes-write" + go func() { + for { + conn, err := ul.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + // Read all input from client. + io.ReadAll(conn) + // Delay, then send response. + time.Sleep(100 * time.Millisecond) + io.WriteString(conn, response) + }() + } + }() + + addr := testServerWithOpts(t, testServerOpts{ + username: "testuser", + allowLocalPortForwarding: true, + }) + + cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cl.Close() }) + + conn, err := cl.Dial("unix", socketPath) + if err != nil { + t.Fatalf("failed to dial unix socket through SSH: %s", err) + } + + // Send data and close write side (half-close). + _, err = io.WriteString(conn, "request data") + if err != nil { + t.Fatalf("failed to write: %s", err) + } + if tc, ok := conn.(*net.TCPConn); ok { + tc.CloseWrite() + } else { + // ssh.Conn doesn't expose CloseWrite, close the write side + // by closing the whole conn -- but then we can't read. + // Instead, just close and rely on the server seeing EOF. + conn.Close() + } + + // Read the delayed response. This is the critical assertion: + // with the old bicopy, this would fail because the connection + // would be torn down when we closed the write side. + got, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("failed to read response: %s", err) + } + if string(got) != response { + t.Errorf("got %q, want %q", got, response) + } +} + +// TestIntegrationSIGHUP verifies that when an SSH session is terminated, +// the child process receives SIGHUP (matching POSIX terminal disconnect +// semantics) rather than SIGKILL. +func TestIntegrationSIGHUP(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { debugTest.Store(false) }) + + markerFile := filepath.Join(t.TempDir(), "sighup-received") + + cl := testClient(t, false, false) + s, err := cl.NewSession() + if err != nil { + t.Fatal(err) + } + + // Start a process that traps SIGHUP and writes a marker file. + // The process sleeps long enough for us to close the session. + cmd := fmt.Sprintf( + `trap 'echo received > %s; exit 0' HUP; sleep 30`, + markerFile, + ) + if err := s.Start(cmd); err != nil { + t.Fatalf("failed to start command: %v", err) + } + + // Give the process time to set up the signal handler. + time.Sleep(500 * time.Millisecond) + + // Close the session, which should trigger SIGHUP to the process. + s.Close() + cl.Close() + + // Wait for the signal handler to run and write the marker file. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if _, err := os.Stat(markerFile); err == nil { + // Marker file exists -- process received SIGHUP. + data, _ := os.ReadFile(markerFile) + if strings.TrimSpace(string(data)) != "received" { + t.Fatalf("unexpected marker content: %q", data) + } + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatal("process did not receive SIGHUP within timeout") +}