util/linuxfw,wgengine/router: export some of the functionality

So that it can be used by iptables/nftables in containerboot

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
Irbe Krumina 2023-09-29 08:30:04 +01:00
parent 651620623b
commit 980cbd790d
7 changed files with 311 additions and 289 deletions

214
util/linuxfw/fakes.go Normal file
View File

@ -0,0 +1,214 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"bytes"
"errors"
"fmt"
"runtime"
"strings"
"testing"
"github.com/google/nftables"
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
)
var errExec = errors.New("execution failed")
type fakeIPTables struct {
t *testing.T
n map[string][]string
}
type fakeRule struct {
table, chain string
args []string
}
func NewIPTables(t *testing.T) *fakeIPTables {
return &fakeIPTables{
t: t,
n: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
"nat/PREROUTING": nil,
"nat/OUTPUT": nil,
"nat/POSTROUTING": nil,
},
}
}
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if pos > len(rules)+1 {
n.t.Errorf("bad position %d in %s", pos, k)
return errExec
}
rules = append(rules, "")
copy(rules[pos:], rules[pos-1:])
rules[pos-1] = strings.Join(args, " ")
n.n[k] = rules
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
return nil
}
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
k := table + "/" + chain
return n.Insert(table, chain, len(n.n[k])+1, args...)
}
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for _, rule := range rules {
if rule == strings.Join(args, " ") {
return true, nil
}
}
return false, nil
} else {
n.t.Logf("unknown table/chain %s", k)
return false, errExec
}
}
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for i, rule := range rules {
if rule == strings.Join(args, " ") {
rules = append(rules[:i], rules[i+1:]...)
n.n[k] = rules
return nil
}
}
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
return errExec
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
}
func (n *fakeIPTables) ClearChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.n[k] = nil
return nil
} else {
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
return errors.New("exitcode:1")
}
}
func (n *fakeIPTables) NewChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.t.Errorf("table/chain %s already exists", k)
return errExec
}
n.n[k] = nil
return nil
}
func (n *fakeIPTables) DeleteChain(table, chain string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if len(rules) != 0 {
n.t.Errorf("%s is not empty", k)
return errExec
}
delete(n.n, k)
return nil
} else {
n.t.Errorf("%s does not exist", k)
return errExec
}
}
func NewTestConn(t *testing.T, want [][]byte) *nftables.Conn {
conn, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return req, nil
}))
if err != nil {
t.Fatal(err)
}
return conn
}
func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
defer runtime.UnlockOSThread()
if err := ns.Close(); err != nil {
t.Fatalf("newNS.Close() failed: %v", err)
}
}
// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "got -- want\n")
linesA := strings.Split(a, "\n")
linesB := strings.Split(b, "\n")
for idx, lineA := range linesA {
if idx >= len(linesB) {
break
}
lineB := linesB[idx]
prefix := "! "
if lineA == lineB {
prefix = " "
}
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
}
return buf.String()
}
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
var buf bytes.Buffer
i := 0
for ; i < len(b); i += 4 {
// TODO: show printable characters as ASCII
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
b[i],
b[i+1],
b[i+2],
b[i+3])
}
for ; i < len(b); i++ {
fmt.Fprintf(&buf, "%02x ", b[i])
}
return buf.String()
}

View File

