From 980cbd790d5b84b6578711cdc676eeeaeec3d47a Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 29 Sep 2023 08:30:04 +0100 Subject: [PATCH] util/linuxfw,wgengine/router: export some of the functionality So that it can be used by iptables/nftables in containerboot Signed-off-by: Irbe Krumina --- util/linuxfw/fakes.go | 214 +++++++++++++++++++++++++++ util/linuxfw/iptables_runner.go | 8 +- util/linuxfw/iptables_runner_test.go | 132 +---------------- util/linuxfw/linuxfw.go | 4 +- util/linuxfw/nftables_runner.go | 125 +++++++++------- util/linuxfw/nftables_runner_test.go | 105 ++----------- wgengine/router/router_linux.go | 12 +- 7 files changed, 311 insertions(+), 289 deletions(-) create mode 100644 util/linuxfw/fakes.go diff --git a/util/linuxfw/fakes.go b/util/linuxfw/fakes.go new file mode 100644 index 000000000..8a43a1343 --- /dev/null +++ b/util/linuxfw/fakes.go @@ -0,0 +1,214 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "strings" + "testing" + + "github.com/google/nftables" + "github.com/mdlayher/netlink" + "github.com/vishvananda/netns" +) + +var errExec = errors.New("execution failed") + +type fakeIPTables struct { + t *testing.T + n map[string][]string +} + +type fakeRule struct { + table, chain string + args []string +} + +func NewIPTables(t *testing.T) *fakeIPTables { + return &fakeIPTables{ + t: t, + n: map[string][]string{ + "filter/INPUT": nil, + "filter/OUTPUT": nil, + "filter/FORWARD": nil, + "nat/PREROUTING": nil, + "nat/OUTPUT": nil, + "nat/POSTROUTING": nil, + }, + } +} + +func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if pos > len(rules)+1 { + n.t.Errorf("bad position %d in %s", pos, k) + return errExec + } + rules = append(rules, "") + copy(rules[pos:], rules[pos-1:]) + rules[pos-1] = strings.Join(args, " ") + n.n[k] = rules + } else { + n.t.Errorf("unknown table/chain %s", k) + return errExec + } + return nil +} + +func (n *fakeIPTables) Append(table, chain string, args ...string) error { + k := table + "/" + chain + return n.Insert(table, chain, len(n.n[k])+1, args...) +} + +func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for _, rule := range rules { + if rule == strings.Join(args, " ") { + return true, nil + } + } + return false, nil + } else { + n.t.Logf("unknown table/chain %s", k) + return false, errExec + } +} + +func (n *fakeIPTables) Delete(table, chain string, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for i, rule := range rules { + if rule == strings.Join(args, " ") { + rules = append(rules[:i], rules[i+1:]...) + n.n[k] = rules + return nil + } + } + n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) + return errExec + } else { + n.t.Errorf("unknown table/chain %s", k) + return errExec + } +} + +func (n *fakeIPTables) ClearChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + n.n[k] = nil + return nil + } else { + n.t.Logf("note: ClearChain: unknown table/chain %s", k) + return errors.New("exitcode:1") + } +} + +func (n *fakeIPTables) NewChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + n.t.Errorf("table/chain %s already exists", k) + return errExec + } + n.n[k] = nil + return nil +} + +func (n *fakeIPTables) DeleteChain(table, chain string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if len(rules) != 0 { + n.t.Errorf("%s is not empty", k) + return errExec + } + delete(n.n, k) + return nil + } else { + n.t.Errorf("%s does not exist", k) + return errExec + } +} + +func NewTestConn(t *testing.T, want [][]byte) *nftables.Conn { + conn, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + })) + if err != nil { + t.Fatal(err) + } + return conn +} + +func cleanupSysConn(t *testing.T, ns netns.NsHandle) { + defer runtime.UnlockOSThread() + + if err := ns.Close(); err != nil { + t.Fatalf("newNS.Close() failed: %v", err) + } +} + +// linediff returns a side-by-side diff of two nfdump() return values, flagging +// lines which are not equal with an exclamation point prefix. +func linediff(a, b string) string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "got -- want\n") + linesA := strings.Split(a, "\n") + linesB := strings.Split(b, "\n") + for idx, lineA := range linesA { + if idx >= len(linesB) { + break + } + lineB := linesB[idx] + prefix := "! " + if lineA == lineB { + prefix = " " + } + fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB) + } + return buf.String() +} + +// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing +// users to make sense of large byte literals more easily. +func nfdump(b []byte) string { + var buf bytes.Buffer + i := 0 + for ; i < len(b); i += 4 { + // TODO: show printable characters as ASCII + fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", + b[i], + b[i+1], + b[i+2], + b[i+3]) + } + for ; i < len(b); i++ { + fmt.Fprintf(&buf, "%02x ", b[i]) + } + return buf.String() +} diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index 14f2fa536..5ef2a165a 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -37,7 +37,7 @@ type iptablesRunner struct { v6NATAvailable bool } -func checkIP6TablesExists() error { +func CheckIP6TablesExists() error { // Some distros ship ip6tables separately from iptables. if _, err := exec.LookPath("ip6tables"); err != nil { return fmt.Errorf("path not found: %w", err) @@ -56,8 +56,8 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { } supportsV6, supportsV6NAT := false, false - v6err := checkIPv6(logf) - ip6terr := checkIP6TablesExists() + v6err := CheckIPv6(logf) + ip6terr := CheckIP6TablesExists() switch { case v6err != nil: logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) @@ -65,7 +65,7 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr) default: supportsV6 = true - supportsV6NAT = supportsV6 && checkSupportsV6NAT() + supportsV6NAT = supportsV6 && CheckSupportsV6NAT() logf("v6nat = %v", supportsV6NAT) } diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go index e294f064b..9694bcc62 100644 --- a/util/linuxfw/iptables_runner_test.go +++ b/util/linuxfw/iptables_runner_test.go @@ -6,7 +6,6 @@ package linuxfw import ( - "errors" "net/netip" "strings" "testing" @@ -14,136 +13,9 @@ import ( "tailscale.com/net/tsaddr" ) -var errExec = errors.New("execution failed") - -type fakeIPTables struct { - t *testing.T - n map[string][]string -} - -type fakeRule struct { - table, chain string - args []string -} - -func newIPTables(t *testing.T) *fakeIPTables { - return &fakeIPTables{ - t: t, - n: map[string][]string{ - "filter/INPUT": nil, - "filter/OUTPUT": nil, - "filter/FORWARD": nil, - "nat/PREROUTING": nil, - "nat/OUTPUT": nil, - "nat/POSTROUTING": nil, - }, - } -} - -func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if pos > len(rules)+1 { - n.t.Errorf("bad position %d in %s", pos, k) - return errExec - } - rules = append(rules, "") - copy(rules[pos:], rules[pos-1:]) - rules[pos-1] = strings.Join(args, " ") - n.n[k] = rules - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec - } - return nil -} - -func (n *fakeIPTables) Append(table, chain string, args ...string) error { - k := table + "/" + chain - return n.Insert(table, chain, len(n.n[k])+1, args...) -} - -func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for _, rule := range rules { - if rule == strings.Join(args, " ") { - return true, nil - } - } - return false, nil - } else { - n.t.Logf("unknown table/chain %s", k) - return false, errExec - } -} - -func hasChain(n *fakeIPTables, table, chain string) bool { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - return true - } else { - return false - } -} - -func (n *fakeIPTables) Delete(table, chain string, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for i, rule := range rules { - if rule == strings.Join(args, " ") { - rules = append(rules[:i], rules[i+1:]...) - n.n[k] = rules - return nil - } - } - n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) - return errExec - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec - } -} - -func (n *fakeIPTables) ClearChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.n[k] = nil - return nil - } else { - n.t.Logf("note: ClearChain: unknown table/chain %s", k) - return errors.New("exitcode:1") - } -} - -func (n *fakeIPTables) NewChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.t.Errorf("table/chain %s already exists", k) - return errExec - } - n.n[k] = nil - return nil -} - -func (n *fakeIPTables) DeleteChain(table, chain string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if len(rules) != 0 { - n.t.Errorf("%s is not empty", k) - return errExec - } - delete(n.n, k) - return nil - } else { - n.t.Errorf("%s does not exist", k) - return errExec - } -} - func newFakeIPTablesRunner(t *testing.T) *iptablesRunner { - ipt4 := newIPTables(t) - ipt6 := newIPTables(t) + ipt4 := NewIPTables(t) + ipt6 := NewIPTables(t) iptr := &iptablesRunner{ipt4, ipt6, true, true} return iptr diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index e381e1f52..0ad430573 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -131,7 +131,7 @@ func errCode(err error) int { // missing. It does not check that IPv6 is currently functional or // that there's a global address, just that the system would support // IPv6 if it were on an IPv6 network. -func checkIPv6(logf logger.Logf) error { +func CheckIPv6(logf logger.Logf) error { _, err := os.Stat("/proc/sys/net/ipv6") if os.IsNotExist(err) { return err @@ -176,7 +176,7 @@ func checkIPv6(logf logger.Logf) error { // The nat table was added after the initial release of ipv6 // netfilter, so some older distros ship a kernel that can't NAT IPv6 // traffic. -func checkSupportsV6NAT() bool { +func CheckSupportsV6NAT() bool { bs, err := os.ReadFile("/proc/net/ip6_tables_names") if err != nil { // Can't read the file. Assume SNAT works. diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 9f56c5423..5e03f7ac1 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -30,13 +30,13 @@ const ( // chainTypeRegular is an nftables chain that does not apply to a hook. const chainTypeRegular = "" -type chainInfo struct { - table *nftables.Table - name string - chainType nftables.ChainType - chainHook *nftables.ChainHook - chainPriority *nftables.ChainPriority - chainPolicy *nftables.ChainPolicy +type ChainInfo struct { + Table *nftables.Table + Name string + ChainType nftables.ChainType + ChainHook *nftables.ChainHook + ChainPriority *nftables.ChainPriority + ChainPolicy *nftables.ChainPolicy } type nftable struct { @@ -45,6 +45,21 @@ type nftable struct { Nat *nftables.Table } +type Conn interface { + ListChainsOfTableFamily(nftables.TableFamily) ([]*nftables.Chain, error) + ListTables() ([]*nftables.Table, error) + AddTable(*nftables.Table) *nftables.Table + DelTable(*nftables.Table) + AddChain(*nftables.Chain) *nftables.Chain + DelChain(*nftables.Chain) + FlushChain(*nftables.Chain) + GetRules(*nftables.Table, *nftables.Chain) ([]*nftables.Rule, error) + InsertRule(*nftables.Rule) *nftables.Rule + AddRule(*nftables.Rule) *nftables.Rule + DelRule(*nftables.Rule) error + Flush() error +} + // nftablesRunner implements a netfilterRunner using the netlink based nftables // library. As nftables allows for arbitrary tables and chains, there is a need // to follow conventions in order to integrate well with a surrounding @@ -70,7 +85,7 @@ type nftablesRunner struct { } // createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family. -func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { +func CreateTableIfNotExist(c Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { tables, err := c.ListTables() if err != nil { return nil, fmt.Errorf("get tables: %w", err) @@ -102,7 +117,7 @@ func (e errorChainNotFound) Error() string { // getChainFromTable returns the chain with the given name from the given table. // Note that a chain name is unique within a table. -func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { +func GetChainFromTable(c Conn, table *nftables.Table, name string) (*nftables.Chain, error) { chains, err := c.ListChainsOfTableFamily(table.Family) if err != nil { return nil, fmt.Errorf("list chains: %w", err) @@ -144,35 +159,35 @@ func isTSChain(name string) bool { // createChainIfNotExist creates a chain with the given name in the given table // if it does not exist. -func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { - chain, err := getChainFromTable(c, cinfo.table, cinfo.name) - if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { - return fmt.Errorf("get chain: %w", err) +func CreateChainIfNotExist(c Conn, cinfo ChainInfo) (*nftables.Chain, error) { + chain, err := GetChainFromTable(c, cinfo.Table, cinfo.Name) + if err != nil && !errors.Is(err, errorChainNotFound{cinfo.Table.Name, cinfo.Name}) { + return nil, fmt.Errorf("get chain: %w", err) } else if err == nil { // The chain already exists. If it is a TS chain, check the // type/hook/priority, but for "conventional chains" assume they're what // we expect (in case iptables-nft/ufw make minor behavior changes in // the future). - if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) { - return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) + if isTSChain(chain.Name) && (chain.Type != cinfo.ChainType || chain.Hooknum != cinfo.ChainHook || chain.Priority != cinfo.ChainPriority) { + return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.Name) } - return nil + return chain, nil } - _ = c.AddChain(&nftables.Chain{ - Name: cinfo.name, - Table: cinfo.table, - Type: cinfo.chainType, - Hooknum: cinfo.chainHook, - Priority: cinfo.chainPriority, - Policy: cinfo.chainPolicy, + chain = c.AddChain(&nftables.Chain{ + Name: cinfo.Name, + Table: cinfo.Table, + Type: cinfo.ChainType, + Hooknum: cinfo.ChainHook, + Priority: cinfo.ChainPriority, + Policy: cinfo.ChainPolicy, }) if err := c.Flush(); err != nil { - return fmt.Errorf("add chain: %w", err) + return nil, fmt.Errorf("add chain: %w", err) } - return nil + return chain, nil } // NewNfTablesRunner creates a new nftablesRunner without guaranteeing @@ -184,12 +199,12 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { } nft4 := &nftable{Proto: nftables.TableFamilyIPv4} - v6err := checkIPv6(logf) + v6err := CheckIPv6(logf) if v6err != nil { logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) } supportsV6 := v6err == nil - supportsV6NAT := supportsV6 && checkSupportsV6NAT() + supportsV6NAT := supportsV6 && CheckSupportsV6NAT() var nft6 *nftable if supportsV6 { @@ -208,7 +223,7 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { }, nil } -// newLoadSaddrExpr creates a new nftables expression that loads the source +// NewLoadSaddrExpr creates a new nftables expression that loads the source // address of the packet into the given register. func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { switch proto { @@ -358,7 +373,7 @@ func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable { func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { nf := n.getNFTByAddr(addr) - inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %w", err) } @@ -375,7 +390,7 @@ func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { nf := n.getNFTByAddr(addr) - inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %w", err) } @@ -428,23 +443,23 @@ func (n *nftablesRunner) AddChains() error { // as the name used by iptables-nft and ufw. We install rules into the // same conventional table so that `accept` verdicts from our jump // chains are conclusive. - filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") + filter, err := CreateTableIfNotExist(n.conn, table.Proto, "filter") if err != nil { return fmt.Errorf("create table: %w", err) } table.Filter = filter // Adding the "conventional chains" that are used by iptables-nft and ufw. - if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { return fmt.Errorf("create forward chain: %w", err) } - if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { return fmt.Errorf("create input chain: %w", err) } // Adding the tailscale chains that contain our rules. - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create forward chain: %w", err) } - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create input chain: %w", err) } } @@ -454,17 +469,17 @@ func (n *nftablesRunner) AddChains() error { // as the name used by iptables-nft and ufw. We install rules into the // same conventional table so that `accept` verdicts from our jump // chains are conclusive. - nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") + nat, err := CreateTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return fmt.Errorf("create table: %w", err) } table.Nat = nat // Adding the "conventional chains" that are used by iptables-nft and ufw. - if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } // Adding the tailscale chain that contains our rules. - if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { + if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } } @@ -474,7 +489,7 @@ func (n *nftablesRunner) AddChains() error { // deleteChainIfExists deletes a chain if it exists. func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { - chain, err := getChainFromTable(c, table, name) + chain, err := GetChainFromTable(c, table, name) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { return fmt.Errorf("get chain: %w", err) } else if err != nil { @@ -557,7 +572,7 @@ func (n *nftablesRunner) AddHooks() error { conn := n.conn for _, table := range n.getTables() { - inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } @@ -565,7 +580,7 @@ func (n *nftablesRunner) AddHooks() error { if err != nil { return fmt.Errorf("Addhook: %w", err) } - forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD") if err != nil { return fmt.Errorf("get FORWARD chain: %w", err) } @@ -576,7 +591,7 @@ func (n *nftablesRunner) AddHooks() error { } for _, table := range n.getNATTables() { - postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } @@ -613,7 +628,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error { conn := n.conn for _, table := range n.getTables() { - inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } @@ -621,7 +636,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error { if err != nil { return fmt.Errorf("delhook: %w", err) } - forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD") if err != nil { return fmt.Errorf("get FORWARD chain: %w", err) } @@ -632,7 +647,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error { } for _, table := range n.getNATTables() { - postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } @@ -894,7 +909,7 @@ func (n *nftablesRunner) AddBase(tunname string) error { func (n *nftablesRunner) addBase4(tunname string) error { conn := n.conn - inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) + inputChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain v4: %v", err) } @@ -905,7 +920,7 @@ func (n *nftablesRunner) addBase4(tunname string) error { return fmt.Errorf("add drop cgnat range rule v4: %w", err) } - forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) + forwardChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain v4: %v", err) } @@ -937,7 +952,7 @@ func (n *nftablesRunner) addBase4(tunname string) error { func (n *nftablesRunner) addBase6(tunname string) error { conn := n.conn - forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) + forwardChain, err := GetChainFromTable(conn, n.nft6.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain v6: %w", err) } @@ -967,12 +982,12 @@ func (n *nftablesRunner) DelBase() error { conn := n.conn for _, table := range n.getTables() { - inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) + inputChain, err := GetChainFromTable(conn, table.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %v", err) } conn.FlushChain(inputChain) - forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward) + forwardChain, err := GetChainFromTable(conn, table.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain: %v", err) } @@ -980,7 +995,7 @@ func (n *nftablesRunner) DelBase() error { } for _, table := range n.getNATTables() { - postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + postrouteChain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %v", err) } @@ -1050,7 +1065,7 @@ func (n *nftablesRunner) AddSNATRule() error { conn := n.conn for _, table := range n.getNATTables() { - chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) } @@ -1093,7 +1108,7 @@ func (n *nftablesRunner) DelSNATRule() error { } for _, table := range n.getNATTables() { - chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) } @@ -1124,7 +1139,7 @@ func (n *nftablesRunner) DelSNATRule() error { // the jump rule and chain continue even if one errors. func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { // remove the jump first, before removing the jump destination. - defaultChain, err := getChainFromTable(conn, table, hookChainName) + defaultChain, err := GetChainFromTable(conn, table, hookChainName) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { logf("cleanup: did not find default chain: %s", err) } @@ -1133,7 +1148,7 @@ func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, _ = delHookRule(conn, table, defaultChain, tsChainName) } - tsChain, err := getChainFromTable(conn, table, tsChainName) + tsChain, err := GetChainFromTable(conn, table, tsChainName) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { logf("cleanup: did not find ts-chain: %s", err) } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index ad068957e..088610e40 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -11,35 +11,14 @@ import ( "net/netip" "os" "runtime" - "strings" "testing" "github.com/google/nftables" "github.com/google/nftables/expr" - "github.com/mdlayher/netlink" "github.com/vishvananda/netns" "tailscale.com/net/tsaddr" ) -// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing -// users to make sense of large byte literals more easily. -func nfdump(b []byte) string { - var buf bytes.Buffer - i := 0 - for ; i < len(b); i += 4 { - // TODO: show printable characters as ASCII - fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", - b[i], - b[i+1], - b[i+2], - b[i+3]) - } - for ; i < len(b); i++ { - fmt.Fprintf(&buf, "%02x ", b[i]) - } - return buf.String() -} - func TestMaskof(t *testing.T) { pfx, err := netip.ParsePrefix("192.168.1.0/24") if err != nil { @@ -51,56 +30,6 @@ func TestMaskof(t *testing.T) { } } -// linediff returns a side-by-side diff of two nfdump() return values, flagging -// lines which are not equal with an exclamation point prefix. -func linediff(a, b string) string { - var buf bytes.Buffer - fmt.Fprintf(&buf, "got -- want\n") - linesA := strings.Split(a, "\n") - linesB := strings.Split(b, "\n") - for idx, lineA := range linesA { - if idx >= len(linesB) { - break - } - lineB := linesB[idx] - prefix := "! " - if lineA == lineB { - prefix = " " - } - fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB) - } - return buf.String() -} - -func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { - conn, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) - if err != nil { - t.Fatal(err) - } - return conn -} - func TestInsertHookRule(t *testing.T) { proto := nftables.TableFamilyIPv4 want := [][]byte{ @@ -117,7 +46,7 @@ func TestInsertHookRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -157,7 +86,7 @@ func TestInsertLoopbackRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -193,7 +122,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) tableV6 := testConn.AddTable(&nftables.Table{ Family: protoV6, Name: "ts-filter-test", @@ -229,7 +158,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -261,7 +190,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -293,7 +222,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -325,7 +254,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -357,7 +286,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -389,7 +318,7 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-nat-test", @@ -421,7 +350,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := NewTestConn(t, want) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -458,14 +387,6 @@ func newSysConn(t *testing.T) *nftables.Conn { return c } -func cleanupSysConn(t *testing.T, ns netns.NsHandle) { - defer runtime.UnlockOSThread() - - if err := ns.Close(); err != nil { - t.Fatalf("newNS.Close() failed: %v", err) - } -} - func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner { nft4 := &nftable{Proto: nftables.TableFamilyIPv4} nft6 := &nftable{Proto: nftables.TableFamilyIPv6} @@ -843,7 +764,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) { defer runner.DelChains() runner.AddHooks() - forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD") + forwardChain, err := GetChainFromTable(conn, runner.nft4.Filter, "FORWARD") if err != nil { t.Fatalf("failed to get forwardChain: %v", err) } @@ -857,7 +778,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) { t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules)) } - inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT") + inputChain, err := GetChainFromTable(conn, runner.nft4.Filter, "INPUT") if err != nil { t.Fatalf("failed to get inputChain: %v", err) } @@ -871,7 +792,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) { t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules)) } - postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING") + postroutingChain, err := GetChainFromTable(conn, runner.nft4.Nat, "POSTROUTING") if err != nil { t.Fatalf("failed to get postroutingChain: %v", err) } diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 8a7273bd2..e4e7c7c59 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -62,21 +62,21 @@ type tableDetector interface { nftDetect() (int, error) } -type linuxFWDetector struct{} +type LinuxFWDetector struct{} // iptDetect returns the number of iptables rules in the current namespace. -func (l *linuxFWDetector) iptDetect() (int, error) { +func (l *LinuxFWDetector) iptDetect() (int, error) { return linuxfw.DetectIptables() } // nftDetect returns the number of nftables rules in the current namespace. -func (l *linuxFWDetector) nftDetect() (int, error) { +func (l *LinuxFWDetector) nftDetect() (int, error) { return linuxfw.DetectNetfilter() } // chooseFireWallMode returns the firewall mode to use based on the // environment and the system's capabilities. -func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode { +func ChooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode { if distro.Get() == distro.Gokrazy { // Reduce startup logging on gokrazy. There's no way to do iptables on // gokrazy anyway. @@ -126,7 +126,7 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod // newNetfilterRunner creates a netfilterRunner using either nftables or iptables. // As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { - tableDetector := &linuxFWDetector{} + tableDetector := &LinuxFWDetector{} var mode linuxfw.FirewallMode // We now use iptables as default and have "auto" and "nftables" as @@ -143,7 +143,7 @@ func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { hostinfo.SetFirewallMode("nft-forced") mode = linuxfw.FirewallModeNfTables case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto": - mode = chooseFireWallMode(logf, tableDetector) + mode = ChooseFireWallMode(logf, tableDetector) case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set") hostinfo.SetFirewallMode("ipt-forced")