mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	Split code into modules
This is a massive commit that restructures the code into modules:
db/
    All functions related to modifying the Database
types/
    All type definitions and methods that can be exclusivly used on
    these types without dependencies
policy/
    All Policy related code, now without dependencies on the Database.
policy/matcher/
    Dedicated code to match machines in a list of FilterRules
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
			
			
This commit is contained in:
		
							parent
							
								
									14e29a7bee
								
							
						
					
					
						commit
						feb15365b5
					
				| @ -7,7 +7,7 @@ import ( | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/pterm/pterm" | ||||
| 	"github.com/spf13/cobra" | ||||
| 	"google.golang.org/grpc/status" | ||||
| @ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData { | ||||
| 
 | ||||
| 			continue | ||||
| 		} | ||||
| 		if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 { | ||||
| 		if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 { | ||||
| 			isPrimaryStr = "-" | ||||
| 		} else { | ||||
| 			isPrimaryStr = strconv.FormatBool(route.IsPrimary) | ||||
|  | ||||
| @ -10,6 +10,7 @@ import ( | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"google.golang.org/grpc" | ||||
| @ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { | ||||
| 
 | ||||
| 	if cfg.ACL.PolicyPath != "" { | ||||
| 		aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) | ||||
| 		err = app.LoadACLPolicyFromPath(aclPath) | ||||
| 		pol, err := policy.LoadACLPolicyFromPath(aclPath) | ||||
| 		if err != nil { | ||||
| 			log.Fatal(). | ||||
| 				Str("path", aclPath). | ||||
| 				Err(err). | ||||
| 				Msg("Could not load the ACL policy") | ||||
| 		} | ||||
| 
 | ||||
| 		app.ACLPolicy = pol | ||||
| 	} | ||||
| 
 | ||||
| 	return app, nil | ||||
|  | ||||
| @ -18,9 +18,6 @@ const ( | ||||
| 	// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. | ||||
| 	registrationHoldoff        = time.Second * 5 | ||||
| 	reservedResponseHeaderSize = 4 | ||||
| 	RegisterMethodAuthKey      = "authkey" | ||||
| 	RegisterMethodOIDC         = "oidc" | ||||
| 	RegisterMethodCLI          = "cli" | ||||
| ) | ||||
| 
 | ||||
| var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( | ||||
| @ -56,7 +53,7 @@ func (h *Headscale) HealthHandler( | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := h.db.pingDB(req.Context()); err != nil { | ||||
| 	if err := h.db.PingDB(req.Context()); err != nil { | ||||
| 		respond(err) | ||||
| 
 | ||||
| 		return | ||||
|  | ||||
| @ -3,6 +3,7 @@ package hscontrol | ||||
| import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"tailscale.com/tailcfg" | ||||
| @ -10,13 +11,13 @@ import ( | ||||
| 
 | ||||
| func (h *Headscale) generateMapResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| ) (*tailcfg.MapResponse, error) { | ||||
| 	log.Trace(). | ||||
| 		Str("func", "generateMapResponse"). | ||||
| 		Str("machine", mapRequest.Hostinfo.Hostname). | ||||
| 		Msg("Creating Map response") | ||||
| 	node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) | ||||
| 	node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) | ||||
| 	peers, err := h.db.GetValidPeers(h.aclRules, machine) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	profiles := h.db.getMapResponseUserProfiles(*machine, peers) | ||||
| 	profiles := h.db.GetMapResponseUserProfiles(*machine, peers) | ||||
| 
 | ||||
| 	nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) | ||||
| 	nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
|  | ||||
							
								
								
									
										166
									
								
								hscontrol/app.go
									
									
									
									
									
								
							
							
						
						
									
										166
									
								
								hscontrol/app.go
									
									
									
									
									
								
							| @ -23,6 +23,9 @@ import ( | ||||
| 	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" | ||||
| 	"github.com/juanfont/headscale" | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/db" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/patrickmn/go-cache" | ||||
| 	zerolog "github.com/philip-bui/grpc-zerolog" | ||||
| @ -73,7 +76,7 @@ const ( | ||||
| // Headscale represents the base app of the service. | ||||
| type Headscale struct { | ||||
| 	cfg             *Config | ||||
| 	db              *HSDatabase | ||||
| 	db              *db.HSDatabase | ||||
| 	dbString        string | ||||
| 	dbType          string | ||||
| 	dbDebug         bool | ||||
| @ -83,7 +86,7 @@ type Headscale struct { | ||||
| 	DERPMap    *tailcfg.DERPMap | ||||
| 	DERPServer *DERPServer | ||||
| 
 | ||||
| 	aclPolicy *ACLPolicy | ||||
| 	ACLPolicy *policy.ACLPolicy | ||||
| 	aclRules  []tailcfg.FilterRule | ||||
| 	sshPolicy *tailcfg.SSHPolicy | ||||
| 
 | ||||
| @ -99,6 +102,12 @@ type Headscale struct { | ||||
| 
 | ||||
| 	stateUpdateChan       chan struct{} | ||||
| 	cancelStateUpdateChan chan struct{} | ||||
| 
 | ||||
| 	// TODO(kradalby): Temporary measure to make sure we can update policy | ||||
| 	// across modules, will be removed when aclRules are no longer stored | ||||
| 	// globally but generated per node basis. | ||||
| 	policyUpdateChan       chan struct{} | ||||
| 	cancelPolicyUpdateChan chan struct{} | ||||
| } | ||||
| 
 | ||||
| func NewHeadscale(cfg *Config) (*Headscale, error) { | ||||
| @ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | ||||
| 
 | ||||
| 	var dbString string | ||||
| 	switch cfg.DBtype { | ||||
| 	case Postgres: | ||||
| 	case db.Postgres: | ||||
| 		dbString = fmt.Sprintf( | ||||
| 			"host=%s dbname=%s user=%s", | ||||
| 			cfg.DBhost, | ||||
| @ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | ||||
| 		if cfg.DBpass != "" { | ||||
| 			dbString += fmt.Sprintf(" password=%s", cfg.DBpass) | ||||
| 		} | ||||
| 	case Sqlite: | ||||
| 	case db.Sqlite: | ||||
| 		dbString = cfg.DBpath | ||||
| 	default: | ||||
| 		return nil, errUnsupportedDatabase | ||||
| @ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | ||||
| 
 | ||||
| 		stateUpdateChan:       make(chan struct{}), | ||||
| 		cancelStateUpdateChan: make(chan struct{}), | ||||
| 
 | ||||
| 		policyUpdateChan:       make(chan struct{}), | ||||
| 		cancelPolicyUpdateChan: make(chan struct{}), | ||||
| 	} | ||||
| 
 | ||||
| 	go app.watchStateChannel() | ||||
| 	go app.watchPolicyChannel() | ||||
| 
 | ||||
| 	db, err := NewHeadscaleDatabase( | ||||
| 	database, err := db.NewHeadscaleDatabase( | ||||
| 		cfg.DBtype, | ||||
| 		dbString, | ||||
| 		cfg.OIDC.StripEmaildomain, | ||||
| 		app.dbDebug, | ||||
| 		app.stateUpdateChan, | ||||
| 		app.policyUpdateChan, | ||||
| 		cfg.IPPrefixes, | ||||
| 		cfg.BaseDomain) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	app.db = db | ||||
| 	app.db = database | ||||
| 
 | ||||
| 	if cfg.OIDC.Issuer != "" { | ||||
| 		err = app.initOIDC() | ||||
| @ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { | ||||
| func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { | ||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||
| 	for range ticker.C { | ||||
| 		h.expireEphemeralNodesWorker() | ||||
| 		h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { | ||||
| func (h *Headscale) expireExpiredMachines(milliSeconds int64) { | ||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||
| 	for range ticker.C { | ||||
| 		h.expireExpiredMachinesWorker() | ||||
| 		h.db.ExpireExpiredMachines(h.getLastStateChange()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { | ||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||
| 	for range ticker.C { | ||||
| 		err := h.db.handlePrimarySubnetFailover() | ||||
| 		err := h.db.HandlePrimarySubnetFailover() | ||||
| 		if err != nil { | ||||
| 			log.Error().Err(err).Msg("failed to handle primary subnet failover") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) expireEphemeralNodesWorker() { | ||||
| 	users, err := h.db.ListUsers() | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Error listing users") | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range users { | ||||
| 		machines, err := h.db.ListMachinesByUser(user.Name) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Err(err). | ||||
| 				Str("user", user.Name). | ||||
| 				Msg("Error listing machines in user") | ||||
| 
 | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		expiredFound := false | ||||
| 		for _, machine := range machines { | ||||
| 			if machine.isEphemeral() && machine.LastSeen != nil && | ||||
| 				time.Now(). | ||||
| 					After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { | ||||
| 				expiredFound = true | ||||
| 				log.Info(). | ||||
| 					Str("machine", machine.Hostname). | ||||
| 					Msg("Ephemeral client removed from database") | ||||
| 
 | ||||
| 				err = h.db.db.Unscoped().Delete(machine).Error | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Err(err). | ||||
| 						Str("machine", machine.Hostname). | ||||
| 						Msg("🤮 Cannot delete ephemeral machine from the database") | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if expiredFound { | ||||
| 			h.setLastStateChangeToNow() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) expireExpiredMachinesWorker() { | ||||
| 	users, err := h.db.ListUsers() | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Error listing users") | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range users { | ||||
| 		machines, err := h.db.ListMachinesByUser(user.Name) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Err(err). | ||||
| 				Str("user", user.Name). | ||||
| 				Msg("Error listing machines in user") | ||||
| 
 | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		expiredFound := false | ||||
| 		for index, machine := range machines { | ||||
| 			if machine.isExpired() && | ||||
| 				machine.Expiry.After(h.getLastStateChange(user)) { | ||||
| 				expiredFound = true | ||||
| 
 | ||||
| 				err := h.db.ExpireMachine(&machines[index]) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Err(err). | ||||
| 						Str("machine", machine.Hostname). | ||||
| 						Str("name", machine.GivenName). | ||||
| 						Msg("🤮 Cannot expire machine") | ||||
| 				} else { | ||||
| 					log.Info(). | ||||
| 						Str("machine", machine.Hostname). | ||||
| 						Str("name", machine.GivenName). | ||||
| 						Msg("Machine successfully expired") | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if expiredFound { | ||||
| 			h.setLastStateChangeToNow() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, | ||||
| 	req interface{}, | ||||
| 	info *grpc.UnaryServerInfo, | ||||
| @ -565,6 +487,8 @@ func (h *Headscale) Serve() error { | ||||
| 		go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO(kradalby): These should have cancel channels and be cleaned | ||||
| 	// up on shutdown. | ||||
| 	go h.expireEphemeralNodes(updateInterval) | ||||
| 	go h.expireExpiredMachines(updateInterval) | ||||
| 
 | ||||
| @ -774,10 +698,12 @@ func (h *Headscale) Serve() error { | ||||
| 
 | ||||
| 				if h.cfg.ACL.PolicyPath != "" { | ||||
| 					aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) | ||||
| 					err := h.LoadACLPolicyFromPath(aclPath) | ||||
| 					pol, err := policy.LoadACLPolicyFromPath(aclPath) | ||||
| 					if err != nil { | ||||
| 						log.Error().Err(err).Msg("Failed to reload ACL policy") | ||||
| 					} | ||||
| 
 | ||||
| 					h.ACLPolicy = pol | ||||
| 					log.Info(). | ||||
| 						Str("path", aclPath). | ||||
| 						Msg("ACL policy successfully reloaded, notifying nodes of change") | ||||
| @ -824,12 +750,12 @@ func (h *Headscale) Serve() error { | ||||
| 				close(h.stateUpdateChan) | ||||
| 				close(h.cancelStateUpdateChan) | ||||
| 
 | ||||
| 				<-h.cancelPolicyUpdateChan | ||||
| 				close(h.policyUpdateChan) | ||||
| 				close(h.cancelPolicyUpdateChan) | ||||
| 
 | ||||
| 				// Close db connections | ||||
| 				db, err := h.db.db.DB() | ||||
| 				if err != nil { | ||||
| 					log.Error().Err(err).Msg("Failed to get db handle") | ||||
| 				} | ||||
| 				err = db.Close() | ||||
| 				err = h.db.Close() | ||||
| 				if err != nil { | ||||
| 					log.Error().Err(err).Msg("Failed to close db") | ||||
| 				} | ||||
| @ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): baby steps, make this more robust. | ||||
| func (h *Headscale) watchPolicyChannel() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-h.policyUpdateChan: | ||||
| 			machines, err := h.db.ListMachines() | ||||
| 			if err != nil { | ||||
| 				log.Error().Err(err).Msg("failed to fetch machines during policy update") | ||||
| 			} | ||||
| 
 | ||||
| 			rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain) | ||||
| 			if err != nil { | ||||
| 				log.Error().Err(err).Msg("failed to update ACL rules") | ||||
| 			} | ||||
| 
 | ||||
| 			h.aclRules = rules | ||||
| 			h.sshPolicy = sshPolicy | ||||
| 
 | ||||
| 		case <-h.cancelPolicyUpdateChan: | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) setLastStateChangeToNow() { | ||||
| 	var err error | ||||
| 
 | ||||
| @ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) getLastStateChange(users ...User) time.Time { | ||||
| func (h *Headscale) getLastStateChange(users ...types.User) time.Time { | ||||
| 	times := []time.Time{} | ||||
| 
 | ||||
| 	// getLastStateChange takes a list of users as a "filter", if no users | ||||
|  | ||||
							
								
								
									
										480
									
								
								hscontrol/db/acls_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										480
									
								
								hscontrol/db/acls_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,480 @@ | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"tailscale.com/envknob" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| // TODO(kradalby): | ||||
| // Convert these tests to being non-database dependent and table driven. They are | ||||
| // very verbose, and dont really need the database. | ||||
| 
 | ||||
| func (s *Suite) TestSshRules(c *check.C) { | ||||
| 	envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") | ||||
| 
 | ||||
| 	user, err := db.CreateUser("user1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	hostInfo := tailcfg.Hostinfo{ | ||||
| 		OS:          "centos", | ||||
| 		Hostname:    "testmachine", | ||||
| 		RequestTags: []string{"tag:test"}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	aclPolicy := &policy.ACLPolicy{ | ||||
| 		Groups: policy.Groups{ | ||||
| 			"group:test": []string{"user1"}, | ||||
| 		}, | ||||
| 		Hosts: policy.Hosts{ | ||||
| 			"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), | ||||
| 		}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"*"}, | ||||
| 				Destinations: []string{"*:*"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		SSHs: []policy.SSH{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"group:test"}, | ||||
| 				Destinations: []string{"client"}, | ||||
| 				Users:        []string{"autogroup:nonroot"}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"*"}, | ||||
| 				Destinations: []string{"client"}, | ||||
| 				Users:        []string{"autogroup:nonroot"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(sshPolicy, check.NotNil) | ||||
| 	c.Assert(sshPolicy.Rules, check.HasLen, 2) | ||||
| 	c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) | ||||
| 	c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1) | ||||
| 	c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") | ||||
| 
 | ||||
| 	c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) | ||||
| 	c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1) | ||||
| 	c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") | ||||
| } | ||||
| 
 | ||||
| // this test should validate that we can expand a group in a TagOWner section and | ||||
| // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. | ||||
| // the tag is matched in the Sources section. | ||||
| func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { | ||||
| 	user, err := db.CreateUser("user1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	hostInfo := tailcfg.Hostinfo{ | ||||
| 		OS:          "centos", | ||||
| 		Hostname:    "testmachine", | ||||
| 		RequestTags: []string{"tag:test"}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pol := &policy.ACLPolicy{ | ||||
| 		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}}, | ||||
| 		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"tag:test"}, | ||||
| 				Destinations: []string{"*:*"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") | ||||
| } | ||||
| 
 | ||||
| // this test should validate that we can expand a group in a TagOWner section and | ||||
| // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. | ||||
| // the tag is matched in the Destinations section. | ||||
| func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { | ||||
| 	user, err := db.CreateUser("user1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	hostInfo := tailcfg.Hostinfo{ | ||||
| 		OS:          "centos", | ||||
| 		Hostname:    "testmachine", | ||||
| 		RequestTags: []string{"tag:test"}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "12345", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pol := &policy.ACLPolicy{ | ||||
| 		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}}, | ||||
| 		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"*"}, | ||||
| 				Destinations: []string{"tag:test:*"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||
| } | ||||
| 
 | ||||
| // need a test with: | ||||
| // tag on a host that isn't owned by a tag owners. So the user | ||||
| // of the host should be valid. | ||||
| func (s *Suite) TestInvalidTagValidUser(c *check.C) { | ||||
| 	user, err := db.CreateUser("user1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	hostInfo := tailcfg.Hostinfo{ | ||||
| 		OS:          "centos", | ||||
| 		Hostname:    "testmachine", | ||||
| 		RequestTags: []string{"tag:foo"}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "12345", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pol := &policy.ACLPolicy{ | ||||
| 		TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"user1"}, | ||||
| 				Destinations: []string{"*:*"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") | ||||
| } | ||||
| 
 | ||||
| // tag on a host is owned by a tag owner, the tag is valid. | ||||
| // an ACL rule is matching the tag to a user. It should not be valid since the | ||||
| // host should be tied to the tag now. | ||||
| func (s *Suite) TestValidTagInvalidUser(c *check.C) { | ||||
| 	user, err := db.CreateUser("user1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "webserver") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	hostInfo := tailcfg.Hostinfo{ | ||||
| 		OS:          "centos", | ||||
| 		Hostname:    "webserver", | ||||
| 		RequestTags: []string{"tag:webapp"}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "12345", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "webserver", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user1", "user") | ||||
| 	hostInfo2 := tailcfg.Hostinfo{ | ||||
| 		OS:       "debian", | ||||
| 		Hostname: "Hostname", | ||||
| 	} | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	machine = types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "56789", | ||||
| 		NodeKey:        "bar2", | ||||
| 		DiscoKey:       "faab", | ||||
| 		Hostname:       "user", | ||||
| 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pol := &policy.ACLPolicy{ | ||||
| 		TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"user1"}, | ||||
| 				Destinations: []string{"tag:webapp:80,443"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") | ||||
| 	c.Assert(rules[0].DstPorts, check.HasLen, 2) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) | ||||
| 	c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||
| 	c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) | ||||
| 	c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) | ||||
| 	c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestPortUser(c *check.C) { | ||||
| 	user, err := db.CreateUser("testuser") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("testuser", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	ips, _ := db.getAvailableIPs() | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "12345", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    ips, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	acl := []byte(` | ||||
| { | ||||
| 	"hosts": { | ||||
| 		"host-1": "100.100.100.100", | ||||
| 		"subnet-1": "100.100.101.100/24", | ||||
| 	}, | ||||
| 
 | ||||
| 	"acls": [ | ||||
| 		{ | ||||
| 			"action": "accept", | ||||
| 			"src": [ | ||||
| 				"testuser", | ||||
| 			], | ||||
| 			"dst": [ | ||||
| 				"host-1:*", | ||||
| 			], | ||||
| 		}, | ||||
| 	], | ||||
| } | ||||
| 	`) | ||||
| 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) | ||||
| 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") | ||||
| 	c.Assert(len(ips), check.Equals, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestPortGroup(c *check.C) { | ||||
| 	user, err := db.CreateUser("testuser") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("testuser", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	ips, _ := db.getAvailableIPs() | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    ips, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	err = db.MachineSave(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	acl := []byte(` | ||||
| { | ||||
| 	"groups": { | ||||
| 		"group:example": [ | ||||
| 			"testuser", | ||||
| 		], | ||||
| 	}, | ||||
| 
 | ||||
| 	"hosts": { | ||||
| 		"host-1": "100.100.100.100", | ||||
| 		"subnet-1": "100.100.101.100/24", | ||||
| 	}, | ||||
| 
 | ||||
| 	"acls": [ | ||||
| 		{ | ||||
| 			"action": "accept", | ||||
| 			"src": [ | ||||
| 				"group:example", | ||||
| 			], | ||||
| 			"dst": [ | ||||
| 				"host-1:*", | ||||
| 			], | ||||
| 		}, | ||||
| 	], | ||||
| } | ||||
| 	`) | ||||
| 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(rules, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) | ||||
| 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) | ||||
| 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") | ||||
| 	c.Assert(len(ips), check.Equals, 1) | ||||
| 	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") | ||||
| } | ||||
| @ -3,21 +3,22 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
| 
 | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"go4.org/netipx" | ||||
| ) | ||||
| 
 | ||||
| var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { | ||||
| 	var ips MachineAddresses | ||||
| func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { | ||||
| 	var ips types.MachineAddresses | ||||
| 	var err error | ||||
| 	for _, ipPrefix := range hsdb.ipPrefixes { | ||||
| 		var ip *netip.Addr | ||||
| @ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { | ||||
| 	// but this was quick to get running and it should be enough | ||||
| 	// to begin experimenting with a dual stack tailnet. | ||||
| 	var addressesSlices []string | ||||
| 	hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) | ||||
| 	hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices) | ||||
| 
 | ||||
| 	var ips netipx.IPSetBuilder | ||||
| 	for _, slice := range addressesSlices { | ||||
| 		var machineAddresses MachineAddresses | ||||
| 		var machineAddresses types.MachineAddresses | ||||
| 		err := machineAddresses.Scan(slice) | ||||
| 		if err != nil { | ||||
| 			return &netipx.IPSet{}, fmt.Errorf( | ||||
| @ -1,14 +1,16 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"go4.org/netipx" | ||||
| 	"gopkg.in/check.v1" | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestGetAvailableIp(c *check.C) { | ||||
| 	ips, err := app.db.getAvailableIPs() | ||||
| 	ips, err := db.getAvailableIPs() | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| @ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetUsedIps(c *check.C) { | ||||
| 	ips, err := app.db.getAvailableIPs() | ||||
| 	ips, err := db.getAvailableIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	user, err := app.db.CreateUser("test-ip") | ||||
| 	user, err := db.CreateUser("test-ip") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "testmachine") | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		IPAddresses:    ips, | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	usedIps, err := app.db.getUsedIPs() | ||||
| 	usedIps, err := db.getUsedIPs() | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| @ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) { | ||||
| 	c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) | ||||
| 	c.Assert(usedIps.Contains(expected), check.Equals, true) | ||||
| 
 | ||||
| 	machine1, err := app.db.GetMachineByID(0) | ||||
| 	machine1, err := db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | ||||
| 	c.Assert(machine1.IPAddresses[0], check.Equals, expected) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test-ip-multi") | ||||
| 	user, err := db.CreateUser("test-ip-multi") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	for index := 1; index <= 350; index++ { | ||||
| 		app.db.ipAllocationMutex.Lock() | ||||
| 		db.ipAllocationMutex.Lock() | ||||
| 
 | ||||
| 		ips, err := app.db.getAvailableIPs() | ||||
| 		ips, err := db.getAvailableIPs() | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 		pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 		pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 		_, err = app.db.GetMachine("test", "testmachine") | ||||
| 		_, err = db.GetMachine("test", "testmachine") | ||||
| 		c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 		machine := Machine{ | ||||
| 		machine := types.Machine{ | ||||
| 			ID:             uint64(index), | ||||
| 			MachineKey:     "foo", | ||||
| 			NodeKey:        "bar", | ||||
| 			DiscoKey:       "faa", | ||||
| 			Hostname:       "testmachine", | ||||
| 			UserID:         user.ID, | ||||
| 			RegisterMethod: RegisterMethodAuthKey, | ||||
| 			RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 			AuthKeyID:      uint(pak.ID), | ||||
| 			IPAddresses:    ips, | ||||
| 		} | ||||
| 		app.db.db.Save(&machine) | ||||
| 		db.db.Save(&machine) | ||||
| 
 | ||||
| 		app.db.ipAllocationMutex.Unlock() | ||||
| 		db.ipAllocationMutex.Unlock() | ||||
| 	} | ||||
| 
 | ||||
| 	usedIps, err := app.db.getUsedIPs() | ||||
| 	usedIps, err := db.getUsedIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	expected0 := netip.MustParseAddr("10.27.0.1") | ||||
| @ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 	c.Assert(usedIps.Contains(expected300), check.Equals, true) | ||||
| 
 | ||||
| 	// Check that we can read back the IPs | ||||
| 	machine1, err := app.db.GetMachineByID(1) | ||||
| 	machine1, err := db.GetMachineByID(1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | ||||
| 	c.Assert( | ||||
| @ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 		netip.MustParseAddr("10.27.0.1"), | ||||
| 	) | ||||
| 
 | ||||
| 	machine50, err := app.db.GetMachineByID(50) | ||||
| 	machine50, err := db.GetMachineByID(50) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(machine50.IPAddresses), check.Equals, 1) | ||||
| 	c.Assert( | ||||
| @ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 	) | ||||
| 
 | ||||
| 	expectedNextIP := netip.MustParseAddr("10.27.1.95") | ||||
| 	nextIP, err := app.db.getAvailableIPs() | ||||
| 	nextIP, err := db.getAvailableIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(nextIP), check.Equals, 1) | ||||
| @ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 
 | ||||
| 	// If we call get Available again, we should receive | ||||
| 	// the same IP, as it has not been reserved. | ||||
| 	nextIP2, err := app.db.getAvailableIPs() | ||||
| 	nextIP2, err := db.getAvailableIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(nextIP2), check.Equals, 1) | ||||
| 	c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { | ||||
| 	ips, err := app.db.getAvailableIPs() | ||||
| 	ips, err := db.getAvailableIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	expected := netip.MustParseAddr("10.27.0.1") | ||||
| @ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { | ||||
| 	c.Assert(len(ips), check.Equals, 1) | ||||
| 	c.Assert(ips[0].String(), check.Equals, expected.String()) | ||||
| 
 | ||||
| 	user, err := app.db.CreateUser("test-ip") | ||||
| 	user, err := db.CreateUser("test-ip") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "testmachine") | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	ips2, err := app.db.getAvailableIPs() | ||||
| 	ips2, err := db.getAvailableIPs() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(ips2), check.Equals, 1) | ||||
| 	c.Assert(ips2[0].String(), check.Equals, expected.String()) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| @ -1,4 +1,4 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| @ -6,10 +6,9 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| @ -19,22 +18,10 @@ const ( | ||||
| 
 | ||||
| var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") | ||||
| 
 | ||||
| // APIKey describes the datamodel for API keys used to remotely authenticate with | ||||
| // headscale. | ||||
| type APIKey struct { | ||||
| 	ID     uint64 `gorm:"primary_key"` | ||||
| 	Prefix string `gorm:"uniqueIndex"` | ||||
| 	Hash   []byte | ||||
| 
 | ||||
| 	CreatedAt  *time.Time | ||||
| 	Expiration *time.Time | ||||
| 	LastSeen   *time.Time | ||||
| } | ||||
| 
 | ||||
| // CreateAPIKey creates a new ApiKey in a user, and returns it. | ||||
| func (hsdb *HSDatabase) CreateAPIKey( | ||||
| 	expiration *time.Time, | ||||
| ) (string, *APIKey, error) { | ||||
| ) (string, *types.APIKey, error) { | ||||
| 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| @ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey( | ||||
| 		return "", nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	key := APIKey{ | ||||
| 	key := types.APIKey{ | ||||
| 		Prefix:     prefix, | ||||
| 		Hash:       hash, | ||||
| 		Expiration: expiration, | ||||
| @ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey( | ||||
| } | ||||
| 
 | ||||
| // ListAPIKeys returns the list of ApiKeys for a user. | ||||
| func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { | ||||
| 	keys := []APIKey{} | ||||
| func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | ||||
| 	keys := []types.APIKey{} | ||||
| 	if err := hsdb.db.Find(&keys).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { | ||||
| } | ||||
| 
 | ||||
| // GetAPIKey returns a ApiKey for a given key. | ||||
| func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { | ||||
| 	key := APIKey{} | ||||
| func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | ||||
| 	key := types.APIKey{} | ||||
| 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | ||||
| 		return nil, result.Error | ||||
| 	} | ||||
| @ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { | ||||
| } | ||||
| 
 | ||||
| // GetAPIKeyByID returns a ApiKey for a given id. | ||||
| func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { | ||||
| 	key := APIKey{} | ||||
| 	if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { | ||||
| func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | ||||
| 	key := types.APIKey{} | ||||
| 	if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { | ||||
| 		return nil, result.Error | ||||
| 	} | ||||
| 
 | ||||
| @ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { | ||||
| 
 | ||||
| // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | ||||
| // does not exist. | ||||
| func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { | ||||
| func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | ||||
| 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | ||||
| 		return result.Error | ||||
| 	} | ||||
| @ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { | ||||
| } | ||||
| 
 | ||||
| // ExpireAPIKey marks a ApiKey as expired. | ||||
| func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { | ||||
| func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | ||||
| 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { | ||||
| 
 | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (key *APIKey) toProto() *v1.ApiKey { | ||||
| 	protoKey := v1.ApiKey{ | ||||
| 		Id:     key.ID, | ||||
| 		Prefix: key.Prefix, | ||||
| 	} | ||||
| 
 | ||||
| 	if key.Expiration != nil { | ||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.CreatedAt != nil { | ||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.LastSeen != nil { | ||||
| 		protoKey.LastSeen = timestamppb.New(*key.LastSeen) | ||||
| 	} | ||||
| 
 | ||||
| 	return &protoKey | ||||
| } | ||||
| @ -1,4 +1,4 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"time" | ||||
| @ -7,7 +7,7 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func (*Suite) TestCreateAPIKey(c *check.C) { | ||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) | ||||
| 	apiKeyStr, apiKey, err := db.CreateAPIKey(nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey, check.NotNil) | ||||
| 
 | ||||
| @ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) { | ||||
| 	c.Assert(apiKey.Hash, check.NotNil) | ||||
| 	c.Assert(apiKeyStr, check.Not(check.Equals), "") | ||||
| 
 | ||||
| 	_, err = app.db.ListAPIKeys() | ||||
| 	_, err = db.ListAPIKeys() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	keys, err := app.db.ListAPIKeys() | ||||
| 	keys, err := db.ListAPIKeys() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(keys), check.Equals, 1) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { | ||||
| 	key, err := app.db.GetAPIKey("does-not-exist") | ||||
| 	key, err := db.GetAPIKey("does-not-exist") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	c.Assert(key, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestValidateAPIKeyOk(c *check.C) { | ||||
| 	nowPlus2 := time.Now().Add(2 * time.Hour) | ||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) | ||||
| 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey, check.NotNil) | ||||
| 
 | ||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | ||||
| 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(valid, check.Equals, true) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { | ||||
| 	nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) | ||||
| 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey, check.NotNil) | ||||
| 
 | ||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | ||||
| 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(valid, check.Equals, false) | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) | ||||
| 	apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey, check.NotNil) | ||||
| 
 | ||||
| 	validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) | ||||
| 	validNow, err := db.ValidateAPIKey(apiKeyStrNow) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(validNow, check.Equals, false) | ||||
| 
 | ||||
| 	validSilly, err := app.db.ValidateAPIKey("nota.validkey") | ||||
| 	validSilly, err := db.ValidateAPIKey("nota.validkey") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	c.Assert(validSilly, check.Equals, false) | ||||
| 
 | ||||
| 	validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") | ||||
| 	validWithErr, err := db.ValidateAPIKey("produceerrorkey") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 	c.Assert(validWithErr, check.Equals, false) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestExpireAPIKey(c *check.C) { | ||||
| 	nowPlus2 := time.Now().Add(2 * time.Hour) | ||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) | ||||
| 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey, check.NotNil) | ||||
| 
 | ||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | ||||
| 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(valid, check.Equals, true) | ||||
| 
 | ||||
| 	err = app.db.ExpireAPIKey(apiKey) | ||||
| 	err = db.ExpireAPIKey(apiKey) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(apiKey.Expiration, check.NotNil) | ||||
| 
 | ||||
| 	notValid, err := app.db.ValidateAPIKey(apiKeyStr) | ||||
| 	notValid, err := db.ValidateAPIKey(apiKeyStr) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(notValid, check.Equals, false) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| @ -1,9 +1,7 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| @ -11,11 +9,12 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/glebarez/sqlite" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"gorm.io/driver/postgres" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| @ -26,7 +25,6 @@ const ( | ||||
| 
 | ||||
| var ( | ||||
| 	errValueNotFound        = errors.New("not found") | ||||
| 	ErrCannotParsePrefix    = errors.New("cannot parse prefix") | ||||
| 	errDatabaseNotSupported = errors.New("database type not supported") | ||||
| ) | ||||
| 
 | ||||
| @ -38,8 +36,9 @@ type KV struct { | ||||
| } | ||||
| 
 | ||||
| type HSDatabase struct { | ||||
| 	db              *gorm.DB | ||||
| 	notifyStateChan chan<- struct{} | ||||
| 	db               *gorm.DB | ||||
| 	notifyStateChan  chan<- struct{} | ||||
| 	notifyPolicyChan chan<- struct{} | ||||
| 
 | ||||
| 	ipAllocationMutex sync.Mutex | ||||
| 
 | ||||
| @ -54,6 +53,7 @@ func NewHeadscaleDatabase( | ||||
| 	dbType, connectionAddr string, | ||||
| 	stripEmailDomain, debug bool, | ||||
| 	notifyStateChan chan<- struct{}, | ||||
| 	notifyPolicyChan chan<- struct{}, | ||||
| 	ipPrefixes []netip.Prefix, | ||||
| 	baseDomain string, | ||||
| ) (*HSDatabase, error) { | ||||
| @ -63,8 +63,9 @@ func NewHeadscaleDatabase( | ||||
| 	} | ||||
| 
 | ||||
| 	db := HSDatabase{ | ||||
| 		db:              dbConn, | ||||
| 		notifyStateChan: notifyStateChan, | ||||
| 		db:               dbConn, | ||||
| 		notifyStateChan:  notifyStateChan, | ||||
| 		notifyPolicyChan: notifyPolicyChan, | ||||
| 
 | ||||
| 		ipPrefixes:       ipPrefixes, | ||||
| 		baseDomain:       baseDomain, | ||||
| @ -79,30 +80,30 @@ func NewHeadscaleDatabase( | ||||
| 
 | ||||
| 	_ = dbConn.Migrator().RenameTable("namespaces", "users") | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(User{}) | ||||
| 	err = dbConn.AutoMigrate(types.User{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") | ||||
| 
 | ||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname") | ||||
| 
 | ||||
| 	// GivenName is used as the primary source of DNS names, make sure | ||||
| 	// the field is populated and normalized if it was not when the | ||||
| 	// machine was registered. | ||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") | ||||
| 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name") | ||||
| 
 | ||||
| 	// If the Machine table has a column for registered, | ||||
| 	// find all occourences of "false" and drop them. Then | ||||
| 	// remove the column. | ||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "registered") { | ||||
| 	if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") { | ||||
| 		log.Info(). | ||||
| 			Msg(`Database has legacy "registered" column in machine, removing...`) | ||||
| 
 | ||||
| 		machines := Machines{} | ||||
| 		machines := types.Machines{} | ||||
| 		if err := dbConn.Not("registered").Find(&machines).Error; err != nil { | ||||
| 			log.Error().Err(err).Msg("Error accessing db") | ||||
| 		} | ||||
| @ -112,7 +113,7 @@ func NewHeadscaleDatabase( | ||||
| 				Str("machine", machine.Hostname). | ||||
| 				Str("machine_key", machine.MachineKey). | ||||
| 				Msg("Deleting unregistered machine") | ||||
| 			if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { | ||||
| 			if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil { | ||||
| 				log.Error(). | ||||
| 					Err(err). | ||||
| 					Str("machine", machine.Hostname). | ||||
| @ -121,23 +122,23 @@ func NewHeadscaleDatabase( | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		err := dbConn.Migrator().DropColumn(&Machine{}, "registered") | ||||
| 		err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered") | ||||
| 		if err != nil { | ||||
| 			log.Error().Err(err).Msg("Error dropping registered column") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(&Route{}) | ||||
| 	err = dbConn.AutoMigrate(&types.Route{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { | ||||
| 	if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") { | ||||
| 		log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") | ||||
| 
 | ||||
| 		type MachineAux struct { | ||||
| 			ID            uint64 | ||||
| 			EnabledRoutes IPPrefixes | ||||
| 			EnabledRoutes types.IPPrefixes | ||||
| 		} | ||||
| 
 | ||||
| 		machinesAux := []MachineAux{} | ||||
| @ -157,8 +158,8 @@ func NewHeadscaleDatabase( | ||||
| 				} | ||||
| 
 | ||||
| 				err = dbConn.Preload("Machine"). | ||||
| 					Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). | ||||
| 					First(&Route{}). | ||||
| 					Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). | ||||
| 					First(&types.Route{}). | ||||
| 					Error | ||||
| 				if err == nil { | ||||
| 					log.Info(). | ||||
| @ -168,11 +169,11 @@ func NewHeadscaleDatabase( | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				route := Route{ | ||||
| 				route := types.Route{ | ||||
| 					MachineID:  machine.ID, | ||||
| 					Advertised: true, | ||||
| 					Enabled:    true, | ||||
| 					Prefix:     IPPrefix(prefix), | ||||
| 					Prefix:     types.IPPrefix(prefix), | ||||
| 				} | ||||
| 				if err := dbConn.Create(&route).Error; err != nil { | ||||
| 					log.Error().Err(err).Msg("Error creating route") | ||||
| @ -185,26 +186,26 @@ func NewHeadscaleDatabase( | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") | ||||
| 		err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes") | ||||
| 		if err != nil { | ||||
| 			log.Error().Err(err).Msg("Error dropping enabled_routes column") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(&Machine{}) | ||||
| 	err = dbConn.AutoMigrate(&types.Machine{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { | ||||
| 		machines := Machines{} | ||||
| 	if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") { | ||||
| 		machines := types.Machines{} | ||||
| 		if err := dbConn.Find(&machines).Error; err != nil { | ||||
| 			log.Error().Err(err).Msg("Error accessing db") | ||||
| 		} | ||||
| 
 | ||||
| 		for item, machine := range machines { | ||||
| 			if machine.GivenName == "" { | ||||
| 				normalizedHostname, err := NormalizeToFQDNRules( | ||||
| 				normalizedHostname, err := util.NormalizeToFQDNRules( | ||||
| 					machine.Hostname, | ||||
| 					stripEmailDomain, | ||||
| 				) | ||||
| @ -233,19 +234,19 @@ func NewHeadscaleDatabase( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(&PreAuthKey{}) | ||||
| 	err = dbConn.AutoMigrate(&types.PreAuthKey{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) | ||||
| 	err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	_ = dbConn.Migrator().DropTable("shared_machines") | ||||
| 
 | ||||
| 	err = dbConn.AutoMigrate(&APIKey{}) | ||||
| 	err = dbConn.AutoMigrate(&types.APIKey{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) pingDB(ctx context.Context) error { | ||||
| func (hsdb *HSDatabase) PingDB(ctx context.Context) error { | ||||
| 	ctx, cancel := context.WithTimeout(ctx, time.Second) | ||||
| 	defer cancel() | ||||
| 	sqlDB, err := hsdb.db.DB() | ||||
| @ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error { | ||||
| 	return sqlDB.PingContext(ctx) | ||||
| } | ||||
| 
 | ||||
| // This is a "wrapper" type around tailscales | ||||
| // Hostinfo to allow us to add database "serialization" | ||||
| // methods. This allows us to use a typed values throughout | ||||
| // the code and not have to marshal/unmarshal and error | ||||
| // check all over the code. | ||||
| type HostInfo tailcfg.Hostinfo | ||||
| 
 | ||||
| func (hi *HostInfo) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, hi) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), hi) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| func (hsdb *HSDatabase) Close() error { | ||||
| 	db, err := hsdb.db.DB() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (hi HostInfo) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(hi) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type IPPrefix netip.Prefix | ||||
| 
 | ||||
| func (i *IPPrefix) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case string: | ||||
| 		prefix, err := netip.ParsePrefix(value) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		*i = IPPrefix(prefix) | ||||
| 
 | ||||
| 		return nil | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i IPPrefix) Value() (driver.Value, error) { | ||||
| 	prefixStr := netip.Prefix(i).String() | ||||
| 
 | ||||
| 	return prefixStr, nil | ||||
| } | ||||
| 
 | ||||
| type IPPrefixes []netip.Prefix | ||||
| 
 | ||||
| func (i *IPPrefixes) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, i) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), i) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i IPPrefixes) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(i) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type StringList []string | ||||
| 
 | ||||
| func (i *StringList) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, i) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), i) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i StringList) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(i) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| 
 | ||||
| 	return db.Close() | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										797
									
								
								hscontrol/db/machine_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										797
									
								
								hscontrol/db/machine_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,797 @@ | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/key" | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestGetMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := &types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(machine) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMachineByID(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMachineByNodeKey(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	nodeKey := key.NewNode() | ||||
| 	machineKey := key.NewMachine() | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||
| 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByNodeKey(nodeKey.Public()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	nodeKey := key.NewNode() | ||||
| 	oldNodeKey := key.NewNode() | ||||
| 
 | ||||
| 	machineKey := key.NewMachine() | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||
| 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestDeleteMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(1), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.DeleteMachine(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(user.Name, "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestHardDeleteMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine3", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(1), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.HardDeleteMachine(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(user.Name, "testmachine3") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestListPeers(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	for index := 0; index <= 10; index++ { | ||||
| 		machine := types.Machine{ | ||||
| 			ID:             uint64(index), | ||||
| 			MachineKey:     "foo" + strconv.Itoa(index), | ||||
| 			NodeKey:        "bar" + strconv.Itoa(index), | ||||
| 			DiscoKey:       "faa" + strconv.Itoa(index), | ||||
| 			Hostname:       "testmachine" + strconv.Itoa(index), | ||||
| 			UserID:         user.ID, | ||||
| 			RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 			AuthKeyID:      uint(pak.ID), | ||||
| 		} | ||||
| 		db.db.Save(&machine) | ||||
| 	} | ||||
| 
 | ||||
| 	machine0ByID, err := db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	peersOfMachine0, err := db.ListPeers(machine0ByID) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(peersOfMachine0), check.Equals, 9) | ||||
| 	c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") | ||||
| 	c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") | ||||
| 	c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetACLFilteredPeers(c *check.C) { | ||||
| 	type base struct { | ||||
| 		user *types.User | ||||
| 		key  *types.PreAuthKey | ||||
| 	} | ||||
| 
 | ||||
| 	stor := make([]base, 0) | ||||
| 
 | ||||
| 	for _, name := range []string{"test", "admin"} { | ||||
| 		user, err := db.CreateUser(name) | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 		pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 		stor = append(stor, base{user, pak}) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err := db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	for index := 0; index <= 10; index++ { | ||||
| 		machine := types.Machine{ | ||||
| 			ID:         uint64(index), | ||||
| 			MachineKey: "foo" + strconv.Itoa(index), | ||||
| 			NodeKey:    "bar" + strconv.Itoa(index), | ||||
| 			DiscoKey:   "faa" + strconv.Itoa(index), | ||||
| 			IPAddresses: types.MachineAddresses{ | ||||
| 				netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), | ||||
| 			}, | ||||
| 			Hostname:       "testmachine" + strconv.Itoa(index), | ||||
| 			UserID:         stor[index%2].user.ID, | ||||
| 			RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 			AuthKeyID:      uint(stor[index%2].key.ID), | ||||
| 		} | ||||
| 		db.db.Save(&machine) | ||||
| 	} | ||||
| 
 | ||||
| 	aclPolicy := &policy.ACLPolicy{ | ||||
| 		Groups: map[string][]string{ | ||||
| 			"group:test": {"admin"}, | ||||
| 		}, | ||||
| 		Hosts:     map[string]netip.Prefix{}, | ||||
| 		TagOwners: map[string][]string{}, | ||||
| 		ACLs: []policy.ACL{ | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"admin"}, | ||||
| 				Destinations: []string{"*:*"}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Action:       "accept", | ||||
| 				Sources:      []string{"test"}, | ||||
| 				Destinations: []string{"test:*"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Tests: []policy.ACLTest{}, | ||||
| 	} | ||||
| 
 | ||||
| 	adminMachine, err := db.GetMachineByID(1) | ||||
| 	c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	testMachine, err := db.GetMachineByID(2) | ||||
| 	c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machines, err := db.ListMachines() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) | ||||
| 	peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) | ||||
| 
 | ||||
| 	c.Log(peersOfTestMachine) | ||||
| 	c.Assert(len(peersOfTestMachine), check.Equals, 9) | ||||
| 	c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") | ||||
| 	c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") | ||||
| 	c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") | ||||
| 
 | ||||
| 	c.Log(peersOfAdminMachine) | ||||
| 	c.Assert(len(peersOfAdminMachine), check.Equals, 9) | ||||
| 	c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") | ||||
| 	c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") | ||||
| 	c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestExpireMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := &types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		Expiry:         &time.Time{}, | ||||
| 	} | ||||
| 	db.db.Save(machine) | ||||
| 
 | ||||
| 	machineFromDB, err := db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machineFromDB, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, false) | ||||
| 
 | ||||
| 	err = db.ExpireMachine(machineFromDB) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, true) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(1)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { | ||||
| 	input := types.MachineAddresses([]netip.Addr{ | ||||
| 		netip.MustParseAddr("192.0.2.1"), | ||||
| 		netip.MustParseAddr("2001:db8::1"), | ||||
| 	}) | ||||
| 	serialized, err := input.Value() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	if serial, ok := serialized.(string); ok { | ||||
| 		c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") | ||||
| 	} | ||||
| 
 | ||||
| 	var deserialized types.MachineAddresses | ||||
| 	err = deserialized.Scan(serialized) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(deserialized), check.Equals, len(input)) | ||||
| 	for i := range deserialized { | ||||
| 		c.Assert(deserialized[i], check.Equals, input[i]) | ||||
| 	} | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGenerateGivenName(c *check.C) { | ||||
| 	user1, err := db.CreateUser("user-1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("user-1", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := &types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "machine-key-1", | ||||
| 		NodeKey:        "node-key-1", | ||||
| 		DiscoKey:       "disco-key-1", | ||||
| 		Hostname:       "hostname-1", | ||||
| 		GivenName:      "hostname-1", | ||||
| 		UserID:         user1.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(machine) | ||||
| 
 | ||||
| 	givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2") | ||||
| 	comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") | ||||
| 	c.Assert(err, check.IsNil, comment) | ||||
| 	c.Assert(givenName, check.Equals, "hostname-2", comment) | ||||
| 
 | ||||
| 	givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1") | ||||
| 	comment = check.Commentf("Same user, same machine, same hostname, no conflict") | ||||
| 	c.Assert(err, check.IsNil, comment) | ||||
| 	c.Assert(givenName, check.Equals, "hostname-1", comment) | ||||
| 
 | ||||
| 	givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") | ||||
| 	comment = check.Commentf("Same user, unique machines, same hostname, conflict") | ||||
| 	c.Assert(err, check.IsNil, comment) | ||||
| 	c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) | ||||
| 
 | ||||
| 	givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") | ||||
| 	comment = check.Commentf("Unique users, unique machines, same hostname, conflict") | ||||
| 	c.Assert(err, check.IsNil, comment) | ||||
| 	c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestSetTags(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machine := &types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(machine) | ||||
| 
 | ||||
| 	// assign simple tags | ||||
| 	sTags := []string{"tag:test", "tag:foo"} | ||||
| 	err = db.SetTags(machine, sTags) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	machine, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags)) | ||||
| 
 | ||||
| 	// assign duplicat tags, expect no errors but no doubles in DB | ||||
| 	eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} | ||||
| 	err = db.SetTags(machine, eTags) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	machine, err = db.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert( | ||||
| 		machine.ForcedTags, | ||||
| 		check.DeepEquals, | ||||
| 		types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), | ||||
| 	) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(4)) | ||||
| } | ||||
| 
 | ||||
| func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		suppliedName string | ||||
| 		randomSuffix bool | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		db      *HSDatabase | ||||
| 		args    args | ||||
| 		want    *regexp.Regexp | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "simple machine name generation", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmachine", | ||||
| 				randomSuffix: false, | ||||
| 			}, | ||||
| 			want:    regexp.MustCompile("^testmachine$"), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 53 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | ||||
| 				randomSuffix: false, | ||||
| 			}, | ||||
| 			want:    regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| 			}, | ||||
| 			want:    regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 64 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | ||||
| 				randomSuffix: false, | ||||
| 			}, | ||||
| 			want:    nil, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 73 chars", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| 			}, | ||||
| 			want:    nil, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with random suffix", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "test", | ||||
| 				randomSuffix: true, | ||||
| 			}, | ||||
| 			want:    regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars with random suffix", | ||||
| 			db: &HSDatabase{ | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: true, | ||||
| 			}, | ||||
| 			want:    regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf( | ||||
| 					"Headscale.GenerateGivenName() error = %v, wantErr %v", | ||||
| 					err, | ||||
| 					tt.wantErr, | ||||
| 				) | ||||
| 
 | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			if tt.want != nil && !tt.want.MatchString(got) { | ||||
| 				t.Errorf( | ||||
| 					"Headscale.GenerateGivenName() = %v, does not match %v", | ||||
| 					tt.want, | ||||
| 					got, | ||||
| 				) | ||||
| 			} | ||||
| 
 | ||||
| 			if len(got) > util.LabelHostnameLength { | ||||
| 				t.Errorf( | ||||
| 					"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", | ||||
| 					got, | ||||
| 					util.LabelHostnameLength, | ||||
| 				) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestAutoApproveRoutes(c *check.C) { | ||||
| 	acl := []byte(` | ||||
| { | ||||
| 	"tagOwners": { | ||||
| 		"tag:exit": ["test"], | ||||
| 	}, | ||||
| 
 | ||||
| 	"groups": { | ||||
| 		"group:test": ["test"] | ||||
| 	}, | ||||
| 
 | ||||
| 	"acls": [ | ||||
| 		{"action": "accept", "users": ["*"], "ports": ["*:*"]}, | ||||
| 	], | ||||
| 
 | ||||
| 	"autoApprovers": { | ||||
| 		"exitNode": ["tag:exit"], | ||||
| 		"routes": { | ||||
| 			"10.10.0.0/16": ["group:test"], | ||||
| 			"10.11.0.0/16": ["test"], | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 	`) | ||||
| 
 | ||||
| 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pol, check.NotNil) | ||||
| 
 | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	nodeKey := key.NewNode() | ||||
| 
 | ||||
| 	defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") | ||||
| 	defaultRouteV6 := netip.MustParsePrefix("::/0") | ||||
| 	route1 := netip.MustParsePrefix("10.10.0.0/16") | ||||
| 	// Check if a subprefix of an autoapproved route is approved | ||||
| 	route2 := netip.MustParsePrefix("10.11.0.0/24") | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo: types.HostInfo{ | ||||
| 			RequestTags: []string{"tag:exit"}, | ||||
| 			RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, | ||||
| 		}, | ||||
| 		IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||
| 	} | ||||
| 
 | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.ProcessMachineRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine0ByID, err := db.GetMachineByID(0) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = db.EnableAutoApprovedRoutes(pol, machine0ByID) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(enabledRoutes, check.HasLen, 4) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(4)) | ||||
| } | ||||
| 
 | ||||
| func TestMachine_canAccess(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		filter   []tailcfg.FilterRule | ||||
| 		machine2 *types.Machine | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		machine types.Machine | ||||
| 		args    args | ||||
| 		want    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "no-rules", | ||||
| 			machine: types.Machine{ | ||||
| 				IPAddresses: types.MachineAddresses{ | ||||
| 					netip.MustParseAddr("10.0.0.1"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				filter: []tailcfg.FilterRule{}, | ||||
| 				machine2: &types.Machine{ | ||||
| 					IPAddresses: types.MachineAddresses{ | ||||
| 						netip.MustParseAddr("10.0.0.2"), | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "wildcard", | ||||
| 			machine: types.Machine{ | ||||
| 				IPAddresses: types.MachineAddresses{ | ||||
| 					netip.MustParseAddr("10.0.0.1"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				filter: []tailcfg.FilterRule{ | ||||
| 					{ | ||||
| 						SrcIPs: []string{"*"}, | ||||
| 						DstPorts: []tailcfg.NetPortRange{ | ||||
| 							{ | ||||
| 								IP: "*", | ||||
| 								Ports: tailcfg.PortRange{ | ||||
| 									First: 0, | ||||
| 									Last:  65535, | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				machine2: &types.Machine{ | ||||
| 					IPAddresses: types.MachineAddresses{ | ||||
| 						netip.MustParseAddr("10.0.0.2"), | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "explicit-m1-to-m2", | ||||
| 			machine: types.Machine{ | ||||
| 				IPAddresses: types.MachineAddresses{ | ||||
| 					netip.MustParseAddr("10.0.0.1"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				filter: []tailcfg.FilterRule{ | ||||
| 					{ | ||||
| 						SrcIPs: []string{"10.0.0.1"}, | ||||
| 						DstPorts: []tailcfg.NetPortRange{ | ||||
| 							{ | ||||
| 								IP: "10.0.0.2", | ||||
| 								Ports: tailcfg.PortRange{ | ||||
| 									First: 0, | ||||
| 									Last:  65535, | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				machine2: &types.Machine{ | ||||
| 					IPAddresses: types.MachineAddresses{ | ||||
| 						netip.MustParseAddr("10.0.0.2"), | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "explicit-m2-to-m1", | ||||
| 			machine: types.Machine{ | ||||
| 				IPAddresses: types.MachineAddresses{ | ||||
| 					netip.MustParseAddr("10.0.0.1"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				filter: []tailcfg.FilterRule{ | ||||
| 					{ | ||||
| 						SrcIPs: []string{"10.0.0.2"}, | ||||
| 						DstPorts: []tailcfg.NetPortRange{ | ||||
| 							{ | ||||
| 								IP: "10.0.0.1", | ||||
| 								Ports: tailcfg.PortRange{ | ||||
| 									First: 0, | ||||
| 									Last:  65535, | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				machine2: &types.Machine{ | ||||
| 					IPAddresses: types.MachineAddresses{ | ||||
| 						netip.MustParseAddr("10.0.0.2"), | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want { | ||||
| 				t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @ -1,17 +1,14 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| @ -23,28 +20,6 @@ var ( | ||||
| 	ErrPreAuthKeyACLTagInvalid     = errors.New("AuthKey tag is invalid") | ||||
| ) | ||||
| 
 | ||||
| // PreAuthKey describes a pre-authorization key usable in a particular user. | ||||
| type PreAuthKey struct { | ||||
| 	ID        uint64 `gorm:"primary_key"` | ||||
| 	Key       string | ||||
| 	UserID    uint | ||||
| 	User      User | ||||
| 	Reusable  bool | ||||
| 	Ephemeral bool `gorm:"default:false"` | ||||
| 	Used      bool `gorm:"default:false"` | ||||
| 	ACLTags   []PreAuthKeyACLTag | ||||
| 
 | ||||
| 	CreatedAt  *time.Time | ||||
| 	Expiration *time.Time | ||||
| } | ||||
| 
 | ||||
| // PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. | ||||
| type PreAuthKeyACLTag struct { | ||||
| 	ID           uint64 `gorm:"primary_key"` | ||||
| 	PreAuthKeyID uint64 | ||||
| 	Tag          string | ||||
| } | ||||
| 
 | ||||
| // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. | ||||
| func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 	userName string, | ||||
| @ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 	ephemeral bool, | ||||
| 	expiration *time.Time, | ||||
| 	aclTags []string, | ||||
| ) (*PreAuthKey, error) { | ||||
| ) (*types.PreAuthKey, error) { | ||||
| 	user, err := hsdb.GetUser(userName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	key := PreAuthKey{ | ||||
| 	key := types.PreAuthKey{ | ||||
| 		Key:        kstr, | ||||
| 		UserID:     user.ID, | ||||
| 		User:       *user, | ||||
| @ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 
 | ||||
| 			for _, tag := range aclTags { | ||||
| 				if !seenTags[tag] { | ||||
| 					if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { | ||||
| 					if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { | ||||
| 						return fmt.Errorf( | ||||
| 							"failed to ceate key tag in the database: %w", | ||||
| 							err, | ||||
| @ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| } | ||||
| 
 | ||||
| // ListPreAuthKeys returns the list of PreAuthKeys for a user. | ||||
| func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { | ||||
| func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||
| 	user, err := hsdb.GetUser(userName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	keys := []PreAuthKey{} | ||||
| 	if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { | ||||
| 	keys := []types.PreAuthKey{} | ||||
| 	if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| @ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { | ||||
| } | ||||
| 
 | ||||
| // GetPreAuthKey returns a PreAuthKey for a given key. | ||||
| func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { | ||||
| 	pak, err := hsdb.checkKeyValidity(key) | ||||
| func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { | ||||
| 	pak, err := hsdb.ValidatePreAuthKey(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err | ||||
| 
 | ||||
| // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | ||||
| // does not exist. | ||||
| func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { | ||||
| func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | ||||
| 	return hsdb.db.Transaction(func(db *gorm.DB) error { | ||||
| 		if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { | ||||
| 		if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { | ||||
| 			return result.Error | ||||
| 		} | ||||
| 
 | ||||
| @ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { | ||||
| } | ||||
| 
 | ||||
| // MarkExpirePreAuthKey marks a PreAuthKey as expired. | ||||
| func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { | ||||
| func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | ||||
| 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { | ||||
| } | ||||
| 
 | ||||
| // UsePreAuthKey marks a PreAuthKey as used. | ||||
| func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { | ||||
| func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | ||||
| 	k.Used = true | ||||
| 	if err := hsdb.db.Save(k).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to update key used status in the database: %w", err) | ||||
| @ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node | ||||
| // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node | ||||
| // If returns no error and a PreAuthKey, it can be used. | ||||
| func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { | ||||
| 	pak := PreAuthKey{} | ||||
| func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { | ||||
| 	pak := types.PreAuthKey{} | ||||
| 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | ||||
| 		result.Error, | ||||
| 		gorm.ErrRecordNotFound, | ||||
| @ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { | ||||
| 		return &pak, nil | ||||
| 	} | ||||
| 
 | ||||
| 	machines := []Machine{} | ||||
| 	if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | ||||
| 	machines := types.Machines{} | ||||
| 	if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| @ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) { | ||||
| 
 | ||||
| 	return hex.EncodeToString(bytes), nil | ||||
| } | ||||
| 
 | ||||
| func (key *PreAuthKey) toProto() *v1.PreAuthKey { | ||||
| 	protoKey := v1.PreAuthKey{ | ||||
| 		User:      key.User.Name, | ||||
| 		Id:        strconv.FormatUint(key.ID, util.Base10), | ||||
| 		Key:       key.Key, | ||||
| 		Ephemeral: key.Ephemeral, | ||||
| 		Reusable:  key.Reusable, | ||||
| 		Used:      key.Used, | ||||
| 		AclTags:   make([]string, len(key.ACLTags)), | ||||
| 	} | ||||
| 
 | ||||
| 	if key.Expiration != nil { | ||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.CreatedAt != nil { | ||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||
| 	} | ||||
| 
 | ||||
| 	for idx := range key.ACLTags { | ||||
| 		protoKey.AclTags[idx] = key.ACLTags[idx].Tag | ||||
| 	} | ||||
| 
 | ||||
| 	return &protoKey | ||||
| } | ||||
| @ -1,20 +1,22 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| ) | ||||
| 
 | ||||
| func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||
| 	_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) | ||||
| 	_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) | ||||
| 
 | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	// Did we get a valid key? | ||||
| @ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||
| 	// Make sure the User association is populated | ||||
| 	c.Assert(key.User.Name, check.Equals, user.Name) | ||||
| 
 | ||||
| 	_, err = app.db.ListPreAuthKeys("bogus") | ||||
| 	_, err = db.ListPreAuthKeys("bogus") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	keys, err := app.db.ListPreAuthKeys(user.Name) | ||||
| 	keys, err := db.ListPreAuthKeys(user.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(keys), check.Equals, 1) | ||||
| 
 | ||||
| @ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestExpiredPreAuthKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test2") | ||||
| 	user, err := db.CreateUser("test2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | ||||
| 	c.Assert(key, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { | ||||
| 	key, err := app.db.checkKeyValidity("potatoKey") | ||||
| 	key, err := db.ValidatePreAuthKey("potatoKey") | ||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) | ||||
| 	c.Assert(key, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestValidateKeyOk(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test3") | ||||
| 	user, err := db.CreateUser("test3") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestAlreadyUsedKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test4") | ||||
| 	user, err := db.CreateUser("test4") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testest", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | ||||
| 	c.Assert(key, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestReusableBeingUsedKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test5") | ||||
| 	user, err := db.CreateUser("test5") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testest", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test6") | ||||
| 	user, err := db.CreateUser("test6") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestEphemeralKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test7") | ||||
| 	user, err := db.CreateUser("test7") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	machine := Machine{ | ||||
| 	now := time.Now().Add(-time.Second * 30) | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testest", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		LastSeen:       &now, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	_, err = app.db.checkKeyValidity(pak.Key) | ||||
| 	_, err = db.ValidatePreAuthKey(pak.Key) | ||||
| 	// Ephemeral keys are by definition reusable | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test7", "testest") | ||||
| 	_, err = db.GetMachine("test7", "testest") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	app.expireEphemeralNodesWorker() | ||||
| 	db.ExpireEphemeralMachines(time.Second * 20) | ||||
| 
 | ||||
| 	// The machine record should have been deleted | ||||
| 	_, err = app.db.GetMachine("test7", "testest") | ||||
| 	_, err = db.GetMachine("test7", "testest") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(1)) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestExpirePreauthKey(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test3") | ||||
| 	user, err := db.CreateUser("test3") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pak.Expiration, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.ExpirePreAuthKey(pak) | ||||
| 	err = db.ExpirePreAuthKey(pak) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(pak.Expiration, check.NotNil) | ||||
| 
 | ||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | ||||
| 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | ||||
| 	c.Assert(key, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test6") | ||||
| 	user, err := db.CreateUser("test6") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	pak.Used = true | ||||
| 	app.db.db.Save(&pak) | ||||
| 	db.db.Save(&pak) | ||||
| 
 | ||||
| 	_, err = app.db.checkKeyValidity(pak.Key) | ||||
| 	_, err = db.ValidatePreAuthKey(pak.Key) | ||||
| 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestPreAuthKeyACLTags(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test8") | ||||
| 	user, err := db.CreateUser("test8") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) | ||||
| 	_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) | ||||
| 	c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected | ||||
| 
 | ||||
| 	tags := []string{"tag:test1", "tag:test2"} | ||||
| 	tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} | ||||
| 	_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) | ||||
| 	_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	listedPaks, err := app.db.ListPreAuthKeys("test8") | ||||
| 	listedPaks, err := db.ListPreAuthKeys("test8") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) | ||||
| 	c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags) | ||||
| } | ||||
| @ -1,55 +1,19 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	ErrRouteIsNotAvailable = errors.New("route is not available") | ||||
| 	ExitRouteV4            = netip.MustParsePrefix("0.0.0.0/0") | ||||
| 	ExitRouteV6            = netip.MustParsePrefix("::/0") | ||||
| ) | ||||
| var ErrRouteIsNotAvailable = errors.New("route is not available") | ||||
| 
 | ||||
| type Route struct { | ||||
| 	gorm.Model | ||||
| 
 | ||||
| 	MachineID uint64 | ||||
| 	Machine   Machine | ||||
| 	Prefix    IPPrefix | ||||
| 
 | ||||
| 	Advertised bool | ||||
| 	Enabled    bool | ||||
| 	IsPrimary  bool | ||||
| } | ||||
| 
 | ||||
| type Routes []Route | ||||
| 
 | ||||
| func (r *Route) String() string { | ||||
| 	return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) | ||||
| } | ||||
| 
 | ||||
| func (r *Route) isExitRoute() bool { | ||||
| 	return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 | ||||
| } | ||||
| 
 | ||||
| func (rs Routes) toPrefixes() []netip.Prefix { | ||||
| 	prefixes := make([]netip.Prefix, len(rs)) | ||||
| 	for i, r := range rs { | ||||
| 		prefixes[i] = netip.Prefix(r.Prefix) | ||||
| 	} | ||||
| 
 | ||||
| 	return prefixes | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { | ||||
| 	var routes []Route | ||||
| func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db.Preload("Machine").Find(&routes).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { | ||||
| 	var routes []Route | ||||
| func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| 		Where("machine_id = ? AND advertised = true", machine.ID). | ||||
| 		Find(&routes).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| 		Where("machine_id = ?", m.ID). | ||||
| @ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { | ||||
| 	var route Route | ||||
| func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | ||||
| 	var route types.Route | ||||
| 	err := hsdb.db.Preload("Machine").First(&route, id).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { | ||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||
| 	// be enabled at the same time, as per | ||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||
| 	if route.isExitRoute() { | ||||
| 		return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) | ||||
| 	if route.IsExitRoute() { | ||||
| 		return hsdb.enableRoutes( | ||||
| 			&route.Machine, | ||||
| 			types.ExitRouteV4.String(), | ||||
| 			types.ExitRouteV6.String(), | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) | ||||
| @ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||
| 	// be enabled at the same time, as per | ||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||
| 	if !route.isExitRoute() { | ||||
| 	if !route.IsExitRoute() { | ||||
| 		route.Enabled = false | ||||
| 		route.IsPrimary = false | ||||
| 		err = hsdb.db.Save(route).Error | ||||
| @ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		return hsdb.handlePrimarySubnetFailover() | ||||
| 		return hsdb.HandlePrimarySubnetFailover() | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||
| @ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 	} | ||||
| 
 | ||||
| 	for i := range routes { | ||||
| 		if routes[i].isExitRoute() { | ||||
| 		if routes[i].IsExitRoute() { | ||||
| 			routes[i].Enabled = false | ||||
| 			routes[i].IsPrimary = false | ||||
| 			err = hsdb.db.Save(&routes[i]).Error | ||||
| @ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| @ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||
| 	// be enabled at the same time, as per | ||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||
| 	if !route.isExitRoute() { | ||||
| 	if !route.IsExitRoute() { | ||||
| 		if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		return hsdb.handlePrimarySubnetFailover() | ||||
| 		return hsdb.HandlePrimarySubnetFailover() | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||
| @ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	routesToDelete := []Route{} | ||||
| 	routesToDelete := types.Routes{} | ||||
| 	for _, r := range routes { | ||||
| 		if r.isExitRoute() { | ||||
| 		if r.IsExitRoute() { | ||||
| 			routesToDelete = append(routesToDelete, r) | ||||
| 		} | ||||
| 	} | ||||
| @ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { | ||||
| func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | ||||
| 	routes, err := hsdb.GetMachineRoutes(m) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| // isUniquePrefix returns if there is another machine providing the same route already. | ||||
| func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { | ||||
| func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { | ||||
| 	var count int64 | ||||
| 	hsdb.db. | ||||
| 		Model(&Route{}). | ||||
| 		Model(&types.Route{}). | ||||
| 		Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | ||||
| 			route.Prefix, | ||||
| 			route.MachineID, | ||||
| @ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { | ||||
| 	return count == 0 | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { | ||||
| 	var route Route | ||||
| func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { | ||||
| 	var route types.Route | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| 		Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). | ||||
| 		Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). | ||||
| 		First(&route).Error | ||||
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 		return nil, err | ||||
| @ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { | ||||
| 
 | ||||
| // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | ||||
| // Exit nodes are not considered for this, as they are never marked as Primary. | ||||
| func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { | ||||
| 	var routes []Route | ||||
| func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| 		Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). | ||||
| @ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | ||||
| 	currentRoutes := []Route{} | ||||
| func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | ||||
| 	currentRoutes := types.Routes{} | ||||
| 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | ||||
| 
 | ||||
| 	for prefix, exists := range advertisedRoutes { | ||||
| 		if !exists { | ||||
| 			route := Route{ | ||||
| 			route := types.Route{ | ||||
| 				MachineID:  machine.ID, | ||||
| 				Prefix:     IPPrefix(prefix), | ||||
| 				Prefix:     types.IPPrefix(prefix), | ||||
| 				Advertised: true, | ||||
| 				Enabled:    false, | ||||
| 			} | ||||
| @ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||
| 	// first, get all the enabled routes | ||||
| 	var routes []Route | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| 		Where("advertised = ? AND enabled = ?", true, true). | ||||
| @ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 
 | ||||
| 	routesChanged := false | ||||
| 	for pos, route := range routes { | ||||
| 		if route.isExitRoute() { | ||||
| 		if route.IsExitRoute() { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| @ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 		} | ||||
| 
 | ||||
| 		if route.IsPrimary { | ||||
| 			if route.Machine.isOnline() { | ||||
| 			if route.Machine.IsOnline() { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| @ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 				Msgf("machine offline, finding a new primary subnet") | ||||
| 
 | ||||
| 			// find a new primary route | ||||
| 			var newPrimaryRoutes []Route | ||||
| 			var newPrimaryRoutes types.Routes | ||||
| 			err := hsdb.db. | ||||
| 				Preload("Machine"). | ||||
| 				Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | ||||
| @ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			var newPrimaryRoute *Route | ||||
| 			var newPrimaryRoute *types.Route | ||||
| 			for pos, r := range newPrimaryRoutes { | ||||
| 				if r.Machine.isOnline() { | ||||
| 				if r.Machine.IsOnline() { | ||||
| 					newPrimaryRoute = &newPrimaryRoutes[pos] | ||||
| 
 | ||||
| 					break | ||||
| @ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (rs Routes) toProto() []*v1.Route { | ||||
| 	protoRoutes := []*v1.Route{} | ||||
| 
 | ||||
| 	for _, route := range rs { | ||||
| 		protoRoute := v1.Route{ | ||||
| 			Id:         uint64(route.ID), | ||||
| 			Machine:    route.Machine.toProto(), | ||||
| 			Prefix:     netip.Prefix(route.Prefix).String(), | ||||
| 			Advertised: route.Advertised, | ||||
| 			Enabled:    route.Enabled, | ||||
| 			IsPrimary:  route.IsPrimary, | ||||
| 			CreatedAt:  timestamppb.New(route.CreatedAt), | ||||
| 			UpdatedAt:  timestamppb.New(route.UpdatedAt), | ||||
| 		} | ||||
| 
 | ||||
| 		if route.DeletedAt.Valid { | ||||
| 			protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) | ||||
| 		} | ||||
| 
 | ||||
| 		protoRoutes = append(protoRoutes, &protoRoute) | ||||
| // EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. | ||||
| func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | ||||
| 	aclPolicy *policy.ACLPolicy, | ||||
| 	machine *types.Machine, | ||||
| ) error { | ||||
| 	if len(machine.IPAddresses) == 0 { | ||||
| 		return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs | ||||
| 	} | ||||
| 
 | ||||
| 	return protoRoutes | ||||
| 	routes, err := hsdb.GetMachineAdvertisedRoutes(machine) | ||||
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Str("machine", machine.Hostname). | ||||
| 			Msg("Could not get advertised routes for machine") | ||||
| 
 | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	approvedRoutes := types.Routes{} | ||||
| 
 | ||||
| 	for _, advertisedRoute := range routes { | ||||
| 		if advertisedRoute.Enabled { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( | ||||
| 			netip.Prefix(advertisedRoute.Prefix), | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			log.Err(err). | ||||
| 				Str("advertisedRoute", advertisedRoute.String()). | ||||
| 				Uint64("machineId", machine.ID). | ||||
| 				Msg("Failed to resolve autoApprovers for advertised route") | ||||
| 
 | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		for _, approvedAlias := range routeApprovers { | ||||
| 			if approvedAlias == machine.User.Name { | ||||
| 				approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||
| 			} else { | ||||
| 				// TODO(kradalby): figure out how to get this to depend on less stuff | ||||
| 				approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) | ||||
| 				if err != nil { | ||||
| 					log.Err(err). | ||||
| 						Str("alias", approvedAlias). | ||||
| 						Msg("Failed to expand alias when processing autoApprovers policy") | ||||
| 
 | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first | ||||
| 				if approvedIps.Contains(machine.IPAddresses[0]) { | ||||
| 					approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, approvedRoute := range approvedRoutes { | ||||
| 		err := hsdb.EnableRoute(uint64(approvedRoute.ID)) | ||||
| 		if err != nil { | ||||
| 			log.Err(err). | ||||
| 				Str("approvedRoute", approvedRoute.String()). | ||||
| 				Uint64("machineId", machine.ID). | ||||
| 				Msg("Failed to enable approved route") | ||||
| 
 | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| @ -1,9 +1,11 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"tailscale.com/tailcfg" | ||||
| @ -11,13 +13,13 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestGetRoutes(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_get_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_get_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	route, err := netip.ParsePrefix("10.0.0.0/24") | ||||
| @ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) { | ||||
| 		RoutableIPs: []netip.Prefix{route}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_get_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine) | ||||
| 	err = db.ProcessMachineRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) | ||||
| 	advertisedRoutes, err := db.GetAdvertisedRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(advertisedRoutes), check.Equals, 1) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine, "192.168.0.0/24") | ||||
| 	err = db.enableRoutes(&machine, "192.168.0.0/24") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetEnableRoutes(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	route, err := netip.ParsePrefix( | ||||
| @ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { | ||||
| 		RoutableIPs: []netip.Prefix{route, route2}, | ||||
| 	} | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine) | ||||
| 	err = db.ProcessMachineRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) | ||||
| 	availableRoutes, err := db.GetAdvertisedRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(availableRoutes), check.Equals, 2) | ||||
| 
 | ||||
| 	noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) | ||||
| 	noEnabledRoutes, err := db.GetEnabledRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(noEnabledRoutes), check.Equals, 0) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine, "192.168.0.0/24") | ||||
| 	err = db.enableRoutes(&machine, "192.168.0.0/24") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes, err := app.db.GetEnabledRoutes(&machine) | ||||
| 	enabledRoutes, err := db.GetEnabledRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes), check.Equals, 1) | ||||
| 
 | ||||
| 	// Adding it twice will just let it pass through | ||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) | ||||
| 	enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine, "150.0.10.0/25") | ||||
| 	err = db.enableRoutes(&machine, "150.0.10.0/25") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) | ||||
| 	enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(3)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestIsUniquePrefix(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	route, err := netip.ParsePrefix( | ||||
| @ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { | ||||
| 	hostInfo1 := tailcfg.Hostinfo{ | ||||
| 		RoutableIPs: []netip.Prefix{route, route2}, | ||||
| 	} | ||||
| 	machine1 := Machine{ | ||||
| 	machine1 := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo1), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine1) | ||||
| 	db.db.Save(&machine1) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine1) | ||||
| 	err = db.ProcessMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, route.String()) | ||||
| 	err = db.enableRoutes(&machine1, route.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, route2.String()) | ||||
| 	err = db.enableRoutes(&machine1, route2.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	hostInfo2 := tailcfg.Hostinfo{ | ||||
| 		RoutableIPs: []netip.Prefix{route2}, | ||||
| 	} | ||||
| 	machine2 := Machine{ | ||||
| 	machine2 := types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo2), | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine2) | ||||
| 	db.db.Save(&machine2) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine2) | ||||
| 	err = db.ProcessMachineRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine2, route2.String()) | ||||
| 	err = db.enableRoutes(&machine2, route2.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||
| 
 | ||||
| 	enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) | ||||
| 	enabledRoutes2, err := db.GetEnabledRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes2), check.Equals, 1) | ||||
| 
 | ||||
| 	routes, err := app.db.getMachinePrimaryRoutes(&machine1) | ||||
| 	routes, err := db.GetMachinePrimaryRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 2) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 0) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(3)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestSubnetFailover(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	prefix, err := netip.ParsePrefix( | ||||
| @ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) { | ||||
| 	} | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	machine1 := Machine{ | ||||
| 	machine1 := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo1), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	app.db.db.Save(&machine1) | ||||
| 	db.db.Save(&machine1) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine1) | ||||
| 	err = db.ProcessMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | ||||
| 	err = db.enableRoutes(&machine1, prefix.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	err = db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.handlePrimarySubnetFailover() | ||||
| 	err = db.HandlePrimarySubnetFailover() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||
| 
 | ||||
| 	route, err := app.db.getPrimaryRoute(prefix) | ||||
| 	route, err := db.getPrimaryRoute(prefix) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(route.MachineID, check.Equals, machine1.ID) | ||||
| 
 | ||||
| 	hostInfo2 := tailcfg.Hostinfo{ | ||||
| 		RoutableIPs: []netip.Prefix{prefix2}, | ||||
| 	} | ||||
| 	machine2 := Machine{ | ||||
| 	machine2 := types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo2), | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	app.db.db.Save(&machine2) | ||||
| 	db.db.Save(&machine2) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine2) | ||||
| 	err = db.ProcessMachineRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine2, prefix2.String()) | ||||
| 	err = db.enableRoutes(&machine2, prefix2.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.handlePrimarySubnetFailover() | ||||
| 	err = db.HandlePrimarySubnetFailover() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err = db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||
| 
 | ||||
| 	enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) | ||||
| 	enabledRoutes2, err := db.GetEnabledRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes2), check.Equals, 1) | ||||
| 
 | ||||
| 	routes, err := app.db.getMachinePrimaryRoutes(&machine1) | ||||
| 	routes, err := db.GetMachinePrimaryRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 2) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 0) | ||||
| 
 | ||||
| 	// lets make machine1 lastseen 10 mins ago | ||||
| 	before := now.Add(-10 * time.Minute) | ||||
| 	machine1.LastSeen = &before | ||||
| 	err = app.db.db.Save(&machine1).Error | ||||
| 	err = db.db.Save(&machine1).Error | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.handlePrimarySubnetFailover() | ||||
| 	err = db.HandlePrimarySubnetFailover() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine1) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 1) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 1) | ||||
| 
 | ||||
| 	machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ | ||||
| 	machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ | ||||
| 		RoutableIPs: []netip.Prefix{prefix, prefix2}, | ||||
| 	}) | ||||
| 	err = app.db.db.Save(&machine2).Error | ||||
| 	err = db.db.Save(&machine2).Error | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine2) | ||||
| 	err = db.ProcessMachineRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine2, prefix.String()) | ||||
| 	err = db.enableRoutes(&machine2, prefix.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.handlePrimarySubnetFailover() | ||||
| 	err = db.HandlePrimarySubnetFailover() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine1) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 0) | ||||
| 
 | ||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | ||||
| 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 2) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(6)) | ||||
| } | ||||
| 
 | ||||
| // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, | ||||
| // including both the primary routes the node is responsible for, and the | ||||
| // exit node routes if enabled. | ||||
| func (s *Suite) TestAllowedIPRoutes(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	prefix, err := netip.ParsePrefix( | ||||
| @ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | ||||
| 	machineKey := key.NewMachine() | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	machine1 := Machine{ | ||||
| 	machine1 := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||
| 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||
| 		DiscoKey:       util.DiscoPublicKeyStripPrefix(discoKey.Public()), | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo1), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	app.db.db.Save(&machine1) | ||||
| 	db.db.Save(&machine1) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine1) | ||||
| 	err = db.ProcessMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | ||||
| 	err = db.enableRoutes(&machine1, prefix.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	// We do not enable this one on purpose to test that it is not enabled | ||||
| 	// err = app.db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	// err = db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	// c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	routes, err := app.db.GetMachineRoutes(&machine1) | ||||
| 	routes, err := db.GetMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	for _, route := range routes { | ||||
| 		if route.isExitRoute() { | ||||
| 			err = app.db.EnableRoute(uint64(route.ID)) | ||||
| 		if route.IsExitRoute() { | ||||
| 			err = db.EnableRoute(uint64(route.ID)) | ||||
| 			c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 			// We only enable one exit route, so we can test that both are enabled | ||||
| @ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = app.db.handlePrimarySubnetFailover() | ||||
| 	err = db.HandlePrimarySubnetFailover() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 3) | ||||
| 
 | ||||
| 	peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) | ||||
| 	peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(peer.AllowedIPs), check.Equals, 3) | ||||
| @ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | ||||
| 
 | ||||
| 	// Now we disable only one of the exit routes | ||||
| 	// and we see if both are disabled | ||||
| 	var exitRouteV4 Route | ||||
| 	var exitRouteV4 types.Route | ||||
| 	for _, route := range routes { | ||||
| 		if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { | ||||
| 		if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { | ||||
| 			exitRouteV4 = route | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = app.db.DisableRoute(uint64(exitRouteV4.ID)) | ||||
| 	err = db.DisableRoute(uint64(exitRouteV4.ID)) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err = db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 1) | ||||
| 
 | ||||
| 	// and now we delete only one of the exit routes | ||||
| 	// and we check if both are deleted | ||||
| 	routes, err = app.db.GetMachineRoutes(&machine1) | ||||
| 	routes, err = db.GetMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 4) | ||||
| 
 | ||||
| 	err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) | ||||
| 	err = db.DeleteRoute(uint64(exitRouteV4.ID)) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	routes, err = app.db.GetMachineRoutes(&machine1) | ||||
| 	routes, err = db.GetMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 2) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(2)) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestDeleteRoutes(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | ||||
| 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	prefix, err := netip.ParsePrefix( | ||||
| @ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { | ||||
| 	} | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	machine1 := Machine{ | ||||
| 	machine1 := types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "test_enable_route_machine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       HostInfo(hostInfo1), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	app.db.db.Save(&machine1) | ||||
| 	db.db.Save(&machine1) | ||||
| 
 | ||||
| 	err = app.db.processMachineRoutes(&machine1) | ||||
| 	err = db.ProcessMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | ||||
| 	err = db.enableRoutes(&machine1, prefix.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	err = db.enableRoutes(&machine1, prefix2.String()) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	routes, err := app.db.GetMachineRoutes(&machine1) | ||||
| 	routes, err := db.GetMachineRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.DeleteRoute(uint64(routes[0].ID)) | ||||
| 	err = db.DeleteRoute(uint64(routes[0].ID)) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | ||||
| 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 1) | ||||
| 
 | ||||
| 	c.Assert(channelUpdates, check.Equals, int32(2)) | ||||
| } | ||||
							
								
								
									
										74
									
								
								hscontrol/db/suite_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								hscontrol/db/suite_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gopkg.in/check.v1" | ||||
| ) | ||||
| 
 | ||||
| func Test(t *testing.T) { | ||||
| 	check.TestingT(t) | ||||
| } | ||||
| 
 | ||||
| var _ = check.Suite(&Suite{}) | ||||
| 
 | ||||
| type Suite struct{} | ||||
| 
 | ||||
| var ( | ||||
| 	tmpDir string | ||||
| 	db     *HSDatabase | ||||
| 
 | ||||
| 	// channelUpdates counts the number of times | ||||
| 	// either of the channels was notified. | ||||
| 	channelUpdates int32 | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) SetUpTest(c *check.C) { | ||||
| 	atomic.StoreInt32(&channelUpdates, 0) | ||||
| 	s.ResetDB(c) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TearDownTest(c *check.C) { | ||||
| 	os.RemoveAll(tmpDir) | ||||
| } | ||||
| 
 | ||||
| func notificationSink(c <-chan struct{}) { | ||||
| 	for { | ||||
| 		<-c | ||||
| 		atomic.AddInt32(&channelUpdates, 1) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) ResetDB(c *check.C) { | ||||
| 	if len(tmpDir) != 0 { | ||||
| 		os.RemoveAll(tmpDir) | ||||
| 	} | ||||
| 	var err error | ||||
| 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test") | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	sink := make(chan struct{}) | ||||
| 
 | ||||
| 	go notificationSink(sink) | ||||
| 
 | ||||
| 	db, err = NewHeadscaleDatabase( | ||||
| 		"sqlite3", | ||||
| 		tmpDir+"/headscale_test.db", | ||||
| 		false, | ||||
| 		false, | ||||
| 		sink, | ||||
| 		sink, | ||||
| 		[]netip.Prefix{ | ||||
| 			netip.MustParsePrefix("10.27.0.0/23"), | ||||
| 		}, | ||||
| 		"", | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| } | ||||
| @ -1,17 +1,12 @@ | ||||
| package hscontrol | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| @ -20,33 +15,16 @@ var ( | ||||
| 	ErrUserExists        = errors.New("user already exists") | ||||
| 	ErrUserNotFound      = errors.New("user not found") | ||||
| 	ErrUserStillHasNodes = errors.New("user not empty: node(s) found") | ||||
| 	ErrInvalidUserName   = errors.New("invalid user name") | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// value related to RFC 1123 and 952. | ||||
| 	labelHostnameLength = 63 | ||||
| ) | ||||
| 
 | ||||
| var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") | ||||
| 
 | ||||
| // User is the way Headscale implements the concept of users in Tailscale | ||||
| // | ||||
| // At the end of the day, users in Tailscale are some kind of 'bubbles' or users | ||||
| // that contain our machines. | ||||
| type User struct { | ||||
| 	gorm.Model | ||||
| 	Name string `gorm:"unique"` | ||||
| } | ||||
| 
 | ||||
| // CreateUser creates a new User. Returns error if could not be created | ||||
| // or another user already exists. | ||||
| func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { | ||||
| 	err := CheckForFQDNRules(name) | ||||
| func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | ||||
| 	err := util.CheckForFQDNRules(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	user := User{} | ||||
| 	user := types.User{} | ||||
| 	if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { | ||||
| 		return nil, ErrUserExists | ||||
| 	} | ||||
| @ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = CheckForFQDNRules(newName) | ||||
| 	err = util.CheckForFQDNRules(newName) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||
| } | ||||
| 
 | ||||
| // GetUser fetches a user by name. | ||||
| func (hsdb *HSDatabase) GetUser(name string) (*User, error) { | ||||
| 	user := User{} | ||||
| func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | ||||
| 	user := types.User{} | ||||
| 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | ||||
| 		result.Error, | ||||
| 		gorm.ErrRecordNotFound, | ||||
| @ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) { | ||||
| } | ||||
| 
 | ||||
| // ListUsers gets all the existing users. | ||||
| func (hsdb *HSDatabase) ListUsers() ([]User, error) { | ||||
| 	users := []User{} | ||||
| func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | ||||
| 	users := []types.User{} | ||||
| 	if err := hsdb.db.Find(&users).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) { | ||||
| } | ||||
| 
 | ||||
| // ListMachinesByUser gets all the nodes in a given user. | ||||
| func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | ||||
| 	err := CheckForFQDNRules(name) | ||||
| func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { | ||||
| 	err := util.CheckForFQDNRules(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	machines := []Machine{} | ||||
| 	if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { | ||||
| 	machines := types.Machines{} | ||||
| 	if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| @ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | ||||
| } | ||||
| 
 | ||||
| // SetMachineUser assigns a Machine to a user. | ||||
| func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { | ||||
| 	err := CheckForFQDNRules(username) | ||||
| func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { | ||||
| 	err := util.CheckForFQDNRules(username) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (n *User) toTailscaleUser() *tailcfg.User { | ||||
| 	user := tailcfg.User{ | ||||
| 		ID:            tailcfg.UserID(n.ID), | ||||
| 		LoginName:     n.Name, | ||||
| 		DisplayName:   n.Name, | ||||
| 		ProfilePicURL: "", | ||||
| 		Domain:        "headscale.net", | ||||
| 		Logins:        []tailcfg.LoginID{}, | ||||
| 		Created:       time.Time{}, | ||||
| 	} | ||||
| 
 | ||||
| 	return &user | ||||
| } | ||||
| 
 | ||||
| func (n *User) toTailscaleLogin() *tailcfg.Login { | ||||
| 	login := tailcfg.Login{ | ||||
| 		ID:            tailcfg.LoginID(n.ID), | ||||
| 		LoginName:     n.Name, | ||||
| 		DisplayName:   n.Name, | ||||
| 		ProfilePicURL: "", | ||||
| 		Domain:        "headscale.net", | ||||
| 	} | ||||
| 
 | ||||
| 	return &login | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getMapResponseUserProfiles( | ||||
| 	machine Machine, | ||||
| 	peers Machines, | ||||
| func (hsdb *HSDatabase) GetMapResponseUserProfiles( | ||||
| 	machine types.Machine, | ||||
| 	peers types.Machines, | ||||
| ) []tailcfg.UserProfile { | ||||
| 	userMap := make(map[string]User) | ||||
| 	userMap := make(map[string]types.User) | ||||
| 	userMap[machine.User.Name] = machine.User | ||||
| 	for _, peer := range peers { | ||||
| 		userMap[peer.User.Name] = peer.User // not worth checking if already is there | ||||
| @ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles( | ||||
| 
 | ||||
| 	return profiles | ||||
| } | ||||
| 
 | ||||
| func (n *User) toProto() *v1.User { | ||||
| 	return &v1.User{ | ||||
| 		Id:        strconv.FormatUint(uint64(n.ID), util.Base10), | ||||
| 		Name:      n.Name, | ||||
| 		CreatedAt: timestamppb.New(n.CreatedAt), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // NormalizeToFQDNRules will replace forbidden chars in user | ||||
| // it can also return an error if the user doesn't respect RFC 952 and 1123. | ||||
| func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { | ||||
| 	name = strings.ToLower(name) | ||||
| 	name = strings.ReplaceAll(name, "'", "") | ||||
| 	atIdx := strings.Index(name, "@") | ||||
| 	if stripEmailDomain && atIdx > 0 { | ||||
| 		name = name[:atIdx] | ||||
| 	} else { | ||||
| 		name = strings.ReplaceAll(name, "@", ".") | ||||
| 	} | ||||
| 	name = invalidCharsInUserRegex.ReplaceAllString(name, "-") | ||||
| 
 | ||||
| 	for _, elt := range strings.Split(name, ".") { | ||||
| 		if len(elt) > labelHostnameLength { | ||||
| 			return "", fmt.Errorf( | ||||
| 				"label %v is more than 63 chars: %w", | ||||
| 				elt, | ||||
| 				ErrInvalidUserName, | ||||
| 			) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return name, nil | ||||
| } | ||||
| 
 | ||||
| func CheckForFQDNRules(name string) error { | ||||
| 	if len(name) > labelHostnameLength { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 	if strings.ToLower(name) != name { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment should be lowercase. %v doesn't comply with this rule: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 	if invalidCharsInUserRegex.MatchString(name) { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										277
									
								
								hscontrol/db/users_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										277
									
								
								hscontrol/db/users_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,277 @@ | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestCreateAndDestroyUser(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(user.Name, check.Equals, "test") | ||||
| 
 | ||||
| 	users, err := db.ListUsers() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(users), check.Equals, 1) | ||||
| 
 | ||||
| 	err = db.DestroyUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetUser("test") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestDestroyUserErrors(c *check.C) { | ||||
| 	err := db.DestroyUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = db.DestroyUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) | ||||
| 	// destroying a user also deletes all associated preauthkeys | ||||
| 	c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) | ||||
| 
 | ||||
| 	user, err = db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.DestroyUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserStillHasNodes) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestRenameUser(c *check.C) { | ||||
| 	userTest, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(userTest.Name, check.Equals, "test") | ||||
| 
 | ||||
| 	users, err := db.ListUsers() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(users), check.Equals, 1) | ||||
| 
 | ||||
| 	err = db.RenameUser("test", "test-renamed") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	_, err = db.GetUser("test-renamed") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = db.RenameUser("test-does-not-exit", "test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	userTest2, err := db.CreateUser("test2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(userTest2.Name, check.Equals, "test2") | ||||
| 
 | ||||
| 	err = db.RenameUser("test2", "test-renamed") | ||||
| 	c.Assert(err, check.Equals, ErrUserExists) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { | ||||
| 	userShared1, err := db.CreateUser("shared1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userShared2, err := db.CreateUser("shared2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userShared3, err := db.CreateUser("shared3") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared1, err := db.CreatePreAuthKey( | ||||
| 		userShared1.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared2, err := db.CreatePreAuthKey( | ||||
| 		userShared2.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared3, err := db.CreatePreAuthKey( | ||||
| 		userShared3.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKey2Shared1, err := db.CreatePreAuthKey( | ||||
| 		userShared1.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machineInShared1 := &types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		Hostname:       "test_get_shared_nodes_1", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared1.ID), | ||||
| 	} | ||||
| 	db.db.Save(machineInShared1) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared2 := &types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_2", | ||||
| 		UserID:         userShared2.ID, | ||||
| 		User:           *userShared2, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared2.ID), | ||||
| 	} | ||||
| 	db.db.Save(machineInShared2) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared3 := &types.Machine{ | ||||
| 		ID:             3, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_3", | ||||
| 		UserID:         userShared3.ID, | ||||
| 		User:           *userShared3, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared3.ID), | ||||
| 	} | ||||
| 	db.db.Save(machineInShared3) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine2InShared1 := &types.Machine{ | ||||
| 		ID:             4, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_4", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||
| 		AuthKeyID:      uint(preAuthKey2Shared1.ID), | ||||
| 	} | ||||
| 	db.db.Save(machine2InShared1) | ||||
| 
 | ||||
| 	peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userProfiles := db.GetMapResponseUserProfiles( | ||||
| 		*machineInShared1, | ||||
| 		peersOfMachine1InShared1, | ||||
| 	) | ||||
| 
 | ||||
| 	c.Assert(len(userProfiles), check.Equals, 3) | ||||
| 
 | ||||
| 	found := false | ||||
| 	for _, userProfiles := range userProfiles { | ||||
| 		if userProfiles.DisplayName == userShared1.Name { | ||||
| 			found = true | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	c.Assert(found, check.Equals, true) | ||||
| 
 | ||||
| 	found = false | ||||
| 	for _, userProfile := range userProfiles { | ||||
| 		if userProfile.DisplayName == userShared2.Name { | ||||
| 			found = true | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	c.Assert(found, check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestSetMachineUser(c *check.C) { | ||||
| 	oldUser, err := db.CreateUser("old") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	newUser, err := db.CreateUser("new") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         oldUser.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, "non-existing-user") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
| } | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"strings" | ||||
| 
 | ||||
| 	mapset "github.com/deckarep/golang-set/v2" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"go4.org/netipx" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/dnstype" | ||||
| @ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { | ||||
| // | ||||
| // This will produce a resolver like: | ||||
| // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` | ||||
| func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { | ||||
| func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { | ||||
| 	for _, resolver := range resolvers { | ||||
| 		if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { | ||||
| 			attrs := url.Values{ | ||||
| @ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { | ||||
| func getMapResponseDNSConfig( | ||||
| 	dnsConfigOrig *tailcfg.DNSConfig, | ||||
| 	baseDomain string, | ||||
| 	machine Machine, | ||||
| 	peers Machines, | ||||
| 	machine types.Machine, | ||||
| 	peers types.Machines, | ||||
| ) *tailcfg.DNSConfig { | ||||
| 	var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() | ||||
| 	if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled | ||||
| @ -200,7 +201,7 @@ func getMapResponseDNSConfig( | ||||
| 			), | ||||
| 		) | ||||
| 
 | ||||
| 		userSet := mapset.NewSet[User]() | ||||
| 		userSet := mapset.NewSet[types.User]() | ||||
| 		userSet.Add(machine.User) | ||||
| 		for _, p := range peers { | ||||
| 			userSet.Add(p.User) | ||||
|  | ||||
| @ -4,6 +4,8 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/dnstype" | ||||
| @ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machineInShared1 := &Machine{ | ||||
| 	machineInShared1 := &types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| @ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_1", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared1) | ||||
| 	err = app.db.MachineSave(machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared2 := &Machine{ | ||||
| 	machineInShared2 := &types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_2", | ||||
| 		UserID:         userShared2.ID, | ||||
| 		User:           *userShared2, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared2) | ||||
| 	err = app.db.MachineSave(machineInShared2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared3 := &Machine{ | ||||
| 	machineInShared3 := &types.Machine{ | ||||
| 		ID:             3, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_3", | ||||
| 		UserID:         userShared3.ID, | ||||
| 		User:           *userShared3, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared3) | ||||
| 	err = app.db.MachineSave(machineInShared3) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine2InShared1 := &Machine{ | ||||
| 	machine2InShared1 := &types.Machine{ | ||||
| 		ID:             4, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_4", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||
| 		AuthKeyID:      uint(PreAuthKey2InShared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machine2InShared1) | ||||
| 	err = app.db.MachineSave(machine2InShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	baseDomain := "foobar.headscale.net" | ||||
| 	dnsConfigOrig := tailcfg.DNSConfig{ | ||||
| @ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | ||||
| 		Proxied: true, | ||||
| 	} | ||||
| 
 | ||||
| 	peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) | ||||
| 	peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	dnsConfig := getMapResponseDNSConfig( | ||||
| @ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machineInShared1 := &Machine{ | ||||
| 	machineInShared1 := &types.Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| @ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_1", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared1) | ||||
| 	err = app.db.MachineSave(machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared2 := &Machine{ | ||||
| 	machineInShared2 := &types.Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_2", | ||||
| 		UserID:         userShared2.ID, | ||||
| 		User:           *userShared2, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared2) | ||||
| 	err = app.db.MachineSave(machineInShared2) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared3 := &Machine{ | ||||
| 	machineInShared3 := &types.Machine{ | ||||
| 		ID:             3, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_3", | ||||
| 		UserID:         userShared3.ID, | ||||
| 		User:           *userShared3, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared3) | ||||
| 	err = app.db.MachineSave(machineInShared3) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine2InShared1 := &Machine{ | ||||
| 	machine2InShared1 := &types.Machine{ | ||||
| 		ID:             4, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| @ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 		Hostname:       "test_get_shared_nodes_4", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||
| 		AuthKeyID:      uint(preAuthKey2InShared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machine2InShared1) | ||||
| 	err = app.db.MachineSave(machine2InShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	baseDomain := "foobar.headscale.net" | ||||
| 	dnsConfigOrig := tailcfg.DNSConfig{ | ||||
| @ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | ||||
| 		Proxied: false, | ||||
| 	} | ||||
| 
 | ||||
| 	peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) | ||||
| 	peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	dnsConfig := getMapResponseDNSConfig( | ||||
|  | ||||
| @ -8,6 +8,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"google.golang.org/grpc/codes" | ||||
| @ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.GetUserResponse{User: user.toProto()}, nil | ||||
| 	return &v1.GetUserResponse{User: user.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) CreateUser( | ||||
| @ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.CreateUserResponse{User: user.toProto()}, nil | ||||
| 	return &v1.CreateUserResponse{User: user.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) RenameUser( | ||||
| @ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.RenameUserResponse{User: user.toProto()}, nil | ||||
| 	return &v1.RenameUserResponse{User: user.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) DeleteUser( | ||||
| @ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers( | ||||
| 
 | ||||
| 	response := make([]*v1.User, len(users)) | ||||
| 	for index, user := range users { | ||||
| 		response[index] = user.toProto() | ||||
| 		response[index] = user.Proto() | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace().Caller().Interface("users", response).Msg("") | ||||
| @ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil | ||||
| 	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) ExpirePreAuthKey( | ||||
| @ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( | ||||
| 
 | ||||
| 	response := make([]*v1.PreAuthKey, len(preAuthKeys)) | ||||
| 	for index, key := range preAuthKeys { | ||||
| 		response[index] = key.toProto() | ||||
| 		response[index] = key.Proto() | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil | ||||
| @ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine( | ||||
| 		request.GetKey(), | ||||
| 		request.GetUser(), | ||||
| 		nil, | ||||
| 		RegisterMethodCLI, | ||||
| 		util.RegisterMethodCLI, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) GetMachine( | ||||
| @ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.GetMachineResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.GetMachineResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) SetTags( | ||||
| @ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags( | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) | ||||
| 	err = api.h.db.SetTags(machine, request.GetTags()) | ||||
| 	if err != nil { | ||||
| 		return &v1.SetTagsResponse{ | ||||
| 			Machine: nil, | ||||
| @ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags( | ||||
| 		Strs("tags", request.GetTags()). | ||||
| 		Msg("Changing tags of machine") | ||||
| 
 | ||||
| 	return &v1.SetTagsResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.SetTagsResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func validateTag(tag string) error { | ||||
| @ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine( | ||||
| 		Time("expiry", *machine.Expiry). | ||||
| 		Msg("machine expired") | ||||
| 
 | ||||
| 	return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) RenameMachine( | ||||
| @ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine( | ||||
| 		Str("new_name", request.GetNewName()). | ||||
| 		Msg("machine renamed") | ||||
| 
 | ||||
| 	return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) ListMachines( | ||||
| @ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines( | ||||
| 
 | ||||
| 		response := make([]*v1.Machine, len(machines)) | ||||
| 		for index, machine := range machines { | ||||
| 			response[index] = machine.toProto() | ||||
| 			response[index] = machine.Proto() | ||||
| 		} | ||||
| 
 | ||||
| 		return &v1.ListMachinesResponse{Machines: response}, nil | ||||
| @ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines( | ||||
| 
 | ||||
| 	response := make([]*v1.Machine, len(machines)) | ||||
| 	for index, machine := range machines { | ||||
| 		m := machine.toProto() | ||||
| 		validTags, invalidTags := getTags( | ||||
| 			api.h.aclPolicy, | ||||
| 		m := machine.Proto() | ||||
| 		validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( | ||||
| 			machine, | ||||
| 			api.h.cfg.OIDC.StripEmaildomain, | ||||
| 		) | ||||
| @ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil | ||||
| 	return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) GetRoutes( | ||||
| @ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes( | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.GetRoutesResponse{ | ||||
| 		Routes: Routes(routes).toProto(), | ||||
| 		Routes: types.Routes(routes).Proto(), | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| @ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes( | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.GetMachineRoutesResponse{ | ||||
| 		Routes: Routes(routes).toProto(), | ||||
| 		Routes: types.Routes(routes).Proto(), | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| @ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey( | ||||
| 	ctx context.Context, | ||||
| 	request *v1.ExpireApiKeyRequest, | ||||
| ) (*v1.ExpireApiKeyResponse, error) { | ||||
| 	var apiKey *APIKey | ||||
| 	var apiKey *types.APIKey | ||||
| 	var err error | ||||
| 
 | ||||
| 	apiKey, err = api.h.db.GetAPIKey(request.Prefix) | ||||
| @ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys( | ||||
| 
 | ||||
| 	response := make([]*v1.ApiKey, len(apiKeys)) | ||||
| 	for index, key := range apiKeys { | ||||
| 		response[index] = key.toProto() | ||||
| 		response[index] = key.Proto() | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.ListApiKeysResponse{ApiKeys: response}, nil | ||||
| @ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	newMachine := Machine{ | ||||
| 	newMachine := types.Machine{ | ||||
| 		MachineKey: request.GetKey(), | ||||
| 		Hostname:   request.GetName(), | ||||
| 		GivenName:  givenName, | ||||
| @ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | ||||
| 		LastSeen:             &time.Time{}, | ||||
| 		LastSuccessfulUpdate: &time.Time{}, | ||||
| 
 | ||||
| 		HostInfo: HostInfo(hostinfo), | ||||
| 		HostInfo: types.HostInfo(hostinfo), | ||||
| 	} | ||||
| 
 | ||||
| 	nodeKey := key.NodePublic{} | ||||
| @ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | ||||
| 		registerCacheExpiration, | ||||
| 	) | ||||
| 
 | ||||
| 	return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil | ||||
| 	return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil | ||||
| } | ||||
| 
 | ||||
| func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,142 +0,0 @@ | ||||
| package hscontrol | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"go4.org/netipx" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| // This is borrowed from, and updated to use IPSet | ||||
| // https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 | ||||
| // TODO(kradalby): contribute upstream and make public. | ||||
| var ( | ||||
| 	zeroIP4 = netip.AddrFrom4([4]byte{}) | ||||
| 	zeroIP6 = netip.AddrFrom16([16]byte{}) | ||||
| ) | ||||
| 
 | ||||
| // parseIPSet parses arg as one: | ||||
| // | ||||
| //   - an IP address (IPv4 or IPv6) | ||||
| //   - the string "*" to match everything (both IPv4 & IPv6) | ||||
| //   - a CIDR (e.g. "192.168.0.0/16") | ||||
| //   - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") | ||||
| // | ||||
| // bits, if non-nil, is the legacy SrcBits CIDR length to make a IP | ||||
| // address (without a slash) treated as a CIDR of *bits length. | ||||
| // nolint | ||||
| func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { | ||||
| 	var ipSet netipx.IPSetBuilder | ||||
| 	if arg == "*" { | ||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) | ||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	if strings.Contains(arg, "/") { | ||||
| 		pfx, err := netip.ParsePrefix(arg) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if pfx != pfx.Masked() { | ||||
| 			return nil, fmt.Errorf("%v contains non-network bits set", pfx) | ||||
| 		} | ||||
| 
 | ||||
| 		ipSet.AddPrefix(pfx) | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	if strings.Count(arg, "-") == 1 { | ||||
| 		ip1s, ip2s, _ := strings.Cut(arg, "-") | ||||
| 
 | ||||
| 		ip1, err := netip.ParseAddr(ip1s) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		ip2, err := netip.ParseAddr(ip2s) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		r := netipx.IPRangeFrom(ip1, ip2) | ||||
| 		if !r.IsValid() { | ||||
| 			return nil, fmt.Errorf("invalid IP range %q", arg) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, prefix := range r.Prefixes() { | ||||
| 			ipSet.AddPrefix(prefix) | ||||
| 		} | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	ip, err := netip.ParseAddr(arg) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("invalid IP address %q", arg) | ||||
| 	} | ||||
| 	bits8 := uint8(ip.BitLen()) | ||||
| 	if bits != nil { | ||||
| 		if *bits < 0 || *bits > int(bits8) { | ||||
| 			return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) | ||||
| 		} | ||||
| 		bits8 = uint8(*bits) | ||||
| 	} | ||||
| 
 | ||||
| 	ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) | ||||
| 
 | ||||
| 	return ipSet.IPSet() | ||||
| } | ||||
| 
 | ||||
| type Match struct { | ||||
| 	Srcs  *netipx.IPSet | ||||
| 	Dests *netipx.IPSet | ||||
| } | ||||
| 
 | ||||
| func MatchFromFilterRule(rule tailcfg.FilterRule) Match { | ||||
| 	srcs := new(netipx.IPSetBuilder) | ||||
| 	dests := new(netipx.IPSetBuilder) | ||||
| 
 | ||||
| 	for _, srcIP := range rule.SrcIPs { | ||||
| 		set, _ := parseIPSet(srcIP, nil) | ||||
| 
 | ||||
| 		srcs.AddSet(set) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, dest := range rule.DstPorts { | ||||
| 		set, _ := parseIPSet(dest.IP, nil) | ||||
| 
 | ||||
| 		dests.AddSet(set) | ||||
| 	} | ||||
| 
 | ||||
| 	srcsSet, _ := srcs.IPSet() | ||||
| 	destsSet, _ := dests.IPSet() | ||||
| 
 | ||||
| 	match := Match{ | ||||
| 		Srcs:  srcsSet, | ||||
| 		Dests: destsSet, | ||||
| 	} | ||||
| 
 | ||||
| 	return match | ||||
| } | ||||
| 
 | ||||
| func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { | ||||
| 	for _, ip := range ips { | ||||
| 		if m.Srcs.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (m *Match) DestsContainsIP(ips []netip.Addr) bool { | ||||
| 	for _, ip := range ips { | ||||
| 		if m.Dests.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
| @ -14,6 +14,8 @@ import ( | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc/v3/oidc" | ||||
| 	"github.com/gorilla/mux" | ||||
| 	"github.com/juanfont/headscale/hscontrol/db" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"golang.org/x/oauth2" | ||||
| @ -638,7 +640,7 @@ func getUserName( | ||||
| 	claims *IDTokenClaims, | ||||
| 	stripEmaildomain bool, | ||||
| ) (string, error) { | ||||
| 	userName, err := NormalizeToFQDNRules( | ||||
| 	userName, err := util.NormalizeToFQDNRules( | ||||
| 		claims.Email, | ||||
| 		stripEmaildomain, | ||||
| 	) | ||||
| @ -663,9 +665,9 @@ func getUserName( | ||||
| func (h *Headscale) findOrCreateNewUserForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	userName string, | ||||
| ) (*User, error) { | ||||
| ) (*types.User, error) { | ||||
| 	user, err := h.db.GetUser(userName) | ||||
| 	if errors.Is(err, ErrUserNotFound) { | ||||
| 	if errors.Is(err, db.ErrUserNotFound) { | ||||
| 		user, err = h.db.CreateUser(userName) | ||||
| 
 | ||||
| 		if err != nil { | ||||
| @ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( | ||||
| 
 | ||||
| func (h *Headscale) registerMachineForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	user *User, | ||||
| 	user *types.User, | ||||
| 	nodeKey *key.NodePublic, | ||||
| 	expiry time.Time, | ||||
| ) error { | ||||
| @ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback( | ||||
| 		nodeKey.String(), | ||||
| 		user.Name, | ||||
| 		&expiry, | ||||
| 		RegisterMethodOIDC, | ||||
| 		util.RegisterMethodOIDC, | ||||
| 	); err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| package hscontrol | ||||
| package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| @ -12,6 +12,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"github.com/tailscale/hujson" | ||||
| @ -22,12 +23,12 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errEmptyPolicy       = errors.New("empty policy") | ||||
| 	errInvalidAction     = errors.New("invalid action") | ||||
| 	errInvalidGroup      = errors.New("invalid group") | ||||
| 	errInvalidTag        = errors.New("invalid tag") | ||||
| 	errInvalidPortFormat = errors.New("invalid port format") | ||||
| 	errWildcardIsNeeded  = errors.New("wildcard as port is required for the protocol") | ||||
| 	ErrEmptyPolicy       = errors.New("empty policy") | ||||
| 	ErrInvalidAction     = errors.New("invalid action") | ||||
| 	ErrInvalidGroup      = errors.New("invalid group") | ||||
| 	ErrInvalidTag        = errors.New("invalid tag") | ||||
| 	ErrInvalidPortFormat = errors.New("invalid port format") | ||||
| 	ErrWildcardIsNeeded  = errors.New("wildcard as port is required for the protocol") | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| @ -56,7 +57,7 @@ const ( | ||||
| var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") | ||||
| 
 | ||||
| // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. | ||||
| func (h *Headscale) LoadACLPolicyFromPath(path string) error { | ||||
| func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { | ||||
| 	log.Debug(). | ||||
| 		Str("func", "LoadACLPolicy"). | ||||
| 		Str("path", path). | ||||
| @ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { | ||||
| 
 | ||||
| 	policyFile, err := os.Open(path) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer policyFile.Close() | ||||
| 
 | ||||
| 	policyBytes, err := io.ReadAll(policyFile) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug(). | ||||
| @ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { | ||||
| 
 | ||||
| 	switch filepath.Ext(path) { | ||||
| 	case ".yml", ".yaml": | ||||
| 		return h.LoadACLPolicyFromBytes(policyBytes, "yaml") | ||||
| 		return LoadACLPolicyFromBytes(policyBytes, "yaml") | ||||
| 	} | ||||
| 
 | ||||
| 	return h.LoadACLPolicyFromBytes(policyBytes, "hujson") | ||||
| 	return LoadACLPolicyFromBytes(policyBytes, "hujson") | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { | ||||
| func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { | ||||
| 	var policy ACLPolicy | ||||
| 	switch format { | ||||
| 	case "yaml": | ||||
| 		err := yaml.Unmarshal(acl, &policy) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 	default: | ||||
| 		ast, err := hujson.Parse(acl) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		ast.Standardize() | ||||
| 		acl = ast.Pack() | ||||
| 		err = json.Unmarshal(acl, &policy) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if policy.IsZero() { | ||||
| 		return errEmptyPolicy | ||||
| 		return nil, ErrEmptyPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	h.aclPolicy = &policy | ||||
| 
 | ||||
| 	return h.UpdateACLRules() | ||||
| 	return &policy, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) UpdateACLRules() error { | ||||
| 	machines, err := h.db.ListMachines() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| // TODO(kradalby): This needs to be replace with something that generates | ||||
| // the rules as needed and not stores it on the global object, rules are | ||||
| // per node and that should be taken into account. | ||||
| func GenerateFilterRules( | ||||
| 	policy *ACLPolicy, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ||||
| 	if policy == nil { | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	if h.aclPolicy == nil { | ||||
| 		return errEmptyPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain) | ||||
| 	rules, err := policy.generateFilterRules(machines, stripEmailDomain) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace().Interface("ACL", rules).Msg("ACL rules generated") | ||||
| 	h.aclRules = rules | ||||
| 
 | ||||
| 	var sshPolicy *tailcfg.SSHPolicy | ||||
| 	if featureEnableSSH() { | ||||
| 		sshRules, err := h.generateSSHRules() | ||||
| 		sshRules, err := generateSSHRules(policy, machines, stripEmailDomain) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||
| 		} | ||||
| 		log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") | ||||
| 		if h.sshPolicy == nil { | ||||
| 			h.sshPolicy = &tailcfg.SSHPolicy{} | ||||
| 		if sshPolicy == nil { | ||||
| 			sshPolicy = &tailcfg.SSHPolicy{} | ||||
| 		} | ||||
| 		h.sshPolicy.Rules = sshRules | ||||
| 	} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 { | ||||
| 		sshPolicy.Rules = sshRules | ||||
| 	} else if policy != nil && len(policy.SSHs) > 0 { | ||||
| 		log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| 	return rules, sshPolicy, nil | ||||
| } | ||||
| 
 | ||||
| // generateFilterRules takes a set of machines and an ACLPolicy and generates a | ||||
| // set of Tailscale compatible FilterRules used to allow traffic on clients. | ||||
| func (pol *ACLPolicy) generateFilterRules( | ||||
| 	machines []Machine, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]tailcfg.FilterRule, error) { | ||||
| 	rules := []tailcfg.FilterRule{} | ||||
| 
 | ||||
| 	for index, acl := range pol.ACLs { | ||||
| 		if acl.Action != "accept" { | ||||
| 			return nil, errInvalidAction | ||||
| 			return nil, ErrInvalidAction | ||||
| 		} | ||||
| 
 | ||||
| 		srcIPs := []string{} | ||||
| @ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules( | ||||
| 	return rules, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | ||||
| func generateSSHRules( | ||||
| 	policy *ACLPolicy, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]*tailcfg.SSHRule, error) { | ||||
| 	rules := []*tailcfg.SSHRule{} | ||||
| 
 | ||||
| 	if h.aclPolicy == nil { | ||||
| 		return nil, errEmptyPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := h.db.ListMachines() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	if policy == nil { | ||||
| 		return nil, ErrEmptyPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	acceptAction := tailcfg.SSHAction{ | ||||
| @ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | ||||
| 		AllowLocalPortForwarding: false, | ||||
| 	} | ||||
| 
 | ||||
| 	for index, sshACL := range h.aclPolicy.SSHs { | ||||
| 	for index, sshACL := range policy.SSHs { | ||||
| 		action := rejectAction | ||||
| 		switch sshACL.Action { | ||||
| 		case "accept": | ||||
| @ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | ||||
| 			} | ||||
| 		default: | ||||
| 			log.Error(). | ||||
| 				Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) | ||||
| 				Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action) | ||||
| 
 | ||||
| 			return nil, err | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) | ||||
| @ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | ||||
| 					Any: true, | ||||
| 				}) | ||||
| 			} else if isGroup(rawSrc) { | ||||
| 				users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain) | ||||
| 				users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Msgf("Error parsing SSH %d, Source %d", index, innerIndex) | ||||
| @ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | ||||
| 					}) | ||||
| 				} | ||||
| 			} else { | ||||
| 				expandedSrcs, err := h.aclPolicy.expandAlias( | ||||
| 				expandedSrcs, err := policy.ExpandAlias( | ||||
| 					machines, | ||||
| 					rawSrc, | ||||
| 					h.cfg.OIDC.StripEmaildomain, | ||||
| 					stripEmailDomain, | ||||
| 				) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| @ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { | ||||
| // with the given src alias. | ||||
| func (pol *ACLPolicy) getIPsFromSource( | ||||
| 	src string, | ||||
| 	machines []Machine, | ||||
| 	machines types.Machines, | ||||
| 	stripEmaildomain bool, | ||||
| ) ([]string, error) { | ||||
| 	ipSet, err := pol.expandAlias(machines, src, stripEmaildomain) | ||||
| 	ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) | ||||
| 	if err != nil { | ||||
| 		return []string{}, err | ||||
| 	} | ||||
| @ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource( | ||||
| // which are associated with the dest alias. | ||||
| func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||
| 	dest string, | ||||
| 	machines []Machine, | ||||
| 	machines types.Machines, | ||||
| 	needsWildcard bool, | ||||
| 	stripEmaildomain bool, | ||||
| ) ([]tailcfg.NetPortRange, error) { | ||||
| @ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||
| 			return nil, fmt.Errorf( | ||||
| 				"failed to parse destination, tokens %v: %w", | ||||
| 				tokens, | ||||
| 				errInvalidPortFormat, | ||||
| 				ErrInvalidPortFormat, | ||||
| 			) | ||||
| 		} else { | ||||
| 			tokens = []string{maybeIPv6Str, port} | ||||
| @ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||
| 		alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) | ||||
| 	} | ||||
| 
 | ||||
| 	expanded, err := pol.expandAlias( | ||||
| 	expanded, err := pol.ExpandAlias( | ||||
| 		machines, | ||||
| 		alias, | ||||
| 		stripEmaildomain, | ||||
| @ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) { | ||||
| // - an ip | ||||
| // - a cidr | ||||
| // and transform these in IPAddresses. | ||||
| func (pol *ACLPolicy) expandAlias( | ||||
| 	machines Machines, | ||||
| func (pol *ACLPolicy) ExpandAlias( | ||||
| 	machines types.Machines, | ||||
| 	alias string, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	if isWildcard(alias) { | ||||
| 		return parseIPSet("*", nil) | ||||
| 		return util.ParseIPSet("*", nil) | ||||
| 	} | ||||
| 
 | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| @ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias( | ||||
| 	// if alias is an host | ||||
| 	// Note, this is recursive. | ||||
| 	if h, ok := pol.Hosts[alias]; ok { | ||||
| 		log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") | ||||
| 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | ||||
| 
 | ||||
| 		return pol.expandAlias(machines, h.String(), stripEmailDomain) | ||||
| 		return pol.ExpandAlias(machines, h.String(), stripEmailDomain) | ||||
| 	} | ||||
| 
 | ||||
| 	// if alias is an IP | ||||
| @ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias( | ||||
| // we assume in this function that we only have nodes from 1 user. | ||||
| func excludeCorrectlyTaggedNodes( | ||||
| 	aclPolicy *ACLPolicy, | ||||
| 	nodes []Machine, | ||||
| 	nodes types.Machines, | ||||
| 	user string, | ||||
| 	stripEmailDomain bool, | ||||
| ) []Machine { | ||||
| 	out := []Machine{} | ||||
| ) types.Machines { | ||||
| 	out := types.Machines{} | ||||
| 	tags := []string{} | ||||
| 	for tag := range aclPolicy.TagOwners { | ||||
| 		owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) | ||||
| @ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err | ||||
| 	} | ||||
| 
 | ||||
| 	if needsWildcard { | ||||
| 		return nil, errWildcardIsNeeded | ||||
| 		return nil, ErrWildcardIsNeeded | ||||
| 	} | ||||
| 
 | ||||
| 	ports := []tailcfg.PortRange{} | ||||
| @ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err | ||||
| 			}) | ||||
| 
 | ||||
| 		default: | ||||
| 			return nil, errInvalidPortFormat | ||||
| 			return nil, ErrInvalidPortFormat | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &ports, nil | ||||
| } | ||||
| 
 | ||||
| func filterMachinesByUser(machines []Machine, user string) []Machine { | ||||
| 	out := []Machine{} | ||||
| func filterMachinesByUser(machines types.Machines, user string) types.Machines { | ||||
| 	out := types.Machines{} | ||||
| 	for _, machine := range machines { | ||||
| 		if machine.User.Name == user { | ||||
| 			out = append(out, machine) | ||||
| @ -664,7 +664,7 @@ func getTagOwners( | ||||
| 	if !ok { | ||||
| 		return []string{}, fmt.Errorf( | ||||
| 			"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", | ||||
| 			errInvalidTag, | ||||
| 			ErrInvalidTag, | ||||
| 			tag, | ||||
| 		) | ||||
| 	} | ||||
| @ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup( | ||||
| 		return []string{}, fmt.Errorf( | ||||
| 			"group %v isn't registered. %w", | ||||
| 			group, | ||||
| 			errInvalidGroup, | ||||
| 			ErrInvalidGroup, | ||||
| 		) | ||||
| 	} | ||||
| 	for _, group := range aclGroups { | ||||
| 		if isGroup(group) { | ||||
| 			return []string{}, fmt.Errorf( | ||||
| 				"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", | ||||
| 				errInvalidGroup, | ||||
| 				ErrInvalidGroup, | ||||
| 			) | ||||
| 		} | ||||
| 		grp, err := NormalizeToFQDNRules(group, stripEmailDomain) | ||||
| 		grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) | ||||
| 		if err != nil { | ||||
| 			return []string{}, fmt.Errorf( | ||||
| 				"failed to normalize group %q, err: %w", | ||||
| 				group, | ||||
| 				errInvalidGroup, | ||||
| 				ErrInvalidGroup, | ||||
| 			) | ||||
| 		} | ||||
| 		users = append(users, grp) | ||||
| @ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup( | ||||
| 
 | ||||
| func (pol *ACLPolicy) getIPsFromGroup( | ||||
| 	group string, | ||||
| 	machines Machines, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| @ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup( | ||||
| 
 | ||||
| func (pol *ACLPolicy) getIPsFromTag( | ||||
| 	alias string, | ||||
| 	machines Machines, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| @ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag( | ||||
| 	// find tag owners | ||||
| 	owners, err := getTagOwners(pol, alias, stripEmailDomain) | ||||
| 	if err != nil { | ||||
| 		if errors.Is(err, errInvalidTag) { | ||||
| 		if errors.Is(err, ErrInvalidTag) { | ||||
| 			ipSet, _ := build.IPSet() | ||||
| 			if len(ipSet.Prefixes()) == 0 { | ||||
| 				return ipSet, fmt.Errorf( | ||||
| 					"%w. %v isn't owned by a TagOwner and no forced tags are defined", | ||||
| 					errInvalidTag, | ||||
| 					ErrInvalidTag, | ||||
| 					alias, | ||||
| 				) | ||||
| 			} | ||||
| @ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag( | ||||
| 
 | ||||
| func (pol *ACLPolicy) getIPsForUser( | ||||
| 	user string, | ||||
| 	machines Machines, | ||||
| 	machines types.Machines, | ||||
| 	stripEmailDomain bool, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| @ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser( | ||||
| 
 | ||||
| func (pol *ACLPolicy) getIPsFromSingleIP( | ||||
| 	ip netip.Addr, | ||||
| 	machines Machines, | ||||
| 	machines types.Machines, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") | ||||
| 	log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") | ||||
| 
 | ||||
| 	matches := machines.FilterByIP(ip) | ||||
| 
 | ||||
| @ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP( | ||||
| 
 | ||||
| func (pol *ACLPolicy) getIPsFromIPPrefix( | ||||
| 	prefix netip.Prefix, | ||||
| 	machines Machines, | ||||
| 	machines types.Machines, | ||||
| ) (*netipx.IPSet, error) { | ||||
| 	log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") | ||||
| 	build := netipx.IPSetBuilder{} | ||||
| @ -862,3 +862,65 @@ func isGroup(str string) bool { | ||||
| func isTag(str string) bool { | ||||
| 	return strings.HasPrefix(str, "tag:") | ||||
| } | ||||
| 
 | ||||
| // getTags will return the tags of the current machine. | ||||
| // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. | ||||
| // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. | ||||
| func (pol *ACLPolicy) GetTagsOfMachine( | ||||
| 	machine types.Machine, | ||||
| 	stripEmailDomain bool, | ||||
| ) ([]string, []string) { | ||||
| 	validTags := make([]string, 0) | ||||
| 	invalidTags := make([]string, 0) | ||||
| 
 | ||||
| 	validTagMap := make(map[string]bool) | ||||
| 	invalidTagMap := make(map[string]bool) | ||||
| 	for _, tag := range machine.HostInfo.RequestTags { | ||||
| 		owners, err := getTagOwners(pol, tag, stripEmailDomain) | ||||
| 		if errors.Is(err, ErrInvalidTag) { | ||||
| 			invalidTagMap[tag] = true | ||||
| 
 | ||||
| 			continue | ||||
| 		} | ||||
| 		var found bool | ||||
| 		for _, owner := range owners { | ||||
| 			if machine.User.Name == owner { | ||||
| 				found = true | ||||
| 			} | ||||
| 		} | ||||
| 		if found { | ||||
| 			validTagMap[tag] = true | ||||
| 		} else { | ||||
| 			invalidTagMap[tag] = true | ||||
| 		} | ||||
| 	} | ||||
| 	for tag := range invalidTagMap { | ||||
| 		invalidTags = append(invalidTags, tag) | ||||
| 	} | ||||
| 	for tag := range validTagMap { | ||||
| 		validTags = append(validTags, tag) | ||||
| 	} | ||||
| 
 | ||||
| 	return validTags, invalidTags | ||||
| } | ||||
| 
 | ||||
| // FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. | ||||
| func FilterMachinesByACL( | ||||
| 	machine *types.Machine, | ||||
| 	machines types.Machines, | ||||
| 	filter []tailcfg.FilterRule, | ||||
| ) types.Machines { | ||||
| 	result := types.Machines{} | ||||
| 
 | ||||
| 	for index, peer := range machines { | ||||
| 		if peer.ID == machine.ID { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) { | ||||
| 			result = append(result, peer) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return result | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,4 +1,4 @@ | ||||
| package hscontrol | ||||
| package policy | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
							
								
								
									
										61
									
								
								hscontrol/policy/matcher/matcher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								hscontrol/policy/matcher/matcher.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | ||||
| package matcher | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"go4.org/netipx" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| type Match struct { | ||||
| 	Srcs  *netipx.IPSet | ||||
| 	Dests *netipx.IPSet | ||||
| } | ||||
| 
 | ||||
| func MatchFromFilterRule(rule tailcfg.FilterRule) Match { | ||||
| 	srcs := new(netipx.IPSetBuilder) | ||||
| 	dests := new(netipx.IPSetBuilder) | ||||
| 
 | ||||
| 	for _, srcIP := range rule.SrcIPs { | ||||
| 		set, _ := util.ParseIPSet(srcIP, nil) | ||||
| 
 | ||||
| 		srcs.AddSet(set) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, dest := range rule.DstPorts { | ||||
| 		set, _ := util.ParseIPSet(dest.IP, nil) | ||||
| 
 | ||||
| 		dests.AddSet(set) | ||||
| 	} | ||||
| 
 | ||||
| 	srcsSet, _ := srcs.IPSet() | ||||
| 	destsSet, _ := dests.IPSet() | ||||
| 
 | ||||
| 	match := Match{ | ||||
| 		Srcs:  srcsSet, | ||||
| 		Dests: destsSet, | ||||
| 	} | ||||
| 
 | ||||
| 	return match | ||||
| } | ||||
| 
 | ||||
| func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { | ||||
| 	for _, ip := range ips { | ||||
| 		if m.Srcs.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (m *Match) DestsContainsIP(ips []netip.Addr) bool { | ||||
| 	for _, ip := range ips { | ||||
| 		if m.Dests.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										1
									
								
								hscontrol/policy/matcher/matcher_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								hscontrol/policy/matcher/matcher_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| package matcher | ||||
| @ -9,6 +9,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"gorm.io/gorm" | ||||
| @ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon( | ||||
| 		// that we rely on a method that calls back some how (OpenID or CLI) | ||||
| 		// We create the machine and then keep it around until a callback | ||||
| 		// happens | ||||
| 		newMachine := Machine{ | ||||
| 		newMachine := types.Machine{ | ||||
| 			MachineKey: util.MachinePublicKeyStripPrefix(machineKey), | ||||
| 			Hostname:   registerRequest.Hostinfo.Hostname, | ||||
| 			GivenName:  givenName, | ||||
| @ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon( | ||||
| 			[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), | ||||
| 		) | ||||
| 		if err != nil || storedMachineKey.IsZero() { | ||||
| 			machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) | ||||
| 			if err := h.db.db.Save(&machine).Error; err != nil { | ||||
| 			if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil { | ||||
| 				log.Error(). | ||||
| 					Caller(). | ||||
| 					Str("func", "RegistrationHandler"). | ||||
| @ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon( | ||||
| 
 | ||||
| 			// If machine is not expired, and it is register, we have a already accepted this machine, | ||||
| 			// let it proceed with a valid registration | ||||
| 			if !machine.isExpired() { | ||||
| 			if !machine.IsExpired() { | ||||
| 				h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) | ||||
| 
 | ||||
| 				return | ||||
| @ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon( | ||||
| 
 | ||||
| 		// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration | ||||
| 		if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && | ||||
| 			!machine.isExpired() { | ||||
| 			!machine.IsExpired() { | ||||
| 			h.handleMachineRefreshKeyCommon( | ||||
| 				writer, | ||||
| 				registerRequest, | ||||
| @ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 		Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) | ||||
| 	resp := tailcfg.RegisterResponse{} | ||||
| 
 | ||||
| 	pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) | ||||
| 	pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 				Err(err). | ||||
| 				Msg("Cannot encode message") | ||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 				Inc() | ||||
| 
 | ||||
| 			return | ||||
| @ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 			Msg("Failed authentication via AuthKey") | ||||
| 
 | ||||
| 		if pak != nil { | ||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 				Inc() | ||||
| 		} else { | ||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() | ||||
| 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() | ||||
| 		} | ||||
| 
 | ||||
| 		return | ||||
| @ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		aclTags := pak.toProto().AclTags | ||||
| 		aclTags := pak.Proto().AclTags | ||||
| 		if len(aclTags) > 0 { | ||||
| 			// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login | ||||
| 			err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) | ||||
| 			err = h.db.SetTags(machine, aclTags) | ||||
| 
 | ||||
| 			if err != nil { | ||||
| 				log.Error(). | ||||
| @ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		machineToRegister := Machine{ | ||||
| 		machineToRegister := types.Machine{ | ||||
| 			Hostname:       registerRequest.Hostinfo.Hostname, | ||||
| 			GivenName:      givenName, | ||||
| 			UserID:         pak.User.ID, | ||||
| 			MachineKey:     util.MachinePublicKeyStripPrefix(machineKey), | ||||
| 			RegisterMethod: RegisterMethodAuthKey, | ||||
| 			RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 			Expiry:         ®isterRequest.Expiry, | ||||
| 			NodeKey:        nodeKey, | ||||
| 			LastSeen:       &now, | ||||
| 			AuthKeyID:      uint(pak.ID), | ||||
| 			ForcedTags:     pak.toProto().AclTags, | ||||
| 			ForcedTags:     pak.Proto().AclTags, | ||||
| 		} | ||||
| 
 | ||||
| 		machine, err = h.db.RegisterMachine( | ||||
| @ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 				Bool("noise", isNoise). | ||||
| 				Err(err). | ||||
| 				Msg("could not register machine") | ||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 				Inc() | ||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||
| 
 | ||||
| @ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 			Bool("noise", isNoise). | ||||
| 			Err(err). | ||||
| 			Msg("Failed to use pre-auth key") | ||||
| 		machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 		machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 			Inc() | ||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||
| 
 | ||||
| @ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 	} | ||||
| 
 | ||||
| 	resp.MachineAuthorized = true | ||||
| 	resp.User = *pak.User.toTailscaleUser() | ||||
| 	resp.User = *pak.User.TailscaleUser() | ||||
| 	// Provide LoginName when registering with pre-auth key | ||||
| 	// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* | ||||
| 	resp.Login = *pak.User.toTailscaleLogin() | ||||
| 	resp.Login = *pak.User.TailscaleLogin() | ||||
| 
 | ||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||
| 	if err != nil { | ||||
| @ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon( | ||||
| 			Str("machine", registerRequest.Hostinfo.Hostname). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot encode message") | ||||
| 		machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 		machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||
| 			Inc() | ||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 	machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). | ||||
| 	machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). | ||||
| 		Inc() | ||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | ||||
| 	writer.WriteHeader(http.StatusOK) | ||||
| @ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon( | ||||
| 
 | ||||
| func (h *Headscale) handleMachineLogOutCommon( | ||||
| 	writer http.ResponseWriter, | ||||
| 	machine Machine, | ||||
| 	machine types.Machine, | ||||
| 	machineKey key.MachinePublic, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
| @ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon( | ||||
| 	resp.AuthURL = "" | ||||
| 	resp.MachineAuthorized = false | ||||
| 	resp.NodeKeyExpired = true | ||||
| 	resp.User = *machine.User.toTailscaleUser() | ||||
| 	resp.User = *machine.User.TailscaleUser() | ||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| @ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.isEphemeral() { | ||||
| 	if machine.IsEphemeral() { | ||||
| 		err = h.db.HardDeleteMachine(&machine) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| @ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon( | ||||
| 
 | ||||
| func (h *Headscale) handleMachineValidRegistrationCommon( | ||||
| 	writer http.ResponseWriter, | ||||
| 	machine Machine, | ||||
| 	machine types.Machine, | ||||
| 	machineKey key.MachinePublic, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
| @ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon( | ||||
| 
 | ||||
| 	resp.AuthURL = "" | ||||
| 	resp.MachineAuthorized = true | ||||
| 	resp.User = *machine.User.toTailscaleUser() | ||||
| 	resp.Login = *machine.User.toTailscaleLogin() | ||||
| 	resp.User = *machine.User.TailscaleUser() | ||||
| 	resp.Login = *machine.User.TailscaleLogin() | ||||
| 
 | ||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||
| 	if err != nil { | ||||
| @ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon( | ||||
| func (h *Headscale) handleMachineRefreshKeyCommon( | ||||
| 	writer http.ResponseWriter, | ||||
| 	registerRequest tailcfg.RegisterRequest, | ||||
| 	machine Machine, | ||||
| 	machine types.Machine, | ||||
| 	machineKey key.MachinePublic, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
| @ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | ||||
| 		Bool("noise", isNoise). | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Msg("We have the OldNodeKey in the database. This is a key refresh") | ||||
| 	machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) | ||||
| 
 | ||||
| 	if err := h.db.db.Save(&machine).Error; err != nil { | ||||
| 	err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| @ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | ||||
| 	} | ||||
| 
 | ||||
| 	resp.AuthURL = "" | ||||
| 	resp.User = *machine.User.toTailscaleUser() | ||||
| 	resp.User = *machine.User.TailscaleUser() | ||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| @ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | ||||
| func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( | ||||
| 	writer http.ResponseWriter, | ||||
| 	registerRequest tailcfg.RegisterRequest, | ||||
| 	machine Machine, | ||||
| 	machine types.Machine, | ||||
| 	machineKey key.MachinePublic, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
|  | ||||
| @ -6,6 +6,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"tailscale.com/tailcfg" | ||||
| @ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName") | ||||
| func (h *Headscale) handlePollCommon( | ||||
| 	writer http.ResponseWriter, | ||||
| 	ctx context.Context, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
| 	machine.Hostname = mapRequest.Hostinfo.Hostname | ||||
| 	machine.HostInfo = HostInfo(*mapRequest.Hostinfo) | ||||
| 	machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) | ||||
| 	machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) | ||||
| 	now := time.Now().UTC() | ||||
| 
 | ||||
| 	err := h.db.processMachineRoutes(machine) | ||||
| 	err := h.db.ProcessMachineRoutes(machine) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon( | ||||
| 	} | ||||
| 
 | ||||
| 	// update ACLRules with peer informations (to update server tags if necessary) | ||||
| 	if h.aclPolicy != nil { | ||||
| 		err := h.UpdateACLRules() | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Bool("noise", isNoise). | ||||
| 				Str("machine", machine.Hostname). | ||||
| 				Err(err) | ||||
| 		} | ||||
| 	if h.ACLPolicy != nil { | ||||
| 		// TODO(kradalby): Since this is not blocking, I might have introduced a bug here. | ||||
| 		// It will be resolved later as we change up the policy stuff. | ||||
| 		h.policyUpdateChan <- struct{}{} | ||||
| 
 | ||||
| 		// update routes with peer information | ||||
| 		err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) | ||||
| 		err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| @ -78,19 +74,17 @@ func (h *Headscale) handlePollCommon( | ||||
| 		machine.LastSeen = &now | ||||
| 	} | ||||
| 
 | ||||
| 	if err := h.db.db.Updates(machine).Error; err != nil { | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Str("handler", "PollNetMap"). | ||||
| 				Bool("noise", isNoise). | ||||
| 				Str("node_key", machine.NodeKey). | ||||
| 				Str("machine", machine.Hostname). | ||||
| 				Err(err). | ||||
| 				Msg("Failed to persist/update machine in the database") | ||||
| 			http.Error(writer, "", http.StatusInternalServerError) | ||||
| 	if err := h.db.MachineSave(machine); err != nil { | ||||
| 		log.Error(). | ||||
| 			Str("handler", "PollNetMap"). | ||||
| 			Bool("noise", isNoise). | ||||
| 			Str("node_key", machine.NodeKey). | ||||
| 			Str("machine", machine.Hostname). | ||||
| 			Err(err). | ||||
| 			Msg("Failed to persist/update machine in the database") | ||||
| 		http.Error(writer, "", http.StatusInternalServerError) | ||||
| 
 | ||||
| 			return | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) | ||||
| @ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon( | ||||
| func (h *Headscale) pollNetMapStream( | ||||
| 	writer http.ResponseWriter, | ||||
| 	ctxReq context.Context, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	pollDataChan chan []byte, | ||||
| 	keepAliveChan chan []byte, | ||||
| @ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream( | ||||
| 			updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). | ||||
| 				Inc() | ||||
| 
 | ||||
| 			if h.db.isOutdated(machine, h.getLastStateChange()) { | ||||
| 			if h.db.IsOutdated(machine, h.getLastStateChange()) { | ||||
| 				var lastUpdate time.Time | ||||
| 				if machine.LastSuccessfulUpdate != nil { | ||||
| 					lastUpdate = *machine.LastSuccessfulUpdate | ||||
| @ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker( | ||||
| 	updateChan chan struct{}, | ||||
| 	keepAliveChan chan []byte, | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| 	isNoise bool, | ||||
| ) { | ||||
| 	keepAliveTicker := time.NewTicker(keepAliveInterval) | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/klauspost/compress/zstd" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| @ -15,7 +16,7 @@ import ( | ||||
| 
 | ||||
| func (h *Headscale) getMapResponseData( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| 	isNoise bool, | ||||
| ) ([]byte, error) { | ||||
| 	mapResponse, err := h.generateMapResponse(mapRequest, machine) | ||||
| @ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData( | ||||
| 
 | ||||
| func (h *Headscale) getMapKeepAliveResponseData( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *Machine, | ||||
| 	machine *types.Machine, | ||||
| 	isNoise bool, | ||||
| ) ([]byte, error) { | ||||
| 	keepAliveResponse := tailcfg.MapResponse{ | ||||
|  | ||||
| @ -18,7 +18,7 @@ type Suite struct{} | ||||
| 
 | ||||
| var ( | ||||
| 	tmpDir string | ||||
| 	app    Headscale | ||||
| 	app    *Headscale | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) SetUpTest(c *check.C) { | ||||
| @ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) { | ||||
| 		os.RemoveAll(tmpDir) | ||||
| 	} | ||||
| 	var err error | ||||
| 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test") | ||||
| 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test2") | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 	cfg := Config{ | ||||
| 		PrivateKeyPath:      tmpDir + "/private.key", | ||||
| 		NoisePrivateKeyPath: tmpDir + "/noise_private.key", | ||||
| 		DBtype:              "sqlite3", | ||||
| 		DBpath:              tmpDir + "/headscale_test.db", | ||||
| 		IPPrefixes: []netip.Prefix{ | ||||
| 			netip.MustParsePrefix("10.27.0.0/23"), | ||||
| 		}, | ||||
| @ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) { | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift | ||||
| 	app = Headscale{ | ||||
| 		cfg:      &cfg, | ||||
| 		dbType:   "sqlite3", | ||||
| 		dbString: tmpDir + "/headscale_test.db", | ||||
| 
 | ||||
| 		stateUpdateChan:       make(chan struct{}), | ||||
| 		cancelStateUpdateChan: make(chan struct{}), | ||||
| 	} | ||||
| 
 | ||||
| 	go app.watchStateChannel() | ||||
| 
 | ||||
| 	db, err := NewHeadscaleDatabase( | ||||
| 		app.dbType, | ||||
| 		app.dbString, | ||||
| 		cfg.OIDC.StripEmaildomain, | ||||
| 		false, | ||||
| 		app.stateUpdateChan, | ||||
| 		cfg.IPPrefixes, | ||||
| 		"", | ||||
| 	) | ||||
| 	app, err = NewHeadscale(&cfg) | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 	app.db = db | ||||
| } | ||||
							
								
								
									
										41
									
								
								hscontrol/types/api_key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								hscontrol/types/api_key.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,41 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| ) | ||||
| 
 | ||||
| // APIKey describes the datamodel for API keys used to remotely authenticate with | ||||
| // headscale. | ||||
| type APIKey struct { | ||||
| 	ID     uint64 `gorm:"primary_key"` | ||||
| 	Prefix string `gorm:"uniqueIndex"` | ||||
| 	Hash   []byte | ||||
| 
 | ||||
| 	CreatedAt  *time.Time | ||||
| 	Expiration *time.Time | ||||
| 	LastSeen   *time.Time | ||||
| } | ||||
| 
 | ||||
| func (key *APIKey) Proto() *v1.ApiKey { | ||||
| 	protoKey := v1.ApiKey{ | ||||
| 		Id:     key.ID, | ||||
| 		Prefix: key.Prefix, | ||||
| 	} | ||||
| 
 | ||||
| 	if key.Expiration != nil { | ||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.CreatedAt != nil { | ||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.LastSeen != nil { | ||||
| 		protoKey.LastSeen = timestamppb.New(*key.LastSeen) | ||||
| 	} | ||||
| 
 | ||||
| 	return &protoKey | ||||
| } | ||||
							
								
								
									
										108
									
								
								hscontrol/types/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								hscontrol/types/common.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| var ErrCannotParsePrefix = errors.New("cannot parse prefix") | ||||
| 
 | ||||
| // This is a "wrapper" type around tailscales | ||||
| // Hostinfo to allow us to add database "serialization" | ||||
| // methods. This allows us to use a typed values throughout | ||||
| // the code and not have to marshal/unmarshal and error | ||||
| // check all over the code. | ||||
| type HostInfo tailcfg.Hostinfo | ||||
| 
 | ||||
| func (hi *HostInfo) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, hi) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), hi) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (hi HostInfo) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(hi) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type IPPrefix netip.Prefix | ||||
| 
 | ||||
| func (i *IPPrefix) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case string: | ||||
| 		prefix, err := netip.ParsePrefix(value) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		*i = IPPrefix(prefix) | ||||
| 
 | ||||
| 		return nil | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i IPPrefix) Value() (driver.Value, error) { | ||||
| 	prefixStr := netip.Prefix(i).String() | ||||
| 
 | ||||
| 	return prefixStr, nil | ||||
| } | ||||
| 
 | ||||
| type IPPrefixes []netip.Prefix | ||||
| 
 | ||||
| func (i *IPPrefixes) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, i) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), i) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i IPPrefixes) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(i) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type StringList []string | ||||
| 
 | ||||
| func (i *StringList) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, i) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), i) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (i StringList) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(i) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
							
								
								
									
										254
									
								
								hscontrol/types/machine.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								hscontrol/types/machine.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||
| 	"go4.org/netipx" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// TODO(kradalby): Move out of here when we got circdeps under control. | ||||
| 	keepAliveInterval = 60 * time.Second | ||||
| ) | ||||
| 
 | ||||
| var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") | ||||
| 
 | ||||
| // Machine is a Headscale client. | ||||
| type Machine struct { | ||||
| 	ID          uint64 `gorm:"primary_key"` | ||||
| 	MachineKey  string `gorm:"type:varchar(64);unique_index"` | ||||
| 	NodeKey     string | ||||
| 	DiscoKey    string | ||||
| 	IPAddresses MachineAddresses | ||||
| 
 | ||||
| 	// Hostname represents the name given by the Tailscale | ||||
| 	// client during registration | ||||
| 	Hostname string | ||||
| 
 | ||||
| 	// Givenname represents either: | ||||
| 	// a DNS normalized version of Hostname | ||||
| 	// a valid name set by the User | ||||
| 	// | ||||
| 	// GivenName is the name used in all DNS related | ||||
| 	// parts of headscale. | ||||
| 	GivenName string `gorm:"type:varchar(63);unique_index"` | ||||
| 	UserID    uint | ||||
| 	User      User `gorm:"foreignKey:UserID"` | ||||
| 
 | ||||
| 	RegisterMethod string | ||||
| 
 | ||||
| 	ForcedTags StringList | ||||
| 
 | ||||
| 	// TODO(kradalby): This seems like irrelevant information? | ||||
| 	AuthKeyID uint | ||||
| 	AuthKey   *PreAuthKey | ||||
| 
 | ||||
| 	LastSeen             *time.Time | ||||
| 	LastSuccessfulUpdate *time.Time | ||||
| 	Expiry               *time.Time | ||||
| 
 | ||||
| 	HostInfo  HostInfo | ||||
| 	Endpoints StringList | ||||
| 
 | ||||
| 	CreatedAt time.Time | ||||
| 	UpdatedAt time.Time | ||||
| 	DeletedAt *time.Time | ||||
| } | ||||
| 
 | ||||
| type ( | ||||
| 	Machines  []Machine | ||||
| 	MachinesP []*Machine | ||||
| ) | ||||
| 
 | ||||
| type MachineAddresses []netip.Addr | ||||
| 
 | ||||
| func (ma MachineAddresses) ToStringSlice() []string { | ||||
| 	strSlice := make([]string, 0, len(ma)) | ||||
| 	for _, addr := range ma { | ||||
| 		strSlice = append(strSlice, addr.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	return strSlice | ||||
| } | ||||
| 
 | ||||
| // AppendToIPSet adds the individual ips in MachineAddresses to a | ||||
| // given netipx.IPSetBuilder. | ||||
| func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { | ||||
| 	for _, ip := range ma { | ||||
| 		build.Add(ip) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (ma *MachineAddresses) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case string: | ||||
| 		addresses := strings.Split(value, ",") | ||||
| 		*ma = (*ma)[:0] | ||||
| 		for _, addr := range addresses { | ||||
| 			if len(addr) < 1 { | ||||
| 				continue | ||||
| 			} | ||||
| 			parsed, err := netip.ParseAddr(addr) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			*ma = append(*ma, parsed) | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (ma MachineAddresses) Value() (driver.Value, error) { | ||||
| 	addresses := strings.Join(ma.ToStringSlice(), ",") | ||||
| 
 | ||||
| 	return addresses, nil | ||||
| } | ||||
| 
 | ||||
| // IsExpired returns whether the machine registration has expired. | ||||
| func (machine Machine) IsExpired() bool { | ||||
| 	// If Expiry is not set, the client has not indicated that | ||||
| 	// it wants an expiry time, it is therefor considered | ||||
| 	// to mean "not expired" | ||||
| 	if machine.Expiry == nil || machine.Expiry.IsZero() { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return time.Now().UTC().After(*machine.Expiry) | ||||
| } | ||||
| 
 | ||||
| // IsOnline returns if the machine is connected to Headscale. | ||||
| // This is really a naive implementation, as we don't really see | ||||
| // if there is a working connection between the client and the server. | ||||
| func (machine *Machine) IsOnline() bool { | ||||
| 	if machine.LastSeen == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.IsExpired() { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) | ||||
| } | ||||
| 
 | ||||
| // IsEphemeral returns if the machine is registered as an Ephemeral node. | ||||
| // https://tailscale.com/kb/1111/ephemeral-nodes/ | ||||
| func (machine *Machine) IsEphemeral() bool { | ||||
| 	return machine.AuthKey != nil && machine.AuthKey.Ephemeral | ||||
| } | ||||
| 
 | ||||
| func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { | ||||
| 	for _, rule := range filter { | ||||
| 		// TODO(kradalby): Cache or pregen this | ||||
| 		matcher := matcher.MatchFromFilterRule(rule) | ||||
| 
 | ||||
| 		if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (machines Machines) FilterByIP(ip netip.Addr) Machines { | ||||
| 	found := make(Machines, 0) | ||||
| 
 | ||||
| 	for _, machine := range machines { | ||||
| 		for _, mIP := range machine.IPAddresses { | ||||
| 			if ip == mIP { | ||||
| 				found = append(found, machine) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return found | ||||
| } | ||||
| 
 | ||||
| func (machine *Machine) Proto() *v1.Machine { | ||||
| 	machineProto := &v1.Machine{ | ||||
| 		Id:         machine.ID, | ||||
| 		MachineKey: machine.MachineKey, | ||||
| 
 | ||||
| 		NodeKey:     machine.NodeKey, | ||||
| 		DiscoKey:    machine.DiscoKey, | ||||
| 		IpAddresses: machine.IPAddresses.ToStringSlice(), | ||||
| 		Name:        machine.Hostname, | ||||
| 		GivenName:   machine.GivenName, | ||||
| 		User:        machine.User.Proto(), | ||||
| 		ForcedTags:  machine.ForcedTags, | ||||
| 		Online:      machine.IsOnline(), | ||||
| 
 | ||||
| 		// TODO(kradalby): Implement register method enum converter | ||||
| 		// RegisterMethod: , | ||||
| 
 | ||||
| 		CreatedAt: timestamppb.New(machine.CreatedAt), | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.AuthKey != nil { | ||||
| 		machineProto.PreAuthKey = machine.AuthKey.Proto() | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.LastSeen != nil { | ||||
| 		machineProto.LastSeen = timestamppb.New(*machine.LastSeen) | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.LastSuccessfulUpdate != nil { | ||||
| 		machineProto.LastSuccessfulUpdate = timestamppb.New( | ||||
| 			*machine.LastSuccessfulUpdate, | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.Expiry != nil { | ||||
| 		machineProto.Expiry = timestamppb.New(*machine.Expiry) | ||||
| 	} | ||||
| 
 | ||||
| 	return machineProto | ||||
| } | ||||
| 
 | ||||
| // GetHostInfo returns a Hostinfo struct for the machine. | ||||
| func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { | ||||
| 	return tailcfg.Hostinfo(machine.HostInfo) | ||||
| } | ||||
| 
 | ||||
| func (machine Machine) String() string { | ||||
| 	return machine.Hostname | ||||
| } | ||||
| 
 | ||||
| func (machines Machines) String() string { | ||||
| 	temp := make([]string, len(machines)) | ||||
| 
 | ||||
| 	for index, machine := range machines { | ||||
| 		temp[index] = machine.Hostname | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): Remove when we have generics... | ||||
| func (machines MachinesP) String() string { | ||||
| 	temp := make([]string, len(machines)) | ||||
| 
 | ||||
| 	for index, machine := range machines { | ||||
| 		temp[index] = machine.Hostname | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) | ||||
| } | ||||
							
								
								
									
										1
									
								
								hscontrol/types/machine_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								hscontrol/types/machine_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| package types | ||||
							
								
								
									
										58
									
								
								hscontrol/types/preauth_key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								hscontrol/types/preauth_key.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| ) | ||||
| 
 | ||||
| // PreAuthKey describes a pre-authorization key usable in a particular user. | ||||
| type PreAuthKey struct { | ||||
| 	ID        uint64 `gorm:"primary_key"` | ||||
| 	Key       string | ||||
| 	UserID    uint | ||||
| 	User      User | ||||
| 	Reusable  bool | ||||
| 	Ephemeral bool `gorm:"default:false"` | ||||
| 	Used      bool `gorm:"default:false"` | ||||
| 	ACLTags   []PreAuthKeyACLTag | ||||
| 
 | ||||
| 	CreatedAt  *time.Time | ||||
| 	Expiration *time.Time | ||||
| } | ||||
| 
 | ||||
| // PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. | ||||
| type PreAuthKeyACLTag struct { | ||||
| 	ID           uint64 `gorm:"primary_key"` | ||||
| 	PreAuthKeyID uint64 | ||||
| 	Tag          string | ||||
| } | ||||
| 
 | ||||
| func (key *PreAuthKey) Proto() *v1.PreAuthKey { | ||||
| 	protoKey := v1.PreAuthKey{ | ||||
| 		User:      key.User.Name, | ||||
| 		Id:        strconv.FormatUint(key.ID, util.Base10), | ||||
| 		Key:       key.Key, | ||||
| 		Ephemeral: key.Ephemeral, | ||||
| 		Reusable:  key.Reusable, | ||||
| 		Used:      key.Used, | ||||
| 		AclTags:   make([]string, len(key.ACLTags)), | ||||
| 	} | ||||
| 
 | ||||
| 	if key.Expiration != nil { | ||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||
| 	} | ||||
| 
 | ||||
| 	if key.CreatedAt != nil { | ||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||
| 	} | ||||
| 
 | ||||
| 	for idx := range key.ACLTags { | ||||
| 		protoKey.AclTags[idx] = key.ACLTags[idx].Tag | ||||
| 	} | ||||
| 
 | ||||
| 	return &protoKey | ||||
| } | ||||
							
								
								
									
										71
									
								
								hscontrol/types/routes.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								hscontrol/types/routes.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,71 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") | ||||
| 	ExitRouteV6 = netip.MustParsePrefix("::/0") | ||||
| ) | ||||
| 
 | ||||
| type Route struct { | ||||
| 	gorm.Model | ||||
| 
 | ||||
| 	MachineID uint64 | ||||
| 	Machine   Machine | ||||
| 	Prefix    IPPrefix | ||||
| 
 | ||||
| 	Advertised bool | ||||
| 	Enabled    bool | ||||
| 	IsPrimary  bool | ||||
| } | ||||
| 
 | ||||
| type Routes []Route | ||||
| 
 | ||||
| func (r *Route) String() string { | ||||
| 	return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) | ||||
| } | ||||
| 
 | ||||
| func (r *Route) IsExitRoute() bool { | ||||
| 	return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 | ||||
| } | ||||
| 
 | ||||
| func (rs Routes) Prefixes() []netip.Prefix { | ||||
| 	prefixes := make([]netip.Prefix, len(rs)) | ||||
| 	for i, r := range rs { | ||||
| 		prefixes[i] = netip.Prefix(r.Prefix) | ||||
| 	} | ||||
| 
 | ||||
| 	return prefixes | ||||
| } | ||||
| 
 | ||||
| func (rs Routes) Proto() []*v1.Route { | ||||
| 	protoRoutes := []*v1.Route{} | ||||
| 
 | ||||
| 	for _, route := range rs { | ||||
| 		protoRoute := v1.Route{ | ||||
| 			Id:         uint64(route.ID), | ||||
| 			Machine:    route.Machine.Proto(), | ||||
| 			Prefix:     netip.Prefix(route.Prefix).String(), | ||||
| 			Advertised: route.Advertised, | ||||
| 			Enabled:    route.Enabled, | ||||
| 			IsPrimary:  route.IsPrimary, | ||||
| 			CreatedAt:  timestamppb.New(route.CreatedAt), | ||||
| 			UpdatedAt:  timestamppb.New(route.UpdatedAt), | ||||
| 		} | ||||
| 
 | ||||
| 		if route.DeletedAt.Valid { | ||||
| 			protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) | ||||
| 		} | ||||
| 
 | ||||
| 		protoRoutes = append(protoRoutes, &protoRoute) | ||||
| 	} | ||||
| 
 | ||||
| 	return protoRoutes | ||||
| } | ||||
							
								
								
									
										55
									
								
								hscontrol/types/users.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								hscontrol/types/users.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | ||||
| package types | ||||
| 
 | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
| 
 | ||||
| // User is the way Headscale implements the concept of users in Tailscale | ||||
| // | ||||
| // At the end of the day, users in Tailscale are some kind of 'bubbles' or users | ||||
| // that contain our machines. | ||||
| type User struct { | ||||
| 	gorm.Model | ||||
| 	Name string `gorm:"unique"` | ||||
| } | ||||
| 
 | ||||
| func (n *User) TailscaleUser() *tailcfg.User { | ||||
| 	user := tailcfg.User{ | ||||
| 		ID:            tailcfg.UserID(n.ID), | ||||
| 		LoginName:     n.Name, | ||||
| 		DisplayName:   n.Name, | ||||
| 		ProfilePicURL: "", | ||||
| 		Domain:        "headscale.net", | ||||
| 		Logins:        []tailcfg.LoginID{}, | ||||
| 		Created:       time.Time{}, | ||||
| 	} | ||||
| 
 | ||||
| 	return &user | ||||
| } | ||||
| 
 | ||||
| func (n *User) TailscaleLogin() *tailcfg.Login { | ||||
| 	login := tailcfg.Login{ | ||||
| 		ID:            tailcfg.LoginID(n.ID), | ||||
| 		LoginName:     n.Name, | ||||
| 		DisplayName:   n.Name, | ||||
| 		ProfilePicURL: "", | ||||
| 		Domain:        "headscale.net", | ||||
| 	} | ||||
| 
 | ||||
| 	return &login | ||||
| } | ||||
| 
 | ||||
| func (n *User) Proto() *v1.User { | ||||
| 	return &v1.User{ | ||||
| 		Id:        strconv.FormatUint(uint64(n.ID), util.Base10), | ||||
| 		Name:      n.Name, | ||||
| 		CreatedAt: timestamppb.New(n.CreatedAt), | ||||
| 	} | ||||
| } | ||||
| @ -1,415 +0,0 @@ | ||||
| package hscontrol | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestCreateAndDestroyUser(c *check.C) { | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(user.Name, check.Equals, "test") | ||||
| 
 | ||||
| 	users, err := app.db.ListUsers() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(users), check.Equals, 1) | ||||
| 
 | ||||
| 	err = app.db.DestroyUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetUser("test") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestDestroyUserErrors(c *check.C) { | ||||
| 	err := app.db.DestroyUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	user, err := app.db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.DestroyUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) | ||||
| 	// destroying a user also deletes all associated preauthkeys | ||||
| 	c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) | ||||
| 
 | ||||
| 	user, err = app.db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = app.db.DestroyUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserStillHasNodes) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestRenameUser(c *check.C) { | ||||
| 	userTest, err := app.db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(userTest.Name, check.Equals, "test") | ||||
| 
 | ||||
| 	users, err := app.db.ListUsers() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(users), check.Equals, 1) | ||||
| 
 | ||||
| 	err = app.db.RenameUser("test", "test-renamed") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetUser("test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	_, err = app.db.GetUser("test-renamed") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	err = app.db.RenameUser("test-does-not-exit", "test") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	userTest2, err := app.db.CreateUser("test2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(userTest2.Name, check.Equals, "test2") | ||||
| 
 | ||||
| 	err = app.db.RenameUser("test2", "test-renamed") | ||||
| 	c.Assert(err, check.Equals, ErrUserExists) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { | ||||
| 	userShared1, err := app.db.CreateUser("shared1") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userShared2, err := app.db.CreateUser("shared2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userShared3, err := app.db.CreateUser("shared3") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared1, err := app.db.CreatePreAuthKey( | ||||
| 		userShared1.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared2, err := app.db.CreatePreAuthKey( | ||||
| 		userShared2.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKeyShared3, err := app.db.CreatePreAuthKey( | ||||
| 		userShared3.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	preAuthKey2Shared1, err := app.db.CreatePreAuthKey( | ||||
| 		userShared1.Name, | ||||
| 		false, | ||||
| 		false, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 	) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	machineInShared1 := &Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||
| 		Hostname:       "test_get_shared_nodes_1", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared1) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared2 := &Machine{ | ||||
| 		ID:             2, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_2", | ||||
| 		UserID:         userShared2.ID, | ||||
| 		User:           *userShared2, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared2.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared2) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machineInShared3 := &Machine{ | ||||
| 		ID:             3, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_3", | ||||
| 		UserID:         userShared3.ID, | ||||
| 		User:           *userShared3, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||
| 		AuthKeyID:      uint(preAuthKeyShared3.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machineInShared3) | ||||
| 
 | ||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine2InShared1 := &Machine{ | ||||
| 		ID:             4, | ||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||
| 		Hostname:       "test_get_shared_nodes_4", | ||||
| 		UserID:         userShared1.ID, | ||||
| 		User:           *userShared1, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||
| 		AuthKeyID:      uint(preAuthKey2Shared1.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(machine2InShared1) | ||||
| 
 | ||||
| 	peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	userProfiles := app.db.getMapResponseUserProfiles( | ||||
| 		*machineInShared1, | ||||
| 		peersOfMachine1InShared1, | ||||
| 	) | ||||
| 
 | ||||
| 	c.Assert(len(userProfiles), check.Equals, 3) | ||||
| 
 | ||||
| 	found := false | ||||
| 	for _, userProfiles := range userProfiles { | ||||
| 		if userProfiles.DisplayName == userShared1.Name { | ||||
| 			found = true | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	c.Assert(found, check.Equals, true) | ||||
| 
 | ||||
| 	found = false | ||||
| 	for _, userProfile := range userProfiles { | ||||
| 		if userProfile.DisplayName == userShared2.Name { | ||||
| 			found = true | ||||
| 
 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	c.Assert(found, check.Equals, true) | ||||
| } | ||||
| 
 | ||||
| func TestNormalizeToFQDNRules(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		name             string | ||||
| 		stripEmailDomain bool | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		args    args | ||||
| 		want    string | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "normalize simple name", | ||||
| 			args: args{ | ||||
| 				name:             "normalize-simple.name", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "normalize-simple.name", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize an email", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar@example.com", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "foo.bar.example.com", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize an email domain should be removed", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar@example.com", | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    "foo.bar", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "strip enabled no email passed as argument", | ||||
| 			args: args{ | ||||
| 				name:             "not-email-and-strip-enabled", | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    "not-email-and-strip-enabled", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize complex email", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar+complex-email@example.com", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "foo.bar-complex-email.example.com", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "user name with space", | ||||
| 			args: args{ | ||||
| 				name:             "name space", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "name-space", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "user with quote", | ||||
| 			args: args{ | ||||
| 				name:             "Jamie's iPhone 5", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "jamies-iphone-5", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf( | ||||
| 					"NormalizeToFQDNRules() error = %v, wantErr %v", | ||||
| 					err, | ||||
| 					tt.wantErr, | ||||
| 				) | ||||
| 
 | ||||
| 				return | ||||
| 			} | ||||
| 			if got != tt.want { | ||||
| 				t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCheckForFQDNRules(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		name string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		args    args | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:    "valid: user", | ||||
| 			args:    args{name: "valid-user"}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: capitalized user", | ||||
| 			args:    args{name: "Invalid-CapItaLIzed-user"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: email as user", | ||||
| 			args:    args{name: "foo.bar@example.com"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: chars in user name", | ||||
| 			args:    args{name: "super-user+name"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "invalid: too long name for user", | ||||
| 			args: args{ | ||||
| 				name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", | ||||
| 			}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestSetMachineUser(c *check.C) { | ||||
| 	oldUser, err := app.db.CreateUser("old") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	newUser, err := app.db.CreateUser("new") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	machine := Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         oldUser.ID, | ||||
| 		RegisterMethod: RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	app.db.db.Save(&machine) | ||||
| 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | ||||
| 
 | ||||
| 	err = app.db.SetMachineUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
| 
 | ||||
| 	err = app.db.SetMachineUser(&machine, "non-existing-user") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	err = app.db.SetMachineUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
| } | ||||
| @ -1,12 +1,94 @@ | ||||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"go4.org/netipx" | ||||
| ) | ||||
| 
 | ||||
| // This is borrowed from, and updated to use IPSet | ||||
| // https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 | ||||
| // TODO(kradalby): contribute upstream and make public. | ||||
| var ( | ||||
| 	zeroIP4 = netip.AddrFrom4([4]byte{}) | ||||
| 	zeroIP6 = netip.AddrFrom16([16]byte{}) | ||||
| ) | ||||
| 
 | ||||
| // parseIPSet parses arg as one: | ||||
| // | ||||
| //   - an IP address (IPv4 or IPv6) | ||||
| //   - the string "*" to match everything (both IPv4 & IPv6) | ||||
| //   - a CIDR (e.g. "192.168.0.0/16") | ||||
| //   - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") | ||||
| // | ||||
| // bits, if non-nil, is the legacy SrcBits CIDR length to make a IP | ||||
| // address (without a slash) treated as a CIDR of *bits length. | ||||
| // nolint | ||||
| func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) { | ||||
| 	var ipSet netipx.IPSetBuilder | ||||
| 	if arg == "*" { | ||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) | ||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	if strings.Contains(arg, "/") { | ||||
| 		pfx, err := netip.ParsePrefix(arg) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if pfx != pfx.Masked() { | ||||
| 			return nil, fmt.Errorf("%v contains non-network bits set", pfx) | ||||
| 		} | ||||
| 
 | ||||
| 		ipSet.AddPrefix(pfx) | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	if strings.Count(arg, "-") == 1 { | ||||
| 		ip1s, ip2s, _ := strings.Cut(arg, "-") | ||||
| 
 | ||||
| 		ip1, err := netip.ParseAddr(ip1s) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		ip2, err := netip.ParseAddr(ip2s) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		r := netipx.IPRangeFrom(ip1, ip2) | ||||
| 		if !r.IsValid() { | ||||
| 			return nil, fmt.Errorf("invalid IP range %q", arg) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, prefix := range r.Prefixes() { | ||||
| 			ipSet.AddPrefix(prefix) | ||||
| 		} | ||||
| 
 | ||||
| 		return ipSet.IPSet() | ||||
| 	} | ||||
| 	ip, err := netip.ParseAddr(arg) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("invalid IP address %q", arg) | ||||
| 	} | ||||
| 	bits8 := uint8(ip.BitLen()) | ||||
| 	if bits != nil { | ||||
| 		if *bits < 0 || *bits > int(bits8) { | ||||
| 			return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) | ||||
| 		} | ||||
| 		bits8 = uint8(*bits) | ||||
| 	} | ||||
| 
 | ||||
| 	ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) | ||||
| 
 | ||||
| 	return ipSet.IPSet() | ||||
| } | ||||
| 
 | ||||
| func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { | ||||
| 	var network, broadcast netip.Addr | ||||
| 	ipRange := netipx.RangeOfPrefix(na) | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| package hscontrol | ||||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"net/netip" | ||||
| @ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) { | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := parseIPSet(tt.args.arg, tt.args.bits) | ||||
| 			got, err := ParseIPSet(tt.args.arg, tt.args.bits) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 
 | ||||
							
								
								
									
										7
									
								
								hscontrol/util/const.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								hscontrol/util/const.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| package util | ||||
| 
 | ||||
| const ( | ||||
| 	RegisterMethodAuthKey = "authkey" | ||||
| 	RegisterMethodOIDC    = "oidc" | ||||
| 	RegisterMethodCLI     = "cli" | ||||
| ) | ||||
							
								
								
									
										69
									
								
								hscontrol/util/dns.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								hscontrol/util/dns.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,69 @@ | ||||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// value related to RFC 1123 and 952. | ||||
| 	LabelHostnameLength = 63 | ||||
| ) | ||||
| 
 | ||||
| var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") | ||||
| 
 | ||||
| var ErrInvalidUserName = errors.New("invalid user name") | ||||
| 
 | ||||
| // NormalizeToFQDNRules will replace forbidden chars in user | ||||
| // it can also return an error if the user doesn't respect RFC 952 and 1123. | ||||
| func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { | ||||
| 	name = strings.ToLower(name) | ||||
| 	name = strings.ReplaceAll(name, "'", "") | ||||
| 	atIdx := strings.Index(name, "@") | ||||
| 	if stripEmailDomain && atIdx > 0 { | ||||
| 		name = name[:atIdx] | ||||
| 	} else { | ||||
| 		name = strings.ReplaceAll(name, "@", ".") | ||||
| 	} | ||||
| 	name = invalidCharsInUserRegex.ReplaceAllString(name, "-") | ||||
| 
 | ||||
| 	for _, elt := range strings.Split(name, ".") { | ||||
| 		if len(elt) > LabelHostnameLength { | ||||
| 			return "", fmt.Errorf( | ||||
| 				"label %v is more than 63 chars: %w", | ||||
| 				elt, | ||||
| 				ErrInvalidUserName, | ||||
| 			) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return name, nil | ||||
| } | ||||
| 
 | ||||
| func CheckForFQDNRules(name string) error { | ||||
| 	if len(name) > LabelHostnameLength { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 	if strings.ToLower(name) != name { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment should be lowercase. %v doesn't comply with this rule: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 	if invalidCharsInUserRegex.MatchString(name) { | ||||
| 		return fmt.Errorf( | ||||
| 			"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", | ||||
| 			name, | ||||
| 			ErrInvalidUserName, | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										143
									
								
								hscontrol/util/dns_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								hscontrol/util/dns_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | ||||
| package util | ||||
| 
 | ||||
| import "testing" | ||||
| 
 | ||||
| func TestNormalizeToFQDNRules(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		name             string | ||||
| 		stripEmailDomain bool | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		args    args | ||||
| 		want    string | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "normalize simple name", | ||||
| 			args: args{ | ||||
| 				name:             "normalize-simple.name", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "normalize-simple.name", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize an email", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar@example.com", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "foo.bar.example.com", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize an email domain should be removed", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar@example.com", | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    "foo.bar", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "strip enabled no email passed as argument", | ||||
| 			args: args{ | ||||
| 				name:             "not-email-and-strip-enabled", | ||||
| 				stripEmailDomain: true, | ||||
| 			}, | ||||
| 			want:    "not-email-and-strip-enabled", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "normalize complex email", | ||||
| 			args: args{ | ||||
| 				name:             "foo.bar+complex-email@example.com", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "foo.bar-complex-email.example.com", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "user name with space", | ||||
| 			args: args{ | ||||
| 				name:             "name space", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "name-space", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "user with quote", | ||||
| 			args: args{ | ||||
| 				name:             "Jamie's iPhone 5", | ||||
| 				stripEmailDomain: false, | ||||
| 			}, | ||||
| 			want:    "jamies-iphone-5", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf( | ||||
| 					"NormalizeToFQDNRules() error = %v, wantErr %v", | ||||
| 					err, | ||||
| 					tt.wantErr, | ||||
| 				) | ||||
| 
 | ||||
| 				return | ||||
| 			} | ||||
| 			if got != tt.want { | ||||
| 				t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCheckForFQDNRules(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		name string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		args    args | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:    "valid: user", | ||||
| 			args:    args{name: "valid-user"}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: capitalized user", | ||||
| 			args:    args{name: "Invalid-CapItaLIzed-user"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: email as user", | ||||
| 			args:    args{name: "foo.bar@example.com"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid: chars in user name", | ||||
| 			args:    args{name: "super-user+name"}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "invalid: too long name for user", | ||||
| 			args: args{ | ||||
| 				name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", | ||||
| 			}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| @ -45,7 +45,7 @@ var veryLargeDestination = []string{ | ||||
| 	"208.0.0.0/4:*", | ||||
| } | ||||
| 
 | ||||
| func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario { | ||||
| func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { | ||||
| 	t.Helper() | ||||
| 	scenario, err := NewScenario() | ||||
| 	assert.NoError(t, err) | ||||
| @ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 	// they can access minus one (them self). | ||||
| 	tests := map[string]struct { | ||||
| 		users  map[string]int | ||||
| 		policy hscontrol.ACLPolicy | ||||
| 		policy policy.ACLPolicy | ||||
| 		want   map[string]int | ||||
| 	}{ | ||||
| 		// Test that when we have no ACL, each client netmap has | ||||
| @ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| @ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"user1"}, | ||||
| @ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"user1"}, | ||||
| @ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"user1"}, | ||||
| @ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"user1"}, | ||||
| @ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		&hscontrol.ACLPolicy{ | ||||
| 			ACLs: []hscontrol.ACL{ | ||||
| 		&policy.ACLPolicy{ | ||||
| 			ACLs: []policy.ACL{ | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"user1"}, | ||||
| @ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		&hscontrol.ACLPolicy{ | ||||
| 		&policy.ACLPolicy{ | ||||
| 			Groups: map[string][]string{ | ||||
| 				"group:integration-acl-test": {"user1", "user2"}, | ||||
| 			}, | ||||
| 			ACLs: []hscontrol.ACL{ | ||||
| 			ACLs: []policy.ACL{ | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"group:integration-acl-test"}, | ||||
| @ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		&hscontrol.ACLPolicy{ | ||||
| 			ACLs: []hscontrol.ACL{ | ||||
| 		&policy.ACLPolicy{ | ||||
| 			ACLs: []policy.ACL{ | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"user1"}, | ||||
| @ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		&hscontrol.ACLPolicy{ | ||||
| 			ACLs: []hscontrol.ACL{ | ||||
| 		&policy.ACLPolicy{ | ||||
| 			ACLs: []policy.ACL{ | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"user1"}, | ||||
| @ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		&hscontrol.ACLPolicy{ | ||||
| 			Hosts: hscontrol.Hosts{ | ||||
| 		&policy.ACLPolicy{ | ||||
| 			Hosts: policy.Hosts{ | ||||
| 				"all": netip.MustParsePrefix("100.64.0.0/24"), | ||||
| 			}, | ||||
| 			ACLs: []hscontrol.ACL{ | ||||
| 			ACLs: []policy.ACL{ | ||||
| 				// Everyone can curl test3 | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| @ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	tests := map[string]struct { | ||||
| 		policy hscontrol.ACLPolicy | ||||
| 		policy policy.ACLPolicy | ||||
| 	}{ | ||||
| 		"ipv4": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				Hosts: hscontrol.Hosts{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				Hosts: policy.Hosts{ | ||||
| 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||
| 					"test3": netip.MustParsePrefix("100.64.0.3/32"), | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					// Everyone can curl test3 | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| @ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"ipv6": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				Hosts: hscontrol.Hosts{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				Hosts: policy.Hosts{ | ||||
| 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | ||||
| 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | ||||
| 					"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					// Everyone can curl test3 | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| @ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	tests := map[string]struct { | ||||
| 		policy hscontrol.ACLPolicy | ||||
| 		policy policy.ACLPolicy | ||||
| 	}{ | ||||
| 		"ipv4": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"100.64.0.1"}, | ||||
| @ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"ipv6": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"fd7a:115c:a1e0::1"}, | ||||
| @ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"hostv4cidr": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				Hosts: hscontrol.Hosts{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				Hosts: policy.Hosts{ | ||||
| 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"test1"}, | ||||
| @ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"hostv6cidr": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 				Hosts: hscontrol.Hosts{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				Hosts: policy.Hosts{ | ||||
| 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | ||||
| 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"test1"}, | ||||
| @ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"group": { | ||||
| 			policy: hscontrol.ACLPolicy{ | ||||
| 			policy: policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:one": {"user1"}, | ||||
| 					"group:two": {"user2"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"group:one"}, | ||||
|  | ||||
| @ -23,7 +23,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/davecgh/go-spew/spew" | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/integrationutil" | ||||
| @ -60,7 +60,7 @@ type HeadscaleInContainer struct { | ||||
| 	port             int | ||||
| 	extraPorts       []string | ||||
| 	hostPortBindings map[string][]string | ||||
| 	aclPolicy        *hscontrol.ACLPolicy | ||||
| 	aclPolicy        *policy.ACLPolicy | ||||
| 	env              map[string]string | ||||
| 	tlsCert          []byte | ||||
| 	tlsKey           []byte | ||||
| @ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer) | ||||
| 
 | ||||
| // WithACLPolicy adds a hscontrol.ACLPolicy policy to the | ||||
| // HeadscaleInContainer instance. | ||||
| func WithACLPolicy(acl *hscontrol.ACLPolicy) Option { | ||||
| func WithACLPolicy(acl *policy.ACLPolicy) Option { | ||||
| 	return func(hsic *HeadscaleInContainer) { | ||||
| 		// TODO(kradalby): Move somewhere appropriate | ||||
| 		hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath | ||||
|  | ||||
| @ -6,7 +6,7 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol" | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| @ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) { | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 		[]tsic.Option{tsic.WithSSH()}, | ||||
| 		hsic.WithACLPolicy( | ||||
| 			&hscontrol.ACLPolicy{ | ||||
| 			&policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:integration-test": {"user1"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| 						Destinations: []string{"*:*"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				SSHs: []hscontrol.SSH{ | ||||
| 				SSHs: []policy.SSH{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"group:integration-test"}, | ||||
| @ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 		[]tsic.Option{tsic.WithSSH()}, | ||||
| 		hsic.WithACLPolicy( | ||||
| 			&hscontrol.ACLPolicy{ | ||||
| 			&policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:integration-test": {"user1", "user2"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| 						Destinations: []string{"*:*"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				SSHs: []hscontrol.SSH{ | ||||
| 				SSHs: []policy.SSH{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"group:integration-test"}, | ||||
| @ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) { | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 		[]tsic.Option{tsic.WithSSH()}, | ||||
| 		hsic.WithACLPolicy( | ||||
| 			&hscontrol.ACLPolicy{ | ||||
| 			&policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:integration-test": {"user1"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| 						Destinations: []string{"*:*"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				SSHs: []hscontrol.SSH{}, | ||||
| 				SSHs: []policy.SSH{}, | ||||
| 			}, | ||||
| 		), | ||||
| 		hsic.WithTestName("sshnoneconfigured"), | ||||
| @ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) { | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 		[]tsic.Option{tsic.WithSSH()}, | ||||
| 		hsic.WithACLPolicy( | ||||
| 			&hscontrol.ACLPolicy{ | ||||
| 			&policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:integration-test": {"user1"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| 						Destinations: []string{"*:80"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				SSHs: []hscontrol.SSH{ | ||||
| 				SSHs: []policy.SSH{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"group:integration-test"}, | ||||
| @ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) { | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 		[]tsic.Option{tsic.WithSSH()}, | ||||
| 		hsic.WithACLPolicy( | ||||
| 			&hscontrol.ACLPolicy{ | ||||
| 			&policy.ACLPolicy{ | ||||
| 				Groups: map[string][]string{ | ||||
| 					"group:ssh1": {"useracl1"}, | ||||
| 					"group:ssh2": {"useracl2"}, | ||||
| 				}, | ||||
| 				ACLs: []hscontrol.ACL{ | ||||
| 				ACLs: []policy.ACL{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"*"}, | ||||
| 						Destinations: []string{"*:*"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				SSHs: []hscontrol.SSH{ | ||||
| 				SSHs: []policy.SSH{ | ||||
| 					{ | ||||
| 						Action:       "accept", | ||||
| 						Sources:      []string{"group:ssh1"}, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user