mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 10:11:18 +01:00 
			
		
		
		
	Add new rules to update DNAT rules for Kubernetes operator's HA ingress where it's expected that rules will be added/removed frequently (so we don't want to keep old rules around or rewrite existing rules unnecessarily): - allow deleting DNAT rules using metadata lookup - allow inserting DNAT rules if they don't already exist (using metadata lookup) Updates tailscale/tailscale#15895 Signed-off-by: Irbe Krumina <irbe@tailscale.com> Co-authored-by: chaosinthecrd <tom@tmlabs.co.uk>
		
			
				
	
	
		
			310 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			310 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
//go:build linux
 | 
						|
 | 
						|
package linuxfw
 | 
						|
 | 
						|
import (
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/netip"
 | 
						|
	"reflect"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/google/nftables"
 | 
						|
	"github.com/google/nftables/binaryutil"
 | 
						|
	"github.com/google/nftables/expr"
 | 
						|
	"golang.org/x/sys/unix"
 | 
						|
)
 | 
						|
 | 
						|
// This file contains functionality that is currently (09/2024) used to set up
 | 
						|
// routing for the Tailscale Kubernetes operator egress proxies. A tailnet
 | 
						|
// service (identified by tailnet IP or FQDN) that gets exposed to cluster
 | 
						|
// workloads gets a separate prerouting chain created for it for each IP family
 | 
						|
// of the chain's target addresses. Each service's prerouting chain contains one
 | 
						|
// or more portmapping rules. A portmapping rule DNATs traffic received on a
 | 
						|
// particular port to a port of the tailnet service. Creating a chain per
 | 
						|
// service makes it easier to delete a service when no longer needed and helps
 | 
						|
// with readability.
 | 
						|
 | 
						|
// EnsurePortMapRuleForSvc:
 | 
						|
// - ensures that nat table exists
 | 
						|
// - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table
 | 
						|
// - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist)
 | 
						|
func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
 | 
						|
	t, ch, err := n.ensureChainForSvc(svc, targetIP)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
 | 
						|
	}
 | 
						|
	meta := svcPortMapRuleMeta(svc, targetIP, pm)
 | 
						|
	rule, err := n.findRuleByMetadata(t, ch, meta)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error looking up rule: %w", err)
 | 
						|
	}
 | 
						|
	if rule != nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	p, err := protoFromString(pm.Protocol)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error converting protocol %s: %w", pm.Protocol, err)
 | 
						|
	}
 | 
						|
 | 
						|
	rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta)
 | 
						|
	n.conn.InsertRule(rule)
 | 
						|
	return n.conn.Flush()
 | 
						|
}
 | 
						|
 | 
						|
// DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain.
 | 
						|
// It finds the matching rule using metadata attached to the rule.
 | 
						|
// The caller is expected to call DeleteSvc if the whole service (the chain)
 | 
						|
// needs to be deleted, so we don't deal with the case where this is the only
 | 
						|
// rule in the chain here.
 | 
						|
func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
 | 
						|
	table, err := n.getNFTByAddr(targetIP)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err)
 | 
						|
	}
 | 
						|
	t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error checking if nat table exists: %w", err)
 | 
						|
	}
 | 
						|
	if t == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	ch, err := getChainFromTable(n.conn, t, svc)
 | 
						|
	if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | 
						|
		return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
 | 
						|
	}
 | 
						|
	if errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | 
						|
		return nil // service chain does not exist, so neither does the portmapping rule
 | 
						|
	}
 | 
						|
	meta := svcPortMapRuleMeta(svc, targetIP, pm)
 | 
						|
	rule, err := n.findRuleByMetadata(t, ch, meta)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error checking if rule exists: %w", err)
 | 
						|
	}
 | 
						|
	if rule == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	if err := n.conn.DelRule(rule); err != nil {
 | 
						|
		return fmt.Errorf("error deleting rule: %w", err)
 | 
						|
	}
 | 
						|
	return n.conn.Flush()
 | 
						|
}
 | 
						|
 | 
						|
