mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	Make matchers part of the Policy interface (#2514)
* Make matchers part of the Policy interface * Prevent race condition between rules and matchers * Test also matchers in tests for Policy.Filter * Compute `filterChanged` in v2 policy correctly * Fix nil vs. empty list issue in v2 policy test * policy/v2: always clear ssh map Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> Co-authored-by: Aras Ergus <aras.ergus@tngtech.com> Co-authored-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									eb1ecefd9e
								
							
						
					
					
						commit
						4651d06fa8
					
				| @ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { | ||||
| 		w.Write(pol) | ||||
| 	})) | ||||
| 	debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		filter := h.polMan.Filter() | ||||
| 		filter, _ := h.polMan.Filter() | ||||
| 
 | ||||
| 		filterJSON, err := json.MarshalIndent(filter, "", "  ") | ||||
| 		if err != nil { | ||||
|  | ||||
| @ -536,7 +536,7 @@ func appendPeerChanges( | ||||
| 	changed types.Nodes, | ||||
| 	cfg *types.Config, | ||||
| ) error { | ||||
| 	filter := polMan.Filter() | ||||
| 	filter, matchers := polMan.Filter() | ||||
| 
 | ||||
| 	sshPolicy, err := polMan.SSHPolicy(node) | ||||
| 	if err != nil { | ||||
| @ -546,7 +546,7 @@ func appendPeerChanges( | ||||
| 	// If there are filter rules present, see if there are any nodes that cannot | ||||
| 	// access each-other at all and remove them from the peers. | ||||
| 	if len(filter) > 0 { | ||||
| 		changed = policy.FilterNodesByACL(node, changed, filter) | ||||
| 		changed = policy.FilterNodesByACL(node, changed, matchers) | ||||
| 	} | ||||
| 
 | ||||
| 	profiles := generateUserProfiles(node, changed) | ||||
|  | ||||
| @ -13,6 +13,14 @@ type Match struct { | ||||
| 	dests *netipx.IPSet | ||||
| } | ||||
| 
 | ||||
| func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { | ||||
| 	matches := make([]Match, 0, len(rules)) | ||||
| 	for _, rule := range rules { | ||||
| 		matches = append(matches, MatchFromFilterRule(rule)) | ||||
| 	} | ||||
| 	return matches | ||||
| } | ||||
| 
 | ||||
| func MatchFromFilterRule(rule tailcfg.FilterRule) Match { | ||||
| 	dests := []string{} | ||||
| 	for _, dest := range rule.DstPorts { | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" | ||||
| @ -15,7 +16,8 @@ var ( | ||||
| ) | ||||
| 
 | ||||
| type PolicyManager interface { | ||||
| 	Filter() []tailcfg.FilterRule | ||||
| 	// Filter returns the current filter rules for the entire tailnet and the associated matchers. | ||||
| 	Filter() ([]tailcfg.FilterRule, []matcher.Match) | ||||
| 	SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) | ||||
| 	SetPolicy([]byte) (bool, error) | ||||
| 	SetUsers(users []types.User) (bool, error) | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"net/netip" | ||||
| 	"slices" | ||||
| 
 | ||||
| @ -15,7 +16,7 @@ import ( | ||||
| func FilterNodesByACL( | ||||
| 	node *types.Node, | ||||
| 	nodes types.Nodes, | ||||
| 	filter []tailcfg.FilterRule, | ||||
| 	matchers []matcher.Match, | ||||
| ) types.Nodes { | ||||
| 	var result types.Nodes | ||||
| 
 | ||||
| @ -24,7 +25,7 @@ func FilterNodesByACL( | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { | ||||
| 		if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) { | ||||
| 			result = append(result, peer) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -2,6 +2,7 @@ package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 
 | ||||
| @ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				var err error | ||||
| 				pm, err = pmf(users, append(tt.peers, tt.node)) | ||||
| 				require.NoError(t, err) | ||||
| 				got := pm.Filter() | ||||
| 				got, _ := pm.Filter() | ||||
| 				got = ReduceFilterRules(tt.node, got) | ||||
| 
 | ||||
| 				if diff := cmp.Diff(tt.want, got); diff != "" { | ||||
| @ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) { | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			matchers := matcher.MatchesFromFilterRules(tt.args.rules) | ||||
| 			got := FilterNodesByACL( | ||||
| 				tt.args.node, | ||||
| 				tt.args.nodes, | ||||
| 				tt.args.rules, | ||||
| 				matchers, | ||||
| 			) | ||||
| 			if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { | ||||
| 				t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) | ||||
|  | ||||
| @ -2,6 +2,7 @@ package v1 | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"io" | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| @ -88,10 +89,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) { | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (pm *PolicyManager) Filter() []tailcfg.FilterRule { | ||||
| func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { | ||||
| 	pm.mu.Lock() | ||||
| 	defer pm.mu.Unlock() | ||||
| 	return pm.filter | ||||
| 	return pm.filter, matcher.MatchesFromFilterRules(pm.filter) | ||||
| } | ||||
| 
 | ||||
| func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package v1 | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| @ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 		wantNodesChange  bool | ||||
| 		wantPolicyChange bool | ||||
| 		wantFilter       []tailcfg.FilterRule | ||||
| 		wantMatchers     []matcher.Match | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "set-nodes", | ||||
| @ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantMatchers: []matcher.Match{ | ||||
| 				matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:            "set-users", | ||||
| @ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantMatchers: []matcher.Match{ | ||||
| 				matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "set-users-and-node", | ||||
| @ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantMatchers: []matcher.Match{ | ||||
| 				matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}), | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "set-policy", | ||||
| @ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantMatchers: []matcher.Match{ | ||||
| 				matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}), | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| @ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) { | ||||
| 				assert.Equal(t, tt.wantNodesChange, change) | ||||
| 			} | ||||
| 
 | ||||
| 			if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { | ||||
| 				t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) | ||||
| 			filter, matchers := pm.Filter() | ||||
| 			if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { | ||||
| 				t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 			if diff := cmp.Diff( | ||||
| 				tt.wantMatchers, | ||||
| 				matchers, | ||||
| 				cmp.AllowUnexported(matcher.Match{}), | ||||
| 			); diff != "" { | ||||
| 				t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| @ -7,6 +7,8 @@ import ( | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 
 | ||||
| 	"slices" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| @ -24,6 +26,7 @@ type PolicyManager struct { | ||||
| 
 | ||||
| 	filterHash deephash.Sum | ||||
| 	filter     []tailcfg.FilterRule | ||||
| 	matchers   []matcher.Match | ||||
| 
 | ||||
| 	tagOwnerMapHash deephash.Sum | ||||
| 	tagOwnerMap     map[Tag]*netipx.IPSet | ||||
| @ -62,15 +65,24 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM | ||||
| // updateLocked updates the filter rules based on the current policy and nodes. | ||||
| // It must be called with the lock held. | ||||
| func (pm *PolicyManager) updateLocked() (bool, error) { | ||||
| 	// Clear the SSH policy map to ensure it's recalculated with the new policy. | ||||
| 	// TODO(kradalby): This could potentially be optimized by only clearing the | ||||
| 	// policies for nodes that have changed. Particularly if the only difference is | ||||
| 	// that nodes has been added or removed. | ||||
| 	defer clear(pm.sshPolicyMap) | ||||
| 
 | ||||
| 	filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes) | ||||
| 	if err != nil { | ||||
| 		return false, fmt.Errorf("compiling filter rules: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	filterHash := deephash.Hash(&filter) | ||||
| 	filterChanged := filterHash == pm.filterHash | ||||
| 	filterChanged := filterHash != pm.filterHash | ||||
| 	pm.filter = filter | ||||
| 	pm.filterHash = filterHash | ||||
| 	if filterChanged { | ||||
| 		pm.matchers = matcher.MatchesFromFilterRules(pm.filter) | ||||
| 	} | ||||
| 
 | ||||
| 	// Order matters, tags might be used in autoapprovers, so we need to ensure | ||||
| 	// that the map for tag owners is resolved before resolving autoapprovers. | ||||
| @ -100,12 +112,6 @@ func (pm *PolicyManager) updateLocked() (bool, error) { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Clear the SSH policy map to ensure it's recalculated with the new policy. | ||||
| 	// TODO(kradalby): This could potentially be optimized by only clearing the | ||||
| 	// policies for nodes that have changed. Particularly if the only difference is | ||||
| 	// that nodes has been added or removed. | ||||
| 	clear(pm.sshPolicyMap) | ||||
| 
 | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| @ -144,11 +150,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { | ||||
| 	return pm.updateLocked() | ||||
| } | ||||
| 
 | ||||
| // Filter returns the current filter rules for the entire tailnet. | ||||
| func (pm *PolicyManager) Filter() []tailcfg.FilterRule { | ||||
| // Filter returns the current filter rules for the entire tailnet and the associated matchers. | ||||
| func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { | ||||
| 	pm.mu.Lock() | ||||
| 	defer pm.mu.Unlock() | ||||
| 	return pm.filter | ||||
| 	return pm.filter, pm.matchers | ||||
| } | ||||
| 
 | ||||
| // SetUsers updates the users in the policy manager and updates the filter rules. | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package v2 | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| @ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		name       string | ||||
| 		pol        string | ||||
| 		nodes      types.Nodes | ||||
| 		wantFilter []tailcfg.FilterRule | ||||
| 		name         string | ||||
| 		pol          string | ||||
| 		nodes        types.Nodes | ||||
| 		wantFilter   []tailcfg.FilterRule | ||||
| 		wantMatchers []matcher.Match | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:       "empty-policy", | ||||
| 			pol:        "{}", | ||||
| 			nodes:      types.Nodes{}, | ||||
| 			wantFilter: nil, | ||||
| 			name:         "empty-policy", | ||||
| 			pol:          "{}", | ||||
| 			nodes:        types.Nodes{}, | ||||
| 			wantFilter:   nil, | ||||
| 			wantMatchers: []matcher.Match{}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| @ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) { | ||||
| 			pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) | ||||
| 			require.NoError(t, err) | ||||
| 
 | ||||
| 			filter := pm.Filter() | ||||
| 			if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { | ||||
| 				t.Errorf("Filter() mismatch (-want +got):\n%s", diff) | ||||
| 			filter, matchers := pm.Filter() | ||||
| 			if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { | ||||
| 				t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 			if diff := cmp.Diff( | ||||
| 				tt.wantMatchers, | ||||
| 				matchers, | ||||
| 				cmp.AllowUnexported(matcher.Match{}), | ||||
| 			); diff != "" { | ||||
| 				t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 
 | ||||
| 			// TODO(kradalby): Test SSH Policy | ||||
|  | ||||
| @ -270,18 +270,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { | ||||
| func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { | ||||
| 	src := node.IPs() | ||||
| 	allowedIPs := node2.IPs() | ||||
| 
 | ||||
| 	// TODO(kradalby): Regenerate this every time the filter change, instead of | ||||
| 	// every time we use it. | ||||
| 	// Part of #2416 | ||||
| 	matchers := make([]matcher.Match, len(filter)) | ||||
| 	for i, rule := range filter { | ||||
| 		matchers[i] = matcher.MatchFromFilterRule(rule) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, matcher := range matchers { | ||||
| 		if !matcher.SrcsContainsIPs(src...) { | ||||
| 			continue | ||||
|  | ||||
| @ -2,6 +2,7 @@ package types | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| @ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) { | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got := tt.node1.CanAccess(tt.rules, &tt.node2) | ||||
| 			matchers := matcher.MatchesFromFilterRules(tt.rules) | ||||
| 			got := tt.node1.CanAccess(matchers, &tt.node2) | ||||
| 
 | ||||
| 			if got != tt.want { | ||||
| 				t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user