feature/rsh: add a non-ssh encapsulated remote shell

The intent of this mode is to quack a bit like the long gone rsh, but
implemented purely inside the client and providing a faster way to use
rsync over tailscale by way of `rsync -e 'tailscale rsh dest`.

This is an initial prototype.
This commit is contained in:
James Tucker 2026-02-22 18:07:45 -08:00
parent 0bac4223d1
commit d4b565aa48
No known key found for this signature in database
13 changed files with 2573 additions and 0 deletions

View File

@ -259,6 +259,7 @@ change in the future.
pingCmd,
ncCmd,
sshCmd,
rshCmd,
nilOrCall(maybeFunnelCmd),
nilOrCall(maybeServeCmd),
versionCmd,

450
cmd/tailscale/cli/rsh.go Normal file
View File

@ -0,0 +1,450 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package cli
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"github.com/peterbourgon/ff/v3/ffcli"
"tailscale.com/client/tailscale/apitype"
)
var rshArgs struct {
loginUser string // -l flag: SSH login user
sshOption string // -o flag: SSH option (ignored, for compatibility)
}
var rshCmd = &ffcli.Command{
Name: "rsh",
ShortUsage: "tailscale rsh [-l user] [user@]<host> [command...]",
ShortHelp: "Execute a remote command over Tailscale without SSH overhead",
LongHelp: strings.TrimSpace(`
The 'tailscale rsh' command executes a command on a remote Tailscale node
using a direct TCP connection over the Tailscale network. Unlike SSH, it
avoids double encryption (SSH + WireGuard) and SSH's suboptimal buffering.
It is designed to be used as an rsync -e transport replacement:
rsync -e 'tailscale rsh' -avz ./local/ user@host:/remote/
The remote node must have Tailscale SSH enabled, as rsh reuses the same
SSH access policy for authorization.
SSH-compatible flags (-l user, -o option) are accepted and handled
appropriately so that rsync and similar tools can invoke rsh as a
drop-in remote shell replacement.
When used without a command, it starts the user's default login shell.
`),
FlagSet: func() *flag.FlagSet {
fs := newFlagSet("rsh")
fs.StringVar(&rshArgs.loginUser, "l", "", "remote login user (SSH-compatible)")
fs.StringVar(&rshArgs.sshOption, "o", "", "SSH option (ignored, for compatibility)")
return fs
}(),
Exec: runRsh,
}
// rshFraming constants matching feature/rsh/protocol.go.
const (
rshChanStdin byte = 0x00
rshChanStdout byte = 0x01
rshChanStderr byte = 0x02
rshChanExit byte = 0x03
rshTokenLen = 32
rshMaxFrame = 256 * 1024
rshFrameHdrSize = 5
)
func runRsh(ctx context.Context, args []string) error {
if len(args) == 0 {
return errors.New("usage: tailscale rsh [user@]<host> [command...]")
}
// Check tailscaled is running.
st, err := localClient.Status(ctx)
if err != nil {
return fixTailscaledConnectError(err)
}
description, ok := isRunningOrStarting(st)
if !ok {
printf("%s\n", description)
os.Exit(1)
}
username, host, cmdArgs, err := parseRshArgs(args)
if err != nil {
return err
}
// The -l flag parsed by ffcli takes priority over user@host.
// This handles cases like: tailscale rsh -l ubuntu james-ai
// where ffcli parses -l before runRsh sees the args.
if rshArgs.loginUser != "" {
username = rshArgs.loginUser
}
// If no explicit user, default to the current OS user.
if username == "" {
u, err := currentUser()
if err != nil {
return fmt.Errorf("cannot determine current user: %w", err)
}
username = u
}
// Resolve host to a peer.
ps, ok := peerStatusFromArg(st, host)
if !ok {
return fmt.Errorf("unknown host %q; not found in Tailscale network", host)
}
// Build the command string (rsync passes it as separate args).
command := strings.Join(cmdArgs, " ")
// Request an rsh session via the LocalAPI.
type localRshRequest struct {
PeerID string `json:"peer"`
User string `json:"user"`
Command string `json:"command,omitempty"`
}
reqBody := localRshRequest{
PeerID: string(ps.ID),
User: username,
Command: command,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST",
"http://"+apitype.LocalAPIHost+"/localapi/v0/rsh",
bytes.NewReader(bodyBytes))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := localClient.DoLocalRequest(req)
if err != nil {
return fmt.Errorf("rsh setup: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("rsh setup failed: %s: %s", resp.Status, strings.TrimSpace(string(body)))
}
type rshResponse struct {
Addr string `json:"addr"`
Token string `json:"token"`
}
type rshStatusMessage struct {
Status string `json:"status"`
}
var rshResp rshResponse
ct := resp.Header.Get("Content-Type")
if strings.HasPrefix(ct, "application/x-ndjson") {
// Streaming check mode: read newline-delimited JSON lines.
// Status messages go to stderr, the final rshResponse has addr+token.
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 64*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
// Try to decode as rshResponse (has "addr" field).
var candidate rshResponse
if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" {
rshResp = candidate
continue
}
// Otherwise, treat as a status message.
var msg rshStatusMessage
if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" {
fmt.Fprintf(os.Stderr, "rsh: %s\n", msg.Status)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("rsh: reading streaming response: %w", err)
}
} else {
// Simple JSON response.
if err := json.NewDecoder(resp.Body).Decode(&rshResp); err != nil {
return fmt.Errorf("rsh: invalid response: %w", err)
}
}
resp.Body.Close()
if rshResp.Addr == "" || rshResp.Token == "" {
return errors.New("rsh: server returned empty address or token")
}
// Parse the address to get host and port for DialTCP.
addrHost, portStr, err := splitHostPort(rshResp.Addr)
if err != nil {
return fmt.Errorf("rsh: invalid address %q: %w", rshResp.Addr, err)
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return fmt.Errorf("rsh: invalid port %q: %w", portStr, err)
}
// Decode the token.
token, err := hex.DecodeString(rshResp.Token)
if err != nil || len(token) != rshTokenLen {
return fmt.Errorf("rsh: invalid token")
}
// Connect to the data channel via tailscaled.
conn, err := localClient.DialTCP(ctx, addrHost, uint16(port))
if err != nil {
return fmt.Errorf("rsh: connect to %s: %w", rshResp.Addr, err)
}
defer conn.Close()
// Send the authentication token.
if _, err := conn.Write(token); err != nil {
return fmt.Errorf("rsh: send token: %w", err)
}
// Run the framing protocol.
return rshPumpIO(conn)
}
// rshPumpIO handles the framing protocol between the local stdin/stdout/stderr
// and the remote process over the connection.
func rshPumpIO(conn io.ReadWriteCloser) error {
// Goroutine: read stdin and send as ChanStdin frames.
stdinDone := make(chan struct{})
go func() {
defer close(stdinDone)
buf := make([]byte, 64*1024)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
if werr := writeFrame(conn, rshChanStdin, buf[:n]); werr != nil {
return
}
}
if err != nil {
// Send a zero-length stdin frame to signal EOF.
writeFrame(conn, rshChanStdin, nil)
return
}
}
}()
// Main loop: read frames from the connection and dispatch.
var hdr [rshFrameHdrSize]byte
for {
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
// Connection closed without exit code.
return fmt.Errorf("rsh: connection closed unexpectedly")
}
return fmt.Errorf("rsh: read frame: %w", err)
}
ch := hdr[0]
n := binary.BigEndian.Uint32(hdr[1:])
if n > rshMaxFrame {
return fmt.Errorf("rsh: frame too large: %d", n)
}
switch ch {
case rshChanStdout:
if _, err := io.CopyN(os.Stdout, conn, int64(n)); err != nil {
return fmt.Errorf("rsh: stdout: %w", err)
}
case rshChanStderr:
if _, err := io.CopyN(os.Stderr, conn, int64(n)); err != nil {
return fmt.Errorf("rsh: stderr: %w", err)
}
case rshChanExit:
if n != 4 {
return fmt.Errorf("rsh: invalid exit frame size: %d", n)
}
var exitBuf [4]byte
if _, err := io.ReadFull(conn, exitBuf[:]); err != nil {
return fmt.Errorf("rsh: read exit code: %w", err)
}
code := int(binary.BigEndian.Uint32(exitBuf[:]))
if code != 0 {
os.Exit(code)
}
return nil
default:
// Unknown channel, skip the payload.
if _, err := io.CopyN(io.Discard, conn, int64(n)); err != nil {
return fmt.Errorf("rsh: skip unknown frame: %w", err)
}
}
}
}
// writeFrame writes a single rsh protocol frame to w.
func writeFrame(w io.Writer, ch byte, data []byte) error {
var hdr [rshFrameHdrSize]byte
hdr[0] = ch
binary.BigEndian.PutUint32(hdr[1:], uint32(len(data)))
if _, err := w.Write(hdr[:]); err != nil {
return err
}
if len(data) > 0 {
if _, err := w.Write(data); err != nil {
return err
}
}
return nil
}
// parseRshArgs parses SSH-compatible arguments as passed by rsync and
// similar tools when using rsh as a remote shell transport.
//
// rsync invokes the remote shell as:
//
// tailscale rsh [user@host] [-l user] [-o option]... <host> <command...>
//
// The user@host may appear as the first positional arg (from the rsync
// URI), while -l overrides the username. The bare hostname after flags
// is the actual target. Everything after that is the remote command.
//
// Returns the resolved username (may be empty if none specified), host,
// and command args.
func parseRshArgs(args []string) (username, host string, cmdArgs []string, err error) {
if len(args) == 0 {
return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]")
}
// First, check if args[0] is a user@host or bare host (not a flag).
// rsync passes the user@host from the rsync URI as the first arg,
// before any -l flag.
i := 0
if !strings.HasPrefix(args[0], "-") {
u, h, hasAt := strings.Cut(args[0], "@")
if hasAt {
username = u
host = h
} else {
// Bare hostname (no @). Record it; it may be
// overridden if a second bare hostname appears
// after flags (the rsync pattern).
host = args[0]
}
i = 1
}
// Parse SSH-compatible flags.
flagUser := ""
hadFlags := false
for i < len(args) {
a := args[i]
if a == "--" {
i++
break
}
if !strings.HasPrefix(a, "-") {
break // first non-flag is the host
}
hadFlags = true
switch {
case a == "-l":
// -l <user>
i++
if i >= len(args) {
return "", "", nil, errors.New("rsh: -l requires an argument")
}
flagUser = args[i]
i++
case strings.HasPrefix(a, "-l"):
// -l<user> (no space)
flagUser = a[2:]
i++
case a == "-o":
// -o <option>: SSH option, ignore.
i++
if i < len(args) {
i++ // skip the option value
}
case strings.HasPrefix(a, "-o"):
// -o<option>: SSH option, ignore.
i++
default:
// Unknown flag (e.g. -4, -6, -p, etc.); skip it.
// SSH has many flags; we silently ignore ones we
// don't understand since we don't need them.
i++
}
}
// After flags, the next non-flag arg is the host. rsync passes
// the bare hostname after -l flags, so we expect it here when
// flags were present. When there were no flags and we already
// have a host from args[0], the remaining args are the command.
if hadFlags && i < len(args) && !strings.HasPrefix(args[i], "-") {
host = args[i]
i++
}
// Everything remaining is the command.
cmdArgs = args[i:]
// -l flag overrides any user from user@host.
if flagUser != "" {
username = flagUser
}
if host == "" {
return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]")
}
return username, host, cmdArgs, nil
}
// splitHostPort splits a host:port string. Unlike net.SplitHostPort,
// it handles bare IPv4 addresses with port (100.1.2.3:1234) as well
// as [IPv6]:port format.
func splitHostPort(addr string) (host, port string, err error) {
// Handle IPv6 [::]:port format.
if strings.HasPrefix(addr, "[") {
end := strings.Index(addr, "]:")
if end < 0 {
return "", "", fmt.Errorf("invalid address: %s", addr)
}
return addr[1:end], addr[end+2:], nil
}
// Handle IPv4 host:port.
i := strings.LastIndex(addr, ":")
if i < 0 {
return "", "", fmt.Errorf("no port in address: %s", addr)
}
return addr[:i], addr[i+1:], nil
}
// currentUser returns the current OS username.
func currentUser() (string, error) {
// os/user.Current() can fail in some environments (static builds, etc).
// Try it first, fall back to env vars.
if u := os.Getenv("USER"); u != "" {
return u, nil
}
return "", errors.New("cannot determine current user")
}

