mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
feature/rsh: add a non-ssh encapsulated remote shell
The intent of this mode is to quack a bit like the long gone rsh, but implemented purely inside the client and providing a faster way to use rsync over tailscale by way of `rsync -e 'tailscale rsh dest`. This is an initial prototype.
This commit is contained in:
parent
0bac4223d1
commit
d4b565aa48
@ -259,6 +259,7 @@ change in the future.
|
||||
pingCmd,
|
||||
ncCmd,
|
||||
sshCmd,
|
||||
rshCmd,
|
||||
nilOrCall(maybeFunnelCmd),
|
||||
nilOrCall(maybeServeCmd),
|
||||
versionCmd,
|
||||
|
||||
450
cmd/tailscale/cli/rsh.go
Normal file
450
cmd/tailscale/cli/rsh.go
Normal file
@ -0,0 +1,450 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/peterbourgon/ff/v3/ffcli"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
)
|
||||
|
||||
var rshArgs struct {
|
||||
loginUser string // -l flag: SSH login user
|
||||
sshOption string // -o flag: SSH option (ignored, for compatibility)
|
||||
}
|
||||
|
||||
var rshCmd = &ffcli.Command{
|
||||
Name: "rsh",
|
||||
ShortUsage: "tailscale rsh [-l user] [user@]<host> [command...]",
|
||||
ShortHelp: "Execute a remote command over Tailscale without SSH overhead",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
The 'tailscale rsh' command executes a command on a remote Tailscale node
|
||||
using a direct TCP connection over the Tailscale network. Unlike SSH, it
|
||||
avoids double encryption (SSH + WireGuard) and SSH's suboptimal buffering.
|
||||
|
||||
It is designed to be used as an rsync -e transport replacement:
|
||||
|
||||
rsync -e 'tailscale rsh' -avz ./local/ user@host:/remote/
|
||||
|
||||
The remote node must have Tailscale SSH enabled, as rsh reuses the same
|
||||
SSH access policy for authorization.
|
||||
|
||||
SSH-compatible flags (-l user, -o option) are accepted and handled
|
||||
appropriately so that rsync and similar tools can invoke rsh as a
|
||||
drop-in remote shell replacement.
|
||||
|
||||
When used without a command, it starts the user's default login shell.
|
||||
`),
|
||||
FlagSet: func() *flag.FlagSet {
|
||||
fs := newFlagSet("rsh")
|
||||
fs.StringVar(&rshArgs.loginUser, "l", "", "remote login user (SSH-compatible)")
|
||||
fs.StringVar(&rshArgs.sshOption, "o", "", "SSH option (ignored, for compatibility)")
|
||||
return fs
|
||||
}(),
|
||||
Exec: runRsh,
|
||||
}
|
||||
|
||||
// rshFraming constants matching feature/rsh/protocol.go.
|
||||
const (
|
||||
rshChanStdin byte = 0x00
|
||||
rshChanStdout byte = 0x01
|
||||
rshChanStderr byte = 0x02
|
||||
rshChanExit byte = 0x03
|
||||
rshTokenLen = 32
|
||||
rshMaxFrame = 256 * 1024
|
||||
rshFrameHdrSize = 5
|
||||
)
|
||||
|
||||
func runRsh(ctx context.Context, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return errors.New("usage: tailscale rsh [user@]<host> [command...]")
|
||||
}
|
||||
|
||||
// Check tailscaled is running.
|
||||
st, err := localClient.Status(ctx)
|
||||
if err != nil {
|
||||
return fixTailscaledConnectError(err)
|
||||
}
|
||||
description, ok := isRunningOrStarting(st)
|
||||
if !ok {
|
||||
printf("%s\n", description)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
username, host, cmdArgs, err := parseRshArgs(args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The -l flag parsed by ffcli takes priority over user@host.
|
||||
// This handles cases like: tailscale rsh -l ubuntu james-ai
|
||||
// where ffcli parses -l before runRsh sees the args.
|
||||
if rshArgs.loginUser != "" {
|
||||
username = rshArgs.loginUser
|
||||
}
|
||||
|
||||
// If no explicit user, default to the current OS user.
|
||||
if username == "" {
|
||||
u, err := currentUser()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot determine current user: %w", err)
|
||||
}
|
||||
username = u
|
||||
}
|
||||
|
||||
// Resolve host to a peer.
|
||||
ps, ok := peerStatusFromArg(st, host)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown host %q; not found in Tailscale network", host)
|
||||
}
|
||||
|
||||
// Build the command string (rsync passes it as separate args).
|
||||
command := strings.Join(cmdArgs, " ")
|
||||
|
||||
// Request an rsh session via the LocalAPI.
|
||||
type localRshRequest struct {
|
||||
PeerID string `json:"peer"`
|
||||
User string `json:"user"`
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
reqBody := localRshRequest{
|
||||
PeerID: string(ps.ID),
|
||||
User: username,
|
||||
Command: command,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST",
|
||||
"http://"+apitype.LocalAPIHost+"/localapi/v0/rsh",
|
||||
bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := localClient.DoLocalRequest(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rsh setup: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("rsh setup failed: %s: %s", resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
type rshResponse struct {
|
||||
Addr string `json:"addr"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
type rshStatusMessage struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
var rshResp rshResponse
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(ct, "application/x-ndjson") {
|
||||
// Streaming check mode: read newline-delimited JSON lines.
|
||||
// Status messages go to stderr, the final rshResponse has addr+token.
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 64*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
// Try to decode as rshResponse (has "addr" field).
|
||||
var candidate rshResponse
|
||||
if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" {
|
||||
rshResp = candidate
|
||||
continue
|
||||
}
|
||||
// Otherwise, treat as a status message.
|
||||
var msg rshStatusMessage
|
||||
if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" {
|
||||
fmt.Fprintf(os.Stderr, "rsh: %s\n", msg.Status)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("rsh: reading streaming response: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Simple JSON response.
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rshResp); err != nil {
|
||||
return fmt.Errorf("rsh: invalid response: %w", err)
|
||||
}
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if rshResp.Addr == "" || rshResp.Token == "" {
|
||||
return errors.New("rsh: server returned empty address or token")
|
||||
}
|
||||
|
||||
// Parse the address to get host and port for DialTCP.
|
||||
addrHost, portStr, err := splitHostPort(rshResp.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rsh: invalid address %q: %w", rshResp.Addr, err)
|
||||
}
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rsh: invalid port %q: %w", portStr, err)
|
||||
}
|
||||
|
||||
// Decode the token.
|
||||
token, err := hex.DecodeString(rshResp.Token)
|
||||
if err != nil || len(token) != rshTokenLen {
|
||||
return fmt.Errorf("rsh: invalid token")
|
||||
}
|
||||
|
||||
// Connect to the data channel via tailscaled.
|
||||
conn, err := localClient.DialTCP(ctx, addrHost, uint16(port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("rsh: connect to %s: %w", rshResp.Addr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send the authentication token.
|
||||
if _, err := conn.Write(token); err != nil {
|
||||
return fmt.Errorf("rsh: send token: %w", err)
|
||||
}
|
||||
|
||||
// Run the framing protocol.
|
||||
return rshPumpIO(conn)
|
||||
}
|
||||
|
||||
// rshPumpIO handles the framing protocol between the local stdin/stdout/stderr
|
||||
// and the remote process over the connection.
|
||||
func rshPumpIO(conn io.ReadWriteCloser) error {
|
||||
// Goroutine: read stdin and send as ChanStdin frames.
|
||||
stdinDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(stdinDone)
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if n > 0 {
|
||||
if werr := writeFrame(conn, rshChanStdin, buf[:n]); werr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Send a zero-length stdin frame to signal EOF.
|
||||
writeFrame(conn, rshChanStdin, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Main loop: read frames from the connection and dispatch.
|
||||
var hdr [rshFrameHdrSize]byte
|
||||
for {
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// Connection closed without exit code.
|
||||
return fmt.Errorf("rsh: connection closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("rsh: read frame: %w", err)
|
||||
}
|
||||
ch := hdr[0]
|
||||
n := binary.BigEndian.Uint32(hdr[1:])
|
||||
if n > rshMaxFrame {
|
||||
return fmt.Errorf("rsh: frame too large: %d", n)
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case rshChanStdout:
|
||||
if _, err := io.CopyN(os.Stdout, conn, int64(n)); err != nil {
|
||||
return fmt.Errorf("rsh: stdout: %w", err)
|
||||
}
|
||||
case rshChanStderr:
|
||||
if _, err := io.CopyN(os.Stderr, conn, int64(n)); err != nil {
|
||||
return fmt.Errorf("rsh: stderr: %w", err)
|
||||
}
|
||||
case rshChanExit:
|
||||
if n != 4 {
|
||||
return fmt.Errorf("rsh: invalid exit frame size: %d", n)
|
||||
}
|
||||
var exitBuf [4]byte
|
||||
if _, err := io.ReadFull(conn, exitBuf[:]); err != nil {
|
||||
return fmt.Errorf("rsh: read exit code: %w", err)
|
||||
}
|
||||
code := int(binary.BigEndian.Uint32(exitBuf[:]))
|
||||
if code != 0 {
|
||||
os.Exit(code)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// Unknown channel, skip the payload.
|
||||
if _, err := io.CopyN(io.Discard, conn, int64(n)); err != nil {
|
||||
return fmt.Errorf("rsh: skip unknown frame: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeFrame writes a single rsh protocol frame to w.
|
||||
func writeFrame(w io.Writer, ch byte, data []byte) error {
|
||||
var hdr [rshFrameHdrSize]byte
|
||||
hdr[0] = ch
|
||||
binary.BigEndian.PutUint32(hdr[1:], uint32(len(data)))
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) > 0 {
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseRshArgs parses SSH-compatible arguments as passed by rsync and
|
||||
// similar tools when using rsh as a remote shell transport.
|
||||
//
|
||||
// rsync invokes the remote shell as:
|
||||
//
|
||||
// tailscale rsh [user@host] [-l user] [-o option]... <host> <command...>
|
||||
//
|
||||
// The user@host may appear as the first positional arg (from the rsync
|
||||
// URI), while -l overrides the username. The bare hostname after flags
|
||||
// is the actual target. Everything after that is the remote command.
|
||||
//
|
||||
// Returns the resolved username (may be empty if none specified), host,
|
||||
// and command args.
|
||||
func parseRshArgs(args []string) (username, host string, cmdArgs []string, err error) {
|
||||
if len(args) == 0 {
|
||||
return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]")
|
||||
}
|
||||
|
||||
// First, check if args[0] is a user@host or bare host (not a flag).
|
||||
// rsync passes the user@host from the rsync URI as the first arg,
|
||||
// before any -l flag.
|
||||
i := 0
|
||||
if !strings.HasPrefix(args[0], "-") {
|
||||
u, h, hasAt := strings.Cut(args[0], "@")
|
||||
if hasAt {
|
||||
username = u
|
||||
host = h
|
||||
} else {
|
||||
// Bare hostname (no @). Record it; it may be
|
||||
// overridden if a second bare hostname appears
|
||||
// after flags (the rsync pattern).
|
||||
host = args[0]
|
||||
}
|
||||
i = 1
|
||||
}
|
||||
|
||||
// Parse SSH-compatible flags.
|
||||
flagUser := ""
|
||||
hadFlags := false
|
||||
for i < len(args) {
|
||||
a := args[i]
|
||||
if a == "--" {
|
||||
i++
|
||||
break
|
||||
}
|
||||
if !strings.HasPrefix(a, "-") {
|
||||
break // first non-flag is the host
|
||||
}
|
||||
hadFlags = true
|
||||
switch {
|
||||
case a == "-l":
|
||||
// -l <user>
|
||||
i++
|
||||
if i >= len(args) {
|
||||
return "", "", nil, errors.New("rsh: -l requires an argument")
|
||||
}
|
||||
flagUser = args[i]
|
||||
i++
|
||||
case strings.HasPrefix(a, "-l"):
|
||||
// -l<user> (no space)
|
||||
flagUser = a[2:]
|
||||
i++
|
||||
case a == "-o":
|
||||
// -o <option>: SSH option, ignore.
|
||||
i++
|
||||
if i < len(args) {
|
||||
i++ // skip the option value
|
||||
}
|
||||
case strings.HasPrefix(a, "-o"):
|
||||
// -o<option>: SSH option, ignore.
|
||||
i++
|
||||
default:
|
||||
// Unknown flag (e.g. -4, -6, -p, etc.); skip it.
|
||||
// SSH has many flags; we silently ignore ones we
|
||||
// don't understand since we don't need them.
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// After flags, the next non-flag arg is the host. rsync passes
|
||||
// the bare hostname after -l flags, so we expect it here when
|
||||
// flags were present. When there were no flags and we already
|
||||
// have a host from args[0], the remaining args are the command.
|
||||
if hadFlags && i < len(args) && !strings.HasPrefix(args[i], "-") {
|
||||
host = args[i]
|
||||
i++
|
||||
}
|
||||
|
||||
// Everything remaining is the command.
|
||||
cmdArgs = args[i:]
|
||||
|
||||
// -l flag overrides any user from user@host.
|
||||
if flagUser != "" {
|
||||
username = flagUser
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return "", "", nil, errors.New("usage: tailscale rsh [-l user] [user@]<host> [command...]")
|
||||
}
|
||||
|
||||
return username, host, cmdArgs, nil
|
||||
}
|
||||
|
||||
// splitHostPort splits a host:port string. Unlike net.SplitHostPort,
|
||||
// it handles bare IPv4 addresses with port (100.1.2.3:1234) as well
|
||||
// as [IPv6]:port format.
|
||||
func splitHostPort(addr string) (host, port string, err error) {
|
||||
// Handle IPv6 [::]:port format.
|
||||
if strings.HasPrefix(addr, "[") {
|
||||
end := strings.Index(addr, "]:")
|
||||
if end < 0 {
|
||||
return "", "", fmt.Errorf("invalid address: %s", addr)
|
||||
}
|
||||
return addr[1:end], addr[end+2:], nil
|
||||
}
|
||||
// Handle IPv4 host:port.
|
||||
i := strings.LastIndex(addr, ":")
|
||||
if i < 0 {
|
||||
return "", "", fmt.Errorf("no port in address: %s", addr)
|
||||
}
|
||||
return addr[:i], addr[i+1:], nil
|
||||
}
|
||||
|
||||
// currentUser returns the current OS username.
|
||||
func currentUser() (string, error) {
|
||||
// os/user.Current() can fail in some environments (static builds, etc).
|
||||
// Try it first, fall back to env vars.
|
||||
if u := os.Getenv("USER"); u != "" {
|
||||
return u, nil
|
||||
}
|
||||
return "", errors.New("cannot determine current user")
|
||||
}
|
||||
142
cmd/tailscale/cli/rsh_test.go
Normal file
142
cmd/tailscale/cli/rsh_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRshArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantUser string
|
||||
wantHost string
|
||||
wantCmd string // joined with space
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple_user_at_host_with_command",
|
||||
args: []string{"alice@myhost", "ls", "-la"},
|
||||
wantUser: "alice",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls -la",
|
||||
},
|
||||
{
|
||||
name: "bare_host_with_command",
|
||||
args: []string{"myhost", "ls"},
|
||||
wantUser: "",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls",
|
||||
},
|
||||
{
|
||||
name: "bare_host_no_command",
|
||||
args: []string{"myhost"},
|
||||
wantUser: "",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "",
|
||||
},
|
||||
{
|
||||
// This is the exact pattern from rsync:
|
||||
// tailscale rsh ubuntu@james-ai -l ubuntu james-ai rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/
|
||||
name: "rsync_pattern",
|
||||
args: []string{"ubuntu@james-ai", "-l", "ubuntu", "james-ai", "rsync", "--server", "--sender", "-vlogDtpre.iLsfxCIvu", ".", "ai/"},
|
||||
wantUser: "ubuntu",
|
||||
wantHost: "james-ai",
|
||||
wantCmd: "rsync --server --sender -vlogDtpre.iLsfxCIvu . ai/",
|
||||
},
|
||||
{
|
||||
name: "l_flag_overrides_user_at_host",
|
||||
args: []string{"alice@myhost", "-l", "bob", "myhost", "echo", "hi"},
|
||||
wantUser: "bob",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "echo hi",
|
||||
},
|
||||
{
|
||||
name: "l_flag_no_space",
|
||||
args: []string{"myhost", "-lubuntu", "myhost", "ls"},
|
||||
wantUser: "ubuntu",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls",
|
||||
},
|
||||
{
|
||||
name: "l_flag_without_user_at_host",
|
||||
args: []string{"-l", "ubuntu", "myhost", "rsync", "--server"},
|
||||
wantUser: "ubuntu",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "rsync --server",
|
||||
},
|
||||
{
|
||||
name: "o_flag_ignored",
|
||||
args: []string{"alice@myhost", "-o", "StrictHostKeyChecking=no", "myhost", "ls"},
|
||||
wantUser: "alice",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls",
|
||||
},
|
||||
{
|
||||
name: "o_flag_no_space_ignored",
|
||||
args: []string{"alice@myhost", "-oStrictHostKeyChecking=no", "myhost", "ls"},
|
||||
wantUser: "alice",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls",
|
||||
},
|
||||
{
|
||||
name: "multiple_flags",
|
||||
args: []string{"alice@myhost", "-o", "BatchMode=yes", "-l", "root", "myhost", "uptime"},
|
||||
wantUser: "root",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "uptime",
|
||||
},
|
||||
{
|
||||
name: "unknown_flags_skipped",
|
||||
args: []string{"alice@myhost", "-4", "-p", "myhost", "ls"},
|
||||
wantUser: "alice",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "ls",
|
||||
},
|
||||
{
|
||||
name: "double_dash_separator",
|
||||
args: []string{"myhost", "--", "-l", "this-is-command"},
|
||||
wantUser: "",
|
||||
wantHost: "myhost",
|
||||
wantCmd: "-l this-is-command",
|
||||
},
|
||||
{
|
||||
name: "empty_args",
|
||||
args: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "l_flag_missing_value",
|
||||
args: []string{"myhost", "-l"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, host, cmdArgs, err := parseRshArgs(tt.args)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if user != tt.wantUser {
|
||||
t.Errorf("user = %q, want %q", user, tt.wantUser)
|
||||
}
|
||||
if host != tt.wantHost {
|
||||
t.Errorf("host = %q, want %q", host, tt.wantHost)
|
||||
}
|
||||
gotCmd := strings.Join(cmdArgs, " ")
|
||||
if gotCmd != tt.wantCmd {
|
||||
t.Errorf("command = %q, want %q", gotCmd, tt.wantCmd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
8
feature/condregister/maybe_rsh.go
Normal file
8
feature/condregister/maybe_rsh.go
Normal file
@ -0,0 +1,8 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd) && !ts_omit_rsh
|
||||
|
||||
package condregister
|
||||
|
||||
import _ "tailscale.com/feature/rsh"
|
||||
286
feature/rsh/checkmode_test.go
Normal file
286
feature/rsh/checkmode_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
)
|
||||
|
||||
func TestExpandDelegateURL(t *testing.T) {
|
||||
nm := &netmap.NetworkMap{
|
||||
SelfNode: (&tailcfg.Node{
|
||||
ID: 42,
|
||||
StableID: "self-stable",
|
||||
Key: key.NewNode().Public(),
|
||||
Addresses: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
},
|
||||
}).View(),
|
||||
}
|
||||
|
||||
peerNode := (&tailcfg.Node{
|
||||
ID: 99,
|
||||
StableID: "peer-stable",
|
||||
Key: key.NewNode().Public(),
|
||||
}).View()
|
||||
|
||||
peerAddr := netip.MustParseAddr("100.64.1.2")
|
||||
lu := &user.User{Username: "localice"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "all_variables",
|
||||
url: "https://control.example.com/check?src=$SRC_NODE_IP&srcid=$SRC_NODE_ID&dst=$DST_NODE_IP&dstid=$DST_NODE_ID&sshuser=$SSH_USER&local=$LOCAL_USER",
|
||||
want: "https://control.example.com/check?src=100.64.1.2&srcid=99&dst=100.64.0.1&dstid=42&sshuser=alice&local=localice",
|
||||
},
|
||||
{
|
||||
name: "no_variables",
|
||||
url: "https://control.example.com/check?static=true",
|
||||
want: "https://control.example.com/check?static=true",
|
||||
},
|
||||
{
|
||||
name: "url_encoding",
|
||||
url: "https://control.example.com/check?user=$SSH_USER",
|
||||
want: "https://control.example.com/check?user=alice%40example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sshUser := "alice"
|
||||
if tt.name == "url_encoding" {
|
||||
sshUser = "alice@example.com"
|
||||
}
|
||||
got := expandDelegateURL(tt.url, nm, peerNode, peerAddr, sshUser, lu)
|
||||
if got != tt.want {
|
||||
t.Errorf("expandDelegateURL() =\n %s\nwant:\n %s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNDJSON(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Write a status message.
|
||||
writeNDJSON(&buf, nil, rshStatusMessage{Status: "waiting"})
|
||||
|
||||
// Write an rshResponse.
|
||||
writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.1:1234", Token: "abcd"})
|
||||
|
||||
// Verify output is two newline-delimited JSON lines.
|
||||
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("got %d lines, want 2:\n%s", len(lines), buf.String())
|
||||
}
|
||||
|
||||
// Verify first line is a status message.
|
||||
var msg rshStatusMessage
|
||||
if err := json.Unmarshal([]byte(lines[0]), &msg); err != nil {
|
||||
t.Fatalf("unmarshal line 0: %v", err)
|
||||
}
|
||||
if msg.Status != "waiting" {
|
||||
t.Errorf("status = %q, want %q", msg.Status, "waiting")
|
||||
}
|
||||
|
||||
// Verify second line is a response.
|
||||
var resp rshResponse
|
||||
if err := json.Unmarshal([]byte(lines[1]), &resp); err != nil {
|
||||
t.Fatalf("unmarshal line 1: %v", err)
|
||||
}
|
||||
if resp.Addr != "100.64.0.1:1234" {
|
||||
t.Errorf("addr = %q, want %q", resp.Addr, "100.64.0.1:1234")
|
||||
}
|
||||
if resp.Token != "abcd" {
|
||||
t.Errorf("token = %q, want %q", resp.Token, "abcd")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNDJSONStreamParsing(t *testing.T) {
|
||||
// Simulate a streaming NDJSON response as the CLI would see it.
|
||||
var buf bytes.Buffer
|
||||
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Checking with control plane..."})
|
||||
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Waiting for approval..."})
|
||||
writeNDJSON(&buf, nil, rshStatusMessage{Status: "Access approved"})
|
||||
writeNDJSON(&buf, nil, rshResponse{Addr: "100.64.0.5:4567", Token: "deadbeef"})
|
||||
|
||||
// Parse like the CLI does.
|
||||
scanner := bufio.NewScanner(&buf)
|
||||
var statusMessages []string
|
||||
var finalResp rshResponse
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
var candidate rshResponse
|
||||
if err := json.Unmarshal(line, &candidate); err == nil && candidate.Addr != "" {
|
||||
finalResp = candidate
|
||||
continue
|
||||
}
|
||||
var msg rshStatusMessage
|
||||
if err := json.Unmarshal(line, &msg); err == nil && msg.Status != "" {
|
||||
statusMessages = append(statusMessages, msg.Status)
|
||||
}
|
||||
}
|
||||
|
||||
if len(statusMessages) != 3 {
|
||||
t.Fatalf("got %d status messages, want 3", len(statusMessages))
|
||||
}
|
||||
if statusMessages[0] != "Checking with control plane..." {
|
||||
t.Errorf("status[0] = %q, want %q", statusMessages[0], "Checking with control plane...")
|
||||
}
|
||||
if statusMessages[2] != "Access approved" {
|
||||
t.Errorf("status[2] = %q, want %q", statusMessages[2], "Access approved")
|
||||
}
|
||||
if finalResp.Addr != "100.64.0.5:4567" {
|
||||
t.Errorf("addr = %q, want %q", finalResp.Addr, "100.64.0.5:4567")
|
||||
}
|
||||
if finalResp.Token != "deadbeef" {
|
||||
t.Errorf("token = %q, want %q", finalResp.Token, "deadbeef")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalSSHPolicyHoldAndDelegate(t *testing.T) {
|
||||
now := timeVal(2025, 1, 1)
|
||||
|
||||
node := (&tailcfg.Node{
|
||||
ID: 1,
|
||||
StableID: "stable1",
|
||||
Key: key.NewNode().Public(),
|
||||
}).View()
|
||||
|
||||
uprof := tailcfg.UserProfile{
|
||||
LoginName: "alice@example.com",
|
||||
}
|
||||
|
||||
srcAddr := netip.MustParseAddr("100.64.1.2")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *tailcfg.SSHPolicy
|
||||
sshUser string
|
||||
wantResult evalResult
|
||||
wantUser string
|
||||
wantURL string
|
||||
}{
|
||||
{
|
||||
name: "hold_and_delegate_with_message",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{
|
||||
HoldAndDelegate: "https://control.example.com/approve?user=$SSH_USER",
|
||||
Message: "Please approve in the admin panel",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalHoldDelegate,
|
||||
wantUser: "alice",
|
||||
wantURL: "https://control.example.com/approve?user=$SSH_USER",
|
||||
},
|
||||
{
|
||||
name: "hold_with_specific_user_mapping",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}},
|
||||
SSHUsers: map[string]string{"root": "admin"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
HoldAndDelegate: "https://control.example.com/check",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "root",
|
||||
wantResult: evalHoldDelegate,
|
||||
wantUser: "admin",
|
||||
wantURL: "https://control.example.com/check",
|
||||
},
|
||||
{
|
||||
name: "hold_rejects_unmapped_user",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"root": "admin"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
HoldAndDelegate: "https://control.example.com/check",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "unknown",
|
||||
wantResult: evalRejectedUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now)
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("result = %v, want %v", result, tt.wantResult)
|
||||
}
|
||||
if tt.wantUser != "" && localUser != tt.wantUser {
|
||||
t.Errorf("localUser = %q, want %q", localUser, tt.wantUser)
|
||||
}
|
||||
if tt.wantURL != "" && action != nil && action.HoldAndDelegate != tt.wantURL {
|
||||
t.Errorf("HoldAndDelegate = %q, want %q", action.HoldAndDelegate, tt.wantURL)
|
||||
}
|
||||
if tt.wantResult == evalHoldDelegate && action == nil {
|
||||
t.Error("expected non-nil action for evalHoldDelegate result")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandDelegateURLNilFields(t *testing.T) {
|
||||
// Test with minimal/nil fields to ensure no panics.
|
||||
lu := &user.User{Username: "bob"}
|
||||
peerAddr := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
// Nil netmap, invalid peer node.
|
||||
got := expandDelegateURL(
|
||||
"https://control.example.com/check?dst=$DST_NODE_ID&src=$SRC_NODE_ID",
|
||||
nil,
|
||||
tailcfg.NodeView{}, // invalid
|
||||
peerAddr,
|
||||
"bob",
|
||||
lu,
|
||||
)
|
||||
// Should not panic; missing IDs should be empty strings.
|
||||
if strings.Contains(got, "$DST_NODE_ID") {
|
||||
t.Errorf("unexpanded variable in URL: %s", got)
|
||||
}
|
||||
if strings.Contains(got, "$SRC_NODE_ID") {
|
||||
t.Errorf("unexpanded variable in URL: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func timeVal(year, month, day int) time.Time {
|
||||
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
156
feature/rsh/localapi.go
Normal file
156
feature/rsh/localapi.go
Normal file
@ -0,0 +1,156 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/ipn/localapi"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/clientmetric"
|
||||
)
|
||||
|
||||
var (
|
||||
metricLocalAPIRshCalls = clientmetric.NewCounter("localapi_rsh")
|
||||
)
|
||||
|
||||
func init() {
|
||||
localapi.Register("rsh", serveRsh)
|
||||
}
|
||||
|
||||
// localRshRequest is the JSON body the CLI sends to POST /localapi/v0/rsh.
|
||||
// It includes the target peer information.
|
||||
type localRshRequest struct {
|
||||
// PeerID is the StableNodeID of the target peer.
|
||||
PeerID tailcfg.StableNodeID `json:"peer"`
|
||||
|
||||
// User is the SSH user to connect as.
|
||||
User string `json:"user"`
|
||||
|
||||
// Command is the command to execute on the remote.
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
// serveRsh proxies an rsh setup request to the target peer's PeerAPI.
|
||||
//
|
||||
// POST /localapi/v0/rsh
|
||||
//
|
||||
// Request body: JSON localRshRequest
|
||||
// Response body: JSON rshResponse (addr + token from the remote)
|
||||
func serveRsh(h *localapi.Handler, w http.ResponseWriter, r *http.Request) {
|
||||
metricLocalAPIRshCalls.Add(1)
|
||||
|
||||
if !h.PermitRead {
|
||||
http.Error(w, "rsh access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req localRshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.PeerID == "" {
|
||||
http.Error(w, "peer is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.User == "" {
|
||||
http.Error(w, "user is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
b := h.LocalBackend()
|
||||
nm := b.NetMap()
|
||||
if nm == nil {
|
||||
http.Error(w, "no netmap available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Find the peer and its PeerAPI base URL.
|
||||
var peerAPIBaseURL string
|
||||
for _, p := range nm.Peers {
|
||||
if p.StableID() == req.PeerID {
|
||||
peerAPIBaseURL = b.PeerAPIBase(p)
|
||||
break
|
||||
}
|
||||
}
|
||||
if peerAPIBaseURL == "" {
|
||||
http.Error(w, "peer not found or no PeerAPI available", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the PeerAPI request.
|
||||
peerReqBody := rshRequest{
|
||||
User: req.User,
|
||||
Command: req.Command,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(peerReqBody)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
peerURL := strings.TrimRight(peerAPIBaseURL, "/") + "/v0/rsh"
|
||||
peerReq, err := http.NewRequestWithContext(r.Context(), "POST", peerURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
http.Error(w, "internal error: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
peerReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Use the PeerAPI transport to dial the remote peer.
|
||||
tr := b.Dialer().PeerAPITransport()
|
||||
resp, err := tr.RoundTrip(peerReq)
|
||||
if err != nil {
|
||||
h.Logf("rsh: failed to reach peer %s: %v", req.PeerID, err)
|
||||
http.Error(w, fmt.Sprintf("failed to reach peer: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
h.Logf("rsh: peer returned status %d: %s", resp.StatusCode, string(body))
|
||||
http.Error(w, fmt.Sprintf("peer error: %s", string(body)), resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// Pass through the response from the peer. If the peer is using
|
||||
// streaming NDJSON (check mode / HoldAndDelegate), we forward
|
||||
// each line as it arrives so the CLI can display status messages.
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
w.Header().Set("Content-Type", ct)
|
||||
if strings.HasPrefix(ct, "application/x-ndjson") {
|
||||
// Streaming mode: flush each line as it arrives.
|
||||
flusher, _ := w.(http.Flusher)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
w.Write(buf[:n])
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Simple JSON response: pass through directly.
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
}
|
||||
152
feature/rsh/policy.go
Normal file
152
feature/rsh/policy.go
Normal file
@ -0,0 +1,152 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// evalResult is the result of SSH policy evaluation.
|
||||
type evalResult int
|
||||
|
||||
const (
|
||||
evalAccepted evalResult = iota // rule matched with Accept
|
||||
evalRejected // no matching rule, or explicit Reject
|
||||
evalRejectedUser // principal matched but user mapping failed
|
||||
evalHoldDelegate // rule matched with HoldAndDelegate (check mode)
|
||||
)
|
||||
|
||||
// evalSSHPolicy evaluates the SSH policy for the given parameters.
|
||||
// This replicates the core matching logic from ssh/tailssh without
|
||||
// depending on the SSH connection type.
|
||||
//
|
||||
// It returns the matching action, the mapped local user, and the evaluation result.
|
||||
func evalSSHPolicy(
|
||||
pol *tailcfg.SSHPolicy,
|
||||
node tailcfg.NodeView,
|
||||
uprof tailcfg.UserProfile,
|
||||
srcAddr netip.Addr,
|
||||
sshUser string,
|
||||
now time.Time,
|
||||
) (action *tailcfg.SSHAction, localUser string, result evalResult) {
|
||||
if pol == nil {
|
||||
return nil, "", evalRejected
|
||||
}
|
||||
failedOnUser := false
|
||||
for _, r := range pol.Rules {
|
||||
if a, lu, err := matchRule(r, node, uprof, srcAddr, sshUser, now); err == nil {
|
||||
if a.HoldAndDelegate != "" {
|
||||
return a, lu, evalHoldDelegate
|
||||
}
|
||||
return a, lu, evalAccepted
|
||||
} else if errors.Is(err, errUserMatch) {
|
||||
failedOnUser = true
|
||||
}
|
||||
}
|
||||
if failedOnUser {
|
||||
return nil, "", evalRejectedUser
|
||||
}
|
||||
return nil, "", evalRejected
|
||||
}
|
||||
|
||||
var (
|
||||
errNilRule = errors.New("nil rule")
|
||||
errNilAction = errors.New("nil action")
|
||||
errRuleExpired = errors.New("rule expired")
|
||||
errPrincipalMatch = errors.New("principal didn't match")
|
||||
errUserMatch = errors.New("user didn't match")
|
||||
)
|
||||
|
||||
// matchRule checks whether a single SSHRule matches the given parameters.
|
||||
func matchRule(
|
||||
r *tailcfg.SSHRule,
|
||||
node tailcfg.NodeView,
|
||||
uprof tailcfg.UserProfile,
|
||||
srcAddr netip.Addr,
|
||||
sshUser string,
|
||||
now time.Time,
|
||||
) (action *tailcfg.SSHAction, localUser string, err error) {
|
||||
if r == nil {
|
||||
return nil, "", errNilRule
|
||||
}
|
||||
if r.Action == nil {
|
||||
return nil, "", errNilAction
|
||||
}
|
||||
if r.RuleExpires != nil && r.RuleExpires.Before(now) {
|
||||
return nil, "", errRuleExpired
|
||||
}
|
||||
if !anyPrincipalMatches(r.Principals, node, uprof, srcAddr) {
|
||||
return nil, "", errPrincipalMatch
|
||||
}
|
||||
if !r.Action.Reject {
|
||||
localUser = mapLocalUser(r.SSHUsers, sshUser)
|
||||
if localUser == "" {
|
||||
return nil, "", errUserMatch
|
||||
}
|
||||
}
|
||||
return r.Action, localUser, nil
|
||||
}
|
||||
|
||||
// anyPrincipalMatches reports whether any of the given principals match
|
||||
// the Tailscale identity of the connecting peer.
|
||||
func anyPrincipalMatches(
|
||||
ps []*tailcfg.SSHPrincipal,
|
||||
node tailcfg.NodeView,
|
||||
uprof tailcfg.UserProfile,
|
||||
srcAddr netip.Addr,
|
||||
) bool {
|
||||
for _, p := range ps {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if principalMatchesTailscaleIdentity(p, node, uprof, srcAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// principalMatchesTailscaleIdentity reports whether a principal matches
|
||||
// the Tailscale identity of the connecting peer.
|
||||
func principalMatchesTailscaleIdentity(
|
||||
p *tailcfg.SSHPrincipal,
|
||||
node tailcfg.NodeView,
|
||||
uprof tailcfg.UserProfile,
|
||||
srcAddr netip.Addr,
|
||||
) bool {
|
||||
if p.Any {
|
||||
return true
|
||||
}
|
||||
if !p.Node.IsZero() && node.Valid() && p.Node == node.StableID() {
|
||||
return true
|
||||
}
|
||||
if p.NodeIP != "" {
|
||||
if ip, _ := netip.ParseAddr(p.NodeIP); ip == srcAddr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if p.UserLogin != "" && uprof.LoginName == p.UserLogin {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mapLocalUser maps an SSH user to a local user using the SSHUsers map
|
||||
// from a policy rule.
|
||||
func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) string {
|
||||
v, ok := ruleSSHUsers[reqSSHUser]
|
||||
if !ok {
|
||||
v = ruleSSHUsers["*"]
|
||||
}
|
||||
if v == "=" {
|
||||
return reqSSHUser
|
||||
}
|
||||
return v
|
||||
}
|
||||
257
feature/rsh/policy_test.go
Normal file
257
feature/rsh/policy_test.go
Normal file
@ -0,0 +1,257 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestEvalSSHPolicy(t *testing.T) {
|
||||
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
node := (&tailcfg.Node{
|
||||
ID: 1,
|
||||
StableID: "stable1",
|
||||
Key: key.NewNode().Public(),
|
||||
}).View()
|
||||
|
||||
uprof := tailcfg.UserProfile{
|
||||
LoginName: "alice@example.com",
|
||||
}
|
||||
|
||||
srcAddr := netip.MustParseAddr("100.64.1.2")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *tailcfg.SSHPolicy
|
||||
sshUser string
|
||||
wantResult evalResult
|
||||
wantUser string
|
||||
}{
|
||||
{
|
||||
name: "nil_policy",
|
||||
pol: nil,
|
||||
sshUser: "root",
|
||||
wantResult: evalRejected,
|
||||
},
|
||||
{
|
||||
name: "empty_policy",
|
||||
pol: &tailcfg.SSHPolicy{},
|
||||
sshUser: "root",
|
||||
wantResult: evalRejected,
|
||||
},
|
||||
{
|
||||
name: "accept_any_wildcard_user",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "alice",
|
||||
},
|
||||
{
|
||||
name: "accept_specific_user_mapping",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"root": "admin"},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "root",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "admin",
|
||||
},
|
||||
{
|
||||
name: "reject_unmapped_user",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"root": "admin"},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "unknown",
|
||||
wantResult: evalRejectedUser,
|
||||
},
|
||||
{
|
||||
name: "match_by_node_stable_id",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Node: "stable1"}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "bob",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "bob",
|
||||
},
|
||||
{
|
||||
name: "reject_wrong_node",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Node: "other-node"}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "bob",
|
||||
wantResult: evalRejected,
|
||||
},
|
||||
{
|
||||
name: "match_by_node_ip",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "alice",
|
||||
},
|
||||
{
|
||||
name: "match_by_user_login",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "alice@example.com"}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "alice",
|
||||
},
|
||||
{
|
||||
name: "hold_and_delegate_returns_eval_hold",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{HoldAndDelegate: "https://example.com/approve"},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalHoldDelegate,
|
||||
},
|
||||
{
|
||||
name: "explicit_reject_action",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
Action: &tailcfg.SSHAction{Reject: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalAccepted, // matchRule succeeds for Reject rules (no SSHUsers check)
|
||||
},
|
||||
{
|
||||
name: "expired_rule_skipped",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
RuleExpires: timePtr(now.Add(-time.Hour)),
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalRejected,
|
||||
},
|
||||
{
|
||||
name: "first_matching_rule_wins",
|
||||
pol: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.1.2"}},
|
||||
SSHUsers: map[string]string{"alice": "localice"},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{"*": "="},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
sshUser: "alice",
|
||||
wantResult: evalAccepted,
|
||||
wantUser: "localice",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
action, localUser, result := evalSSHPolicy(tt.pol, node, uprof, srcAddr, tt.sshUser, now)
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("result = %v, want %v", result, tt.wantResult)
|
||||
}
|
||||
if tt.wantUser != "" && localUser != tt.wantUser {
|
||||
t.Errorf("localUser = %q, want %q", localUser, tt.wantUser)
|
||||
}
|
||||
_ = action // not checked in most tests
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLocalUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sshUsers map[string]string
|
||||
reqUser string
|
||||
wantResult string
|
||||
}{
|
||||
{"exact_match", map[string]string{"root": "admin"}, "root", "admin"},
|
||||
{"wildcard_match", map[string]string{"*": "defaultuser"}, "anyone", "defaultuser"},
|
||||
{"identity_match", map[string]string{"*": "="}, "alice", "alice"},
|
||||
{"no_match", map[string]string{"root": "admin"}, "unknown", ""},
|
||||
{"exact_over_wildcard", map[string]string{"root": "admin", "*": "default"}, "root", "admin"},
|
||||
{"nil_map", nil, "alice", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mapLocalUser(tt.sshUsers, tt.reqUser)
|
||||
if got != tt.wantResult {
|
||||
t.Errorf("mapLocalUser(%v, %q) = %q, want %q", tt.sshUsers, tt.reqUser, got, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time { return &t }
|
||||
157
feature/rsh/protocol.go
Normal file
157
feature/rsh/protocol.go
Normal file
@ -0,0 +1,157 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package rsh implements a fast remote shell transport over Tailscale,
|
||||
// designed as an rsync -e compatible replacement for SSH. It uses a PeerAPI
|
||||
// endpoint for session setup and a raw TCP data channel for I/O,
|
||||
// avoiding SSH's double encryption and suboptimal buffering.
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Channel type constants for the wire protocol.
|
||||
// The protocol is length-prefixed framing:
|
||||
//
|
||||
// [1 byte: channel] [4 bytes: length (big-endian)] [N bytes: payload]
|
||||
const (
|
||||
// ChanStdin is data from client to server (remote process stdin).
|
||||
ChanStdin byte = 0x00
|
||||
|
||||
// ChanStdout is data from server to client (remote process stdout).
|
||||
ChanStdout byte = 0x01
|
||||
|
||||
// ChanStderr is data from server to client (remote process stderr).
|
||||
ChanStderr byte = 0x02
|
||||
|
||||
// ChanExit is the exit code from the remote process.
|
||||
// Payload is a 4-byte big-endian signed integer exit code.
|
||||
// Sent by server to client, then the server closes the connection.
|
||||
ChanExit byte = 0x03
|
||||
)
|
||||
|
||||
const (
|
||||
// tokenLen is the length of the one-time authentication token.
|
||||
tokenLen = 32
|
||||
|
||||
// maxFrameSize is the maximum payload size for a single frame.
|
||||
// 256KB is a good balance between throughput and memory usage,
|
||||
// matching typical rsync block sizes.
|
||||
maxFrameSize = 256 * 1024
|
||||
|
||||
// frameHeaderSize is the size of the frame header (channel + length).
|
||||
frameHeaderSize = 5
|
||||
)
|
||||
|
||||
// frameWriter writes length-prefixed frames to an underlying writer.
|
||||
// It is safe for concurrent use.
|
||||
type frameWriter struct {
|
||||
mu sync.Mutex
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// newFrameWriter creates a new frameWriter that writes to w.
|
||||
func newFrameWriter(w io.Writer) *frameWriter {
|
||||
return &frameWriter{w: w}
|
||||
}
|
||||
|
||||
// WriteFrame writes a single frame with the given channel and payload.
|
||||
func (fw *frameWriter) WriteFrame(ch byte, data []byte) error {
|
||||
if len(data) > maxFrameSize {
|
||||
return fmt.Errorf("rsh: frame payload too large: %d > %d", len(data), maxFrameSize)
|
||||
}
|
||||
fw.mu.Lock()
|
||||
defer fw.mu.Unlock()
|
||||
|
||||
var hdr [frameHeaderSize]byte
|
||||
hdr[0] = ch
|
||||
binary.BigEndian.PutUint32(hdr[1:], uint32(len(data)))
|
||||
|
||||
if _, err := fw.w.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) > 0 {
|
||||
if _, err := fw.w.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteExitCode writes an exit code frame and is a convenience wrapper.
|
||||
func (fw *frameWriter) WriteExitCode(code int) error {
|
||||
var buf [4]byte
|
||||
binary.BigEndian.PutUint32(buf[:], uint32(code))
|
||||
return fw.WriteFrame(ChanExit, buf[:])
|
||||
}
|
||||
|
||||
// frameReader reads length-prefixed frames from an underlying reader.
|
||||
type frameReader struct {
|
||||
r io.Reader
|
||||
buf []byte // reusable buffer for payloads
|
||||
}
|
||||
|
||||
// newFrameReader creates a new frameReader that reads from r.
|
||||
func newFrameReader(r io.Reader) *frameReader {
|
||||
return &frameReader{
|
||||
r: r,
|
||||
buf: make([]byte, 0, 32*1024), // start small, grow as needed
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrame reads the next frame, returning the channel type and payload.
|
||||
// The returned payload slice is valid until the next call to ReadFrame.
|
||||
func (fr *frameReader) ReadFrame() (ch byte, data []byte, err error) {
|
||||
var hdr [frameHeaderSize]byte
|
||||
if _, err := io.ReadFull(fr.r, hdr[:]); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
ch = hdr[0]
|
||||
n := binary.BigEndian.Uint32(hdr[1:])
|
||||
if n > maxFrameSize {
|
||||
return 0, nil, fmt.Errorf("rsh: frame too large: %d > %d", n, maxFrameSize)
|
||||
}
|
||||
if int(n) > cap(fr.buf) {
|
||||
fr.buf = make([]byte, n)
|
||||
} else {
|
||||
fr.buf = fr.buf[:n]
|
||||
}
|
||||
if n > 0 {
|
||||
if _, err := io.ReadFull(fr.r, fr.buf); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
}
|
||||
return ch, fr.buf, nil
|
||||
}
|
||||
|
||||
// channelWriter wraps a frameWriter to implement io.Writer for a specific channel.
|
||||
type channelWriter struct {
|
||||
fw *frameWriter
|
||||
ch byte
|
||||
}
|
||||
|
||||
// newChannelWriter returns an io.Writer that writes all data as frames on
|
||||
// the given channel.
|
||||
func newChannelWriter(fw *frameWriter, ch byte) io.Writer {
|
||||
return &channelWriter{fw: fw, ch: ch}
|
||||
}
|
||||
|
||||
func (cw *channelWriter) Write(p []byte) (int, error) {
|
||||
written := 0
|
||||
for len(p) > 0 {
|
||||
chunk := p
|
||||
if len(chunk) > maxFrameSize {
|
||||
chunk = chunk[:maxFrameSize]
|
||||
}
|
||||
if err := cw.fw.WriteFrame(cw.ch, chunk); err != nil {
|
||||
return written, err
|
||||
}
|
||||
written += len(chunk)
|
||||
p = p[len(chunk):]
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
188
feature/rsh/protocol_test.go
Normal file
188
feature/rsh/protocol_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFrameRoundtrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
fw := newFrameWriter(&buf)
|
||||
fr := newFrameReader(&buf)
|
||||
|
||||
// Write several frames.
|
||||
if err := fw.WriteFrame(ChanStdout, []byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrame stdout: %v", err)
|
||||
}
|
||||
if err := fw.WriteFrame(ChanStderr, []byte("world")); err != nil {
|
||||
t.Fatalf("WriteFrame stderr: %v", err)
|
||||
}
|
||||
if err := fw.WriteFrame(ChanStdin, []byte("input")); err != nil {
|
||||
t.Fatalf("WriteFrame stdin: %v", err)
|
||||
}
|
||||
if err := fw.WriteExitCode(42); err != nil {
|
||||
t.Fatalf("WriteExitCode: %v", err)
|
||||
}
|
||||
|
||||
// Read them back.
|
||||
ch, data, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 1: %v", err)
|
||||
}
|
||||
if ch != ChanStdout || string(data) != "hello" {
|
||||
t.Errorf("frame 1: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdout, "hello")
|
||||
}
|
||||
|
||||
ch, data, err = fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 2: %v", err)
|
||||
}
|
||||
if ch != ChanStderr || string(data) != "world" {
|
||||
t.Errorf("frame 2: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStderr, "world")
|
||||
}
|
||||
|
||||
ch, data, err = fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 3: %v", err)
|
||||
}
|
||||
if ch != ChanStdin || string(data) != "input" {
|
||||
t.Errorf("frame 3: got ch=%d data=%q, want ch=%d data=%q", ch, data, ChanStdin, "input")
|
||||
}
|
||||
|
||||
ch, data, err = fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 4: %v", err)
|
||||
}
|
||||
if ch != ChanExit {
|
||||
t.Errorf("frame 4: got ch=%d, want ch=%d", ch, ChanExit)
|
||||
}
|
||||
if len(data) != 4 {
|
||||
t.Fatalf("exit frame data len = %d, want 4", len(data))
|
||||
}
|
||||
code := int(binary.BigEndian.Uint32(data))
|
||||
if code != 42 {
|
||||
t.Errorf("exit code = %d, want 42", code)
|
||||
}
|
||||
|
||||
// Should get EOF now.
|
||||
_, _, err = fr.ReadFrame()
|
||||
if err != io.EOF && err != io.ErrUnexpectedEOF {
|
||||
t.Errorf("expected EOF after all frames, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameEmptyPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
fw := newFrameWriter(&buf)
|
||||
fr := newFrameReader(&buf)
|
||||
|
||||
if err := fw.WriteFrame(ChanStdin, nil); err != nil {
|
||||
t.Fatalf("WriteFrame empty: %v", err)
|
||||
}
|
||||
|
||||
ch, data, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame: %v", err)
|
||||
}
|
||||
if ch != ChanStdin {
|
||||
t.Errorf("ch = %d, want %d", ch, ChanStdin)
|
||||
}
|
||||
if len(data) != 0 {
|
||||
t.Errorf("data len = %d, want 0", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameTooLarge(t *testing.T) {
|
||||
fw := newFrameWriter(io.Discard)
|
||||
data := make([]byte, maxFrameSize+1)
|
||||
if err := fw.WriteFrame(ChanStdout, data); err == nil {
|
||||
t.Error("expected error for oversized frame, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameReaderTooLarge(t *testing.T) {
|
||||
// Construct a frame with length > maxFrameSize.
|
||||
var buf bytes.Buffer
|
||||
var hdr [frameHeaderSize]byte
|
||||
hdr[0] = ChanStdout
|
||||
binary.BigEndian.PutUint32(hdr[1:], maxFrameSize+1)
|
||||
buf.Write(hdr[:])
|
||||
|
||||
fr := newFrameReader(&buf)
|
||||
_, _, err := fr.ReadFrame()
|
||||
if err == nil {
|
||||
t.Error("expected error for oversized frame in reader, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelWriter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
fw := newFrameWriter(&buf)
|
||||
cw := newChannelWriter(fw, ChanStdout)
|
||||
|
||||
data := []byte("hello world from channel writer")
|
||||
n, err := cw.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("n = %d, want %d", n, len(data))
|
||||
}
|
||||
|
||||
fr := newFrameReader(&buf)
|
||||
ch, got, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame: %v", err)
|
||||
}
|
||||
if ch != ChanStdout {
|
||||
t.Errorf("ch = %d, want %d", ch, ChanStdout)
|
||||
}
|
||||
if !bytes.Equal(got, data) {
|
||||
t.Errorf("data mismatch: got %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelWriterChunking(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
fw := newFrameWriter(&buf)
|
||||
cw := newChannelWriter(fw, ChanStdout)
|
||||
|
||||
// Write more than maxFrameSize to verify chunking.
|
||||
data := make([]byte, maxFrameSize+100)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
n, err := cw.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("n = %d, want %d", n, len(data))
|
||||
}
|
||||
|
||||
// Should produce two frames.
|
||||
fr := newFrameReader(&buf)
|
||||
|
||||
ch, chunk1, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 1: %v", err)
|
||||
}
|
||||
if ch != ChanStdout || len(chunk1) != maxFrameSize {
|
||||
t.Errorf("frame 1: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk1), ChanStdout, maxFrameSize)
|
||||
}
|
||||
|
||||
ch, chunk2, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrame 2: %v", err)
|
||||
}
|
||||
if ch != ChanStdout || len(chunk2) != 100 {
|
||||
t.Errorf("frame 2: ch=%d len=%d, want ch=%d len=%d", ch, len(chunk2), ChanStdout, 100)
|
||||
}
|
||||
}
|
||||
739
feature/rsh/rsh.go
Normal file
739
feature/rsh/rsh.go
Normal file
@ -0,0 +1,739 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/backoff"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/osuser"
|
||||
)
|
||||
|
||||
var (
|
||||
metricRshCalls = clientmetric.NewCounter("peerapi_rsh")
|
||||
metricRshAccepts = clientmetric.NewCounter("peerapi_rsh_accept")
|
||||
metricRshRejects = clientmetric.NewCounter("peerapi_rsh_reject")
|
||||
)
|
||||
|
||||
func init() {
|
||||
ipnlocal.RegisterPeerAPIHandler("/v0/rsh", handleRsh)
|
||||
}
|
||||
|
||||
// rshRequest is the JSON body sent to POST /v0/rsh.
|
||||
type rshRequest struct {
|
||||
// User is the requested SSH user (will be mapped via SSHUsers policy).
|
||||
User string `json:"user"`
|
||||
|
||||
// Command is the command to execute. If empty, the user's default
|
||||
// login shell is started.
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
// rshResponse is returned by a successful POST /v0/rsh.
|
||||
// In streaming mode (check mode), this is the final JSON line in
|
||||
// the newline-delimited JSON stream.
|
||||
type rshResponse struct {
|
||||
// Addr is the Tailscale IP:port to connect to for the data channel.
|
||||
Addr string `json:"addr"`
|
||||
|
||||
// Token is the hex-encoded one-time authentication token that must
|
||||
// be sent as the first bytes on the data channel connection.
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// rshStatusMessage is sent as a streaming JSON line during the
|
||||
// HoldAndDelegate (check mode) flow before the final rshResponse.
|
||||
// Each message is a newline-delimited JSON object.
|
||||
type rshStatusMessage struct {
|
||||
// Status is a human-readable status message to display to the user.
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// netstackTCPListenerFunc is the type of a function that creates a TCP
|
||||
// listener on the netstack (gVisor) network stack. It is set by the
|
||||
// netstack package at init time.
|
||||
//
|
||||
// We use a function hook instead of a type assertion on NetstackImpl
|
||||
// because netstack.Impl.ListenTCP returns *gonet.TCPListener (not
|
||||
// net.Listener), and importing gonet would create an unwanted gVisor
|
||||
// dependency.
|
||||
var netstackListenTCP func(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error)
|
||||
|
||||
const linux = "linux"
|
||||
|
||||
func handleRsh(ph ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) {
|
||||
metricRshCalls.Add(1)
|
||||
logf := ph.Logf
|
||||
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
b := ph.LocalBackend()
|
||||
|
||||
// Check that SSH is enabled on this node.
|
||||
if !b.ShouldRunSSH() {
|
||||
logf("rsh: denied; SSH not enabled")
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, "SSH not enabled on this node", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request.
|
||||
var req rshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
logf("rsh: bad request body: %v", err)
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.User == "" {
|
||||
http.Error(w, "user is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Evaluate SSH policy.
|
||||
peerNode := ph.Peer()
|
||||
peerAddr := ph.RemoteAddr().Addr()
|
||||
|
||||
nm := b.NetMap()
|
||||
if nm == nil {
|
||||
logf("rsh: no netmap")
|
||||
http.Error(w, "no netmap available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sshPol := nm.SSHPolicy
|
||||
if sshPol == nil {
|
||||
logf("rsh: no SSH policy")
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, "no SSH policy configured", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the peer's user profile for policy matching.
|
||||
_, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr())
|
||||
if !ok {
|
||||
logf("rsh: unknown peer %v", ph.RemoteAddr())
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, "unknown peer", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
action, localUser, result := evalSSHPolicy(sshPol, peerNode, uprof, peerAddr, req.User, time.Now())
|
||||
|
||||
switch result {
|
||||
case evalAccepted:
|
||||
if action.Reject {
|
||||
logf("rsh: policy explicitly rejects %v -> %s@%s", peerAddr, req.User, localUser)
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, "access denied by policy", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
// Good, accepted. action may still have a Message to send.
|
||||
case evalHoldDelegate:
|
||||
// Check mode: we need to poll the control plane for approval.
|
||||
// The response uses streaming newline-delimited JSON so
|
||||
// status messages can be sent while we wait.
|
||||
case evalRejectedUser:
|
||||
logf("rsh: user %q not mapped for peer %v", req.User, peerAddr)
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, fmt.Sprintf("user %q not permitted", req.User), http.StatusForbidden)
|
||||
return
|
||||
case evalRejected:
|
||||
logf("rsh: policy rejects %v -> %s", peerAddr, req.User)
|
||||
metricRshRejects.Add(1)
|
||||
http.Error(w, "access denied by policy", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the local user. We need this for both the immediate accept
|
||||
// path and the HoldAndDelegate path (to expand delegate URL variables).
|
||||
lu, loginShell, err := osuser.LookupByUsernameWithShell(localUser)
|
||||
if err != nil {
|
||||
logf("rsh: user lookup failed for %q: %v", localUser, err)
|
||||
http.Error(w, fmt.Sprintf("user %q not found", localUser), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
groupIDs, err := osuser.GetGroupIds(lu)
|
||||
if err != nil {
|
||||
logf("rsh: group lookup failed for %q: %v", localUser, err)
|
||||
http.Error(w, "failed to look up user groups", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// If HoldAndDelegate, run the check mode loop to get a terminal action.
|
||||
// We use streaming JSON so status messages can be sent to the client.
|
||||
if result == evalHoldDelegate {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
action, err = resolveCheckMode(r.Context(), b, action, nm, peerNode, peerAddr, req.User, lu, w, flusher, logf)
|
||||
if err != nil {
|
||||
// Connection is already streaming; send error as a status message.
|
||||
logf("rsh: check mode failed: %v", err)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: fmt.Sprintf("check mode error: %v", err)})
|
||||
return
|
||||
}
|
||||
if action.Reject {
|
||||
logf("rsh: check mode rejected %v -> %s", peerAddr, req.User)
|
||||
metricRshRejects.Add(1)
|
||||
msg := "access denied"
|
||||
if action.Message != "" {
|
||||
msg = action.Message
|
||||
}
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: msg})
|
||||
return
|
||||
}
|
||||
if !action.Accept {
|
||||
logf("rsh: check mode returned non-terminal action for %v -> %s", peerAddr, req.User)
|
||||
metricRshRejects.Add(1)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: "unexpected response from control"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Find a local Tailscale IP to listen on.
|
||||
listenAddr, err := pickListenAddr(nm, peerAddr)
|
||||
if err != nil {
|
||||
logf("rsh: no listen address: %v", err)
|
||||
if result == evalHoldDelegate {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: "no suitable listen address"})
|
||||
} else {
|
||||
http.Error(w, "no suitable listen address", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create the listener.
|
||||
ln, err := listenTailscale(b, listenAddr)
|
||||
if err != nil {
|
||||
logf("rsh: listen failed: %v", err)
|
||||
if result == evalHoldDelegate {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: "failed to create listener"})
|
||||
} else {
|
||||
http.Error(w, "failed to create listener", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Generate one-time token.
|
||||
var tokenBytes [tokenLen]byte
|
||||
if _, err := rand.Read(tokenBytes[:]); err != nil {
|
||||
ln.Close()
|
||||
logf("rsh: rand failed: %v", err)
|
||||
if result == evalHoldDelegate {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: "internal error"})
|
||||
} else {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
tokenHex := hex.EncodeToString(tokenBytes[:])
|
||||
|
||||
metricRshAccepts.Add(1)
|
||||
|
||||
// Start the session handler in a goroutine. It will accept one
|
||||
// connection, verify the token, and wire up the incubator process.
|
||||
go handleRshSession(b, ln, tokenBytes[:], peerAddr, lu, loginShell, groupIDs, req, ph, logf)
|
||||
|
||||
// Return the listen address and token to the client.
|
||||
resp := rshResponse{
|
||||
Addr: ln.Addr().String(),
|
||||
Token: tokenHex,
|
||||
}
|
||||
if result == evalHoldDelegate {
|
||||
// Streaming mode: send a final accept message then the response.
|
||||
flusher, _ := w.(http.Flusher)
|
||||
if action.Message != "" {
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message})
|
||||
}
|
||||
writeNDJSON(w, flusher, resp)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// writeNDJSON writes v as a single newline-delimited JSON line to w
|
||||
// and flushes. This is used for the streaming check mode response.
|
||||
func writeNDJSON(w io.Writer, flusher http.Flusher, v any) {
|
||||
json.NewEncoder(w).Encode(v) // Encode appends '\n'
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCheckMode runs the HoldAndDelegate loop, polling the control plane
|
||||
// until a terminal action (Accept or Reject) is returned. It sends status
|
||||
// messages to the client as streaming JSON lines while waiting.
|
||||
//
|
||||
// This is the rsh equivalent of SSH's clientAuth HoldAndDelegate loop.
|
||||
func resolveCheckMode(
|
||||
ctx context.Context,
|
||||
b *ipnlocal.LocalBackend,
|
||||
action *tailcfg.SSHAction,
|
||||
nm *netmap.NetworkMap,
|
||||
peerNode tailcfg.NodeView,
|
||||
peerAddr netip.Addr,
|
||||
sshUser string,
|
||||
lu *user.User,
|
||||
w io.Writer,
|
||||
flusher http.Flusher,
|
||||
logf func(string, ...any),
|
||||
) (*tailcfg.SSHAction, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
if action.Message != "" {
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: action.Message})
|
||||
}
|
||||
|
||||
if action.Accept || action.Reject {
|
||||
return action, nil
|
||||
}
|
||||
if action.HoldAndDelegate == "" {
|
||||
return nil, fmt.Errorf("action has neither Accept, Reject, nor HoldAndDelegate")
|
||||
}
|
||||
|
||||
delegateURL := expandDelegateURL(action.HoldAndDelegate, nm, peerNode, peerAddr, sshUser, lu)
|
||||
logf("rsh: check mode: polling %s", delegateURL)
|
||||
writeNDJSON(w, flusher, rshStatusMessage{Status: "Waiting for approval..."})
|
||||
|
||||
var err error
|
||||
action, err = fetchSSHAction(ctx, b, delegateURL, logf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching SSH action: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expandDelegateURL expands the variables in a HoldAndDelegate URL.
|
||||
// The variables match those used by SSH: $SRC_NODE_IP, $SRC_NODE_ID,
|
||||
// $DST_NODE_IP, $DST_NODE_ID, $SSH_USER, $LOCAL_USER.
|
||||
func expandDelegateURL(
|
||||
actionURL string,
|
||||
nm *netmap.NetworkMap,
|
||||
peerNode tailcfg.NodeView,
|
||||
peerAddr netip.Addr,
|
||||
sshUser string,
|
||||
lu *user.User,
|
||||
) string {
|
||||
var dstNodeID string
|
||||
if nm != nil {
|
||||
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID()))
|
||||
}
|
||||
var srcNodeID string
|
||||
if peerNode.Valid() {
|
||||
srcNodeID = fmt.Sprint(int64(peerNode.ID()))
|
||||
}
|
||||
var dstNodeIP string
|
||||
if nm != nil {
|
||||
addrs := nm.GetAddresses()
|
||||
for _, pfx := range addrs.All() {
|
||||
if pfx.IsSingleIP() {
|
||||
dstNodeIP = pfx.Addr().String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.NewReplacer(
|
||||
"$SRC_NODE_IP", url.QueryEscape(peerAddr.String()),
|
||||
"$SRC_NODE_ID", srcNodeID,
|
||||
"$DST_NODE_IP", url.QueryEscape(dstNodeIP),
|
||||
"$DST_NODE_ID", dstNodeID,
|
||||
"$SSH_USER", url.QueryEscape(sshUser),
|
||||
"$LOCAL_USER", url.QueryEscape(lu.Username),
|
||||
).Replace(actionURL)
|
||||
}
|
||||
|
||||
// fetchSSHAction polls a control plane URL over the Noise transport
|
||||
// and returns the SSHAction. It retries with exponential backoff on
|
||||
// transient errors, matching the behavior of SSH's fetchSSHAction.
|
||||
func fetchSSHAction(ctx context.Context, b *ipnlocal.LocalBackend, url string, logf func(string, ...any)) (*tailcfg.SSHAction, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
bo := backoff.NewBackoff("rsh-fetch-ssh-action", logf, 10*time.Second)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := b.DoNoiseRequest(req)
|
||||
if err != nil {
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if len(body) > 1<<10 {
|
||||
body = body[:1<<10]
|
||||
}
|
||||
logf("rsh: fetch of %v: %s, %s", url, res.Status, body)
|
||||
bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status))
|
||||
continue
|
||||
}
|
||||
a := new(tailcfg.SSHAction)
|
||||
err = json.NewDecoder(res.Body).Decode(a)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
logf("rsh: invalid SSHAction JSON from %v: %v", url, err)
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
|
||||
// pickListenAddr selects a local Tailscale IP address that matches the
|
||||
// address family of the peer. This ensures the data channel connection
|
||||
// uses the same protocol version.
|
||||
func pickListenAddr(nm *netmap.NetworkMap, peerAddr netip.Addr) (netip.Addr, error) {
|
||||
addrs := nm.GetAddresses()
|
||||
wantV4 := peerAddr.Is4()
|
||||
|
||||
for _, pfx := range addrs.All() {
|
||||
if !pfx.IsSingleIP() {
|
||||
continue
|
||||
}
|
||||
a := pfx.Addr()
|
||||
if wantV4 && a.Is4() {
|
||||
return a, nil
|
||||
}
|
||||
if !wantV4 && a.Is6() {
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
// Fallback: return any address.
|
||||
for _, pfx := range addrs.All() {
|
||||
if pfx.IsSingleIP() {
|
||||
return pfx.Addr(), nil
|
||||
}
|
||||
}
|
||||
return netip.Addr{}, fmt.Errorf("no Tailscale addresses available")
|
||||
}
|
||||
|
||||
// listenTailscale creates a TCP listener on the given Tailscale IP.
|
||||
// In netstack mode, it uses the gVisor stack via the netstackListenTCP hook.
|
||||
// In kernel TUN mode, it uses the standard library.
|
||||
func listenTailscale(b *ipnlocal.LocalBackend, addr netip.Addr) (net.Listener, error) {
|
||||
network := "tcp4"
|
||||
if addr.Is6() {
|
||||
network = "tcp6"
|
||||
}
|
||||
listenAddr := netip.AddrPortFrom(addr, 0).String()
|
||||
|
||||
if b.Sys().IsNetstack() {
|
||||
// In full netstack mode, we need to use the gVisor stack to listen
|
||||
// since all local IP traffic is handled by netstack.
|
||||
if netstackListenTCP == nil {
|
||||
return nil, fmt.Errorf("netstack listener not available (rsh_netstack not linked)")
|
||||
}
|
||||
return netstackListenTCP(b, network, listenAddr)
|
||||
}
|
||||
|
||||
// In kernel TUN mode, the Tailscale IP is assigned to the TUN device
|
||||
// and the kernel handles routing. Standard net.Listen works.
|
||||
return net.Listen(network, listenAddr)
|
||||
}
|
||||
|
||||
// handleRshSession is run in a goroutine. It accepts a single connection
|
||||
// from the listener, verifies the token and source, then spawns the
|
||||
// remote command via the incubator.
|
||||
func handleRshSession(
|
||||
b *ipnlocal.LocalBackend,
|
||||
ln net.Listener,
|
||||
token []byte,
|
||||
expectedPeer netip.Addr,
|
||||
lu *user.User,
|
||||
loginShell string,
|
||||
groupIDs []string,
|
||||
req rshRequest,
|
||||
ph ipnlocal.PeerAPIHandler,
|
||||
logf func(string, ...any),
|
||||
) {
|
||||
defer ln.Close()
|
||||
|
||||
// Set a deadline for the client to connect.
|
||||
if dl, ok := ln.(interface{ SetDeadline(time.Time) error }); ok {
|
||||
dl.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
logf("rsh: accept failed: %v", err)
|
||||
return
|
||||
}
|
||||
ln.Close() // Only accept one connection.
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
// Verify source IP.
|
||||
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logf("rsh: unexpected remote addr type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
logf("rsh: invalid remote IP")
|
||||
return
|
||||
}
|
||||
remoteIP = remoteIP.Unmap()
|
||||
if remoteIP != expectedPeer {
|
||||
logf("rsh: unexpected peer %v, expected %v", remoteIP, expectedPeer)
|
||||
return
|
||||
}
|
||||
|
||||
// Read and verify token.
|
||||
var gotToken [tokenLen]byte
|
||||
if _, err := io.ReadFull(conn, gotToken[:]); err != nil {
|
||||
logf("rsh: failed to read token: %v", err)
|
||||
return
|
||||
}
|
||||
if subtle.ConstantTimeCompare(gotToken[:], token) != 1 {
|
||||
logf("rsh: invalid token from %v", remoteIP)
|
||||
return
|
||||
}
|
||||
|
||||
// Set TCP_NODELAY for low-latency rsync control messages.
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
tc.SetNoDelay(true)
|
||||
}
|
||||
|
||||
logf("rsh: session accepted from %v as %s, command=%q", remoteIP, lu.Username, req.Command)
|
||||
|
||||
// Build and run the incubator command.
|
||||
runIncubator(b, conn, lu, loginShell, groupIDs, req, ph, logf)
|
||||
}
|
||||
|
||||
// runIncubator spawns the remote command using the existing SSH incubator
|
||||
// mechanism for privilege dropping and PAM integration.
|
||||
func runIncubator(
|
||||
b *ipnlocal.LocalBackend,
|
||||
conn net.Conn,
|
||||
lu *user.User,
|
||||
loginShell string,
|
||||
groupIDs []string,
|
||||
req rshRequest,
|
||||
ph ipnlocal.PeerAPIHandler,
|
||||
logf func(string, ...any),
|
||||
) {
|
||||
tailscaledPath, err := os.Executable()
|
||||
if err != nil {
|
||||
logf("rsh: os.Executable: %v", err)
|
||||
sendExitCode(conn, 1)
|
||||
return
|
||||
}
|
||||
|
||||
peerNode := ph.Peer()
|
||||
remoteUser := "unknown"
|
||||
if peerNode.Valid() {
|
||||
if peerNode.IsTagged() {
|
||||
remoteUser = strings.Join(peerNode.Tags().AsSlice(), ",")
|
||||
} else {
|
||||
_, uprof, ok := b.WhoIs("tcp", ph.RemoteAddr())
|
||||
if ok {
|
||||
remoteUser = uprof.LoginName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups := strings.Join(groupIDs, ",")
|
||||
isShell := req.Command == ""
|
||||
|
||||
incubatorArgs := []string{
|
||||
"be-child",
|
||||
"ssh",
|
||||
"--login-shell=" + loginShell,
|
||||
"--uid=" + lu.Uid,
|
||||
"--gid=" + lu.Gid,
|
||||
"--groups=" + groups,
|
||||
"--local-user=" + lu.Username,
|
||||
"--home-dir=" + lu.HomeDir,
|
||||
"--remote-user=" + remoteUser,
|
||||
"--remote-ip=" + ph.RemoteAddr().Addr().String(),
|
||||
"--has-tty=false",
|
||||
"--tty-name=",
|
||||
}
|
||||
|
||||
if runtime.GOOS == linux && hostinfo.IsSELinuxEnforcing() {
|
||||
incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing")
|
||||
}
|
||||
|
||||
nm := b.NetMap()
|
||||
if nm != nil && nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) {
|
||||
incubatorArgs = append(incubatorArgs, "--force-v1-behavior")
|
||||
}
|
||||
|
||||
if isShell {
|
||||
incubatorArgs = append(incubatorArgs, "--shell")
|
||||
} else {
|
||||
incubatorArgs = append(incubatorArgs, "--cmd="+req.Command)
|
||||
}
|
||||
|
||||
cmd := exec.Command(tailscaledPath, incubatorArgs...)
|
||||
cmd.Dir = "/"
|
||||
|
||||
// Set up the environment for the child.
|
||||
cmd.Env = []string{
|
||||
"SHELL=" + loginShell,
|
||||
"USER=" + lu.Username,
|
||||
"HOME=" + lu.HomeDir,
|
||||
"PATH=" + defaultPathForUser(lu),
|
||||
}
|
||||
|
||||
// Create stdin/stdout/stderr pipes.
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
logf("rsh: stdin pipe: %v", err)
|
||||
sendExitCode(conn, 1)
|
||||
return
|
||||
}
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
logf("rsh: stdout pipe: %v", err)
|
||||
sendExitCode(conn, 1)
|
||||
return
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
logf("rsh: stderr pipe: %v", err)
|
||||
sendExitCode(conn, 1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
logf("rsh: start incubator: %v", err)
|
||||
sendExitCode(conn, 1)
|
||||
return
|
||||
}
|
||||
|
||||
fw := newFrameWriter(conn)
|
||||
fr := newFrameReader(conn)
|
||||
|
||||
// Goroutine: read frames from client, write stdin to incubator.
|
||||
stdinDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(stdinDone)
|
||||
defer stdinPipe.Close()
|
||||
for {
|
||||
ch, data, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ch == ChanStdin {
|
||||
if _, err := stdinPipe.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine: read stdout from incubator, write frames to client.
|
||||
stdoutDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(stdoutDone)
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := stdoutPipe.Read(buf)
|
||||
if n > 0 {
|
||||
if werr := fw.WriteFrame(ChanStdout, buf[:n]); werr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine: read stderr from incubator, write frames to client.
|
||||
stderrDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(stderrDone)
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := stderrPipe.Read(buf)
|
||||
if n > 0 {
|
||||
if werr := fw.WriteFrame(ChanStderr, buf[:n]); werr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for the process to exit.
|
||||
exitCode := 0
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
logf("rsh: wait: %v", err)
|
||||
exitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for output goroutines to drain.
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
|
||||
// Send exit code and close.
|
||||
logf("rsh: session ended for %s, exit code %d", lu.Username, exitCode)
|
||||
fw.WriteExitCode(exitCode)
|
||||
}
|
||||
|
||||
// sendExitCode is a helper used before the framing writer is set up.
|
||||
func sendExitCode(conn net.Conn, code int) {
|
||||
fw := newFrameWriter(conn)
|
||||
fw.WriteExitCode(code)
|
||||
}
|
||||
|
||||
// defaultPathForUser returns an appropriate default PATH for the user.
|
||||
// This is a simplified version of the logic in ssh/tailssh/user.go.
|
||||
func defaultPathForUser(u *user.User) string {
|
||||
if u.Uid == "0" {
|
||||
return "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
}
|
||||
return "/usr/local/bin:/usr/bin:/bin"
|
||||
}
|
||||
|
||||
// envknobs for debugging.
|
||||
var rshVerbose = envknob.RegisterBool("TS_DEBUG_RSH_VLOG")
|
||||
30
feature/rsh/rsh_netstack.go
Normal file
30
feature/rsh/rsh_netstack.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd
|
||||
|
||||
package rsh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/wgengine/netstack"
|
||||
)
|
||||
|
||||
func init() {
|
||||
netstackListenTCP = netstackListenTCPImpl
|
||||
}
|
||||
|
||||
func netstackListenTCPImpl(b *ipnlocal.LocalBackend, network, address string) (net.Listener, error) {
|
||||
ns, ok := b.Sys().Netstack.GetOK()
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
// Type-assert to *netstack.Impl which has the ListenTCP method.
|
||||
impl, ok := ns.(*netstack.Impl)
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return impl.ListenTCP(network, address)
|
||||
}
|
||||
@ -653,6 +653,13 @@ func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) {
|
||||
func (b *LocalBackend) Clock() tstime.Clock { return b.clock }
|
||||
func (b *LocalBackend) Sys() *tsd.System { return b.sys }
|
||||
|
||||
// PeerAPIBase returns the "http://ip:port" URL base to reach a peer's PeerAPI.
|
||||
// It returns the empty string if the peer doesn't support PeerAPI or there's
|
||||
// no matching address family.
|
||||
func (b *LocalBackend) PeerAPIBase(peer tailcfg.NodeView) string {
|
||||
return peerAPIBase(b.NetMap(), peer)
|
||||
}
|
||||
|
||||
// NodeBackend returns the current node's NodeBackend interface.
|
||||
func (b *LocalBackend) NodeBackend() ipnext.NodeBackend {
|
||||
return b.currentNode()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user