mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 08:11:32 +01:00 
			
		
		
		
	Add new rules to update DNAT rules for Kubernetes operator's HA ingress where it's expected that rules will be added/removed frequently (so we don't want to keep old rules around or rewrite existing rules unnecessarily): - allow deleting DNAT rules using metadata lookup - allow inserting DNAT rules if they don't already exist (using metadata lookup) Updates tailscale/tailscale#15895 Signed-off-by: Irbe Krumina <irbe@tailscale.com> Co-authored-by: chaosinthecrd <tom@tmlabs.co.uk>
		
			
				
	
	
		
			310 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			310 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| //go:build linux
 | |
| 
 | |
| package linuxfw
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/netip"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/google/nftables"
 | |
| 	"github.com/google/nftables/binaryutil"
 | |
| 	"github.com/google/nftables/expr"
 | |
| 	"golang.org/x/sys/unix"
 | |
| )
 | |
| 
 | |
| // This file contains functionality that is currently (09/2024) used to set up
 | |
| // routing for the Tailscale Kubernetes operator egress proxies. A tailnet
 | |
| // service (identified by tailnet IP or FQDN) that gets exposed to cluster
 | |
| // workloads gets a separate prerouting chain created for it for each IP family
 | |
| // of the chain's target addresses. Each service's prerouting chain contains one
 | |
| // or more portmapping rules. A portmapping rule DNATs traffic received on a
 | |
| // particular port to a port of the tailnet service. Creating a chain per
 | |
| // service makes it easier to delete a service when no longer needed and helps
 | |
| // with readability.
 | |
| 
 | |
| // EnsurePortMapRuleForSvc:
 | |
| // - ensures that nat table exists
 | |
| // - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table
 | |
| // - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist)
 | |
| func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
 | |
| 	t, ch, err := n.ensureChainForSvc(svc, targetIP)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
 | |
| 	}
 | |
| 	meta := svcPortMapRuleMeta(svc, targetIP, pm)
 | |
| 	rule, err := n.findRuleByMetadata(t, ch, meta)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error looking up rule: %w", err)
 | |
| 	}
 | |
| 	if rule != nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	p, err := protoFromString(pm.Protocol)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error converting protocol %s: %w", pm.Protocol, err)
 | |
| 	}
 | |
| 
 | |
| 	rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta)
 | |
| 	n.conn.InsertRule(rule)
 | |
| 	return n.conn.Flush()
 | |
| }
 | |
| 
 | |
| // DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain.
 | |
| // It finds the matching rule using metadata attached to the rule.
 | |
| // The caller is expected to call DeleteSvc if the whole service (the chain)
 | |
| // needs to be deleted, so we don't deal with the case where this is the only
 | |
| // rule in the chain here.
 | |
| func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
 | |
| 	table, err := n.getNFTByAddr(targetIP)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err)
 | |
| 	}
 | |
| 	t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error checking if nat table exists: %w", err)
 | |
| 	}
 | |
| 	if t == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	ch, err := getChainFromTable(n.conn, t, svc)
 | |
| 	if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | |
| 		return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
 | |
| 	}
 | |
| 	if errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | |
| 		return nil // service chain does not exist, so neither does the portmapping rule
 | |
| 	}
 | |
| 	meta := svcPortMapRuleMeta(svc, targetIP, pm)
 | |
| 	rule, err := n.findRuleByMetadata(t, ch, meta)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error checking if rule exists: %w", err)
 | |
| 	}
 | |
| 	if rule == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	if err := n.conn.DelRule(rule); err != nil {
 | |
| 		return fmt.Errorf("error deleting rule: %w", err)
 | |
| 	}
 | |
| 	return n.conn.Flush()
 | |
| }
 | |
| 
 | |
| // DeleteSvc deletes the chains for the given service if any exist.
 | |
| func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error {
 | |
| 	for _, tip := range targetIPs {
 | |
| 		table, err := n.getNFTByAddr(tip)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err)
 | |
| 		}
 | |
| 		t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("error checking if nat table exists: %w", err)
 | |
| 		}
 | |
| 		if t == nil {
 | |
| 			return nil
 | |
| 		}
 | |
| 		ch, err := getChainFromTable(n.conn, t, svc)
 | |
| 		if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | |
| 			return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
 | |
| 		}
 | |
| 		if errors.Is(err, errorChainNotFound{t.Name, svc}) {
 | |
| 			return nil
 | |
| 		}
 | |
| 		n.conn.DelChain(ch)
 | |
| 	}
 | |
| 	return n.conn.Flush()
 | |
| }
 | |
| 
 | |
| // EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
 | |
| // VIPService IP address to a local address. This is used by the Kubernetes
 | |
| // operator's network layer proxies to forward tailnet traffic for VIPServices
 | |
| // to Kubernetes Services.
 | |
| func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
 | |
| 	t, ch, err := n.ensurePreroutingChain(origDst)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
 | |
| 	}
 | |
| 	meta := svcRuleMeta(svc, origDst, dst)
 | |
| 	rule, err := n.findRuleByMetadata(t, ch, meta)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error looking up rule: %w", err)
 | |
| 	}
 | |
| 	if rule != nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	rule = dnatRuleForChain(t, ch, origDst, dst, meta)
 | |
| 	n.conn.InsertRule(rule)
 | |
| 	return n.conn.Flush()
 | |
| }
 | |
| 
 | |
| // DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
 | |
| // We use the metadata attached to the rule to look it up.
 | |
| func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
 | |
| 	table, err := n.getNFTByAddr(origDst)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
 | |
| 	}
 | |
| 	t, err := getTableIfExists(n.conn, table.Proto, "nat")
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error checking if nat table exists: %w", err)
 | |