View File

@ -0,0 +1,142 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package cli
import (
"strings"
"testing"
)
func TestParseRshArgs(t *testing.T) {
tests := []struct {
name string
args []string
wantUser string
wantHost string
wantCmd string // joined with space
wantErr bool
}{
{
name: "simple_user_at_host_with_command",
args: []string{"alice@myhost", "ls", "-la"},
wantUser: "alice",
wantHost: "myhost",
wantCmd: "ls -la",
},
{
name: "bare_host_with_command",
args: []string{"myhost", "ls"},
wantUser: "",
wantHost: "myhost",
wantCmd: "ls",
},
{
name: "bare_host_no_command",
args: []string{"myhost"},
wantUser: "",
wantHost: "myhost",
wantCmd: "",
},
{
// This is the exact pattern from rsync:
// tailscale rsh ubuntu@james-ai -l ubuntu james-ai rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/
name: "rsync_pattern",
args: []string{"ubuntu@james-ai", "-l", "ubuntu", "james-ai", "rsync", "--server", "--sender", "-vlogDtpre.iLsfxCIvu", ".", "ai/"},
wantUser: "ubuntu",
wantHost: "james-ai",
wantCmd: "rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/",
},
{
name: "l_flag_overrides_user_at_host",
args: []string{"alice@myhost", "-l", "bob", "myhost", "echo", "hi"},
wantUser: "bob",
wantHost: "myhost",
wantCmd: "echo hi",
},
{
name: "l_flag_no_space",
args: []string{"myhost", "-lubuntu", "myhost", "ls"},
wantUser: "ubuntu",
wantHost: "myhost",
wantCmd: "ls",
},
{
name: "l_flag_without_user_at_host",
args: []string{"-l", "ubuntu", "myhost", "rsync", "--server"},
wantUser: "ubuntu",
wantHost: "myhost",
wantCmd: "rsync --server",
},
{
name: "o_flag_ignored",
args: []string{"alice@myhost", "-o", "StrictHostKeyChecking=no", "myhost", "ls"},
wantUser: "alice",
wantHost: "myhost",
wantCmd: "ls",
},
{
name: "o_flag_no_space_ignored",
args: []string{"alice@myhost", "-oStrictHostKeyChecking=no", "myhost", "ls"},
wantUser: "alice",
wantHost: "myhost",
wantCmd: "ls",
},
{
name: "multiple_flags",
args: []string{"alice@myhost", "-o", "BatchMode=yes", "-l", "root", "myhost", "uptime"},
wantUser: "root",
wantHost: "myhost",
wantCmd: "uptime",
},
{
name: "unknown_flags_skipped",
args: []string{"alice@myhost", "-4", "-p", "myhost", "ls"},
wantUser: "alice",
wantHost: "myhost",
wantCmd: "ls",
},
{
name: "double_dash_separator",
args: []string{"myhost", "--", "-l", "this-is-command"},
wantUser: "",
wantHost: "myhost",
wantCmd: "-l this-is-command",
},
{
name: "empty_args",
args: []string{},
wantErr: true,
},
{
name: "l_flag_missing_value",
args: []string{"myhost", "-l"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, host, cmdArgs, err := parseRshArgs(tt.args)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user != tt.wantUser {
t.Errorf("user = %q, want %q", user, tt.wantUser)
}
if host != tt.wantHost {
t.Errorf("host = %q, want %q", host, tt.wantHost)
}
gotCmd := strings.Join(cmdArgs, " ")
if gotCmd != tt.wantCmd {
t.Errorf("command = %q, want %q", gotCmd, tt.wantCmd)
}
})
}
}

View File

@ -0,0 +1,8 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd) && !ts_omit_rsh
package condregister
import _ "tailscale.com/feature/rsh"

View File

