From 5f256f114f178da7bf59da97537ab77052b379dc Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Sun, 26 Feb 2023 17:32:36 -0500 Subject: [PATCH] net/pidlisten: new package that restricts dials to the current process To be used in the C library wrapping tsnet to provide LocalAPI access. This commit contains a linux implementation. More operating systems to follow. Signed-off-by: David Crawshaw --- net/pidlisten/pidlisten.go | 42 ++++++++++ net/pidlisten/pidlisten_linux.go | 63 +++++++++++++++ net/pidlisten/pidlisten_noimpl.go | 13 ++++ net/pidlisten/pidlisten_test.go | 122 ++++++++++++++++++++++++++++++ 4 files changed, 240 insertions(+) create mode 100644 net/pidlisten/pidlisten.go create mode 100644 net/pidlisten/pidlisten_linux.go create mode 100644 net/pidlisten/pidlisten_noimpl.go create mode 100644 net/pidlisten/pidlisten_test.go diff --git a/net/pidlisten/pidlisten.go b/net/pidlisten/pidlisten.go new file mode 100644 index 000000000..520ef7765 --- /dev/null +++ b/net/pidlisten/pidlisten.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package pidlisten implements a TCP listener that only +// accepts connections from the current process. +package pidlisten + +import ( + "fmt" + "net" +) + +type listener struct { + ln net.Listener +} + +func (pln *listener) Accept() (net.Conn, error) { + for { + conn, err := pln.ln.Accept() + if err != nil { + return nil, err + } + ok, err := checkPIDLocal(conn) + if err != nil { + conn.Close() + return nil, fmt.Errorf("pidlisten: %w", err) + } + if !ok { + conn.Close() + continue + } + return conn, nil + } +} + +func (pln *listener) Close() error { + return pln.ln.Close() +} + +func (pln *listener) Addr() net.Addr { + return pln.ln.Addr() +} diff --git a/net/pidlisten/pidlisten_linux.go b/net/pidlisten/pidlisten_linux.go new file mode 100644 index 000000000..6e4ffbdf2 --- /dev/null +++ b/net/pidlisten/pidlisten_linux.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package pidlisten + +import ( + "errors" + "fmt" + "go4.org/mem" + "io/fs" + "net" + "os" + "path/filepath" + "tailscale.com/util/dirwalk" + + "github.com/vishvananda/netlink" +) + +// NewPIDListener wraps a net.Listener so that it only accepts connections from the current process. +func NewPIDListener(ln net.Listener) net.Listener { + return &listener{ln: ln} +} + +var errFoundSocket = errors.New("found socket") + +func checkPIDLocal(conn net.Conn) (bool, error) { + remoteAddr := conn.RemoteAddr() + var remoteIP net.IP + switch remoteAddr.Network() { + case "tcp": + remoteIP = remoteAddr.(*net.TCPAddr).IP + case "udp": + remoteIP = remoteAddr.(*net.UDPAddr).IP + default: + return false, nil + } + if !remoteIP.IsLoopback() { + return false, nil + } + + // You can look up a net.Conn in both directions. + // There are different inodes for remote->local and local->remote. + // We want to look up the starting side of the net.Conn and check + // that its inode belongs to the current PID. + s, err := netlink.SocketGet(conn.RemoteAddr(), conn.LocalAddr()) + if err != nil { + return false, err + } + + want := fmt.Sprintf("socket:[%d]", s.INode) + dir := fmt.Sprintf("/proc/%d/fd", os.Getpid()) + err = dirwalk.WalkShallow(mem.S(dir), func(name mem.RO, de fs.DirEntry) error { + n, err := os.Readlink(filepath.Join(dir, name.StringCopy())) + if err == nil && want == n { + return errFoundSocket + } + return nil + }) + if err == errFoundSocket { + return true, nil + } + return false, err +} diff --git a/net/pidlisten/pidlisten_noimpl.go b/net/pidlisten/pidlisten_noimpl.go new file mode 100644 index 000000000..395425731 --- /dev/null +++ b/net/pidlisten/pidlisten_noimpl.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux +// +build !linux + +package pidlisten + +import "net" + +func checkPIDLocal(conn net.Conn) (bool, error) { + panic("not implemented") +} diff --git a/net/pidlisten/pidlisten_test.go b/net/pidlisten/pidlisten_test.go new file mode 100644 index 000000000..cce02c2b8 --- /dev/null +++ b/net/pidlisten/pidlisten_test.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux +// +build linux + +package pidlisten + +import ( + "errors" + "flag" + "fmt" + "io" + "net" + "os" + "os/exec" + "testing" + "time" +) + +var flagDial = flag.String("dial", "", "if set, dials the given addr and reads until close") + +func TestMain(m *testing.M) { + flag.Parse() + if *flagDial != "" { + conn, err := net.DialTimeout("tcp", *flagDial, 5*time.Second) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + conn.SetDeadline(time.Now().Add(5 * time.Second)) + b, err := io.ReadAll(conn) + fmt.Fprintf(os.Stderr, "%s", b) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + os.Exit(0) + } + os.Exit(m.Run()) +} + +func TestPIDLocal(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer clientConn.Close() + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + ok, err := checkPIDLocal(conn) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("checkPIDLocal=false, want true") + } +} + +func testExternalProcess(t *testing.T, ln net.Listener) string { + go func() { + for { + c, err := ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + panic(err) + } + fmt.Fprintf(c, "hello\n") + c.Close() + } + }() + + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + + out, err := exec.Command(exe, "-dial="+ln.Addr().String()).CombinedOutput() + if err != nil { + t.Fatal(err) + } + return string(out) +} + +func TestExternalDialWorks(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + out := testExternalProcess(t, ln) + if out != "hello\n" { + t.Errorf("out=%q, want hello", out) + } +} + +func TestPIDExternal(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + ln = NewPIDListener(ln) + out := testExternalProcess(t, ln) + + if len(out) != 0 { + t.Errorf("unexpected socket output: %q", out) + } +}