diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 04c9cd2f5..20374bf6f 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -523,6 +523,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { handler func(w http.ResponseWriter, r *http.Request) sshCommand string wantClientOutput string + wantExitCode int // expected SSH exit code; 0 means don't check clientOutputMustNotContain []string }{ @@ -533,6 +534,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { }, sshCommand: "echo hello", wantClientOutput: "session rejected\r\n", + wantExitCode: 254, clientOutputMustNotContain: []string{"hello"}, }, @@ -580,6 +582,16 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { } else { t.Errorf("client did not get kicked out: %q", got) } + if tt.wantExitCode != 0 { + var exitErr *testssh.ExitError + if errors.As(err, &exitErr) { + if exitErr.ExitStatus() != tt.wantExitCode { + t.Errorf("exit code = %d, want %d", exitErr.ExitStatus(), tt.wantExitCode) + } + } else { + t.Errorf("expected *ssh.ExitError, got %T: %v", err, err) + } + } gotStr := string(got) if !strings.HasSuffix(gotStr, tt.wantClientOutput) { t.Errorf("client got %q, want %q", got, tt.wantClientOutput) @@ -1218,6 +1230,45 @@ func TestSSH(t *testing.T) { t.Errorf("got %q; want %q", got, str) } }) + + t.Run("exit_code_zero", func(t *testing.T) { + cmd := execSSH("true") + if err := cmd.Run(); err != nil { + t.Fatalf("expected exit code 0, got error: %v", err) + } + }) + + t.Run("exit_code_passthrough", func(t *testing.T) { + cmd := execSSH("exit 42") + err := cmd.Run() + if err == nil { + t.Fatal("expected non-zero exit code") + } + var ee *exec.ExitError + if !errors.As(err, &ee) { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + if got := ee.ExitCode(); got != 42 { + t.Errorf("exit code = %d, want 42", got) + } + }) + + t.Run("exit_code_127_command_not_found", func(t *testing.T) { + cmd := execSSH("/nonexistent/binary") + err := cmd.Run() + if err == nil { + t.Fatal("expected non-zero exit code") + } + var ee *exec.ExitError + if !errors.As(err, &ee) { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + // Exit code 127 for command not found, per POSIX shell convention. + // https://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html#tag_18_08_02 + if got := ee.ExitCode(); got != 127 { + t.Errorf("exit code = %d, want 127", got) + } + }) } func parseEnv(out []byte) map[string]string {