@ -37,7 +37,7 @@ type iptablesRunner struct {
v6NATAvailable bool
}
func checkIP6TablesExists() error {
func CheckIP6TablesExists() error {
// Some distros ship ip6tables separately from iptables.
if _, err := exec.LookPath("ip6tables"); err != nil {
return fmt.Errorf("path not found: %w", err)
@ -56,8 +56,8 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
}
supportsV6, supportsV6NAT := false, false
v6err := checkIPv6(logf)
ip6terr := checkIP6TablesExists()
v6err := CheckIPv6(logf)
ip6terr := CheckIP6TablesExists()
switch {
case v6err != nil:
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
@ -65,7 +65,7 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr)
default:
supportsV6 = true
supportsV6NAT = supportsV6 && checkSupportsV6NAT()
supportsV6NAT = supportsV6 && CheckSupportsV6NAT()
logf("v6nat = %v", supportsV6NAT)
}

View File

@ -6,7 +6,6 @@
package linuxfw
import (
"errors"
"net/netip"
"strings"
"testing"
@ -14,136 +13,9 @@ import (
"tailscale.com/net/tsaddr"
)
var errExec = errors.New("execution failed")
type fakeIPTables struct {
t *testing.T
n map[string][]string
}
type fakeRule struct {
table, chain string
args []string
}
func newIPTables(t *testing.T) *fakeIPTables {
return &fakeIPTables{
t: t,
n: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
"nat/PREROUTING": nil,
"nat/OUTPUT": nil,
"nat/POSTROUTING": nil,
},
}
}
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if pos > len(rules)+1 {
n.t.Errorf("bad position %d in %s", pos, k)
return errExec
}
rules = append(rules, "")
copy(rules[pos:], rules[pos-1:])
rules[pos-1] = strings.Join(args, " ")
n.n[k] = rules
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
return nil
}
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
k := table + "/" + chain
return n.Insert(table, chain, len(n.n[k])+1, args...)
}
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for _, rule := range rules {
if rule == strings.Join(args, " ") {
return true, nil
}
}
return false, nil
} else {
n.t.Logf("unknown table/chain %s", k)
return false, errExec
}
}
func hasChain(n *fakeIPTables, table, chain string) bool {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
return true
} else {
return false
}
}
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for i, rule := range rules {
if rule == strings.Join(args, " ") {
rules = append(rules[:i], rules[i+1:]...)
n.n[k] = rules
return nil
}
}
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
return errExec
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
}
func (n *fakeIPTables) ClearChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.n[k] = nil
return nil
} else {
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
return errors.New("exitcode:1")
}
}
func (n *fakeIPTables) NewChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.t.Errorf("table/chain %s already exists", k)
return errExec
}
n.n[k] = nil
return nil
}
func (n *fakeIPTables) DeleteChain(table, chain string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if len(rules) != 0 {
n.t.Errorf("%s is not empty", k)
return errExec
}
delete(n.n, k)
return nil
} else {
n.t.Errorf("%s does not exist", k)
return errExec
}
}
func newFakeIPTablesRunner(t *testing.T) *iptablesRunner {
ipt4 := newIPTables(t)
ipt6 := newIPTables(t)
ipt4 := NewIPTables(t)
ipt6 := NewIPTables(t)
iptr := &iptablesRunner{ipt4, ipt6, true, true}
return iptr

View File

@ -131,7 +131,7 @@ func errCode(err error) int {
// missing. It does not check that IPv6 is currently functional or
// that there's a global address, just that the system would support
// IPv6 if it were on an IPv6 network.
func checkIPv6(logf logger.Logf) error {
func CheckIPv6(logf logger.Logf) error {
_, err := os.Stat("/proc/sys/net/ipv6")
if os.IsNotExist(err) {
return err
@ -176,7 +176,7 @@ func checkIPv6(logf logger.Logf) error {
// The nat table was added after the initial release of ipv6
// netfilter, so some older distros ship a kernel that can't NAT IPv6
// traffic.
func checkSupportsV6NAT() bool {
func CheckSupportsV6NAT() bool {
bs, err := os.ReadFile("/proc/net/ip6_tables_names")
if err != nil {
// Can't read the file. Assume SNAT works.

View File

@ -30,13 +30,13 @@ const (
// chainTypeRegular is an nftables chain that does not apply to a hook.
const chainTypeRegular = ""
type chainInfo struct {
table *nftables.Table
name string
chainType nftables.ChainType
chainHook *nftables.ChainHook
chainPriority *nftables.ChainPriority
chainPolicy *nftables.ChainPolicy
type ChainInfo struct {
Table *nftables.Table
Name string
ChainType nftables.ChainType
ChainHook *nftables.ChainHook
ChainPriority *nftables.ChainPriority
ChainPolicy *nftables.ChainPolicy
}
type nftable struct {
@ -45,6 +45,21 @@ type nftable struct {
Nat *nftables.Table
}
type Conn interface {
ListChainsOfTableFamily(nftables.TableFamily) ([]*nftables.Chain, error)
ListTables() ([]*nftables.Table, error)
AddTable(*nftables.Table) *nftables.Table
DelTable(*nftables.Table)
AddChain(*nftables.Chain) *nftables.Chain
DelChain(*nftables.Chain)
FlushChain(*nftables.Chain)
GetRules(*nftables.Table, *nftables.Chain) ([]*nftables.Rule, error)
InsertRule(*nftables.Rule) *nftables.Rule
AddRule(*nftables.Rule) *nftables.Rule
DelRule(*nftables.Rule) error
Flush() error
}
// nftablesRunner implements a netfilterRunner using the netlink based nftables
// library. As nftables allows for arbitrary tables and chains, there is a need
// to follow conventions in order to integrate well with a surrounding
@ -70,7 +85,7 @@ type nftablesRunner struct {
}
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
func CreateTableIfNotExist(c Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
tables, err := c.ListTables()
if err != nil {
return nil, fmt.Errorf("get tables: %w", err)
@ -102,7 +117,7 @@ func (e errorChainNotFound) Error() string {
// getChainFromTable returns the chain with the given name from the given table.
// Note that a chain name is unique within a table.
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
func GetChainFromTable(c Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
chains, err := c.ListChainsOfTableFamily(table.Family)
if err != nil {
return nil, fmt.Errorf("list chains: %w", err)
@ -144,35 +159,35 @@ func isTSChain(name string) bool {
// createChainIfNotExist creates a chain with the given name in the given table
// if it does not exist.
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
return fmt.Errorf("get chain: %w", err)
func CreateChainIfNotExist(c Conn, cinfo ChainInfo) (*nftables.Chain, error) {
chain, err := GetChainFromTable(c, cinfo.Table, cinfo.Name)
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.Table.Name, cinfo.Name}) {
return nil, fmt.Errorf("get chain: %w", err)
} else if err == nil {
// The chain already exists. If it is a TS chain, check the
// type/hook/priority, but for "conventional chains" assume they're what
// we expect (in case iptables-nft/ufw make minor behavior changes in
// the future).
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
if isTSChain(chain.Name) && (chain.Type != cinfo.ChainType || chain.Hooknum != cinfo.ChainHook || chain.Priority != cinfo.ChainPriority) {
return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.Name)
}
return nil
return chain, nil
}
_ = c.AddChain(&nftables.Chain{
Name: cinfo.name,
Table: cinfo.table,
Type: cinfo.chainType,
Hooknum: cinfo.chainHook,
Priority: cinfo.chainPriority,
Policy: cinfo.chainPolicy,
chain = c.AddChain(&nftables.Chain{
Name: cinfo.Name,
Table: cinfo.Table,
Type: cinfo.ChainType,
Hooknum: cinfo.ChainHook,
Priority: cinfo.ChainPriority,
Policy: cinfo.ChainPolicy,
})
if err := c.Flush(); err != nil {
return fmt.Errorf("add chain: %w", err)
return nil, fmt.Errorf("add chain: %w", err)
}
return nil
return chain, nil
}
// NewNfTablesRunner creates a new nftablesRunner without guaranteeing
@ -184,12 +199,12 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
}
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
v6err := checkIPv6(logf)
v6err := CheckIPv6(logf)
if v6err != nil {
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
}
supportsV6 := v6err == nil
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
supportsV6NAT := supportsV6 && CheckSupportsV6NAT()
var nft6 *nftable
if supportsV6 {
@ -208,7 +223,7 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
}, nil
}
// newLoadSaddrExpr creates a new nftables expression that loads the source
// NewLoadSaddrExpr creates a new nftables expression that loads the source
// address of the packet into the given register.
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
switch proto {
@ -358,7 +373,7 @@ func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
@ -375,7 +390,7 @@ func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
nf := n.getNFTByAddr(addr)
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %w", err)
}
@ -428,23 +443,23 @@ func (n *nftablesRunner) AddChains() error {
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
filter, err := CreateTableIfNotExist(n.conn, table.Proto, "filter")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Filter = filter
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
// Adding the tailscale chains that contain our rules.
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create forward chain: %w", err)
}
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}
@ -454,17 +469,17 @@ func (n *nftablesRunner) AddChains() error {
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
nat, err := CreateTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("create table: %w", err)
}
table.Nat = nat
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
// Adding the tailscale chain that contains our rules.
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err)
}
}
@ -474,7 +489,7 @@ func (n *nftablesRunner) AddChains() error {
// deleteChainIfExists deletes a chain if it exists.
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
chain, err := getChainFromTable(c, table, name)
chain, err := GetChainFromTable(c, table, name)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
return fmt.Errorf("get chain: %w", err)
} else if err != nil {
@ -557,7 +572,7 @@ func (n *nftablesRunner) AddHooks() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
@ -565,7 +580,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
@ -576,7 +591,7 @@ func (n *nftablesRunner) AddHooks() error {
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
@ -613,7 +628,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
@ -621,7 +636,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD")
if err != nil {
return fmt.Errorf("get FORWARD chain: %w", err)
}
@ -632,7 +647,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
}
@ -894,7 +909,7 @@ func (n *nftablesRunner) AddBase(tunname string) error {
func (n *nftablesRunner) addBase4(tunname string) error {
conn := n.conn
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
inputChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain v4: %v", err)
}
@ -905,7 +920,7 @@ func (n *nftablesRunner) addBase4(tunname string) error {
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
}
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
forwardChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v4: %v", err)
}
@ -937,7 +952,7 @@ func (n *nftablesRunner) addBase4(tunname string) error {
func (n *nftablesRunner) addBase6(tunname string) error {
conn := n.conn
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
forwardChain, err := GetChainFromTable(conn, n.nft6.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain v6: %w", err)
}
@ -967,12 +982,12 @@ func (n *nftablesRunner) DelBase() error {
conn := n.conn
for _, table := range n.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
inputChain, err := GetChainFromTable(conn, table.Filter, chainNameInput)
if err != nil {
return fmt.Errorf("get input chain: %v", err)
}
conn.FlushChain(inputChain)
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
forwardChain, err := GetChainFromTable(conn, table.Filter, chainNameForward)
if err != nil {
return fmt.Errorf("get forward chain: %v", err)
}
@ -980,7 +995,7 @@ func (n *nftablesRunner) DelBase() error {
}
for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
postrouteChain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
}
@ -1050,7 +1065,7 @@ func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
@ -1093,7 +1108,7 @@ func (n *nftablesRunner) DelSNATRule() error {
}
for _, table := range n.getNATTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
}
@ -1124,7 +1139,7 @@ func (n *nftablesRunner) DelSNATRule() error {
// the jump rule and chain continue even if one errors.
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
// remove the jump first, before removing the jump destination.
defaultChain, err := getChainFromTable(conn, table, hookChainName)
defaultChain, err := GetChainFromTable(conn, table, hookChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
logf("cleanup: did not find default chain: %s", err)
}
@ -1133,7 +1148,7 @@ func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table,
_ = delHookRule(conn, table, defaultChain, tsChainName)
}
tsChain, err := getChainFromTable(conn, table, tsChainName)
tsChain, err := GetChainFromTable(conn, table, tsChainName)
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
logf("cleanup: did not find ts-chain: %s", err)
}

View File

@ -11,35 +11,14 @@ import (
"net/netip"
"os"
"runtime"
"strings"
"testing"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"tailscale.com/net/tsaddr"
)
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
var buf bytes.Buffer
i := 0
for ; i < len(b); i += 4 {
// TODO: show printable characters as ASCII
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
b[i],
b[i+1],
b[i+2],
b[i+3])
}
for ; i < len(b); i++ {
fmt.Fprintf(&buf, "%02x ", b[i])
}
return buf.String()
}
func TestMaskof(t *testing.T) {
pfx, err := netip.ParsePrefix("192.168.1.0/24")
if err != nil {
@ -51,56 +30,6 @@ func TestMaskof(t *testing.T) {
}
}
// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "got -- want\n")
linesA := strings.Split(a, "\n")
linesB := strings.Split(b, "\n")
for idx, lineA := range linesA {
if idx >= len(linesB) {
break
}
lineB := linesB[idx]
prefix := "! "
if lineA == lineB {
prefix = " "
}
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
}
return buf.String()
}
func newTestConn(t *testing.T, want [][]byte) *nftables.Conn {
conn, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return req, nil
}))
if err != nil {
t.Fatal(err)
}
return conn
}
func TestInsertHookRule(t *testing.T) {
proto := nftables.TableFamilyIPv4
want := [][]byte{
@ -117,7 +46,7 @@ func TestInsertHookRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -157,7 +86,7 @@ func TestInsertLoopbackRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -193,7 +122,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
tableV6 := testConn.AddTable(&nftables.Table{
Family: protoV6,
Name: "ts-filter-test",
@ -229,7 +158,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -261,7 +190,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -293,7 +222,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -325,7 +254,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -357,7 +286,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -389,7 +318,7 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-nat-test",
@ -421,7 +350,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
testConn := newTestConn(t, want)
testConn := NewTestConn(t, want)
table := testConn.AddTable(&nftables.Table{
Family: proto,
Name: "ts-filter-test",
@ -458,14 +387,6 @@ func newSysConn(t *testing.T) *nftables.Conn {
return c
}
func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
defer runtime.UnlockOSThread()
if err := ns.Close(); err != nil {
t.Fatalf("newNS.Close() failed: %v", err)
}
}
func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}
@ -843,7 +764,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
defer runner.DelChains()
runner.AddHooks()
forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD")
forwardChain, err := GetChainFromTable(conn, runner.nft4.Filter, "FORWARD")
if err != nil {
t.Fatalf("failed to get forwardChain: %v", err)
}
@ -857,7 +778,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
}
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
inputChain, err := GetChainFromTable(conn, runner.nft4.Filter, "INPUT")
if err != nil {
t.Fatalf("failed to get inputChain: %v", err)
}
@ -871,7 +792,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
}
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
postroutingChain, err := GetChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
if err != nil {
t.Fatalf("failed to get postroutingChain: %v", err)
}

View File

@ -62,21 +62,21 @@ type tableDetector interface {
nftDetect() (int, error)
}
type linuxFWDetector struct{}
type LinuxFWDetector struct{}
// iptDetect returns the number of iptables rules in the current namespace.
func (l *linuxFWDetector) iptDetect() (int, error) {
func (l *LinuxFWDetector) iptDetect() (int, error) {
return linuxfw.DetectIptables()
}
// nftDetect returns the number of nftables rules in the current namespace.
func (l *linuxFWDetector) nftDetect() (int, error) {
func (l *LinuxFWDetector) nftDetect() (int, error) {
return linuxfw.DetectNetfilter()
}
// chooseFireWallMode returns the firewall mode to use based on the
// environment and the system's capabilities.
func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
func ChooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
if distro.Get() == distro.Gokrazy {
// Reduce startup logging on gokrazy. There's no way to do iptables on
// gokrazy anyway.
@ -126,7 +126,7 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod
// newNetfilterRunner creates a netfilterRunner using either nftables or iptables.
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
tableDetector := &linuxFWDetector{}
tableDetector := &LinuxFWDetector{}
var mode linuxfw.FirewallMode
// We now use iptables as default and have "auto" and "nftables" as
@ -143,7 +143,7 @@ func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
hostinfo.SetFirewallMode("nft-forced")
mode = linuxfw.FirewallModeNfTables
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto":
mode = chooseFireWallMode(logf, tableDetector)
mode = ChooseFireWallMode(logf, tableDetector)
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
hostinfo.SetFirewallMode("ipt-forced")