mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	Resolve user to stable unique ID in policy (#2205)
This commit is contained in:
		
							parent
							
								
									3a2589f1a9
								
							
						
					
					
						commit
						fffd23602b
					
				| @ -1029,14 +1029,18 @@ func (h *Headscale) loadACLPolicy() error { | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("loading nodes from database to validate policy: %w", err) | ||||
| 		} | ||||
| 		users, err := h.db.ListUsers() | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("loading users from database to validate policy: %w", err) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = pol.CompileFilterRules(nodes) | ||||
| 		_, err = pol.CompileFilterRules(users, nodes) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("verifying policy rules: %w", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if len(nodes) > 0 { | ||||
| 			_, err = pol.CompileSSHPolicy(nodes[0], nodes) | ||||
| 			_, err = pol.CompileSSHPolicy(nodes[0], users, nodes) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("verifying SSH rules: %w", err) | ||||
| 			} | ||||
|  | ||||
| @ -256,10 +256,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(testPeers), check.Equals, 9) | ||||
| 
 | ||||
| 	adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) | ||||
| 	adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) | ||||
| 	testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) | ||||
|  | ||||
| @ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( | ||||
| 			if approvedAlias == node.User.Username() { | ||||
| 				approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||
| 			} else { | ||||
| 				users, err := ListUsers(tx) | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf("looking up users to expand route alias: %w", err) | ||||
| 				} | ||||
| 
 | ||||
| 				// TODO(kradalby): figure out how to get this to depend on less stuff | ||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) | ||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) | ||||
| 				} | ||||
|  | ||||
| @ -773,14 +773,18 @@ func (api headscaleV1APIServer) SetPolicy( | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) | ||||
| 	} | ||||
| 	users, err := api.h.db.ListUsers() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("loading users from database to validate policy: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = pol.CompileFilterRules(nodes) | ||||
| 	_, err = pol.CompileFilterRules(users, nodes) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("verifying policy rules: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(nodes) > 0 { | ||||
| 		_, err = pol.CompileSSHPolicy(nodes[0], nodes) | ||||
| 		_, err = pol.CompileSSHPolicy(nodes[0], users, nodes) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("verifying SSH rules: %w", err) | ||||
| 		} | ||||
|  | ||||
| @ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { | ||||
| func (m *Mapper) fullMapResponse( | ||||
| 	node *types.Node, | ||||
| 	peers types.Nodes, | ||||
| 	users []types.User, | ||||
| 	pol *policy.ACLPolicy, | ||||
| 	capVer tailcfg.CapabilityVersion, | ||||
| ) (*tailcfg.MapResponse, error) { | ||||
| @ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse( | ||||
| 		pol, | ||||
| 		node, | ||||
| 		capVer, | ||||
| 		users, | ||||
| 		peers, | ||||
| 		peers, | ||||
| 		m.cfg, | ||||
| @ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse( | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	users, err := m.db.ListUsers() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) | ||||
| 	resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	users, err := m.db.ListUsers() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("listing users for map response: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var removedIDs []tailcfg.NodeID | ||||
| 	var changedIDs []types.NodeID | ||||
| 	for nodeID, nodeChanged := range changed { | ||||
| @ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse( | ||||
| 		pol, | ||||
| 		node, | ||||
| 		mapRequest.Version, | ||||
| 		users, | ||||
| 		peers, | ||||
| 		changedNodes, | ||||
| 		m.cfg, | ||||
| @ -508,16 +520,17 @@ func appendPeerChanges( | ||||
| 	pol *policy.ACLPolicy, | ||||
| 	node *types.Node, | ||||
| 	capVer tailcfg.CapabilityVersion, | ||||
| 	users []types.User, | ||||
| 	peers types.Nodes, | ||||
| 	changed types.Nodes, | ||||
| 	cfg *types.Config, | ||||
| ) error { | ||||
| 	packetFilter, err := pol.CompileFilterRules(append(peers, node)) | ||||
| 	packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	sshPolicy, err := pol.CompileSSHPolicy(node, peers) | ||||
| 	sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 	lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) | ||||
| 	expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) | ||||
| 
 | ||||
| 	user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} | ||||
| 	user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} | ||||
| 
 | ||||
| 	mini := &types.Node{ | ||||
| 		ID: 0, | ||||
| 		MachineKey: mustMK( | ||||
| @ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		IPv4:       iap("100.64.0.1"), | ||||
| 		Hostname:   "mini", | ||||
| 		GivenName:  "mini", | ||||
| 		UserID:     0, | ||||
| 		User:       types.User{Name: "mini"}, | ||||
| 		UserID:     user1.ID, | ||||
| 		User:       user1, | ||||
| 		ForcedTags: []string{}, | ||||
| 		AuthKey:    &types.PreAuthKey{}, | ||||
| 		LastSeen:   &lastSeen, | ||||
| @ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		IPv4:       iap("100.64.0.2"), | ||||
| 		Hostname:   "peer1", | ||||
| 		GivenName:  "peer1", | ||||
| 		UserID:     0, | ||||
| 		User:       types.User{Name: "mini"}, | ||||
| 		UserID:     user1.ID, | ||||
| 		User:       user1, | ||||
| 		ForcedTags: []string{}, | ||||
| 		LastSeen:   &lastSeen, | ||||
| 		Expiry:     &expire, | ||||
| @ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		IPv4:       iap("100.64.0.3"), | ||||
| 		Hostname:   "peer2", | ||||
| 		GivenName:  "peer2", | ||||
| 		UserID:     1, | ||||
| 		User:       types.User{Name: "peer2"}, | ||||
| 		UserID:     user2.ID, | ||||
| 		User:       user2, | ||||
| 		ForcedTags: []string{}, | ||||
| 		LastSeen:   &lastSeen, | ||||
| 		Expiry:     &expire, | ||||
| @ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 			got, err := mappy.fullMapResponse( | ||||
| 				tt.node, | ||||
| 				tt.peers, | ||||
| 				[]types.User{user1, user2}, | ||||
| 				tt.pol, | ||||
| 				0, | ||||
| 			) | ||||
|  | ||||
| @ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( | ||||
| 	policy *ACLPolicy, | ||||
| 	node *types.Node, | ||||
| 	peers types.Nodes, | ||||
| 	users []types.User, | ||||
| ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ||||
| 	// If there is no policy defined, we default to allow all | ||||
| 	if policy == nil { | ||||
| 		return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	rules, err := policy.CompileFilterRules(append(peers, node)) | ||||
| 	rules, err := policy.CompileFilterRules(users, append(peers, node)) | ||||
| 	if err != nil { | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") | ||||
| 
 | ||||
| 	sshPolicy, err := policy.CompileSSHPolicy(node, peers) | ||||
| 	sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) | ||||
| 	if err != nil { | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 	} | ||||
| @ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( | ||||
| // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a | ||||
| // set of Tailscale compatible FilterRules used to allow traffic on clients. | ||||
| func (pol *ACLPolicy) CompileFilterRules( | ||||
| 	users []types.User, | ||||
| 	nodes types.Nodes, | ||||
| ) ([]tailcfg.FilterRule, error) { | ||||
| 	if pol == nil { | ||||
| @ -176,7 +178,7 @@ func (pol *ACLPolicy) CompileFilterRules( | ||||
| 
 | ||||
| 		var srcIPs []string | ||||
| 		for srcIndex, src := range acl.Sources { | ||||
| 			srcs, err := pol.expandSource(src, nodes) | ||||
| 			srcs, err := pol.expandSource(src, users, nodes) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf( | ||||
| 					"parsing policy, acl index: %d->%d: %w", | ||||
| @ -202,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( | ||||
| 
 | ||||
| 			expanded, err := pol.ExpandAlias( | ||||
| 				nodes, | ||||
| 				users, | ||||
| 				alias, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| @ -286,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F | ||||
| 
 | ||||
| func (pol *ACLPolicy) CompileSSHPolicy( | ||||
| 	node *types.Node, | ||||
| 	users []types.User, | ||||
| 	peers types.Nodes, | ||||
| ) (*tailcfg.SSHPolicy, error) { | ||||
| 	if pol == nil { | ||||
| @ -317,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( | ||||
| 	for index, sshACL := range pol.SSHs { | ||||
| 		var dest netipx.IPSetBuilder | ||||
| 		for _, src := range sshACL.Destinations { | ||||
| 			expanded, err := pol.ExpandAlias(append(peers, node), src) | ||||
| 			expanded, err := pol.ExpandAlias(append(peers, node), users, src) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @ -377,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( | ||||
| 			} else { | ||||
| 				expandedSrcs, err := pol.ExpandAlias( | ||||
| 					peers, | ||||
| 					users, | ||||
| 					rawSrc, | ||||
| 				) | ||||
| 				if err != nil { | ||||
| @ -526,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { | ||||
| // with the given src alias. | ||||
| func (pol *ACLPolicy) expandSource( | ||||
| 	src string, | ||||
| 	users []types.User, | ||||
| 	nodes types.Nodes, | ||||
| ) ([]string, error) { | ||||
| 	ipSet, err := pol.ExpandAlias(nodes, src) | ||||
| 	ipSet, err := pol.ExpandAlias(nodes, users, src) | ||||
| 	if err != nil { | ||||
| 		return []string{}, err | ||||
| 	} | ||||
| @ -552,6 +558,7 @@ func (pol *ACLPolicy) expandSource( | ||||
| // and transform these in IPAddresses. | ||||
| func (pol *ACLPolicy) ExpandAlias( | ||||
| 	nodes types.Nodes, | ||||
| 	users []types.User, | ||||
| 	alias string, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	if isWildcard(alias) { | ||||
| @ -566,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( | ||||
| 
 | ||||
| 	// if alias is a group | ||||
| 	if isGroup(alias) { | ||||
| 		return pol.expandIPsFromGroup(alias, nodes) | ||||
| 		return pol.expandIPsFromGroup(alias, users, nodes) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is a tag | ||||
| 	if isTag(alias) { | ||||
| 		return pol.expandIPsFromTag(alias, nodes) | ||||
| 		return pol.expandIPsFromTag(alias, users, nodes) | ||||
| 	} | ||||
| 
 | ||||
| 	if isAutoGroup(alias) { | ||||
| @ -579,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is a user | ||||
| 	if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { | ||||
| 	if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { | ||||
| 		return ips, err | ||||
| 	} | ||||
| 
 | ||||
| @ -588,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( | ||||
| 	if h, ok := pol.Hosts[alias]; ok { | ||||
| 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | ||||
| 
 | ||||
| 		return pol.ExpandAlias(nodes, h.String()) | ||||
| 		return pol.ExpandAlias(nodes, users, h.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is an IP | ||||
| @ -765,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( | ||||
| 
 | ||||
| func (pol *ACLPolicy) expandIPsFromGroup( | ||||
| 	group string, | ||||
| 	users []types.User, | ||||
| 	nodes types.Nodes, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	var build netipx.IPSetBuilder | ||||
| 
 | ||||
| 	users, err := pol.expandUsersFromGroup(group) | ||||
| 	userTokens, err := pol.expandUsersFromGroup(group) | ||||
| 	if err != nil { | ||||
| 		return &netipx.IPSet{}, err | ||||
| 	} | ||||
| 	for _, user := range users { | ||||
| 		filteredNodes := filterNodesByUser(nodes, user) | ||||
| 	for _, user := range userTokens { | ||||
| 		filteredNodes := filterNodesByUser(nodes, users, user) | ||||
| 		for _, node := range filteredNodes { | ||||
| 			node.AppendToIPSet(&build) | ||||
| 		} | ||||
| @ -785,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( | ||||
| 
 | ||||
| func (pol *ACLPolicy) expandIPsFromTag( | ||||
| 	alias string, | ||||
| 	users []types.User, | ||||
| 	nodes types.Nodes, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	var build netipx.IPSetBuilder | ||||
| @ -817,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( | ||||
| 
 | ||||
| 	// filter out nodes per tag owner | ||||
| 	for _, user := range owners { | ||||
| 		nodes := filterNodesByUser(nodes, user) | ||||
| 		nodes := filterNodesByUser(nodes, users, user) | ||||
| 		for _, node := range nodes { | ||||
| 			if node.Hostinfo == nil { | ||||
| 				continue | ||||
| @ -834,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( | ||||
| 
 | ||||
| func (pol *ACLPolicy) expandIPsFromUser( | ||||
| 	user string, | ||||
| 	users []types.User, | ||||
| 	nodes types.Nodes, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	var build netipx.IPSetBuilder | ||||
| 
 | ||||
| 	filteredNodes := filterNodesByUser(nodes, user) | ||||
| 	filteredNodes := filterNodesByUser(nodes, users, user) | ||||
| 	filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) | ||||
| 
 | ||||
| 	// shortcurcuit if we have no nodes to get ips from. | ||||
| @ -967,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode( | ||||
| 	return validTags, invalidTags | ||||
| } | ||||
| 
 | ||||
| func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { | ||||
| // filterNodesByUser returns a list of nodes that match the given userToken from a | ||||
| // policy. | ||||
| // Matching nodes are determined by first matching the user token to a user by checking: | ||||
| // - If it is an ID that mactches the user database ID | ||||
| // - It is the Provider Identifier from OIDC | ||||
| // - It matches the username or email of a user | ||||
| // | ||||
| // If the token matches more than one user, zero nodes will returned. | ||||
| func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { | ||||
| 	var out types.Nodes | ||||
| 
 | ||||
| 	var potentialUsers []types.User | ||||
| 	for _, user := range users { | ||||
| 		if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken { | ||||
| 			// If a user is matching with a known unique field, | ||||
| 			// disgard all other users and only keep the current | ||||
| 			// user. | ||||
| 			potentialUsers = []types.User{user} | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 		if user.Email == userToken { | ||||
| 			potentialUsers = append(potentialUsers, user) | ||||
| 		} | ||||
| 		if user.Name == userToken { | ||||
| 			potentialUsers = append(potentialUsers, user) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(potentialUsers) != 1 { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	user := potentialUsers[0] | ||||
| 
 | ||||
| 	for _, node := range nodes { | ||||
| 		if node.User.Username() == user { | ||||
| 		if node.User.ID == user.ID { | ||||
| 			out = append(out, node) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -1,9 +1,12 @@ | ||||
| package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"math/rand/v2" | ||||
| 	"net/netip" | ||||
| 	"slices" | ||||
| 	"sort" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| @ -14,6 +17,7 @@ import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"go4.org/netipx" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/net/tsaddr" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| @ -375,18 +379,24 @@ func TestParsing(t *testing.T) { | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			rules, err := pol.CompileFilterRules(types.Nodes{ | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.100.100.100"), | ||||
| 			user := types.User{ | ||||
| 				Model: gorm.Model{ID: 1}, | ||||
| 				Name:  "testuser", | ||||
| 			} | ||||
| 			rules, err := pol.CompileFilterRules( | ||||
| 				[]types.User{ | ||||
| 					user, | ||||
| 				}, | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("200.200.200.200"), | ||||
| 					User: types.User{ | ||||
| 						Name: "testuser", | ||||
| 				types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.100.100.100"), | ||||
| 					}, | ||||
| 					Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}) | ||||
| 					&types.Node{ | ||||
| 						IPv4:     iap("200.200.200.200"), | ||||
| 						User:     user, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}) | ||||
| 
 | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) | ||||
| @ -533,7 +543,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { | ||||
| 	c.Assert(pol.ACLs, check.HasLen, 6) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, err := pol.CompileFilterRules(types.Nodes{}) | ||||
| 	rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	c.Assert(rules, check.IsNil) | ||||
| } | ||||
| @ -549,7 +559,12 @@ func (s *Suite) TestInvalidAction(c *check.C) { | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		&types.Node{}, | ||||
| 		types.Nodes{}, | ||||
| 		[]types.User{}, | ||||
| 	) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -568,7 +583,12 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		&types.Node{}, | ||||
| 		types.Nodes{}, | ||||
| 		[]types.User{}, | ||||
| 	) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -584,7 +604,12 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | ||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		&types.Node{}, | ||||
| 		types.Nodes{}, | ||||
| 		[]types.User{}, | ||||
| 	) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -860,7 +885,25 @@ func Test_expandPorts(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_listNodesInUser(t *testing.T) { | ||||
| func Test_filterNodesByUser(t *testing.T) { | ||||
| 	users := []types.User{ | ||||
| 		{Model: gorm.Model{ID: 1}, Name: "marc"}, | ||||
| 		{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, | ||||
| 		{ | ||||
| 			Model:              gorm.Model{ID: 3}, | ||||
| 			Name:               "mikael", | ||||
| 			Email:              "mikael@headscale.net", | ||||
| 			ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}, | ||||
| 		}, | ||||
| 		{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, | ||||
| 		{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, | ||||
| 		{Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, | ||||
| 		{Model: gorm.Model{ID: 7}, Name: "1"}, | ||||
| 		{Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"}, | ||||
| 		{Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"}, | ||||
| 		{Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"}, | ||||
| 	} | ||||
| 
 | ||||
| 	type args struct { | ||||
| 		nodes types.Nodes | ||||
| 		user  string | ||||
| @ -874,50 +917,258 @@ func Test_listNodesInUser(t *testing.T) { | ||||
| 			name: "1 node in user", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{User: types.User{Name: "joe"}}, | ||||
| 					&types.Node{User: users[1]}, | ||||
| 				}, | ||||
| 				user: "joe", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{User: types.User{Name: "joe"}}, | ||||
| 				&types.Node{User: users[1]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "3 nodes, 2 in user", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: types.User{Name: "joe"}}, | ||||
| 					&types.Node{ID: 2, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 3, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 1, User: users[1]}, | ||||
| 					&types.Node{ID: 2, User: users[0]}, | ||||
| 					&types.Node{ID: 3, User: users[0]}, | ||||
| 				}, | ||||
| 				user: "marc", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 2, User: types.User{Name: "marc"}}, | ||||
| 				&types.Node{ID: 3, User: types.User{Name: "marc"}}, | ||||
| 				&types.Node{ID: 2, User: users[0]}, | ||||
| 				&types.Node{ID: 3, User: users[0]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "5 nodes, 0 in user", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: types.User{Name: "joe"}}, | ||||
| 					&types.Node{ID: 2, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 3, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 4, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 5, User: types.User{Name: "marc"}}, | ||||
| 					&types.Node{ID: 1, User: users[1]}, | ||||
| 					&types.Node{ID: 2, User: users[0]}, | ||||
| 					&types.Node{ID: 3, User: users[0]}, | ||||
| 					&types.Node{ID: 4, User: users[0]}, | ||||
| 					&types.Node{ID: 5, User: users[0]}, | ||||
| 				}, | ||||
| 				user: "mickael", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "match-by-provider-ident", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[1]}, | ||||
| 					&types.Node{ID: 2, User: users[2]}, | ||||
| 				}, | ||||
| 				user: "http://oidc.org/1234", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 2, User: users[2]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "match-by-email", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[1]}, | ||||
| 					&types.Node{ID: 2, User: users[2]}, | ||||
| 					&types.Node{ID: 8, User: users[7]}, | ||||
| 				}, | ||||
| 				user: "joe@headscale.net", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 1, User: users[1]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multi-match-is-zero", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[1]}, | ||||
| 					&types.Node{ID: 2, User: users[2]}, | ||||
| 					&types.Node{ID: 3, User: users[3]}, | ||||
| 				}, | ||||
| 				user: "mikael@headscale.net", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multi-email-first-match-is-zero", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					// First match email, then provider id | ||||
| 					&types.Node{ID: 3, User: users[3]}, | ||||
| 					&types.Node{ID: 2, User: users[2]}, | ||||
| 				}, | ||||
| 				user: "mikael@headscale.net", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multi-username-first-match-is-zero", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					// First match username, then provider id | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 2, User: users[2]}, | ||||
| 				}, | ||||
| 				user: "mikael", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-duplicate-username-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 				}, | ||||
| 				user: "mikael", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-unique-username-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 				}, | ||||
| 				user: "marc", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 1, User: users[0]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-no-username-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 				}, | ||||
| 				user: "not-working", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-duplicate-email-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 				}, | ||||
| 				user: "mikael@headscale.net", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-duplicate-email-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 					&types.Node{ID: 8, User: users[7]}, | ||||
| 				}, | ||||
| 				user: "joe@headscale.net", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 2, User: users[1]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "email-as-username-duplicate", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[7]}, | ||||
| 					&types.Node{ID: 2, User: users[8]}, | ||||
| 				}, | ||||
| 				user: "alex@headscale.net", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-no-email-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 				}, | ||||
| 				user: "not-working@headscale.net", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-provider-id-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 					&types.Node{ID: 6, User: users[5]}, | ||||
| 				}, | ||||
| 				user: "http://oidc.org/1234", | ||||
| 			}, | ||||
| 			want: types.Nodes{ | ||||
| 				&types.Node{ID: 3, User: users[2]}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "all-users-no-provider-id-random-order", | ||||
| 			args: args{ | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ID: 1, User: users[0]}, | ||||
| 					&types.Node{ID: 2, User: users[1]}, | ||||
| 					&types.Node{ID: 3, User: users[2]}, | ||||
| 					&types.Node{ID: 4, User: users[3]}, | ||||
| 					&types.Node{ID: 5, User: users[4]}, | ||||
| 					&types.Node{ID: 6, User: users[5]}, | ||||
| 				}, | ||||
| 				user: "http://oidc.org/4321", | ||||
| 			}, | ||||
| 			want: nil, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			got := filterNodesByUser(test.args.nodes, test.args.user) | ||||
| 			for range 1000 { | ||||
| 				ns := test.args.nodes | ||||
| 				rand.Shuffle(len(ns), func(i, j int) { | ||||
| 					ns[i], ns[j] = ns[j], ns[i] | ||||
| 				}) | ||||
| 				us := users | ||||
| 				rand.Shuffle(len(us), func(i, j int) { | ||||
| 					us[i], us[j] = us[j], us[i] | ||||
| 				}) | ||||
| 				got := filterNodesByUser(ns, us, test.args.user) | ||||
| 				sort.Slice(got, func(i, j int) bool { | ||||
| 					return got[i].ID < got[j].ID | ||||
| 				}) | ||||
| 
 | ||||
| 			if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { | ||||
| 				t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) | ||||
| 				if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { | ||||
| 					t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| @ -940,6 +1191,12 @@ func Test_expandAlias(t *testing.T) { | ||||
| 		return s | ||||
| 	} | ||||
| 
 | ||||
| 	users := []types.User{ | ||||
| 		{Model: gorm.Model{ID: 1}, Name: "joe"}, | ||||
| 		{Model: gorm.Model{ID: 2}, Name: "marc"}, | ||||
| 		{Model: gorm.Model{ID: 3}, Name: "mickael"}, | ||||
| 	} | ||||
| 
 | ||||
| 	type field struct { | ||||
| 		pol ACLPolicy | ||||
| 	} | ||||
| @ -989,19 +1246,19 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.1"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.2"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.3"), | ||||
| 						User: types.User{Name: "marc"}, | ||||
| 						User: users[1], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.4"), | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 						User: users[2], | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1022,19 +1279,19 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.1"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.2"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.3"), | ||||
| 						User: types.User{Name: "marc"}, | ||||
| 						User: users[1], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.4"), | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 						User: users[2], | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1185,7 +1442,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.1"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| @ -1194,7 +1451,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.2"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| @ -1203,11 +1460,11 @@ func Test_expandAlias(t *testing.T) { | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.3"), | ||||
| 						User: types.User{Name: "marc"}, | ||||
| 						User: users[1], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.4"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1260,21 +1517,21 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4:       iap("100.64.0.1"), | ||||
| 						User:       types.User{Name: "joe"}, | ||||
| 						User:       users[0], | ||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4:       iap("100.64.0.2"), | ||||
| 						User:       types.User{Name: "joe"}, | ||||
| 						User:       users[0], | ||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.3"), | ||||
| 						User: types.User{Name: "marc"}, | ||||
| 						User: users[1], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.4"), | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 						User: users[2], | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1295,12 +1552,12 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				nodes: types.Nodes{ | ||||
| 					&types.Node{ | ||||
| 						IPv4:       iap("100.64.0.1"), | ||||
| 						User:       types.User{Name: "joe"}, | ||||
| 						User:       users[0], | ||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.2"), | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						User: users[0], | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| @ -1309,11 +1566,11 @@ func Test_expandAlias(t *testing.T) { | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.3"), | ||||
| 						User: types.User{Name: "marc"}, | ||||
| 						User: users[1], | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4: iap("100.64.0.4"), | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 						User: users[2], | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1350,12 +1607,12 @@ func Test_expandAlias(t *testing.T) { | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4:     iap("100.64.0.3"), | ||||
| 						User:     types.User{Name: "marc"}, | ||||
| 						User:     users[1], | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPv4:     iap("100.64.0.4"), | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						User:     users[0], | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| @ -1368,6 +1625,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			got, err := test.field.pol.ExpandAlias( | ||||
| 				test.args.nodes, | ||||
| 				users, | ||||
| 				test.args.alias, | ||||
| 			) | ||||
| 			if (err != nil) != test.wantErr { | ||||
| @ -1715,6 +1973,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := tt.field.pol.CompileFilterRules( | ||||
| 				[]types.User{}, | ||||
| 				tt.args.nodes, | ||||
| 			) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| @ -1842,6 +2101,13 @@ func TestTheInternet(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestReduceFilterRules(t *testing.T) { | ||||
| 	users := []types.User{ | ||||
| 		{Model: gorm.Model{ID: 1}, Name: "mickael"}, | ||||
| 		{Model: gorm.Model{ID: 2}, Name: "user1"}, | ||||
| 		{Model: gorm.Model{ID: 3}, Name: "user2"}, | ||||
| 		{Model: gorm.Model{ID: 4}, Name: "user100"}, | ||||
| 	} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		name  string | ||||
| 		node  *types.Node | ||||
| @ -1863,13 +2129,13 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.1"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), | ||||
| 				User: types.User{Name: "mickael"}, | ||||
| 				User: users[0], | ||||
| 			}, | ||||
| 			peers: types.Nodes{ | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), | ||||
| 					User: types.User{Name: "mickael"}, | ||||
| 					User: users[0], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{}, | ||||
| @ -1896,7 +2162,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.1"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 				User: types.User{Name: "user1"}, | ||||
| 				User: users[1], | ||||
| 				Hostinfo: &tailcfg.Hostinfo{ | ||||
| 					RoutableIPs: []netip.Prefix{ | ||||
| 						netip.MustParsePrefix("10.33.0.0/16"), | ||||
| @ -1907,7 +2173,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -1975,19 +2241,19 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.1"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 				User: types.User{Name: "user1"}, | ||||
| 				User: users[1], | ||||
| 			}, | ||||
| 			peers: types.Nodes{ | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user2"}, | ||||
| 					User: users[2], | ||||
| 				}, | ||||
| 				// "internal" exit node | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.100"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::100"), | ||||
| 					User: types.User{Name: "user100"}, | ||||
| 					User: users[3], | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RoutableIPs: tsaddr.ExitRoutes(), | ||||
| 					}, | ||||
| @ -2034,12 +2300,12 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user2"}, | ||||
| 					User: users[2], | ||||
| 				}, | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.1"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -2131,7 +2397,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.100"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||
| 				User: types.User{Name: "user100"}, | ||||
| 				User: users[3], | ||||
| 				Hostinfo: &tailcfg.Hostinfo{ | ||||
| 					RoutableIPs: tsaddr.ExitRoutes(), | ||||
| 				}, | ||||
| @ -2140,12 +2406,12 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user2"}, | ||||
| 					User: users[2], | ||||
| 				}, | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.1"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -2243,7 +2509,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.100"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||
| 				User: types.User{Name: "user100"}, | ||||
| 				User: users[3], | ||||
| 				Hostinfo: &tailcfg.Hostinfo{ | ||||
| 					RoutableIPs: []netip.Prefix{ | ||||
| 						netip.MustParsePrefix("8.0.0.0/16"), | ||||
| @ -2255,12 +2521,12 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user2"}, | ||||
| 					User: users[2], | ||||
| 				}, | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.1"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -2333,7 +2599,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.100"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||
| 				User: types.User{Name: "user100"}, | ||||
| 				User: users[3], | ||||
| 				Hostinfo: &tailcfg.Hostinfo{ | ||||
| 					RoutableIPs: []netip.Prefix{ | ||||
| 						netip.MustParsePrefix("8.0.0.0/8"), | ||||
| @ -2345,12 +2611,12 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.2"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||
| 					User: types.User{Name: "user2"}, | ||||
| 					User: users[2], | ||||
| 				}, | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.1"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -2416,7 +2682,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 			node: &types.Node{ | ||||
| 				IPv4: iap("100.64.0.100"), | ||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||
| 				User: types.User{Name: "user100"}, | ||||
| 				User: users[3], | ||||
| 				Hostinfo: &tailcfg.Hostinfo{ | ||||
| 					RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, | ||||
| 				}, | ||||
| @ -2426,7 +2692,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPv4: iap("100.64.0.1"), | ||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||
| 					User: types.User{Name: "user1"}, | ||||
| 					User: users[1], | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| @ -2454,6 +2720,7 @@ func TestReduceFilterRules(t *testing.T) { | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, _ := tt.pol.CompileFilterRules( | ||||
| 				users, | ||||
| 				append(tt.peers, tt.node), | ||||
| 			) | ||||
| 
 | ||||
| @ -3461,7 +3728,7 @@ func TestSSHRules(t *testing.T) { | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) | ||||
| 			got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) | ||||
| 			require.NoError(t, err) | ||||
| 
 | ||||
| 			if diff := cmp.Diff(tt.want, got); diff != "" { | ||||
| @ -3544,14 +3811,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | ||||
| 		RequestTags: []string{"tag:test"}, | ||||
| 	} | ||||
| 
 | ||||
| 	user := types.User{ | ||||
| 		Model: gorm.Model{ID: 1}, | ||||
| 		Name:  "user1", | ||||
| 	} | ||||
| 
 | ||||
| 	node := &types.Node{ | ||||
| 		ID:       0, | ||||
| 		Hostname: "testnodes", | ||||
| 		IPv4:     iap("100.64.0.1"), | ||||
| 		UserID:   0, | ||||
| 		User: types.User{ | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		ID:             0, | ||||
| 		Hostname:       "testnodes", | ||||
| 		IPv4:           iap("100.64.0.1"), | ||||
| 		UserID:         0, | ||||
| 		User:           user, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| @ -3568,7 +3838,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -3602,7 +3872,8 @@ func TestInvalidTagValidUser(t *testing.T) { | ||||
| 		IPv4:     iap("100.64.0.1"), | ||||
| 		UserID:   1, | ||||
| 		User: types.User{ | ||||
| 			Name: "user1", | ||||
| 			Model: gorm.Model{ID: 1}, | ||||
| 			Name:  "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &hostInfo, | ||||
| @ -3619,7 +3890,12 @@ func TestInvalidTagValidUser(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		node, | ||||
| 		types.Nodes{}, | ||||
| 		[]types.User{node.User}, | ||||
| 	) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -3653,7 +3929,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | ||||
| 		IPv4:     iap("100.64.0.1"), | ||||
| 		UserID:   1, | ||||
| 		User: types.User{ | ||||
| 			Name: "user1", | ||||
| 			Model: gorm.Model{ID: 1}, | ||||
| 			Name:  "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &hostInfo, | ||||
| @ -3678,7 +3955,12 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | ||||
| 	// c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||
| 	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||
| 
 | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		node, | ||||
| 		types.Nodes{}, | ||||
| 		[]types.User{node.User}, | ||||
| 	) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -3707,15 +3989,17 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 		Hostname:    "webserver", | ||||
| 		RequestTags: []string{"tag:webapp"}, | ||||
| 	} | ||||
| 	user := types.User{ | ||||
| 		Model: gorm.Model{ID: 1}, | ||||
| 		Name:  "user1", | ||||
| 	} | ||||
| 
 | ||||
| 	node := &types.Node{ | ||||
| 		ID:       1, | ||||
| 		Hostname: "webserver", | ||||
| 		IPv4:     iap("100.64.0.1"), | ||||
| 		UserID:   1, | ||||
| 		User: types.User{ | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		ID:             1, | ||||
| 		Hostname:       "webserver", | ||||
| 		IPv4:           iap("100.64.0.1"), | ||||
| 		UserID:         1, | ||||
| 		User:           user, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| @ -3726,13 +4010,11 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	nodes2 := &types.Node{ | ||||
| 		ID:       2, | ||||
| 		Hostname: "user", | ||||
| 		IPv4:     iap("100.64.0.2"), | ||||
| 		UserID:   1, | ||||
| 		User: types.User{ | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		ID:             2, | ||||
| 		Hostname:       "user", | ||||
| 		IPv4:           iap("100.64.0.2"), | ||||
| 		UserID:         1, | ||||
| 		User:           user, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &hostInfo2, | ||||
| 	} | ||||
| @ -3748,7 +4030,12 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) | ||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||
| 		pol, | ||||
| 		node, | ||||
| 		types.Nodes{nodes2}, | ||||
| 		[]types.User{user}, | ||||
| 	) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
|  | ||||
| @ -3,6 +3,7 @@ package types | ||||
| import ( | ||||
| 	"cmp" | ||||
| 	"database/sql" | ||||
| 	"net/mail" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| @ -56,14 +57,7 @@ type User struct { | ||||
| // should be used throughout headscale, in information returned to the | ||||
| // user and the Policy engine. | ||||
| func (u *User) Username() string { | ||||
| 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) | ||||
| 
 | ||||
| 	// TODO(kradalby): Wire up all of this for the future | ||||
| 	// if !strings.Contains(username, "@") { | ||||
| 	// 	username = username + "@" | ||||
| 	// } | ||||
| 
 | ||||
| 	return username | ||||
| 	return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) | ||||
| } | ||||
| 
 | ||||
| // DisplayNameOrUsername returns the DisplayName if it exists, otherwise | ||||
| @ -146,12 +140,20 @@ func (c *OIDCClaims) Identifier() string { | ||||
| // FromClaim overrides a User from OIDC claims. | ||||
| // All fields will be updated, except for the ID. | ||||
| func (u *User) FromClaim(claims *OIDCClaims) { | ||||
| 	err := util.CheckForFQDNRules(claims.Username) | ||||
| 	if err == nil { | ||||
| 		u.Name = claims.Username | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.EmailVerified { | ||||
| 		_, err = mail.ParseAddress(claims.Email) | ||||
| 		if err == nil { | ||||
| 			u.Email = claims.Email | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} | ||||
| 	u.DisplayName = claims.Name | ||||
| 	if claims.EmailVerified { | ||||
| 		u.Email = claims.Email | ||||
| 	} | ||||
| 	u.Name = claims.Username | ||||
| 	u.ProfilePicURL = claims.ProfilePictureURL | ||||
| 	u.Provider = util.RegisterMethodOIDC | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user