mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 20:26:47 +02:00
cmd/tswrap: command to run a child process and make it accessible over Tailscale.
Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
21ef7e5c35
commit
2a619d3bcf
326
cmd/tswrap/main.go
Normal file
326
cmd/tswrap/main.go
Normal file
@ -0,0 +1,326 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
|
||||
// The tswrap binary runs a child process and makes it accessible over
|
||||
// Tailscale.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tsnet"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
tsDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
|
||||
verbose = flag.Bool("verbose", false, "Output tailscaled logs to stderr")
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigch := make(chan os.Signal, 1)
|
||||
signal.Notify(sigch, os.Interrupt, unix.SIGTERM)
|
||||
|
||||
flag.Parse()
|
||||
|
||||
argv := flag.Args()
|
||||
|
||||
if len(argv) < 2 {
|
||||
log.Fatalf("Usage: %s tailscale-host:port child-cmd...", os.Args[0])
|
||||
}
|
||||
|
||||
p := proxy{
|
||||
ListenAddr: argv[0],
|
||||
Command: argv[1:],
|
||||
AuthKey: os.Getenv("TS_AUTHKEY"),
|
||||
Dir: *tsDir,
|
||||
Verbose: *verbose,
|
||||
}
|
||||
|
||||
if err := p.Start(); err != nil {
|
||||
log.Fatalf("Failed to start tswrap: %v", err)
|
||||
}
|
||||
go func() {
|
||||
<-sigch
|
||||
p.Stop()
|
||||
}()
|
||||
|
||||
p.Wait()
|
||||
}
|
||||
|
||||
type proxy struct {
|
||||
ListenAddr string
|
||||
Command []string
|
||||
AuthKey string
|
||||
Dir string
|
||||
Verbose bool
|
||||
|
||||
shutdownCtx context.Context
|
||||
startShutdown context.CancelFunc
|
||||
srv *tsnet.Server
|
||||
client *tailscale.LocalClient
|
||||
ln net.Listener
|
||||
cmd *exec.Cmd
|
||||
ports syncs.AtomicValue[[]int]
|
||||
}
|
||||
|
||||
func (p *proxy) Start() error {
|
||||
host, port, err := net.SplitHostPort(p.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing %q: %v", p.ListenAddr, err)
|
||||
}
|
||||
if _, err := strconv.Atoi(port); err != nil {
|
||||
return fmt.Errorf("parsing port number %q: %v", port, err)
|
||||
}
|
||||
|
||||
if p.Dir == "" && p.AuthKey == "" {
|
||||
return errors.New("must provide either a TS_AUTHKEY or a state storage dir")
|
||||
}
|
||||
|
||||
p.srv = &tsnet.Server{
|
||||
Hostname: host,
|
||||
AuthKey: p.AuthKey,
|
||||
Logf: logger.Discard,
|
||||
Dir: p.Dir,
|
||||
}
|
||||
if p.Dir == "" {
|
||||
p.srv.Store = new(mem.Store)
|
||||
p.srv.Ephemeral = true
|
||||
}
|
||||
if p.Verbose {
|
||||
p.srv.Logf = log.Printf
|
||||
}
|
||||
|
||||
p.shutdownCtx, p.startShutdown = context.WithCancel(context.Background())
|
||||
|
||||
p.client, err = p.srv.LocalClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting tsnet server failed: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
looped = false
|
||||
authURLShown = false
|
||||
status *ipnstate.Status
|
||||
)
|
||||
loginLoop:
|
||||
for {
|
||||
if looped {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
looped = true
|
||||
|
||||
status, err = p.client.Status(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting tsnet status: %v", err)
|
||||
}
|
||||
|
||||
switch status.BackendState {
|
||||
case "Running":
|
||||
if status.Self == nil || status.Self.DNSName == "" {
|
||||
// No known DNS name yet, keep going
|
||||
continue
|
||||
}
|
||||
break loginLoop
|
||||
case "NeedsLogin":
|
||||
if status.AuthURL != "" && p.AuthKey != "" {
|
||||
return errors.New("failed to auth with provided authkey")
|
||||
}
|
||||
if status.AuthURL != "" && !authURLShown {
|
||||
log.Printf("To log into Tailscale, please visit: %s", status.AuthURL)
|
||||
authURLShown = true
|
||||
}
|
||||
default:
|
||||
// Just keep trying, eventually we should get into either
|
||||
// NeedsLogin or Running.
|
||||
}
|
||||
}
|
||||
|
||||
addr := ":" + port
|
||||
p.ln, err = p.srv.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tailscale listen on %q: %v", addr, err)
|
||||
}
|
||||
|
||||
log.Printf("Listening on %s:%s", status.Self.DNSName, port)
|
||||
|
||||
p.cmd = exec.Command(p.Command[0], p.Command[1:]...)
|
||||
p.cmd.Stdin = os.Stdin
|
||||
p.cmd.Stdout = os.Stdout
|
||||
p.cmd.Stderr = os.Stderr
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting child failed: %v", err)
|
||||
}
|
||||
|
||||
go p.listen()
|
||||
go p.waitForChildExit()
|
||||
go p.monitorChildPorts()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxy) Stop() {
|
||||
p.startShutdown()
|
||||
}
|
||||
|
||||
func (p *proxy) Wait() {
|
||||
<-p.shutdownCtx.Done()
|
||||
p.cmd.Process.Signal(unix.SIGTERM)
|
||||
p.ln.Close()
|
||||
if p.srv.Ephemeral {
|
||||
p.client.Logout(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *proxy) listen() {
|
||||
for {
|
||||
conn, err := p.ln.Accept()
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Printf("accept: %v", err)
|
||||
p.startShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := p.proxy(conn); err != nil {
|
||||
log.Printf("proxying %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *proxy) proxy(conn net.Conn) error {
|
||||
defer conn.Close()
|
||||
ports, err := p.getPorts()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ports) > 1 {
|
||||
log.Printf("warning: multiple listening ports found on child, proxying to lowest one (%d)", ports[0])
|
||||
}
|
||||
|
||||
prox, err := net.Dial("tcp", net.JoinHostPort("localhost", strconv.Itoa(ports[0])))
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing child port %d: %v", ports[0], err)
|
||||
}
|
||||
defer prox.Close()
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go proxyCopy(errc, conn, prox)
|
||||
go proxyCopy(errc, prox, conn)
|
||||
<-errc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxy) getPorts() ([]int, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for ctx.Err() == nil {
|
||||
if ports := p.ports.Load(); len(ports) > 0 {
|
||||
return ports, nil
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
return nil, errors.New("timed out waiting for child listening ports")
|
||||
}
|
||||
|
||||
func (p *proxy) waitForChildExit() {
|
||||
if err := p.cmd.Wait(); err != nil {
|
||||
log.Printf("child exited with error: %v", err)
|
||||
} else {
|
||||
log.Printf("child exited, shutting down")
|
||||
}
|
||||
p.startShutdown()
|
||||
}
|
||||
|
||||
func (p *proxy) monitorChildPorts() {
|
||||
for p.shutdownCtx.Err() == nil {
|
||||
ports, err := portsOfCmd(p.cmd)
|
||||
if err == nil {
|
||||
p.ports.Store(ports)
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-p.shutdownCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCopy(errc chan<- error, dst, src net.Conn) {
|
||||
// TODO: still need the unwrap hack from tcpproxy? Or is io.Copy
|
||||
// smart now?
|
||||
_, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func portsOfCmd(cmd *exec.Cmd) (ports []int, err error) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil, errors.New("no process")
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
wantSub := fmt.Sprintf(" %d/", pid)
|
||||
|
||||
ns := exec.Command("netstat", "-p", "--inet", "-l", "-n")
|
||||
out, err := ns.Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bs := bufio.NewScanner(bytes.NewReader(out))
|
||||
for bs.Scan() {
|
||||
line := bs.Text()
|
||||
if !strings.HasPrefix(line, "tcp") ||
|
||||
!strings.Contains(line, "LISTEN") ||
|
||||
!strings.Contains(line, wantSub) {
|
||||
continue
|
||||
}
|
||||
f := strings.Fields(line)
|
||||
if len(f) < 4 {
|
||||
continue
|
||||
}
|
||||
ipp, err := netip.ParseAddrPort(f[3])
|
||||
if err == nil {
|
||||
ports = append(ports, int(ipp.Port()))
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := bs.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
return nil, errors.New("no listening ports found")
|
||||
}
|
||||
sort.Ints(ports)
|
||||
return ports, nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user