@ -0,0 +1,286 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"bufio"
"bytes"
"encoding/json"
"net/netip"
"os/user"
"strings"
"testing"
"time"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
)
func TestExpandDelegateURL(t *testing.T) {
nm := &netmap.NetworkMap{
SelfNode: (&tailcfg.Node{
ID: 42,
StableID: "self-stable",
Key: key.NewNode().Public(),
Addresses: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
},
}).View(),
}
peerNode := (&tailcfg.Node{
ID: 99,
StableID: "peer-stable",
Key: key.NewNode().Public(),
}).View()
peerAddr := netip.MustParseAddr("100.64.1.2")
lu := &user.User{Username: "localice"}
tests := []struct {
name string
url string
want string
}{
{
name: "all_variables",
url: "https://control.example.com/check?src=$SRC_NODE_IP&srcid=$SRC_NODE_ID&dst=$DST_NODE_IP&dstid=$DST_NODE_ID&sshuser=$SSH_USER&local=$LOCAL_USER",
want: "https://control.example.com/check?src=100.64.1.2&srcid=99&dst=100.64.0.1&dstid=42&sshuser=alice&local=localice",
},
{
name: "no_variables",
url: "https://control.example.com/check?static=true",
want: "https://control.example.com/check?static=true",
},
{
name: "url_encoding",
url: "https://control.example.com/check?user=$SSH_USER",
want: "https://control.example.com/check?user=alice%40example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sshUser := "alice"
if tt.name == "url_encoding" {
sshUser = "alice@example.com"
}
got := expandDelegateURL(tt.url, nm, peerNode, peerAddr, sshUser, lu)
if got != tt.want {
t.Errorf("expandDelegateURL() =\n %s\nwant:\n %s", got, tt.want)
}
})
}
}
func TestWriteNDJSON(t *testing.T) {
var buf bytes.Buffer
// Write a status message.
writeNDJSON(&buf, nil, rshStatusMessage{Status: "waiting"})
// Write an rshResponse.
writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.1:1234", Token: "abcd"})
// Verify output is two newline-delimited JSON lines.
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
if len(lines) != 2 {
t.Fatalf("got %d lines, want 2:\n%s", len(lines), buf.String())
}
// Verify first line is a status message.
var msg rshStatusMessage
if err := json.Unmarshal([]byte(lines[0]), &msg); err != nil {
t.Fatalf("unmarshal line 0: %v", err)
}
if msg.Status != "waiting" {
t.Errorf("status = %q, want %q", msg.Status, "waiting")
}
// Verify second line is a response.
var resp rshResponse
if err := json.Unmarshal([]byte(lines[1]), &resp); err != nil {
t.Fatalf("unmarshal line 1: %v", err)
}
if resp.Addr != "100.64.0.1:1234" {
t.Errorf("addr = %q, want %q", resp.Addr, "100.64.0.1:1234")
}
if resp.Token != "abcd" {
t.Errorf("token = %q, want %q", resp.Token, "abcd")
}
}
func TestNDJSONStreamParsing(t *testing.T) {
// Simulate a streaming NDJSON response as the CLI would see it.
var buf bytes.Buffer
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Checking with control plane..."})
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Waiting for approval..."})
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Access approved"})
writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.5:4567", Token: "deadbeef"})
// Parse like the CLI does.
scanner := bufio.NewScanner(&buf)
var statusMessages []string
var finalResp rshResponse
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
var candidate rshResponse
if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" {
finalResp = candidate
continue
}
var msg rshStatusMessage
if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" {
statusMessages = append(statusMessages, msg.Status)
}
}
if len(statusMessages) != 3 {
t.Fatalf("got %d status messages, want 3", len(statusMessages))
}
if statusMessages[0] != "Checking with control plane..." {
t.Errorf("status[0] = %q, want %q", statusMessages[0], "Checking with control plane...")
}
if statusMessages[2] != "Access approved" {
t.Errorf("status[2] = %q, want %q", statusMessages[2], "Access approved")
}
if finalResp.Addr != "100.64.0.5:4567" {
t.Errorf("addr = %q, want %q", finalResp.Addr, "100.64.0.5:4567")
}
if finalResp.Token != "deadbeef" {
t.Errorf("token = %q, want %q", finalResp.Token, "deadbeef")
}
}
func TestEvalSSHPolicyHoldAndDelegate(t *testing.T) {
now := timeVal(2025, 1, 1)
node := (&tailcfg.Node{
ID: 1,
StableID: "stable1",
Key: key.NewNode().Public(),
}).View()
uprof := tailcfg.UserProfile{
LoginName: "alice@example.com",
}
srcAddr := netip.MustParseAddr("100.64.1.2")
tests := []struct {
name string
pol *tailcfg.SSHPolicy
sshUser string
wantResult evalResult
wantUser string
wantURL string
}{
{
name: "hold_and_delegate_with_message",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{
HoldAndDelegate: "https://control.example.com/approve?user=$SSH_USER",
Message: "Please approve in the admin panel",
},
},
},
},
sshUser: "alice",
wantResult: evalHoldDelegate,
wantUser: "alice",
wantURL: "https://control.example.com/approve?user=$SSH_USER",
},
{
name: "hold_with_specific_user_mapping",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}},
SSHUsers: map[string]string{"root": "admin"},
Action: &tailcfg.SSHAction{
HoldAndDelegate: "https://control.example.com/check",
},
},
},
},
sshUser: "root",
wantResult: evalHoldDelegate,
wantUser: "admin",
wantURL: "https://control.example.com/check",
},
{
name: "hold_rejects_unmapped_user",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"root": "admin"},
Action: &tailcfg.SSHAction{
HoldAndDelegate: "https://control.example.com/check",
},
},
},
},
sshUser: "unknown",
wantResult: evalRejectedUser,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now)
if result != tt.wantResult {
t.Errorf("result = %v, want %v", result, tt.wantResult)
}
if tt.wantUser != "" && localUser != tt.wantUser {
t.Errorf("localUser = %q, want %q", localUser, tt.wantUser)
}
if tt.wantURL != "" && action != nil && action.HoldAndDelegate != tt.wantURL {
t.Errorf("HoldAndDelegate = %q, want %q", action.HoldAndDelegate, tt.wantURL)
}
if tt.wantResult == evalHoldDelegate && action == nil {
t.Error("expected non-nil action for evalHoldDelegate result")
}
})
}
}
func TestExpandDelegateURLNilFields(t *testing.T) {
// Test with minimal/nil fields to ensure no panics.
lu := &user.User{Username: "bob"}
peerAddr := netip.MustParseAddr("100.64.0.2")
// Nil netmap, invalid peer node.
got := expandDelegateURL(
"https://control.example.com/check?dst=$DST_NODE_ID&src=$SRC_NODE_ID",
nil,
tailcfg.NodeView{}, // invalid
peerAddr,
"bob",
lu,
)
// Should not panic; missing IDs should be empty strings.
if strings.Contains(got, "$DST_NODE_ID") {
t.Errorf("unexpanded variable in URL: %s", got)
}
if strings.Contains(got, "$SRC_NODE_ID") {
t.Errorf("unexpanded variable in URL: %s", got)
}
}
func timeVal(year, month, day int) time.Time {
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
}

156
feature/rsh/localapi.go Normal file
View File