// DeleteSvc deletes the chains for the given service if any exist.
 | 
						|
func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error {
 | 
						|
	for _, tip := range targetIPs {
 | 
						|
		table, err := n.getNFTByAddr(tip)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err)
 | 
						|
		}
 | 
						|
		t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("error checking if nat table exists: %w", err)
 | 
						|
		}
 | 
						|
		if t == nil {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		ch, err := getChainFromTable(n.conn, t, svc)
 | 
						|
		if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | 
						|
			return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
 | 
						|
		}
 | 
						|
		if errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		n.conn.DelChain(ch)
 | 
						|
	}
 | 
						|
	return n.conn.Flush()
 | 
						|
}
 | 
						|
 | 
						|
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
 | 
						|
// VIPService IP address to a local address. This is used by the Kubernetes
 | 
						|
// operator's network layer proxies to forward tailnet traffic for VIPServices
 | 
						|
// to Kubernetes Services.
 | 
						|
func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
 | 
						|
	t, ch, err := n.ensurePreroutingChain(origDst)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
 | 
						|
	}
 | 
						|
	meta := svcRuleMeta(svc, origDst, dst)
 | 
						|
	rule, err := n.findRuleByMetadata(t, ch, meta)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error looking up rule: %w", err)
 | 
						|
	}
 | 
						|
	if rule != nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	rule = dnatRuleForChain(t, ch, origDst, dst, meta)
 | 
						|
	n.conn.InsertRule(rule)
 | 
						|
	return n.conn.Flush()
 | 
						|
}
 | 
						|
 | 
						|
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
 | 
						|
// We use the metadata attached to the rule to look it up.
 | 
						|
func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
 | 
						|
	table, err := n.getNFTByAddr(origDst)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
 | 
						|
	}
 | 
						|
	t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error checking if nat table exists: %w", err)
 | 
						|
	}
 | 
						|
	if t == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	ch, err := getChainFromTable(n.conn, t, "PREROUTING")
 | 
						|
	if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
 | 
						|
	}
 | 
						|
	meta := svcRuleMeta(svcName, origDst, dst)
 | 
						|
	rule, err := n.findRuleByMetadata(t, ch, meta)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("error checking if rule exists: %w", err)
 | 
						|
	}
 | 
						|
	if rule == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	if err := n.conn.DelRule(rule); err != nil {
 | 
						|
		return fmt.Errorf("error deleting rule: %w", err)
 | 
						|
	}
 | 
						|
	return n.conn.Flush()
 | 
						|
}
 | 
						|
 | 
						|
