mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-01 00:21:02 +01:00 
			
		
		
		
	Split code into modules
This is a massive commit that restructures the code into modules:
db/
    All functions related to modifying the Database
types/
    All type definitions and methods that can be exclusivly used on
    these types without dependencies
policy/
    All Policy related code, now without dependencies on the Database.
policy/matcher/
    Dedicated code to match machines in a list of FilterRules
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
			
			
This commit is contained in:
		
							parent
							
								
									14e29a7bee
								
							
						
					
					
						commit
						feb15365b5
					
				| @ -7,7 +7,7 @@ import ( | |||||||
| 	"strconv" | 	"strconv" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| 	"github.com/juanfont/headscale/hscontrol" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/pterm/pterm" | 	"github.com/pterm/pterm" | ||||||
| 	"github.com/spf13/cobra" | 	"github.com/spf13/cobra" | ||||||
| 	"google.golang.org/grpc/status" | 	"google.golang.org/grpc/status" | ||||||
| @ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData { | |||||||
| 
 | 
 | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 { | 		if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 { | ||||||
| 			isPrimaryStr = "-" | 			isPrimaryStr = "-" | ||||||
| 		} else { | 		} else { | ||||||
| 			isPrimaryStr = strconv.FormatBool(route.IsPrimary) | 			isPrimaryStr = strconv.FormatBool(route.IsPrimary) | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| 	"github.com/juanfont/headscale/hscontrol" | 	"github.com/juanfont/headscale/hscontrol" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| @ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { | |||||||
| 
 | 
 | ||||||
| 	if cfg.ACL.PolicyPath != "" { | 	if cfg.ACL.PolicyPath != "" { | ||||||
| 		aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) | 		aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) | ||||||
| 		err = app.LoadACLPolicyFromPath(aclPath) | 		pol, err := policy.LoadACLPolicyFromPath(aclPath) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal(). | 			log.Fatal(). | ||||||
| 				Str("path", aclPath). | 				Str("path", aclPath). | ||||||
| 				Err(err). | 				Err(err). | ||||||
| 				Msg("Could not load the ACL policy") | 				Msg("Could not load the ACL policy") | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		app.ACLPolicy = pol | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return app, nil | 	return app, nil | ||||||
|  | |||||||
| @ -18,9 +18,6 @@ const ( | |||||||
| 	// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. | 	// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. | ||||||
| 	registrationHoldoff        = time.Second * 5 | 	registrationHoldoff        = time.Second * 5 | ||||||
| 	reservedResponseHeaderSize = 4 | 	reservedResponseHeaderSize = 4 | ||||||
| 	RegisterMethodAuthKey      = "authkey" |  | ||||||
| 	RegisterMethodOIDC         = "oidc" |  | ||||||
| 	RegisterMethodCLI          = "cli" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( | var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( | ||||||
| @ -56,7 +53,7 @@ func (h *Headscale) HealthHandler( | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := h.db.pingDB(req.Context()); err != nil { | 	if err := h.db.PingDB(req.Context()); err != nil { | ||||||
| 		respond(err) | 		respond(err) | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ package hscontrol | |||||||
| import ( | import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| @ -10,13 +11,13 @@ import ( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) generateMapResponse( | func (h *Headscale) generateMapResponse( | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| ) (*tailcfg.MapResponse, error) { | ) (*tailcfg.MapResponse, error) { | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Str("func", "generateMapResponse"). | 		Str("func", "generateMapResponse"). | ||||||
| 		Str("machine", mapRequest.Hostinfo.Hostname). | 		Str("machine", mapRequest.Hostinfo.Hostname). | ||||||
| 		Msg("Creating Map response") | 		Msg("Creating Map response") | ||||||
| 	node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) | 	node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) | 	peers, err := h.db.GetValidPeers(h.aclRules, machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	profiles := h.db.getMapResponseUserProfiles(*machine, peers) | 	profiles := h.db.GetMapResponseUserProfiles(*machine, peers) | ||||||
| 
 | 
 | ||||||
| 	nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) | 	nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
|  | |||||||
							
								
								
									
										166
									
								
								hscontrol/app.go
									
									
									
									
									
								
							
							
						
						
									
										166
									
								
								hscontrol/app.go
									
									
									
									
									
								
							| @ -23,6 +23,9 @@ import ( | |||||||
| 	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" | 	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" | ||||||
| 	"github.com/juanfont/headscale" | 	"github.com/juanfont/headscale" | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/db" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/patrickmn/go-cache" | 	"github.com/patrickmn/go-cache" | ||||||
| 	zerolog "github.com/philip-bui/grpc-zerolog" | 	zerolog "github.com/philip-bui/grpc-zerolog" | ||||||
| @ -73,7 +76,7 @@ const ( | |||||||
| // Headscale represents the base app of the service. | // Headscale represents the base app of the service. | ||||||
| type Headscale struct { | type Headscale struct { | ||||||
| 	cfg             *Config | 	cfg             *Config | ||||||
| 	db              *HSDatabase | 	db              *db.HSDatabase | ||||||
| 	dbString        string | 	dbString        string | ||||||
| 	dbType          string | 	dbType          string | ||||||
| 	dbDebug         bool | 	dbDebug         bool | ||||||
| @ -83,7 +86,7 @@ type Headscale struct { | |||||||
| 	DERPMap    *tailcfg.DERPMap | 	DERPMap    *tailcfg.DERPMap | ||||||
| 	DERPServer *DERPServer | 	DERPServer *DERPServer | ||||||
| 
 | 
 | ||||||
| 	aclPolicy *ACLPolicy | 	ACLPolicy *policy.ACLPolicy | ||||||
| 	aclRules  []tailcfg.FilterRule | 	aclRules  []tailcfg.FilterRule | ||||||
| 	sshPolicy *tailcfg.SSHPolicy | 	sshPolicy *tailcfg.SSHPolicy | ||||||
| 
 | 
 | ||||||
| @ -99,6 +102,12 @@ type Headscale struct { | |||||||
| 
 | 
 | ||||||
| 	stateUpdateChan       chan struct{} | 	stateUpdateChan       chan struct{} | ||||||
| 	cancelStateUpdateChan chan struct{} | 	cancelStateUpdateChan chan struct{} | ||||||
|  | 
 | ||||||
|  | 	// TODO(kradalby): Temporary measure to make sure we can update policy | ||||||
|  | 	// across modules, will be removed when aclRules are no longer stored | ||||||
|  | 	// globally but generated per node basis. | ||||||
|  | 	policyUpdateChan       chan struct{} | ||||||
|  | 	cancelPolicyUpdateChan chan struct{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewHeadscale(cfg *Config) (*Headscale, error) { | func NewHeadscale(cfg *Config) (*Headscale, error) { | ||||||
| @ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | |||||||
| 
 | 
 | ||||||
| 	var dbString string | 	var dbString string | ||||||
| 	switch cfg.DBtype { | 	switch cfg.DBtype { | ||||||
| 	case Postgres: | 	case db.Postgres: | ||||||
| 		dbString = fmt.Sprintf( | 		dbString = fmt.Sprintf( | ||||||
| 			"host=%s dbname=%s user=%s", | 			"host=%s dbname=%s user=%s", | ||||||
| 			cfg.DBhost, | 			cfg.DBhost, | ||||||
| @ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | |||||||
| 		if cfg.DBpass != "" { | 		if cfg.DBpass != "" { | ||||||
| 			dbString += fmt.Sprintf(" password=%s", cfg.DBpass) | 			dbString += fmt.Sprintf(" password=%s", cfg.DBpass) | ||||||
| 		} | 		} | ||||||
| 	case Sqlite: | 	case db.Sqlite: | ||||||
| 		dbString = cfg.DBpath | 		dbString = cfg.DBpath | ||||||
| 	default: | 	default: | ||||||
| 		return nil, errUnsupportedDatabase | 		return nil, errUnsupportedDatabase | ||||||
| @ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { | |||||||
| 
 | 
 | ||||||
| 		stateUpdateChan:       make(chan struct{}), | 		stateUpdateChan:       make(chan struct{}), | ||||||
| 		cancelStateUpdateChan: make(chan struct{}), | 		cancelStateUpdateChan: make(chan struct{}), | ||||||
|  | 
 | ||||||
|  | 		policyUpdateChan:       make(chan struct{}), | ||||||
|  | 		cancelPolicyUpdateChan: make(chan struct{}), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	go app.watchStateChannel() | 	go app.watchStateChannel() | ||||||
|  | 	go app.watchPolicyChannel() | ||||||
| 
 | 
 | ||||||
| 	db, err := NewHeadscaleDatabase( | 	database, err := db.NewHeadscaleDatabase( | ||||||
| 		cfg.DBtype, | 		cfg.DBtype, | ||||||
| 		dbString, | 		dbString, | ||||||
| 		cfg.OIDC.StripEmaildomain, | 		cfg.OIDC.StripEmaildomain, | ||||||
| 		app.dbDebug, | 		app.dbDebug, | ||||||
| 		app.stateUpdateChan, | 		app.stateUpdateChan, | ||||||
|  | 		app.policyUpdateChan, | ||||||
| 		cfg.IPPrefixes, | 		cfg.IPPrefixes, | ||||||
| 		cfg.BaseDomain) | 		cfg.BaseDomain) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	app.db = db | 	app.db = database | ||||||
| 
 | 
 | ||||||
| 	if cfg.OIDC.Issuer != "" { | 	if cfg.OIDC.Issuer != "" { | ||||||
| 		err = app.initOIDC() | 		err = app.initOIDC() | ||||||
| @ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { | |||||||
| func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { | func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { | ||||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||||
| 	for range ticker.C { | 	for range ticker.C { | ||||||
| 		h.expireEphemeralNodesWorker() | 		h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { | |||||||
| func (h *Headscale) expireExpiredMachines(milliSeconds int64) { | func (h *Headscale) expireExpiredMachines(milliSeconds int64) { | ||||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||||
| 	for range ticker.C { | 	for range ticker.C { | ||||||
| 		h.expireExpiredMachinesWorker() | 		h.db.ExpireExpiredMachines(h.getLastStateChange()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { | func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { | ||||||
| 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | 	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) | ||||||
| 	for range ticker.C { | 	for range ticker.C { | ||||||
| 		err := h.db.handlePrimarySubnetFailover() | 		err := h.db.HandlePrimarySubnetFailover() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Err(err).Msg("failed to handle primary subnet failover") | 			log.Error().Err(err).Msg("failed to handle primary subnet failover") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) expireEphemeralNodesWorker() { |  | ||||||
| 	users, err := h.db.ListUsers() |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error().Err(err).Msg("Error listing users") |  | ||||||
| 
 |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, user := range users { |  | ||||||
| 		machines, err := h.db.ListMachinesByUser(user.Name) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Str("user", user.Name). |  | ||||||
| 				Msg("Error listing machines in user") |  | ||||||
| 
 |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		expiredFound := false |  | ||||||
| 		for _, machine := range machines { |  | ||||||
| 			if machine.isEphemeral() && machine.LastSeen != nil && |  | ||||||
| 				time.Now(). |  | ||||||
| 					After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { |  | ||||||
| 				expiredFound = true |  | ||||||
| 				log.Info(). |  | ||||||
| 					Str("machine", machine.Hostname). |  | ||||||
| 					Msg("Ephemeral client removed from database") |  | ||||||
| 
 |  | ||||||
| 				err = h.db.db.Unscoped().Delete(machine).Error |  | ||||||
| 				if err != nil { |  | ||||||
| 					log.Error(). |  | ||||||
| 						Err(err). |  | ||||||
| 						Str("machine", machine.Hostname). |  | ||||||
| 						Msg("🤮 Cannot delete ephemeral machine from the database") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if expiredFound { |  | ||||||
| 			h.setLastStateChangeToNow() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (h *Headscale) expireExpiredMachinesWorker() { |  | ||||||
| 	users, err := h.db.ListUsers() |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error().Err(err).Msg("Error listing users") |  | ||||||
| 
 |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, user := range users { |  | ||||||
| 		machines, err := h.db.ListMachinesByUser(user.Name) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Str("user", user.Name). |  | ||||||
| 				Msg("Error listing machines in user") |  | ||||||
| 
 |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		expiredFound := false |  | ||||||
| 		for index, machine := range machines { |  | ||||||
| 			if machine.isExpired() && |  | ||||||
| 				machine.Expiry.After(h.getLastStateChange(user)) { |  | ||||||
| 				expiredFound = true |  | ||||||
| 
 |  | ||||||
| 				err := h.db.ExpireMachine(&machines[index]) |  | ||||||
| 				if err != nil { |  | ||||||
| 					log.Error(). |  | ||||||
| 						Err(err). |  | ||||||
| 						Str("machine", machine.Hostname). |  | ||||||
| 						Str("name", machine.GivenName). |  | ||||||
| 						Msg("🤮 Cannot expire machine") |  | ||||||
| 				} else { |  | ||||||
| 					log.Info(). |  | ||||||
| 						Str("machine", machine.Hostname). |  | ||||||
| 						Str("name", machine.GivenName). |  | ||||||
| 						Msg("Machine successfully expired") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if expiredFound { |  | ||||||
| 			h.setLastStateChangeToNow() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, | func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, | ||||||
| 	req interface{}, | 	req interface{}, | ||||||
| 	info *grpc.UnaryServerInfo, | 	info *grpc.UnaryServerInfo, | ||||||
| @ -565,6 +487,8 @@ func (h *Headscale) Serve() error { | |||||||
| 		go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) | 		go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// TODO(kradalby): These should have cancel channels and be cleaned | ||||||
|  | 	// up on shutdown. | ||||||
| 	go h.expireEphemeralNodes(updateInterval) | 	go h.expireEphemeralNodes(updateInterval) | ||||||
| 	go h.expireExpiredMachines(updateInterval) | 	go h.expireExpiredMachines(updateInterval) | ||||||
| 
 | 
 | ||||||
| @ -774,10 +698,12 @@ func (h *Headscale) Serve() error { | |||||||
| 
 | 
 | ||||||
| 				if h.cfg.ACL.PolicyPath != "" { | 				if h.cfg.ACL.PolicyPath != "" { | ||||||
| 					aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) | 					aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) | ||||||
| 					err := h.LoadACLPolicyFromPath(aclPath) | 					pol, err := policy.LoadACLPolicyFromPath(aclPath) | ||||||
| 					if err != nil { | 					if err != nil { | ||||||
| 						log.Error().Err(err).Msg("Failed to reload ACL policy") | 						log.Error().Err(err).Msg("Failed to reload ACL policy") | ||||||
| 					} | 					} | ||||||
|  | 
 | ||||||
|  | 					h.ACLPolicy = pol | ||||||
| 					log.Info(). | 					log.Info(). | ||||||
| 						Str("path", aclPath). | 						Str("path", aclPath). | ||||||
| 						Msg("ACL policy successfully reloaded, notifying nodes of change") | 						Msg("ACL policy successfully reloaded, notifying nodes of change") | ||||||
| @ -824,12 +750,12 @@ func (h *Headscale) Serve() error { | |||||||
| 				close(h.stateUpdateChan) | 				close(h.stateUpdateChan) | ||||||
| 				close(h.cancelStateUpdateChan) | 				close(h.cancelStateUpdateChan) | ||||||
| 
 | 
 | ||||||
|  | 				<-h.cancelPolicyUpdateChan | ||||||
|  | 				close(h.policyUpdateChan) | ||||||
|  | 				close(h.cancelPolicyUpdateChan) | ||||||
|  | 
 | ||||||
| 				// Close db connections | 				// Close db connections | ||||||
| 				db, err := h.db.db.DB() | 				err = h.db.Close() | ||||||
| 				if err != nil { |  | ||||||
| 					log.Error().Err(err).Msg("Failed to get db handle") |  | ||||||
| 				} |  | ||||||
| 				err = db.Close() |  | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error().Err(err).Msg("Failed to close db") | 					log.Error().Err(err).Msg("Failed to close db") | ||||||
| 				} | 				} | ||||||
| @ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // TODO(kradalby): baby steps, make this more robust. | ||||||
|  | func (h *Headscale) watchPolicyChannel() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-h.policyUpdateChan: | ||||||
|  | 			machines, err := h.db.ListMachines() | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Error().Err(err).Msg("failed to fetch machines during policy update") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Error().Err(err).Msg("failed to update ACL rules") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			h.aclRules = rules | ||||||
|  | 			h.sshPolicy = sshPolicy | ||||||
|  | 
 | ||||||
|  | 		case <-h.cancelPolicyUpdateChan: | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (h *Headscale) setLastStateChangeToNow() { | func (h *Headscale) setLastStateChangeToNow() { | ||||||
| 	var err error | 	var err error | ||||||
| 
 | 
 | ||||||
| @ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) getLastStateChange(users ...User) time.Time { | func (h *Headscale) getLastStateChange(users ...types.User) time.Time { | ||||||
| 	times := []time.Time{} | 	times := []time.Time{} | ||||||
| 
 | 
 | ||||||
| 	// getLastStateChange takes a list of users as a "filter", if no users | 	// getLastStateChange takes a list of users as a "filter", if no users | ||||||
|  | |||||||
							
								
								
									
										480
									
								
								hscontrol/db/acls_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										480
									
								
								hscontrol/db/acls_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,480 @@ | |||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"gopkg.in/check.v1" | ||||||
|  | 	"tailscale.com/envknob" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // TODO(kradalby): | ||||||
|  | // Convert these tests to being non-database dependent and table driven. They are | ||||||
|  | // very verbose, and dont really need the database. | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestSshRules(c *check.C) { | ||||||
|  | 	envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") | ||||||
|  | 
 | ||||||
|  | 	user, err := db.CreateUser("user1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	hostInfo := tailcfg.Hostinfo{ | ||||||
|  | 		OS:          "centos", | ||||||
|  | 		Hostname:    "testmachine", | ||||||
|  | 		RequestTags: []string{"tag:test"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	aclPolicy := &policy.ACLPolicy{ | ||||||
|  | 		Groups: policy.Groups{ | ||||||
|  | 			"group:test": []string{"user1"}, | ||||||
|  | 		}, | ||||||
|  | 		Hosts: policy.Hosts{ | ||||||
|  | 			"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), | ||||||
|  | 		}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"*"}, | ||||||
|  | 				Destinations: []string{"*:*"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		SSHs: []policy.SSH{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"group:test"}, | ||||||
|  | 				Destinations: []string{"client"}, | ||||||
|  | 				Users:        []string{"autogroup:nonroot"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"*"}, | ||||||
|  | 				Destinations: []string{"client"}, | ||||||
|  | 				Users:        []string{"autogroup:nonroot"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(sshPolicy, check.NotNil) | ||||||
|  | 	c.Assert(sshPolicy.Rules, check.HasLen, 2) | ||||||
|  | 	c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) | ||||||
|  | 	c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1) | ||||||
|  | 	c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") | ||||||
|  | 
 | ||||||
|  | 	c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) | ||||||
|  | 	c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1) | ||||||
|  | 	c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // this test should validate that we can expand a group in a TagOWner section and | ||||||
|  | // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. | ||||||
|  | // the tag is matched in the Sources section. | ||||||
|  | func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("user1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	hostInfo := tailcfg.Hostinfo{ | ||||||
|  | 		OS:          "centos", | ||||||
|  | 		Hostname:    "testmachine", | ||||||
|  | 		RequestTags: []string{"tag:test"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pol := &policy.ACLPolicy{ | ||||||
|  | 		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}}, | ||||||
|  | 		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"tag:test"}, | ||||||
|  | 				Destinations: []string{"*:*"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // this test should validate that we can expand a group in a TagOWner section and | ||||||
|  | // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. | ||||||
|  | // the tag is matched in the Destinations section. | ||||||
|  | func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("user1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	hostInfo := tailcfg.Hostinfo{ | ||||||
|  | 		OS:          "centos", | ||||||
|  | 		Hostname:    "testmachine", | ||||||
|  | 		RequestTags: []string{"tag:test"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             1, | ||||||
|  | 		MachineKey:     "12345", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pol := &policy.ACLPolicy{ | ||||||
|  | 		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}}, | ||||||
|  | 		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"*"}, | ||||||
|  | 				Destinations: []string{"tag:test:*"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // need a test with: | ||||||
|  | // tag on a host that isn't owned by a tag owners. So the user | ||||||
|  | // of the host should be valid. | ||||||
|  | func (s *Suite) TestInvalidTagValidUser(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("user1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	hostInfo := tailcfg.Hostinfo{ | ||||||
|  | 		OS:          "centos", | ||||||
|  | 		Hostname:    "testmachine", | ||||||
|  | 		RequestTags: []string{"tag:foo"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             1, | ||||||
|  | 		MachineKey:     "12345", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pol := &policy.ACLPolicy{ | ||||||
|  | 		TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"user1"}, | ||||||
|  | 				Destinations: []string{"*:*"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // tag on a host is owned by a tag owner, the tag is valid. | ||||||
|  | // an ACL rule is matching the tag to a user. It should not be valid since the | ||||||
|  | // host should be tied to the tag now. | ||||||
|  | func (s *Suite) TestValidTagInvalidUser(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("user1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "webserver") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	hostInfo := tailcfg.Hostinfo{ | ||||||
|  | 		OS:          "centos", | ||||||
|  | 		Hostname:    "webserver", | ||||||
|  | 		RequestTags: []string{"tag:webapp"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             1, | ||||||
|  | 		MachineKey:     "12345", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "webserver", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user1", "user") | ||||||
|  | 	hostInfo2 := tailcfg.Hostinfo{ | ||||||
|  | 		OS:       "debian", | ||||||
|  | 		Hostname: "Hostname", | ||||||
|  | 	} | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	machine = types.Machine{ | ||||||
|  | 		ID:             2, | ||||||
|  | 		MachineKey:     "56789", | ||||||
|  | 		NodeKey:        "bar2", | ||||||
|  | 		DiscoKey:       "faab", | ||||||
|  | 		Hostname:       "user", | ||||||
|  | 		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo:       types.HostInfo(hostInfo2), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pol := &policy.ACLPolicy{ | ||||||
|  | 		TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"user1"}, | ||||||
|  | 				Destinations: []string{"tag:webapp:80,443"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") | ||||||
|  | 	c.Assert(rules[0].DstPorts, check.HasLen, 2) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") | ||||||
|  | 	c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestPortUser(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("testuser") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("testuser", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	ips, _ := db.getAvailableIPs() | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "12345", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    ips, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	acl := []byte(` | ||||||
|  | { | ||||||
|  | 	"hosts": { | ||||||
|  | 		"host-1": "100.100.100.100", | ||||||
|  | 		"subnet-1": "100.100.101.100/24", | ||||||
|  | 	}, | ||||||
|  | 
 | ||||||
|  | 	"acls": [ | ||||||
|  | 		{ | ||||||
|  | 			"action": "accept", | ||||||
|  | 			"src": [ | ||||||
|  | 				"testuser", | ||||||
|  | 			], | ||||||
|  | 			"dst": [ | ||||||
|  | 				"host-1:*", | ||||||
|  | 			], | ||||||
|  | 		}, | ||||||
|  | 	], | ||||||
|  | } | ||||||
|  | 	`) | ||||||
|  | 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(pol, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(rules, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) | ||||||
|  | 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") | ||||||
|  | 	c.Assert(len(ips), check.Equals, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestPortGroup(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("testuser") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("testuser", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 	ips, _ := db.getAvailableIPs() | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    ips, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	err = db.MachineSave(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	acl := []byte(` | ||||||
|  | { | ||||||
|  | 	"groups": { | ||||||
|  | 		"group:example": [ | ||||||
|  | 			"testuser", | ||||||
|  | 		], | ||||||
|  | 	}, | ||||||
|  | 
 | ||||||
|  | 	"hosts": { | ||||||
|  | 		"host-1": "100.100.100.100", | ||||||
|  | 		"subnet-1": "100.100.101.100/24", | ||||||
|  | 	}, | ||||||
|  | 
 | ||||||
|  | 	"acls": [ | ||||||
|  | 		{ | ||||||
|  | 			"action": "accept", | ||||||
|  | 			"src": [ | ||||||
|  | 				"group:example", | ||||||
|  | 			], | ||||||
|  | 			"dst": [ | ||||||
|  | 				"host-1:*", | ||||||
|  | 			], | ||||||
|  | 		}, | ||||||
|  | 	], | ||||||
|  | } | ||||||
|  | 	`) | ||||||
|  | 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	rules, _, err := policy.GenerateFilterRules(pol, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(rules, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) | ||||||
|  | 	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) | ||||||
|  | 	c.Assert(rules[0].SrcIPs, check.HasLen, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") | ||||||
|  | 	c.Assert(len(ips), check.Equals, 1) | ||||||
|  | 	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") | ||||||
|  | } | ||||||
| @ -3,21 +3,22 @@ | |||||||
| // Use of this source code is governed by a BSD-style | // Use of this source code is governed by a BSD-style | ||||||
| // license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||||
| 
 | 
 | ||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"go4.org/netipx" | 	"go4.org/netipx" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") | var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { | func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { | ||||||
| 	var ips MachineAddresses | 	var ips types.MachineAddresses | ||||||
| 	var err error | 	var err error | ||||||
| 	for _, ipPrefix := range hsdb.ipPrefixes { | 	for _, ipPrefix := range hsdb.ipPrefixes { | ||||||
| 		var ip *netip.Addr | 		var ip *netip.Addr | ||||||
| @ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { | |||||||
| 	// but this was quick to get running and it should be enough | 	// but this was quick to get running and it should be enough | ||||||
| 	// to begin experimenting with a dual stack tailnet. | 	// to begin experimenting with a dual stack tailnet. | ||||||
| 	var addressesSlices []string | 	var addressesSlices []string | ||||||
| 	hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) | 	hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices) | ||||||
| 
 | 
 | ||||||
| 	var ips netipx.IPSetBuilder | 	var ips netipx.IPSetBuilder | ||||||
| 	for _, slice := range addressesSlices { | 	for _, slice := range addressesSlices { | ||||||
| 		var machineAddresses MachineAddresses | 		var machineAddresses types.MachineAddresses | ||||||
| 		err := machineAddresses.Scan(slice) | 		err := machineAddresses.Scan(slice) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return &netipx.IPSet{}, fmt.Errorf( | 			return &netipx.IPSet{}, fmt.Errorf( | ||||||
| @ -1,14 +1,16 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"go4.org/netipx" | 	"go4.org/netipx" | ||||||
| 	"gopkg.in/check.v1" | 	"gopkg.in/check.v1" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetAvailableIp(c *check.C) { | func (s *Suite) TestGetAvailableIp(c *check.C) { | ||||||
| 	ips, err := app.db.getAvailableIPs() | 	ips, err := db.getAvailableIPs() | ||||||
| 
 | 
 | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| @ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetUsedIps(c *check.C) { | func (s *Suite) TestGetUsedIps(c *check.C) { | ||||||
| 	ips, err := app.db.getAvailableIPs() | 	ips, err := db.getAvailableIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	user, err := app.db.CreateUser("test-ip") | 	user, err := db.CreateUser("test-ip") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "testmachine") | 	_, err = db.GetMachine("test", "testmachine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "testmachine", | 		Hostname:       "testmachine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		IPAddresses:    ips, | 		IPAddresses:    ips, | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	usedIps, err := app.db.getUsedIPs() | 	usedIps, err := db.getUsedIPs() | ||||||
| 
 | 
 | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| @ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) { | |||||||
| 	c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) | 	c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) | ||||||
| 	c.Assert(usedIps.Contains(expected), check.Equals, true) | 	c.Assert(usedIps.Contains(expected), check.Equals, true) | ||||||
| 
 | 
 | ||||||
| 	machine1, err := app.db.GetMachineByID(0) | 	machine1, err := db.GetMachineByID(0) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | ||||||
| 	c.Assert(machine1.IPAddresses[0], check.Equals, expected) | 	c.Assert(machine1.IPAddresses[0], check.Equals, expected) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetMultiIp(c *check.C) { | func (s *Suite) TestGetMultiIp(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test-ip-multi") | 	user, err := db.CreateUser("test-ip-multi") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	for index := 1; index <= 350; index++ { | 	for index := 1; index <= 350; index++ { | ||||||
| 		app.db.ipAllocationMutex.Lock() | 		db.ipAllocationMutex.Lock() | ||||||
| 
 | 
 | ||||||
| 		ips, err := app.db.getAvailableIPs() | 		ips, err := db.getAvailableIPs() | ||||||
| 		c.Assert(err, check.IsNil) | 		c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 		pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 		pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 		c.Assert(err, check.IsNil) | 		c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 		_, err = app.db.GetMachine("test", "testmachine") | 		_, err = db.GetMachine("test", "testmachine") | ||||||
| 		c.Assert(err, check.NotNil) | 		c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 		machine := Machine{ | 		machine := types.Machine{ | ||||||
| 			ID:             uint64(index), | 			ID:             uint64(index), | ||||||
| 			MachineKey:     "foo", | 			MachineKey:     "foo", | ||||||
| 			NodeKey:        "bar", | 			NodeKey:        "bar", | ||||||
| 			DiscoKey:       "faa", | 			DiscoKey:       "faa", | ||||||
| 			Hostname:       "testmachine", | 			Hostname:       "testmachine", | ||||||
| 			UserID:         user.ID, | 			UserID:         user.ID, | ||||||
| 			RegisterMethod: RegisterMethodAuthKey, | 			RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 			AuthKeyID:      uint(pak.ID), | 			AuthKeyID:      uint(pak.ID), | ||||||
| 			IPAddresses:    ips, | 			IPAddresses:    ips, | ||||||
| 		} | 		} | ||||||
| 		app.db.db.Save(&machine) | 		db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 		app.db.ipAllocationMutex.Unlock() | 		db.ipAllocationMutex.Unlock() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	usedIps, err := app.db.getUsedIPs() | 	usedIps, err := db.getUsedIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	expected0 := netip.MustParseAddr("10.27.0.1") | 	expected0 := netip.MustParseAddr("10.27.0.1") | ||||||
| @ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | |||||||
| 	c.Assert(usedIps.Contains(expected300), check.Equals, true) | 	c.Assert(usedIps.Contains(expected300), check.Equals, true) | ||||||
| 
 | 
 | ||||||
| 	// Check that we can read back the IPs | 	// Check that we can read back the IPs | ||||||
| 	machine1, err := app.db.GetMachineByID(1) | 	machine1, err := db.GetMachineByID(1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | 	c.Assert(len(machine1.IPAddresses), check.Equals, 1) | ||||||
| 	c.Assert( | 	c.Assert( | ||||||
| @ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | |||||||
| 		netip.MustParseAddr("10.27.0.1"), | 		netip.MustParseAddr("10.27.0.1"), | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	machine50, err := app.db.GetMachineByID(50) | 	machine50, err := db.GetMachineByID(50) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(machine50.IPAddresses), check.Equals, 1) | 	c.Assert(len(machine50.IPAddresses), check.Equals, 1) | ||||||
| 	c.Assert( | 	c.Assert( | ||||||
| @ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | |||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	expectedNextIP := netip.MustParseAddr("10.27.1.95") | 	expectedNextIP := netip.MustParseAddr("10.27.1.95") | ||||||
| 	nextIP, err := app.db.getAvailableIPs() | 	nextIP, err := db.getAvailableIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(len(nextIP), check.Equals, 1) | 	c.Assert(len(nextIP), check.Equals, 1) | ||||||
| @ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) { | |||||||
| 
 | 
 | ||||||
| 	// If we call get Available again, we should receive | 	// If we call get Available again, we should receive | ||||||
| 	// the same IP, as it has not been reserved. | 	// the same IP, as it has not been reserved. | ||||||
| 	nextIP2, err := app.db.getAvailableIPs() | 	nextIP2, err := db.getAvailableIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(len(nextIP2), check.Equals, 1) | 	c.Assert(len(nextIP2), check.Equals, 1) | ||||||
| 	c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) | 	c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { | func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { | ||||||
| 	ips, err := app.db.getAvailableIPs() | 	ips, err := db.getAvailableIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	expected := netip.MustParseAddr("10.27.0.1") | 	expected := netip.MustParseAddr("10.27.0.1") | ||||||
| @ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { | |||||||
| 	c.Assert(len(ips), check.Equals, 1) | 	c.Assert(len(ips), check.Equals, 1) | ||||||
| 	c.Assert(ips[0].String(), check.Equals, expected.String()) | 	c.Assert(ips[0].String(), check.Equals, expected.String()) | ||||||
| 
 | 
 | ||||||
| 	user, err := app.db.CreateUser("test-ip") | 	user, err := db.CreateUser("test-ip") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "testmachine") | 	_, err = db.GetMachine("test", "testmachine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "testmachine", | 		Hostname:       "testmachine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	ips2, err := app.db.getAvailableIPs() | 	ips2, err := db.getAvailableIPs() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(len(ips2), check.Equals, 1) | 	c.Assert(len(ips2), check.Equals, 1) | ||||||
| 	c.Assert(ips2[0].String(), check.Equals, expected.String()) | 	c.Assert(ips2[0].String(), check.Equals, expected.String()) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| @ -1,4 +1,4 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| @ -6,10 +6,9 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| 	"google.golang.org/protobuf/types/known/timestamppb" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -19,22 +18,10 @@ const ( | |||||||
| 
 | 
 | ||||||
| var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") | var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") | ||||||
| 
 | 
 | ||||||
| // APIKey describes the datamodel for API keys used to remotely authenticate with |  | ||||||
| // headscale. |  | ||||||
| type APIKey struct { |  | ||||||
| 	ID     uint64 `gorm:"primary_key"` |  | ||||||
| 	Prefix string `gorm:"uniqueIndex"` |  | ||||||
| 	Hash   []byte |  | ||||||
| 
 |  | ||||||
| 	CreatedAt  *time.Time |  | ||||||
| 	Expiration *time.Time |  | ||||||
| 	LastSeen   *time.Time |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // CreateAPIKey creates a new ApiKey in a user, and returns it. | // CreateAPIKey creates a new ApiKey in a user, and returns it. | ||||||
| func (hsdb *HSDatabase) CreateAPIKey( | func (hsdb *HSDatabase) CreateAPIKey( | ||||||
| 	expiration *time.Time, | 	expiration *time.Time, | ||||||
| ) (string, *APIKey, error) { | ) (string, *types.APIKey, error) { | ||||||
| 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", nil, err | 		return "", nil, err | ||||||
| @ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey( | |||||||
| 		return "", nil, err | 		return "", nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	key := APIKey{ | 	key := types.APIKey{ | ||||||
| 		Prefix:     prefix, | 		Prefix:     prefix, | ||||||
| 		Hash:       hash, | 		Hash:       hash, | ||||||
| 		Expiration: expiration, | 		Expiration: expiration, | ||||||
| @ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListAPIKeys returns the list of ApiKeys for a user. | // ListAPIKeys returns the list of ApiKeys for a user. | ||||||
| func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { | func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | ||||||
| 	keys := []APIKey{} | 	keys := []types.APIKey{} | ||||||
| 	if err := hsdb.db.Find(&keys).Error; err != nil { | 	if err := hsdb.db.Find(&keys).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetAPIKey returns a ApiKey for a given key. | // GetAPIKey returns a ApiKey for a given key. | ||||||
| func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { | func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | ||||||
| 	key := APIKey{} | 	key := types.APIKey{} | ||||||
| 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | ||||||
| 		return nil, result.Error | 		return nil, result.Error | ||||||
| 	} | 	} | ||||||
| @ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetAPIKeyByID returns a ApiKey for a given id. | // GetAPIKeyByID returns a ApiKey for a given id. | ||||||
| func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { | func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | ||||||
| 	key := APIKey{} | 	key := types.APIKey{} | ||||||
| 	if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { | 	if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { | ||||||
| 		return nil, result.Error | 		return nil, result.Error | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { | |||||||
| 
 | 
 | ||||||
| // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | ||||||
| // does not exist. | // does not exist. | ||||||
| func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { | func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | ||||||
| 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | ||||||
| 		return result.Error | 		return result.Error | ||||||
| 	} | 	} | ||||||
| @ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ExpireAPIKey marks a ApiKey as expired. | // ExpireAPIKey marks a ApiKey as expired. | ||||||
| func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { | func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | ||||||
| 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { | |||||||
| 
 | 
 | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (key *APIKey) toProto() *v1.ApiKey { |  | ||||||
| 	protoKey := v1.ApiKey{ |  | ||||||
| 		Id:     key.ID, |  | ||||||
| 		Prefix: key.Prefix, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if key.Expiration != nil { |  | ||||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if key.CreatedAt != nil { |  | ||||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if key.LastSeen != nil { |  | ||||||
| 		protoKey.LastSeen = timestamppb.New(*key.LastSeen) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &protoKey |  | ||||||
| } |  | ||||||
| @ -1,4 +1,4 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"time" | 	"time" | ||||||
| @ -7,7 +7,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestCreateAPIKey(c *check.C) { | func (*Suite) TestCreateAPIKey(c *check.C) { | ||||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) | 	apiKeyStr, apiKey, err := db.CreateAPIKey(nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey, check.NotNil) | 	c.Assert(apiKey, check.NotNil) | ||||||
| 
 | 
 | ||||||
| @ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) { | |||||||
| 	c.Assert(apiKey.Hash, check.NotNil) | 	c.Assert(apiKey.Hash, check.NotNil) | ||||||
| 	c.Assert(apiKeyStr, check.Not(check.Equals), "") | 	c.Assert(apiKeyStr, check.Not(check.Equals), "") | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.ListAPIKeys() | 	_, err = db.ListAPIKeys() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	keys, err := app.db.ListAPIKeys() | 	keys, err := db.ListAPIKeys() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(keys), check.Equals, 1) | 	c.Assert(len(keys), check.Equals, 1) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { | func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { | ||||||
| 	key, err := app.db.GetAPIKey("does-not-exist") | 	key, err := db.GetAPIKey("does-not-exist") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 	c.Assert(key, check.IsNil) | 	c.Assert(key, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestValidateAPIKeyOk(c *check.C) { | func (*Suite) TestValidateAPIKeyOk(c *check.C) { | ||||||
| 	nowPlus2 := time.Now().Add(2 * time.Hour) | 	nowPlus2 := time.Now().Add(2 * time.Hour) | ||||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) | 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey, check.NotNil) | 	c.Assert(apiKey, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(valid, check.Equals, true) | 	c.Assert(valid, check.Equals, true) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { | func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { | ||||||
| 	nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) | 	nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) | ||||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) | 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey, check.NotNil) | 	c.Assert(apiKey, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(valid, check.Equals, false) | 	c.Assert(valid, check.Equals, false) | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) | 	apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey, check.NotNil) | 	c.Assert(apiKey, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) | 	validNow, err := db.ValidateAPIKey(apiKeyStrNow) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(validNow, check.Equals, false) | 	c.Assert(validNow, check.Equals, false) | ||||||
| 
 | 
 | ||||||
| 	validSilly, err := app.db.ValidateAPIKey("nota.validkey") | 	validSilly, err := db.ValidateAPIKey("nota.validkey") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 	c.Assert(validSilly, check.Equals, false) | 	c.Assert(validSilly, check.Equals, false) | ||||||
| 
 | 
 | ||||||
| 	validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") | 	validWithErr, err := db.ValidateAPIKey("produceerrorkey") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 	c.Assert(validWithErr, check.Equals, false) | 	c.Assert(validWithErr, check.Equals, false) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestExpireAPIKey(c *check.C) { | func (*Suite) TestExpireAPIKey(c *check.C) { | ||||||
| 	nowPlus2 := time.Now().Add(2 * time.Hour) | 	nowPlus2 := time.Now().Add(2 * time.Hour) | ||||||
| 	apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) | 	apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey, check.NotNil) | 	c.Assert(apiKey, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	valid, err := app.db.ValidateAPIKey(apiKeyStr) | 	valid, err := db.ValidateAPIKey(apiKeyStr) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(valid, check.Equals, true) | 	c.Assert(valid, check.Equals, true) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.ExpireAPIKey(apiKey) | 	err = db.ExpireAPIKey(apiKey) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(apiKey.Expiration, check.NotNil) | 	c.Assert(apiKey.Expiration, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	notValid, err := app.db.ValidateAPIKey(apiKeyStr) | 	notValid, err := db.ValidateAPIKey(apiKeyStr) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(notValid, check.Equals, false) | 	c.Assert(notValid, check.Equals, false) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| @ -1,9 +1,7 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql/driver" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| @ -11,11 +9,12 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/glebarez/sqlite" | 	"github.com/glebarez/sqlite" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"gorm.io/driver/postgres" | 	"gorm.io/driver/postgres" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/logger" | 	"gorm.io/gorm/logger" | ||||||
| 	"tailscale.com/tailcfg" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -26,7 +25,6 @@ const ( | |||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	errValueNotFound        = errors.New("not found") | 	errValueNotFound        = errors.New("not found") | ||||||
| 	ErrCannotParsePrefix    = errors.New("cannot parse prefix") |  | ||||||
| 	errDatabaseNotSupported = errors.New("database type not supported") | 	errDatabaseNotSupported = errors.New("database type not supported") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -38,8 +36,9 @@ type KV struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type HSDatabase struct { | type HSDatabase struct { | ||||||
| 	db              *gorm.DB | 	db               *gorm.DB | ||||||
| 	notifyStateChan chan<- struct{} | 	notifyStateChan  chan<- struct{} | ||||||
|  | 	notifyPolicyChan chan<- struct{} | ||||||
| 
 | 
 | ||||||
| 	ipAllocationMutex sync.Mutex | 	ipAllocationMutex sync.Mutex | ||||||
| 
 | 
 | ||||||
| @ -54,6 +53,7 @@ func NewHeadscaleDatabase( | |||||||
| 	dbType, connectionAddr string, | 	dbType, connectionAddr string, | ||||||
| 	stripEmailDomain, debug bool, | 	stripEmailDomain, debug bool, | ||||||
| 	notifyStateChan chan<- struct{}, | 	notifyStateChan chan<- struct{}, | ||||||
|  | 	notifyPolicyChan chan<- struct{}, | ||||||
| 	ipPrefixes []netip.Prefix, | 	ipPrefixes []netip.Prefix, | ||||||
| 	baseDomain string, | 	baseDomain string, | ||||||
| ) (*HSDatabase, error) { | ) (*HSDatabase, error) { | ||||||
| @ -63,8 +63,9 @@ func NewHeadscaleDatabase( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db := HSDatabase{ | 	db := HSDatabase{ | ||||||
| 		db:              dbConn, | 		db:               dbConn, | ||||||
| 		notifyStateChan: notifyStateChan, | 		notifyStateChan:  notifyStateChan, | ||||||
|  | 		notifyPolicyChan: notifyPolicyChan, | ||||||
| 
 | 
 | ||||||
| 		ipPrefixes:       ipPrefixes, | 		ipPrefixes:       ipPrefixes, | ||||||
| 		baseDomain:       baseDomain, | 		baseDomain:       baseDomain, | ||||||
| @ -79,30 +80,30 @@ func NewHeadscaleDatabase( | |||||||
| 
 | 
 | ||||||
| 	_ = dbConn.Migrator().RenameTable("namespaces", "users") | 	_ = dbConn.Migrator().RenameTable("namespaces", "users") | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(User{}) | 	err = dbConn.AutoMigrate(types.User{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") | 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id") | ||||||
| 	_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") | 	_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") | ||||||
| 
 | 
 | ||||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") | 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses") | ||||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") | 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname") | ||||||
| 
 | 
 | ||||||
| 	// GivenName is used as the primary source of DNS names, make sure | 	// GivenName is used as the primary source of DNS names, make sure | ||||||
| 	// the field is populated and normalized if it was not when the | 	// the field is populated and normalized if it was not when the | ||||||
| 	// machine was registered. | 	// machine was registered. | ||||||
| 	_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") | 	_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name") | ||||||
| 
 | 
 | ||||||
| 	// If the Machine table has a column for registered, | 	// If the Machine table has a column for registered, | ||||||
| 	// find all occourences of "false" and drop them. Then | 	// find all occourences of "false" and drop them. Then | ||||||
| 	// remove the column. | 	// remove the column. | ||||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "registered") { | 	if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") { | ||||||
| 		log.Info(). | 		log.Info(). | ||||||
| 			Msg(`Database has legacy "registered" column in machine, removing...`) | 			Msg(`Database has legacy "registered" column in machine, removing...`) | ||||||
| 
 | 
 | ||||||
| 		machines := Machines{} | 		machines := types.Machines{} | ||||||
| 		if err := dbConn.Not("registered").Find(&machines).Error; err != nil { | 		if err := dbConn.Not("registered").Find(&machines).Error; err != nil { | ||||||
| 			log.Error().Err(err).Msg("Error accessing db") | 			log.Error().Err(err).Msg("Error accessing db") | ||||||
| 		} | 		} | ||||||
| @ -112,7 +113,7 @@ func NewHeadscaleDatabase( | |||||||
| 				Str("machine", machine.Hostname). | 				Str("machine", machine.Hostname). | ||||||
| 				Str("machine_key", machine.MachineKey). | 				Str("machine_key", machine.MachineKey). | ||||||
| 				Msg("Deleting unregistered machine") | 				Msg("Deleting unregistered machine") | ||||||
| 			if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { | 			if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil { | ||||||
| 				log.Error(). | 				log.Error(). | ||||||
| 					Err(err). | 					Err(err). | ||||||
| 					Str("machine", machine.Hostname). | 					Str("machine", machine.Hostname). | ||||||
| @ -121,23 +122,23 @@ func NewHeadscaleDatabase( | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		err := dbConn.Migrator().DropColumn(&Machine{}, "registered") | 		err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Err(err).Msg("Error dropping registered column") | 			log.Error().Err(err).Msg("Error dropping registered column") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(&Route{}) | 	err = dbConn.AutoMigrate(&types.Route{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { | 	if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") { | ||||||
| 		log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") | 		log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") | ||||||
| 
 | 
 | ||||||
| 		type MachineAux struct { | 		type MachineAux struct { | ||||||
| 			ID            uint64 | 			ID            uint64 | ||||||
| 			EnabledRoutes IPPrefixes | 			EnabledRoutes types.IPPrefixes | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		machinesAux := []MachineAux{} | 		machinesAux := []MachineAux{} | ||||||
| @ -157,8 +158,8 @@ func NewHeadscaleDatabase( | |||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				err = dbConn.Preload("Machine"). | 				err = dbConn.Preload("Machine"). | ||||||
| 					Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). | 					Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). | ||||||
| 					First(&Route{}). | 					First(&types.Route{}). | ||||||
| 					Error | 					Error | ||||||
| 				if err == nil { | 				if err == nil { | ||||||
| 					log.Info(). | 					log.Info(). | ||||||
| @ -168,11 +169,11 @@ func NewHeadscaleDatabase( | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				route := Route{ | 				route := types.Route{ | ||||||
| 					MachineID:  machine.ID, | 					MachineID:  machine.ID, | ||||||
| 					Advertised: true, | 					Advertised: true, | ||||||
| 					Enabled:    true, | 					Enabled:    true, | ||||||
| 					Prefix:     IPPrefix(prefix), | 					Prefix:     types.IPPrefix(prefix), | ||||||
| 				} | 				} | ||||||
| 				if err := dbConn.Create(&route).Error; err != nil { | 				if err := dbConn.Create(&route).Error; err != nil { | ||||||
| 					log.Error().Err(err).Msg("Error creating route") | 					log.Error().Err(err).Msg("Error creating route") | ||||||
| @ -185,26 +186,26 @@ func NewHeadscaleDatabase( | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") | 		err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Err(err).Msg("Error dropping enabled_routes column") | 			log.Error().Err(err).Msg("Error dropping enabled_routes column") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(&Machine{}) | 	err = dbConn.AutoMigrate(&types.Machine{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { | 	if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") { | ||||||
| 		machines := Machines{} | 		machines := types.Machines{} | ||||||
| 		if err := dbConn.Find(&machines).Error; err != nil { | 		if err := dbConn.Find(&machines).Error; err != nil { | ||||||
| 			log.Error().Err(err).Msg("Error accessing db") | 			log.Error().Err(err).Msg("Error accessing db") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for item, machine := range machines { | 		for item, machine := range machines { | ||||||
| 			if machine.GivenName == "" { | 			if machine.GivenName == "" { | ||||||
| 				normalizedHostname, err := NormalizeToFQDNRules( | 				normalizedHostname, err := util.NormalizeToFQDNRules( | ||||||
| 					machine.Hostname, | 					machine.Hostname, | ||||||
| 					stripEmailDomain, | 					stripEmailDomain, | ||||||
| 				) | 				) | ||||||
| @ -233,19 +234,19 @@ func NewHeadscaleDatabase( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(&PreAuthKey{}) | 	err = dbConn.AutoMigrate(&types.PreAuthKey{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) | 	err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	_ = dbConn.Migrator().DropTable("shared_machines") | 	_ = dbConn.Migrator().DropTable("shared_machines") | ||||||
| 
 | 
 | ||||||
| 	err = dbConn.AutoMigrate(&APIKey{}) | 	err = dbConn.AutoMigrate(&types.APIKey{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) pingDB(ctx context.Context) error { | func (hsdb *HSDatabase) PingDB(ctx context.Context) error { | ||||||
| 	ctx, cancel := context.WithTimeout(ctx, time.Second) | 	ctx, cancel := context.WithTimeout(ctx, time.Second) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
| 	sqlDB, err := hsdb.db.DB() | 	sqlDB, err := hsdb.db.DB() | ||||||
| @ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error { | |||||||
| 	return sqlDB.PingContext(ctx) | 	return sqlDB.PingContext(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This is a "wrapper" type around tailscales | func (hsdb *HSDatabase) Close() error { | ||||||
| // Hostinfo to allow us to add database "serialization" | 	db, err := hsdb.db.DB() | ||||||
| // methods. This allows us to use a typed values throughout | 	if err != nil { | ||||||
| // the code and not have to marshal/unmarshal and error | 		return err | ||||||
| // check all over the code. |  | ||||||
| type HostInfo tailcfg.Hostinfo |  | ||||||
| 
 |  | ||||||
| func (hi *HostInfo) Scan(destination interface{}) error { |  | ||||||
| 	switch value := destination.(type) { |  | ||||||
| 	case []byte: |  | ||||||
| 		return json.Unmarshal(value, hi) |  | ||||||
| 
 |  | ||||||
| 	case string: |  | ||||||
| 		return json.Unmarshal([]byte(value), hi) |  | ||||||
| 
 |  | ||||||
| 	default: |  | ||||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) |  | ||||||
| 	} | 	} | ||||||
| } | 
 | ||||||
| 
 | 	return db.Close() | ||||||
| // Value return json value, implement driver.Valuer interface. |  | ||||||
| func (hi HostInfo) Value() (driver.Value, error) { |  | ||||||
| 	bytes, err := json.Marshal(hi) |  | ||||||
| 
 |  | ||||||
| 	return string(bytes), err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type IPPrefix netip.Prefix |  | ||||||
| 
 |  | ||||||
| func (i *IPPrefix) Scan(destination interface{}) error { |  | ||||||
| 	switch value := destination.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		prefix, err := netip.ParsePrefix(value) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		*i = IPPrefix(prefix) |  | ||||||
| 
 |  | ||||||
| 		return nil |  | ||||||
| 	default: |  | ||||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Value return json value, implement driver.Valuer interface. |  | ||||||
| func (i IPPrefix) Value() (driver.Value, error) { |  | ||||||
| 	prefixStr := netip.Prefix(i).String() |  | ||||||
| 
 |  | ||||||
| 	return prefixStr, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type IPPrefixes []netip.Prefix |  | ||||||
| 
 |  | ||||||
| func (i *IPPrefixes) Scan(destination interface{}) error { |  | ||||||
| 	switch value := destination.(type) { |  | ||||||
| 	case []byte: |  | ||||||
| 		return json.Unmarshal(value, i) |  | ||||||
| 
 |  | ||||||
| 	case string: |  | ||||||
| 		return json.Unmarshal([]byte(value), i) |  | ||||||
| 
 |  | ||||||
| 	default: |  | ||||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Value return json value, implement driver.Valuer interface. |  | ||||||
| func (i IPPrefixes) Value() (driver.Value, error) { |  | ||||||
| 	bytes, err := json.Marshal(i) |  | ||||||
| 
 |  | ||||||
| 	return string(bytes), err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type StringList []string |  | ||||||
| 
 |  | ||||||
| func (i *StringList) Scan(destination interface{}) error { |  | ||||||
| 	switch value := destination.(type) { |  | ||||||
| 	case []byte: |  | ||||||
| 		return json.Unmarshal(value, i) |  | ||||||
| 
 |  | ||||||
| 	case string: |  | ||||||
| 		return json.Unmarshal([]byte(value), i) |  | ||||||
| 
 |  | ||||||
| 	default: |  | ||||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Value return json value, implement driver.Valuer interface. |  | ||||||
| func (i StringList) Value() (driver.Value, error) { |  | ||||||
| 	bytes, err := json.Marshal(i) |  | ||||||
| 
 |  | ||||||
| 	return string(bytes), err |  | ||||||
| } | } | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										797
									
								
								hscontrol/db/machine_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										797
									
								
								hscontrol/db/machine_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,797 @@ | |||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strconv" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"gopkg.in/check.v1" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | 	"tailscale.com/types/key" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetMachine(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machine := &types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machine) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetMachineByID(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetMachineByNodeKey(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	nodeKey := key.NewNode() | ||||||
|  | 	machineKey := key.NewMachine() | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||||
|  | 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByNodeKey(nodeKey.Public()) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	nodeKey := key.NewNode() | ||||||
|  | 	oldNodeKey := key.NewNode() | ||||||
|  | 
 | ||||||
|  | 	machineKey := key.NewMachine() | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||||
|  | 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestDeleteMachine(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(1), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	err = db.DeleteMachine(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(user.Name, "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestHardDeleteMachine(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine3", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(1), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	err = db.HardDeleteMachine(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(user.Name, "testmachine3") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestListPeers(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	for index := 0; index <= 10; index++ { | ||||||
|  | 		machine := types.Machine{ | ||||||
|  | 			ID:             uint64(index), | ||||||
|  | 			MachineKey:     "foo" + strconv.Itoa(index), | ||||||
|  | 			NodeKey:        "bar" + strconv.Itoa(index), | ||||||
|  | 			DiscoKey:       "faa" + strconv.Itoa(index), | ||||||
|  | 			Hostname:       "testmachine" + strconv.Itoa(index), | ||||||
|  | 			UserID:         user.ID, | ||||||
|  | 			RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 			AuthKeyID:      uint(pak.ID), | ||||||
|  | 		} | ||||||
|  | 		db.db.Save(&machine) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	machine0ByID, err := db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	peersOfMachine0, err := db.ListPeers(machine0ByID) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(len(peersOfMachine0), check.Equals, 9) | ||||||
|  | 	c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") | ||||||
|  | 	c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") | ||||||
|  | 	c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetACLFilteredPeers(c *check.C) { | ||||||
|  | 	type base struct { | ||||||
|  | 		user *types.User | ||||||
|  | 		key  *types.PreAuthKey | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stor := make([]base, 0) | ||||||
|  | 
 | ||||||
|  | 	for _, name := range []string{"test", "admin"} { | ||||||
|  | 		user, err := db.CreateUser(name) | ||||||
|  | 		c.Assert(err, check.IsNil) | ||||||
|  | 		pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 		c.Assert(err, check.IsNil) | ||||||
|  | 		stor = append(stor, base{user, pak}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err := db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	for index := 0; index <= 10; index++ { | ||||||
|  | 		machine := types.Machine{ | ||||||
|  | 			ID:         uint64(index), | ||||||
|  | 			MachineKey: "foo" + strconv.Itoa(index), | ||||||
|  | 			NodeKey:    "bar" + strconv.Itoa(index), | ||||||
|  | 			DiscoKey:   "faa" + strconv.Itoa(index), | ||||||
|  | 			IPAddresses: types.MachineAddresses{ | ||||||
|  | 				netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), | ||||||
|  | 			}, | ||||||
|  | 			Hostname:       "testmachine" + strconv.Itoa(index), | ||||||
|  | 			UserID:         stor[index%2].user.ID, | ||||||
|  | 			RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 			AuthKeyID:      uint(stor[index%2].key.ID), | ||||||
|  | 		} | ||||||
|  | 		db.db.Save(&machine) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	aclPolicy := &policy.ACLPolicy{ | ||||||
|  | 		Groups: map[string][]string{ | ||||||
|  | 			"group:test": {"admin"}, | ||||||
|  | 		}, | ||||||
|  | 		Hosts:     map[string]netip.Prefix{}, | ||||||
|  | 		TagOwners: map[string][]string{}, | ||||||
|  | 		ACLs: []policy.ACL{ | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"admin"}, | ||||||
|  | 				Destinations: []string{"*:*"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				Action:       "accept", | ||||||
|  | 				Sources:      []string{"test"}, | ||||||
|  | 				Destinations: []string{"test:*"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		Tests: []policy.ACLTest{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	adminMachine, err := db.GetMachineByID(1) | ||||||
|  | 	c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	testMachine, err := db.GetMachineByID(2) | ||||||
|  | 	c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machines, err := db.ListMachines() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) | ||||||
|  | 	peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) | ||||||
|  | 
 | ||||||
|  | 	c.Log(peersOfTestMachine) | ||||||
|  | 	c.Assert(len(peersOfTestMachine), check.Equals, 9) | ||||||
|  | 	c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") | ||||||
|  | 	c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") | ||||||
|  | 	c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") | ||||||
|  | 
 | ||||||
|  | 	c.Log(peersOfAdminMachine) | ||||||
|  | 	c.Assert(len(peersOfAdminMachine), check.Equals, 9) | ||||||
|  | 	c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") | ||||||
|  | 	c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") | ||||||
|  | 	c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestExpireMachine(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machine := &types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		Expiry:         &time.Time{}, | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machine) | ||||||
|  | 
 | ||||||
|  | 	machineFromDB, err := db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(machineFromDB, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(machineFromDB.IsExpired(), check.Equals, false) | ||||||
|  | 
 | ||||||
|  | 	err = db.ExpireMachine(machineFromDB) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(machineFromDB.IsExpired(), check.Equals, true) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(1)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { | ||||||
|  | 	input := types.MachineAddresses([]netip.Addr{ | ||||||
|  | 		netip.MustParseAddr("192.0.2.1"), | ||||||
|  | 		netip.MustParseAddr("2001:db8::1"), | ||||||
|  | 	}) | ||||||
|  | 	serialized, err := input.Value() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	if serial, ok := serialized.(string); ok { | ||||||
|  | 		c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var deserialized types.MachineAddresses | ||||||
|  | 	err = deserialized.Scan(serialized) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(len(deserialized), check.Equals, len(input)) | ||||||
|  | 	for i := range deserialized { | ||||||
|  | 		c.Assert(deserialized[i], check.Equals, input[i]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGenerateGivenName(c *check.C) { | ||||||
|  | 	user1, err := db.CreateUser("user-1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("user-1", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machine := &types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "machine-key-1", | ||||||
|  | 		NodeKey:        "node-key-1", | ||||||
|  | 		DiscoKey:       "disco-key-1", | ||||||
|  | 		Hostname:       "hostname-1", | ||||||
|  | 		GivenName:      "hostname-1", | ||||||
|  | 		UserID:         user1.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machine) | ||||||
|  | 
 | ||||||
|  | 	givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2") | ||||||
|  | 	comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") | ||||||
|  | 	c.Assert(err, check.IsNil, comment) | ||||||
|  | 	c.Assert(givenName, check.Equals, "hostname-2", comment) | ||||||
|  | 
 | ||||||
|  | 	givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1") | ||||||
|  | 	comment = check.Commentf("Same user, same machine, same hostname, no conflict") | ||||||
|  | 	c.Assert(err, check.IsNil, comment) | ||||||
|  | 	c.Assert(givenName, check.Equals, "hostname-1", comment) | ||||||
|  | 
 | ||||||
|  | 	givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") | ||||||
|  | 	comment = check.Commentf("Same user, unique machines, same hostname, conflict") | ||||||
|  | 	c.Assert(err, check.IsNil, comment) | ||||||
|  | 	c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) | ||||||
|  | 
 | ||||||
|  | 	givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") | ||||||
|  | 	comment = check.Commentf("Unique users, unique machines, same hostname, conflict") | ||||||
|  | 	c.Assert(err, check.IsNil, comment) | ||||||
|  | 	c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestSetTags(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machine := &types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machine) | ||||||
|  | 
 | ||||||
|  | 	// assign simple tags | ||||||
|  | 	sTags := []string{"tag:test", "tag:foo"} | ||||||
|  | 	err = db.SetTags(machine, sTags) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	machine, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags)) | ||||||
|  | 
 | ||||||
|  | 	// assign duplicat tags, expect no errors but no doubles in DB | ||||||
|  | 	eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} | ||||||
|  | 	err = db.SetTags(machine, eTags) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	machine, err = db.GetMachine("test", "testmachine") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert( | ||||||
|  | 		machine.ForcedTags, | ||||||
|  | 		check.DeepEquals, | ||||||
|  | 		types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(4)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestHeadscale_generateGivenName(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		suppliedName string | ||||||
|  | 		randomSuffix bool | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name    string | ||||||
|  | 		db      *HSDatabase | ||||||
|  | 		args    args | ||||||
|  | 		want    *regexp.Regexp | ||||||
|  | 		wantErr bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "simple machine name generation", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "testmachine", | ||||||
|  | 				randomSuffix: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    regexp.MustCompile("^testmachine$"), | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with 53 chars", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | ||||||
|  | 				randomSuffix: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with 63 chars", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||||
|  | 				randomSuffix: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with 64 chars", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | ||||||
|  | 				randomSuffix: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    nil, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with 73 chars", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | ||||||
|  | 				randomSuffix: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    nil, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with random suffix", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "test", | ||||||
|  | 				randomSuffix: true, | ||||||
|  | 			}, | ||||||
|  | 			want:    regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "machine name with 63 chars with random suffix", | ||||||
|  | 			db: &HSDatabase{ | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||||
|  | 				randomSuffix: true, | ||||||
|  | 			}, | ||||||
|  | 			want:    regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | ||||||
|  | 			if (err != nil) != tt.wantErr { | ||||||
|  | 				t.Errorf( | ||||||
|  | 					"Headscale.GenerateGivenName() error = %v, wantErr %v", | ||||||
|  | 					err, | ||||||
|  | 					tt.wantErr, | ||||||
|  | 				) | ||||||
|  | 
 | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if tt.want != nil && !tt.want.MatchString(got) { | ||||||
|  | 				t.Errorf( | ||||||
|  | 					"Headscale.GenerateGivenName() = %v, does not match %v", | ||||||
|  | 					tt.want, | ||||||
|  | 					got, | ||||||
|  | 				) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if len(got) > util.LabelHostnameLength { | ||||||
|  | 				t.Errorf( | ||||||
|  | 					"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", | ||||||
|  | 					got, | ||||||
|  | 					util.LabelHostnameLength, | ||||||
|  | 				) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestAutoApproveRoutes(c *check.C) { | ||||||
|  | 	acl := []byte(` | ||||||
|  | { | ||||||
|  | 	"tagOwners": { | ||||||
|  | 		"tag:exit": ["test"], | ||||||
|  | 	}, | ||||||
|  | 
 | ||||||
|  | 	"groups": { | ||||||
|  | 		"group:test": ["test"] | ||||||
|  | 	}, | ||||||
|  | 
 | ||||||
|  | 	"acls": [ | ||||||
|  | 		{"action": "accept", "users": ["*"], "ports": ["*:*"]}, | ||||||
|  | 	], | ||||||
|  | 
 | ||||||
|  | 	"autoApprovers": { | ||||||
|  | 		"exitNode": ["tag:exit"], | ||||||
|  | 		"routes": { | ||||||
|  | 			"10.10.0.0/16": ["group:test"], | ||||||
|  | 			"10.11.0.0/16": ["test"], | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 	`) | ||||||
|  | 
 | ||||||
|  | 	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(pol, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	nodeKey := key.NewNode() | ||||||
|  | 
 | ||||||
|  | 	defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") | ||||||
|  | 	defaultRouteV6 := netip.MustParsePrefix("::/0") | ||||||
|  | 	route1 := netip.MustParsePrefix("10.10.0.0/16") | ||||||
|  | 	// Check if a subprefix of an autoapproved route is approved | ||||||
|  | 	route2 := netip.MustParsePrefix("10.11.0.0/24") | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "test", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 		HostInfo: types.HostInfo{ | ||||||
|  | 			RequestTags: []string{"tag:exit"}, | ||||||
|  | 			RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, | ||||||
|  | 		}, | ||||||
|  | 		IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	err = db.ProcessMachineRoutes(&machine) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machine0ByID, err := db.GetMachineByID(0) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	err = db.EnableAutoApprovedRoutes(pol, machine0ByID) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(enabledRoutes, check.HasLen, 4) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(4)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestMachine_canAccess(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		filter   []tailcfg.FilterRule | ||||||
|  | 		machine2 *types.Machine | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name    string | ||||||
|  | 		machine types.Machine | ||||||
|  | 		args    args | ||||||
|  | 		want    bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "no-rules", | ||||||
|  | 			machine: types.Machine{ | ||||||
|  | 				IPAddresses: types.MachineAddresses{ | ||||||
|  | 					netip.MustParseAddr("10.0.0.1"), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				filter: []tailcfg.FilterRule{}, | ||||||
|  | 				machine2: &types.Machine{ | ||||||
|  | 					IPAddresses: types.MachineAddresses{ | ||||||
|  | 						netip.MustParseAddr("10.0.0.2"), | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "wildcard", | ||||||
|  | 			machine: types.Machine{ | ||||||
|  | 				IPAddresses: types.MachineAddresses{ | ||||||
|  | 					netip.MustParseAddr("10.0.0.1"), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				filter: []tailcfg.FilterRule{ | ||||||
|  | 					{ | ||||||
|  | 						SrcIPs: []string{"*"}, | ||||||
|  | 						DstPorts: []tailcfg.NetPortRange{ | ||||||
|  | 							{ | ||||||
|  | 								IP: "*", | ||||||
|  | 								Ports: tailcfg.PortRange{ | ||||||
|  | 									First: 0, | ||||||
|  | 									Last:  65535, | ||||||
|  | 								}, | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				machine2: &types.Machine{ | ||||||
|  | 					IPAddresses: types.MachineAddresses{ | ||||||
|  | 						netip.MustParseAddr("10.0.0.2"), | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "explicit-m1-to-m2", | ||||||
|  | 			machine: types.Machine{ | ||||||
|  | 				IPAddresses: types.MachineAddresses{ | ||||||
|  | 					netip.MustParseAddr("10.0.0.1"), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				filter: []tailcfg.FilterRule{ | ||||||
|  | 					{ | ||||||
|  | 						SrcIPs: []string{"10.0.0.1"}, | ||||||
|  | 						DstPorts: []tailcfg.NetPortRange{ | ||||||
|  | 							{ | ||||||
|  | 								IP: "10.0.0.2", | ||||||
|  | 								Ports: tailcfg.PortRange{ | ||||||
|  | 									First: 0, | ||||||
|  | 									Last:  65535, | ||||||
|  | 								}, | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				machine2: &types.Machine{ | ||||||
|  | 					IPAddresses: types.MachineAddresses{ | ||||||
|  | 						netip.MustParseAddr("10.0.0.2"), | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "explicit-m2-to-m1", | ||||||
|  | 			machine: types.Machine{ | ||||||
|  | 				IPAddresses: types.MachineAddresses{ | ||||||
|  | 					netip.MustParseAddr("10.0.0.1"), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			args: args{ | ||||||
|  | 				filter: []tailcfg.FilterRule{ | ||||||
|  | 					{ | ||||||
|  | 						SrcIPs: []string{"10.0.0.2"}, | ||||||
|  | 						DstPorts: []tailcfg.NetPortRange{ | ||||||
|  | 							{ | ||||||
|  | 								IP: "10.0.0.1", | ||||||
|  | 								Ports: tailcfg.PortRange{ | ||||||
|  | 									First: 0, | ||||||
|  | 									Last:  65535, | ||||||
|  | 								}, | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				machine2: &types.Machine{ | ||||||
|  | 					IPAddresses: types.MachineAddresses{ | ||||||
|  | 						netip.MustParseAddr("10.0.0.2"), | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want { | ||||||
|  | 				t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -1,17 +1,14 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strconv" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" |  | ||||||
| 	"google.golang.org/protobuf/types/known/timestamppb" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -23,28 +20,6 @@ var ( | |||||||
| 	ErrPreAuthKeyACLTagInvalid     = errors.New("AuthKey tag is invalid") | 	ErrPreAuthKeyACLTagInvalid     = errors.New("AuthKey tag is invalid") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // PreAuthKey describes a pre-authorization key usable in a particular user. |  | ||||||
| type PreAuthKey struct { |  | ||||||
| 	ID        uint64 `gorm:"primary_key"` |  | ||||||
| 	Key       string |  | ||||||
| 	UserID    uint |  | ||||||
| 	User      User |  | ||||||
| 	Reusable  bool |  | ||||||
| 	Ephemeral bool `gorm:"default:false"` |  | ||||||
| 	Used      bool `gorm:"default:false"` |  | ||||||
| 	ACLTags   []PreAuthKeyACLTag |  | ||||||
| 
 |  | ||||||
| 	CreatedAt  *time.Time |  | ||||||
| 	Expiration *time.Time |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. |  | ||||||
| type PreAuthKeyACLTag struct { |  | ||||||
| 	ID           uint64 `gorm:"primary_key"` |  | ||||||
| 	PreAuthKeyID uint64 |  | ||||||
| 	Tag          string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. | // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. | ||||||
| func (hsdb *HSDatabase) CreatePreAuthKey( | func (hsdb *HSDatabase) CreatePreAuthKey( | ||||||
| 	userName string, | 	userName string, | ||||||
| @ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| 	ephemeral bool, | 	ephemeral bool, | ||||||
| 	expiration *time.Time, | 	expiration *time.Time, | ||||||
| 	aclTags []string, | 	aclTags []string, | ||||||
| ) (*PreAuthKey, error) { | ) (*types.PreAuthKey, error) { | ||||||
| 	user, err := hsdb.GetUser(userName) | 	user, err := hsdb.GetUser(userName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	key := PreAuthKey{ | 	key := types.PreAuthKey{ | ||||||
| 		Key:        kstr, | 		Key:        kstr, | ||||||
| 		UserID:     user.ID, | 		UserID:     user.ID, | ||||||
| 		User:       *user, | 		User:       *user, | ||||||
| @ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| 
 | 
 | ||||||
| 			for _, tag := range aclTags { | 			for _, tag := range aclTags { | ||||||
| 				if !seenTags[tag] { | 				if !seenTags[tag] { | ||||||
| 					if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { | 					if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { | ||||||
| 						return fmt.Errorf( | 						return fmt.Errorf( | ||||||
| 							"failed to ceate key tag in the database: %w", | 							"failed to ceate key tag in the database: %w", | ||||||
| 							err, | 							err, | ||||||
| @ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListPreAuthKeys returns the list of PreAuthKeys for a user. | // ListPreAuthKeys returns the list of PreAuthKeys for a user. | ||||||
| func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { | func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||||
| 	user, err := hsdb.GetUser(userName) | 	user, err := hsdb.GetUser(userName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	keys := []PreAuthKey{} | 	keys := []types.PreAuthKey{} | ||||||
| 	if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { | 	if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetPreAuthKey returns a PreAuthKey for a given key. | // GetPreAuthKey returns a PreAuthKey for a given key. | ||||||
| func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { | func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { | ||||||
| 	pak, err := hsdb.checkKeyValidity(key) | 	pak, err := hsdb.ValidatePreAuthKey(key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err | |||||||
| 
 | 
 | ||||||
| // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | ||||||
| // does not exist. | // does not exist. | ||||||
| func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { | func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | ||||||
| 	return hsdb.db.Transaction(func(db *gorm.DB) error { | 	return hsdb.db.Transaction(func(db *gorm.DB) error { | ||||||
| 		if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { | 		if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { | ||||||
| 			return result.Error | 			return result.Error | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // MarkExpirePreAuthKey marks a PreAuthKey as expired. | // MarkExpirePreAuthKey marks a PreAuthKey as expired. | ||||||
| func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { | func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | ||||||
| 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UsePreAuthKey marks a PreAuthKey as used. | // UsePreAuthKey marks a PreAuthKey as used. | ||||||
| func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { | func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | ||||||
| 	k.Used = true | 	k.Used = true | ||||||
| 	if err := hsdb.db.Save(k).Error; err != nil { | 	if err := hsdb.db.Save(k).Error; err != nil { | ||||||
| 		return fmt.Errorf("failed to update key used status in the database: %w", err) | 		return fmt.Errorf("failed to update key used status in the database: %w", err) | ||||||
| @ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node | // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node | ||||||
| // If returns no error and a PreAuthKey, it can be used. | // If returns no error and a PreAuthKey, it can be used. | ||||||
| func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { | func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { | ||||||
| 	pak := PreAuthKey{} | 	pak := types.PreAuthKey{} | ||||||
| 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | ||||||
| 		result.Error, | 		result.Error, | ||||||
| 		gorm.ErrRecordNotFound, | 		gorm.ErrRecordNotFound, | ||||||
| @ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { | |||||||
| 		return &pak, nil | 		return &pak, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machines := []Machine{} | 	machines := types.Machines{} | ||||||
| 	if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | 	if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) { | |||||||
| 
 | 
 | ||||||
| 	return hex.EncodeToString(bytes), nil | 	return hex.EncodeToString(bytes), nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (key *PreAuthKey) toProto() *v1.PreAuthKey { |  | ||||||
| 	protoKey := v1.PreAuthKey{ |  | ||||||
| 		User:      key.User.Name, |  | ||||||
| 		Id:        strconv.FormatUint(key.ID, util.Base10), |  | ||||||
| 		Key:       key.Key, |  | ||||||
| 		Ephemeral: key.Ephemeral, |  | ||||||
| 		Reusable:  key.Reusable, |  | ||||||
| 		Used:      key.Used, |  | ||||||
| 		AclTags:   make([]string, len(key.ACLTags)), |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if key.Expiration != nil { |  | ||||||
| 		protoKey.Expiration = timestamppb.New(*key.Expiration) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if key.CreatedAt != nil { |  | ||||||
| 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for idx := range key.ACLTags { |  | ||||||
| 		protoKey.AclTags[idx] = key.ACLTags[idx].Tag |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &protoKey |  | ||||||
| } |  | ||||||
| @ -1,20 +1,22 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"gopkg.in/check.v1" | 	"gopkg.in/check.v1" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestCreatePreAuthKey(c *check.C) { | func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||||
| 	_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) | 	_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | 	key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	// Did we get a valid key? | 	// Did we get a valid key? | ||||||
| @ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { | |||||||
| 	// Make sure the User association is populated | 	// Make sure the User association is populated | ||||||
| 	c.Assert(key.User.Name, check.Equals, user.Name) | 	c.Assert(key.User.Name, check.Equals, user.Name) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.ListPreAuthKeys("bogus") | 	_, err = db.ListPreAuthKeys("bogus") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	keys, err := app.db.ListPreAuthKeys(user.Name) | 	keys, err := db.ListPreAuthKeys(user.Name) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(keys), check.Equals, 1) | 	c.Assert(len(keys), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| @ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestExpiredPreAuthKey(c *check.C) { | func (*Suite) TestExpiredPreAuthKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test2") | 	user, err := db.CreateUser("test2") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | ||||||
| 	c.Assert(key, check.IsNil) | 	c.Assert(key, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { | func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { | ||||||
| 	key, err := app.db.checkKeyValidity("potatoKey") | 	key, err := db.ValidatePreAuthKey("potatoKey") | ||||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) | 	c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) | ||||||
| 	c.Assert(key, check.IsNil) | 	c.Assert(key, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestValidateKeyOk(c *check.C) { | func (*Suite) TestValidateKeyOk(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test3") | 	user, err := db.CreateUser("test3") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | 	c.Assert(key.ID, check.Equals, pak.ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestAlreadyUsedKey(c *check.C) { | func (*Suite) TestAlreadyUsedKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test4") | 	user, err := db.CreateUser("test4") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "testest", | 		Hostname:       "testest", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | ||||||
| 	c.Assert(key, check.IsNil) | 	c.Assert(key, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestReusableBeingUsedKey(c *check.C) { | func (*Suite) TestReusableBeingUsedKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test5") | 	user, err := db.CreateUser("test5") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "testest", | 		Hostname:       "testest", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | 	c.Assert(key.ID, check.Equals, pak.ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { | func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test6") | 	user, err := db.CreateUser("test6") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(key.ID, check.Equals, pak.ID) | 	c.Assert(key.ID, check.Equals, pak.ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestEphemeralKey(c *check.C) { | func (*Suite) TestEphemeralKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test7") | 	user, err := db.CreateUser("test7") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now().Add(-time.Second * 30) | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "testest", | 		Hostname:       "testest", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		LastSeen:       &now, | 		LastSeen:       &now, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.checkKeyValidity(pak.Key) | 	_, err = db.ValidatePreAuthKey(pak.Key) | ||||||
| 	// Ephemeral keys are by definition reusable | 	// Ephemeral keys are by definition reusable | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test7", "testest") | 	_, err = db.GetMachine("test7", "testest") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	app.expireEphemeralNodesWorker() | 	db.ExpireEphemeralMachines(time.Second * 20) | ||||||
| 
 | 
 | ||||||
| 	// The machine record should have been deleted | 	// The machine record should have been deleted | ||||||
| 	_, err = app.db.GetMachine("test7", "testest") | 	_, err = db.GetMachine("test7", "testest") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(1)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestExpirePreauthKey(c *check.C) { | func (*Suite) TestExpirePreauthKey(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test3") | 	user, err := db.CreateUser("test3") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(pak.Expiration, check.IsNil) | 	c.Assert(pak.Expiration, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.ExpirePreAuthKey(pak) | 	err = db.ExpirePreAuthKey(pak) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(pak.Expiration, check.NotNil) | 	c.Assert(pak.Expiration, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	key, err := app.db.checkKeyValidity(pak.Key) | 	key, err := db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | 	c.Assert(err, check.Equals, ErrPreAuthKeyExpired) | ||||||
| 	c.Assert(key, check.IsNil) | 	c.Assert(key, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { | func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test6") | 	user, err := db.CreateUser("test6") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	pak.Used = true | 	pak.Used = true | ||||||
| 	app.db.db.Save(&pak) | 	db.db.Save(&pak) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.checkKeyValidity(pak.Key) | 	_, err = db.ValidatePreAuthKey(pak.Key) | ||||||
| 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | 	c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (*Suite) TestPreAuthKeyACLTags(c *check.C) { | func (*Suite) TestPreAuthKeyACLTags(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test8") | 	user, err := db.CreateUser("test8") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) | 	_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) | ||||||
| 	c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected | 	c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected | ||||||
| 
 | 
 | ||||||
| 	tags := []string{"tag:test1", "tag:test2"} | 	tags := []string{"tag:test1", "tag:test2"} | ||||||
| 	tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} | 	tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} | ||||||
| 	_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) | 	_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	listedPaks, err := app.db.ListPreAuthKeys("test8") | 	listedPaks, err := db.ListPreAuthKeys("test8") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) | 	c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags) | ||||||
| } | } | ||||||
| @ -1,55 +1,19 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"google.golang.org/protobuf/types/known/timestamppb" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ErrRouteIsNotAvailable = errors.New("route is not available") | ||||||
| 	ErrRouteIsNotAvailable = errors.New("route is not available") |  | ||||||
| 	ExitRouteV4            = netip.MustParsePrefix("0.0.0.0/0") |  | ||||||
| 	ExitRouteV6            = netip.MustParsePrefix("::/0") |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| type Route struct { | func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | ||||||
| 	gorm.Model | 	var routes types.Routes | ||||||
| 
 |  | ||||||
| 	MachineID uint64 |  | ||||||
| 	Machine   Machine |  | ||||||
| 	Prefix    IPPrefix |  | ||||||
| 
 |  | ||||||
| 	Advertised bool |  | ||||||
| 	Enabled    bool |  | ||||||
| 	IsPrimary  bool |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Routes []Route |  | ||||||
| 
 |  | ||||||
| func (r *Route) String() string { |  | ||||||
| 	return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (r *Route) isExitRoute() bool { |  | ||||||
| 	return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (rs Routes) toPrefixes() []netip.Prefix { |  | ||||||
| 	prefixes := make([]netip.Prefix, len(rs)) |  | ||||||
| 	for i, r := range rs { |  | ||||||
| 		prefixes[i] = netip.Prefix(r.Prefix) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return prefixes |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { |  | ||||||
| 	var routes []Route |  | ||||||
| 	err := hsdb.db.Preload("Machine").Find(&routes).Error | 	err := hsdb.db.Preload("Machine").Find(&routes).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { | |||||||
| 	return routes, nil | 	return routes, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { | func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||||
| 	var routes []Route | 	var routes types.Routes | ||||||
|  | 	err := hsdb.db. | ||||||
|  | 		Preload("Machine"). | ||||||
|  | 		Where("machine_id = ? AND advertised = true", machine.ID). | ||||||
|  | 		Find(&routes).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return routes, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||||
|  | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| 		Where("machine_id = ?", m.ID). | 		Where("machine_id = ?", m.ID). | ||||||
| @ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { | |||||||
| 	return routes, nil | 	return routes, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { | func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | ||||||
| 	var route Route | 	var route types.Route | ||||||
| 	err := hsdb.db.Preload("Machine").First(&route, id).Error | 	err := hsdb.db.Preload("Machine").First(&route, id).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { | |||||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||||
| 	// be enabled at the same time, as per | 	// be enabled at the same time, as per | ||||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||||
| 	if route.isExitRoute() { | 	if route.IsExitRoute() { | ||||||
| 		return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) | 		return hsdb.enableRoutes( | ||||||
|  | 			&route.Machine, | ||||||
|  | 			types.ExitRouteV4.String(), | ||||||
|  | 			types.ExitRouteV6.String(), | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) | 	return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) | ||||||
| @ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||||
| 	// be enabled at the same time, as per | 	// be enabled at the same time, as per | ||||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||||
| 	if !route.isExitRoute() { | 	if !route.IsExitRoute() { | ||||||
| 		route.Enabled = false | 		route.Enabled = false | ||||||
| 		route.IsPrimary = false | 		route.IsPrimary = false | ||||||
| 		err = hsdb.db.Save(route).Error | 		err = hsdb.db.Save(route).Error | ||||||
| @ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return hsdb.handlePrimarySubnetFailover() | 		return hsdb.HandlePrimarySubnetFailover() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||||
| @ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for i := range routes { | 	for i := range routes { | ||||||
| 		if routes[i].isExitRoute() { | 		if routes[i].IsExitRoute() { | ||||||
| 			routes[i].Enabled = false | 			routes[i].Enabled = false | ||||||
| 			routes[i].IsPrimary = false | 			routes[i].IsPrimary = false | ||||||
| 			err = hsdb.db.Save(&routes[i]).Error | 			err = hsdb.db.Save(&routes[i]).Error | ||||||
| @ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.handlePrimarySubnetFailover() | 	return hsdb.HandlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||||
| @ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | |||||||
| 	// Tailscale requires both IPv4 and IPv6 exit routes to | 	// Tailscale requires both IPv4 and IPv6 exit routes to | ||||||
| 	// be enabled at the same time, as per | 	// be enabled at the same time, as per | ||||||
| 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | 	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 | ||||||
| 	if !route.isExitRoute() { | 	if !route.IsExitRoute() { | ||||||
| 		if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { | 		if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return hsdb.handlePrimarySubnetFailover() | 		return hsdb.HandlePrimarySubnetFailover() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||||
| @ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routesToDelete := []Route{} | 	routesToDelete := types.Routes{} | ||||||
| 	for _, r := range routes { | 	for _, r := range routes { | ||||||
| 		if r.isExitRoute() { | 		if r.IsExitRoute() { | ||||||
| 			routesToDelete = append(routesToDelete, r) | 			routesToDelete = append(routesToDelete, r) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.handlePrimarySubnetFailover() | 	return hsdb.HandlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { | func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | ||||||
| 	routes, err := hsdb.GetMachineRoutes(m) | 	routes, err := hsdb.GetMachineRoutes(m) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.handlePrimarySubnetFailover() | 	return hsdb.HandlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // isUniquePrefix returns if there is another machine providing the same route already. | // isUniquePrefix returns if there is another machine providing the same route already. | ||||||
| func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { | func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	hsdb.db. | 	hsdb.db. | ||||||
| 		Model(&Route{}). | 		Model(&types.Route{}). | ||||||
| 		Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | 		Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | ||||||
| 			route.Prefix, | 			route.Prefix, | ||||||
| 			route.MachineID, | 			route.MachineID, | ||||||
| @ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { | |||||||
| 	return count == 0 | 	return count == 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { | func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { | ||||||
| 	var route Route | 	var route types.Route | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| 		Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). | 		Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). | ||||||
| 		First(&route).Error | 		First(&route).Error | ||||||
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { | |||||||
| 
 | 
 | ||||||
| // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | ||||||
| // Exit nodes are not considered for this, as they are never marked as Primary. | // Exit nodes are not considered for this, as they are never marked as Primary. | ||||||
| func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { | func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { | ||||||
| 	var routes []Route | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| 		Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). | 		Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). | ||||||
| @ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { | |||||||
| 	return routes, nil | 	return routes, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | ||||||
| 	currentRoutes := []Route{} | 	currentRoutes := types.Routes{} | ||||||
| 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | |||||||
| 
 | 
 | ||||||
| 	for prefix, exists := range advertisedRoutes { | 	for prefix, exists := range advertisedRoutes { | ||||||
| 		if !exists { | 		if !exists { | ||||||
| 			route := Route{ | 			route := types.Route{ | ||||||
| 				MachineID:  machine.ID, | 				MachineID:  machine.ID, | ||||||
| 				Prefix:     IPPrefix(prefix), | 				Prefix:     types.IPPrefix(prefix), | ||||||
| 				Advertised: true, | 				Advertised: true, | ||||||
| 				Enabled:    false, | 				Enabled:    false, | ||||||
| 			} | 			} | ||||||
| @ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||||
| 	// first, get all the enabled routes | 	// first, get all the enabled routes | ||||||
| 	var routes []Route | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| 		Where("advertised = ? AND enabled = ?", true, true). | 		Where("advertised = ? AND enabled = ?", true, true). | ||||||
| @ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | |||||||
| 
 | 
 | ||||||
| 	routesChanged := false | 	routesChanged := false | ||||||
| 	for pos, route := range routes { | 	for pos, route := range routes { | ||||||
| 		if route.isExitRoute() { | 		if route.IsExitRoute() { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if route.IsPrimary { | 		if route.IsPrimary { | ||||||
| 			if route.Machine.isOnline() { | 			if route.Machine.IsOnline() { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | |||||||
| 				Msgf("machine offline, finding a new primary subnet") | 				Msgf("machine offline, finding a new primary subnet") | ||||||
| 
 | 
 | ||||||
| 			// find a new primary route | 			// find a new primary route | ||||||
| 			var newPrimaryRoutes []Route | 			var newPrimaryRoutes types.Routes | ||||||
| 			err := hsdb.db. | 			err := hsdb.db. | ||||||
| 				Preload("Machine"). | 				Preload("Machine"). | ||||||
| 				Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | 				Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", | ||||||
| @ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | |||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			var newPrimaryRoute *Route | 			var newPrimaryRoute *types.Route | ||||||
| 			for pos, r := range newPrimaryRoutes { | 			for pos, r := range newPrimaryRoutes { | ||||||
| 				if r.Machine.isOnline() { | 				if r.Machine.IsOnline() { | ||||||
| 					newPrimaryRoute = &newPrimaryRoutes[pos] | 					newPrimaryRoute = &newPrimaryRoutes[pos] | ||||||
| 
 | 
 | ||||||
| 					break | 					break | ||||||
| @ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (rs Routes) toProto() []*v1.Route { | // EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. | ||||||
| 	protoRoutes := []*v1.Route{} | func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | ||||||
| 
 | 	aclPolicy *policy.ACLPolicy, | ||||||
| 	for _, route := range rs { | 	machine *types.Machine, | ||||||
| 		protoRoute := v1.Route{ | ) error { | ||||||
| 			Id:         uint64(route.ID), | 	if len(machine.IPAddresses) == 0 { | ||||||
| 			Machine:    route.Machine.toProto(), | 		return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs | ||||||
| 			Prefix:     netip.Prefix(route.Prefix).String(), |  | ||||||
| 			Advertised: route.Advertised, |  | ||||||
| 			Enabled:    route.Enabled, |  | ||||||
| 			IsPrimary:  route.IsPrimary, |  | ||||||
| 			CreatedAt:  timestamppb.New(route.CreatedAt), |  | ||||||
| 			UpdatedAt:  timestamppb.New(route.UpdatedAt), |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if route.DeletedAt.Valid { |  | ||||||
| 			protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		protoRoutes = append(protoRoutes, &protoRoute) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return protoRoutes | 	routes, err := hsdb.GetMachineAdvertisedRoutes(machine) | ||||||
|  | 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
|  | 		log.Error(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Str("machine", machine.Hostname). | ||||||
|  | 			Msg("Could not get advertised routes for machine") | ||||||
|  | 
 | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	approvedRoutes := types.Routes{} | ||||||
|  | 
 | ||||||
|  | 	for _, advertisedRoute := range routes { | ||||||
|  | 		if advertisedRoute.Enabled { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( | ||||||
|  | 			netip.Prefix(advertisedRoute.Prefix), | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Err(err). | ||||||
|  | 				Str("advertisedRoute", advertisedRoute.String()). | ||||||
|  | 				Uint64("machineId", machine.ID). | ||||||
|  | 				Msg("Failed to resolve autoApprovers for advertised route") | ||||||
|  | 
 | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, approvedAlias := range routeApprovers { | ||||||
|  | 			if approvedAlias == machine.User.Name { | ||||||
|  | 				approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||||
|  | 			} else { | ||||||
|  | 				// TODO(kradalby): figure out how to get this to depend on less stuff | ||||||
|  | 				approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Err(err). | ||||||
|  | 						Str("alias", approvedAlias). | ||||||
|  | 						Msg("Failed to expand alias when processing autoApprovers policy") | ||||||
|  | 
 | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first | ||||||
|  | 				if approvedIps.Contains(machine.IPAddresses[0]) { | ||||||
|  | 					approvedRoutes = append(approvedRoutes, advertisedRoute) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, approvedRoute := range approvedRoutes { | ||||||
|  | 		err := hsdb.EnableRoute(uint64(approvedRoute.ID)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Err(err). | ||||||
|  | 				Str("approvedRoute", approvedRoute.String()). | ||||||
|  | 				Uint64("machineId", machine.ID). | ||||||
|  | 				Msg("Failed to enable approved route") | ||||||
|  | 
 | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| @ -1,9 +1,11 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"gopkg.in/check.v1" | 	"gopkg.in/check.v1" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| @ -11,13 +13,13 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetRoutes(c *check.C) { | func (s *Suite) TestGetRoutes(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_get_route_machine") | 	_, err = db.GetMachine("test", "test_get_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	route, err := netip.ParsePrefix("10.0.0.0/24") | 	route, err := netip.ParsePrefix("10.0.0.0/24") | ||||||
| @ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) { | |||||||
| 		RoutableIPs: []netip.Prefix{route}, | 		RoutableIPs: []netip.Prefix{route}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_get_route_machine", | 		Hostname:       "test_get_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo), | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine) | 	err = db.ProcessMachineRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) | 	advertisedRoutes, err := db.GetAdvertisedRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(advertisedRoutes), check.Equals, 1) | 	c.Assert(len(advertisedRoutes), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine, "192.168.0.0/24") | 	err = db.enableRoutes(&machine, "192.168.0.0/24") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(0)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestGetEnableRoutes(c *check.C) { | func (s *Suite) TestGetEnableRoutes(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	route, err := netip.ParsePrefix( | 	route, err := netip.ParsePrefix( | ||||||
| @ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { | |||||||
| 		RoutableIPs: []netip.Prefix{route, route2}, | 		RoutableIPs: []netip.Prefix{route, route2}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machine := Machine{ | 	machine := types.Machine{ | ||||||
| 		ID:             0, | 		ID:             0, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo), | 		HostInfo:       types.HostInfo(hostInfo), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine) | 	err = db.ProcessMachineRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) | 	availableRoutes, err := db.GetAdvertisedRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(availableRoutes), check.Equals, 2) | 	c.Assert(len(availableRoutes), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) | 	noEnabledRoutes, err := db.GetEnabledRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(noEnabledRoutes), check.Equals, 0) | 	c.Assert(len(noEnabledRoutes), check.Equals, 0) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine, "192.168.0.0/24") | 	err = db.enableRoutes(&machine, "192.168.0.0/24") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes, err := app.db.GetEnabledRoutes(&machine) | 	enabledRoutes, err := db.GetEnabledRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes), check.Equals, 1) | 	c.Assert(len(enabledRoutes), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	// Adding it twice will just let it pass through | 	// Adding it twice will just let it pass through | ||||||
| 	err = app.db.enableRoutes(&machine, "10.0.0.0/24") | 	err = db.enableRoutes(&machine, "10.0.0.0/24") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) | 	enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) | 	c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine, "150.0.10.0/25") | 	err = db.enableRoutes(&machine, "150.0.10.0/25") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) | 	enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) | 	c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(3)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestIsUniquePrefix(c *check.C) { | func (s *Suite) TestIsUniquePrefix(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	route, err := netip.ParsePrefix( | 	route, err := netip.ParsePrefix( | ||||||
| @ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { | |||||||
| 	hostInfo1 := tailcfg.Hostinfo{ | 	hostInfo1 := tailcfg.Hostinfo{ | ||||||
| 		RoutableIPs: []netip.Prefix{route, route2}, | 		RoutableIPs: []netip.Prefix{route, route2}, | ||||||
| 	} | 	} | ||||||
| 	machine1 := Machine{ | 	machine1 := types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo1), | 		HostInfo:       types.HostInfo(hostInfo1), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine1) | 	db.db.Save(&machine1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine1) | 	err = db.ProcessMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, route.String()) | 	err = db.enableRoutes(&machine1, route.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, route2.String()) | 	err = db.enableRoutes(&machine1, route2.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	hostInfo2 := tailcfg.Hostinfo{ | 	hostInfo2 := tailcfg.Hostinfo{ | ||||||
| 		RoutableIPs: []netip.Prefix{route2}, | 		RoutableIPs: []netip.Prefix{route2}, | ||||||
| 	} | 	} | ||||||
| 	machine2 := Machine{ | 	machine2 := types.Machine{ | ||||||
| 		ID:             2, | 		ID:             2, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo2), | 		HostInfo:       types.HostInfo(hostInfo2), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine2) | 	db.db.Save(&machine2) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine2) | 	err = db.ProcessMachineRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine2, route2.String()) | 	err = db.enableRoutes(&machine2, route2.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) | 	enabledRoutes2, err := db.GetEnabledRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes2), check.Equals, 1) | 	c.Assert(len(enabledRoutes2), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	routes, err := app.db.getMachinePrimaryRoutes(&machine1) | 	routes, err := db.GetMachinePrimaryRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 2) | 	c.Assert(len(routes), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 0) | 	c.Assert(len(routes), check.Equals, 0) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(3)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestSubnetFailover(c *check.C) { | func (s *Suite) TestSubnetFailover(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	prefix, err := netip.ParsePrefix( | 	prefix, err := netip.ParsePrefix( | ||||||
| @ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	machine1 := Machine{ | 	machine1 := types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo1), | 		HostInfo:       types.HostInfo(hostInfo1), | ||||||
| 		LastSeen:       &now, | 		LastSeen:       &now, | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine1) | 	db.db.Save(&machine1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine1) | 	err = db.ProcessMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | 	err = db.enableRoutes(&machine1, prefix.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, prefix2.String()) | 	err = db.enableRoutes(&machine1, prefix2.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.handlePrimarySubnetFailover() | 	err = db.HandlePrimarySubnetFailover() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	route, err := app.db.getPrimaryRoute(prefix) | 	route, err := db.getPrimaryRoute(prefix) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(route.MachineID, check.Equals, machine1.ID) | 	c.Assert(route.MachineID, check.Equals, machine1.ID) | ||||||
| 
 | 
 | ||||||
| 	hostInfo2 := tailcfg.Hostinfo{ | 	hostInfo2 := tailcfg.Hostinfo{ | ||||||
| 		RoutableIPs: []netip.Prefix{prefix2}, | 		RoutableIPs: []netip.Prefix{prefix2}, | ||||||
| 	} | 	} | ||||||
| 	machine2 := Machine{ | 	machine2 := types.Machine{ | ||||||
| 		ID:             2, | 		ID:             2, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo2), | 		HostInfo:       types.HostInfo(hostInfo2), | ||||||
| 		LastSeen:       &now, | 		LastSeen:       &now, | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine2) | 	db.db.Save(&machine2) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine2) | 	err = db.ProcessMachineRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine2, prefix2.String()) | 	err = db.enableRoutes(&machine2, prefix2.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.handlePrimarySubnetFailover() | 	err = db.HandlePrimarySubnetFailover() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err = db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 2) | 	c.Assert(len(enabledRoutes1), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) | 	enabledRoutes2, err := db.GetEnabledRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes2), check.Equals, 1) | 	c.Assert(len(enabledRoutes2), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	routes, err := app.db.getMachinePrimaryRoutes(&machine1) | 	routes, err := db.GetMachinePrimaryRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 2) | 	c.Assert(len(routes), check.Equals, 2) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 0) | 	c.Assert(len(routes), check.Equals, 0) | ||||||
| 
 | 
 | ||||||
| 	// lets make machine1 lastseen 10 mins ago | 	// lets make machine1 lastseen 10 mins ago | ||||||
| 	before := now.Add(-10 * time.Minute) | 	before := now.Add(-10 * time.Minute) | ||||||
| 	machine1.LastSeen = &before | 	machine1.LastSeen = &before | ||||||
| 	err = app.db.db.Save(&machine1).Error | 	err = db.db.Save(&machine1).Error | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.handlePrimarySubnetFailover() | 	err = db.HandlePrimarySubnetFailover() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine1) | 	routes, err = db.GetMachinePrimaryRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 1) | 	c.Assert(len(routes), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 1) | 	c.Assert(len(routes), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ | 	machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ | ||||||
| 		RoutableIPs: []netip.Prefix{prefix, prefix2}, | 		RoutableIPs: []netip.Prefix{prefix, prefix2}, | ||||||
| 	}) | 	}) | ||||||
| 	err = app.db.db.Save(&machine2).Error | 	err = db.db.Save(&machine2).Error | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine2) | 	err = db.ProcessMachineRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine2, prefix.String()) | 	err = db.enableRoutes(&machine2, prefix.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.handlePrimarySubnetFailover() | 	err = db.HandlePrimarySubnetFailover() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine1) | 	routes, err = db.GetMachinePrimaryRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 0) | 	c.Assert(len(routes), check.Equals, 0) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.getMachinePrimaryRoutes(&machine2) | 	routes, err = db.GetMachinePrimaryRoutes(&machine2) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 2) | 	c.Assert(len(routes), check.Equals, 2) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(6)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, | // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, | ||||||
| // including both the primary routes the node is responsible for, and the | // including both the primary routes the node is responsible for, and the | ||||||
| // exit node routes if enabled. | // exit node routes if enabled. | ||||||
| func (s *Suite) TestAllowedIPRoutes(c *check.C) { | func (s *Suite) TestAllowedIPRoutes(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	prefix, err := netip.ParsePrefix( | 	prefix, err := netip.ParsePrefix( | ||||||
| @ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | |||||||
| 	machineKey := key.NewMachine() | 	machineKey := key.NewMachine() | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	machine1 := Machine{ | 	machine1 := types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | 		MachineKey:     util.MachinePublicKeyStripPrefix(machineKey.Public()), | ||||||
| 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | 		NodeKey:        util.NodePublicKeyStripPrefix(nodeKey.Public()), | ||||||
| 		DiscoKey:       util.DiscoPublicKeyStripPrefix(discoKey.Public()), | 		DiscoKey:       util.DiscoPublicKeyStripPrefix(discoKey.Public()), | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo1), | 		HostInfo:       types.HostInfo(hostInfo1), | ||||||
| 		LastSeen:       &now, | 		LastSeen:       &now, | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine1) | 	db.db.Save(&machine1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine1) | 	err = db.ProcessMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | 	err = db.enableRoutes(&machine1, prefix.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	// We do not enable this one on purpose to test that it is not enabled | 	// We do not enable this one on purpose to test that it is not enabled | ||||||
| 	// err = app.db.enableRoutes(&machine1, prefix2.String()) | 	// err = db.enableRoutes(&machine1, prefix2.String()) | ||||||
| 	// c.Assert(err, check.IsNil) | 	// c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	routes, err := app.db.GetMachineRoutes(&machine1) | 	routes, err := db.GetMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	for _, route := range routes { | 	for _, route := range routes { | ||||||
| 		if route.isExitRoute() { | 		if route.IsExitRoute() { | ||||||
| 			err = app.db.EnableRoute(uint64(route.ID)) | 			err = db.EnableRoute(uint64(route.ID)) | ||||||
| 			c.Assert(err, check.IsNil) | 			c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 			// We only enable one exit route, so we can test that both are enabled | 			// We only enable one exit route, so we can test that both are enabled | ||||||
| @ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = app.db.handlePrimarySubnetFailover() | 	err = db.HandlePrimarySubnetFailover() | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 3) | 	c.Assert(len(enabledRoutes1), check.Equals, 3) | ||||||
| 
 | 
 | ||||||
| 	peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) | 	peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(len(peer.AllowedIPs), check.Equals, 3) | 	c.Assert(len(peer.AllowedIPs), check.Equals, 3) | ||||||
| @ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { | |||||||
| 
 | 
 | ||||||
| 	// Now we disable only one of the exit routes | 	// Now we disable only one of the exit routes | ||||||
| 	// and we see if both are disabled | 	// and we see if both are disabled | ||||||
| 	var exitRouteV4 Route | 	var exitRouteV4 types.Route | ||||||
| 	for _, route := range routes { | 	for _, route := range routes { | ||||||
| 		if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { | 		if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { | ||||||
| 			exitRouteV4 = route | 			exitRouteV4 = route | ||||||
| 
 | 
 | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = app.db.DisableRoute(uint64(exitRouteV4.ID)) | 	err = db.DisableRoute(uint64(exitRouteV4.ID)) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err = db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 1) | 	c.Assert(len(enabledRoutes1), check.Equals, 1) | ||||||
| 
 | 
 | ||||||
| 	// and now we delete only one of the exit routes | 	// and now we delete only one of the exit routes | ||||||
| 	// and we check if both are deleted | 	// and we check if both are deleted | ||||||
| 	routes, err = app.db.GetMachineRoutes(&machine1) | 	routes, err = db.GetMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 4) | 	c.Assert(len(routes), check.Equals, 4) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) | 	err = db.DeleteRoute(uint64(exitRouteV4.ID)) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	routes, err = app.db.GetMachineRoutes(&machine1) | 	routes, err = db.GetMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(routes), check.Equals, 2) | 	c.Assert(len(routes), check.Equals, 2) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(2)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestDeleteRoutes(c *check.C) { | func (s *Suite) TestDeleteRoutes(c *check.C) { | ||||||
| 	user, err := app.db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine("test", "test_enable_route_machine") | 	_, err = db.GetMachine("test", "test_enable_route_machine") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	prefix, err := netip.ParsePrefix( | 	prefix, err := netip.ParsePrefix( | ||||||
| @ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	machine1 := Machine{ | 	machine1 := types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "foo", | 		MachineKey:     "foo", | ||||||
| 		NodeKey:        "bar", | 		NodeKey:        "bar", | ||||||
| 		DiscoKey:       "faa", | 		DiscoKey:       "faa", | ||||||
| 		Hostname:       "test_enable_route_machine", | 		Hostname:       "test_enable_route_machine", | ||||||
| 		UserID:         user.ID, | 		UserID:         user.ID, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		AuthKeyID:      uint(pak.ID), | 		AuthKeyID:      uint(pak.ID), | ||||||
| 		HostInfo:       HostInfo(hostInfo1), | 		HostInfo:       types.HostInfo(hostInfo1), | ||||||
| 		LastSeen:       &now, | 		LastSeen:       &now, | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(&machine1) | 	db.db.Save(&machine1) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.processMachineRoutes(&machine1) | 	err = db.ProcessMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, prefix.String()) | 	err = db.enableRoutes(&machine1, prefix.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.enableRoutes(&machine1, prefix2.String()) | 	err = db.enableRoutes(&machine1, prefix2.String()) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	routes, err := app.db.GetMachineRoutes(&machine1) | 	routes, err := db.GetMachineRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	err = app.db.DeleteRoute(uint64(routes[0].ID)) | 	err = db.DeleteRoute(uint64(routes[0].ID)) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) | 	enabledRoutes1, err := db.GetEnabledRoutes(&machine1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(len(enabledRoutes1), check.Equals, 1) | 	c.Assert(len(enabledRoutes1), check.Equals, 1) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(channelUpdates, check.Equals, int32(2)) | ||||||
| } | } | ||||||
							
								
								
									
										74
									
								
								hscontrol/db/suite_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								hscontrol/db/suite_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | |||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"os" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"gopkg.in/check.v1" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func Test(t *testing.T) { | ||||||
|  | 	check.TestingT(t) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var _ = check.Suite(&Suite{}) | ||||||
|  | 
 | ||||||
|  | type Suite struct{} | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	tmpDir string | ||||||
|  | 	db     *HSDatabase | ||||||
|  | 
 | ||||||
|  | 	// channelUpdates counts the number of times | ||||||
|  | 	// either of the channels was notified. | ||||||
|  | 	channelUpdates int32 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (s *Suite) SetUpTest(c *check.C) { | ||||||
|  | 	atomic.StoreInt32(&channelUpdates, 0) | ||||||
|  | 	s.ResetDB(c) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TearDownTest(c *check.C) { | ||||||
|  | 	os.RemoveAll(tmpDir) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func notificationSink(c <-chan struct{}) { | ||||||
|  | 	for { | ||||||
|  | 		<-c | ||||||
|  | 		atomic.AddInt32(&channelUpdates, 1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) ResetDB(c *check.C) { | ||||||
|  | 	if len(tmpDir) != 0 { | ||||||
|  | 		os.RemoveAll(tmpDir) | ||||||
|  | 	} | ||||||
|  | 	var err error | ||||||
|  | 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	sink := make(chan struct{}) | ||||||
|  | 
 | ||||||
|  | 	go notificationSink(sink) | ||||||
|  | 
 | ||||||
|  | 	db, err = NewHeadscaleDatabase( | ||||||
|  | 		"sqlite3", | ||||||
|  | 		tmpDir+"/headscale_test.db", | ||||||
|  | 		false, | ||||||
|  | 		false, | ||||||
|  | 		sink, | ||||||
|  | 		sink, | ||||||
|  | 		[]netip.Prefix{ | ||||||
|  | 			netip.MustParsePrefix("10.27.0.0/23"), | ||||||
|  | 		}, | ||||||
|  | 		"", | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -1,17 +1,12 @@ | |||||||
| package hscontrol | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"regexp" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"google.golang.org/protobuf/types/known/timestamppb" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| ) | ) | ||||||
| @ -20,33 +15,16 @@ var ( | |||||||
| 	ErrUserExists        = errors.New("user already exists") | 	ErrUserExists        = errors.New("user already exists") | ||||||
| 	ErrUserNotFound      = errors.New("user not found") | 	ErrUserNotFound      = errors.New("user not found") | ||||||
| 	ErrUserStillHasNodes = errors.New("user not empty: node(s) found") | 	ErrUserStillHasNodes = errors.New("user not empty: node(s) found") | ||||||
| 	ErrInvalidUserName   = errors.New("invalid user name") |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( |  | ||||||
| 	// value related to RFC 1123 and 952. |  | ||||||
| 	labelHostnameLength = 63 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") |  | ||||||
| 
 |  | ||||||
| // User is the way Headscale implements the concept of users in Tailscale |  | ||||||
| // |  | ||||||
| // At the end of the day, users in Tailscale are some kind of 'bubbles' or users |  | ||||||
| // that contain our machines. |  | ||||||
| type User struct { |  | ||||||
| 	gorm.Model |  | ||||||
| 	Name string `gorm:"unique"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // CreateUser creates a new User. Returns error if could not be created | // CreateUser creates a new User. Returns error if could not be created | ||||||
| // or another user already exists. | // or another user already exists. | ||||||
| func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { | func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | ||||||
| 	err := CheckForFQDNRules(name) | 	err := util.CheckForFQDNRules(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	user := User{} | 	user := types.User{} | ||||||
| 	if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { | 	if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { | ||||||
| 		return nil, ErrUserExists | 		return nil, ErrUserExists | ||||||
| 	} | 	} | ||||||
| @ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	err = CheckForFQDNRules(newName) | 	err = util.CheckForFQDNRules(newName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetUser fetches a user by name. | // GetUser fetches a user by name. | ||||||
| func (hsdb *HSDatabase) GetUser(name string) (*User, error) { | func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | ||||||
| 	user := User{} | 	user := types.User{} | ||||||
| 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | ||||||
| 		result.Error, | 		result.Error, | ||||||
| 		gorm.ErrRecordNotFound, | 		gorm.ErrRecordNotFound, | ||||||
| @ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListUsers gets all the existing users. | // ListUsers gets all the existing users. | ||||||
| func (hsdb *HSDatabase) ListUsers() ([]User, error) { | func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | ||||||
| 	users := []User{} | 	users := []types.User{} | ||||||
| 	if err := hsdb.db.Find(&users).Error; err != nil { | 	if err := hsdb.db.Find(&users).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListMachinesByUser gets all the nodes in a given user. | // ListMachinesByUser gets all the nodes in a given user. | ||||||
| func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { | ||||||
| 	err := CheckForFQDNRules(name) | 	err := util.CheckForFQDNRules(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machines := []Machine{} | 	machines := types.Machines{} | ||||||
| 	if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { | 	if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetMachineUser assigns a Machine to a user. | // SetMachineUser assigns a Machine to a user. | ||||||
| func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { | func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { | ||||||
| 	err := CheckForFQDNRules(username) | 	err := util.CheckForFQDNRules(username) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (n *User) toTailscaleUser() *tailcfg.User { | func (hsdb *HSDatabase) GetMapResponseUserProfiles( | ||||||
| 	user := tailcfg.User{ | 	machine types.Machine, | ||||||
| 		ID:            tailcfg.UserID(n.ID), | 	peers types.Machines, | ||||||
| 		LoginName:     n.Name, |  | ||||||
| 		DisplayName:   n.Name, |  | ||||||
| 		ProfilePicURL: "", |  | ||||||
| 		Domain:        "headscale.net", |  | ||||||
| 		Logins:        []tailcfg.LoginID{}, |  | ||||||
| 		Created:       time.Time{}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &user |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (n *User) toTailscaleLogin() *tailcfg.Login { |  | ||||||
| 	login := tailcfg.Login{ |  | ||||||
| 		ID:            tailcfg.LoginID(n.ID), |  | ||||||
| 		LoginName:     n.Name, |  | ||||||
| 		DisplayName:   n.Name, |  | ||||||
| 		ProfilePicURL: "", |  | ||||||
| 		Domain:        "headscale.net", |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &login |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (hsdb *HSDatabase) getMapResponseUserProfiles( |  | ||||||
| 	machine Machine, |  | ||||||
| 	peers Machines, |  | ||||||
| ) []tailcfg.UserProfile { | ) []tailcfg.UserProfile { | ||||||
| 	userMap := make(map[string]User) | 	userMap := make(map[string]types.User) | ||||||
| 	userMap[machine.User.Name] = machine.User | 	userMap[machine.User.Name] = machine.User | ||||||
| 	for _, peer := range peers { | 	for _, peer := range peers { | ||||||
| 		userMap[peer.User.Name] = peer.User // not worth checking if already is there | 		userMap[peer.User.Name] = peer.User // not worth checking if already is there | ||||||
| @ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles( | |||||||
| 
 | 
 | ||||||
| 	return profiles | 	return profiles | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (n *User) toProto() *v1.User { |  | ||||||
| 	return &v1.User{ |  | ||||||
| 		Id:        strconv.FormatUint(uint64(n.ID), util.Base10), |  | ||||||
| 		Name:      n.Name, |  | ||||||
| 		CreatedAt: timestamppb.New(n.CreatedAt), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NormalizeToFQDNRules will replace forbidden chars in user |  | ||||||
| // it can also return an error if the user doesn't respect RFC 952 and 1123. |  | ||||||
| func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { |  | ||||||
| 	name = strings.ToLower(name) |  | ||||||
| 	name = strings.ReplaceAll(name, "'", "") |  | ||||||
| 	atIdx := strings.Index(name, "@") |  | ||||||
| 	if stripEmailDomain && atIdx > 0 { |  | ||||||
| 		name = name[:atIdx] |  | ||||||
| 	} else { |  | ||||||
| 		name = strings.ReplaceAll(name, "@", ".") |  | ||||||
| 	} |  | ||||||
| 	name = invalidCharsInUserRegex.ReplaceAllString(name, "-") |  | ||||||
| 
 |  | ||||||
| 	for _, elt := range strings.Split(name, ".") { |  | ||||||
| 		if len(elt) > labelHostnameLength { |  | ||||||
| 			return "", fmt.Errorf( |  | ||||||
| 				"label %v is more than 63 chars: %w", |  | ||||||
| 				elt, |  | ||||||
| 				ErrInvalidUserName, |  | ||||||
| 			) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return name, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func CheckForFQDNRules(name string) error { |  | ||||||
| 	if len(name) > labelHostnameLength { |  | ||||||
| 		return fmt.Errorf( |  | ||||||
| 			"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", |  | ||||||
| 			name, |  | ||||||
| 			ErrInvalidUserName, |  | ||||||
| 		) |  | ||||||
| 	} |  | ||||||
| 	if strings.ToLower(name) != name { |  | ||||||
| 		return fmt.Errorf( |  | ||||||
| 			"DNS segment should be lowercase. %v doesn't comply with this rule: %w", |  | ||||||
| 			name, |  | ||||||
| 			ErrInvalidUserName, |  | ||||||
| 		) |  | ||||||
| 	} |  | ||||||
| 	if invalidCharsInUserRegex.MatchString(name) { |  | ||||||
| 		return fmt.Errorf( |  | ||||||
| 			"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", |  | ||||||
| 			name, |  | ||||||
| 			ErrInvalidUserName, |  | ||||||
| 		) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
							
								
								
									
										277
									
								
								hscontrol/db/users_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										277
									
								
								hscontrol/db/users_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,277 @@ | |||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"gopkg.in/check.v1" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestCreateAndDestroyUser(c *check.C) { | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(user.Name, check.Equals, "test") | ||||||
|  | 
 | ||||||
|  | 	users, err := db.ListUsers() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(len(users), check.Equals, 1) | ||||||
|  | 
 | ||||||
|  | 	err = db.DestroyUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetUser("test") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestDestroyUserErrors(c *check.C) { | ||||||
|  | 	err := db.DestroyUser("test") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||||
|  | 
 | ||||||
|  | 	user, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	err = db.DestroyUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) | ||||||
|  | 	// destroying a user also deletes all associated preauthkeys | ||||||
|  | 	c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) | ||||||
|  | 
 | ||||||
|  | 	user, err = db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 
 | ||||||
|  | 	err = db.DestroyUser("test") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserStillHasNodes) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestRenameUser(c *check.C) { | ||||||
|  | 	userTest, err := db.CreateUser("test") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(userTest.Name, check.Equals, "test") | ||||||
|  | 
 | ||||||
|  | 	users, err := db.ListUsers() | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(len(users), check.Equals, 1) | ||||||
|  | 
 | ||||||
|  | 	err = db.RenameUser("test", "test-renamed") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetUser("test") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetUser("test-renamed") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	err = db.RenameUser("test-does-not-exit", "test") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||||
|  | 
 | ||||||
|  | 	userTest2, err := db.CreateUser("test2") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(userTest2.Name, check.Equals, "test2") | ||||||
|  | 
 | ||||||
|  | 	err = db.RenameUser("test2", "test-renamed") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserExists) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { | ||||||
|  | 	userShared1, err := db.CreateUser("shared1") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	userShared2, err := db.CreateUser("shared2") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	userShared3, err := db.CreateUser("shared3") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	preAuthKeyShared1, err := db.CreatePreAuthKey( | ||||||
|  | 		userShared1.Name, | ||||||
|  | 		false, | ||||||
|  | 		false, | ||||||
|  | 		nil, | ||||||
|  | 		nil, | ||||||
|  | 	) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	preAuthKeyShared2, err := db.CreatePreAuthKey( | ||||||
|  | 		userShared2.Name, | ||||||
|  | 		false, | ||||||
|  | 		false, | ||||||
|  | 		nil, | ||||||
|  | 		nil, | ||||||
|  | 	) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	preAuthKeyShared3, err := db.CreatePreAuthKey( | ||||||
|  | 		userShared3.Name, | ||||||
|  | 		false, | ||||||
|  | 		false, | ||||||
|  | 		nil, | ||||||
|  | 		nil, | ||||||
|  | 	) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	preAuthKey2Shared1, err := db.CreatePreAuthKey( | ||||||
|  | 		userShared1.Name, | ||||||
|  | 		false, | ||||||
|  | 		false, | ||||||
|  | 		nil, | ||||||
|  | 		nil, | ||||||
|  | 	) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||||
|  | 	c.Assert(err, check.NotNil) | ||||||
|  | 
 | ||||||
|  | 	machineInShared1 := &types.Machine{ | ||||||
|  | 		ID:             1, | ||||||
|  | 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
|  | 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
|  | 		DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
|  | 		Hostname:       "test_get_shared_nodes_1", | ||||||
|  | 		UserID:         userShared1.ID, | ||||||
|  | 		User:           *userShared1, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||||
|  | 		AuthKeyID:      uint(preAuthKeyShared1.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machineInShared1) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machineInShared2 := &types.Machine{ | ||||||
|  | 		ID:             2, | ||||||
|  | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		Hostname:       "test_get_shared_nodes_2", | ||||||
|  | 		UserID:         userShared2.ID, | ||||||
|  | 		User:           *userShared2, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||||
|  | 		AuthKeyID:      uint(preAuthKeyShared2.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machineInShared2) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machineInShared3 := &types.Machine{ | ||||||
|  | 		ID:             3, | ||||||
|  | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		Hostname:       "test_get_shared_nodes_3", | ||||||
|  | 		UserID:         userShared3.ID, | ||||||
|  | 		User:           *userShared3, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||||
|  | 		AuthKeyID:      uint(preAuthKeyShared3.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machineInShared3) | ||||||
|  | 
 | ||||||
|  | 	_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machine2InShared1 := &types.Machine{ | ||||||
|  | 		ID:             4, | ||||||
|  | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
|  | 		Hostname:       "test_get_shared_nodes_4", | ||||||
|  | 		UserID:         userShared1.ID, | ||||||
|  | 		User:           *userShared1, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||||
|  | 		AuthKeyID:      uint(preAuthKey2Shared1.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(machine2InShared1) | ||||||
|  | 
 | ||||||
|  | 	peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	userProfiles := db.GetMapResponseUserProfiles( | ||||||
|  | 		*machineInShared1, | ||||||
|  | 		peersOfMachine1InShared1, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	c.Assert(len(userProfiles), check.Equals, 3) | ||||||
|  | 
 | ||||||
|  | 	found := false | ||||||
|  | 	for _, userProfiles := range userProfiles { | ||||||
|  | 		if userProfiles.DisplayName == userShared1.Name { | ||||||
|  | 			found = true | ||||||
|  | 
 | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	c.Assert(found, check.Equals, true) | ||||||
|  | 
 | ||||||
|  | 	found = false | ||||||
|  | 	for _, userProfile := range userProfiles { | ||||||
|  | 		if userProfile.DisplayName == userShared2.Name { | ||||||
|  | 			found = true | ||||||
|  | 
 | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	c.Assert(found, check.Equals, true) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Suite) TestSetMachineUser(c *check.C) { | ||||||
|  | 	oldUser, err := db.CreateUser("old") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	newUser, err := db.CreateUser("new") | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 
 | ||||||
|  | 	machine := types.Machine{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     "foo", | ||||||
|  | 		NodeKey:        "bar", | ||||||
|  | 		DiscoKey:       "faa", | ||||||
|  | 		Hostname:       "testmachine", | ||||||
|  | 		UserID:         oldUser.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		AuthKeyID:      uint(pak.ID), | ||||||
|  | 	} | ||||||
|  | 	db.db.Save(&machine) | ||||||
|  | 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | ||||||
|  | 
 | ||||||
|  | 	err = db.SetMachineUser(&machine, newUser.Name) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||||
|  | 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||||
|  | 
 | ||||||
|  | 	err = db.SetMachineUser(&machine, "non-existing-user") | ||||||
|  | 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||||
|  | 
 | ||||||
|  | 	err = db.SetMachineUser(&machine, newUser.Name) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
|  | 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||||
|  | 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||||
|  | } | ||||||
| @ -7,6 +7,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	mapset "github.com/deckarep/golang-set/v2" | 	mapset "github.com/deckarep/golang-set/v2" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"go4.org/netipx" | 	"go4.org/netipx" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| 	"tailscale.com/types/dnstype" | 	"tailscale.com/types/dnstype" | ||||||
| @ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { | |||||||
| // | // | ||||||
| // This will produce a resolver like: | // This will produce a resolver like: | ||||||
| // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` | // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` | ||||||
| func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { | func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { | ||||||
| 	for _, resolver := range resolvers { | 	for _, resolver := range resolvers { | ||||||
| 		if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { | 		if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { | ||||||
| 			attrs := url.Values{ | 			attrs := url.Values{ | ||||||
| @ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { | |||||||
| func getMapResponseDNSConfig( | func getMapResponseDNSConfig( | ||||||
| 	dnsConfigOrig *tailcfg.DNSConfig, | 	dnsConfigOrig *tailcfg.DNSConfig, | ||||||
| 	baseDomain string, | 	baseDomain string, | ||||||
| 	machine Machine, | 	machine types.Machine, | ||||||
| 	peers Machines, | 	peers types.Machines, | ||||||
| ) *tailcfg.DNSConfig { | ) *tailcfg.DNSConfig { | ||||||
| 	var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() | 	var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() | ||||||
| 	if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled | 	if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled | ||||||
| @ -200,7 +201,7 @@ func getMapResponseDNSConfig( | |||||||
| 			), | 			), | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		userSet := mapset.NewSet[User]() | 		userSet := mapset.NewSet[types.User]() | ||||||
| 		userSet.Add(machine.User) | 		userSet.Add(machine.User) | ||||||
| 		for _, p := range peers { | 		for _, p := range peers { | ||||||
| 			userSet.Add(p.User) | 			userSet.Add(p.User) | ||||||
|  | |||||||
| @ -4,6 +4,8 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"gopkg.in/check.v1" | 	"gopkg.in/check.v1" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| 	"tailscale.com/types/dnstype" | 	"tailscale.com/types/dnstype" | ||||||
| @ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared1 := &Machine{ | 	machineInShared1 := &types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
| @ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_1", | 		Hostname:       "test_get_shared_nodes_1", | ||||||
| 		UserID:         userShared1.ID, | 		UserID:         userShared1.ID, | ||||||
| 		User:           *userShared1, | 		User:           *userShared1, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared1) | 	err = app.db.MachineSave(machineInShared1) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared2 := &Machine{ | 	machineInShared2 := &types.Machine{ | ||||||
| 		ID:             2, | 		ID:             2, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_2", | 		Hostname:       "test_get_shared_nodes_2", | ||||||
| 		UserID:         userShared2.ID, | 		UserID:         userShared2.ID, | ||||||
| 		User:           *userShared2, | 		User:           *userShared2, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared2) | 	err = app.db.MachineSave(machineInShared2) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared3 := &Machine{ | 	machineInShared3 := &types.Machine{ | ||||||
| 		ID:             3, | 		ID:             3, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_3", | 		Hostname:       "test_get_shared_nodes_3", | ||||||
| 		UserID:         userShared3.ID, | 		UserID:         userShared3.ID, | ||||||
| 		User:           *userShared3, | 		User:           *userShared3, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared3) | 	err = app.db.MachineSave(machineInShared3) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machine2InShared1 := &Machine{ | 	machine2InShared1 := &types.Machine{ | ||||||
| 		ID:             4, | 		ID:             4, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_4", | 		Hostname:       "test_get_shared_nodes_4", | ||||||
| 		UserID:         userShared1.ID, | 		UserID:         userShared1.ID, | ||||||
| 		User:           *userShared1, | 		User:           *userShared1, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||||
| 		AuthKeyID:      uint(PreAuthKey2InShared1.ID), | 		AuthKeyID:      uint(PreAuthKey2InShared1.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machine2InShared1) | 	err = app.db.MachineSave(machine2InShared1) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	baseDomain := "foobar.headscale.net" | 	baseDomain := "foobar.headscale.net" | ||||||
| 	dnsConfigOrig := tailcfg.DNSConfig{ | 	dnsConfigOrig := tailcfg.DNSConfig{ | ||||||
| @ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { | |||||||
| 		Proxied: true, | 		Proxied: true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) | 	peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	dnsConfig := getMapResponseDNSConfig( | 	dnsConfig := getMapResponseDNSConfig( | ||||||
| @ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") | ||||||
| 	c.Assert(err, check.NotNil) | 	c.Assert(err, check.NotNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared1 := &Machine{ | 	machineInShared1 := &types.Machine{ | ||||||
| 		ID:             1, | 		ID:             1, | ||||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", | ||||||
| @ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_1", | 		Hostname:       "test_get_shared_nodes_1", | ||||||
| 		UserID:         userShared1.ID, | 		UserID:         userShared1.ID, | ||||||
| 		User:           *userShared1, | 		User:           *userShared1, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | 		AuthKeyID:      uint(preAuthKeyInShared1.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared1) | 	err = app.db.MachineSave(machineInShared1) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared2 := &Machine{ | 	machineInShared2 := &types.Machine{ | ||||||
| 		ID:             2, | 		ID:             2, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_2", | 		Hostname:       "test_get_shared_nodes_2", | ||||||
| 		UserID:         userShared2.ID, | 		UserID:         userShared2.ID, | ||||||
| 		User:           *userShared2, | 		User:           *userShared2, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | 		AuthKeyID:      uint(preAuthKeyInShared2.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared2) | 	err = app.db.MachineSave(machineInShared2) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machineInShared3 := &Machine{ | 	machineInShared3 := &types.Machine{ | ||||||
| 		ID:             3, | 		ID:             3, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_3", | 		Hostname:       "test_get_shared_nodes_3", | ||||||
| 		UserID:         userShared3.ID, | 		UserID:         userShared3.ID, | ||||||
| 		User:           *userShared3, | 		User:           *userShared3, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, | ||||||
| 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | 		AuthKeyID:      uint(preAuthKeyInShared3.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machineInShared3) | 	err = app.db.MachineSave(machineInShared3) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	machine2InShared1 := &Machine{ | 	machine2InShared1 := &types.Machine{ | ||||||
| 		ID:             4, | 		ID:             4, | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", | ||||||
| @ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 		Hostname:       "test_get_shared_nodes_4", | 		Hostname:       "test_get_shared_nodes_4", | ||||||
| 		UserID:         userShared1.ID, | 		UserID:         userShared1.ID, | ||||||
| 		User:           *userShared1, | 		User:           *userShared1, | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, | ||||||
| 		AuthKeyID:      uint(preAuthKey2InShared1.ID), | 		AuthKeyID:      uint(preAuthKey2InShared1.ID), | ||||||
| 	} | 	} | ||||||
| 	app.db.db.Save(machine2InShared1) | 	err = app.db.MachineSave(machine2InShared1) | ||||||
|  | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	baseDomain := "foobar.headscale.net" | 	baseDomain := "foobar.headscale.net" | ||||||
| 	dnsConfigOrig := tailcfg.DNSConfig{ | 	dnsConfigOrig := tailcfg.DNSConfig{ | ||||||
| @ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { | |||||||
| 		Proxied: false, | 		Proxied: false, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) | 	peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	dnsConfig := getMapResponseDNSConfig( | 	dnsConfig := getMapResponseDNSConfig( | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| @ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.GetUserResponse{User: user.toProto()}, nil | 	return &v1.GetUserResponse{User: user.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) CreateUser( | func (api headscaleV1APIServer) CreateUser( | ||||||
| @ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.CreateUserResponse{User: user.toProto()}, nil | 	return &v1.CreateUserResponse{User: user.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) RenameUser( | func (api headscaleV1APIServer) RenameUser( | ||||||
| @ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.RenameUserResponse{User: user.toProto()}, nil | 	return &v1.RenameUserResponse{User: user.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) DeleteUser( | func (api headscaleV1APIServer) DeleteUser( | ||||||
| @ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers( | |||||||
| 
 | 
 | ||||||
| 	response := make([]*v1.User, len(users)) | 	response := make([]*v1.User, len(users)) | ||||||
| 	for index, user := range users { | 	for index, user := range users { | ||||||
| 		response[index] = user.toProto() | 		response[index] = user.Proto() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Trace().Caller().Interface("users", response).Msg("") | 	log.Trace().Caller().Interface("users", response).Msg("") | ||||||
| @ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil | 	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) ExpirePreAuthKey( | func (api headscaleV1APIServer) ExpirePreAuthKey( | ||||||
| @ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( | |||||||
| 
 | 
 | ||||||
| 	response := make([]*v1.PreAuthKey, len(preAuthKeys)) | 	response := make([]*v1.PreAuthKey, len(preAuthKeys)) | ||||||
| 	for index, key := range preAuthKeys { | 	for index, key := range preAuthKeys { | ||||||
| 		response[index] = key.toProto() | 		response[index] = key.Proto() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil | 	return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil | ||||||
| @ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine( | |||||||
| 		request.GetKey(), | 		request.GetKey(), | ||||||
| 		request.GetUser(), | 		request.GetUser(), | ||||||
| 		nil, | 		nil, | ||||||
| 		RegisterMethodCLI, | 		util.RegisterMethodCLI, | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil | 	return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) GetMachine( | func (api headscaleV1APIServer) GetMachine( | ||||||
| @ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.GetMachineResponse{Machine: machine.toProto()}, nil | 	return &v1.GetMachineResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) SetTags( | func (api headscaleV1APIServer) SetTags( | ||||||
| @ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags( | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) | 	err = api.h.db.SetTags(machine, request.GetTags()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return &v1.SetTagsResponse{ | 		return &v1.SetTagsResponse{ | ||||||
| 			Machine: nil, | 			Machine: nil, | ||||||
| @ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags( | |||||||
| 		Strs("tags", request.GetTags()). | 		Strs("tags", request.GetTags()). | ||||||
| 		Msg("Changing tags of machine") | 		Msg("Changing tags of machine") | ||||||
| 
 | 
 | ||||||
| 	return &v1.SetTagsResponse{Machine: machine.toProto()}, nil | 	return &v1.SetTagsResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func validateTag(tag string) error { | func validateTag(tag string) error { | ||||||
| @ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine( | |||||||
| 		Time("expiry", *machine.Expiry). | 		Time("expiry", *machine.Expiry). | ||||||
| 		Msg("machine expired") | 		Msg("machine expired") | ||||||
| 
 | 
 | ||||||
| 	return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil | 	return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) RenameMachine( | func (api headscaleV1APIServer) RenameMachine( | ||||||
| @ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine( | |||||||
| 		Str("new_name", request.GetNewName()). | 		Str("new_name", request.GetNewName()). | ||||||
| 		Msg("machine renamed") | 		Msg("machine renamed") | ||||||
| 
 | 
 | ||||||
| 	return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil | 	return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) ListMachines( | func (api headscaleV1APIServer) ListMachines( | ||||||
| @ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines( | |||||||
| 
 | 
 | ||||||
| 		response := make([]*v1.Machine, len(machines)) | 		response := make([]*v1.Machine, len(machines)) | ||||||
| 		for index, machine := range machines { | 		for index, machine := range machines { | ||||||
| 			response[index] = machine.toProto() | 			response[index] = machine.Proto() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return &v1.ListMachinesResponse{Machines: response}, nil | 		return &v1.ListMachinesResponse{Machines: response}, nil | ||||||
| @ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines( | |||||||
| 
 | 
 | ||||||
| 	response := make([]*v1.Machine, len(machines)) | 	response := make([]*v1.Machine, len(machines)) | ||||||
| 	for index, machine := range machines { | 	for index, machine := range machines { | ||||||
| 		m := machine.toProto() | 		m := machine.Proto() | ||||||
| 		validTags, invalidTags := getTags( | 		validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( | ||||||
| 			api.h.aclPolicy, |  | ||||||
| 			machine, | 			machine, | ||||||
| 			api.h.cfg.OIDC.StripEmaildomain, | 			api.h.cfg.OIDC.StripEmaildomain, | ||||||
| 		) | 		) | ||||||
| @ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil | 	return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) GetRoutes( | func (api headscaleV1APIServer) GetRoutes( | ||||||
| @ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.GetRoutesResponse{ | 	return &v1.GetRoutesResponse{ | ||||||
| 		Routes: Routes(routes).toProto(), | 		Routes: types.Routes(routes).Proto(), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.GetMachineRoutesResponse{ | 	return &v1.GetMachineRoutesResponse{ | ||||||
| 		Routes: Routes(routes).toProto(), | 		Routes: types.Routes(routes).Proto(), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey( | |||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	request *v1.ExpireApiKeyRequest, | 	request *v1.ExpireApiKeyRequest, | ||||||
| ) (*v1.ExpireApiKeyResponse, error) { | ) (*v1.ExpireApiKeyResponse, error) { | ||||||
| 	var apiKey *APIKey | 	var apiKey *types.APIKey | ||||||
| 	var err error | 	var err error | ||||||
| 
 | 
 | ||||||
| 	apiKey, err = api.h.db.GetAPIKey(request.Prefix) | 	apiKey, err = api.h.db.GetAPIKey(request.Prefix) | ||||||
| @ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys( | |||||||
| 
 | 
 | ||||||
| 	response := make([]*v1.ApiKey, len(apiKeys)) | 	response := make([]*v1.ApiKey, len(apiKeys)) | ||||||
| 	for index, key := range apiKeys { | 	for index, key := range apiKeys { | ||||||
| 		response[index] = key.toProto() | 		response[index] = key.Proto() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &v1.ListApiKeysResponse{ApiKeys: response}, nil | 	return &v1.ListApiKeysResponse{ApiKeys: response}, nil | ||||||
| @ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newMachine := Machine{ | 	newMachine := types.Machine{ | ||||||
| 		MachineKey: request.GetKey(), | 		MachineKey: request.GetKey(), | ||||||
| 		Hostname:   request.GetName(), | 		Hostname:   request.GetName(), | ||||||
| 		GivenName:  givenName, | 		GivenName:  givenName, | ||||||
| @ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | |||||||
| 		LastSeen:             &time.Time{}, | 		LastSeen:             &time.Time{}, | ||||||
| 		LastSuccessfulUpdate: &time.Time{}, | 		LastSuccessfulUpdate: &time.Time{}, | ||||||
| 
 | 
 | ||||||
| 		HostInfo: HostInfo(hostinfo), | 		HostInfo: types.HostInfo(hostinfo), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	nodeKey := key.NodePublic{} | 	nodeKey := key.NodePublic{} | ||||||
| @ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( | |||||||
| 		registerCacheExpiration, | 		registerCacheExpiration, | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil | 	return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} | func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} | ||||||
|  | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,142 +0,0 @@ | |||||||
| package hscontrol |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"go4.org/netipx" |  | ||||||
| 	"tailscale.com/tailcfg" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // This is borrowed from, and updated to use IPSet |  | ||||||
| // https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 |  | ||||||
| // TODO(kradalby): contribute upstream and make public. |  | ||||||
| var ( |  | ||||||
| 	zeroIP4 = netip.AddrFrom4([4]byte{}) |  | ||||||
| 	zeroIP6 = netip.AddrFrom16([16]byte{}) |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // parseIPSet parses arg as one: |  | ||||||
| // |  | ||||||
| //   - an IP address (IPv4 or IPv6) |  | ||||||
| //   - the string "*" to match everything (both IPv4 & IPv6) |  | ||||||
| //   - a CIDR (e.g. "192.168.0.0/16") |  | ||||||
| //   - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") |  | ||||||
| // |  | ||||||
| // bits, if non-nil, is the legacy SrcBits CIDR length to make a IP |  | ||||||
| // address (without a slash) treated as a CIDR of *bits length. |  | ||||||
| // nolint |  | ||||||
| func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { |  | ||||||
| 	var ipSet netipx.IPSetBuilder |  | ||||||
| 	if arg == "*" { |  | ||||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) |  | ||||||
| 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) |  | ||||||
| 
 |  | ||||||
| 		return ipSet.IPSet() |  | ||||||
| 	} |  | ||||||
| 	if strings.Contains(arg, "/") { |  | ||||||
| 		pfx, err := netip.ParsePrefix(arg) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		if pfx != pfx.Masked() { |  | ||||||
| 			return nil, fmt.Errorf("%v contains non-network bits set", pfx) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		ipSet.AddPrefix(pfx) |  | ||||||
| 
 |  | ||||||
| 		return ipSet.IPSet() |  | ||||||
| 	} |  | ||||||
| 	if strings.Count(arg, "-") == 1 { |  | ||||||
| 		ip1s, ip2s, _ := strings.Cut(arg, "-") |  | ||||||
| 
 |  | ||||||
| 		ip1, err := netip.ParseAddr(ip1s) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		ip2, err := netip.ParseAddr(ip2s) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		r := netipx.IPRangeFrom(ip1, ip2) |  | ||||||
| 		if !r.IsValid() { |  | ||||||
| 			return nil, fmt.Errorf("invalid IP range %q", arg) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		for _, prefix := range r.Prefixes() { |  | ||||||
| 			ipSet.AddPrefix(prefix) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return ipSet.IPSet() |  | ||||||
| 	} |  | ||||||
| 	ip, err := netip.ParseAddr(arg) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("invalid IP address %q", arg) |  | ||||||
| 	} |  | ||||||
| 	bits8 := uint8(ip.BitLen()) |  | ||||||
| 	if bits != nil { |  | ||||||
| 		if *bits < 0 || *bits > int(bits8) { |  | ||||||
| 			return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) |  | ||||||
| 		} |  | ||||||
| 		bits8 = uint8(*bits) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) |  | ||||||
| 
 |  | ||||||
| 	return ipSet.IPSet() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Match struct { |  | ||||||
| 	Srcs  *netipx.IPSet |  | ||||||
| 	Dests *netipx.IPSet |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func MatchFromFilterRule(rule tailcfg.FilterRule) Match { |  | ||||||
| 	srcs := new(netipx.IPSetBuilder) |  | ||||||
| 	dests := new(netipx.IPSetBuilder) |  | ||||||
| 
 |  | ||||||
| 	for _, srcIP := range rule.SrcIPs { |  | ||||||
| 		set, _ := parseIPSet(srcIP, nil) |  | ||||||
| 
 |  | ||||||
| 		srcs.AddSet(set) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, dest := range rule.DstPorts { |  | ||||||
| 		set, _ := parseIPSet(dest.IP, nil) |  | ||||||
| 
 |  | ||||||
| 		dests.AddSet(set) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	srcsSet, _ := srcs.IPSet() |  | ||||||
| 	destsSet, _ := dests.IPSet() |  | ||||||
| 
 |  | ||||||
| 	match := Match{ |  | ||||||
| 		Srcs:  srcsSet, |  | ||||||
| 		Dests: destsSet, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return match |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { |  | ||||||
| 	for _, ip := range ips { |  | ||||||
| 		if m.Srcs.Contains(ip) { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *Match) DestsContainsIP(ips []netip.Addr) bool { |  | ||||||
| 	for _, ip := range ips { |  | ||||||
| 		if m.Dests.Contains(ip) { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| @ -14,6 +14,8 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc/v3/oidc" | 	"github.com/coreos/go-oidc/v3/oidc" | ||||||
| 	"github.com/gorilla/mux" | 	"github.com/gorilla/mux" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/db" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| @ -638,7 +640,7 @@ func getUserName( | |||||||
| 	claims *IDTokenClaims, | 	claims *IDTokenClaims, | ||||||
| 	stripEmaildomain bool, | 	stripEmaildomain bool, | ||||||
| ) (string, error) { | ) (string, error) { | ||||||
| 	userName, err := NormalizeToFQDNRules( | 	userName, err := util.NormalizeToFQDNRules( | ||||||
| 		claims.Email, | 		claims.Email, | ||||||
| 		stripEmaildomain, | 		stripEmaildomain, | ||||||
| 	) | 	) | ||||||
| @ -663,9 +665,9 @@ func getUserName( | |||||||
| func (h *Headscale) findOrCreateNewUserForOIDCCallback( | func (h *Headscale) findOrCreateNewUserForOIDCCallback( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	userName string, | 	userName string, | ||||||
| ) (*User, error) { | ) (*types.User, error) { | ||||||
| 	user, err := h.db.GetUser(userName) | 	user, err := h.db.GetUser(userName) | ||||||
| 	if errors.Is(err, ErrUserNotFound) { | 	if errors.Is(err, db.ErrUserNotFound) { | ||||||
| 		user, err = h.db.CreateUser(userName) | 		user, err = h.db.CreateUser(userName) | ||||||
| 
 | 
 | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) registerMachineForOIDCCallback( | func (h *Headscale) registerMachineForOIDCCallback( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	user *User, | 	user *types.User, | ||||||
| 	nodeKey *key.NodePublic, | 	nodeKey *key.NodePublic, | ||||||
| 	expiry time.Time, | 	expiry time.Time, | ||||||
| ) error { | ) error { | ||||||
| @ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback( | |||||||
| 		nodeKey.String(), | 		nodeKey.String(), | ||||||
| 		user.Name, | 		user.Name, | ||||||
| 		&expiry, | 		&expiry, | ||||||
| 		RegisterMethodOIDC, | 		util.RegisterMethodOIDC, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| package hscontrol | package policy | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| @ -12,6 +12,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"github.com/tailscale/hujson" | 	"github.com/tailscale/hujson" | ||||||
| @ -22,12 +23,12 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	errEmptyPolicy       = errors.New("empty policy") | 	ErrEmptyPolicy       = errors.New("empty policy") | ||||||
| 	errInvalidAction     = errors.New("invalid action") | 	ErrInvalidAction     = errors.New("invalid action") | ||||||
| 	errInvalidGroup      = errors.New("invalid group") | 	ErrInvalidGroup      = errors.New("invalid group") | ||||||
| 	errInvalidTag        = errors.New("invalid tag") | 	ErrInvalidTag        = errors.New("invalid tag") | ||||||
| 	errInvalidPortFormat = errors.New("invalid port format") | 	ErrInvalidPortFormat = errors.New("invalid port format") | ||||||
| 	errWildcardIsNeeded  = errors.New("wildcard as port is required for the protocol") | 	ErrWildcardIsNeeded  = errors.New("wildcard as port is required for the protocol") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -56,7 +57,7 @@ const ( | |||||||
| var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") | var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") | ||||||
| 
 | 
 | ||||||
| // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. | // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. | ||||||
| func (h *Headscale) LoadACLPolicyFromPath(path string) error { | func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { | ||||||
| 	log.Debug(). | 	log.Debug(). | ||||||
| 		Str("func", "LoadACLPolicy"). | 		Str("func", "LoadACLPolicy"). | ||||||
| 		Str("path", path). | 		Str("path", path). | ||||||
| @ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { | |||||||
| 
 | 
 | ||||||
| 	policyFile, err := os.Open(path) | 	policyFile, err := os.Open(path) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer policyFile.Close() | 	defer policyFile.Close() | ||||||
| 
 | 
 | ||||||
| 	policyBytes, err := io.ReadAll(policyFile) | 	policyBytes, err := io.ReadAll(policyFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Debug(). | 	log.Debug(). | ||||||
| @ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { | |||||||
| 
 | 
 | ||||||
| 	switch filepath.Ext(path) { | 	switch filepath.Ext(path) { | ||||||
| 	case ".yml", ".yaml": | 	case ".yml", ".yaml": | ||||||
| 		return h.LoadACLPolicyFromBytes(policyBytes, "yaml") | 		return LoadACLPolicyFromBytes(policyBytes, "yaml") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return h.LoadACLPolicyFromBytes(policyBytes, "hujson") | 	return LoadACLPolicyFromBytes(policyBytes, "hujson") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { | func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { | ||||||
| 	var policy ACLPolicy | 	var policy ACLPolicy | ||||||
| 	switch format { | 	switch format { | ||||||
| 	case "yaml": | 	case "yaml": | ||||||
| 		err := yaml.Unmarshal(acl, &policy) | 		err := yaml.Unmarshal(acl, &policy) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	default: | 	default: | ||||||
| 		ast, err := hujson.Parse(acl) | 		ast, err := hujson.Parse(acl) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		ast.Standardize() | 		ast.Standardize() | ||||||
| 		acl = ast.Pack() | 		acl = ast.Pack() | ||||||
| 		err = json.Unmarshal(acl, &policy) | 		err = json.Unmarshal(acl, &policy) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if policy.IsZero() { | 	if policy.IsZero() { | ||||||
| 		return errEmptyPolicy | 		return nil, ErrEmptyPolicy | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	h.aclPolicy = &policy | 	return &policy, nil | ||||||
| 
 |  | ||||||
| 	return h.UpdateACLRules() |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) UpdateACLRules() error { | // TODO(kradalby): This needs to be replace with something that generates | ||||||
| 	machines, err := h.db.ListMachines() | // the rules as needed and not stores it on the global object, rules are | ||||||
| 	if err != nil { | // per node and that should be taken into account. | ||||||
| 		return err | func GenerateFilterRules( | ||||||
|  | 	policy *ACLPolicy, | ||||||
|  | 	machines types.Machines, | ||||||
|  | 	stripEmailDomain bool, | ||||||
|  | ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { | ||||||
|  | 	if policy == nil { | ||||||
|  | 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if h.aclPolicy == nil { | 	rules, err := policy.generateFilterRules(machines, stripEmailDomain) | ||||||
| 		return errEmptyPolicy |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Trace().Interface("ACL", rules).Msg("ACL rules generated") | 	log.Trace().Interface("ACL", rules).Msg("ACL rules generated") | ||||||
| 	h.aclRules = rules |  | ||||||
| 
 | 
 | ||||||
|  | 	var sshPolicy *tailcfg.SSHPolicy | ||||||
| 	if featureEnableSSH() { | 	if featureEnableSSH() { | ||||||
| 		sshRules, err := h.generateSSHRules() | 		sshRules, err := generateSSHRules(policy, machines, stripEmailDomain) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err | ||||||
| 		} | 		} | ||||||
| 		log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") | 		log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") | ||||||
| 		if h.sshPolicy == nil { | 		if sshPolicy == nil { | ||||||
| 			h.sshPolicy = &tailcfg.SSHPolicy{} | 			sshPolicy = &tailcfg.SSHPolicy{} | ||||||
| 		} | 		} | ||||||
| 		h.sshPolicy.Rules = sshRules | 		sshPolicy.Rules = sshRules | ||||||
| 	} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 { | 	} else if policy != nil && len(policy.SSHs) > 0 { | ||||||
| 		log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") | 		log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return rules, sshPolicy, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // generateFilterRules takes a set of machines and an ACLPolicy and generates a | // generateFilterRules takes a set of machines and an ACLPolicy and generates a | ||||||
| // set of Tailscale compatible FilterRules used to allow traffic on clients. | // set of Tailscale compatible FilterRules used to allow traffic on clients. | ||||||
| func (pol *ACLPolicy) generateFilterRules( | func (pol *ACLPolicy) generateFilterRules( | ||||||
| 	machines []Machine, | 	machines types.Machines, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) ([]tailcfg.FilterRule, error) { | ) ([]tailcfg.FilterRule, error) { | ||||||
| 	rules := []tailcfg.FilterRule{} | 	rules := []tailcfg.FilterRule{} | ||||||
| 
 | 
 | ||||||
| 	for index, acl := range pol.ACLs { | 	for index, acl := range pol.ACLs { | ||||||
| 		if acl.Action != "accept" { | 		if acl.Action != "accept" { | ||||||
| 			return nil, errInvalidAction | 			return nil, ErrInvalidAction | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		srcIPs := []string{} | 		srcIPs := []string{} | ||||||
| @ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules( | |||||||
| 	return rules, nil | 	return rules, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | func generateSSHRules( | ||||||
|  | 	policy *ACLPolicy, | ||||||
|  | 	machines types.Machines, | ||||||
|  | 	stripEmailDomain bool, | ||||||
|  | ) ([]*tailcfg.SSHRule, error) { | ||||||
| 	rules := []*tailcfg.SSHRule{} | 	rules := []*tailcfg.SSHRule{} | ||||||
| 
 | 
 | ||||||
| 	if h.aclPolicy == nil { | 	if policy == nil { | ||||||
| 		return nil, errEmptyPolicy | 		return nil, ErrEmptyPolicy | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	machines, err := h.db.ListMachines() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	acceptAction := tailcfg.SSHAction{ | 	acceptAction := tailcfg.SSHAction{ | ||||||
| @ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | |||||||
| 		AllowLocalPortForwarding: false, | 		AllowLocalPortForwarding: false, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for index, sshACL := range h.aclPolicy.SSHs { | 	for index, sshACL := range policy.SSHs { | ||||||
| 		action := rejectAction | 		action := rejectAction | ||||||
| 		switch sshACL.Action { | 		switch sshACL.Action { | ||||||
| 		case "accept": | 		case "accept": | ||||||
| @ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | |||||||
| 			} | 			} | ||||||
| 		default: | 		default: | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) | 				Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action) | ||||||
| 
 | 
 | ||||||
| 			return nil, err | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) | 		principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) | ||||||
| @ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | |||||||
| 					Any: true, | 					Any: true, | ||||||
| 				}) | 				}) | ||||||
| 			} else if isGroup(rawSrc) { | 			} else if isGroup(rawSrc) { | ||||||
| 				users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain) | 				users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error(). | 					log.Error(). | ||||||
| 						Msgf("Error parsing SSH %d, Source %d", index, innerIndex) | 						Msgf("Error parsing SSH %d, Source %d", index, innerIndex) | ||||||
| @ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { | |||||||
| 					}) | 					}) | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| 				expandedSrcs, err := h.aclPolicy.expandAlias( | 				expandedSrcs, err := policy.ExpandAlias( | ||||||
| 					machines, | 					machines, | ||||||
| 					rawSrc, | 					rawSrc, | ||||||
| 					h.cfg.OIDC.StripEmaildomain, | 					stripEmailDomain, | ||||||
| 				) | 				) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error(). | 					log.Error(). | ||||||
| @ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { | |||||||
| // with the given src alias. | // with the given src alias. | ||||||
| func (pol *ACLPolicy) getIPsFromSource( | func (pol *ACLPolicy) getIPsFromSource( | ||||||
| 	src string, | 	src string, | ||||||
| 	machines []Machine, | 	machines types.Machines, | ||||||
| 	stripEmaildomain bool, | 	stripEmaildomain bool, | ||||||
| ) ([]string, error) { | ) ([]string, error) { | ||||||
| 	ipSet, err := pol.expandAlias(machines, src, stripEmaildomain) | 	ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return []string{}, err | 		return []string{}, err | ||||||
| 	} | 	} | ||||||
| @ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource( | |||||||
| // which are associated with the dest alias. | // which are associated with the dest alias. | ||||||
| func (pol *ACLPolicy) getNetPortRangeFromDestination( | func (pol *ACLPolicy) getNetPortRangeFromDestination( | ||||||
| 	dest string, | 	dest string, | ||||||
| 	machines []Machine, | 	machines types.Machines, | ||||||
| 	needsWildcard bool, | 	needsWildcard bool, | ||||||
| 	stripEmaildomain bool, | 	stripEmaildomain bool, | ||||||
| ) ([]tailcfg.NetPortRange, error) { | ) ([]tailcfg.NetPortRange, error) { | ||||||
| @ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | |||||||
| 			return nil, fmt.Errorf( | 			return nil, fmt.Errorf( | ||||||
| 				"failed to parse destination, tokens %v: %w", | 				"failed to parse destination, tokens %v: %w", | ||||||
| 				tokens, | 				tokens, | ||||||
| 				errInvalidPortFormat, | 				ErrInvalidPortFormat, | ||||||
| 			) | 			) | ||||||
| 		} else { | 		} else { | ||||||
| 			tokens = []string{maybeIPv6Str, port} | 			tokens = []string{maybeIPv6Str, port} | ||||||
| @ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( | |||||||
| 		alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) | 		alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	expanded, err := pol.expandAlias( | 	expanded, err := pol.ExpandAlias( | ||||||
| 		machines, | 		machines, | ||||||
| 		alias, | 		alias, | ||||||
| 		stripEmaildomain, | 		stripEmaildomain, | ||||||
| @ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) { | |||||||
| // - an ip | // - an ip | ||||||
| // - a cidr | // - a cidr | ||||||
| // and transform these in IPAddresses. | // and transform these in IPAddresses. | ||||||
| func (pol *ACLPolicy) expandAlias( | func (pol *ACLPolicy) ExpandAlias( | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| 	alias string, | 	alias string, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	if isWildcard(alias) { | 	if isWildcard(alias) { | ||||||
| 		return parseIPSet("*", nil) | 		return util.ParseIPSet("*", nil) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	build := netipx.IPSetBuilder{} | 	build := netipx.IPSetBuilder{} | ||||||
| @ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias( | |||||||
| 	// if alias is an host | 	// if alias is an host | ||||||
| 	// Note, this is recursive. | 	// Note, this is recursive. | ||||||
| 	if h, ok := pol.Hosts[alias]; ok { | 	if h, ok := pol.Hosts[alias]; ok { | ||||||
| 		log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") | 		log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") | ||||||
| 
 | 
 | ||||||
| 		return pol.expandAlias(machines, h.String(), stripEmailDomain) | 		return pol.ExpandAlias(machines, h.String(), stripEmailDomain) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// if alias is an IP | 	// if alias is an IP | ||||||
| @ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias( | |||||||
| // we assume in this function that we only have nodes from 1 user. | // we assume in this function that we only have nodes from 1 user. | ||||||
| func excludeCorrectlyTaggedNodes( | func excludeCorrectlyTaggedNodes( | ||||||
| 	aclPolicy *ACLPolicy, | 	aclPolicy *ACLPolicy, | ||||||
| 	nodes []Machine, | 	nodes types.Machines, | ||||||
| 	user string, | 	user string, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) []Machine { | ) types.Machines { | ||||||
| 	out := []Machine{} | 	out := types.Machines{} | ||||||
| 	tags := []string{} | 	tags := []string{} | ||||||
| 	for tag := range aclPolicy.TagOwners { | 	for tag := range aclPolicy.TagOwners { | ||||||
| 		owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) | 		owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) | ||||||
| @ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if needsWildcard { | 	if needsWildcard { | ||||||
| 		return nil, errWildcardIsNeeded | 		return nil, ErrWildcardIsNeeded | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ports := []tailcfg.PortRange{} | 	ports := []tailcfg.PortRange{} | ||||||
| @ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err | |||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
| 		default: | 		default: | ||||||
| 			return nil, errInvalidPortFormat | 			return nil, ErrInvalidPortFormat | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &ports, nil | 	return &ports, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func filterMachinesByUser(machines []Machine, user string) []Machine { | func filterMachinesByUser(machines types.Machines, user string) types.Machines { | ||||||
| 	out := []Machine{} | 	out := types.Machines{} | ||||||
| 	for _, machine := range machines { | 	for _, machine := range machines { | ||||||
| 		if machine.User.Name == user { | 		if machine.User.Name == user { | ||||||
| 			out = append(out, machine) | 			out = append(out, machine) | ||||||
| @ -664,7 +664,7 @@ func getTagOwners( | |||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return []string{}, fmt.Errorf( | 		return []string{}, fmt.Errorf( | ||||||
| 			"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", | 			"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", | ||||||
| 			errInvalidTag, | 			ErrInvalidTag, | ||||||
| 			tag, | 			tag, | ||||||
| 		) | 		) | ||||||
| 	} | 	} | ||||||
| @ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup( | |||||||
| 		return []string{}, fmt.Errorf( | 		return []string{}, fmt.Errorf( | ||||||
| 			"group %v isn't registered. %w", | 			"group %v isn't registered. %w", | ||||||
| 			group, | 			group, | ||||||
| 			errInvalidGroup, | 			ErrInvalidGroup, | ||||||
| 		) | 		) | ||||||
| 	} | 	} | ||||||
| 	for _, group := range aclGroups { | 	for _, group := range aclGroups { | ||||||
| 		if isGroup(group) { | 		if isGroup(group) { | ||||||
| 			return []string{}, fmt.Errorf( | 			return []string{}, fmt.Errorf( | ||||||
| 				"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", | 				"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", | ||||||
| 				errInvalidGroup, | 				ErrInvalidGroup, | ||||||
| 			) | 			) | ||||||
| 		} | 		} | ||||||
| 		grp, err := NormalizeToFQDNRules(group, stripEmailDomain) | 		grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return []string{}, fmt.Errorf( | 			return []string{}, fmt.Errorf( | ||||||
| 				"failed to normalize group %q, err: %w", | 				"failed to normalize group %q, err: %w", | ||||||
| 				group, | 				group, | ||||||
| 				errInvalidGroup, | 				ErrInvalidGroup, | ||||||
| 			) | 			) | ||||||
| 		} | 		} | ||||||
| 		users = append(users, grp) | 		users = append(users, grp) | ||||||
| @ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) getIPsFromGroup( | func (pol *ACLPolicy) getIPsFromGroup( | ||||||
| 	group string, | 	group string, | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	build := netipx.IPSetBuilder{} | 	build := netipx.IPSetBuilder{} | ||||||
| @ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) getIPsFromTag( | func (pol *ACLPolicy) getIPsFromTag( | ||||||
| 	alias string, | 	alias string, | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	build := netipx.IPSetBuilder{} | 	build := netipx.IPSetBuilder{} | ||||||
| @ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag( | |||||||
| 	// find tag owners | 	// find tag owners | ||||||
| 	owners, err := getTagOwners(pol, alias, stripEmailDomain) | 	owners, err := getTagOwners(pol, alias, stripEmailDomain) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if errors.Is(err, errInvalidTag) { | 		if errors.Is(err, ErrInvalidTag) { | ||||||
| 			ipSet, _ := build.IPSet() | 			ipSet, _ := build.IPSet() | ||||||
| 			if len(ipSet.Prefixes()) == 0 { | 			if len(ipSet.Prefixes()) == 0 { | ||||||
| 				return ipSet, fmt.Errorf( | 				return ipSet, fmt.Errorf( | ||||||
| 					"%w. %v isn't owned by a TagOwner and no forced tags are defined", | 					"%w. %v isn't owned by a TagOwner and no forced tags are defined", | ||||||
| 					errInvalidTag, | 					ErrInvalidTag, | ||||||
| 					alias, | 					alias, | ||||||
| 				) | 				) | ||||||
| 			} | 			} | ||||||
| @ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) getIPsForUser( | func (pol *ACLPolicy) getIPsForUser( | ||||||
| 	user string, | 	user string, | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| 	stripEmailDomain bool, | 	stripEmailDomain bool, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	build := netipx.IPSetBuilder{} | 	build := netipx.IPSetBuilder{} | ||||||
| @ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) getIPsFromSingleIP( | func (pol *ACLPolicy) getIPsFromSingleIP( | ||||||
| 	ip netip.Addr, | 	ip netip.Addr, | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") | 	log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") | ||||||
| 
 | 
 | ||||||
| 	matches := machines.FilterByIP(ip) | 	matches := machines.FilterByIP(ip) | ||||||
| 
 | 
 | ||||||
| @ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP( | |||||||
| 
 | 
 | ||||||
| func (pol *ACLPolicy) getIPsFromIPPrefix( | func (pol *ACLPolicy) getIPsFromIPPrefix( | ||||||
| 	prefix netip.Prefix, | 	prefix netip.Prefix, | ||||||
| 	machines Machines, | 	machines types.Machines, | ||||||
| ) (*netipx.IPSet, error) { | ) (*netipx.IPSet, error) { | ||||||
| 	log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") | 	log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") | ||||||
| 	build := netipx.IPSetBuilder{} | 	build := netipx.IPSetBuilder{} | ||||||
| @ -862,3 +862,65 @@ func isGroup(str string) bool { | |||||||
| func isTag(str string) bool { | func isTag(str string) bool { | ||||||
| 	return strings.HasPrefix(str, "tag:") | 	return strings.HasPrefix(str, "tag:") | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // getTags will return the tags of the current machine. | ||||||
|  | // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. | ||||||
|  | // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. | ||||||
|  | func (pol *ACLPolicy) GetTagsOfMachine( | ||||||
|  | 	machine types.Machine, | ||||||
|  | 	stripEmailDomain bool, | ||||||
|  | ) ([]string, []string) { | ||||||
|  | 	validTags := make([]string, 0) | ||||||
|  | 	invalidTags := make([]string, 0) | ||||||
|  | 
 | ||||||
|  | 	validTagMap := make(map[string]bool) | ||||||
|  | 	invalidTagMap := make(map[string]bool) | ||||||
|  | 	for _, tag := range machine.HostInfo.RequestTags { | ||||||
|  | 		owners, err := getTagOwners(pol, tag, stripEmailDomain) | ||||||
|  | 		if errors.Is(err, ErrInvalidTag) { | ||||||
|  | 			invalidTagMap[tag] = true | ||||||
|  | 
 | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		var found bool | ||||||
|  | 		for _, owner := range owners { | ||||||
|  | 			if machine.User.Name == owner { | ||||||
|  | 				found = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if found { | ||||||
|  | 			validTagMap[tag] = true | ||||||
|  | 		} else { | ||||||
|  | 			invalidTagMap[tag] = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for tag := range invalidTagMap { | ||||||
|  | 		invalidTags = append(invalidTags, tag) | ||||||
|  | 	} | ||||||
|  | 	for tag := range validTagMap { | ||||||
|  | 		validTags = append(validTags, tag) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return validTags, invalidTags | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. | ||||||
|  | func FilterMachinesByACL( | ||||||
|  | 	machine *types.Machine, | ||||||
|  | 	machines types.Machines, | ||||||
|  | 	filter []tailcfg.FilterRule, | ||||||
|  | ) types.Machines { | ||||||
|  | 	result := types.Machines{} | ||||||
|  | 
 | ||||||
|  | 	for index, peer := range machines { | ||||||
|  | 		if peer.ID == machine.ID { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) { | ||||||
|  | 			result = append(result, peer) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return result | ||||||
|  | } | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,4 +1,4 @@ | |||||||
| package hscontrol | package policy | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
							
								
								
									
										61
									
								
								hscontrol/policy/matcher/matcher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								hscontrol/policy/matcher/matcher.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | |||||||
|  | package matcher | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"go4.org/netipx" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Match struct { | ||||||
|  | 	Srcs  *netipx.IPSet | ||||||
|  | 	Dests *netipx.IPSet | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func MatchFromFilterRule(rule tailcfg.FilterRule) Match { | ||||||
|  | 	srcs := new(netipx.IPSetBuilder) | ||||||
|  | 	dests := new(netipx.IPSetBuilder) | ||||||
|  | 
 | ||||||
|  | 	for _, srcIP := range rule.SrcIPs { | ||||||
|  | 		set, _ := util.ParseIPSet(srcIP, nil) | ||||||
|  | 
 | ||||||
|  | 		srcs.AddSet(set) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, dest := range rule.DstPorts { | ||||||
|  | 		set, _ := util.ParseIPSet(dest.IP, nil) | ||||||
|  | 
 | ||||||
|  | 		dests.AddSet(set) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	srcsSet, _ := srcs.IPSet() | ||||||
|  | 	destsSet, _ := dests.IPSet() | ||||||
|  | 
 | ||||||
|  | 	match := Match{ | ||||||
|  | 		Srcs:  srcsSet, | ||||||
|  | 		Dests: destsSet, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return match | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { | ||||||
|  | 	for _, ip := range ips { | ||||||
|  | 		if m.Srcs.Contains(ip) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *Match) DestsContainsIP(ips []netip.Addr) bool { | ||||||
|  | 	for _, ip := range ips { | ||||||
|  | 		if m.Dests.Contains(ip) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								hscontrol/policy/matcher/matcher_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								hscontrol/policy/matcher/matcher_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | package matcher | ||||||
| @ -9,6 +9,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon( | |||||||
| 		// that we rely on a method that calls back some how (OpenID or CLI) | 		// that we rely on a method that calls back some how (OpenID or CLI) | ||||||
| 		// We create the machine and then keep it around until a callback | 		// We create the machine and then keep it around until a callback | ||||||
| 		// happens | 		// happens | ||||||
| 		newMachine := Machine{ | 		newMachine := types.Machine{ | ||||||
| 			MachineKey: util.MachinePublicKeyStripPrefix(machineKey), | 			MachineKey: util.MachinePublicKeyStripPrefix(machineKey), | ||||||
| 			Hostname:   registerRequest.Hostinfo.Hostname, | 			Hostname:   registerRequest.Hostinfo.Hostname, | ||||||
| 			GivenName:  givenName, | 			GivenName:  givenName, | ||||||
| @ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon( | |||||||
| 			[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), | 			[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), | ||||||
| 		) | 		) | ||||||
| 		if err != nil || storedMachineKey.IsZero() { | 		if err != nil || storedMachineKey.IsZero() { | ||||||
| 			machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) | 			if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil { | ||||||
| 			if err := h.db.db.Save(&machine).Error; err != nil { |  | ||||||
| 				log.Error(). | 				log.Error(). | ||||||
| 					Caller(). | 					Caller(). | ||||||
| 					Str("func", "RegistrationHandler"). | 					Str("func", "RegistrationHandler"). | ||||||
| @ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon( | |||||||
| 
 | 
 | ||||||
| 			// If machine is not expired, and it is register, we have a already accepted this machine, | 			// If machine is not expired, and it is register, we have a already accepted this machine, | ||||||
| 			// let it proceed with a valid registration | 			// let it proceed with a valid registration | ||||||
| 			if !machine.isExpired() { | 			if !machine.IsExpired() { | ||||||
| 				h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) | 				h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) | ||||||
| 
 | 
 | ||||||
| 				return | 				return | ||||||
| @ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon( | |||||||
| 
 | 
 | ||||||
| 		// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration | 		// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration | ||||||
| 		if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && | 		if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && | ||||||
| 			!machine.isExpired() { | 			!machine.IsExpired() { | ||||||
| 			h.handleMachineRefreshKeyCommon( | 			h.handleMachineRefreshKeyCommon( | ||||||
| 				writer, | 				writer, | ||||||
| 				registerRequest, | 				registerRequest, | ||||||
| @ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 		Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) | 		Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) | ||||||
| 	resp := tailcfg.RegisterResponse{} | 	resp := tailcfg.RegisterResponse{} | ||||||
| 
 | 
 | ||||||
| 	pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) | 	pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 				Err(err). | 				Err(err). | ||||||
| 				Msg("Cannot encode message") | 				Msg("Cannot encode message") | ||||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||||
| 				Inc() | 				Inc() | ||||||
| 
 | 
 | ||||||
| 			return | 			return | ||||||
| @ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 			Msg("Failed authentication via AuthKey") | 			Msg("Failed authentication via AuthKey") | ||||||
| 
 | 
 | ||||||
| 		if pak != nil { | 		if pak != nil { | ||||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||||
| 				Inc() | 				Inc() | ||||||
| 		} else { | 		} else { | ||||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() | 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| @ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		aclTags := pak.toProto().AclTags | 		aclTags := pak.Proto().AclTags | ||||||
| 		if len(aclTags) > 0 { | 		if len(aclTags) > 0 { | ||||||
| 			// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login | 			// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login | ||||||
| 			err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) | 			err = h.db.SetTags(machine, aclTags) | ||||||
| 
 | 
 | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Error(). | 				log.Error(). | ||||||
| @ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		machineToRegister := Machine{ | 		machineToRegister := types.Machine{ | ||||||
| 			Hostname:       registerRequest.Hostinfo.Hostname, | 			Hostname:       registerRequest.Hostinfo.Hostname, | ||||||
| 			GivenName:      givenName, | 			GivenName:      givenName, | ||||||
| 			UserID:         pak.User.ID, | 			UserID:         pak.User.ID, | ||||||
| 			MachineKey:     util.MachinePublicKeyStripPrefix(machineKey), | 			MachineKey:     util.MachinePublicKeyStripPrefix(machineKey), | ||||||
| 			RegisterMethod: RegisterMethodAuthKey, | 			RegisterMethod: util.RegisterMethodAuthKey, | ||||||
| 			Expiry:         ®isterRequest.Expiry, | 			Expiry:         ®isterRequest.Expiry, | ||||||
| 			NodeKey:        nodeKey, | 			NodeKey:        nodeKey, | ||||||
| 			LastSeen:       &now, | 			LastSeen:       &now, | ||||||
| 			AuthKeyID:      uint(pak.ID), | 			AuthKeyID:      uint(pak.ID), | ||||||
| 			ForcedTags:     pak.toProto().AclTags, | 			ForcedTags:     pak.Proto().AclTags, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		machine, err = h.db.RegisterMachine( | 		machine, err = h.db.RegisterMachine( | ||||||
| @ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 				Bool("noise", isNoise). | 				Bool("noise", isNoise). | ||||||
| 				Err(err). | 				Err(err). | ||||||
| 				Msg("could not register machine") | 				Msg("could not register machine") | ||||||
| 			machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | 			machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||||
| 				Inc() | 				Inc() | ||||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||||
| 
 | 
 | ||||||
| @ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 			Bool("noise", isNoise). | 			Bool("noise", isNoise). | ||||||
| 			Err(err). | 			Err(err). | ||||||
| 			Msg("Failed to use pre-auth key") | 			Msg("Failed to use pre-auth key") | ||||||
| 		machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | 		machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||||
| 			Inc() | 			Inc() | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||||
| 
 | 
 | ||||||
| @ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp.MachineAuthorized = true | 	resp.MachineAuthorized = true | ||||||
| 	resp.User = *pak.User.toTailscaleUser() | 	resp.User = *pak.User.TailscaleUser() | ||||||
| 	// Provide LoginName when registering with pre-auth key | 	// Provide LoginName when registering with pre-auth key | ||||||
| 	// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* | 	// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* | ||||||
| 	resp.Login = *pak.User.toTailscaleLogin() | 	resp.Login = *pak.User.TailscaleLogin() | ||||||
| 
 | 
 | ||||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon( | |||||||
| 			Str("machine", registerRequest.Hostinfo.Hostname). | 			Str("machine", registerRequest.Hostinfo.Hostname). | ||||||
| 			Err(err). | 			Err(err). | ||||||
| 			Msg("Cannot encode message") | 			Msg("Cannot encode message") | ||||||
| 		machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). | 		machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). | ||||||
| 			Inc() | 			Inc() | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). | 	machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). | ||||||
| 		Inc() | 		Inc() | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | ||||||
| 	writer.WriteHeader(http.StatusOK) | 	writer.WriteHeader(http.StatusOK) | ||||||
| @ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) handleMachineLogOutCommon( | func (h *Headscale) handleMachineLogOutCommon( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	machine Machine, | 	machine types.Machine, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
| @ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon( | |||||||
| 	resp.AuthURL = "" | 	resp.AuthURL = "" | ||||||
| 	resp.MachineAuthorized = false | 	resp.MachineAuthorized = false | ||||||
| 	resp.NodeKeyExpired = true | 	resp.NodeKeyExpired = true | ||||||
| 	resp.User = *machine.User.toTailscaleUser() | 	resp.User = *machine.User.TailscaleUser() | ||||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| @ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if machine.isEphemeral() { | 	if machine.IsEphemeral() { | ||||||
| 		err = h.db.HardDeleteMachine(&machine) | 		err = h.db.HardDeleteMachine(&machine) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| @ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) handleMachineValidRegistrationCommon( | func (h *Headscale) handleMachineValidRegistrationCommon( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	machine Machine, | 	machine types.Machine, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
| @ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon( | |||||||
| 
 | 
 | ||||||
| 	resp.AuthURL = "" | 	resp.AuthURL = "" | ||||||
| 	resp.MachineAuthorized = true | 	resp.MachineAuthorized = true | ||||||
| 	resp.User = *machine.User.toTailscaleUser() | 	resp.User = *machine.User.TailscaleUser() | ||||||
| 	resp.Login = *machine.User.toTailscaleLogin() | 	resp.Login = *machine.User.TailscaleLogin() | ||||||
| 
 | 
 | ||||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon( | |||||||
| func (h *Headscale) handleMachineRefreshKeyCommon( | func (h *Headscale) handleMachineRefreshKeyCommon( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	registerRequest tailcfg.RegisterRequest, | 	registerRequest tailcfg.RegisterRequest, | ||||||
| 	machine Machine, | 	machine types.Machine, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
| @ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | |||||||
| 		Bool("noise", isNoise). | 		Bool("noise", isNoise). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("machine", machine.Hostname). | ||||||
| 		Msg("We have the OldNodeKey in the database. This is a key refresh") | 		Msg("We have the OldNodeKey in the database. This is a key refresh") | ||||||
| 	machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) |  | ||||||
| 
 | 
 | ||||||
| 	if err := h.db.db.Save(&machine).Error; err != nil { | 	err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey) | ||||||
|  | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| 			Err(err). | 			Err(err). | ||||||
| @ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp.AuthURL = "" | 	resp.AuthURL = "" | ||||||
| 	resp.User = *machine.User.toTailscaleUser() | 	resp.User = *machine.User.TailscaleUser() | ||||||
| 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | 	respBody, err := h.marshalResponse(resp, machineKey, isNoise) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| @ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( | |||||||
| func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( | func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	registerRequest tailcfg.RegisterRequest, | 	registerRequest tailcfg.RegisterRequest, | ||||||
| 	machine Machine, | 	machine types.Machine, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| @ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName") | |||||||
| func (h *Headscale) handlePollCommon( | func (h *Headscale) handlePollCommon( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
| 	machine.Hostname = mapRequest.Hostinfo.Hostname | 	machine.Hostname = mapRequest.Hostinfo.Hostname | ||||||
| 	machine.HostInfo = HostInfo(*mapRequest.Hostinfo) | 	machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) | ||||||
| 	machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) | 	machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) | ||||||
| 	now := time.Now().UTC() | 	now := time.Now().UTC() | ||||||
| 
 | 
 | ||||||
| 	err := h.db.processMachineRoutes(machine) | 	err := h.db.ProcessMachineRoutes(machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// update ACLRules with peer informations (to update server tags if necessary) | 	// update ACLRules with peer informations (to update server tags if necessary) | ||||||
| 	if h.aclPolicy != nil { | 	if h.ACLPolicy != nil { | ||||||
| 		err := h.UpdateACLRules() | 		// TODO(kradalby): Since this is not blocking, I might have introduced a bug here. | ||||||
| 		if err != nil { | 		// It will be resolved later as we change up the policy stuff. | ||||||
| 			log.Error(). | 		h.policyUpdateChan <- struct{}{} | ||||||
| 				Caller(). |  | ||||||
| 				Bool("noise", isNoise). |  | ||||||
| 				Str("machine", machine.Hostname). |  | ||||||
| 				Err(err) |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		// update routes with peer information | 		// update routes with peer information | ||||||
| 		err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) | 		err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Caller(). | 				Caller(). | ||||||
| @ -78,19 +74,17 @@ func (h *Headscale) handlePollCommon( | |||||||
| 		machine.LastSeen = &now | 		machine.LastSeen = &now | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := h.db.db.Updates(machine).Error; err != nil { | 	if err := h.db.MachineSave(machine); err != nil { | ||||||
| 		if err != nil { | 		log.Error(). | ||||||
| 			log.Error(). | 			Str("handler", "PollNetMap"). | ||||||
| 				Str("handler", "PollNetMap"). | 			Bool("noise", isNoise). | ||||||
| 				Bool("noise", isNoise). | 			Str("node_key", machine.NodeKey). | ||||||
| 				Str("node_key", machine.NodeKey). | 			Str("machine", machine.Hostname). | ||||||
| 				Str("machine", machine.Hostname). | 			Err(err). | ||||||
| 				Err(err). | 			Msg("Failed to persist/update machine in the database") | ||||||
| 				Msg("Failed to persist/update machine in the database") | 		http.Error(writer, "", http.StatusInternalServerError) | ||||||
| 			http.Error(writer, "", http.StatusInternalServerError) |  | ||||||
| 
 | 
 | ||||||
| 			return | 		return | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) | 	mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) | ||||||
| @ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon( | |||||||
| func (h *Headscale) pollNetMapStream( | func (h *Headscale) pollNetMapStream( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	ctxReq context.Context, | 	ctxReq context.Context, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	pollDataChan chan []byte, | 	pollDataChan chan []byte, | ||||||
| 	keepAliveChan chan []byte, | 	keepAliveChan chan []byte, | ||||||
| @ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream( | |||||||
| 			updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). | 			updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). | ||||||
| 				Inc() | 				Inc() | ||||||
| 
 | 
 | ||||||
| 			if h.db.isOutdated(machine, h.getLastStateChange()) { | 			if h.db.IsOutdated(machine, h.getLastStateChange()) { | ||||||
| 				var lastUpdate time.Time | 				var lastUpdate time.Time | ||||||
| 				if machine.LastSuccessfulUpdate != nil { | 				if machine.LastSuccessfulUpdate != nil { | ||||||
| 					lastUpdate = *machine.LastSuccessfulUpdate | 					lastUpdate = *machine.LastSuccessfulUpdate | ||||||
| @ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker( | |||||||
| 	updateChan chan struct{}, | 	updateChan chan struct{}, | ||||||
| 	keepAliveChan chan []byte, | 	keepAliveChan chan []byte, | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) { | ) { | ||||||
| 	keepAliveTicker := time.NewTicker(keepAliveInterval) | 	keepAliveTicker := time.NewTicker(keepAliveInterval) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/klauspost/compress/zstd" | 	"github.com/klauspost/compress/zstd" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| @ -15,7 +16,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) getMapResponseData( | func (h *Headscale) getMapResponseData( | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) ([]byte, error) { | ) ([]byte, error) { | ||||||
| 	mapResponse, err := h.generateMapResponse(mapRequest, machine) | 	mapResponse, err := h.generateMapResponse(mapRequest, machine) | ||||||
| @ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData( | |||||||
| 
 | 
 | ||||||
| func (h *Headscale) getMapKeepAliveResponseData( | func (h *Headscale) getMapKeepAliveResponseData( | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *Machine, | 	machine *types.Machine, | ||||||
| 	isNoise bool, | 	isNoise bool, | ||||||
| ) ([]byte, error) { | ) ([]byte, error) { | ||||||
| 	keepAliveResponse := tailcfg.MapResponse{ | 	keepAliveResponse := tailcfg.MapResponse{ | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ type Suite struct{} | |||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	tmpDir string | 	tmpDir string | ||||||
| 	app    Headscale | 	app    *Headscale | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Suite) SetUpTest(c *check.C) { | func (s *Suite) SetUpTest(c *check.C) { | ||||||
| @ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) { | |||||||
| 		os.RemoveAll(tmpDir) | 		os.RemoveAll(tmpDir) | ||||||
| 	} | 	} | ||||||
| 	var err error | 	var err error | ||||||
| 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test") | 	tmpDir, err = os.MkdirTemp("", "autoygg-client-test2") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.Fatal(err) | 		c.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	cfg := Config{ | 	cfg := Config{ | ||||||
|  | 		PrivateKeyPath:      tmpDir + "/private.key", | ||||||
|  | 		NoisePrivateKeyPath: tmpDir + "/noise_private.key", | ||||||
|  | 		DBtype:              "sqlite3", | ||||||
|  | 		DBpath:              tmpDir + "/headscale_test.db", | ||||||
| 		IPPrefixes: []netip.Prefix{ | 		IPPrefixes: []netip.Prefix{ | ||||||
| 			netip.MustParsePrefix("10.27.0.0/23"), | 			netip.MustParsePrefix("10.27.0.0/23"), | ||||||
| 		}, | 		}, | ||||||
| @ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift | 	app, err = NewHeadscale(&cfg) | ||||||
| 	app = Headscale{ |  | ||||||
| 		cfg:      &cfg, |  | ||||||
| 		dbType:   "sqlite3", |  | ||||||
| 		dbString: tmpDir + "/headscale_test.db", |  | ||||||
| 
 |  | ||||||
| 		stateUpdateChan:       make(chan struct{}), |  | ||||||
| 		cancelStateUpdateChan: make(chan struct{}), |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	go app.watchStateChannel() |  | ||||||
| 
 |  | ||||||
| 	db, err := NewHeadscaleDatabase( |  | ||||||
| 		app.dbType, |  | ||||||
| 		app.dbString, |  | ||||||
| 		cfg.OIDC.StripEmaildomain, |  | ||||||
| 		false, |  | ||||||
| 		app.stateUpdateChan, |  | ||||||
| 		cfg.IPPrefixes, |  | ||||||
| 		"", |  | ||||||
| 	) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.Fatal(err) | 		c.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	app.db = db |  | ||||||
| } | } | ||||||
							
								
								
									
										41
									
								
								hscontrol/types/api_key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								hscontrol/types/api_key.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,41 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"google.golang.org/protobuf/types/known/timestamppb" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // APIKey describes the datamodel for API keys used to remotely authenticate with | ||||||
|  | // headscale. | ||||||
|  | type APIKey struct { | ||||||
|  | 	ID     uint64 `gorm:"primary_key"` | ||||||
|  | 	Prefix string `gorm:"uniqueIndex"` | ||||||
|  | 	Hash   []byte | ||||||
|  | 
 | ||||||
|  | 	CreatedAt  *time.Time | ||||||
|  | 	Expiration *time.Time | ||||||
|  | 	LastSeen   *time.Time | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (key *APIKey) Proto() *v1.ApiKey { | ||||||
|  | 	protoKey := v1.ApiKey{ | ||||||
|  | 		Id:     key.ID, | ||||||
|  | 		Prefix: key.Prefix, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if key.Expiration != nil { | ||||||
|  | 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if key.CreatedAt != nil { | ||||||
|  | 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if key.LastSeen != nil { | ||||||
|  | 		protoKey.LastSeen = timestamppb.New(*key.LastSeen) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &protoKey | ||||||
|  | } | ||||||
							
								
								
									
										108
									
								
								hscontrol/types/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								hscontrol/types/common.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
|  | 
 | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ErrCannotParsePrefix = errors.New("cannot parse prefix") | ||||||
|  | 
 | ||||||
|  | // This is a "wrapper" type around tailscales | ||||||
|  | // Hostinfo to allow us to add database "serialization" | ||||||
|  | // methods. This allows us to use a typed values throughout | ||||||
|  | // the code and not have to marshal/unmarshal and error | ||||||
|  | // check all over the code. | ||||||
|  | type HostInfo tailcfg.Hostinfo | ||||||
|  | 
 | ||||||
|  | func (hi *HostInfo) Scan(destination interface{}) error { | ||||||
|  | 	switch value := destination.(type) { | ||||||
|  | 	case []byte: | ||||||
|  | 		return json.Unmarshal(value, hi) | ||||||
|  | 
 | ||||||
|  | 	case string: | ||||||
|  | 		return json.Unmarshal([]byte(value), hi) | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value return json value, implement driver.Valuer interface. | ||||||
|  | func (hi HostInfo) Value() (driver.Value, error) { | ||||||
|  | 	bytes, err := json.Marshal(hi) | ||||||
|  | 
 | ||||||
|  | 	return string(bytes), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type IPPrefix netip.Prefix | ||||||
|  | 
 | ||||||
|  | func (i *IPPrefix) Scan(destination interface{}) error { | ||||||
|  | 	switch value := destination.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		prefix, err := netip.ParsePrefix(value) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		*i = IPPrefix(prefix) | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value return json value, implement driver.Valuer interface. | ||||||
|  | func (i IPPrefix) Value() (driver.Value, error) { | ||||||
|  | 	prefixStr := netip.Prefix(i).String() | ||||||
|  | 
 | ||||||
|  | 	return prefixStr, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type IPPrefixes []netip.Prefix | ||||||
|  | 
 | ||||||
|  | func (i *IPPrefixes) Scan(destination interface{}) error { | ||||||
|  | 	switch value := destination.(type) { | ||||||
|  | 	case []byte: | ||||||
|  | 		return json.Unmarshal(value, i) | ||||||
|  | 
 | ||||||
|  | 	case string: | ||||||
|  | 		return json.Unmarshal([]byte(value), i) | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value return json value, implement driver.Valuer interface. | ||||||
|  | func (i IPPrefixes) Value() (driver.Value, error) { | ||||||
|  | 	bytes, err := json.Marshal(i) | ||||||
|  | 
 | ||||||
|  | 	return string(bytes), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type StringList []string | ||||||
|  | 
 | ||||||
|  | func (i *StringList) Scan(destination interface{}) error { | ||||||
|  | 	switch value := destination.(type) { | ||||||
|  | 	case []byte: | ||||||
|  | 		return json.Unmarshal(value, i) | ||||||
|  | 
 | ||||||
|  | 	case string: | ||||||
|  | 		return json.Unmarshal([]byte(value), i) | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value return json value, implement driver.Valuer interface. | ||||||
|  | func (i StringList) Value() (driver.Value, error) { | ||||||
|  | 	bytes, err := json.Marshal(i) | ||||||
|  | 
 | ||||||
|  | 	return string(bytes), err | ||||||
|  | } | ||||||
							
								
								
									
										254
									
								
								hscontrol/types/machine.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								hscontrol/types/machine.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/policy/matcher" | ||||||
|  | 	"go4.org/netipx" | ||||||
|  | 	"google.golang.org/protobuf/types/known/timestamppb" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	// TODO(kradalby): Move out of here when we got circdeps under control. | ||||||
|  | 	keepAliveInterval = 60 * time.Second | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") | ||||||
|  | 
 | ||||||
|  | // Machine is a Headscale client. | ||||||
|  | type Machine struct { | ||||||
|  | 	ID          uint64 `gorm:"primary_key"` | ||||||
|  | 	MachineKey  string `gorm:"type:varchar(64);unique_index"` | ||||||
|  | 	NodeKey     string | ||||||
|  | 	DiscoKey    string | ||||||
|  | 	IPAddresses MachineAddresses | ||||||
|  | 
 | ||||||
|  | 	// Hostname represents the name given by the Tailscale | ||||||
|  | 	// client during registration | ||||||
|  | 	Hostname string | ||||||
|  | 
 | ||||||
|  | 	// Givenname represents either: | ||||||
|  | 	// a DNS normalized version of Hostname | ||||||
|  | 	// a valid name set by the User | ||||||
|  | 	// | ||||||
|  | 	// GivenName is the name used in all DNS related | ||||||
|  | 	// parts of headscale. | ||||||
|  | 	GivenName string `gorm:"type:varchar(63);unique_index"` | ||||||
|  | 	UserID    uint | ||||||
|  | 	User      User `gorm:"foreignKey:UserID"` | ||||||
|  | 
 | ||||||
|  | 	RegisterMethod string | ||||||
|  | 
 | ||||||
|  | 	ForcedTags StringList | ||||||
|  | 
 | ||||||
|  | 	// TODO(kradalby): This seems like irrelevant information? | ||||||
|  | 	AuthKeyID uint | ||||||
|  | 	AuthKey   *PreAuthKey | ||||||
|  | 
 | ||||||
|  | 	LastSeen             *time.Time | ||||||
|  | 	LastSuccessfulUpdate *time.Time | ||||||
|  | 	Expiry               *time.Time | ||||||
|  | 
 | ||||||
|  | 	HostInfo  HostInfo | ||||||
|  | 	Endpoints StringList | ||||||
|  | 
 | ||||||
|  | 	CreatedAt time.Time | ||||||
|  | 	UpdatedAt time.Time | ||||||
|  | 	DeletedAt *time.Time | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ( | ||||||
|  | 	Machines  []Machine | ||||||
|  | 	MachinesP []*Machine | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type MachineAddresses []netip.Addr | ||||||
|  | 
 | ||||||
|  | func (ma MachineAddresses) ToStringSlice() []string { | ||||||
|  | 	strSlice := make([]string, 0, len(ma)) | ||||||
|  | 	for _, addr := range ma { | ||||||
|  | 		strSlice = append(strSlice, addr.String()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return strSlice | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // AppendToIPSet adds the individual ips in MachineAddresses to a | ||||||
|  | // given netipx.IPSetBuilder. | ||||||
|  | func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { | ||||||
|  | 	for _, ip := range ma { | ||||||
|  | 		build.Add(ip) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ma *MachineAddresses) Scan(destination interface{}) error { | ||||||
|  | 	switch value := destination.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		addresses := strings.Split(value, ",") | ||||||
|  | 		*ma = (*ma)[:0] | ||||||
|  | 		for _, addr := range addresses { | ||||||
|  | 			if len(addr) < 1 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			parsed, err := netip.ParseAddr(addr) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			*ma = append(*ma, parsed) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Value return json value, implement driver.Valuer interface. | ||||||
|  | func (ma MachineAddresses) Value() (driver.Value, error) { | ||||||
|  | 	addresses := strings.Join(ma.ToStringSlice(), ",") | ||||||
|  | 
 | ||||||
|  | 	return addresses, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsExpired returns whether the machine registration has expired. | ||||||
|  | func (machine Machine) IsExpired() bool { | ||||||
|  | 	// If Expiry is not set, the client has not indicated that | ||||||
|  | 	// it wants an expiry time, it is therefor considered | ||||||
|  | 	// to mean "not expired" | ||||||
|  | 	if machine.Expiry == nil || machine.Expiry.IsZero() { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return time.Now().UTC().After(*machine.Expiry) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsOnline returns if the machine is connected to Headscale. | ||||||
|  | // This is really a naive implementation, as we don't really see | ||||||
|  | // if there is a working connection between the client and the server. | ||||||
|  | func (machine *Machine) IsOnline() bool { | ||||||
|  | 	if machine.LastSeen == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if machine.IsExpired() { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsEphemeral returns if the machine is registered as an Ephemeral node. | ||||||
|  | // https://tailscale.com/kb/1111/ephemeral-nodes/ | ||||||
|  | func (machine *Machine) IsEphemeral() bool { | ||||||
|  | 	return machine.AuthKey != nil && machine.AuthKey.Ephemeral | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { | ||||||
|  | 	for _, rule := range filter { | ||||||
|  | 		// TODO(kradalby): Cache or pregen this | ||||||
|  | 		matcher := matcher.MatchFromFilterRule(rule) | ||||||
|  | 
 | ||||||
|  | 		if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (machines Machines) FilterByIP(ip netip.Addr) Machines { | ||||||
|  | 	found := make(Machines, 0) | ||||||
|  | 
 | ||||||
|  | 	for _, machine := range machines { | ||||||
|  | 		for _, mIP := range machine.IPAddresses { | ||||||
|  | 			if ip == mIP { | ||||||
|  | 				found = append(found, machine) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return found | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (machine *Machine) Proto() *v1.Machine { | ||||||
|  | 	machineProto := &v1.Machine{ | ||||||
|  | 		Id:         machine.ID, | ||||||
|  | 		MachineKey: machine.MachineKey, | ||||||
|  | 
 | ||||||
|  | 		NodeKey:     machine.NodeKey, | ||||||
|  | 		DiscoKey:    machine.DiscoKey, | ||||||
|  | 		IpAddresses: machine.IPAddresses.ToStringSlice(), | ||||||
|  | 		Name:        machine.Hostname, | ||||||
|  | 		GivenName:   machine.GivenName, | ||||||
|  | 		User:        machine.User.Proto(), | ||||||
|  | 		ForcedTags:  machine.ForcedTags, | ||||||
|  | 		Online:      machine.IsOnline(), | ||||||
|  | 
 | ||||||
|  | 		// TODO(kradalby): Implement register method enum converter | ||||||
|  | 		// RegisterMethod: , | ||||||
|  | 
 | ||||||
|  | 		CreatedAt: timestamppb.New(machine.CreatedAt), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if machine.AuthKey != nil { | ||||||
|  | 		machineProto.PreAuthKey = machine.AuthKey.Proto() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if machine.LastSeen != nil { | ||||||
|  | 		machineProto.LastSeen = timestamppb.New(*machine.LastSeen) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if machine.LastSuccessfulUpdate != nil { | ||||||
|  | 		machineProto.LastSuccessfulUpdate = timestamppb.New( | ||||||
|  | 			*machine.LastSuccessfulUpdate, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if machine.Expiry != nil { | ||||||
|  | 		machineProto.Expiry = timestamppb.New(*machine.Expiry) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return machineProto | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetHostInfo returns a Hostinfo struct for the machine. | ||||||
|  | func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { | ||||||
|  | 	return tailcfg.Hostinfo(machine.HostInfo) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (machine Machine) String() string { | ||||||
|  | 	return machine.Hostname | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (machines Machines) String() string { | ||||||
|  | 	temp := make([]string, len(machines)) | ||||||
|  | 
 | ||||||
|  | 	for index, machine := range machines { | ||||||
|  | 		temp[index] = machine.Hostname | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TODO(kradalby): Remove when we have generics... | ||||||
|  | func (machines MachinesP) String() string { | ||||||
|  | 	temp := make([]string, len(machines)) | ||||||
|  | 
 | ||||||
|  | 	for index, machine := range machines { | ||||||
|  | 		temp[index] = machine.Hostname | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								hscontrol/types/machine_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								hscontrol/types/machine_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | package types | ||||||
							
								
								
									
										58
									
								
								hscontrol/types/preauth_key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								hscontrol/types/preauth_key.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"google.golang.org/protobuf/types/known/timestamppb" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // PreAuthKey describes a pre-authorization key usable in a particular user. | ||||||
|  | type PreAuthKey struct { | ||||||
|  | 	ID        uint64 `gorm:"primary_key"` | ||||||
|  | 	Key       string | ||||||
|  | 	UserID    uint | ||||||
|  | 	User      User | ||||||
|  | 	Reusable  bool | ||||||
|  | 	Ephemeral bool `gorm:"default:false"` | ||||||
|  | 	Used      bool `gorm:"default:false"` | ||||||
|  | 	ACLTags   []PreAuthKeyACLTag | ||||||
|  | 
 | ||||||
|  | 	CreatedAt  *time.Time | ||||||
|  | 	Expiration *time.Time | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. | ||||||
|  | type PreAuthKeyACLTag struct { | ||||||
|  | 	ID           uint64 `gorm:"primary_key"` | ||||||
|  | 	PreAuthKeyID uint64 | ||||||
|  | 	Tag          string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (key *PreAuthKey) Proto() *v1.PreAuthKey { | ||||||
|  | 	protoKey := v1.PreAuthKey{ | ||||||
|  | 		User:      key.User.Name, | ||||||
|  | 		Id:        strconv.FormatUint(key.ID, util.Base10), | ||||||
|  | 		Key:       key.Key, | ||||||
|  | 		Ephemeral: key.Ephemeral, | ||||||
|  | 		Reusable:  key.Reusable, | ||||||
|  | 		Used:      key.Used, | ||||||
|  | 		AclTags:   make([]string, len(key.ACLTags)), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if key.Expiration != nil { | ||||||
|  | 		protoKey.Expiration = timestamppb.New(*key.Expiration) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if key.CreatedAt != nil { | ||||||
|  | 		protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for idx := range key.ACLTags { | ||||||
|  | 		protoKey.AclTags[idx] = key.ACLTags[idx].Tag | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &protoKey | ||||||
|  | } | ||||||
							
								
								
									
										71
									
								
								hscontrol/types/routes.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								hscontrol/types/routes.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,71 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
|  | 
 | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"google.golang.org/protobuf/types/known/timestamppb" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") | ||||||
|  | 	ExitRouteV6 = netip.MustParsePrefix("::/0") | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Route struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 
 | ||||||
|  | 	MachineID uint64 | ||||||
|  | 	Machine   Machine | ||||||
|  | 	Prefix    IPPrefix | ||||||
|  | 
 | ||||||
|  | 	Advertised bool | ||||||
|  | 	Enabled    bool | ||||||
|  | 	IsPrimary  bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Routes []Route | ||||||
|  | 
 | ||||||
|  | func (r *Route) String() string { | ||||||
|  | 	return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (r *Route) IsExitRoute() bool { | ||||||
|  | 	return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rs Routes) Prefixes() []netip.Prefix { | ||||||
|  | 	prefixes := make([]netip.Prefix, len(rs)) | ||||||
|  | 	for i, r := range rs { | ||||||
|  | 		prefixes[i] = netip.Prefix(r.Prefix) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return prefixes | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rs Routes) Proto() []*v1.Route { | ||||||
|  | 	protoRoutes := []*v1.Route{} | ||||||
|  | 
 | ||||||
|  | 	for _, route := range rs { | ||||||
|  | 		protoRoute := v1.Route{ | ||||||
|  | 			Id:         uint64(route.ID), | ||||||
|  | 			Machine:    route.Machine.Proto(), | ||||||
|  | 			Prefix:     netip.Prefix(route.Prefix).String(), | ||||||
|  | 			Advertised: route.Advertised, | ||||||
|  | 			Enabled:    route.Enabled, | ||||||
|  | 			IsPrimary:  route.IsPrimary, | ||||||
|  | 			CreatedAt:  timestamppb.New(route.CreatedAt), | ||||||
|  | 			UpdatedAt:  timestamppb.New(route.UpdatedAt), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if route.DeletedAt.Valid { | ||||||
|  | 			protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		protoRoutes = append(protoRoutes, &protoRoute) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return protoRoutes | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								hscontrol/types/users.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								hscontrol/types/users.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | |||||||
|  | package types | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
|  | 	"google.golang.org/protobuf/types/known/timestamppb" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // User is the way Headscale implements the concept of users in Tailscale | ||||||
|  | // | ||||||
|  | // At the end of the day, users in Tailscale are some kind of 'bubbles' or users | ||||||
|  | // that contain our machines. | ||||||
|  | type User struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 	Name string `gorm:"unique"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (n *User) TailscaleUser() *tailcfg.User { | ||||||
|  | 	user := tailcfg.User{ | ||||||
|  | 		ID:            tailcfg.UserID(n.ID), | ||||||
|  | 		LoginName:     n.Name, | ||||||
|  | 		DisplayName:   n.Name, | ||||||
|  | 		ProfilePicURL: "", | ||||||
|  | 		Domain:        "headscale.net", | ||||||
|  | 		Logins:        []tailcfg.LoginID{}, | ||||||
|  | 		Created:       time.Time{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &user | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (n *User) TailscaleLogin() *tailcfg.Login { | ||||||
|  | 	login := tailcfg.Login{ | ||||||
|  | 		ID:            tailcfg.LoginID(n.ID), | ||||||
|  | 		LoginName:     n.Name, | ||||||
|  | 		DisplayName:   n.Name, | ||||||
|  | 		ProfilePicURL: "", | ||||||
|  | 		Domain:        "headscale.net", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &login | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (n *User) Proto() *v1.User { | ||||||
|  | 	return &v1.User{ | ||||||
|  | 		Id:        strconv.FormatUint(uint64(n.ID), util.Base10), | ||||||
|  | 		Name:      n.Name, | ||||||
|  | 		CreatedAt: timestamppb.New(n.CreatedAt), | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -1,415 +0,0 @@ | |||||||
| package hscontrol |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"net/netip" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"gopkg.in/check.v1" |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestCreateAndDestroyUser(c *check.C) { |  | ||||||
| 	user, err := app.db.CreateUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(user.Name, check.Equals, "test") |  | ||||||
| 
 |  | ||||||
| 	users, err := app.db.ListUsers() |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(len(users), check.Equals, 1) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.DestroyUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetUser("test") |  | ||||||
| 	c.Assert(err, check.NotNil) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestDestroyUserErrors(c *check.C) { |  | ||||||
| 	err := app.db.DestroyUser("test") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) |  | ||||||
| 
 |  | ||||||
| 	user, err := app.db.CreateUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.DestroyUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) |  | ||||||
| 	// destroying a user also deletes all associated preauthkeys |  | ||||||
| 	c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) |  | ||||||
| 
 |  | ||||||
| 	user, err = app.db.CreateUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	machine := Machine{ |  | ||||||
| 		ID:             0, |  | ||||||
| 		MachineKey:     "foo", |  | ||||||
| 		NodeKey:        "bar", |  | ||||||
| 		DiscoKey:       "faa", |  | ||||||
| 		Hostname:       "testmachine", |  | ||||||
| 		UserID:         user.ID, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		AuthKeyID:      uint(pak.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(&machine) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.DestroyUser("test") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserStillHasNodes) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestRenameUser(c *check.C) { |  | ||||||
| 	userTest, err := app.db.CreateUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(userTest.Name, check.Equals, "test") |  | ||||||
| 
 |  | ||||||
| 	users, err := app.db.ListUsers() |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(len(users), check.Equals, 1) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.RenameUser("test", "test-renamed") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetUser("test") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetUser("test-renamed") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.RenameUser("test-does-not-exit", "test") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) |  | ||||||
| 
 |  | ||||||
| 	userTest2, err := app.db.CreateUser("test2") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(userTest2.Name, check.Equals, "test2") |  | ||||||
| 
 |  | ||||||
| 	err = app.db.RenameUser("test2", "test-renamed") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserExists) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { |  | ||||||
| 	userShared1, err := app.db.CreateUser("shared1") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	userShared2, err := app.db.CreateUser("shared2") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	userShared3, err := app.db.CreateUser("shared3") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	preAuthKeyShared1, err := app.db.CreatePreAuthKey( |  | ||||||
| 		userShared1.Name, |  | ||||||
| 		false, |  | ||||||
| 		false, |  | ||||||
| 		nil, |  | ||||||
| 		nil, |  | ||||||
| 	) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	preAuthKeyShared2, err := app.db.CreatePreAuthKey( |  | ||||||
| 		userShared2.Name, |  | ||||||
| 		false, |  | ||||||
| 		false, |  | ||||||
| 		nil, |  | ||||||
| 		nil, |  | ||||||
| 	) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	preAuthKeyShared3, err := app.db.CreatePreAuthKey( |  | ||||||
| 		userShared3.Name, |  | ||||||
| 		false, |  | ||||||
| 		false, |  | ||||||
| 		nil, |  | ||||||
| 		nil, |  | ||||||
| 	) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	preAuthKey2Shared1, err := app.db.CreatePreAuthKey( |  | ||||||
| 		userShared1.Name, |  | ||||||
| 		false, |  | ||||||
| 		false, |  | ||||||
| 		nil, |  | ||||||
| 		nil, |  | ||||||
| 	) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") |  | ||||||
| 	c.Assert(err, check.NotNil) |  | ||||||
| 
 |  | ||||||
| 	machineInShared1 := &Machine{ |  | ||||||
| 		ID:             1, |  | ||||||
| 		MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", |  | ||||||
| 		NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", |  | ||||||
| 		DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", |  | ||||||
| 		Hostname:       "test_get_shared_nodes_1", |  | ||||||
| 		UserID:         userShared1.ID, |  | ||||||
| 		User:           *userShared1, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.1")}, |  | ||||||
| 		AuthKeyID:      uint(preAuthKeyShared1.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(machineInShared1) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	machineInShared2 := &Machine{ |  | ||||||
| 		ID:             2, |  | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		Hostname:       "test_get_shared_nodes_2", |  | ||||||
| 		UserID:         userShared2.ID, |  | ||||||
| 		User:           *userShared2, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.2")}, |  | ||||||
| 		AuthKeyID:      uint(preAuthKeyShared2.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(machineInShared2) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	machineInShared3 := &Machine{ |  | ||||||
| 		ID:             3, |  | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		Hostname:       "test_get_shared_nodes_3", |  | ||||||
| 		UserID:         userShared3.ID, |  | ||||||
| 		User:           *userShared3, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.3")}, |  | ||||||
| 		AuthKeyID:      uint(preAuthKeyShared3.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(machineInShared3) |  | ||||||
| 
 |  | ||||||
| 	_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	machine2InShared1 := &Machine{ |  | ||||||
| 		ID:             4, |  | ||||||
| 		MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", |  | ||||||
| 		Hostname:       "test_get_shared_nodes_4", |  | ||||||
| 		UserID:         userShared1.ID, |  | ||||||
| 		User:           *userShared1, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		IPAddresses:    []netip.Addr{netip.MustParseAddr("100.64.0.4")}, |  | ||||||
| 		AuthKeyID:      uint(preAuthKey2Shared1.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(machine2InShared1) |  | ||||||
| 
 |  | ||||||
| 	peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	userProfiles := app.db.getMapResponseUserProfiles( |  | ||||||
| 		*machineInShared1, |  | ||||||
| 		peersOfMachine1InShared1, |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	c.Assert(len(userProfiles), check.Equals, 3) |  | ||||||
| 
 |  | ||||||
| 	found := false |  | ||||||
| 	for _, userProfiles := range userProfiles { |  | ||||||
| 		if userProfiles.DisplayName == userShared1.Name { |  | ||||||
| 			found = true |  | ||||||
| 
 |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	c.Assert(found, check.Equals, true) |  | ||||||
| 
 |  | ||||||
| 	found = false |  | ||||||
| 	for _, userProfile := range userProfiles { |  | ||||||
| 		if userProfile.DisplayName == userShared2.Name { |  | ||||||
| 			found = true |  | ||||||
| 
 |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	c.Assert(found, check.Equals, true) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestNormalizeToFQDNRules(t *testing.T) { |  | ||||||
| 	type args struct { |  | ||||||
| 		name             string |  | ||||||
| 		stripEmailDomain bool |  | ||||||
| 	} |  | ||||||
| 	tests := []struct { |  | ||||||
| 		name    string |  | ||||||
| 		args    args |  | ||||||
| 		want    string |  | ||||||
| 		wantErr bool |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			name: "normalize simple name", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "normalize-simple.name", |  | ||||||
| 				stripEmailDomain: false, |  | ||||||
| 			}, |  | ||||||
| 			want:    "normalize-simple.name", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "normalize an email", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "foo.bar@example.com", |  | ||||||
| 				stripEmailDomain: false, |  | ||||||
| 			}, |  | ||||||
| 			want:    "foo.bar.example.com", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "normalize an email domain should be removed", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "foo.bar@example.com", |  | ||||||
| 				stripEmailDomain: true, |  | ||||||
| 			}, |  | ||||||
| 			want:    "foo.bar", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "strip enabled no email passed as argument", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "not-email-and-strip-enabled", |  | ||||||
| 				stripEmailDomain: true, |  | ||||||
| 			}, |  | ||||||
| 			want:    "not-email-and-strip-enabled", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "normalize complex email", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "foo.bar+complex-email@example.com", |  | ||||||
| 				stripEmailDomain: false, |  | ||||||
| 			}, |  | ||||||
| 			want:    "foo.bar-complex-email.example.com", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "user name with space", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "name space", |  | ||||||
| 				stripEmailDomain: false, |  | ||||||
| 			}, |  | ||||||
| 			want:    "name-space", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "user with quote", |  | ||||||
| 			args: args{ |  | ||||||
| 				name:             "Jamie's iPhone 5", |  | ||||||
| 				stripEmailDomain: false, |  | ||||||
| 			}, |  | ||||||
| 			want:    "jamies-iphone-5", |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for _, tt := range tests { |  | ||||||
| 		t.Run(tt.name, func(t *testing.T) { |  | ||||||
| 			got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) |  | ||||||
| 			if (err != nil) != tt.wantErr { |  | ||||||
| 				t.Errorf( |  | ||||||
| 					"NormalizeToFQDNRules() error = %v, wantErr %v", |  | ||||||
| 					err, |  | ||||||
| 					tt.wantErr, |  | ||||||
| 				) |  | ||||||
| 
 |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if got != tt.want { |  | ||||||
| 				t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestCheckForFQDNRules(t *testing.T) { |  | ||||||
| 	type args struct { |  | ||||||
| 		name string |  | ||||||
| 	} |  | ||||||
| 	tests := []struct { |  | ||||||
| 		name    string |  | ||||||
| 		args    args |  | ||||||
| 		wantErr bool |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			name:    "valid: user", |  | ||||||
| 			args:    args{name: "valid-user"}, |  | ||||||
| 			wantErr: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name:    "invalid: capitalized user", |  | ||||||
| 			args:    args{name: "Invalid-CapItaLIzed-user"}, |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name:    "invalid: email as user", |  | ||||||
| 			args:    args{name: "foo.bar@example.com"}, |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name:    "invalid: chars in user name", |  | ||||||
| 			args:    args{name: "super-user+name"}, |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "invalid: too long name for user", |  | ||||||
| 			args: args{ |  | ||||||
| 				name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", |  | ||||||
| 			}, |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for _, tt := range tests { |  | ||||||
| 		t.Run(tt.name, func(t *testing.T) { |  | ||||||
| 			if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { |  | ||||||
| 				t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestSetMachineUser(c *check.C) { |  | ||||||
| 	oldUser, err := app.db.CreateUser("old") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	newUser, err := app.db.CreateUser("new") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	machine := Machine{ |  | ||||||
| 		ID:             0, |  | ||||||
| 		MachineKey:     "foo", |  | ||||||
| 		NodeKey:        "bar", |  | ||||||
| 		DiscoKey:       "faa", |  | ||||||
| 		Hostname:       "testmachine", |  | ||||||
| 		UserID:         oldUser.ID, |  | ||||||
| 		RegisterMethod: RegisterMethodAuthKey, |  | ||||||
| 		AuthKeyID:      uint(pak.ID), |  | ||||||
| 	} |  | ||||||
| 	app.db.db.Save(&machine) |  | ||||||
| 	c.Assert(machine.UserID, check.Equals, oldUser.ID) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.SetMachineUser(&machine, newUser.Name) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) |  | ||||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.SetMachineUser(&machine, "non-existing-user") |  | ||||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) |  | ||||||
| 
 |  | ||||||
| 	err = app.db.SetMachineUser(&machine, newUser.Name) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) |  | ||||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) |  | ||||||
| } |  | ||||||
| @ -1,12 +1,94 @@ | |||||||
| package util | package util | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"go4.org/netipx" | 	"go4.org/netipx" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // This is borrowed from, and updated to use IPSet | ||||||
|  | // https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 | ||||||
|  | // TODO(kradalby): contribute upstream and make public. | ||||||
|  | var ( | ||||||
|  | 	zeroIP4 = netip.AddrFrom4([4]byte{}) | ||||||
|  | 	zeroIP6 = netip.AddrFrom16([16]byte{}) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // parseIPSet parses arg as one: | ||||||
|  | // | ||||||
|  | //   - an IP address (IPv4 or IPv6) | ||||||
|  | //   - the string "*" to match everything (both IPv4 & IPv6) | ||||||
|  | //   - a CIDR (e.g. "192.168.0.0/16") | ||||||
|  | //   - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") | ||||||
|  | // | ||||||
|  | // bits, if non-nil, is the legacy SrcBits CIDR length to make a IP | ||||||
|  | // address (without a slash) treated as a CIDR of *bits length. | ||||||
|  | // nolint | ||||||
|  | func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) { | ||||||
|  | 	var ipSet netipx.IPSetBuilder | ||||||
|  | 	if arg == "*" { | ||||||
|  | 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) | ||||||
|  | 		ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) | ||||||
|  | 
 | ||||||
|  | 		return ipSet.IPSet() | ||||||
|  | 	} | ||||||
|  | 	if strings.Contains(arg, "/") { | ||||||
|  | 		pfx, err := netip.ParsePrefix(arg) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if pfx != pfx.Masked() { | ||||||
|  | 			return nil, fmt.Errorf("%v contains non-network bits set", pfx) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		ipSet.AddPrefix(pfx) | ||||||
|  | 
 | ||||||
|  | 		return ipSet.IPSet() | ||||||
|  | 	} | ||||||
|  | 	if strings.Count(arg, "-") == 1 { | ||||||
|  | 		ip1s, ip2s, _ := strings.Cut(arg, "-") | ||||||
|  | 
 | ||||||
|  | 		ip1, err := netip.ParseAddr(ip1s) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		ip2, err := netip.ParseAddr(ip2s) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		r := netipx.IPRangeFrom(ip1, ip2) | ||||||
|  | 		if !r.IsValid() { | ||||||
|  | 			return nil, fmt.Errorf("invalid IP range %q", arg) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, prefix := range r.Prefixes() { | ||||||
|  | 			ipSet.AddPrefix(prefix) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return ipSet.IPSet() | ||||||
|  | 	} | ||||||
|  | 	ip, err := netip.ParseAddr(arg) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("invalid IP address %q", arg) | ||||||
|  | 	} | ||||||
|  | 	bits8 := uint8(ip.BitLen()) | ||||||
|  | 	if bits != nil { | ||||||
|  | 		if *bits < 0 || *bits > int(bits8) { | ||||||
|  | 			return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) | ||||||
|  | 		} | ||||||
|  | 		bits8 = uint8(*bits) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) | ||||||
|  | 
 | ||||||
|  | 	return ipSet.IPSet() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { | func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { | ||||||
| 	var network, broadcast netip.Addr | 	var network, broadcast netip.Addr | ||||||
| 	ipRange := netipx.RangeOfPrefix(na) | 	ipRange := netipx.RangeOfPrefix(na) | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| package hscontrol | package util | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| @ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			got, err := parseIPSet(tt.args.arg, tt.args.bits) | 			got, err := ParseIPSet(tt.args.arg, tt.args.bits) | ||||||
| 			if (err != nil) != tt.wantErr { | 			if (err != nil) != tt.wantErr { | ||||||
| 				t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) | 				t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) | ||||||
| 
 | 
 | ||||||
							
								
								
									
										7
									
								
								hscontrol/util/const.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								hscontrol/util/const.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  | package util | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	RegisterMethodAuthKey = "authkey" | ||||||
|  | 	RegisterMethodOIDC    = "oidc" | ||||||
|  | 	RegisterMethodCLI     = "cli" | ||||||
|  | ) | ||||||
							
								
								
									
										69
									
								
								hscontrol/util/dns.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								hscontrol/util/dns.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,69 @@ | |||||||
|  | package util | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	// value related to RFC 1123 and 952. | ||||||
|  | 	LabelHostnameLength = 63 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") | ||||||
|  | 
 | ||||||
|  | var ErrInvalidUserName = errors.New("invalid user name") | ||||||
|  | 
 | ||||||
|  | // NormalizeToFQDNRules will replace forbidden chars in user | ||||||
|  | // it can also return an error if the user doesn't respect RFC 952 and 1123. | ||||||
|  | func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { | ||||||
|  | 	name = strings.ToLower(name) | ||||||
|  | 	name = strings.ReplaceAll(name, "'", "") | ||||||
|  | 	atIdx := strings.Index(name, "@") | ||||||
|  | 	if stripEmailDomain && atIdx > 0 { | ||||||
|  | 		name = name[:atIdx] | ||||||
|  | 	} else { | ||||||
|  | 		name = strings.ReplaceAll(name, "@", ".") | ||||||
|  | 	} | ||||||
|  | 	name = invalidCharsInUserRegex.ReplaceAllString(name, "-") | ||||||
|  | 
 | ||||||
|  | 	for _, elt := range strings.Split(name, ".") { | ||||||
|  | 		if len(elt) > LabelHostnameLength { | ||||||
|  | 			return "", fmt.Errorf( | ||||||
|  | 				"label %v is more than 63 chars: %w", | ||||||
|  | 				elt, | ||||||
|  | 				ErrInvalidUserName, | ||||||
|  | 			) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return name, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CheckForFQDNRules(name string) error { | ||||||
|  | 	if len(name) > LabelHostnameLength { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", | ||||||
|  | 			name, | ||||||
|  | 			ErrInvalidUserName, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 	if strings.ToLower(name) != name { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"DNS segment should be lowercase. %v doesn't comply with this rule: %w", | ||||||
|  | 			name, | ||||||
|  | 			ErrInvalidUserName, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 	if invalidCharsInUserRegex.MatchString(name) { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", | ||||||
|  | 			name, | ||||||
|  | 			ErrInvalidUserName, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										143
									
								
								hscontrol/util/dns_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								hscontrol/util/dns_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | |||||||
|  | package util | ||||||
|  | 
 | ||||||
|  | import "testing" | ||||||
|  | 
 | ||||||
|  | func TestNormalizeToFQDNRules(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		name             string | ||||||
|  | 		stripEmailDomain bool | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name    string | ||||||
|  | 		args    args | ||||||
|  | 		want    string | ||||||
|  | 		wantErr bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "normalize simple name", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "normalize-simple.name", | ||||||
|  | 				stripEmailDomain: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    "normalize-simple.name", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "normalize an email", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "foo.bar@example.com", | ||||||
|  | 				stripEmailDomain: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    "foo.bar.example.com", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "normalize an email domain should be removed", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "foo.bar@example.com", | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			want:    "foo.bar", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "strip enabled no email passed as argument", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "not-email-and-strip-enabled", | ||||||
|  | 				stripEmailDomain: true, | ||||||
|  | 			}, | ||||||
|  | 			want:    "not-email-and-strip-enabled", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "normalize complex email", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "foo.bar+complex-email@example.com", | ||||||
|  | 				stripEmailDomain: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    "foo.bar-complex-email.example.com", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "user name with space", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "name space", | ||||||
|  | 				stripEmailDomain: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    "name-space", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "user with quote", | ||||||
|  | 			args: args{ | ||||||
|  | 				name:             "Jamie's iPhone 5", | ||||||
|  | 				stripEmailDomain: false, | ||||||
|  | 			}, | ||||||
|  | 			want:    "jamies-iphone-5", | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) | ||||||
|  | 			if (err != nil) != tt.wantErr { | ||||||
|  | 				t.Errorf( | ||||||
|  | 					"NormalizeToFQDNRules() error = %v, wantErr %v", | ||||||
|  | 					err, | ||||||
|  | 					tt.wantErr, | ||||||
|  | 				) | ||||||
|  | 
 | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			if got != tt.want { | ||||||
|  | 				t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCheckForFQDNRules(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		name string | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name    string | ||||||
|  | 		args    args | ||||||
|  | 		wantErr bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:    "valid: user", | ||||||
|  | 			args:    args{name: "valid-user"}, | ||||||
|  | 			wantErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:    "invalid: capitalized user", | ||||||
|  | 			args:    args{name: "Invalid-CapItaLIzed-user"}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:    "invalid: email as user", | ||||||
|  | 			args:    args{name: "foo.bar@example.com"}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:    "invalid: chars in user name", | ||||||
|  | 			args:    args{name: "super-user+name"}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "invalid: too long name for user", | ||||||
|  | 			args: args{ | ||||||
|  | 				name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", | ||||||
|  | 			}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { | ||||||
|  | 				t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/juanfont/headscale/hscontrol" | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
| 	"github.com/juanfont/headscale/integration/hsic" | 	"github.com/juanfont/headscale/integration/hsic" | ||||||
| 	"github.com/juanfont/headscale/integration/tsic" | 	"github.com/juanfont/headscale/integration/tsic" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| @ -45,7 +45,7 @@ var veryLargeDestination = []string{ | |||||||
| 	"208.0.0.0/4:*", | 	"208.0.0.0/4:*", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario { | func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { | ||||||
| 	t.Helper() | 	t.Helper() | ||||||
| 	scenario, err := NewScenario() | 	scenario, err := NewScenario() | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| @ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 	// they can access minus one (them self). | 	// they can access minus one (them self). | ||||||
| 	tests := map[string]struct { | 	tests := map[string]struct { | ||||||
| 		users  map[string]int | 		users  map[string]int | ||||||
| 		policy hscontrol.ACLPolicy | 		policy policy.ACLPolicy | ||||||
| 		want   map[string]int | 		want   map[string]int | ||||||
| 	}{ | 	}{ | ||||||
| 		// Test that when we have no ACL, each client netmap has | 		// Test that when we have no ACL, each client netmap has | ||||||
| @ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 				"user1": 2, | 				"user1": 2, | ||||||
| 				"user2": 2, | 				"user2": 2, | ||||||
| 			}, | 			}, | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| @ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 				"user1": 2, | 				"user1": 2, | ||||||
| 				"user2": 2, | 				"user2": 2, | ||||||
| 			}, | 			}, | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"user1"}, | 						Sources:      []string{"user1"}, | ||||||
| @ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 				"user1": 2, | 				"user1": 2, | ||||||
| 				"user2": 2, | 				"user2": 2, | ||||||
| 			}, | 			}, | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"user1"}, | 						Sources:      []string{"user1"}, | ||||||
| @ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 				"user1": 2, | 				"user1": 2, | ||||||
| 				"user2": 2, | 				"user2": 2, | ||||||
| 			}, | 			}, | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"user1"}, | 						Sources:      []string{"user1"}, | ||||||
| @ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { | |||||||
| 				"user1": 2, | 				"user1": 2, | ||||||
| 				"user2": 2, | 				"user2": 2, | ||||||
| 			}, | 			}, | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"user1"}, | 						Sources:      []string{"user1"}, | ||||||
| @ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	scenario := aclScenario(t, | 	scenario := aclScenario(t, | ||||||
| 		&hscontrol.ACLPolicy{ | 		&policy.ACLPolicy{ | ||||||
| 			ACLs: []hscontrol.ACL{ | 			ACLs: []policy.ACL{ | ||||||
| 				{ | 				{ | ||||||
| 					Action:       "accept", | 					Action:       "accept", | ||||||
| 					Sources:      []string{"user1"}, | 					Sources:      []string{"user1"}, | ||||||
| @ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	scenario := aclScenario(t, | 	scenario := aclScenario(t, | ||||||
| 		&hscontrol.ACLPolicy{ | 		&policy.ACLPolicy{ | ||||||
| 			Groups: map[string][]string{ | 			Groups: map[string][]string{ | ||||||
| 				"group:integration-acl-test": {"user1", "user2"}, | 				"group:integration-acl-test": {"user1", "user2"}, | ||||||
| 			}, | 			}, | ||||||
| 			ACLs: []hscontrol.ACL{ | 			ACLs: []policy.ACL{ | ||||||
| 				{ | 				{ | ||||||
| 					Action:       "accept", | 					Action:       "accept", | ||||||
| 					Sources:      []string{"group:integration-acl-test"}, | 					Sources:      []string{"group:integration-acl-test"}, | ||||||
| @ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	scenario := aclScenario(t, | 	scenario := aclScenario(t, | ||||||
| 		&hscontrol.ACLPolicy{ | 		&policy.ACLPolicy{ | ||||||
| 			ACLs: []hscontrol.ACL{ | 			ACLs: []policy.ACL{ | ||||||
| 				{ | 				{ | ||||||
| 					Action:       "accept", | 					Action:       "accept", | ||||||
| 					Sources:      []string{"user1"}, | 					Sources:      []string{"user1"}, | ||||||
| @ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	scenario := aclScenario(t, | 	scenario := aclScenario(t, | ||||||
| 		&hscontrol.ACLPolicy{ | 		&policy.ACLPolicy{ | ||||||
| 			ACLs: []hscontrol.ACL{ | 			ACLs: []policy.ACL{ | ||||||
| 				{ | 				{ | ||||||
| 					Action:       "accept", | 					Action:       "accept", | ||||||
| 					Sources:      []string{"user1"}, | 					Sources:      []string{"user1"}, | ||||||
| @ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	scenario := aclScenario(t, | 	scenario := aclScenario(t, | ||||||
| 		&hscontrol.ACLPolicy{ | 		&policy.ACLPolicy{ | ||||||
| 			Hosts: hscontrol.Hosts{ | 			Hosts: policy.Hosts{ | ||||||
| 				"all": netip.MustParsePrefix("100.64.0.0/24"), | 				"all": netip.MustParsePrefix("100.64.0.0/24"), | ||||||
| 			}, | 			}, | ||||||
| 			ACLs: []hscontrol.ACL{ | 			ACLs: []policy.ACL{ | ||||||
| 				// Everyone can curl test3 | 				// Everyone can curl test3 | ||||||
| 				{ | 				{ | ||||||
| 					Action:       "accept", | 					Action:       "accept", | ||||||
| @ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	tests := map[string]struct { | 	tests := map[string]struct { | ||||||
| 		policy hscontrol.ACLPolicy | 		policy policy.ACLPolicy | ||||||
| 	}{ | 	}{ | ||||||
| 		"ipv4": { | 		"ipv4": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				Hosts: hscontrol.Hosts{ | 				Hosts: policy.Hosts{ | ||||||
| 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||||
| 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||||
| 					"test3": netip.MustParsePrefix("100.64.0.3/32"), | 					"test3": netip.MustParsePrefix("100.64.0.3/32"), | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					// Everyone can curl test3 | 					// Everyone can curl test3 | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| @ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"ipv6": { | 		"ipv6": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				Hosts: hscontrol.Hosts{ | 				Hosts: policy.Hosts{ | ||||||
| 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | ||||||
| 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | ||||||
| 					"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), | 					"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					// Everyone can curl test3 | 					// Everyone can curl test3 | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| @ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | |||||||
| 	IntegrationSkip(t) | 	IntegrationSkip(t) | ||||||
| 
 | 
 | ||||||
| 	tests := map[string]struct { | 	tests := map[string]struct { | ||||||
| 		policy hscontrol.ACLPolicy | 		policy policy.ACLPolicy | ||||||
| 	}{ | 	}{ | ||||||
| 		"ipv4": { | 		"ipv4": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"100.64.0.1"}, | 						Sources:      []string{"100.64.0.1"}, | ||||||
| @ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"ipv6": { | 		"ipv6": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"fd7a:115c:a1e0::1"}, | 						Sources:      []string{"fd7a:115c:a1e0::1"}, | ||||||
| @ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"hostv4cidr": { | 		"hostv4cidr": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				Hosts: hscontrol.Hosts{ | 				Hosts: policy.Hosts{ | ||||||
| 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | 					"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||||
| 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | 					"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"test1"}, | 						Sources:      []string{"test1"}, | ||||||
| @ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"hostv6cidr": { | 		"hostv6cidr": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				Hosts: hscontrol.Hosts{ | 				Hosts: policy.Hosts{ | ||||||
| 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | 					"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), | ||||||
| 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | 					"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"test1"}, | 						Sources:      []string{"test1"}, | ||||||
| @ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"group": { | 		"group": { | ||||||
| 			policy: hscontrol.ACLPolicy{ | 			policy: policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:one": {"user1"}, | 					"group:one": {"user1"}, | ||||||
| 					"group:two": {"user2"}, | 					"group:two": {"user2"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"group:one"}, | 						Sources:      []string{"group:one"}, | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/davecgh/go-spew/spew" | 	"github.com/davecgh/go-spew/spew" | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| 	"github.com/juanfont/headscale/hscontrol" | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||||
| 	"github.com/juanfont/headscale/integration/integrationutil" | 	"github.com/juanfont/headscale/integration/integrationutil" | ||||||
| @ -60,7 +60,7 @@ type HeadscaleInContainer struct { | |||||||
| 	port             int | 	port             int | ||||||
| 	extraPorts       []string | 	extraPorts       []string | ||||||
| 	hostPortBindings map[string][]string | 	hostPortBindings map[string][]string | ||||||
| 	aclPolicy        *hscontrol.ACLPolicy | 	aclPolicy        *policy.ACLPolicy | ||||||
| 	env              map[string]string | 	env              map[string]string | ||||||
| 	tlsCert          []byte | 	tlsCert          []byte | ||||||
| 	tlsKey           []byte | 	tlsKey           []byte | ||||||
| @ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer) | |||||||
| 
 | 
 | ||||||
| // WithACLPolicy adds a hscontrol.ACLPolicy policy to the | // WithACLPolicy adds a hscontrol.ACLPolicy policy to the | ||||||
| // HeadscaleInContainer instance. | // HeadscaleInContainer instance. | ||||||
| func WithACLPolicy(acl *hscontrol.ACLPolicy) Option { | func WithACLPolicy(acl *policy.ACLPolicy) Option { | ||||||
| 	return func(hsic *HeadscaleInContainer) { | 	return func(hsic *HeadscaleInContainer) { | ||||||
| 		// TODO(kradalby): Move somewhere appropriate | 		// TODO(kradalby): Move somewhere appropriate | ||||||
| 		hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath | 		hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/juanfont/headscale/hscontrol" | 	"github.com/juanfont/headscale/hscontrol/policy" | ||||||
| 	"github.com/juanfont/headscale/integration/hsic" | 	"github.com/juanfont/headscale/integration/hsic" | ||||||
| 	"github.com/juanfont/headscale/integration/tsic" | 	"github.com/juanfont/headscale/integration/tsic" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| @ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) { | |||||||
| 	err = scenario.CreateHeadscaleEnv(spec, | 	err = scenario.CreateHeadscaleEnv(spec, | ||||||
| 		[]tsic.Option{tsic.WithSSH()}, | 		[]tsic.Option{tsic.WithSSH()}, | ||||||
| 		hsic.WithACLPolicy( | 		hsic.WithACLPolicy( | ||||||
| 			&hscontrol.ACLPolicy{ | 			&policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:integration-test": {"user1"}, | 					"group:integration-test": {"user1"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| 						Destinations: []string{"*:*"}, | 						Destinations: []string{"*:*"}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				SSHs: []hscontrol.SSH{ | 				SSHs: []policy.SSH{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"group:integration-test"}, | 						Sources:      []string{"group:integration-test"}, | ||||||
| @ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { | |||||||
| 	err = scenario.CreateHeadscaleEnv(spec, | 	err = scenario.CreateHeadscaleEnv(spec, | ||||||
| 		[]tsic.Option{tsic.WithSSH()}, | 		[]tsic.Option{tsic.WithSSH()}, | ||||||
| 		hsic.WithACLPolicy( | 		hsic.WithACLPolicy( | ||||||
| 			&hscontrol.ACLPolicy{ | 			&policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:integration-test": {"user1", "user2"}, | 					"group:integration-test": {"user1", "user2"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| 						Destinations: []string{"*:*"}, | 						Destinations: []string{"*:*"}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				SSHs: []hscontrol.SSH{ | 				SSHs: []policy.SSH{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"group:integration-test"}, | 						Sources:      []string{"group:integration-test"}, | ||||||
| @ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) { | |||||||
| 	err = scenario.CreateHeadscaleEnv(spec, | 	err = scenario.CreateHeadscaleEnv(spec, | ||||||
| 		[]tsic.Option{tsic.WithSSH()}, | 		[]tsic.Option{tsic.WithSSH()}, | ||||||
| 		hsic.WithACLPolicy( | 		hsic.WithACLPolicy( | ||||||
| 			&hscontrol.ACLPolicy{ | 			&policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:integration-test": {"user1"}, | 					"group:integration-test": {"user1"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| 						Destinations: []string{"*:*"}, | 						Destinations: []string{"*:*"}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				SSHs: []hscontrol.SSH{}, | 				SSHs: []policy.SSH{}, | ||||||
| 			}, | 			}, | ||||||
| 		), | 		), | ||||||
| 		hsic.WithTestName("sshnoneconfigured"), | 		hsic.WithTestName("sshnoneconfigured"), | ||||||
| @ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) { | |||||||
| 	err = scenario.CreateHeadscaleEnv(spec, | 	err = scenario.CreateHeadscaleEnv(spec, | ||||||
| 		[]tsic.Option{tsic.WithSSH()}, | 		[]tsic.Option{tsic.WithSSH()}, | ||||||
| 		hsic.WithACLPolicy( | 		hsic.WithACLPolicy( | ||||||
| 			&hscontrol.ACLPolicy{ | 			&policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:integration-test": {"user1"}, | 					"group:integration-test": {"user1"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| 						Destinations: []string{"*:80"}, | 						Destinations: []string{"*:80"}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				SSHs: []hscontrol.SSH{ | 				SSHs: []policy.SSH{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"group:integration-test"}, | 						Sources:      []string{"group:integration-test"}, | ||||||
| @ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) { | |||||||
| 	err = scenario.CreateHeadscaleEnv(spec, | 	err = scenario.CreateHeadscaleEnv(spec, | ||||||
| 		[]tsic.Option{tsic.WithSSH()}, | 		[]tsic.Option{tsic.WithSSH()}, | ||||||
| 		hsic.WithACLPolicy( | 		hsic.WithACLPolicy( | ||||||
| 			&hscontrol.ACLPolicy{ | 			&policy.ACLPolicy{ | ||||||
| 				Groups: map[string][]string{ | 				Groups: map[string][]string{ | ||||||
| 					"group:ssh1": {"useracl1"}, | 					"group:ssh1": {"useracl1"}, | ||||||
| 					"group:ssh2": {"useracl2"}, | 					"group:ssh2": {"useracl2"}, | ||||||
| 				}, | 				}, | ||||||
| 				ACLs: []hscontrol.ACL{ | 				ACLs: []policy.ACL{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"*"}, | 						Sources:      []string{"*"}, | ||||||
| 						Destinations: []string{"*:*"}, | 						Destinations: []string{"*:*"}, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				SSHs: []hscontrol.SSH{ | 				SSHs: []policy.SSH{ | ||||||
| 					{ | 					{ | ||||||
| 						Action:       "accept", | 						Action:       "accept", | ||||||
| 						Sources:      []string{"group:ssh1"}, | 						Sources:      []string{"group:ssh1"}, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user