@ -0,0 +1,156 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"tailscale.com/ipn/localapi"
"tailscale.com/tailcfg"
"tailscale.com/util/clientmetric"
)
var (
metricLocalAPIRshCalls = clientmetric.NewCounter("localapi_rsh")
)
func init() {
localapi.Register("rsh", serveRsh)
}
// localRshRequest is the JSON body the CLI sends to POST /localapi/v0/rsh.
// It includes the target peer information.
type localRshRequest struct {
// PeerID is the StableNodeID of the target peer.
PeerID tailcfg.StableNodeID `json:"peer"`
// User is the SSH user to connect as.
User string `json:"user"`
// Command is the command to execute on the remote.
Command string `json:"command,omitempty"`
}
// serveRsh proxies an rsh setup request to the target peer's PeerAPI.
//
// POST /localapi/v0/rsh
//
// Request body: JSON localRshRequest
// Response body: JSON rshResponse (addr + token from the remote)
func serveRsh(h *localapi.Handler, w http.ResponseWriter, r *http.Request) {
metricLocalAPIRshCalls.Add(1)
if !h.PermitRead {
http.Error(w, "rsh access denied", http.StatusForbidden)
return
}
if r.Method != "POST" {
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
return
}
var req localRshRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
if req.PeerID == "" {
http.Error(w, "peer is required", http.StatusBadRequest)
return
}
if req.User == "" {
http.Error(w, "user is required", http.StatusBadRequest)
return
}
b := h.LocalBackend()
nm := b.NetMap()
if nm == nil {
http.Error(w, "no netmap available", http.StatusInternalServerError)
return
}
// Find the peer and its PeerAPI base URL.
var peerAPIBaseURL string
for _, p := range nm.Peers {
if p.StableID() == req.PeerID {
peerAPIBaseURL = b.PeerAPIBase(p)
break
}
}
if peerAPIBaseURL == "" {
http.Error(w, "peer not found or no PeerAPI available", http.StatusNotFound)
return
}
// Build the PeerAPI request.
peerReqBody := rshRequest{
User: req.User,
Command: req.Command,
}
bodyBytes, err := json.Marshal(peerReqBody)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
peerURL := strings.TrimRight(peerAPIBaseURL, "/") + "/v0/rsh"
peerReq, err := http.NewRequestWithContext(r.Context(), "POST", peerURL, bytes.NewReader(bodyBytes))
if err != nil {
http.Error(w, "internal error: "+err.Error(), http.StatusInternalServerError)
return
}
peerReq.Header.Set("Content-Type", "application/json")
// Use the PeerAPI transport to dial the remote peer.
tr := b.Dialer().PeerAPITransport()
resp, err := tr.RoundTrip(peerReq)
if err != nil {
h.Logf("rsh: failed to reach peer %s: %v", req.PeerID, err)
http.Error(w, fmt.Sprintf("failed to reach peer: %v", err), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
h.Logf("rsh: peer returned status %d: %s", resp.StatusCode, string(body))
http.Error(w, fmt.Sprintf("peer error: %s", string(body)), resp.StatusCode)
return
}
// Pass through the response from the peer. If the peer is using
// streaming NDJSON (check mode / HoldAndDelegate), we forward
// each line as it arrives so the CLI can display status messages.
ct := resp.Header.Get("Content-Type")
w.Header().Set("Content-Type", ct)
if strings.HasPrefix(ct, "application/x-ndjson") {
// Streaming mode: flush each line as it arrives.
flusher, _ := w.(http.Flusher)
w.WriteHeader(http.StatusOK)
buf := make([]byte, 4096)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
w.Write(buf[:n])
if flusher != nil {
flusher.Flush()
}
}
if err != nil {
return
}
}
} else {
// Simple JSON response: pass through directly.
io.Copy(w, resp.Body)
}
}

152
feature/rsh/policy.go Normal file
View File

@ -0,0 +1,152 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"errors"
"net/netip"
"time"
"tailscale.com/tailcfg"
)
// evalResult is the result of SSH policy evaluation.
type evalResult int
const (
evalAccepted evalResult = iota // rule matched with Accept
evalRejected // no matching rule, or explicit Reject
evalRejectedUser // principal matched but user mapping failed
evalHoldDelegate // rule matched with HoldAndDelegate (check mode)
)
// evalSSHPolicy evaluates the SSH policy for the given parameters.
// This replicates the core matching logic from ssh/tailssh without
// depending on the SSH connection type.
//
// It returns the matching action, the mapped local user, and the evaluation result.
func evalSSHPolicy(
pol *tailcfg.SSHPolicy,
node tailcfg.NodeView,
uprof tailcfg.UserProfile,
srcAddr netip.Addr,
sshUser string,
now time.Time,
) (action *tailcfg.SSHAction, localUser string, result evalResult) {
if pol == nil {
return nil, "", evalRejected
}
failedOnUser := false
for _, r := range pol.Rules {
if a, lu, err := matchRule(r, node, uprof, srcAddr, sshUser, now); err == nil {
if a.HoldAndDelegate != "" {
return a, lu, evalHoldDelegate
}
return a, lu, evalAccepted
} else if errors.Is(err, errUserMatch) {
failedOnUser = true
}
}
if failedOnUser {
return nil, "", evalRejectedUser
}
return nil, "", evalRejected
}
var (
errNilRule = errors.New("nil rule")
errNilAction = errors.New("nil action")
errRuleExpired = errors.New("rule expired")
errPrincipalMatch = errors.New("principal didn't match")
errUserMatch = errors.New("user didn't match")
)
// matchRule checks whether a single SSHRule matches the given parameters.
func matchRule(
r *tailcfg.SSHRule,
node tailcfg.NodeView,
uprof tailcfg.UserProfile,
srcAddr netip.Addr,
sshUser string,
now time.Time,
) (action *tailcfg.SSHAction, localUser string, err error) {
if r == nil {
return nil, "", errNilRule
}
if r.Action == nil {
return nil, "", errNilAction
}
if r.RuleExpires != nil && r.RuleExpires.Before(now) {
return nil, "", errRuleExpired
}
if !anyPrincipalMatches(r.Principals, node, uprof, srcAddr) {
return nil, "", errPrincipalMatch
}
if !r.Action.Reject {
localUser = mapLocalUser(r.SSHUsers, sshUser)
if localUser == "" {
return nil, "", errUserMatch
}
}
return r.Action, localUser, nil
}
// anyPrincipalMatches reports whether any of the given principals match
// the Tailscale identity of the connecting peer.
func anyPrincipalMatches(
ps []*tailcfg.SSHPrincipal,
node tailcfg.NodeView,
uprof tailcfg.UserProfile,
srcAddr netip.Addr,
) bool {
for _, p := range ps {
if p == nil {
continue
}
if principalMatchesTailscaleIdentity(p, node, uprof, srcAddr) {
return true
}
}
return false
}
// principalMatchesTailscaleIdentity reports whether a principal matches
// the Tailscale identity of the connecting peer.
func principalMatchesTailscaleIdentity(
p *tailcfg.SSHPrincipal,
node tailcfg.NodeView,
uprof tailcfg.UserProfile,
srcAddr netip.Addr,
) bool {
if p.Any {
return true
}
if !p.Node.IsZero() && node.Valid() && p.Node == node.StableID() {
return true
}
if p.NodeIP != "" {
if ip, _ := netip.ParseAddr(p.NodeIP); ip == srcAddr {
return true
}
}
if p.UserLogin != "" && uprof.LoginName == p.UserLogin {
return true
}
return false
}
// mapLocalUser maps an SSH user to a local user using the SSHUsers map
// from a policy rule.
func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) string {
v, ok := ruleSSHUsers[reqSSHUser]
if !ok {
v = ruleSSHUsers["*"]
}
if v == "=" {
return reqSSHUser
}
return v
}

257
feature/rsh/policy_test.go Normal file
View File

@ -0,0 +1,257 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"net/netip"
"testing"
"time"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func TestEvalSSHPolicy(t *testing.T) {
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
node := (&tailcfg.Node{
ID: 1,
StableID: "stable1",
Key: key.NewNode().Public(),
}).View()
uprof := tailcfg.UserProfile{
LoginName: "alice@example.com",
}
srcAddr := netip.MustParseAddr("100.64.1.2")
tests := []struct {
name string
pol *tailcfg.SSHPolicy
sshUser string
wantResult evalResult
wantUser string
}{
{
name: "nil_policy",
pol: nil,
sshUser: "root",
wantResult: evalRejected,
},
{
name: "empty_policy",
pol: &tailcfg.SSHPolicy{},
sshUser: "root",
wantResult: evalRejected,
},
{
name: "accept_any_wildcard_user",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "alice",
wantResult: evalAccepted,
wantUser: "alice",
},
{
name: "accept_specific_user_mapping",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"root": "admin"},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "root",
wantResult: evalAccepted,
wantUser: "admin",
},
{
name: "reject_unmapped_user",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"root": "admin"},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "unknown",
wantResult: evalRejectedUser,
},
{
name: "match_by_node_stable_id",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Node: "stable1"}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "bob",
wantResult: evalAccepted,
wantUser: "bob",
},
{
name: "reject_wrong_node",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Node: "other-node"}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "bob",
wantResult: evalRejected,
},
{
name: "match_by_node_ip",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "alice",
wantResult: evalAccepted,
wantUser: "alice",
},
{
name: "match_by_user_login",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "alice",
wantResult: evalAccepted,
wantUser: "alice",
},
{
name: "hold_and_delegate_returns_eval_hold",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{HoldAndDelegate: "https://example.com/approve"},
},
},
},
sshUser: "alice",
wantResult: evalHoldDelegate,
},
{
name: "explicit_reject_action",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Reject: true},
},
},
},
sshUser: "alice",
wantResult: evalAccepted, // matchRule succeeds for Reject rules (no SSHUsers check)
},
{
name: "expired_rule_skipped",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
RuleExpires: timePtr(now.Add(-time.Hour)),
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "alice",
wantResult: evalRejected,
},
{
name: "first_matching_rule_wins",
pol: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}},
SSHUsers: map[string]string{"alice": "localice"},
Action: &tailcfg.SSHAction{Accept: true},
},
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{"*": "="},
Action: &tailcfg.SSHAction{Accept: true},
},
},
},
sshUser: "alice",
wantResult: evalAccepted,
wantUser: "localice",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now)
if result != tt.wantResult {
t.Errorf("result = %v, want %v", result, tt.wantResult)
}
if tt.wantUser != "" && localUser != tt.wantUser {
t.Errorf("localUser = %q, want %q", localUser, tt.wantUser)
}
_ = action // not checked in most tests
})
}
}
func TestMapLocalUser(t *testing.T) {
tests := []struct {
name string
sshUsers map[string]string
reqUser string
wantResult string
}{
{"exact_match", map[string]string{"root": "admin"}, "root", "admin"},
{"wildcard_match", map[string]string{"*": "defaultuser"}, "anyone", "defaultuser"},
{"identity_match", map[string]string{"*": "="}, "alice", "alice"},
{"no_match", map[string]string{"root": "admin"}, "unknown", ""},
{"exact_over_wildcard", map[string]string{"root": "admin", "*": "default"}, "root", "admin"},
{"nil_map", nil, "alice", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mapLocalUser(tt.sshUsers, tt.reqUser)
if got != tt.wantResult {
t.Errorf("mapLocalUser(%v, %q) = %q, want %q", tt.sshUsers, tt.reqUser, got, tt.wantResult)
}
})
}
}
func timePtr(t time.Time) *time.Time { return &t }

