mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-06 04:36:15 +02:00
util/linuxfw,wgengine/router: export some of the functionality
So that it can be used by iptables/nftables in containerboot Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
parent
651620623b
commit
980cbd790d
214
util/linuxfw/fakes.go
Normal file
214
util/linuxfw/fakes.go
Normal file
@ -0,0 +1,214 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build linux
|
||||
|
||||
package linuxfw
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/vishvananda/netns"
|
||||
)
|
||||
|
||||
var errExec = errors.New("execution failed")
|
||||
|
||||
type fakeIPTables struct {
|
||||
t *testing.T
|
||||
n map[string][]string
|
||||
}
|
||||
|
||||
type fakeRule struct {
|
||||
table, chain string
|
||||
args []string
|
||||
}
|
||||
|
||||
func NewIPTables(t *testing.T) *fakeIPTables {
|
||||
return &fakeIPTables{
|
||||
t: t,
|
||||
n: map[string][]string{
|
||||
"filter/INPUT": nil,
|
||||
"filter/OUTPUT": nil,
|
||||
"filter/FORWARD": nil,
|
||||
"nat/PREROUTING": nil,
|
||||
"nat/OUTPUT": nil,
|
||||
"nat/POSTROUTING": nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if pos > len(rules)+1 {
|
||||
n.t.Errorf("bad position %d in %s", pos, k)
|
||||
return errExec
|
||||
}
|
||||
rules = append(rules, "")
|
||||
copy(rules[pos:], rules[pos-1:])
|
||||
rules[pos-1] = strings.Join(args, " ")
|
||||
n.n[k] = rules
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
return n.Insert(table, chain, len(n.n[k])+1, args...)
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for _, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
n.t.Logf("unknown table/chain %s", k)
|
||||
return false, errExec
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for i, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
rules = append(rules[:i], rules[i+1:]...)
|
||||
n.n[k] = rules
|
||||
return nil
|
||||
}
|
||||
}
|
||||
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
|
||||
return errExec
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) ClearChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
} else {
|
||||
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
|
||||
return errors.New("exitcode:1")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) NewChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.t.Errorf("table/chain %s already exists", k)
|
||||
return errExec
|
||||
}
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) DeleteChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if len(rules) != 0 {
|
||||
n.t.Errorf("%s is not empty", k)
|
||||
return errExec
|
||||
}
|
||||
delete(n.n, k)
|
||||
return nil
|
||||
} else {
|
||||
n.t.Errorf("%s does not exist", k)
|
||||
return errExec
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestConn(t *testing.T, want [][]byte) *nftables.Conn {
|
||||
conn, err := nftables.New(nftables.WithTestDial(
|
||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||
for idx, msg := range req {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(b) < 16 {
|
||||
continue
|
||||
}
|
||||
b = b[16:]
|
||||
if len(want) == 0 {
|
||||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||||
continue
|
||||
}
|
||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
return req, nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
if err := ns.Close(); err != nil {
|
||||
t.Fatalf("newNS.Close() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||||
// lines which are not equal with an exclamation point prefix.
|
||||
func linediff(a, b string) string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "got -- want\n")
|
||||
linesA := strings.Split(a, "\n")
|
||||
linesB := strings.Split(b, "\n")
|
||||
for idx, lineA := range linesA {
|
||||
if idx >= len(linesB) {
|
||||
break
|
||||
}
|
||||
lineB := linesB[idx]
|
||||
prefix := "! "
|
||||
if lineA == lineB {
|
||||
prefix = " "
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||||
// users to make sense of large byte literals more easily.
|
||||
func nfdump(b []byte) string {
|
||||
var buf bytes.Buffer
|
||||
i := 0
|
||||
for ; i < len(b); i += 4 {
|
||||
// TODO: show printable characters as ASCII
|
||||
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||||
b[i],
|
||||
b[i+1],
|
||||
b[i+2],
|
||||
b[i+3])
|
||||
}
|
||||
for ; i < len(b); i++ {
|
||||
fmt.Fprintf(&buf, "%02x ", b[i])
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
@ -37,7 +37,7 @@ type iptablesRunner struct {
|
||||
v6NATAvailable bool
|
||||
}
|
||||
|
||||
func checkIP6TablesExists() error {
|
||||
func CheckIP6TablesExists() error {
|
||||
// Some distros ship ip6tables separately from iptables.
|
||||
if _, err := exec.LookPath("ip6tables"); err != nil {
|
||||
return fmt.Errorf("path not found: %w", err)
|
||||
@ -56,8 +56,8 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
|
||||
}
|
||||
|
||||
supportsV6, supportsV6NAT := false, false
|
||||
v6err := checkIPv6(logf)
|
||||
ip6terr := checkIP6TablesExists()
|
||||
v6err := CheckIPv6(logf)
|
||||
ip6terr := CheckIP6TablesExists()
|
||||
switch {
|
||||
case v6err != nil:
|
||||
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
|
||||
@ -65,7 +65,7 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
|
||||
logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr)
|
||||
default:
|
||||
supportsV6 = true
|
||||
supportsV6NAT = supportsV6 && checkSupportsV6NAT()
|
||||
supportsV6NAT = supportsV6 && CheckSupportsV6NAT()
|
||||
logf("v6nat = %v", supportsV6NAT)
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
package linuxfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -14,136 +13,9 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
var errExec = errors.New("execution failed")
|
||||
|
||||
type fakeIPTables struct {
|
||||
t *testing.T
|
||||
n map[string][]string
|
||||
}
|
||||
|
||||
type fakeRule struct {
|
||||
table, chain string
|
||||
args []string
|
||||
}
|
||||
|
||||
func newIPTables(t *testing.T) *fakeIPTables {
|
||||
return &fakeIPTables{
|
||||
t: t,
|
||||
n: map[string][]string{
|
||||
"filter/INPUT": nil,
|
||||
"filter/OUTPUT": nil,
|
||||
"filter/FORWARD": nil,
|
||||
"nat/PREROUTING": nil,
|
||||
"nat/OUTPUT": nil,
|
||||
"nat/POSTROUTING": nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if pos > len(rules)+1 {
|
||||
n.t.Errorf("bad position %d in %s", pos, k)
|
||||
return errExec
|
||||
}
|
||||
rules = append(rules, "")
|
||||
copy(rules[pos:], rules[pos-1:])
|
||||
rules[pos-1] = strings.Join(args, " ")
|
||||
n.n[k] = rules
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Append(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
return n.Insert(table, chain, len(n.n[k])+1, args...)
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for _, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
n.t.Logf("unknown table/chain %s", k)
|
||||
return false, errExec
|
||||
}
|
||||
}
|
||||
|
||||
func hasChain(n *fakeIPTables, table, chain string) bool {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
for i, rule := range rules {
|
||||
if rule == strings.Join(args, " ") {
|
||||
rules = append(rules[:i], rules[i+1:]...)
|
||||
n.n[k] = rules
|
||||
return nil
|
||||
}
|
||||
}
|
||||
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
|
||||
return errExec
|
||||
} else {
|
||||
n.t.Errorf("unknown table/chain %s", k)
|
||||
return errExec
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) ClearChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
} else {
|
||||
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
|
||||
return errors.New("exitcode:1")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) NewChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if _, ok := n.n[k]; ok {
|
||||
n.t.Errorf("table/chain %s already exists", k)
|
||||
return errExec
|
||||
}
|
||||
n.n[k] = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *fakeIPTables) DeleteChain(table, chain string) error {
|
||||
k := table + "/" + chain
|
||||
if rules, ok := n.n[k]; ok {
|
||||
if len(rules) != 0 {
|
||||
n.t.Errorf("%s is not empty", k)
|
||||
return errExec
|
||||
}
|
||||
delete(n.n, k)
|
||||
return nil
|
||||
} else {
|
||||
n.t.Errorf("%s does not exist", k)
|
||||
return errExec
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeIPTablesRunner(t *testing.T) *iptablesRunner {
|
||||
ipt4 := newIPTables(t)
|
||||
ipt6 := newIPTables(t)
|
||||
ipt4 := NewIPTables(t)
|
||||
ipt6 := NewIPTables(t)
|
||||
|
||||
iptr := &iptablesRunner{ipt4, ipt6, true, true}
|
||||
return iptr
|
||||
|
||||
@ -131,7 +131,7 @@ func errCode(err error) int {
|
||||
// missing. It does not check that IPv6 is currently functional or
|
||||
// that there's a global address, just that the system would support
|
||||
// IPv6 if it were on an IPv6 network.
|
||||
func checkIPv6(logf logger.Logf) error {
|
||||
func CheckIPv6(logf logger.Logf) error {
|
||||
_, err := os.Stat("/proc/sys/net/ipv6")
|
||||
if os.IsNotExist(err) {
|
||||
return err
|
||||
@ -176,7 +176,7 @@ func checkIPv6(logf logger.Logf) error {
|
||||
// The nat table was added after the initial release of ipv6
|
||||
// netfilter, so some older distros ship a kernel that can't NAT IPv6
|
||||
// traffic.
|
||||
func checkSupportsV6NAT() bool {
|
||||
func CheckSupportsV6NAT() bool {
|
||||
bs, err := os.ReadFile("/proc/net/ip6_tables_names")
|
||||
if err != nil {
|
||||
// Can't read the file. Assume SNAT works.
|
||||
|
||||
@ -30,13 +30,13 @@ const (
|
||||
// chainTypeRegular is an nftables chain that does not apply to a hook.
|
||||
const chainTypeRegular = ""
|
||||
|
||||
type chainInfo struct {
|
||||
table *nftables.Table
|
||||
name string
|
||||
chainType nftables.ChainType
|
||||
chainHook *nftables.ChainHook
|
||||
chainPriority *nftables.ChainPriority
|
||||
chainPolicy *nftables.ChainPolicy
|
||||
type ChainInfo struct {
|
||||
Table *nftables.Table
|
||||
Name string
|
||||
ChainType nftables.ChainType
|
||||
ChainHook *nftables.ChainHook
|
||||
ChainPriority *nftables.ChainPriority
|
||||
ChainPolicy *nftables.ChainPolicy
|
||||
}
|
||||
|
||||
type nftable struct {
|
||||
@ -45,6 +45,21 @@ type nftable struct {
|
||||
Nat *nftables.Table
|
||||
}
|
||||
|
||||
type Conn interface {
|
||||
ListChainsOfTableFamily(nftables.TableFamily) ([]*nftables.Chain, error)
|
||||
ListTables() ([]*nftables.Table, error)
|
||||
AddTable(*nftables.Table) *nftables.Table
|
||||
DelTable(*nftables.Table)
|
||||
AddChain(*nftables.Chain) *nftables.Chain
|
||||
DelChain(*nftables.Chain)
|
||||
FlushChain(*nftables.Chain)
|
||||
GetRules(*nftables.Table, *nftables.Chain) ([]*nftables.Rule, error)
|
||||
InsertRule(*nftables.Rule) *nftables.Rule
|
||||
AddRule(*nftables.Rule) *nftables.Rule
|
||||
DelRule(*nftables.Rule) error
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// nftablesRunner implements a netfilterRunner using the netlink based nftables
|
||||
// library. As nftables allows for arbitrary tables and chains, there is a need
|
||||
// to follow conventions in order to integrate well with a surrounding
|
||||
@ -70,7 +85,7 @@ type nftablesRunner struct {
|
||||
}
|
||||
|
||||
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
|
||||
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
||||
func CreateTableIfNotExist(c Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
||||
tables, err := c.ListTables()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get tables: %w", err)
|
||||
@ -102,7 +117,7 @@ func (e errorChainNotFound) Error() string {
|
||||
|
||||
// getChainFromTable returns the chain with the given name from the given table.
|
||||
// Note that a chain name is unique within a table.
|
||||
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
|
||||
func GetChainFromTable(c Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
|
||||
chains, err := c.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list chains: %w", err)
|
||||
@ -144,35 +159,35 @@ func isTSChain(name string) bool {
|
||||
|
||||
// createChainIfNotExist creates a chain with the given name in the given table
|
||||
// if it does not exist.
|
||||
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||||
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
||||
return fmt.Errorf("get chain: %w", err)
|
||||
func CreateChainIfNotExist(c Conn, cinfo ChainInfo) (*nftables.Chain, error) {
|
||||
chain, err := GetChainFromTable(c, cinfo.Table, cinfo.Name)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.Table.Name, cinfo.Name}) {
|
||||
return nil, fmt.Errorf("get chain: %w", err)
|
||||
} else if err == nil {
|
||||
// The chain already exists. If it is a TS chain, check the
|
||||
// type/hook/priority, but for "conventional chains" assume they're what
|
||||
// we expect (in case iptables-nft/ufw make minor behavior changes in
|
||||
// the future).
|
||||
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
|
||||
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
||||
if isTSChain(chain.Name) && (chain.Type != cinfo.ChainType || chain.Hooknum != cinfo.ChainHook || chain.Priority != cinfo.ChainPriority) {
|
||||
return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.Name)
|
||||
}
|
||||
return nil
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
_ = c.AddChain(&nftables.Chain{
|
||||
Name: cinfo.name,
|
||||
Table: cinfo.table,
|
||||
Type: cinfo.chainType,
|
||||
Hooknum: cinfo.chainHook,
|
||||
Priority: cinfo.chainPriority,
|
||||
Policy: cinfo.chainPolicy,
|
||||
chain = c.AddChain(&nftables.Chain{
|
||||
Name: cinfo.Name,
|
||||
Table: cinfo.Table,
|
||||
Type: cinfo.ChainType,
|
||||
Hooknum: cinfo.ChainHook,
|
||||
Priority: cinfo.ChainPriority,
|
||||
Policy: cinfo.ChainPolicy,
|
||||
})
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
return fmt.Errorf("add chain: %w", err)
|
||||
return nil, fmt.Errorf("add chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// NewNfTablesRunner creates a new nftablesRunner without guaranteeing
|
||||
@ -184,12 +199,12 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
|
||||
}
|
||||
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
|
||||
|
||||
v6err := checkIPv6(logf)
|
||||
v6err := CheckIPv6(logf)
|
||||
if v6err != nil {
|
||||
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
|
||||
}
|
||||
supportsV6 := v6err == nil
|
||||
supportsV6NAT := supportsV6 && checkSupportsV6NAT()
|
||||
supportsV6NAT := supportsV6 && CheckSupportsV6NAT()
|
||||
|
||||
var nft6 *nftable
|
||||
if supportsV6 {
|
||||
@ -208,7 +223,7 @@ func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newLoadSaddrExpr creates a new nftables expression that loads the source
|
||||
// NewLoadSaddrExpr creates a new nftables expression that loads the source
|
||||
// address of the packet into the given register.
|
||||
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
|
||||
switch proto {
|
||||
@ -358,7 +373,7 @@ func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
|
||||
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
||||
nf := n.getNFTByAddr(addr)
|
||||
|
||||
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||||
inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get input chain: %w", err)
|
||||
}
|
||||
@ -375,7 +390,7 @@ func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
||||
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
||||
nf := n.getNFTByAddr(addr)
|
||||
|
||||
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||||
inputChain, err := GetChainFromTable(n.conn, nf.Filter, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get input chain: %w", err)
|
||||
}
|
||||
@ -428,23 +443,23 @@ func (n *nftablesRunner) AddChains() error {
|
||||
// as the name used by iptables-nft and ufw. We install rules into the
|
||||
// same conventional table so that `accept` verdicts from our jump
|
||||
// chains are conclusive.
|
||||
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
||||
filter, err := CreateTableIfNotExist(n.conn, table.Proto, "filter")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
table.Filter = filter
|
||||
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
return fmt.Errorf("create forward chain: %w", err)
|
||||
}
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
return fmt.Errorf("create input chain: %w", err)
|
||||
}
|
||||
// Adding the tailscale chains that contain our rules.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
return fmt.Errorf("create forward chain: %w", err)
|
||||
}
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
return fmt.Errorf("create input chain: %w", err)
|
||||
}
|
||||
}
|
||||
@ -454,17 +469,17 @@ func (n *nftablesRunner) AddChains() error {
|
||||
// as the name used by iptables-nft and ufw. We install rules into the
|
||||
// same conventional table so that `accept` verdicts from our jump
|
||||
// chains are conclusive.
|
||||
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
||||
nat, err := CreateTableIfNotExist(n.conn, table.Proto, "nat")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
table.Nat = nat
|
||||
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
||||
return fmt.Errorf("create postrouting chain: %w", err)
|
||||
}
|
||||
// Adding the tailscale chain that contains our rules.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
if _, err = CreateChainIfNotExist(n.conn, ChainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
return fmt.Errorf("create postrouting chain: %w", err)
|
||||
}
|
||||
}
|
||||
@ -474,7 +489,7 @@ func (n *nftablesRunner) AddChains() error {
|
||||
|
||||
// deleteChainIfExists deletes a chain if it exists.
|
||||
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
|
||||
chain, err := getChainFromTable(c, table, name)
|
||||
chain, err := GetChainFromTable(c, table, name)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
|
||||
return fmt.Errorf("get chain: %w", err)
|
||||
} else if err != nil {
|
||||
@ -557,7 +572,7 @@ func (n *nftablesRunner) AddHooks() error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getTables() {
|
||||
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||
inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
@ -565,7 +580,7 @@ func (n *nftablesRunner) AddHooks() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("Addhook: %w", err)
|
||||
}
|
||||
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||
forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||
}
|
||||
@ -576,7 +591,7 @@ func (n *nftablesRunner) AddHooks() error {
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
@ -613,7 +628,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getTables() {
|
||||
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||
inputChain, err := GetChainFromTable(conn, table.Filter, "INPUT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
@ -621,7 +636,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("delhook: %w", err)
|
||||
}
|
||||
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||
forwardChain, err := GetChainFromTable(conn, table.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||
}
|
||||
@ -632,7 +647,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
postroutingChain, err := GetChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
@ -894,7 +909,7 @@ func (n *nftablesRunner) AddBase(tunname string) error {
|
||||
func (n *nftablesRunner) addBase4(tunname string) error {
|
||||
conn := n.conn
|
||||
|
||||
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
||||
inputChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get input chain v4: %v", err)
|
||||
}
|
||||
@ -905,7 +920,7 @@ func (n *nftablesRunner) addBase4(tunname string) error {
|
||||
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
||||
}
|
||||
|
||||
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
||||
forwardChain, err := GetChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get forward chain v4: %v", err)
|
||||
}
|
||||
@ -937,7 +952,7 @@ func (n *nftablesRunner) addBase4(tunname string) error {
|
||||
func (n *nftablesRunner) addBase6(tunname string) error {
|
||||
conn := n.conn
|
||||
|
||||
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
||||
forwardChain, err := GetChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get forward chain v6: %w", err)
|
||||
}
|
||||
@ -967,12 +982,12 @@ func (n *nftablesRunner) DelBase() error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getTables() {
|
||||
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
||||
inputChain, err := GetChainFromTable(conn, table.Filter, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get input chain: %v", err)
|
||||
}
|
||||
conn.FlushChain(inputChain)
|
||||
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
||||
forwardChain, err := GetChainFromTable(conn, table.Filter, chainNameForward)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get forward chain: %v", err)
|
||||
}
|
||||
@ -980,7 +995,7 @@ func (n *nftablesRunner) DelBase() error {
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
postrouteChain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get postrouting chain v4: %v", err)
|
||||
}
|
||||
@ -1050,7 +1065,7 @@ func (n *nftablesRunner) AddSNATRule() error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get postrouting chain v4: %w", err)
|
||||
}
|
||||
@ -1093,7 +1108,7 @@ func (n *nftablesRunner) DelSNATRule() error {
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
chain, err := GetChainFromTable(conn, table.Nat, chainNamePostrouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get postrouting chain v4: %w", err)
|
||||
}
|
||||
@ -1124,7 +1139,7 @@ func (n *nftablesRunner) DelSNATRule() error {
|
||||
// the jump rule and chain continue even if one errors.
|
||||
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
|
||||
// remove the jump first, before removing the jump destination.
|
||||
defaultChain, err := getChainFromTable(conn, table, hookChainName)
|
||||
defaultChain, err := GetChainFromTable(conn, table, hookChainName)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
||||
logf("cleanup: did not find default chain: %s", err)
|
||||
}
|
||||
@ -1133,7 +1148,7 @@ func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table,
|
||||
_ = delHookRule(conn, table, defaultChain, tsChainName)
|
||||
}
|
||||
|
||||
tsChain, err := getChainFromTable(conn, table, tsChainName)
|
||||
tsChain, err := GetChainFromTable(conn, table, tsChainName)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
|
||||
logf("cleanup: did not find ts-chain: %s", err)
|
||||
}
|
||||
|
||||
@ -11,35 +11,14 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/vishvananda/netns"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||||
// users to make sense of large byte literals more easily.
|
||||
func nfdump(b []byte) string {
|
||||
var buf bytes.Buffer
|
||||
i := 0
|
||||
for ; i < len(b); i += 4 {
|
||||
// TODO: show printable characters as ASCII
|
||||
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||||
b[i],
|
||||
b[i+1],
|
||||
b[i+2],
|
||||
b[i+3])
|
||||
}
|
||||
for ; i < len(b); i++ {
|
||||
fmt.Fprintf(&buf, "%02x ", b[i])
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestMaskof(t *testing.T) {
|
||||
pfx, err := netip.ParsePrefix("192.168.1.0/24")
|
||||
if err != nil {
|
||||
@ -51,56 +30,6 @@ func TestMaskof(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||||
// lines which are not equal with an exclamation point prefix.
|
||||
func linediff(a, b string) string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "got -- want\n")
|
||||
linesA := strings.Split(a, "\n")
|
||||
linesB := strings.Split(b, "\n")
|
||||
for idx, lineA := range linesA {
|
||||
if idx >= len(linesB) {
|
||||
break
|
||||
}
|
||||
lineB := linesB[idx]
|
||||
prefix := "! "
|
||||
if lineA == lineB {
|
||||
prefix = " "
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func newTestConn(t *testing.T, want [][]byte) *nftables.Conn {
|
||||
conn, err := nftables.New(nftables.WithTestDial(
|
||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||
for idx, msg := range req {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(b) < 16 {
|
||||
continue
|
||||
}
|
||||
b = b[16:]
|
||||
if len(want) == 0 {
|
||||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||||
continue
|
||||
}
|
||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
return req, nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestInsertHookRule(t *testing.T) {
|
||||
proto := nftables.TableFamilyIPv4
|
||||
want := [][]byte{
|
||||
@ -117,7 +46,7 @@ func TestInsertHookRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -157,7 +86,7 @@ func TestInsertLoopbackRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -193,7 +122,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
tableV6 := testConn.AddTable(&nftables.Table{
|
||||
Family: protoV6,
|
||||
Name: "ts-filter-test",
|
||||
@ -229,7 +158,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -261,7 +190,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -293,7 +222,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -325,7 +254,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -357,7 +286,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -389,7 +318,7 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-nat-test",
|
||||
@ -421,7 +350,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) {
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
testConn := NewTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
@ -458,14 +387,6 @@ func newSysConn(t *testing.T) *nftables.Conn {
|
||||
return c
|
||||
}
|
||||
|
||||
func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
if err := ns.Close(); err != nil {
|
||||
t.Fatalf("newNS.Close() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
|
||||
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
|
||||
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}
|
||||
@ -843,7 +764,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||
defer runner.DelChains()
|
||||
runner.AddHooks()
|
||||
|
||||
forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD")
|
||||
forwardChain, err := GetChainFromTable(conn, runner.nft4.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get forwardChain: %v", err)
|
||||
}
|
||||
@ -857,7 +778,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
|
||||
}
|
||||
|
||||
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
||||
inputChain, err := GetChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get inputChain: %v", err)
|
||||
}
|
||||
@ -871,7 +792,7 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
|
||||
}
|
||||
|
||||
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
||||
postroutingChain, err := GetChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postroutingChain: %v", err)
|
||||
}
|
||||
|
||||
@ -62,21 +62,21 @@ type tableDetector interface {
|
||||
nftDetect() (int, error)
|
||||
}
|
||||
|
||||
type linuxFWDetector struct{}
|
||||
type LinuxFWDetector struct{}
|
||||
|
||||
// iptDetect returns the number of iptables rules in the current namespace.
|
||||
func (l *linuxFWDetector) iptDetect() (int, error) {
|
||||
func (l *LinuxFWDetector) iptDetect() (int, error) {
|
||||
return linuxfw.DetectIptables()
|
||||
}
|
||||
|
||||
// nftDetect returns the number of nftables rules in the current namespace.
|
||||
func (l *linuxFWDetector) nftDetect() (int, error) {
|
||||
func (l *LinuxFWDetector) nftDetect() (int, error) {
|
||||
return linuxfw.DetectNetfilter()
|
||||
}
|
||||
|
||||
// chooseFireWallMode returns the firewall mode to use based on the
|
||||
// environment and the system's capabilities.
|
||||
func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
|
||||
func ChooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
|
||||
if distro.Get() == distro.Gokrazy {
|
||||
// Reduce startup logging on gokrazy. There's no way to do iptables on
|
||||
// gokrazy anyway.
|
||||
@ -126,7 +126,7 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod
|
||||
// newNetfilterRunner creates a netfilterRunner using either nftables or iptables.
|
||||
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
|
||||
func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
|
||||
tableDetector := &linuxFWDetector{}
|
||||
tableDetector := &LinuxFWDetector{}
|
||||
var mode linuxfw.FirewallMode
|
||||
|
||||
// We now use iptables as default and have "auto" and "nftables" as
|
||||
@ -143,7 +143,7 @@ func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
|
||||
hostinfo.SetFirewallMode("nft-forced")
|
||||
mode = linuxfw.FirewallModeNfTables
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto":
|
||||
mode = chooseFireWallMode(logf, tableDetector)
|
||||
mode = ChooseFireWallMode(logf, tableDetector)
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
|
||||
logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
|
||||
hostinfo.SetFirewallMode("ipt-forced")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user