func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
 | 
						|
	var fam uint32
 | 
						|
	if targetIP.Is4() {
 | 
						|
		fam = unix.NFPROTO_IPV4
 | 
						|
	} else {
 | 
						|
		fam = unix.NFPROTO_IPV6
 | 
						|
	}
 | 
						|
	rule := &nftables.Rule{
 | 
						|
		Table:    t,
 | 
						|
		Chain:    ch,
 | 
						|
		UserData: meta,
 | 
						|
		Exprs: []expr.Any{
 | 
						|
			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
 | 
						|
			&expr.Cmp{
 | 
						|
				Op:       expr.CmpOpNeq,
 | 
						|
				Register: 1,
 | 
						|
				Data:     []byte(tun),
 | 
						|
			},
 | 
						|
			&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
 | 
						|
			&expr.Cmp{
 | 
						|
				Op:       expr.CmpOpEq,
 | 
						|
				Register: 1,
 | 
						|
				Data:     []byte{proto},
 | 
						|
			},
 | 
						|
			&expr.Payload{
 | 
						|
				DestRegister: 1,
 | 
						|
				Base:         expr.PayloadBaseTransportHeader,
 | 
						|
				Offset:       2,
 | 
						|
				Len:          2,
 | 
						|
			},
 | 
						|
			&expr.Cmp{
 | 
						|
				Op:       expr.CmpOpEq,
 | 
						|
				Register: 1,
 | 
						|
				Data:     binaryutil.BigEndian.PutUint16(matchPort),
 | 
						|
			},
 | 
						|
			&expr.Immediate{
 | 
						|
				Register: 1,
 | 
						|
				Data:     targetIP.AsSlice(),
 | 
						|
			},
 | 
						|
			&expr.Immediate{
 | 
						|
				Register: 2,
 | 
						|
				Data:     binaryutil.BigEndian.PutUint16(targetPort),
 | 
						|
			},
 | 
						|
			&expr.NAT{
 | 
						|
				Type:        expr.NATTypeDestNAT,
 | 
						|
				Family:      fam,
 | 
						|
				RegAddrMin:  1,
 | 
						|
				RegAddrMax:  1,
 | 
						|
				RegProtoMin: 2,
 | 
						|
				RegProtoMax: 2,
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
	return rule
 | 
						|
}
 | 
						|
 | 
						|
// svcPortMapRuleMeta generates metadata for a rule.
 | 
						|
// This metadata can then be used to find the rule.
 | 
						|
// https://github.com/google/nftables/issues/48
 | 
						|
func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte {
 | 
						|
	return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol))
 | 
						|
}
 | 
						|
 | 
						|
func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) {
 | 
						|
	if n.conn == nil || t == nil || ch == nil || len(meta) == 0 {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
	rules, err := n.conn.GetRules(t, ch)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("error listing rules: %w", err)
 | 
						|
	}
 | 
						|
	for _, rule := range rules {
 | 
						|
		if reflect.DeepEqual(rule.UserData, meta) {
 | 
						|
			return rule, nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil, nil
 | 
						|
}
 | 
						|
 | 
						|
func (n *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) {
 | 
						|
	polAccept := nftables.ChainPolicyAccept
 | 
						|
	table, err := n.getNFTByAddr(targetIP)
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err)
 | 
						|
	}
 | 
						|
	nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
 | 
						|
	}
 | 
						|
	svcCh, err := getOrCreateChain(n.conn, chainInfo{
 | 
						|
		table:         nat,
 | 
						|
		name:          svc,
 | 
						|
		chainType:     nftables.ChainTypeNAT,
 | 
						|
		chainHook:     nftables.ChainHookPrerouting,
 | 
						|
		chainPriority: nftables.ChainPriorityNATDest,
 | 
						|
		chainPolicy:   &polAccept,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
 | 
						|
	}
 | 
						|
	return nat, svcCh, nil
 | 
						|
}
 | 
						|
 | 
						|
// // PortMap is the port mapping for a service rule.
 | 
						|
type PortMap struct {
 | 
						|
	// MatchPort is the local port to which the rule should apply.
 | 
						|
	MatchPort uint16
 | 
						|
	// TargetPort is the port to which the traffic should be forwarded.
 | 
						|
	TargetPort uint16
 | 
						|
	// Protocol is the protocol to match packets on. Only TCP and UDP are
 | 
						|
	// supported.
 | 
						|
	Protocol string
 | 
						|
}
 | 
						|
 | 
						|
func protoFromString(s string) (uint8, error) {
 | 
						|
	switch strings.ToLower(s) {
 | 
						|
	case "tcp":
 | 
						|
		return unix.IPPROTO_TCP, nil
 | 
						|
	case "udp":
 | 
						|
		return unix.IPPROTO_UDP, nil
 | 
						|
	default:
 | 
						|
		return 0, fmt.Errorf("unrecognized protocol: %q", s)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// svcRuleMeta generates metadata for a rule.
 | 
						|
// This metadata can then be used to find the rule.
 | 
						|
// https://github.com/google/nftables/issues/48
 | 
						|
func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte {
 | 
						|
	return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()))
 | 
						|
}
 |