mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	Resolve user to stable unique ID in policy (#2205)
This commit is contained in:
		
							parent
							
								
									3a2589f1a9
								
							
						
					
					
						commit
						fffd23602b
					
				| @ -1029,14 +1029,18 @@ func (h *Headscale) loadACLPolicy() error { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("loading nodes from database to validate policy: %w", err) | 			return fmt.Errorf("loading nodes from database to validate policy: %w", err) | ||||||
| 		} | 		} | ||||||
|  | 		users, err := h.db.ListUsers() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("loading users from database to validate policy: %w", err) | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		_, err = pol.CompileFilterRules(nodes) | 		_, err = pol.CompileFilterRules(users, nodes) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("verifying policy rules: %w", err) | 			return fmt.Errorf("verifying policy rules: %w", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if len(nodes) > 0 { | 		if len(nodes) > 0 { | ||||||
| 			_, err = pol.CompileSSHPolicy(nodes[0], nodes) | 			_, err = pol.CompileSSHPolicy(nodes[0], users, nodes) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return fmt.Errorf("verifying SSH rules: %w", err) | 				return fmt.Errorf("verifying SSH rules: %w", err) | ||||||
| 			} | 			} | ||||||
|  | |||||||
| @ -256,10 +256,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { | |||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(testPeers), check.Equals, 9) | 	c.Assert(len(testPeers), check.Equals, 9) | ||||||
| 
 | 
 | ||||||
| 	adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) | 	adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) | 	testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) | 	peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) | ||||||
|  | |||||||
| @ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( | |||||||
| 			if approvedAlias == node.User.Username() { | 			if approvedAlias == node.User.Username() { | ||||||
| 				approvedRoutes = append(approvedRoutes, advertisedRoute) | 				approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||||
| 			} else { | 			} else { | ||||||
|  | 				users, err := ListUsers(tx) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return fmt.Errorf("looking up users to expand route alias: %w", err) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
| 				// TODO(kradalby): figure out how to get this to depend on less stuff | 				// TODO(kradalby): figure out how to get this to depend on less stuff | ||||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) | 				approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) | 					return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) | ||||||
| 				} | 				} | ||||||
|  | |||||||
| @ -773,14 +773,18 @@ func (api headscaleV1APIServer) SetPolicy( | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) | 		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) | ||||||
| 	} | 	} | ||||||
|  | 	users, err := api.h.db.ListUsers() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("loading users from database to validate policy: %w", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	_, err = pol.CompileFilterRules(nodes) | 	_, err = pol.CompileFilterRules(users, nodes) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("verifying policy rules: %w", err) | 		return nil, fmt.Errorf("verifying policy rules: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(nodes) > 0 { | 	if len(nodes) > 0 { | ||||||
| 		_, err = pol.CompileSSHPolicy(nodes[0], nodes) | 		_, err = pol.CompileSSHPolicy(nodes[0], users, nodes) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("verifying SSH rules: %w", err) | 			return nil, fmt.Errorf("verifying SSH rules: %w", err) | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { | |||||||
| func (m *Mapper) fullMapResponse( | func (m *Mapper) fullMapResponse( | ||||||
| 	node *types.Node, | 	node *types.Node, | ||||||
| 	peers types.Nodes, | 	peers types.Nodes, | ||||||
|  | 	users []types.User, | ||||||
| 	pol *policy.ACLPolicy, | 	pol *policy.ACLPolicy, | ||||||
| 	capVer tailcfg.CapabilityVersion, | 	capVer tailcfg.CapabilityVersion, | ||||||
| ) (*tailcfg.MapResponse, error) { | ) (*tailcfg.MapResponse, error) { | ||||||
| @ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse( | |||||||
| 		pol, | 		pol, | ||||||
| 		node, | 		node, | ||||||
| 		capVer, | 		capVer, | ||||||
|  | 		users, | ||||||
| 		peers, | 		peers, | ||||||
| 		peers, | 		peers, | ||||||
| 		m.cfg, | 		m.cfg, | ||||||
| @ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse( | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	users, err := m.db.ListUsers() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) | 	resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	users, err := m.db.ListUsers() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("listing users for map response: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	var removedIDs []tailcfg.NodeID | 	var removedIDs []tailcfg.NodeID | ||||||
| 	var changedIDs []types.NodeID | 	var changedIDs []types.NodeID | ||||||
| 	for nodeID, nodeChanged := range changed { | 	for nodeID, nodeChanged := range changed { | ||||||
| @ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse( | |||||||
| 		pol, | 		pol, | ||||||
| 		node, | 		node, | ||||||
| 		mapRequest.Version, | 		mapRequest.Version, | ||||||
|  | 		users, | ||||||
| 		peers, | 		peers, | ||||||
| 		changedNodes, | 		changedNodes, | ||||||
| 		m.cfg, | 		m.cfg, | ||||||
| @ -508,16 +520,17 @@ func appendPeerChanges( | |||||||
| 	pol *policy.ACLPolicy, | 	pol *policy.ACLPolicy, | ||||||
| 	node *types.Node, | 	node *types.Node, | ||||||
| 	capVer tailcfg.CapabilityVersion, | 	capVer tailcfg.CapabilityVersion, | ||||||
|  | 	users []types.User, | ||||||
| 	peers types.Nodes, | 	peers types.Nodes, | ||||||
| 	changed types.Nodes, | 	changed types.Nodes, | ||||||
| 	cfg *types.Config, | 	cfg *types.Config, | ||||||
| ) error { | ) error { | ||||||
| 	packetFilter, err := pol.CompileFilterRules(append(peers, node)) | 	packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	sshPolicy, err := pol.CompileSSHPolicy(node, peers) | 	sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 	lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) | 	lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) | ||||||
| 	expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) | 	expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) | ||||||
| 
 | 
 | ||||||
|  | 	user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} | ||||||
|  | 	user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} | ||||||
|  | 
 | ||||||
| 	mini := &types.Node{ | 	mini := &types.Node{ | ||||||
| 		ID: 0, | 		ID: 0, | ||||||
| 		MachineKey: mustMK( | 		MachineKey: mustMK( | ||||||
| @ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 		IPv4:       iap("100.64.0.1"), | 		IPv4:       iap("100.64.0.1"), | ||||||
| 		Hostname:   "mini", | 		Hostname:   "mini", | ||||||
| 		GivenName:  "mini", | 		GivenName:  "mini", | ||||||
| 		UserID:     0, | 		UserID:     user1.ID, | ||||||
| 		User:       types.User{Name: "mini"}, | 		User:       user1, | ||||||
| 		ForcedTags: []string{}, | 		ForcedTags: []string{}, | ||||||
| 		AuthKey:    &types.PreAuthKey{}, | 		AuthKey:    &types.PreAuthKey{}, | ||||||
| 		LastSeen:   &lastSeen, | 		LastSeen:   &lastSeen, | ||||||
| @ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 		IPv4:       iap("100.64.0.2"), | 		IPv4:       iap("100.64.0.2"), | ||||||
| 		Hostname:   "peer1", | 		Hostname:   "peer1", | ||||||
| 		GivenName:  "peer1", | 		GivenName:  "peer1", | ||||||
| 		UserID:     0, | 		UserID:     user1.ID, | ||||||
| 		User:       types.User{Name: "mini"}, | 		User:       user1, | ||||||
| 		ForcedTags: []string{}, | 		ForcedTags: []string{}, | ||||||
| 		LastSeen:   &lastSeen, | 		LastSeen:   &lastSeen, | ||||||
| 		Expiry:     &expire, | 		Expiry:     &expire, | ||||||
| @ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 		IPv4:       iap("100.64.0.3"), | 		IPv4:       iap("100.64.0.3"), | ||||||
| 		Hostname:   "peer2", | 		Hostname:   "peer2", | ||||||
| 		GivenName:  "peer2", | 		GivenName:  "peer2", | ||||||
| 		UserID:     1, | 		UserID:     user2.ID, | ||||||
| 		User:       types.User{Name: "peer2"}, | 		User:       user2, | ||||||
| 		ForcedTags: []string{}, | 		ForcedTags: []string{}, | ||||||
| 		LastSeen:   &lastSeen, | 		LastSeen:   &lastSeen, | ||||||
| 		Expiry:     &expire, | 		Expiry:     &expire, | ||||||
| @ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 			got, err := mappy.fullMapResponse( | 			got, err := mappy.fullMapResponse( | ||||||
| 				tt.node, | 				tt.node, | ||||||
| 				tt.peers, | 				tt.peers, | ||||||
|  | 				[]types.User{user1, user2}, | ||||||
| 				tt.pol, | 				tt.pol, | ||||||
| 				0, | 				0, | ||||||
| 			) | 			) | ||||||
|  | |||||||
| @ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( | |||||||
| 	policy *ACLPolicy, | 	policy *ACLPolicy, | ||||||
| 	node *types.Node, | 	node *types.Node, | ||||||
| 	peers types.Nodes, | 	peers types.Nodes, | ||||||
|  | 	users []types.User, | ||||||
| ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ||||||
| 	// If there is no policy defined, we default to allow all | 	// If there is no policy defined, we default to allow all | ||||||
| 	if policy == nil { | 	if policy == nil { | ||||||
| 		return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil | 		return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rules, err := policy.CompileFilterRules(append(peers, node)) | 	rules, err := policy.CompileFilterRules(users, append(peers, node)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") | 	log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") | ||||||
| 
 | 
 | ||||||
| 	sshPolicy, err := policy.CompileSSHPolicy(node, peers) | 	sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||||
| 	} | 	} | ||||||
| @ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( | |||||||
| // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a | // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a | ||||||
| // set of Tailscale compatible FilterRules used to allow traffic on clients. | // set of Tailscale compatible FilterRules used to allow traffic on clients. | ||||||
| func (pol *ACLPolicy) CompileFilterRules( | func (pol *ACLPolicy) CompileFilterRules( | ||||||
|  | 	users []types.User, | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
| ) ([]tailcfg.FilterRule, error) { | ) ([]tailcfg.FilterRule, error) { | ||||||
| 	if pol == nil { | 	if pol == nil { | ||||||
| @ -176,7 +178,7 @@ func (pol *ACLPolicy) CompileFilterRules( | |||||||
| 
 | 
 | ||||||
| 		var srcIPs []string | 		var srcIPs []string | ||||||
| 		for srcIndex, src := range acl.Sources { | 		for srcIndex, src := range acl.Sources { | ||||||
| 			srcs, err := pol.expandSource(src, nodes) | 			srcs, err := pol.expandSource(src, users, nodes) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, fmt.Errorf( | 				return nil, fmt.Errorf( | ||||||
| 					"parsing policy, acl index: %d->%d: %w", | 					"parsing policy, acl index: %d->%d: %w", | ||||||
| @ -202,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( | |||||||
| 
 | 
 | ||||||
| 			expanded, err := pol.ExpandAlias( | 			expanded, err := pol.ExpandAlias( | ||||||
| 				nodes, | 				nodes, | ||||||
|  | 				users, | ||||||
| 				alias, | 				alias, | ||||||
| 			) | 			) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @ -286,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) CompileSSHPolicy( | func (pol *ACLPolicy) CompileSSHPolicy( | ||||||
| 	node *types.Node, | 	node *types.Node, | ||||||
|  | 	users []types.User, | ||||||
| 	peers types.Nodes, | 	peers types.Nodes, | ||||||
| ) (*tailcfg.SSHPolicy, error) { | ) (*tailcfg.SSHPolicy, error) { | ||||||
| 	if pol == nil { | 	if pol == nil { | ||||||
| @ -317,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( | |||||||
| 	for index, sshACL := range pol.SSHs { | 	for index, sshACL := range pol.SSHs { | ||||||
| 		var dest netipx.IPSetBuilder | 		var dest netipx.IPSetBuilder | ||||||
| 		for _, src := range sshACL.Destinations { | 		for _, src := range sshACL.Destinations { | ||||||
| 			expanded, err := pol.ExpandAlias(append(peers, node), src) | 			expanded, err := pol.ExpandAlias(append(peers, node), users, src) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| @ -377,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( | |||||||
| 			} else { | 			} else { | ||||||
| 				expandedSrcs, err := pol.ExpandAlias( | 				expandedSrcs, err := pol.ExpandAlias( | ||||||
| 					peers, | 					peers, | ||||||
|  | 					users, | ||||||
| 					rawSrc, | 					rawSrc, | ||||||
| 				) | 				) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| @ -526,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { | |||||||
| // with the given src alias. | // with the given src alias. | ||||||
| func (pol *ACLPolicy) expandSource( | func (pol *ACLPolicy) expandSource( | ||||||
| 	src string, | 	src string, | ||||||
|  | 	users []types.User, | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
| ) ([]string, error) { | ) ([]string, error) { | ||||||
| 	ipSet, err := pol.ExpandAlias(nodes, src) | 	ipSet, err := pol.ExpandAlias(nodes, users, src) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return []string{}, err | 		return []string{}, err | ||||||
| 	} | 	} | ||||||
| @ -552,6 +558,7 @@ func (pol *ACLPolicy) expandSource( | |||||||
| // and transform these in IPAddresses. | // and transform these in IPAddresses. | ||||||
| func (pol *ACLPolicy) ExpandAlias( | func (pol *ACLPolicy) ExpandAlias( | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
|  | 	users []types.User, | ||||||
| 	alias string, | 	alias string, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	if isWildcard(alias) { | 	if isWildcard(alias) { | ||||||
| @ -566,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( | |||||||
| 
 | 
 | ||||||
| 	// if alias is a group | 	// if alias is a group | ||||||
| 	if isGroup(alias) { | 	if isGroup(alias) { | ||||||
| 		return pol.expandIPsFromGroup(alias, nodes) | 		return pol.expandIPsFromGroup(alias, users, nodes) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// if alias is a tag | 	// if alias is a tag | ||||||
| 	if isTag(alias) { | 	if isTag(alias) { | ||||||
| 		return pol.expandIPsFromTag(alias, nodes) | 		return pol.expandIPsFromTag(alias, users, nodes) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if isAutoGroup(alias) { | 	if isAutoGroup(alias) { | ||||||
| @ -579,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// if alias is a user | 	// if alias is a user | ||||||
| 	if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { | 	if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { | ||||||
| 		return ips, err | 		return ips, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -588,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( | |||||||
| 	if h, ok := pol.Hosts[alias]; ok { | 	if h, ok := pol.Hosts[alias]; ok { | ||||||
| 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | ||||||
| 
 | 
 | ||||||
| 		return pol.ExpandAlias(nodes, h.String()) | 		return pol.ExpandAlias(nodes, users, h.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// if alias is an IP | 	// if alias is an IP | ||||||
| @ -765,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) expandIPsFromGroup( | func (pol *ACLPolicy) expandIPsFromGroup( | ||||||
| 	group string, | 	group string, | ||||||
|  | 	users []types.User, | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	var build netipx.IPSetBuilder | 	var build netipx.IPSetBuilder | ||||||
| 
 | 
 | ||||||
| 	users, err := pol.expandUsersFromGroup(group) | 	userTokens, err := pol.expandUsersFromGroup(group) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return &netipx.IPSet{}, err | 		return &netipx.IPSet{}, err | ||||||
| 	} | 	} | ||||||
| 	for _, user := range users { | 	for _, user := range userTokens { | ||||||
| 		filteredNodes := filterNodesByUser(nodes, user) | 		filteredNodes := filterNodesByUser(nodes, users, user) | ||||||
| 		for _, node := range filteredNodes { | 		for _, node := range filteredNodes { | ||||||
| 			node.AppendToIPSet(&build) | 			node.AppendToIPSet(&build) | ||||||
| 		} | 		} | ||||||
| @ -785,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) expandIPsFromTag( | func (pol *ACLPolicy) expandIPsFromTag( | ||||||
| 	alias string, | 	alias string, | ||||||
|  | 	users []types.User, | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	var build netipx.IPSetBuilder | 	var build netipx.IPSetBuilder | ||||||
| @ -817,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( | |||||||
| 
 | 
 | ||||||
| 	// filter out nodes per tag owner | 	// filter out nodes per tag owner | ||||||
| 	for _, user := range owners { | 	for _, user := range owners { | ||||||
| 		nodes := filterNodesByUser(nodes, user) | 		nodes := filterNodesByUser(nodes, users, user) | ||||||
| 		for _, node := range nodes { | 		for _, node := range nodes { | ||||||
| 			if node.Hostinfo == nil { | 			if node.Hostinfo == nil { | ||||||
| 				continue | 				continue | ||||||
| @ -834,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) expandIPsFromUser( | func (pol *ACLPolicy) expandIPsFromUser( | ||||||
| 	user string, | 	user string, | ||||||
|  | 	users []types.User, | ||||||
| 	nodes types.Nodes, | 	nodes types.Nodes, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	var build netipx.IPSetBuilder | 	var build netipx.IPSetBuilder | ||||||
| 
 | 
 | ||||||
| 	filteredNodes := filterNodesByUser(nodes, user) | 	filteredNodes := filterNodesByUser(nodes, users, user) | ||||||
| 	filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) | 	filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) | ||||||
| 
 | 
 | ||||||
| 	// shortcurcuit if we have no nodes to get ips from. | 	// shortcurcuit if we have no nodes to get ips from. | ||||||
| @ -967,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode( | |||||||
| 	return validTags, invalidTags | 	return validTags, invalidTags | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { | // filterNodesByUser returns a list of nodes that match the given userToken from a | ||||||
|  | // policy. | ||||||
|  | // Matching nodes are determined by first matching the user token to a user by checking: | ||||||
|  | // - If it is an ID that mactches the user database ID | ||||||
|  | // - It is the Provider Identifier from OIDC | ||||||
|  | // - It matches the username or email of a user | ||||||
|  | // | ||||||
|  | // If the token matches more than one user, zero nodes will returned. | ||||||
|  | func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { | ||||||
| 	var out types.Nodes | 	var out types.Nodes | ||||||
|  | 
 | ||||||
|  | 	var potentialUsers []types.User | ||||||
|  | 	for _, user := range users { | ||||||
|  | 		if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken { | ||||||
|  | 			// If a user is matching with a known unique field, | ||||||
|  | 			// disgard all other users and only keep the current | ||||||
|  | 			// user. | ||||||
|  | 			potentialUsers = []types.User{user} | ||||||
|  | 
 | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		if user.Email == userToken { | ||||||
|  | 			potentialUsers = append(potentialUsers, user) | ||||||
|  | 		} | ||||||
|  | 		if user.Name == userToken { | ||||||
|  | 			potentialUsers = append(potentialUsers, user) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(potentialUsers) != 1 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user := potentialUsers[0] | ||||||
|  | 
 | ||||||
| 	for _, node := range nodes { | 	for _, node := range nodes { | ||||||
| 		if node.User.Username() == user { | 		if node.User.ID == user.ID { | ||||||
| 			out = append(out, node) | 			out = append(out, node) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -1,9 +1,12 @@ | |||||||
| package policy | package policy | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"math/rand/v2" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"slices" | 	"slices" | ||||||
|  | 	"sort" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/go-cmp/cmp" | 	"github.com/google/go-cmp/cmp" | ||||||
| @ -14,6 +17,7 @@ import ( | |||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| 	"go4.org/netipx" | 	"go4.org/netipx" | ||||||
| 	"gopkg.in/check.v1" | 	"gopkg.in/check.v1" | ||||||
|  | 	"gorm.io/gorm" | ||||||
| 	"tailscale.com/net/tsaddr" | 	"tailscale.com/net/tsaddr" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| ) | ) | ||||||
| @ -375,15 +379,21 @@ func TestParsing(t *testing.T) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			rules, err := pol.CompileFilterRules(types.Nodes{ | 			user := types.User{ | ||||||
|  | 				Model: gorm.Model{ID: 1}, | ||||||
|  | 				Name:  "testuser", | ||||||
|  | 			} | ||||||
|  | 			rules, err := pol.CompileFilterRules( | ||||||
|  | 				[]types.User{ | ||||||
|  | 					user, | ||||||
|  | 				}, | ||||||
|  | 				types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.100.100.100"), | 						IPv4: iap("100.100.100.100"), | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:     iap("200.200.200.200"), | 						IPv4:     iap("200.200.200.200"), | ||||||
| 					User: types.User{ | 						User:     user, | ||||||
| 						Name: "testuser", |  | ||||||
| 					}, |  | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | 						Hostinfo: &tailcfg.Hostinfo{}, | ||||||
| 					}, | 					}, | ||||||
| 				}) | 				}) | ||||||
| @ -533,7 +543,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { | |||||||
| 	c.Assert(pol.ACLs, check.HasLen, 6) | 	c.Assert(pol.ACLs, check.HasLen, 6) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	rules, err := pol.CompileFilterRules(types.Nodes{}) | 	rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 	c.Assert(rules, check.IsNil) | 	c.Assert(rules, check.IsNil) | ||||||
| } | } | ||||||
| @ -549,7 +559,12 @@ func (s *Suite) TestInvalidAction(c *check.C) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		&types.Node{}, | ||||||
|  | 		types.Nodes{}, | ||||||
|  | 		[]types.User{}, | ||||||
|  | 	) | ||||||
| 	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) | 	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -568,7 +583,12 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		&types.Node{}, | ||||||
|  | 		types.Nodes{}, | ||||||
|  | 		[]types.User{}, | ||||||
|  | 	) | ||||||
| 	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) | 	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -584,7 +604,12 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) | 	_, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		&types.Node{}, | ||||||
|  | 		types.Nodes{}, | ||||||
|  | 		[]types.User{}, | ||||||
|  | 	) | ||||||
| 	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) | 	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -860,7 +885,25 @@ func Test_expandPorts(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Test_listNodesInUser(t *testing.T) { | func Test_filterNodesByUser(t *testing.T) { | ||||||
|  | 	users := []types.User{ | ||||||
|  | 		{Model: gorm.Model{ID: 1}, Name: "marc"}, | ||||||
|  | 		{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, | ||||||
|  | 		{ | ||||||
|  | 			Model:              gorm.Model{ID: 3}, | ||||||
|  | 			Name:               "mikael", | ||||||
|  | 			Email:              "mikael@headscale.net", | ||||||
|  | 			ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}, | ||||||
|  | 		}, | ||||||
|  | 		{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, | ||||||
|  | 		{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, | ||||||
|  | 		{Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, | ||||||
|  | 		{Model: gorm.Model{ID: 7}, Name: "1"}, | ||||||
|  | 		{Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"}, | ||||||
|  | 		{Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"}, | ||||||
|  | 		{Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	type args struct { | 	type args struct { | ||||||
| 		nodes types.Nodes | 		nodes types.Nodes | ||||||
| 		user  string | 		user  string | ||||||
| @ -874,50 +917,258 @@ func Test_listNodesInUser(t *testing.T) { | |||||||
| 			name: "1 node in user", | 			name: "1 node in user", | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{User: types.User{Name: "joe"}}, | 					&types.Node{User: users[1]}, | ||||||
| 				}, | 				}, | ||||||
| 				user: "joe", | 				user: "joe", | ||||||
| 			}, | 			}, | ||||||
| 			want: types.Nodes{ | 			want: types.Nodes{ | ||||||
| 				&types.Node{User: types.User{Name: "joe"}}, | 				&types.Node{User: users[1]}, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "3 nodes, 2 in user", | 			name: "3 nodes, 2 in user", | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ID: 1, User: types.User{Name: "joe"}}, | 					&types.Node{ID: 1, User: users[1]}, | ||||||
| 					&types.Node{ID: 2, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 2, User: users[0]}, | ||||||
| 					&types.Node{ID: 3, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 3, User: users[0]}, | ||||||
| 				}, | 				}, | ||||||
| 				user: "marc", | 				user: "marc", | ||||||
| 			}, | 			}, | ||||||
| 			want: types.Nodes{ | 			want: types.Nodes{ | ||||||
| 				&types.Node{ID: 2, User: types.User{Name: "marc"}}, | 				&types.Node{ID: 2, User: users[0]}, | ||||||
| 				&types.Node{ID: 3, User: types.User{Name: "marc"}}, | 				&types.Node{ID: 3, User: users[0]}, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "5 nodes, 0 in user", | 			name: "5 nodes, 0 in user", | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ID: 1, User: types.User{Name: "joe"}}, | 					&types.Node{ID: 1, User: users[1]}, | ||||||
| 					&types.Node{ID: 2, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 2, User: users[0]}, | ||||||
| 					&types.Node{ID: 3, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 3, User: users[0]}, | ||||||
| 					&types.Node{ID: 4, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 4, User: users[0]}, | ||||||
| 					&types.Node{ID: 5, User: types.User{Name: "marc"}}, | 					&types.Node{ID: 5, User: users[0]}, | ||||||
| 				}, | 				}, | ||||||
| 				user: "mickael", | 				user: "mickael", | ||||||
| 			}, | 			}, | ||||||
| 			want: nil, | 			want: nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "match-by-provider-ident", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "http://oidc.org/1234", | ||||||
|  | 			}, | ||||||
|  | 			want: types.Nodes{ | ||||||
|  | 				&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "match-by-email", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 8, User: users[7]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "joe@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: types.Nodes{ | ||||||
|  | 				&types.Node{ID: 1, User: users[1]}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "multi-match-is-zero", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[3]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "mikael@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "multi-email-first-match-is-zero", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					// First match email, then provider id | ||||||
|  | 					&types.Node{ID: 3, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "mikael@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "multi-username-first-match-is-zero", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					// First match username, then provider id | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[2]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "mikael", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-duplicate-username-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "mikael", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-unique-username-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "marc", | ||||||
|  | 			}, | ||||||
|  | 			want: types.Nodes{ | ||||||
|  | 				&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-no-username-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "not-working", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-duplicate-email-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "mikael@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-duplicate-email-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 					&types.Node{ID: 8, User: users[7]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "joe@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: types.Nodes{ | ||||||
|  | 				&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "email-as-username-duplicate", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[7]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[8]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "alex@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-no-email-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "not-working@headscale.net", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-provider-id-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 					&types.Node{ID: 6, User: users[5]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "http://oidc.org/1234", | ||||||
|  | 			}, | ||||||
|  | 			want: types.Nodes{ | ||||||
|  | 				&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "all-users-no-provider-id-random-order", | ||||||
|  | 			args: args{ | ||||||
|  | 				nodes: types.Nodes{ | ||||||
|  | 					&types.Node{ID: 1, User: users[0]}, | ||||||
|  | 					&types.Node{ID: 2, User: users[1]}, | ||||||
|  | 					&types.Node{ID: 3, User: users[2]}, | ||||||
|  | 					&types.Node{ID: 4, User: users[3]}, | ||||||
|  | 					&types.Node{ID: 5, User: users[4]}, | ||||||
|  | 					&types.Node{ID: 6, User: users[5]}, | ||||||
|  | 				}, | ||||||
|  | 				user: "http://oidc.org/4321", | ||||||
|  | 			}, | ||||||
|  | 			want: nil, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
| 		t.Run(test.name, func(t *testing.T) { | 		t.Run(test.name, func(t *testing.T) { | ||||||
| 			got := filterNodesByUser(test.args.nodes, test.args.user) | 			for range 1000 { | ||||||
|  | 				ns := test.args.nodes | ||||||
|  | 				rand.Shuffle(len(ns), func(i, j int) { | ||||||
|  | 					ns[i], ns[j] = ns[j], ns[i] | ||||||
|  | 				}) | ||||||
|  | 				us := users | ||||||
|  | 				rand.Shuffle(len(us), func(i, j int) { | ||||||
|  | 					us[i], us[j] = us[j], us[i] | ||||||
|  | 				}) | ||||||
|  | 				got := filterNodesByUser(ns, us, test.args.user) | ||||||
|  | 				sort.Slice(got, func(i, j int) bool { | ||||||
|  | 					return got[i].ID < got[j].ID | ||||||
|  | 				}) | ||||||
| 
 | 
 | ||||||
| 				if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { | 				if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { | ||||||
| 				t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) | 					t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| @ -940,6 +1191,12 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 		return s | 		return s | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	users := []types.User{ | ||||||
|  | 		{Model: gorm.Model{ID: 1}, Name: "joe"}, | ||||||
|  | 		{Model: gorm.Model{ID: 2}, Name: "marc"}, | ||||||
|  | 		{Model: gorm.Model{ID: 3}, Name: "mickael"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	type field struct { | 	type field struct { | ||||||
| 		pol ACLPolicy | 		pol ACLPolicy | ||||||
| 	} | 	} | ||||||
| @ -989,19 +1246,19 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.1"), | 						IPv4: iap("100.64.0.1"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.2"), | 						IPv4: iap("100.64.0.2"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.3"), | 						IPv4: iap("100.64.0.3"), | ||||||
| 						User: types.User{Name: "marc"}, | 						User: users[1], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.4"), | 						IPv4: iap("100.64.0.4"), | ||||||
| 						User: types.User{Name: "mickael"}, | 						User: users[2], | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| @ -1022,19 +1279,19 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.1"), | 						IPv4: iap("100.64.0.1"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.2"), | 						IPv4: iap("100.64.0.2"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.3"), | 						IPv4: iap("100.64.0.3"), | ||||||
| 						User: types.User{Name: "marc"}, | 						User: users[1], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.4"), | 						IPv4: iap("100.64.0.4"), | ||||||
| 						User: types.User{Name: "mickael"}, | 						User: users[2], | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| @ -1185,7 +1442,7 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.1"), | 						IPv4: iap("100.64.0.1"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{ | 						Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 							OS:          "centos", | 							OS:          "centos", | ||||||
| 							Hostname:    "foo", | 							Hostname:    "foo", | ||||||
| @ -1194,7 +1451,7 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.2"), | 						IPv4: iap("100.64.0.2"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{ | 						Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 							OS:          "centos", | 							OS:          "centos", | ||||||
| 							Hostname:    "foo", | 							Hostname:    "foo", | ||||||
| @ -1203,11 +1460,11 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.3"), | 						IPv4: iap("100.64.0.3"), | ||||||
| 						User: types.User{Name: "marc"}, | 						User: users[1], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.4"), | 						IPv4: iap("100.64.0.4"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| @ -1260,21 +1517,21 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:       iap("100.64.0.1"), | 						IPv4:       iap("100.64.0.1"), | ||||||
| 						User:       types.User{Name: "joe"}, | 						User:       users[0], | ||||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | 						ForcedTags: []string{"tag:hr-webserver"}, | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:       iap("100.64.0.2"), | 						IPv4:       iap("100.64.0.2"), | ||||||
| 						User:       types.User{Name: "joe"}, | 						User:       users[0], | ||||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | 						ForcedTags: []string{"tag:hr-webserver"}, | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.3"), | 						IPv4: iap("100.64.0.3"), | ||||||
| 						User: types.User{Name: "marc"}, | 						User: users[1], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.4"), | 						IPv4: iap("100.64.0.4"), | ||||||
| 						User: types.User{Name: "mickael"}, | 						User: users[2], | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| @ -1295,12 +1552,12 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 				nodes: types.Nodes{ | 				nodes: types.Nodes{ | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:       iap("100.64.0.1"), | 						IPv4:       iap("100.64.0.1"), | ||||||
| 						User:       types.User{Name: "joe"}, | 						User:       users[0], | ||||||
| 						ForcedTags: []string{"tag:hr-webserver"}, | 						ForcedTags: []string{"tag:hr-webserver"}, | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.2"), | 						IPv4: iap("100.64.0.2"), | ||||||
| 						User: types.User{Name: "joe"}, | 						User: users[0], | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{ | 						Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 							OS:          "centos", | 							OS:          "centos", | ||||||
| 							Hostname:    "foo", | 							Hostname:    "foo", | ||||||
| @ -1309,11 +1566,11 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.3"), | 						IPv4: iap("100.64.0.3"), | ||||||
| 						User: types.User{Name: "marc"}, | 						User: users[1], | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4: iap("100.64.0.4"), | 						IPv4: iap("100.64.0.4"), | ||||||
| 						User: types.User{Name: "mickael"}, | 						User: users[2], | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| @ -1350,12 +1607,12 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:     iap("100.64.0.3"), | 						IPv4:     iap("100.64.0.3"), | ||||||
| 						User:     types.User{Name: "marc"}, | 						User:     users[1], | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | 						Hostinfo: &tailcfg.Hostinfo{}, | ||||||
| 					}, | 					}, | ||||||
| 					&types.Node{ | 					&types.Node{ | ||||||
| 						IPv4:     iap("100.64.0.4"), | 						IPv4:     iap("100.64.0.4"), | ||||||
| 						User:     types.User{Name: "joe"}, | 						User:     users[0], | ||||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | 						Hostinfo: &tailcfg.Hostinfo{}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| @ -1368,6 +1625,7 @@ func Test_expandAlias(t *testing.T) { | |||||||
| 		t.Run(test.name, func(t *testing.T) { | 		t.Run(test.name, func(t *testing.T) { | ||||||
| 			got, err := test.field.pol.ExpandAlias( | 			got, err := test.field.pol.ExpandAlias( | ||||||
| 				test.args.nodes, | 				test.args.nodes, | ||||||
|  | 				users, | ||||||
| 				test.args.alias, | 				test.args.alias, | ||||||
| 			) | 			) | ||||||
| 			if (err != nil) != test.wantErr { | 			if (err != nil) != test.wantErr { | ||||||
| @ -1715,6 +1973,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | |||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			got, err := tt.field.pol.CompileFilterRules( | 			got, err := tt.field.pol.CompileFilterRules( | ||||||
|  | 				[]types.User{}, | ||||||
| 				tt.args.nodes, | 				tt.args.nodes, | ||||||
| 			) | 			) | ||||||
| 			if (err != nil) != tt.wantErr { | 			if (err != nil) != tt.wantErr { | ||||||
| @ -1842,6 +2101,13 @@ func TestTheInternet(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestReduceFilterRules(t *testing.T) { | func TestReduceFilterRules(t *testing.T) { | ||||||
|  | 	users := []types.User{ | ||||||
|  | 		{Model: gorm.Model{ID: 1}, Name: "mickael"}, | ||||||
|  | 		{Model: gorm.Model{ID: 2}, Name: "user1"}, | ||||||
|  | 		{Model: gorm.Model{ID: 3}, Name: "user2"}, | ||||||
|  | 		{Model: gorm.Model{ID: 4}, Name: "user100"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		name  string | 		name  string | ||||||
| 		node  *types.Node | 		node  *types.Node | ||||||
| @ -1863,13 +2129,13 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.1"), | 				IPv4: iap("100.64.0.1"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), | 				IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), | ||||||
| 				User: types.User{Name: "mickael"}, | 				User: users[0], | ||||||
| 			}, | 			}, | ||||||
| 			peers: types.Nodes{ | 			peers: types.Nodes{ | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), | 					IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), | ||||||
| 					User: types.User{Name: "mickael"}, | 					User: users[0], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{}, | 			want: []tailcfg.FilterRule{}, | ||||||
| @ -1896,7 +2162,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.1"), | 				IPv4: iap("100.64.0.1"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::1"), | 				IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 				User: types.User{Name: "user1"}, | 				User: users[1], | ||||||
| 				Hostinfo: &tailcfg.Hostinfo{ | 				Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 					RoutableIPs: []netip.Prefix{ | 					RoutableIPs: []netip.Prefix{ | ||||||
| 						netip.MustParsePrefix("10.33.0.0/16"), | 						netip.MustParsePrefix("10.33.0.0/16"), | ||||||
| @ -1907,7 +2173,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -1975,19 +2241,19 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.1"), | 				IPv4: iap("100.64.0.1"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::1"), | 				IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 				User: types.User{Name: "user1"}, | 				User: users[1], | ||||||
| 			}, | 			}, | ||||||
| 			peers: types.Nodes{ | 			peers: types.Nodes{ | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user2"}, | 					User: users[2], | ||||||
| 				}, | 				}, | ||||||
| 				// "internal" exit node | 				// "internal" exit node | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.100"), | 					IPv4: iap("100.64.0.100"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::100"), | 					IPv6: iap("fd7a:115c:a1e0::100"), | ||||||
| 					User: types.User{Name: "user100"}, | 					User: users[3], | ||||||
| 					Hostinfo: &tailcfg.Hostinfo{ | 					Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 						RoutableIPs: tsaddr.ExitRoutes(), | 						RoutableIPs: tsaddr.ExitRoutes(), | ||||||
| 					}, | 					}, | ||||||
| @ -2034,12 +2300,12 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user2"}, | 					User: users[2], | ||||||
| 				}, | 				}, | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.1"), | 					IPv4: iap("100.64.0.1"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -2131,7 +2397,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.100"), | 				IPv4: iap("100.64.0.100"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||||
| 				User: types.User{Name: "user100"}, | 				User: users[3], | ||||||
| 				Hostinfo: &tailcfg.Hostinfo{ | 				Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 					RoutableIPs: tsaddr.ExitRoutes(), | 					RoutableIPs: tsaddr.ExitRoutes(), | ||||||
| 				}, | 				}, | ||||||
| @ -2140,12 +2406,12 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user2"}, | 					User: users[2], | ||||||
| 				}, | 				}, | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.1"), | 					IPv4: iap("100.64.0.1"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -2243,7 +2509,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.100"), | 				IPv4: iap("100.64.0.100"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||||
| 				User: types.User{Name: "user100"}, | 				User: users[3], | ||||||
| 				Hostinfo: &tailcfg.Hostinfo{ | 				Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 					RoutableIPs: []netip.Prefix{ | 					RoutableIPs: []netip.Prefix{ | ||||||
| 						netip.MustParsePrefix("8.0.0.0/16"), | 						netip.MustParsePrefix("8.0.0.0/16"), | ||||||
| @ -2255,12 +2521,12 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user2"}, | 					User: users[2], | ||||||
| 				}, | 				}, | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.1"), | 					IPv4: iap("100.64.0.1"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -2333,7 +2599,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.100"), | 				IPv4: iap("100.64.0.100"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||||
| 				User: types.User{Name: "user100"}, | 				User: users[3], | ||||||
| 				Hostinfo: &tailcfg.Hostinfo{ | 				Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 					RoutableIPs: []netip.Prefix{ | 					RoutableIPs: []netip.Prefix{ | ||||||
| 						netip.MustParsePrefix("8.0.0.0/8"), | 						netip.MustParsePrefix("8.0.0.0/8"), | ||||||
| @ -2345,12 +2611,12 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.2"), | 					IPv4: iap("100.64.0.2"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::2"), | 					IPv6: iap("fd7a:115c:a1e0::2"), | ||||||
| 					User: types.User{Name: "user2"}, | 					User: users[2], | ||||||
| 				}, | 				}, | ||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.1"), | 					IPv4: iap("100.64.0.1"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -2416,7 +2682,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 			node: &types.Node{ | 			node: &types.Node{ | ||||||
| 				IPv4: iap("100.64.0.100"), | 				IPv4: iap("100.64.0.100"), | ||||||
| 				IPv6: iap("fd7a:115c:a1e0::100"), | 				IPv6: iap("fd7a:115c:a1e0::100"), | ||||||
| 				User: types.User{Name: "user100"}, | 				User: users[3], | ||||||
| 				Hostinfo: &tailcfg.Hostinfo{ | 				Hostinfo: &tailcfg.Hostinfo{ | ||||||
| 					RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, | 					RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, | ||||||
| 				}, | 				}, | ||||||
| @ -2426,7 +2692,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 				&types.Node{ | 				&types.Node{ | ||||||
| 					IPv4: iap("100.64.0.1"), | 					IPv4: iap("100.64.0.1"), | ||||||
| 					IPv6: iap("fd7a:115c:a1e0::1"), | 					IPv6: iap("fd7a:115c:a1e0::1"), | ||||||
| 					User: types.User{Name: "user1"}, | 					User: users[1], | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 			want: []tailcfg.FilterRule{ | 			want: []tailcfg.FilterRule{ | ||||||
| @ -2454,6 +2720,7 @@ func TestReduceFilterRules(t *testing.T) { | |||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			got, _ := tt.pol.CompileFilterRules( | 			got, _ := tt.pol.CompileFilterRules( | ||||||
|  | 				users, | ||||||
| 				append(tt.peers, tt.node), | 				append(tt.peers, tt.node), | ||||||
| 			) | 			) | ||||||
| 
 | 
 | ||||||
| @ -3461,7 +3728,7 @@ func TestSSHRules(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) | 			got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			if diff := cmp.Diff(tt.want, got); diff != "" { | 			if diff := cmp.Diff(tt.want, got); diff != "" { | ||||||
| @ -3544,14 +3811,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | |||||||
| 		RequestTags: []string{"tag:test"}, | 		RequestTags: []string{"tag:test"}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	user := types.User{ | ||||||
|  | 		Model: gorm.Model{ID: 1}, | ||||||
|  | 		Name:  "user1", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	node := &types.Node{ | 	node := &types.Node{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		Hostname:       "testnodes", | 		Hostname:       "testnodes", | ||||||
| 		IPv4:           iap("100.64.0.1"), | 		IPv4:           iap("100.64.0.1"), | ||||||
| 		UserID:         0, | 		UserID:         0, | ||||||
| 		User: types.User{ | 		User:           user, | ||||||
| 			Name: "user1", |  | ||||||
| 		}, |  | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		Hostinfo:       &hostInfo, | 		Hostinfo:       &hostInfo, | ||||||
| 	} | 	} | ||||||
| @ -3568,7 +3838,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	want := []tailcfg.FilterRule{ | 	want := []tailcfg.FilterRule{ | ||||||
| @ -3602,6 +3872,7 @@ func TestInvalidTagValidUser(t *testing.T) { | |||||||
| 		IPv4:     iap("100.64.0.1"), | 		IPv4:     iap("100.64.0.1"), | ||||||
| 		UserID:   1, | 		UserID:   1, | ||||||
| 		User: types.User{ | 		User: types.User{ | ||||||
|  | 			Model: gorm.Model{ID: 1}, | ||||||
| 			Name:  "user1", | 			Name:  "user1", | ||||||
| 		}, | 		}, | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| @ -3619,7 +3890,12 @@ func TestInvalidTagValidUser(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		node, | ||||||
|  | 		types.Nodes{}, | ||||||
|  | 		[]types.User{node.User}, | ||||||
|  | 	) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	want := []tailcfg.FilterRule{ | 	want := []tailcfg.FilterRule{ | ||||||
| @ -3653,6 +3929,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | |||||||
| 		IPv4:     iap("100.64.0.1"), | 		IPv4:     iap("100.64.0.1"), | ||||||
| 		UserID:   1, | 		UserID:   1, | ||||||
| 		User: types.User{ | 		User: types.User{ | ||||||
|  | 			Model: gorm.Model{ID: 1}, | ||||||
| 			Name:  "user1", | 			Name:  "user1", | ||||||
| 		}, | 		}, | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| @ -3678,7 +3955,12 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | |||||||
| 	// c.Assert(rules[0].DstPorts, check.HasLen, 1) | 	// c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||||
| 	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | 	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||||
| 
 | 
 | ||||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) | 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		node, | ||||||
|  | 		types.Nodes{}, | ||||||
|  | 		[]types.User{node.User}, | ||||||
|  | 	) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	want := []tailcfg.FilterRule{ | 	want := []tailcfg.FilterRule{ | ||||||
| @ -3707,15 +3989,17 @@ func TestValidTagInvalidUser(t *testing.T) { | |||||||
| 		Hostname:    "webserver", | 		Hostname:    "webserver", | ||||||
| 		RequestTags: []string{"tag:webapp"}, | 		RequestTags: []string{"tag:webapp"}, | ||||||
| 	} | 	} | ||||||
|  | 	user := types.User{ | ||||||
|  | 		Model: gorm.Model{ID: 1}, | ||||||
|  | 		Name:  "user1", | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	node := &types.Node{ | 	node := &types.Node{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		Hostname:       "webserver", | 		Hostname:       "webserver", | ||||||
| 		IPv4:           iap("100.64.0.1"), | 		IPv4:           iap("100.64.0.1"), | ||||||
| 		UserID:         1, | 		UserID:         1, | ||||||
| 		User: types.User{ | 		User:           user, | ||||||
| 			Name: "user1", |  | ||||||
| 		}, |  | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		Hostinfo:       &hostInfo, | 		Hostinfo:       &hostInfo, | ||||||
| 	} | 	} | ||||||
| @ -3730,9 +4014,7 @@ func TestValidTagInvalidUser(t *testing.T) { | |||||||
| 		Hostname:       "user", | 		Hostname:       "user", | ||||||
| 		IPv4:           iap("100.64.0.2"), | 		IPv4:           iap("100.64.0.2"), | ||||||
| 		UserID:         1, | 		UserID:         1, | ||||||
| 		User: types.User{ | 		User:           user, | ||||||
| 			Name: "user1", |  | ||||||
| 		}, |  | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		Hostinfo:       &hostInfo2, | 		Hostinfo:       &hostInfo2, | ||||||
| 	} | 	} | ||||||
| @ -3748,7 +4030,12 @@ func TestValidTagInvalidUser(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) | 	got, _, err := GenerateFilterAndSSHRulesForTests( | ||||||
|  | 		pol, | ||||||
|  | 		node, | ||||||
|  | 		types.Nodes{nodes2}, | ||||||
|  | 		[]types.User{user}, | ||||||
|  | 	) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	want := []tailcfg.FilterRule{ | 	want := []tailcfg.FilterRule{ | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ package types | |||||||
| import ( | import ( | ||||||
| 	"cmp" | 	"cmp" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
|  | 	"net/mail" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| @ -56,14 +57,7 @@ type User struct { | |||||||
| // should be used throughout headscale, in information returned to the | // should be used throughout headscale, in information returned to the | ||||||
| // user and the Policy engine. | // user and the Policy engine. | ||||||
| func (u *User) Username() string { | func (u *User) Username() string { | ||||||
| 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) | 	return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) | ||||||
| 
 |  | ||||||
| 	// TODO(kradalby): Wire up all of this for the future |  | ||||||
| 	// if !strings.Contains(username, "@") { |  | ||||||
| 	// 	username = username + "@" |  | ||||||
| 	// } |  | ||||||
| 
 |  | ||||||
| 	return username |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // DisplayNameOrUsername returns the DisplayName if it exists, otherwise | // DisplayNameOrUsername returns the DisplayName if it exists, otherwise | ||||||
| @ -146,12 +140,20 @@ func (c *OIDCClaims) Identifier() string { | |||||||
| // FromClaim overrides a User from OIDC claims. | // FromClaim overrides a User from OIDC claims. | ||||||
| // All fields will be updated, except for the ID. | // All fields will be updated, except for the ID. | ||||||
| func (u *User) FromClaim(claims *OIDCClaims) { | func (u *User) FromClaim(claims *OIDCClaims) { | ||||||
| 	u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} | 	err := util.CheckForFQDNRules(claims.Username) | ||||||
| 	u.DisplayName = claims.Name | 	if err == nil { | ||||||
|  | 		u.Name = claims.Username | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if claims.EmailVerified { | 	if claims.EmailVerified { | ||||||
|  | 		_, err = mail.ParseAddress(claims.Email) | ||||||
|  | 		if err == nil { | ||||||
| 			u.Email = claims.Email | 			u.Email = claims.Email | ||||||
| 		} | 		} | ||||||
| 	u.Name = claims.Username | 	} | ||||||
|  | 
 | ||||||
|  | 	u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} | ||||||
|  | 	u.DisplayName = claims.Name | ||||||
| 	u.ProfilePicURL = claims.ProfilePictureURL | 	u.ProfilePicURL = claims.ProfilePictureURL | ||||||
| 	u.Provider = util.RegisterMethodOIDC | 	u.Provider = util.RegisterMethodOIDC | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user