157
feature/rsh/protocol.go Normal file
View File

@ -0,0 +1,157 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
// Package rsh implements a fast remote shell transport over Tailscale,
// designed as an rsync -e compatible replacement for SSH. It uses a PeerAPI
// endpoint for session setup and a raw TCP data channel for I/O,
// avoiding SSH's double encryption and suboptimal buffering.
package rsh
import (
"encoding/binary"
"fmt"
"io"
"sync"
)
// Channel type constants for the wire protocol.
// The protocol is length-prefixed framing:
//
// [1 byte: channel] [4 bytes: length (big-endian)] [N bytes: payload]
const (
// ChanStdin is data from client to server (remote process stdin).
ChanStdin byte = 0x00
// ChanStdout is data from server to client (remote process stdout).
ChanStdout byte = 0x01
// ChanStderr is data from server to client (remote process stderr).
ChanStderr byte = 0x02
// ChanExit is the exit code from the remote process.
// Payload is a 4-byte big-endian signed integer exit code.
// Sent by server to client, then the server closes the connection.
ChanExit byte = 0x03
)
const (
// tokenLen is the length of the one-time authentication token.
tokenLen = 32
// maxFrameSize is the maximum payload size for a single frame.
// 256KB is a good balance between throughput and memory usage,
// matching typical rsync block sizes.
maxFrameSize = 256 * 1024
// frameHeaderSize is the size of the frame header (channel + length).
frameHeaderSize = 5
)
// frameWriter writes length-prefixed frames to an underlying writer.
// It is safe for concurrent use.
type frameWriter struct {
mu sync.Mutex
w io.Writer
}
// newFrameWriter creates a new frameWriter that writes to w.
func newFrameWriter(w io.Writer) *frameWriter {
return &frameWriter{w: w}
}
// WriteFrame writes a single frame with the given channel and payload.
func (fw *frameWriter) WriteFrame(ch byte, data []byte) error {
if len(data) > maxFrameSize {
return fmt.Errorf("rsh: frame payload too large: %d > %d", len(data), maxFrameSize)
}
fw.mu.Lock()
defer fw.mu.Unlock()
var hdr [frameHeaderSize]byte
hdr[0] = ch
binary.BigEndian.PutUint32(hdr[1:], uint32(len(data)))
if _, err := fw.w.Write(hdr[:]); err != nil {
return err
}
if len(data) > 0 {
if _, err := fw.w.Write(data); err != nil {
return err
}
}
return nil
}
// WriteExitCode writes an exit code frame and is a convenience wrapper.
func (fw *frameWriter) WriteExitCode(code int) error {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], uint32(code))
return fw.WriteFrame(ChanExit, buf[:])
}
// frameReader reads length-prefixed frames from an underlying reader.
type frameReader struct {
r io.Reader
buf []byte // reusable buffer for payloads
}
// newFrameReader creates a new frameReader that reads from r.
func newFrameReader(r io.Reader) *frameReader {
return &frameReader{
r: r,
buf: make([]byte, 0, 32*1024), // start small, grow as needed
}
}
// ReadFrame reads the next frame, returning the channel type and payload.
// The returned payload slice is valid until the next call to ReadFrame.
func (fr *frameReader) ReadFrame() (ch byte, data []byte, err error) {
var hdr [frameHeaderSize]byte
if _, err := io.ReadFull(fr.r, hdr[:]); err != nil {
return 0, nil, err
}
ch = hdr[0]
n := binary.BigEndian.Uint32(hdr[1:])
if n > maxFrameSize {
return 0, nil, fmt.Errorf("rsh: frame too large: %d > %d", n, maxFrameSize)
}
if int(n) > cap(fr.buf) {
fr.buf = make([]byte, n)
} else {
fr.buf = fr.buf[:n]
}
if n > 0 {
if _, err := io.ReadFull(fr.r, fr.buf); err != nil {
return 0, nil, err
}
}
return ch, fr.buf, nil
}
// channelWriter wraps a frameWriter to implement io.Writer for a specific channel.
type channelWriter struct {
fw *frameWriter
ch byte
}
// newChannelWriter returns an io.Writer that writes all data as frames on
// the given channel.
func newChannelWriter(fw *frameWriter, ch byte) io.Writer {
return &channelWriter{fw: fw, ch: ch}
}
func (cw *channelWriter) Write(p []byte) (int, error) {
written := 0
for len(p) > 0 {
chunk := p
if len(chunk) > maxFrameSize {
chunk = chunk[:maxFrameSize]
}
if err := cw.fw.WriteFrame(cw.ch, chunk); err != nil {
return written, err
}
written += len(chunk)
p = p[len(chunk):]
}
return written, nil
}

View File

