mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 08:01:34 +01:00 
			
		
		
		
	remove "stripEmailDomain" argument
This commit makes a wrapper function round the normalisation requiring "stripEmailDomain" which has to be passed in almost all functions of headscale by loading it from Viper instead. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									161243c787
								
							
						
					
					
						commit
						717abe89c1
					
				| @ -169,7 +169,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { | ||||
| 	database, err := db.NewHeadscaleDatabase( | ||||
| 		cfg.DBtype, | ||||
| 		dbString, | ||||
| 		cfg.OIDC.StripEmaildomain, | ||||
| 		app.dbDebug, | ||||
| 		app.stateUpdateChan, | ||||
| 		cfg.IPPrefixes, | ||||
|  | ||||
| @ -53,7 +53,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -107,7 +107,7 @@ func TestInvalidTagValidUser(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -169,7 +169,7 @@ func TestPortGroup(t *testing.T) { | ||||
| 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -224,7 +224,7 @@ func TestPortUser(t *testing.T) { | ||||
| 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -285,7 +285,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | ||||
| 	// c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||
| 	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
| @ -361,7 +361,7 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false) | ||||
| 	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	want := []tailcfg.FilterRule{ | ||||
|  | ||||
| @ -41,16 +41,15 @@ type HSDatabase struct { | ||||
| 
 | ||||
| 	ipAllocationMutex sync.Mutex | ||||
| 
 | ||||
| 	ipPrefixes       []netip.Prefix | ||||
| 	baseDomain       string | ||||
| 	stripEmailDomain bool | ||||
| 	ipPrefixes []netip.Prefix | ||||
| 	baseDomain string | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): assemble this struct from toptions or something typed | ||||
| // rather than arguments. | ||||
| func NewHeadscaleDatabase( | ||||
| 	dbType, connectionAddr string, | ||||
| 	stripEmailDomain, debug bool, | ||||
| 	debug bool, | ||||
| 	notifyStateChan chan<- struct{}, | ||||
| 	ipPrefixes []netip.Prefix, | ||||
| 	baseDomain string, | ||||
| @ -64,9 +63,8 @@ func NewHeadscaleDatabase( | ||||
| 		db:              dbConn, | ||||
| 		notifyStateChan: notifyStateChan, | ||||
| 
 | ||||
| 		ipPrefixes:       ipPrefixes, | ||||
| 		baseDomain:       baseDomain, | ||||
| 		stripEmailDomain: stripEmailDomain, | ||||
| 		ipPrefixes: ipPrefixes, | ||||
| 		baseDomain: baseDomain, | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug().Msgf("database %#v", dbConn) | ||||
| @ -202,9 +200,8 @@ func NewHeadscaleDatabase( | ||||
| 
 | ||||
| 		for item, machine := range machines { | ||||
| 			if machine.GivenName == "" { | ||||
| 				normalizedHostname, err := util.NormalizeToFQDNRules( | ||||
| 				normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( | ||||
| 					machine.Hostname, | ||||
| 					stripEmailDomain, | ||||
| 				) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
|  | ||||
| @ -632,9 +632,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { | ||||
| 	normalizedHostname, err := util.NormalizeToFQDNRules( | ||||
| 	normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( | ||||
| 		suppliedName, | ||||
| 		hsdb.stripEmailDomain, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
|  | ||||
| @ -293,10 +293,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { | ||||
| 	testPeers, err := db.ListPeers(testMachine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false) | ||||
| 	adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false) | ||||
| 	testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules) | ||||
| @ -482,9 +482,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "simple machine name generation", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmachine", | ||||
| 				randomSuffix: false, | ||||
| @ -494,9 +492,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 53 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | ||||
| 				randomSuffix: false, | ||||
| @ -506,9 +502,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| @ -518,9 +512,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 64 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | ||||
| 				randomSuffix: false, | ||||
| @ -530,9 +522,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 73 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| @ -542,9 +532,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with random suffix", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "test", | ||||
| 				randomSuffix: true, | ||||
| @ -554,9 +542,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars with random suffix", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: true, | ||||
|  | ||||
| @ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | ||||
| 				approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||
| 			} else { | ||||
| 				// TODO(kradalby): figure out how to get this to depend on less stuff | ||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) | ||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias) | ||||
| 				if err != nil { | ||||
| 					log.Err(err). | ||||
| 						Str("alias", approvedAlias). | ||||
|  | ||||
| @ -60,7 +60,6 @@ func (s *Suite) ResetDB(c *check.C) { | ||||
| 		"sqlite3", | ||||
| 		tmpDir+"/headscale_test.db", | ||||
| 		false, | ||||
| 		false, | ||||
| 		sink, | ||||
| 		[]netip.Prefix{ | ||||
| 			netip.MustParsePrefix("10.27.0.0/23"), | ||||
|  | ||||
| @ -340,7 +340,6 @@ func (api headscaleV1APIServer) ListMachines( | ||||
| 		m := machine.Proto() | ||||
| 		validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( | ||||
| 			machine, | ||||
| 			api.h.cfg.OIDC.StripEmaildomain, | ||||
| 		) | ||||
| 		m.InvalidTags = invalidTags | ||||
| 		m.ValidTags = validTags | ||||
|  | ||||
| @ -41,7 +41,6 @@ type Mapper struct { | ||||
| 	dnsCfg           *tailcfg.DNSConfig | ||||
| 	logtail          bool | ||||
| 	randomClientPort bool | ||||
| 	stripEmailDomain bool | ||||
| } | ||||
| 
 | ||||
| func NewMapper( | ||||
| @ -53,7 +52,6 @@ func NewMapper( | ||||
| 	dnsCfg *tailcfg.DNSConfig, | ||||
| 	logtail bool, | ||||
| 	randomClientPort bool, | ||||
| 	stripEmailDomain bool, | ||||
| ) *Mapper { | ||||
| 	return &Mapper{ | ||||
| 		db: db, | ||||
| @ -66,7 +64,6 @@ func NewMapper( | ||||
| 		dnsCfg:           dnsCfg, | ||||
| 		logtail:          logtail, | ||||
| 		randomClientPort: randomClientPort, | ||||
| 		stripEmailDomain: stripEmailDomain, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -87,14 +84,13 @@ func fullMapResponse( | ||||
| 	machine *types.Machine, | ||||
| 	peers types.Machines, | ||||
| 
 | ||||
| 	stripEmailDomain bool, | ||||
| 	baseDomain string, | ||||
| 	dnsCfg *tailcfg.DNSConfig, | ||||
| 	derpMap *tailcfg.DERPMap, | ||||
| 	logtail bool, | ||||
| 	randomClientPort bool, | ||||
| ) (*tailcfg.MapResponse, error) { | ||||
| 	tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain, stripEmailDomain) | ||||
| 	tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -103,7 +99,6 @@ func fullMapResponse( | ||||
| 		pol, | ||||
| 		machine, | ||||
| 		peers, | ||||
| 		stripEmailDomain, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -129,7 +124,7 @@ func fullMapResponse( | ||||
| 		peers, | ||||
| 	) | ||||
| 
 | ||||
| 	tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain, stripEmailDomain) | ||||
| 	tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -296,7 +291,6 @@ func (m Mapper) CreateMapResponse( | ||||
| 		pol, | ||||
| 		machine, | ||||
| 		peers, | ||||
| 		m.stripEmailDomain, | ||||
| 		m.baseDomain, | ||||
| 		m.dnsCfg, | ||||
| 		m.derpMap, | ||||
|  | ||||
| @ -320,7 +320,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		machine *types.Machine | ||||
| 		peers   types.Machines | ||||
| 
 | ||||
| 		stripEmailDomain bool | ||||
| 		baseDomain       string | ||||
| 		dnsConfig        *tailcfg.DNSConfig | ||||
| 		derpMap          *tailcfg.DERPMap | ||||
| @ -335,7 +334,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		// 	pol:              &policy.ACLPolicy{}, | ||||
| 		// 	dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 		// 	baseDomain:       "", | ||||
| 		// 	stripEmailDomain: false, | ||||
| 		// 	want:             nil, | ||||
| 		// 	wantErr:          true, | ||||
| 		// }, | ||||
| @ -344,7 +342,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 			pol:              &policy.ACLPolicy{}, | ||||
| 			machine:          mini, | ||||
| 			peers:            []types.Machine{}, | ||||
| 			stripEmailDomain: false, | ||||
| 			baseDomain:       "", | ||||
| 			dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 			derpMap:          &tailcfg.DERPMap{}, | ||||
| @ -375,7 +372,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 			peers: []types.Machine{ | ||||
| 				peer1, | ||||
| 			}, | ||||
| 			stripEmailDomain: false, | ||||
| 			baseDomain:       "", | ||||
| 			dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 			derpMap:          &tailcfg.DERPMap{}, | ||||
| @ -417,7 +413,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 				peer1, | ||||
| 				peer2, | ||||
| 			}, | ||||
| 			stripEmailDomain: false, | ||||
| 			baseDomain:       "", | ||||
| 			dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 			derpMap:          &tailcfg.DERPMap{}, | ||||
| @ -458,7 +453,6 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 				tt.pol, | ||||
| 				tt.machine, | ||||
| 				tt.peers, | ||||
| 				tt.stripEmailDomain, | ||||
| 				tt.baseDomain, | ||||
| 				tt.dnsConfig, | ||||
| 				tt.derpMap, | ||||
|  | ||||
| @ -18,7 +18,6 @@ func tailNodes( | ||||
| 	pol *policy.ACLPolicy, | ||||
| 	dnsConfig *tailcfg.DNSConfig, | ||||
| 	baseDomain string, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]*tailcfg.Node, error) { | ||||
| 	nodes := make([]*tailcfg.Node, len(machines)) | ||||
| 
 | ||||
| @ -28,7 +27,6 @@ func tailNodes( | ||||
| 			pol, | ||||
| 			dnsConfig, | ||||
| 			baseDomain, | ||||
| 			stripEmailDomain, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| @ -47,7 +45,6 @@ func tailNode( | ||||
| 	pol *policy.ACLPolicy, | ||||
| 	dnsConfig *tailcfg.DNSConfig, | ||||
| 	baseDomain string, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*tailcfg.Node, error) { | ||||
| 	nodeKey, err := machine.NodePublicKey() | ||||
| 	if err != nil { | ||||
| @ -107,7 +104,7 @@ func tailNode( | ||||
| 
 | ||||
| 	online := machine.IsOnline() | ||||
| 
 | ||||
| 	tags, _ := pol.GetTagsOfMachine(machine, stripEmailDomain) | ||||
| 	tags, _ := pol.GetTagsOfMachine(machine) | ||||
| 	tags = lo.Uniq(append(tags, machine.ForcedTags...)) | ||||
| 
 | ||||
| 	node := tailcfg.Node{ | ||||
|  | ||||
| @ -44,24 +44,22 @@ func TestTailNode(t *testing.T) { | ||||
| 	expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		name             string | ||||
| 		machine          types.Machine | ||||
| 		pol              *policy.ACLPolicy | ||||
| 		dnsConfig        *tailcfg.DNSConfig | ||||
| 		baseDomain       string | ||||
| 		stripEmailDomain bool | ||||
| 		want             *tailcfg.Node | ||||
| 		wantErr          bool | ||||
| 		name       string | ||||
| 		machine    types.Machine | ||||
| 		pol        *policy.ACLPolicy | ||||
| 		dnsConfig  *tailcfg.DNSConfig | ||||
| 		baseDomain string | ||||
| 		want       *tailcfg.Node | ||||
| 		wantErr    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:             "empty-machine", | ||||
| 			machine:          types.Machine{}, | ||||
| 			pol:              &policy.ACLPolicy{}, | ||||
| 			dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 			baseDomain:       "", | ||||
| 			stripEmailDomain: false, | ||||
| 			want:             nil, | ||||
| 			wantErr:          true, | ||||
| 			name:       "empty-machine", | ||||
| 			machine:    types.Machine{}, | ||||
| 			pol:        &policy.ACLPolicy{}, | ||||
| 			dnsConfig:  &tailcfg.DNSConfig{}, | ||||
| 			baseDomain: "", | ||||
| 			want:       nil, | ||||
| 			wantErr:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "minimal-machine", | ||||
| @ -108,10 +106,9 @@ func TestTailNode(t *testing.T) { | ||||
| 				}, | ||||
| 				CreatedAt: created, | ||||
| 			}, | ||||
| 			pol:              &policy.ACLPolicy{}, | ||||
| 			dnsConfig:        &tailcfg.DNSConfig{}, | ||||
| 			baseDomain:       "", | ||||
| 			stripEmailDomain: false, | ||||
| 			pol:        &policy.ACLPolicy{}, | ||||
| 			dnsConfig:  &tailcfg.DNSConfig{}, | ||||
| 			baseDomain: "", | ||||
| 			want: &tailcfg.Node{ | ||||
| 				ID:       0, | ||||
| 				StableID: "0", | ||||
| @ -172,7 +169,6 @@ func TestTailNode(t *testing.T) { | ||||
| 				tt.pol, | ||||
| 				tt.dnsConfig, | ||||
| 				tt.baseDomain, | ||||
| 				tt.stripEmailDomain, | ||||
| 			) | ||||
| 
 | ||||
| 			if (err != nil) != tt.wantErr { | ||||
|  | ||||
| @ -121,14 +121,13 @@ func GenerateFilterRules( | ||||
| 	policy *ACLPolicy, | ||||
| 	machine *types.Machine, | ||||
| 	peers types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ||||
| 	// If there is no policy defined, we default to allow all | ||||
| 	if policy == nil { | ||||
| 		return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain) | ||||
| 	rules, err := policy.generateFilterRules(machine, peers) | ||||
| 	if err != nil { | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 	} | ||||
| @ -136,7 +135,7 @@ func GenerateFilterRules( | ||||
| 	log.Trace().Interface("ACL", rules).Msg("ACL rules generated") | ||||
| 
 | ||||
| 	var sshPolicy *tailcfg.SSHPolicy | ||||
| 	sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain) | ||||
| 	sshRules, err := policy.generateSSHRules(machine, peers) | ||||
| 	if err != nil { | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 	} | ||||
| @ -154,7 +153,6 @@ func GenerateFilterRules( | ||||
| func (pol *ACLPolicy) generateFilterRules( | ||||
| 	machine *types.Machine, | ||||
| 	peers types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]tailcfg.FilterRule, error) { | ||||
| 	rules := []tailcfg.FilterRule{} | ||||
| 	machines := append(peers, *machine) | ||||
| @ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules( | ||||
| 
 | ||||
| 		srcIPs := []string{} | ||||
| 		for srcIndex, src := range acl.Sources { | ||||
| 			srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain) | ||||
| 			srcs, err := pol.getIPsFromSource(src, machines) | ||||
| 			if err != nil { | ||||
| 				log.Error(). | ||||
| 					Interface("src", src). | ||||
| @ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules( | ||||
| 				dest, | ||||
| 				machines, | ||||
| 				needsWildcard, | ||||
| 				stripEmailDomain, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				log.Error(). | ||||
| @ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules( | ||||
| func (pol *ACLPolicy) generateSSHRules( | ||||
| 	machine *types.Machine, | ||||
| 	peers types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]*tailcfg.SSHRule, error) { | ||||
| 	rules := []*tailcfg.SSHRule{} | ||||
| 
 | ||||
| @ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules( | ||||
| 	for index, sshACL := range pol.SSHs { | ||||
| 		var dest netipx.IPSetBuilder | ||||
| 		for _, src := range sshACL.Destinations { | ||||
| 			expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain) | ||||
| 			expanded, err := pol.ExpandAlias(append(peers, *machine), src) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules( | ||||
| 					Any: true, | ||||
| 				}) | ||||
| 			} else if isGroup(rawSrc) { | ||||
| 				users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain) | ||||
| 				users, err := pol.getUsersInGroup(rawSrc) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Msgf("Error parsing SSH %d, Source %d", index, innerIndex) | ||||
| @ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules( | ||||
| 				expandedSrcs, err := pol.ExpandAlias( | ||||
| 					peers, | ||||
| 					rawSrc, | ||||
| 					stripEmailDomain, | ||||
| 				) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| @ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { | ||||
| func (pol *ACLPolicy) getIPsFromSource( | ||||
| 	src string, | ||||
| 	machines types.Machines, | ||||
| 	stripEmaildomain bool, | ||||
| ) ([]string, error) { | ||||
| 	ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) | ||||
| 	ipSet, err := pol.ExpandAlias(machines, src) | ||||
| 	if err != nil { | ||||
| 		return []string{}, err | ||||
| 	} | ||||
| @ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||
| 	dest string, | ||||
| 	machines types.Machines, | ||||
| 	needsWildcard bool, | ||||
| 	stripEmaildomain bool, | ||||
| ) ([]tailcfg.NetPortRange, error) { | ||||
| 	var tokens []string | ||||
| 
 | ||||
| @ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||
| 	expanded, err := pol.ExpandAlias( | ||||
| 		machines, | ||||
| 		alias, | ||||
| 		stripEmaildomain, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) { | ||||
| func (pol *ACLPolicy) ExpandAlias( | ||||
| 	machines types.Machines, | ||||
| 	alias string, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	if isWildcard(alias) { | ||||
| 		return util.ParseIPSet("*", nil) | ||||
| @ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias( | ||||
| 
 | ||||
| 	// if alias is a group | ||||
| 	if isGroup(alias) { | ||||
| 		return pol.getIPsFromGroup(alias, machines, stripEmailDomain) | ||||
| 		return pol.getIPsFromGroup(alias, machines) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is a tag | ||||
| 	if isTag(alias) { | ||||
| 		return pol.getIPsFromTag(alias, machines, stripEmailDomain) | ||||
| 		return pol.getIPsFromTag(alias, machines) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is a user | ||||
| 	if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil { | ||||
| 	if ips, err := pol.getIPsForUser(alias, machines); ips != nil { | ||||
| 		return ips, err | ||||
| 	} | ||||
| 
 | ||||
| @ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias( | ||||
| 	if h, ok := pol.Hosts[alias]; ok { | ||||
| 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | ||||
| 
 | ||||
| 		return pol.ExpandAlias(machines, h.String(), stripEmailDomain) | ||||
| 		return pol.ExpandAlias(machines, h.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is an IP | ||||
| @ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes( | ||||
| 	aclPolicy *ACLPolicy, | ||||
| 	nodes types.Machines, | ||||
| 	user string, | ||||
| 	stripEmailDomain bool, | ||||
| ) types.Machines { | ||||
| 	out := types.Machines{} | ||||
| 	tags := []string{} | ||||
| 	for tag := range aclPolicy.TagOwners { | ||||
| 		owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) | ||||
| 		owners, _ := getTagOwners(aclPolicy, user) | ||||
| 		ns := append(owners, user) | ||||
| 		if util.StringOrPrefixListContains(ns, user) { | ||||
| 			tags = append(tags, tag) | ||||
| @ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines { | ||||
| func getTagOwners( | ||||
| 	pol *ACLPolicy, | ||||
| 	tag string, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]string, error) { | ||||
| 	var owners []string | ||||
| 	ows, ok := pol.TagOwners[tag] | ||||
| @ -687,7 +676,7 @@ func getTagOwners( | ||||
| 	} | ||||
| 	for _, owner := range ows { | ||||
| 		if isGroup(owner) { | ||||
| 			gs, err := pol.getUsersInGroup(owner, stripEmailDomain) | ||||
| 			gs, err := pol.getUsersInGroup(owner) | ||||
| 			if err != nil { | ||||
| 				return []string{}, err | ||||
| 			} | ||||
| @ -704,7 +693,6 @@ func getTagOwners( | ||||
| // after some validation. | ||||
| func (pol *ACLPolicy) getUsersInGroup( | ||||
| 	group string, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]string, error) { | ||||
| 	users := []string{} | ||||
| 	log.Trace().Caller().Interface("pol", pol).Msg("test") | ||||
| @ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup( | ||||
| 				ErrInvalidGroup, | ||||
| 			) | ||||
| 		} | ||||
| 		grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) | ||||
| 		grp, err := util.NormalizeToFQDNRulesConfigFromViper(group) | ||||
| 		if err != nil { | ||||
| 			return []string{}, fmt.Errorf( | ||||
| 				"failed to normalize group %q, err: %w", | ||||
| @ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup( | ||||
| func (pol *ACLPolicy) getIPsFromGroup( | ||||
| 	group string, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| 
 | ||||
| 	users, err := pol.getUsersInGroup(group, stripEmailDomain) | ||||
| 	users, err := pol.getUsersInGroup(group) | ||||
| 	if err != nil { | ||||
| 		return &netipx.IPSet{}, err | ||||
| 	} | ||||
| @ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup( | ||||
| func (pol *ACLPolicy) getIPsFromTag( | ||||
| 	alias string, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| 
 | ||||
| @ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag( | ||||
| 	} | ||||
| 
 | ||||
| 	// find tag owners | ||||
| 	owners, err := getTagOwners(pol, alias, stripEmailDomain) | ||||
| 	owners, err := getTagOwners(pol, alias) | ||||
| 	if err != nil { | ||||
| 		if errors.Is(err, ErrInvalidTag) { | ||||
| 			ipSet, _ := build.IPSet() | ||||
| @ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag( | ||||
| func (pol *ACLPolicy) getIPsForUser( | ||||
| 	user string, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| 
 | ||||
| 	filteredMachines := filterMachinesByUser(machines, user) | ||||
| 	filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain) | ||||
| 	filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user) | ||||
| 
 | ||||
| 	// shortcurcuit if we have no machines to get ips from. | ||||
| 	if len(filteredMachines) == 0 { | ||||
| @ -885,7 +870,6 @@ func isTag(str string) bool { | ||||
| // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. | ||||
| func (pol *ACLPolicy) GetTagsOfMachine( | ||||
| 	machine types.Machine, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]string, []string) { | ||||
| 	validTags := make([]string, 0) | ||||
| 	invalidTags := make([]string, 0) | ||||
| @ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine( | ||||
| 	validTagMap := make(map[string]bool) | ||||
| 	invalidTagMap := make(map[string]bool) | ||||
| 	for _, tag := range machine.HostInfo.RequestTags { | ||||
| 		owners, err := getTagOwners(pol, tag, stripEmailDomain) | ||||
| 		owners, err := getTagOwners(pol, tag) | ||||
| 		if errors.Is(err, ErrInvalidTag) { | ||||
| 			invalidTagMap[tag] = true | ||||
| 
 | ||||
|  | ||||
| @ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"github.com/spf13/viper" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"go4.org/netipx" | ||||
| 	"gopkg.in/check.v1" | ||||
| @ -199,7 +200,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { | ||||
| 	c.Assert(pol.ACLs, check.HasLen, 6) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	c.Assert(rules, check.IsNil) | ||||
| } | ||||
| @ -230,7 +231,7 @@ func (s *Suite) TestBasicRule(c *check.C) { | ||||
| 	pol, err := LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| } | ||||
| @ -246,7 +247,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -265,7 +266,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -281,7 +282,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) | ||||
| 	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| @ -310,7 +311,7 @@ func (s *Suite) TestPortRange(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| @ -366,7 +367,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| @ -401,7 +402,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| @ -428,7 +429,7 @@ acls: | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| @ -459,7 +460,7 @@ acls: | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) | ||||
| 	rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| @ -483,8 +484,8 @@ func Test_expandGroup(t *testing.T) { | ||||
| 		pol ACLPolicy | ||||
| 	} | ||||
| 	type args struct { | ||||
| 		group            string | ||||
| 		stripEmailDomain bool | ||||
| 		group      string | ||||
| 		stripEmail bool | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| @ -504,8 +505,7 @@ func Test_expandGroup(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				group:            "group:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				group: "group:test", | ||||
| 			}, | ||||
| 			want:    []string{"user1", "user2", "user3"}, | ||||
| 			wantErr: false, | ||||
| @ -521,14 +521,13 @@ func Test_expandGroup(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				group:            "group:undefined", | ||||
| 				stripEmailDomain: true, | ||||
| 				group: "group:undefined", | ||||
| 			}, | ||||
| 			want:    []string{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Expand emails in group", | ||||
| 			name: "Expand emails in group strip domains", | ||||
| 			field: field{ | ||||
| 				pol: ACLPolicy{ | ||||
| 					Groups: Groups{ | ||||
| @ -540,8 +539,8 @@ func Test_expandGroup(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				group:            "group:admin", | ||||
| 				stripEmailDomain: true, | ||||
| 				group:      "group:admin", | ||||
| 				stripEmail: true, | ||||
| 			}, | ||||
| 			want:    []string{"joe.bar", "john.doe"}, | ||||
| 			wantErr: false, | ||||
| @ -559,8 +558,7 @@ func Test_expandGroup(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				group:            "group:admin", | ||||
| 				stripEmailDomain: false, | ||||
| 				group: "group:admin", | ||||
| 			}, | ||||
| 			want:    []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"}, | ||||
| 			wantErr: false, | ||||
| @ -568,17 +566,20 @@ func Test_expandGroup(t *testing.T) { | ||||
| 	} | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			viper.Set("oidc.strip_email_domain", test.args.stripEmail) | ||||
| 
 | ||||
| 			got, err := test.field.pol.getUsersInGroup( | ||||
| 				test.args.group, | ||||
| 				test.args.stripEmailDomain, | ||||
| 			) | ||||
| 
 | ||||
| 			if (err != nil) != test.wantErr { | ||||
| 				t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) | ||||
| 
 | ||||
| 				return | ||||
| 			} | ||||
| 			if !reflect.DeepEqual(got, test.want) { | ||||
| 				t.Errorf("expandGroup() = %v, want %v", got, test.want) | ||||
| 
 | ||||
| 			if diff := cmp.Diff(test.want, got); diff != "" { | ||||
| 				t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| @ -586,9 +587,8 @@ func Test_expandGroup(t *testing.T) { | ||||
| 
 | ||||
| func Test_expandTagOwners(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		aclPolicy        *ACLPolicy | ||||
| 		tag              string | ||||
| 		stripEmailDomain bool | ||||
| 		aclPolicy *ACLPolicy | ||||
| 		tag       string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| @ -602,8 +602,7 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 				aclPolicy: &ACLPolicy{ | ||||
| 					TagOwners: TagOwners{"tag:test": []string{"user1"}}, | ||||
| 				}, | ||||
| 				tag:              "tag:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				tag: "tag:test", | ||||
| 			}, | ||||
| 			want:    []string{"user1"}, | ||||
| 			wantErr: false, | ||||
| @ -615,8 +614,7 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 					Groups:    Groups{"group:foo": []string{"user1", "user2"}}, | ||||
| 					TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, | ||||
| 				}, | ||||
| 				tag:              "tag:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				tag: "tag:test", | ||||
| 			}, | ||||
| 			want:    []string{"user1", "user2"}, | ||||
| 			wantErr: false, | ||||
| @ -628,8 +626,7 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 					Groups:    Groups{"group:foo": []string{"user1", "user2"}}, | ||||
| 					TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}}, | ||||
| 				}, | ||||
| 				tag:              "tag:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				tag: "tag:test", | ||||
| 			}, | ||||
| 			want:    []string{"user1", "user2", "user3"}, | ||||
| 			wantErr: false, | ||||
| @ -640,8 +637,7 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 				aclPolicy: &ACLPolicy{ | ||||
| 					TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}}, | ||||
| 				}, | ||||
| 				tag:              "tag:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				tag: "tag:test", | ||||
| 			}, | ||||
| 			want:    []string{}, | ||||
| 			wantErr: true, | ||||
| @ -653,8 +649,7 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 					Groups:    Groups{"group:bar": []string{"user1", "user2"}}, | ||||
| 					TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}}, | ||||
| 				}, | ||||
| 				tag:              "tag:test", | ||||
| 				stripEmailDomain: true, | ||||
| 				tag: "tag:test", | ||||
| 			}, | ||||
| 			want:    []string{}, | ||||
| 			wantErr: true, | ||||
| @ -665,7 +660,6 @@ func Test_expandTagOwners(t *testing.T) { | ||||
| 			got, err := getTagOwners( | ||||
| 				test.args.aclPolicy, | ||||
| 				test.args.tag, | ||||
| 				test.args.stripEmailDomain, | ||||
| 			) | ||||
| 			if (err != nil) != test.wantErr { | ||||
| 				t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) | ||||
| @ -861,10 +855,9 @@ func Test_expandAlias(t *testing.T) { | ||||
| 		pol ACLPolicy | ||||
| 	} | ||||
| 	type args struct { | ||||
| 		machines         types.Machines | ||||
| 		aclPolicy        ACLPolicy | ||||
| 		alias            string | ||||
| 		stripEmailDomain bool | ||||
| 		machines  types.Machines | ||||
| 		aclPolicy ACLPolicy | ||||
| 		alias     string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| @ -888,7 +881,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{}, []string{ | ||||
| 				"0.0.0.0/0", | ||||
| @ -931,7 +923,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"100.64.0.1", "100.64.0.2", "100.64.0.3", | ||||
| @ -973,7 +964,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    set([]string{}, []string{}), | ||||
| 			wantErr: true, | ||||
| @ -984,9 +974,8 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				pol: ACLPolicy{}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				alias:            "10.0.0.3", | ||||
| 				machines:         types.Machines{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				alias:    "10.0.0.3", | ||||
| 				machines: types.Machines{}, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"10.0.0.3", | ||||
| @ -999,9 +988,8 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				pol: ACLPolicy{}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				alias:            "10.0.0.1", | ||||
| 				machines:         types.Machines{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				alias:    "10.0.0.1", | ||||
| 				machines: types.Machines{}, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"10.0.0.1", | ||||
| @ -1023,7 +1011,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"10.0.0.1", | ||||
| @ -1046,7 +1033,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", | ||||
| @ -1069,7 +1055,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", | ||||
| @ -1086,9 +1071,8 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				alias:            "testy", | ||||
| 				machines:         types.Machines{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				alias:    "testy", | ||||
| 				machines: types.Machines{}, | ||||
| 			}, | ||||
| 			want:    set([]string{}, []string{"10.0.0.132/32"}), | ||||
| 			wantErr: false, | ||||
| @ -1103,9 +1087,8 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				alias:            "homeNetwork", | ||||
| 				machines:         types.Machines{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				alias:    "homeNetwork", | ||||
| 				machines: types.Machines{}, | ||||
| 			}, | ||||
| 			want:    set([]string{}, []string{"192.168.1.0/24"}), | ||||
| 			wantErr: false, | ||||
| @ -1116,10 +1099,9 @@ func Test_expandAlias(t *testing.T) { | ||||
| 				pol: ACLPolicy{}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				alias:            "10.0.0.0/16", | ||||
| 				machines:         types.Machines{}, | ||||
| 				aclPolicy:        ACLPolicy{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				alias:     "10.0.0.0/16", | ||||
| 				machines:  types.Machines{}, | ||||
| 				aclPolicy: ACLPolicy{}, | ||||
| 			}, | ||||
| 			want:    set([]string{}, []string{"10.0.0.0/16"}), | ||||
| 			wantErr: false, | ||||
| @ -1169,7 +1151,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: set([]string{ | ||||
| 				"100.64.0.1", "100.64.0.2", | ||||
| @ -1214,7 +1195,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    set([]string{}, []string{}), | ||||
| 			wantErr: true, | ||||
| @ -1254,7 +1234,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), | ||||
| 			wantErr: false, | ||||
| @ -1302,7 +1281,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), | ||||
| 			wantErr: false, | ||||
| @ -1352,7 +1330,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    set([]string{"100.64.0.4"}, []string{}), | ||||
| 			wantErr: false, | ||||
| @ -1363,7 +1340,6 @@ func Test_expandAlias(t *testing.T) { | ||||
| 			got, err := test.field.pol.ExpandAlias( | ||||
| 				test.args.machines, | ||||
| 				test.args.alias, | ||||
| 				test.args.stripEmailDomain, | ||||
| 			) | ||||
| 			if (err != nil) != test.wantErr { | ||||
| 				t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) | ||||
| @ -1379,10 +1355,9 @@ func Test_expandAlias(t *testing.T) { | ||||
| 
 | ||||
| func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		aclPolicy        *ACLPolicy | ||||
| 		nodes            types.Machines | ||||
| 		user             string | ||||
| 		stripEmailDomain bool | ||||
| 		aclPolicy *ACLPolicy | ||||
| 		nodes     types.Machines | ||||
| 		user      string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| @ -1426,8 +1401,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user:             "joe", | ||||
| 				stripEmailDomain: true, | ||||
| 				user: "joe", | ||||
| 			}, | ||||
| 			want: types.Machines{ | ||||
| 				{ | ||||
| @ -1477,8 +1451,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user:             "joe", | ||||
| 				stripEmailDomain: true, | ||||
| 				user: "joe", | ||||
| 			}, | ||||
| 			want: types.Machines{ | ||||
| 				{ | ||||
| @ -1519,8 +1492,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user:             "joe", | ||||
| 				stripEmailDomain: true, | ||||
| 				user: "joe", | ||||
| 			}, | ||||
| 			want: types.Machines{ | ||||
| 				{ | ||||
| @ -1565,8 +1537,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user:             "joe", | ||||
| 				stripEmailDomain: true, | ||||
| 				user: "joe", | ||||
| 			}, | ||||
| 			want: types.Machines{ | ||||
| 				{ | ||||
| @ -1606,7 +1577,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 				test.args.aclPolicy, | ||||
| 				test.args.nodes, | ||||
| 				test.args.user, | ||||
| 				test.args.stripEmailDomain, | ||||
| 			) | ||||
| 			if !reflect.DeepEqual(got, test.want) { | ||||
| 				t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want) | ||||
| @ -1620,9 +1590,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 		pol ACLPolicy | ||||
| 	} | ||||
| 	type args struct { | ||||
| 		machine          types.Machine | ||||
| 		peers            types.Machines | ||||
| 		stripEmailDomain bool | ||||
| 		machine types.Machine | ||||
| 		peers   types.Machines | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| @ -1652,9 +1621,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				machine:          types.Machine{}, | ||||
| 				peers:            types.Machines{}, | ||||
| 				stripEmailDomain: true, | ||||
| 				machine: types.Machine{}, | ||||
| 				peers:   types.Machines{}, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| 				{ | ||||
| @ -1709,7 +1677,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 						User: types.User{Name: "mickael"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want: []tailcfg.FilterRule{ | ||||
| 				{ | ||||
| @ -1743,7 +1710,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 			got, err := tt.field.pol.generateFilterRules( | ||||
| 				&tt.args.machine, | ||||
| 				tt.args.peers, | ||||
| 				tt.args.stripEmailDomain, | ||||
| 			) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) | ||||
| @ -1761,9 +1727,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { | ||||
| 
 | ||||
| func Test_getTags(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		aclPolicy        *ACLPolicy | ||||
| 		machine          types.Machine | ||||
| 		stripEmailDomain bool | ||||
| 		aclPolicy *ACLPolicy | ||||
| 		machine   types.Machine | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name        string | ||||
| @ -1787,7 +1752,6 @@ func Test_getTags(t *testing.T) { | ||||
| 						RequestTags: []string{"tag:valid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			wantValid:   []string{"tag:valid"}, | ||||
| 			wantInvalid: nil, | ||||
| @ -1808,7 +1772,6 @@ func Test_getTags(t *testing.T) { | ||||
| 						RequestTags: []string{"tag:valid", "tag:invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			wantValid:   []string{"tag:valid"}, | ||||
| 			wantInvalid: []string{"tag:invalid"}, | ||||
| @ -1833,7 +1796,6 @@ func Test_getTags(t *testing.T) { | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			wantValid:   []string{"tag:valid"}, | ||||
| 			wantInvalid: []string{"tag:invalid"}, | ||||
| @ -1854,7 +1816,6 @@ func Test_getTags(t *testing.T) { | ||||
| 						RequestTags: []string{"tag:invalid", "very-invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			wantValid:   nil, | ||||
| 			wantInvalid: []string{"tag:invalid", "very-invalid"}, | ||||
| @ -1871,7 +1832,6 @@ func Test_getTags(t *testing.T) { | ||||
| 						RequestTags: []string{"tag:invalid", "very-invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			wantValid:   nil, | ||||
| 			wantInvalid: []string{"tag:invalid", "very-invalid"}, | ||||
| @ -1881,7 +1841,6 @@ func Test_getTags(t *testing.T) { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( | ||||
| 				test.args.machine, | ||||
| 				test.args.stripEmailDomain, | ||||
| 			) | ||||
| 			for _, valid := range gotValid { | ||||
| 				if !util.StringOrPrefixListContains(test.wantValid, valid) { | ||||
| @ -2589,7 +2548,7 @@ func TestSSHRules(t *testing.T) { | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers, false) | ||||
| 			got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			if diff := cmp.Diff(tt.want, got); diff != "" { | ||||
|  | ||||
| @ -40,7 +40,6 @@ func (h *Headscale) handlePoll( | ||||
| 		h.cfg.DNSConfig, | ||||
| 		h.cfg.LogTail.Enabled, | ||||
| 		h.cfg.RandomizeClientPort, | ||||
| 		h.cfg.OIDC.StripEmaildomain, | ||||
| 	) | ||||
| 
 | ||||
| 	machine.Hostname = mapRequest.Hostinfo.Hostname | ||||
| @ -265,7 +264,6 @@ func (h *Headscale) pollNetMapStream( | ||||
| 		h.cfg.DNSConfig, | ||||
| 		h.cfg.LogTail.Enabled, | ||||
| 		h.cfg.RandomizeClientPort, | ||||
| 		h.cfg.OIDC.StripEmaildomain, | ||||
| 	) | ||||
| 
 | ||||
| 	h.pollNetMapStreamWG.Add(1) | ||||
| @ -656,7 +654,6 @@ func (h *Headscale) scheduledPollWorker( | ||||
| 		h.cfg.DNSConfig, | ||||
| 		h.cfg.LogTail.Enabled, | ||||
| 		h.cfg.RandomizeClientPort, | ||||
| 		h.cfg.OIDC.StripEmaildomain, | ||||
| 	) | ||||
| 
 | ||||
| 	keepAliveTicker := time.NewTicker(keepAliveInterval) | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/spf13/viper" | ||||
| 	"go4.org/netipx" | ||||
| 	"tailscale.com/util/dnsname" | ||||
| ) | ||||
| @ -24,6 +25,12 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") | ||||
| 
 | ||||
| var ErrInvalidUserName = errors.New("invalid user name") | ||||
| 
 | ||||
| func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) { | ||||
| 	strip := viper.GetBool("oidc.strip_email_domain") | ||||
| 
 | ||||
| 	return NormalizeToFQDNRules(name, strip) | ||||
| } | ||||
| 
 | ||||
| // NormalizeToFQDNRules will replace forbidden chars in user | ||||
| // it can also return an error if the user doesn't respect RFC 952 and 1123. | ||||
| func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user