| 	}
 | |
| 	if t == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	ch, err := getChainFromTable(n.conn, t, "PREROUTING")
 | |
| 	if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
 | |
| 		return nil
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
 | |
| 	}
 | |
| 	meta := svcRuleMeta(svcName, origDst, dst)
 | |
| 	rule, err := n.findRuleByMetadata(t, ch, meta)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error checking if rule exists: %w", err)
 | |
| 	}
 | |
| 	if rule == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	if err := n.conn.DelRule(rule); err != nil {
 | |
| 		return fmt.Errorf("error deleting rule: %w", err)
 | |
| 	}
 | |
| 	return n.conn.Flush()
 | |
| }
 | |
| 
 | |
| func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
 | |
| 	var fam uint32
 | |
| 	if targetIP.Is4() {
 | |
| 		fam = unix.NFPROTO_IPV4
 | |
| 	} else {
 | |
| 		fam = unix.NFPROTO_IPV6
 | |
| 	}
 | |
| 	rule := &nftables.Rule{
 | |
| 		Table:    t,
 | |
| 		Chain:    ch,
 | |
| 		UserData: meta,
 | |
| 		Exprs: []expr.Any{
 | |
| 			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
 | |
| 			&expr.Cmp{
 | |
| 				Op:       expr.CmpOpNeq,
 | |
| 				Register: 1,
 | |
| 				Data:     []byte(tun),
 | |
| 			},
 | |
| 			&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
 | |
| 			&expr.Cmp{
 | |
| 				Op:       expr.CmpOpEq,
 | |
| 				Register: 1,
 | |
| 				Data:     []byte{proto},
 | |
| 			},
 | |
| 			&expr.Payload{
 | |
| 				DestRegister: 1,
 | |
| 				Base:         expr.PayloadBaseTransportHeader,
 | |
| 				Offset:       2,
 | |
| 				Len:          2,
 | |
| 			},
 | |
| 			&expr.Cmp{
 | |
| 				Op:       expr.CmpOpEq,
 | |
| 				Register: 1,
 | |
| 				Data:     binaryutil.BigEndian.PutUint16(matchPort),
 | |
| 			},
 | |
| 			&expr.Immediate{
 | |
| 				Register: 1,
 | |
| 				Data:     targetIP.AsSlice(),
 | |
| 			},
 | |
| 			&expr.Immediate{
 | |
| 				Register: 2,
 | |
| 				Data:     binaryutil.BigEndian.PutUint16(targetPort),
 | |
| 			},
 | |
| 			&expr.NAT{
 | |
| 				Type:        expr.NATTypeDestNAT,
 | |
| 				Family:      fam,
 | |
| 				RegAddrMin:  1,
 | |
| 				RegAddrMax:  1,
 | |
| 				RegProtoMin: 2,
 | |
| 				RegProtoMax: 2,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 	return rule
 | |
| }
 | |
| 
 | |
| // svcPortMapRuleMeta generates metadata for a rule.
 | |
| // This metadata can then be used to find the rule.
 | |
| // https://github.com/google/nftables/issues/48
 | |
| func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte {
 | |
| 	return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol))
 | |
| }
 | |
| 
 | |
| func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) {
 | |
| 	if n.conn == nil || t == nil || ch == nil || len(meta) == 0 {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 	rules, err := n.conn.GetRules(t, ch)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error listing rules: %w", err)
 | |
| 	}
 | |
| 	for _, rule := range rules {
 | |
| 		if reflect.DeepEqual(rule.UserData, meta) {
 | |
| 			return rule, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return nil, nil
 | |
| }
 | |
| 
 | |
| func (n *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) {
 | |
| 	polAccept := nftables.ChainPolicyAccept
 | |
| 	table, err := n.getNFTByAddr(targetIP)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err)
 | |
| 	}
 | |
| 	nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
 | |
| 	if err != nil {
 | |
| 		return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
 | |
| 	}
 | |
| 	svcCh, err := getOrCreateChain(n.conn, chainInfo{
 | |
| 		table:         nat,
 | |
| 		name:          svc,
 | |
| 		chainType:     nftables.ChainTypeNAT,
 | |
| 		chainHook:     nftables.ChainHookPrerouting,
 | |
| 		chainPriority: nftables.ChainPriorityNATDest,
 | |
| 		chainPolicy:   &polAccept,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
 | |
| 	}
 | |
| 	return nat, svcCh, nil
 | |
| }
 | |
| 
 | |
| // // PortMap is the port mapping for a service rule.
 | |
| type PortMap struct {
 | |
| 	// MatchPort is the local port to which the rule should apply.
 | |
| 	MatchPort uint16
 | |
| 	// TargetPort is the port to which the traffic should be forwarded.
 | |
| 	TargetPort uint16
 | |
| 	// Protocol is the protocol to match packets on. Only TCP and UDP are
 | |
| 	// supported.
 | |
| 	Protocol string
 | |
| }
 | |
| 
 | |
| func protoFromString(s string) (uint8, error) {
 | |
| 	switch strings.ToLower(s) {
 | |
| 	case "tcp":
 | |
| 		return unix.IPPROTO_TCP, nil
 | |
| 	case "udp":
 | |
| 		return unix.IPPROTO_UDP, nil
 | |
| 	default:
 | |
| 		return 0, fmt.Errorf("unrecognized protocol: %q", s)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // svcRuleMeta generates metadata for a rule.
 | |
| // This metadata can then be used to find the rule.
 | |
| // https://github.com/google/nftables/issues/48
 | |
| func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte {
 | |
| 	return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()))
 | |
| }
 |