@ -0,0 +1,188 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package rsh
import (
"bytes"
"encoding/binary"
"io"
"testing"
)
func TestFrameRoundtrip(t *testing.T) {
var buf bytes.Buffer
fw := newFrameWriter(&buf)
fr := newFrameReader(&buf)
// Write several frames.
if err := fw.WriteFrame(ChanStdout, []byte("hello")); err != nil {
t.Fatalf("WriteFrame stdout: %v", err)
}
if err := fw.WriteFrame(ChanStderr, []byte("world")); err != nil {
t.Fatalf("WriteFrame stderr: %v", err)
}
if err := fw.WriteFrame(ChanStdin, []byte("input")); err != nil {
t.Fatalf("WriteFrame stdin: %v", err)
}
if err := fw.WriteExitCode(42); err != nil {
t.Fatalf("WriteExitCode: %v", err)
}
// Read them back.
ch, data, err := fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 1: %v", err)
}
if ch != ChanStdout || string(data) != "hello" {
t.Errorf("frame 1: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdout, "hello")
}
ch, data, err = fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 2: %v", err)
}
if ch != ChanStderr || string(data) != "world" {
t.Errorf("frame 2: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStderr, "world")
}
ch, data, err = fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 3: %v", err)
}
if ch != ChanStdin || string(data) != "input" {
t.Errorf("frame 3: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdin, "input")
}
ch, data, err = fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 4: %v", err)
}
if ch != ChanExit {
t.Errorf("frame 4: got ch=%d, want ch=%d", ch, ChanExit)
}
if len(data) != 4 {
t.Fatalf("exit frame data len = %d, want 4", len(data))
}
code := int(binary.BigEndian.Uint32(data))
if code != 42 {
t.Errorf("exit code = %d, want 42", code)
}
// Should get EOF now.
_, _, err = fr.ReadFrame()
if err != io.EOF && err != io.ErrUnexpectedEOF {
t.Errorf("expected EOF after all frames, got: %v", err)
}
}
func TestFrameEmptyPayload(t *testing.T) {
var buf bytes.Buffer
fw := newFrameWriter(&buf)
fr := newFrameReader(&buf)
if err := fw.WriteFrame(ChanStdin, nil); err != nil {
t.Fatalf("WriteFrame empty: %v", err)
}
ch, data, err := fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame: %v", err)
}
if ch != ChanStdin {
t.Errorf("ch = %d, want %d", ch, ChanStdin)
}
if len(data) != 0 {
t.Errorf("data len = %d, want 0", len(data))
}
}
func TestFrameTooLarge(t *testing.T) {
fw := newFrameWriter(io.Discard)
data := make([]byte, maxFrameSize+1)
if err := fw.WriteFrame(ChanStdout, data); err == nil {
t.Error("expected error for oversized frame, got nil")
}
}
func TestFrameReaderTooLarge(t *testing.T) {
// Construct a frame with length > maxFrameSize.
var buf bytes.Buffer
var hdr [frameHeaderSize]byte
hdr[0] = ChanStdout
binary.BigEndian.PutUint32(hdr[1:], maxFrameSize+1)
buf.Write(hdr[:])
fr := newFrameReader(&buf)
_, _, err := fr.ReadFrame()
if err == nil {
t.Error("expected error for oversized frame in reader, got nil")
}
}
func TestChannelWriter(t *testing.T) {
var buf bytes.Buffer
fw := newFrameWriter(&buf)
cw := newChannelWriter(fw, ChanStdout)
data := []byte("hello world from channel writer")
n, err := cw.Write(data)
if err != nil {
t.Fatalf("Write: %v", err)
}
if n != len(data) {
t.Errorf("n = %d, want %d", n, len(data))
}
fr := newFrameReader(&buf)
ch, got, err := fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame: %v", err)
}
if ch != ChanStdout {
t.Errorf("ch = %d, want %d", ch, ChanStdout)
}
if !bytes.Equal(got, data) {
t.Errorf("data mismatch: got %q, want %q", got, data)
}
}
func TestChannelWriterChunking(t *testing.T) {
var buf bytes.Buffer
fw := newFrameWriter(&buf)
cw := newChannelWriter(fw, ChanStdout)
// Write more than maxFrameSize to verify chunking.
data := make([]byte, maxFrameSize+100)
for i := range data {
data[i] = byte(i % 256)
}
n, err := cw.Write(data)
if err != nil {
t.Fatalf("Write: %v", err)
}
if n != len(data) {
t.Errorf("n = %d, want %d", n, len(data))
}
// Should produce two frames.
fr := newFrameReader(&buf)
ch, chunk1, err := fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 1: %v", err)
}
if ch != ChanStdout || len(chunk1) != maxFrameSize {
t.Errorf("frame 1: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk1), ChanStdout, maxFrameSize)
}
ch, chunk2, err := fr.ReadFrame()
if err != nil {
t.Fatalf("ReadFrame 2: %v", err)
}
if ch != ChanStdout || len(chunk2) != 100 {
t.Errorf("frame 2: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk2), ChanStdout, 100)
}
}

739
feature/rsh/rsh.go Normal file
View File

@ -0,0 +1,739 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/tailcfg"
"tailscale.com/types/netmap"
"tailscale.com/util/backoff"
"tailscale.com/util/clientmetric"
"tailscale.com/util/osuser"
)
var (
metricRshCalls = clientmetric.NewCounter("peerapi_rsh")
metricRshAccepts = clientmetric.NewCounter("peerapi_rsh_accept")
metricRshRejects = clientmetric.NewCounter("peerapi_rsh_reject")
)
func init() {
ipnlocal.RegisterPeerAPIHandler("/v0/rsh", handleRsh)
}
// rshRequest is the JSON body sent to POST /v0/rsh.
type rshRequest struct {
// User is the requested SSH user (will be mapped via SSHUsers policy).
User string `json:"user"`
// Command is the command to execute. If empty, the user's default
// login shell is started.
Command string `json:"command,omitempty"`
}
// rshResponse is returned by a successful POST /v0/rsh.
// In streaming mode (check mode), this is the final JSON line in
// the newline-delimited JSON stream.
type rshResponse struct {
// Addr is the Tailscale IP:port to connect to for the data channel.
Addr string `json:"addr"`
// Token is the hex-encoded one-time authentication token that must
// be sent as the first bytes on the data channel connection.
Token string `json:"token"`
}
// rshStatusMessage is sent as a streaming JSON line during the
// HoldAndDelegate (check mode) flow before the final rshResponse.
// Each message is a newline-delimited JSON object.
type rshStatusMessage struct {
// Status is a human-readable status message to display to the user.
Status string `json:"status"`
}
// netstackTCPListenerFunc is the type of a function that creates a TCP
// listener on the netstack (gVisor) network stack. It is set by the
// netstack package at init time.
//
// We use a function hook instead of a type assertion on NetstackImpl
// because netstack.Impl.ListenTCP returns *gonet.TCPListener (not
// net.Listener), and importing gonet would create an unwanted gVisor
// dependency.
var netstackListenTCP func(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error)
const linux = "linux"
func handleRsh(ph ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) {
metricRshCalls.Add(1)
logf := ph.Logf
if r.Method != "POST" {
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
return
}
b := ph.LocalBackend()
// Check that SSH is enabled on this node.
if !b.ShouldRunSSH() {
logf("rsh: denied; SSH not enabled")
metricRshRejects.Add(1)
http.Error(w, "SSH not enabled on this node", http.StatusForbidden)
return
}
// Parse request.
var req rshRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
logf("rsh: bad request body: %v", err)
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
if req.User == "" {
http.Error(w, "user is required", http.StatusBadRequest)
return
}
// Evaluate SSH policy.
peerNode := ph.Peer()
peerAddr := ph.RemoteAddr().Addr()
nm := b.NetMap()
if nm == nil {
logf("rsh: no netmap")
http.Error(w, "no netmap available", http.StatusInternalServerError)
return
}
sshPol := nm.SSHPolicy
if sshPol == nil {
logf("rsh: no SSH policy")
metricRshRejects.Add(1)
http.Error(w, "no SSH policy configured", http.StatusForbidden)
return
}
// Look up the peer's user profile for policy matching.
_, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr())
if !ok {
logf("rsh: unknown peer %v", ph.RemoteAddr())
metricRshRejects.Add(1)
http.Error(w, "unknown peer", http.StatusForbidden)
return
}
action, localUser, result := evalSSHPolicy(sshPol, peerNode, uprof, peerAddr, req.User, time.Now())
switch result {
case evalAccepted:
if action.Reject {
logf("rsh: policy explicitly rejects %v -> %s@%s", peerAddr, req.User, localUser)
metricRshRejects.Add(1)
http.Error(w, "access denied by policy", http.StatusForbidden)
return
}
// Good, accepted. action may still have a Message to send.
case evalHoldDelegate:
// Check mode: we need to poll the control plane for approval.
// The response uses streaming newline-delimited JSON so
// status messages can be sent while we wait.
case evalRejectedUser:
logf("rsh: user %q not mapped for peer %v", req.User, peerAddr)
metricRshRejects.Add(1)
http.Error(w, fmt.Sprintf("user %q not permitted", req.User), http.StatusForbidden)
return
case evalRejected:
logf("rsh: policy rejects %v -> %s", peerAddr, req.User)
metricRshRejects.Add(1)
http.Error(w, "access denied by policy", http.StatusForbidden)
return
}
// Look up the local user. We need this for both the immediate accept
// path and the HoldAndDelegate path (to expand delegate URL variables).
lu, loginShell, err := osuser.LookupByUsernameWithShell(localUser)
if err != nil {
logf("rsh: user lookup failed for %q: %v", localUser, err)
http.Error(w, fmt.Sprintf("user %q not found", localUser), http.StatusInternalServerError)
return
}
groupIDs, err := osuser.GetGroupIds(lu)
if err != nil {
logf("rsh: group lookup failed for %q: %v", localUser, err)
http.Error(w, "failed to look up user groups", http.StatusInternalServerError)
return
}
// If HoldAndDelegate, run the check mode loop to get a terminal action.
// We use streaming JSON so status messages can be sent to the client.
if result == evalHoldDelegate {
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
action, err = resolveCheckMode(r.Context(), b, action, nm, peerNode, peerAddr, req.User, lu, w, flusher, logf)
if err != nil {
// Connection is already streaming; send error as a status message.
logf("rsh: check mode failed: %v", err)
writeNDJSON(w, flusher, rshStatusMessage{Status: fmt.Sprintf("check mode error: %v", err)})
return
}
if action.Reject {
logf("rsh: check mode rejected %v -> %s", peerAddr, req.User)
metricRshRejects.Add(1)
msg := "access denied"
if action.Message != "" {
msg = action.Message
}
writeNDJSON(w, flusher, rshStatusMessage{Status: msg})
return
}
if !action.Accept {
logf("rsh: check mode returned non-terminal action for %v -> %s", peerAddr, req.User)
metricRshRejects.Add(1)
writeNDJSON(w, flusher, rshStatusMessage{Status: "unexpected response from control"})
return
}
}
// Find a local Tailscale IP to listen on.
listenAddr, err := pickListenAddr(nm, peerAddr)
if err != nil {
logf("rsh: no listen address: %v", err)
if result == evalHoldDelegate {
flusher, _ := w.(http.Flusher)
writeNDJSON(w, flusher, rshStatusMessage{Status: "no suitable listen address"})
} else {
http.Error(w, "no suitable listen address", http.StatusInternalServerError)
}
return
}
// Create the listener.
ln, err := listenTailscale(b, listenAddr)
if err != nil {
logf("rsh: listen failed: %v", err)
if result == evalHoldDelegate {
flusher, _ := w.(http.Flusher)
writeNDJSON(w, flusher, rshStatusMessage{Status: "failed to create listener"})
} else {
http.Error(w, "failed to create listener", http.StatusInternalServerError)
}
return
}
// Generate one-time token.
var tokenBytes [tokenLen]byte
if _, err := rand.Read(tokenBytes[:]); err != nil {
ln.Close()
logf("rsh: rand failed: %v", err)
if result == evalHoldDelegate {
flusher, _ := w.(http.Flusher)
writeNDJSON(w, flusher, rshStatusMessage{Status: "internal error"})
} else {
http.Error(w, "internal error", http.StatusInternalServerError)
}
return
}
tokenHex := hex.EncodeToString(tokenBytes[:])
metricRshAccepts.Add(1)
// Start the session handler in a goroutine. It will accept one
// connection, verify the token, and wire up the incubator process.
go handleRshSession(b, ln, tokenBytes[:], peerAddr, lu, loginShell, groupIDs, req, ph, logf)
// Return the listen address and token to the client.
resp := rshResponse{
Addr: ln.Addr().String(),
Token: tokenHex,
}
if result == evalHoldDelegate {
// Streaming mode: send a final accept message then the response.
flusher, _ := w.(http.Flusher)
if action.Message != "" {
writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message})
}
writeNDJSON(w, flusher, resp)
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}
// writeNDJSON writes v as a single newline-delimited JSON line to w
// and flushes. This is used for the streaming check mode response.
func writeNDJSON(w io.Writer, flusher http.Flusher, v any) {
json.NewEncoder(w).Encode(v) // Encode appends '\n'
if flusher != nil {
flusher.Flush()
}
}
// resolveCheckMode runs the HoldAndDelegate loop, polling the control plane
// until a terminal action (Accept or Reject) is returned. It sends status
// messages to the client as streaming JSON lines while waiting.
//
// This is the rsh equivalent of SSH's clientAuth HoldAndDelegate loop.
func resolveCheckMode(
ctx context.Context,
b *ipnlocal.LocalBackend,
action *tailcfg.SSHAction,
nm *netmap.NetworkMap,
peerNode tailcfg.NodeView,
peerAddr netip.Addr,
sshUser string,
lu *user.User,
w io.Writer,
flusher http.Flusher,
logf func(string, ...any),
) (*tailcfg.SSHAction, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
for {
if action.Message != "" {
writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message})
}
if action.Accept || action.Reject {
return action, nil
}
if action.HoldAndDelegate == "" {
return nil, fmt.Errorf("action has neither Accept, Reject, nor HoldAndDelegate")
}
delegateURL := expandDelegateURL(action.HoldAndDelegate, nm, peerNode, peerAddr, sshUser, lu)
logf("rsh: check mode: polling %s", delegateURL)
writeNDJSON(w, flusher, rshStatusMessage{Status: "Waiting for approval..."})
var err error
action, err = fetchSSHAction(ctx, b, delegateURL, logf)
if err != nil {
return nil, fmt.Errorf("fetching SSH action: %w", err)
}
}
}
// expandDelegateURL expands the variables in a HoldAndDelegate URL.
// The variables match those used by SSH: $SRC_NODE_IP, $SRC_NODE_ID,
// $DST_NODE_IP, $DST_NODE_ID, $SSH_USER, $LOCAL_USER.
func expandDelegateURL(
actionURL string,
nm *netmap.NetworkMap,
peerNode tailcfg.NodeView,
peerAddr netip.Addr,
sshUser string,
lu *user.User,
) string {
var dstNodeID string
if nm != nil {
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID()))
}
var srcNodeID string
if peerNode.Valid() {
srcNodeID = fmt.Sprint(int64(peerNode.ID()))
}
var dstNodeIP string
if nm != nil {
addrs := nm.GetAddresses()
for _, pfx := range addrs.All() {
if pfx.IsSingleIP() {
dstNodeIP = pfx.Addr().String()
break
}
}
}
return strings.NewReplacer(
"$SRC_NODE_IP", url.QueryEscape(peerAddr.String()),
"$SRC_NODE_ID", srcNodeID,
"$DST_NODE_IP", url.QueryEscape(dstNodeIP),
"$DST_NODE_ID", dstNodeID,
"$SSH_USER", url.QueryEscape(sshUser),
"$LOCAL_USER", url.QueryEscape(lu.Username),
).Replace(actionURL)
}
// fetchSSHAction polls a control plane URL over the Noise transport
// and returns the SSHAction. It retries with exponential backoff on
// transient errors, matching the behavior of SSH's fetchSSHAction.
func fetchSSHAction(ctx context.Context, b *ipnlocal.LocalBackend, url string, logf func(string, ...any)) (*tailcfg.SSHAction, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
bo := backoff.NewBackoff("rsh-fetch-ssh-action", logf, 10*time.Second)
for {
if err := ctx.Err(); err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
res, err := b.DoNoiseRequest(req)
if err != nil {
bo.BackOff(ctx, err)
continue
}
if res.StatusCode != 200 {
body, _ := io.ReadAll(res.Body)
res.Body.Close()
if len(body) > 1<<10 {
body = body[:1<<10]
}
logf("rsh: fetch of %v: %s, %s", url, res.Status, body)
bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status))
continue
}
a := new(tailcfg.SSHAction)
err = json.NewDecoder(res.Body).Decode(a)
res.Body.Close()
if err != nil {
logf("rsh: invalid SSHAction JSON from %v: %v", url, err)
bo.BackOff(ctx, err)
continue
}
return a, nil
}
}
// pickListenAddr selects a local Tailscale IP address that matches the
// address family of the peer. This ensures the data channel connection
// uses the same protocol version.
func pickListenAddr(nm *netmap.NetworkMap, peerAddr netip.Addr) (netip.Addr, error) {
addrs := nm.GetAddresses()
wantV4 := peerAddr.Is4()
for _, pfx := range addrs.All() {
if !pfx.IsSingleIP() {
continue
}
a := pfx.Addr()
if wantV4 && a.Is4() {
return a, nil
}
if !wantV4 && a.Is6() {
return a, nil
}
}
// Fallback: return any address.
for _, pfx := range addrs.All() {
if pfx.IsSingleIP() {
return pfx.Addr(), nil
}
}
return netip.Addr{}, fmt.Errorf("no Tailscale addresses available")
}
// listenTailscale creates a TCP listener on the given Tailscale IP.
// In netstack mode, it uses the gVisor stack via the netstackListenTCP hook.
// In kernel TUN mode, it uses the standard library.
func listenTailscale(b *ipnlocal.LocalBackend, addr netip.Addr) (net.Listener, error) {
network := "tcp4"
if addr.Is6() {
network = "tcp6"
}
listenAddr := netip.AddrPortFrom(addr, 0).String()
if b.Sys().IsNetstack() {
// In full netstack mode, we need to use the gVisor stack to listen
// since all local IP traffic is handled by netstack.
if netstackListenTCP == nil {
return nil, fmt.Errorf("netstack listener not available (rsh_netstack not linked)")
}
return netstackListenTCP(b, network, listenAddr)
}
// In kernel TUN mode, the Tailscale IP is assigned to the TUN device
// and the kernel handles routing. Standard net.Listen works.
return net.Listen(network, listenAddr)
}
// handleRshSession is run in a goroutine. It accepts a single connection
// from the listener, verifies the token and source, then spawns the
// remote command via the incubator.
func handleRshSession(
b *ipnlocal.LocalBackend,
ln net.Listener,
token []byte,
expectedPeer netip.Addr,
lu *user.User,
loginShell string,
groupIDs []string,
req rshRequest,
ph ipnlocal.PeerAPIHandler,
logf func(string, ...any),
) {
defer ln.Close()
// Set a deadline for the client to connect.
if dl, ok := ln.(interface{ SetDeadline(time.Time) error }); ok {
dl.SetDeadline(time.Now().Add(30 * time.Second))
}
conn, err := ln.Accept()
if err != nil {
logf("rsh: accept failed: %v", err)
return
}
ln.Close() // Only accept one connection.
defer conn.Close()
// Verify source IP.
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logf("rsh: unexpected remote addr type: %T", conn.RemoteAddr())
return
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
logf("rsh: invalid remote IP")
return
}
remoteIP = remoteIP.Unmap()
if remoteIP != expectedPeer {
logf("rsh: unexpected peer %v, expected %v", remoteIP, expectedPeer)
return
}
// Read and verify token.
var gotToken [tokenLen]byte
if _, err := io.ReadFull(conn, gotToken[:]); err != nil {
logf("rsh: failed to read token: %v", err)
return
}
if subtle.ConstantTimeCompare(gotToken[:], token) != 1 {
logf("rsh: invalid token from %v", remoteIP)
return
}
// Set TCP_NODELAY for low-latency rsync control messages.
if tc, ok := conn.(*net.TCPConn); ok {
tc.SetNoDelay(true)
}
logf("rsh: session accepted from %v as %s, command=%q", remoteIP, lu.Username, req.Command)
// Build and run the incubator command.
runIncubator(b, conn, lu, loginShell, groupIDs, req, ph, logf)
}
// runIncubator spawns the remote command using the existing SSH incubator
// mechanism for privilege dropping and PAM integration.
func runIncubator(
b *ipnlocal.LocalBackend,
conn net.Conn,
lu *user.User,
loginShell string,
groupIDs []string,
req rshRequest,
ph ipnlocal.PeerAPIHandler,
logf func(string, ...any),
) {
tailscaledPath, err := os.Executable()
if err != nil {
logf("rsh: os.Executable: %v", err)
sendExitCode(conn, 1)
return
}
peerNode := ph.Peer()
remoteUser := "unknown"
if peerNode.Valid() {
if peerNode.IsTagged() {
remoteUser = strings.Join(peerNode.Tags().AsSlice(), ",")
} else {
_, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr())
if ok {
remoteUser = uprof.LoginName
}
}
}
groups := strings.Join(groupIDs, ",")
isShell := req.Command == ""
incubatorArgs := []string{
"be-child",
"ssh",
"--login-shell=" + loginShell,
"--uid=" + lu.Uid,
"--gid=" + lu.Gid,
"--groups=" + groups,
"--local-user=" + lu.Username,
"--home-dir=" + lu.HomeDir,
"--remote-user=" + remoteUser,
"--remote-ip=" + ph.RemoteAddr().Addr().String(),
"--has-tty=false",
"--tty-name=",
}
if runtime.GOOS == linux && hostinfo.IsSELinuxEnforcing() {
incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing")
}
nm := b.NetMap()
if nm != nil && nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) {
incubatorArgs = append(incubatorArgs, "--force-v1-behavior")
}
if isShell {
incubatorArgs = append(incubatorArgs, "--shell")
} else {
incubatorArgs = append(incubatorArgs, "--cmd="+req.Command)
}
cmd := exec.Command(tailscaledPath, incubatorArgs...)
cmd.Dir = "/"
// Set up the environment for the child.
cmd.Env = []string{
"SHELL=" + loginShell,
"USER=" + lu.Username,
"HOME=" + lu.HomeDir,
"PATH=" + defaultPathForUser(lu),
}
// Create stdin/stdout/stderr pipes.
stdinPipe, err := cmd.StdinPipe()
if err != nil {
logf("rsh: stdin pipe: %v", err)
sendExitCode(conn, 1)
return
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
logf("rsh: stdout pipe: %v", err)
sendExitCode(conn, 1)
return
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
logf("rsh: stderr pipe: %v", err)
sendExitCode(conn, 1)
return
}
if err := cmd.Start(); err != nil {
logf("rsh: start incubator: %v", err)
sendExitCode(conn, 1)
return
}
fw := newFrameWriter(conn)
fr := newFrameReader(conn)
// Goroutine: read frames from client, write stdin to incubator.
stdinDone := make(chan struct{})
go func() {
defer close(stdinDone)
defer stdinPipe.Close()
for {
ch, data, err := fr.ReadFrame()
if err != nil {
return
}
if ch == ChanStdin {
if _, err := stdinPipe.Write(data); err != nil {
return
}
}
}
}()
// Goroutine: read stdout from incubator, write frames to client.
stdoutDone := make(chan struct{})
go func() {
defer close(stdoutDone)
buf := make([]byte, 64*1024)
for {
n, err := stdoutPipe.Read(buf)
if n > 0 {
if werr := fw.WriteFrame(ChanStdout, buf[:n]); werr != nil {
return
}
}
if err != nil {
return
}
}
}()
// Goroutine: read stderr from incubator, write frames to client.
stderrDone := make(chan struct{})
go func() {
defer close(stderrDone)
buf := make([]byte, 64*1024)
for {
n, err := stderrPipe.Read(buf)
if n > 0 {
if werr := fw.WriteFrame(ChanStderr, buf[:n]); werr != nil {
return
}
}
if err != nil {
return
}
}
}()
// Wait for the process to exit.
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
logf("rsh: wait: %v", err)
exitCode = 1
}
}
// Wait for output goroutines to drain.
<-stdoutDone
<-stderrDone
// Send exit code and close.
logf("rsh: session ended for %s, exit code %d", lu.Username, exitCode)
fw.WriteExitCode(exitCode)
}
// sendExitCode is a helper used before the framing writer is set up.
func sendExitCode(conn net.Conn, code int) {
fw := newFrameWriter(conn)
fw.WriteExitCode(code)
}
// defaultPathForUser returns an appropriate default PATH for the user.
// This is a simplified version of the logic in ssh/tailssh/user.go.
func defaultPathForUser(u *user.User) string {
if u.Uid == "0" {
return "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
}
return "/usr/local/bin:/usr/bin:/bin"
}
// envknobs for debugging.
var rshVerbose = envknob.RegisterBool("TS_DEBUG_RSH_VLOG")

View File

@ -0,0 +1,30 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
package rsh
import (
"net"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/wgengine/netstack"
)
func init() {
netstackListenTCP = netstackListenTCPImpl
}
func netstackListenTCPImpl(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error) {
ns, ok := b.Sys().Netstack.GetOK()
if !ok {
return nil, net.ErrClosed
}
// Type-assert to *netstack.Impl which has the ListenTCP method.
impl, ok := ns.(*netstack.Impl)
if !ok {
return nil, net.ErrClosed
}
return impl.ListenTCP(network, address)
}

View File

@ -653,6 +653,13 @@ func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) {
func (b *LocalBackend) Clock() tstime.Clock { return b.clock }
func (b *LocalBackend) Sys() *tsd.System { return b.sys }
// PeerAPIBase returns the "http://ip:port" URL base to reach a peer's PeerAPI.
// It returns the empty string if the peer doesn't support PeerAPI or there's
// no matching address family.
func (b *LocalBackend) PeerAPIBase(peer tailcfg.NodeView) string {
return peerAPIBase(b.NetMap(), peer)
}
// NodeBackend returns the current node's NodeBackend interface.
func (b *LocalBackend) NodeBackend() ipnext.NodeBackend {
return b.currentNode()