mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	introduce rw lock for db, ish...
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									a1a3ff4ba8
								
							
						
					
					
						commit
						eff529f2c5
					
				| @ -309,7 +309,7 @@ func (h *Headscale) handleAuthKey( | ||||
| 
 | ||||
| 		machine.NodeKey = nodeKey | ||||
| 		machine.AuthKeyID = uint(pak.ID) | ||||
| 		err := h.db.RefreshMachine(machine, registerRequest.Expiry) | ||||
| 		err := h.db.MachineSetExpiry(machine, registerRequest.Expiry) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| @ -510,7 +510,8 @@ func (h *Headscale) handleMachineLogOut( | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Msg("Client requested logout") | ||||
| 
 | ||||
| 	err := h.db.ExpireMachine(&machine) | ||||
| 	now := time.Now() | ||||
| 	err := h.db.MachineSetExpiry(&machine, now) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -552,7 +553,7 @@ func (h *Headscale) handleMachineLogOut( | ||||
| 	} | ||||
| 
 | ||||
| 	if machine.IsEphemeral() { | ||||
| 		err = h.db.HardDeleteMachine(&machine) | ||||
| 		err = h.db.DeleteMachine(&machine) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Err(err). | ||||
|  | ||||
| @ -22,6 +22,9 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") | ||||
| func (hsdb *HSDatabase) CreateAPIKey( | ||||
| 	expiration *time.Time, | ||||
| ) (string, *types.APIKey, error) { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| @ -55,6 +58,9 @@ func (hsdb *HSDatabase) CreateAPIKey( | ||||
| 
 | ||||
| // ListAPIKeys returns the list of ApiKeys for a user. | ||||
| func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	keys := []types.APIKey{} | ||||
| 	if err := hsdb.db.Find(&keys).Error; err != nil { | ||||
| 		return nil, err | ||||
| @ -65,6 +71,9 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | ||||
| 
 | ||||
| // GetAPIKey returns a ApiKey for a given key. | ||||
| func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	key := types.APIKey{} | ||||
| 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | ||||
| 		return nil, result.Error | ||||
| @ -75,6 +84,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | ||||
| 
 | ||||
| // GetAPIKeyByID returns a ApiKey for a given id. | ||||
| func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	key := types.APIKey{} | ||||
| 	if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { | ||||
| 		return nil, result.Error | ||||
| @ -86,6 +98,9 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | ||||
| // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | ||||
| // does not exist. | ||||
| func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | ||||
| 		return result.Error | ||||
| 	} | ||||
| @ -95,6 +110,9 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | ||||
| 
 | ||||
| // ExpireAPIKey marks a ApiKey as expired. | ||||
| func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -103,6 +121,9 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	prefix, hash, found := strings.Cut(keyStr, ".") | ||||
| 	if !found { | ||||
| 		return false, ErrAPIKeyFailedToParse | ||||
|  | ||||
| @ -40,6 +40,8 @@ type HSDatabase struct { | ||||
| 	db       *gorm.DB | ||||
| 	notifier *notifier.Notifier | ||||
| 
 | ||||
| 	mu sync.RWMutex | ||||
| 
 | ||||
| 	ipAllocationMutex sync.Mutex | ||||
| 
 | ||||
| 	ipPrefixes []netip.Prefix | ||||
|  | ||||
| @ -36,6 +36,13 @@ var ( | ||||
| 
 | ||||
| // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. | ||||
| func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listPeers(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listPeers(machine *types.Machine) (types.Machines, error) { | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", machine.Hostname). | ||||
| @ -63,6 +70,13 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listMachines() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listMachines() ([]types.Machine, error) { | ||||
| 	machines := []types.Machine{} | ||||
| 	if err := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -77,6 +91,13 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listMachinesByGivenName(givenName) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listMachinesByGivenName(givenName string) (types.Machines, error) { | ||||
| 	machines := types.Machines{} | ||||
| 	if err := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -92,6 +113,9 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine | ||||
| 
 | ||||
| // GetMachine finds a Machine by name and user and returns the Machine struct. | ||||
| func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	machines, err := hsdb.ListMachinesByUser(user) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -111,15 +135,17 @@ func (hsdb *HSDatabase) GetMachineByGivenName( | ||||
| 	user string, | ||||
| 	givenName string, | ||||
| ) (*types.Machine, error) { | ||||
| 	machines, err := hsdb.ListMachinesByUser(user) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	for _, m := range machines { | ||||
| 		if m.GivenName == givenName { | ||||
| 			return &m, nil | ||||
| 		} | ||||
| 	machine := types.Machine{} | ||||
| 	if err := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| 		Preload("AuthKey.User"). | ||||
| 		Preload("User"). | ||||
| 		Preload("Routes"). | ||||
| 		Where("given_name = ?", givenName).First(&machine).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, ErrMachineNotFound | ||||
| @ -127,6 +153,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( | ||||
| 
 | ||||
| // GetMachineByID finds a Machine by ID and returns the Machine struct. | ||||
| func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	mach := types.Machine{} | ||||
| 	if result := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -144,6 +173,9 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { | ||||
| func (hsdb *HSDatabase) GetMachineByMachineKey( | ||||
| 	machineKey key.MachinePublic, | ||||
| ) (*types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	mach := types.Machine{} | ||||
| 	if result := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -161,6 +193,9 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( | ||||
| func (hsdb *HSDatabase) GetMachineByNodeKey( | ||||
| 	nodeKey key.NodePublic, | ||||
| ) (*types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	machine := types.Machine{} | ||||
| 	if result := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -179,6 +214,9 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( | ||||
| func (hsdb *HSDatabase) GetMachineByAnyKey( | ||||
| 	machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, | ||||
| ) (*types.Machine, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	machine := types.Machine{} | ||||
| 	if result := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| @ -195,10 +233,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( | ||||
| 	return &machine, nil | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): rename this, it sounds like a mix of getting and setting to db | ||||
| // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database | ||||
| // and updates it with the latest data from the database. | ||||
| func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { | ||||
| func (hsdb *HSDatabase) MachineReloadFromDatabase(machine *types.Machine) error { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { | ||||
| 		return result.Error | ||||
| 	} | ||||
| @ -211,46 +249,36 @@ func (hsdb *HSDatabase) SetTags( | ||||
| 	machine *types.Machine, | ||||
| 	tags []string, | ||||
| ) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	newTags := []string{} | ||||
| 	for _, tag := range tags { | ||||
| 		if !util.StringOrPrefixListContains(newTags, tag) { | ||||
| 			newTags = append(newTags, tag) | ||||
| 		} | ||||
| 	} | ||||
| 	machine.ForcedTags = newTags | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		ForcedTags: newTags, | ||||
| 	}).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // ExpireMachine takes a Machine struct and sets the expire field to now. | ||||
| func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { | ||||
| 	now := time.Now() | ||||
| 	machine.Expiry = &now | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to expire machine in the database: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RenameMachine takes a Machine struct and a new GivenName for the machines | ||||
| // and renames it. | ||||
| func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	err := util.CheckForFQDNRules( | ||||
| 		newName, | ||||
| 	) | ||||
| @ -260,82 +288,93 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er | ||||
| 			Str("func", "RenameMachine"). | ||||
| 			Str("machine", machine.Hostname). | ||||
| 			Str("newName", newName). | ||||
| 			Err(err) | ||||
| 			Err(err). | ||||
| 			Msg("failed to rename machine") | ||||
| 
 | ||||
| 		return err | ||||
| 	} | ||||
| 	machine.GivenName = newName | ||||
| 
 | ||||
| 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		GivenName: newName, | ||||
| 	}).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to rename machine in the database: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to rename machine in the database: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RefreshMachine takes a Machine struct and  a new expiry time. | ||||
| func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { | ||||
| // MachineSetExpiry takes a Machine struct and  a new expiry time. | ||||
| func (hsdb *HSDatabase) MachineSetExpiry(machine *types.Machine, expiry time.Time) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.machineSetExpiry(machine, expiry) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) machineSetExpiry(machine *types.Machine, expiry time.Time) error { | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	machine.LastSuccessfulUpdate = &now | ||||
| 	machine.Expiry = &expiry | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		LastSuccessfulUpdate: &now, | ||||
| 		Expiry:               &expiry, | ||||
| 	}).Error; err != nil { | ||||
| 		return fmt.Errorf( | ||||
| 			"failed to refresh machine (update expiration) in the database: %w", | ||||
| 			err, | ||||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // DeleteMachine softs deletes a Machine from the database. | ||||
| // DeleteMachine deletes a Machine from the database. | ||||
| func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { | ||||
| 	err := hsdb.DeleteMachineRoutes(machine) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.deleteMachine(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) deleteMachine(machine *types.Machine) error { | ||||
| 	err := hsdb.deleteMachineRoutes(machine) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := hsdb.db.Delete(&machine).Error; err != nil { | ||||
| 	// Unscoped causes the machine to be fully removed from the database. | ||||
| 	if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerRemoved, | ||||
| 		Removed: []tailcfg.NodeID{tailcfg.NodeID(machine.ID)}, | ||||
| 	}) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { | ||||
| 	return hsdb.db.Updates(types.Machine{ | ||||
| 		ID:                   machine.ID, | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		LastSeen:             machine.LastSeen, | ||||
| 		LastSuccessfulUpdate: machine.LastSuccessfulUpdate, | ||||
| 	}).Error | ||||
| } | ||||
| 
 | ||||
| // HardDeleteMachine hard deletes a Machine from the database. | ||||
| func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { | ||||
| 	err := hsdb.DeleteMachineRoutes(machine) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||
| 	cache *cache.Cache, | ||||
| 	nodeKeyStr string, | ||||
| @ -343,6 +382,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||
| 	machineExpiry *time.Time, | ||||
| 	registrationMethod string, | ||||
| ) (*types.Machine, error) { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	nodeKey := key.NodePublic{} | ||||
| 	err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) | ||||
| 	if err != nil { | ||||
| @ -358,7 +400,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||
| 
 | ||||
| 	if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { | ||||
| 		if registrationMachine, ok := machineInterface.(types.Machine); ok { | ||||
| 			user, err := hsdb.GetUser(userName) | ||||
| 			user, err := hsdb.getUser(userName) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf( | ||||
| 					"failed to find user in register machine from auth callback, %w", | ||||
| @ -379,7 +421,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||
| 				registrationMachine.Expiry = machineExpiry | ||||
| 			} | ||||
| 
 | ||||
| 			machine, err := hsdb.RegisterMachine( | ||||
| 			machine, err := hsdb.registerMachine( | ||||
| 				registrationMachine, | ||||
| 			) | ||||
| 
 | ||||
| @ -397,8 +439,14 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||
| } | ||||
| 
 | ||||
| // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. | ||||
| func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, | ||||
| ) (*types.Machine, error) { | ||||
| func (hsdb *HSDatabase) RegisterMachine(machine types.Machine) (*types.Machine, error) { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.registerMachine(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) registerMachine(machine types.Machine) (*types.Machine, error) { | ||||
| 	log.Debug(). | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Str("machine_key", machine.MachineKey). | ||||
| @ -456,9 +504,12 @@ func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, | ||||
| 
 | ||||
| // MachineSetNodeKey sets the node key of a machine and saves it to the database. | ||||
| func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { | ||||
| 	machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		NodeKey: util.NodePublicKeyStripPrefix(nodeKey), | ||||
| 	}).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| @ -468,11 +519,14 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No | ||||
| // MachineSetMachineKey sets the machine key of a machine and saves it to the database. | ||||
| func (hsdb *HSDatabase) MachineSetMachineKey( | ||||
| 	machine *types.Machine, | ||||
| 	nodeKey key.MachinePublic, | ||||
| 	machineKey key.MachinePublic, | ||||
| ) error { | ||||
| 	machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||
| 		MachineKey: util.MachinePublicKeyStripPrefix(machineKey), | ||||
| 	}).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| @ -482,6 +536,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( | ||||
| // MachineSave saves a machine object to the database, prefer to use a specific save method rather | ||||
| // than this. It is intended to be used when we are changing or. | ||||
| func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -491,6 +548,13 @@ func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { | ||||
| 
 | ||||
| // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. | ||||
| func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getAdvertisedRoutes(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||
| 	routes := types.Routes{} | ||||
| 
 | ||||
| 	err := hsdb.db. | ||||
| @ -516,6 +580,13 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre | ||||
| 
 | ||||
| // GetEnabledRoutes returns the routes that are enabled for the machine. | ||||
| func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getEnabledRoutes(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||
| 	routes := types.Routes{} | ||||
| 
 | ||||
| 	err := hsdb.db. | ||||
| @ -541,12 +612,15 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	route, err := netip.ParsePrefix(routeStr) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	enabledRoutes, err := hsdb.GetEnabledRoutes(machine) | ||||
| 	enabledRoutes, err := hsdb.getEnabledRoutes(machine) | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Could not get enabled routes") | ||||
| 
 | ||||
| @ -575,7 +649,10 @@ func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { | ||||
| func (hsdb *HSDatabase) ListOnlineMachines( | ||||
| 	machine *types.Machine, | ||||
| ) (map[tailcfg.NodeID]bool, error) { | ||||
| 	peers, err := hsdb.ListPeers(machine) | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	peers, err := hsdb.listPeers(machine) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -595,7 +672,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | ||||
| 		newRoutes[index] = route | ||||
| 	} | ||||
| 
 | ||||
| 	advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) | ||||
| 	advertisedRoutes, err := hsdb.getAdvertisedRoutes(machine) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -642,7 +719,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { | ||||
| func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { | ||||
| 	normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( | ||||
| 		suppliedName, | ||||
| 	) | ||||
| @ -669,20 +746,23 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { | ||||
| 	givenName, err := hsdb.generateGivenName(suppliedName, false) | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	givenName, err := generateGivenName(suppliedName, false) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ | ||||
| 	machines, err := hsdb.ListMachinesByGivenName(givenName) | ||||
| 	machines, err := hsdb.listMachinesByGivenName(givenName) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, machine := range machines { | ||||
| 		if machine.MachineKey != machineKey && machine.GivenName == givenName { | ||||
| 			postfixedName, err := hsdb.generateGivenName(suppliedName, true) | ||||
| 			postfixedName, err := generateGivenName(suppliedName, true) | ||||
| 			if err != nil { | ||||
| 				return "", err | ||||
| 			} | ||||
| @ -695,7 +775,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { | ||||
| 	users, err := hsdb.ListUsers() | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	users, err := hsdb.listUsers() | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Error listing users") | ||||
| 
 | ||||
| @ -703,7 +786,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range users { | ||||
| 		machines, err := hsdb.ListMachinesByUser(user.Name) | ||||
| 		machines, err := hsdb.listMachinesByUser(user.Name) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Err(err). | ||||
| @ -724,7 +807,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | ||||
| 					Str("machine", machine.Hostname). | ||||
| 					Msg("Ephemeral client removed from database") | ||||
| 
 | ||||
| 				err = hsdb.HardDeleteMachine(&machines[idx]) | ||||
| 				err = hsdb.deleteMachine(&machines[idx]) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Err(err). | ||||
| @ -744,12 +827,15 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	// use the time of the start of the function to ensure we | ||||
| 	// dont miss some machines by returning it _after_ we have | ||||
| 	// checked everything. | ||||
| 	started := time.Now() | ||||
| 
 | ||||
| 	users, err := hsdb.ListUsers() | ||||
| 	users, err := hsdb.listUsers() | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Error listing users") | ||||
| 
 | ||||
| @ -757,7 +843,7 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range users { | ||||
| 		machines, err := hsdb.ListMachinesByUser(user.Name) | ||||
| 		machines, err := hsdb.listMachinesByUser(user.Name) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Err(err). | ||||
| @ -773,7 +859,8 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||
| 				machine.Expiry.After(lastCheck) { | ||||
| 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||
| 
 | ||||
| 				err := hsdb.ExpireMachine(&machines[index]) | ||||
| 				now := time.Now() | ||||
| 				err := hsdb.machineSetExpiry(&machines[index], now) | ||||
| 				if err != nil { | ||||
| 					log.Error(). | ||||
| 						Err(err). | ||||
|  | ||||
| @ -127,28 +127,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestDeleteMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	machine := types.Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Hostname:       "testmachine", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(1), | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.DeleteMachine(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(user.Name, "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestHardDeleteMachine(c *check.C) { | ||||
| 	user, err := db.CreateUser("test") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| @ -164,7 +142,7 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { | ||||
| 	} | ||||
| 	db.db.Save(&machine) | ||||
| 
 | ||||
| 	err = db.HardDeleteMachine(&machine) | ||||
| 	err = db.DeleteMachine(&machine) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = db.GetMachine(user.Name, "testmachine3") | ||||
| @ -329,7 +307,8 @@ func (s *Suite) TestExpireMachine(c *check.C) { | ||||
| 
 | ||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, false) | ||||
| 
 | ||||
| 	err = db.ExpireMachine(machineFromDB) | ||||
| 	now := time.Now() | ||||
| 	err = db.MachineSetExpiry(machineFromDB, now) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, true) | ||||
| @ -450,14 +429,12 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		db      *HSDatabase | ||||
| 		args    args | ||||
| 		want    *regexp.Regexp | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "simple machine name generation", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmachine", | ||||
| 				randomSuffix: false, | ||||
| @ -467,7 +444,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 53 chars", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | ||||
| 				randomSuffix: false, | ||||
| @ -477,7 +453,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| @ -487,7 +462,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 64 chars", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | ||||
| 				randomSuffix: false, | ||||
| @ -497,7 +471,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 73 chars", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: false, | ||||
| @ -507,7 +480,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with random suffix", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "test", | ||||
| 				randomSuffix: true, | ||||
| @ -517,7 +489,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "machine name with 63 chars with random suffix", | ||||
| 			db:   &HSDatabase{}, | ||||
| 			args: args{ | ||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||
| 				randomSuffix: true, | ||||
| @ -528,7 +499,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | ||||
| 			got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf( | ||||
| 					"Headscale.GenerateGivenName() error = %v, wantErr %v", | ||||
|  | ||||
| @ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 	expiration *time.Time, | ||||
| 	aclTags []string, | ||||
| ) (*types.PreAuthKey, error) { | ||||
| 	// TODO(kradalby): figure out this lock | ||||
| 	// hsdb.mu.Lock() | ||||
| 	// defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	user, err := hsdb.GetUser(userName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -92,7 +96,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | ||||
| 
 | ||||
| // ListPreAuthKeys returns the list of PreAuthKeys for a user. | ||||
| func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||
| 	user, err := hsdb.GetUser(userName) | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listPreAuthKeys(userName) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||
| 	user, err := hsdb.getUser(userName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -107,6 +118,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er | ||||
| 
 | ||||
| // GetPreAuthKey returns a PreAuthKey for a given key. | ||||
| func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	pak, err := hsdb.ValidatePreAuthKey(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -122,6 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe | ||||
| // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | ||||
| // does not exist. | ||||
| func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.destroyPreAuthKey(pak) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { | ||||
| 	return hsdb.db.Transaction(func(db *gorm.DB) error { | ||||
| 		if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { | ||||
| 			return result.Error | ||||
| @ -137,6 +158,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | ||||
| 
 | ||||
| // MarkExpirePreAuthKey marks a PreAuthKey as expired. | ||||
| func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -146,6 +170,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | ||||
| 
 | ||||
| // UsePreAuthKey marks a PreAuthKey as used. | ||||
| func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	k.Used = true | ||||
| 	if err := hsdb.db.Save(k).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to update key used status in the database: %w", err) | ||||
| @ -157,6 +184,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | ||||
| // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node | ||||
| // If returns no error and a PreAuthKey, it can be used. | ||||
| func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	pak := types.PreAuthKey{} | ||||
| 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | ||||
| 		result.Error, | ||||
| @ -174,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) | ||||
| 	} | ||||
| 
 | ||||
| 	machines := types.Machines{} | ||||
| 	if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | ||||
| 	if err := hsdb.db. | ||||
| 		Preload("AuthKey"). | ||||
| 		Where(&types.Machine{AuthKeyID: uint(pak.ID)}). | ||||
| 		Find(&machines).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -13,6 +13,13 @@ import ( | ||||
| var ErrRouteIsNotAvailable = errors.New("route is not available") | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getRoutes() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db.Preload("Machine").Find(&routes).Error | ||||
| 	if err != nil { | ||||
| @ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getMachineAdvertisedRoutes(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| @ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getMachineRoutes(m) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| @ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getRoute(id) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { | ||||
| 	var route types.Route | ||||
| 	err := hsdb.db.Preload("Machine").First(&route, id).Error | ||||
| 	if err != nil { | ||||
| @ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) EnableRoute(id uint64) error { | ||||
| 	route, err := hsdb.GetRoute(id) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.enableRoute(id) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) enableRoute(id uint64) error { | ||||
| 	route, err := hsdb.getRoute(id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 	route, err := hsdb.GetRoute(id) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	route, err := hsdb.getRoute(id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		return hsdb.HandlePrimarySubnetFailover() | ||||
| 		return hsdb.handlePrimarySubnetFailover() | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||
| 	routes, err := hsdb.getMachineRoutes(&route.Machine) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 	route, err := hsdb.GetRoute(id) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	route, err := hsdb.getRoute(id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		return hsdb.HandlePrimarySubnetFailover() | ||||
| 		return hsdb.handlePrimarySubnetFailover() | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | ||||
| 	routes, err := hsdb.getMachineRoutes(&route.Machine) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | ||||
| 	routes, err := hsdb.GetMachineRoutes(m) | ||||
| func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error { | ||||
| 	routes, err := hsdb.getMachineRoutes(m) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return hsdb.HandlePrimarySubnetFailover() | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| // isUniquePrefix returns if there is another machine providing the same route already. | ||||
| @ -201,6 +242,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro | ||||
| // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | ||||
| // Exit nodes are not considered for this, as they are never marked as Primary. | ||||
| func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| 		Preload("Machine"). | ||||
| @ -214,6 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.processMachineRoutes(machine) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error { | ||||
| 	currentRoutes := types.Routes{} | ||||
| 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | ||||
| 	if err != nil { | ||||
| @ -264,6 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	return hsdb.handlePrimarySubnetFailover() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||
| 	// first, get all the enabled routes | ||||
| 	var routes types.Routes | ||||
| 	err := hsdb.db. | ||||
| @ -388,11 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | ||||
| 	aclPolicy *policy.ACLPolicy, | ||||
| 	machine *types.Machine, | ||||
| ) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	if len(machine.IPAddresses) == 0 { | ||||
| 		return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := hsdb.GetMachineAdvertisedRoutes(machine) | ||||
| 	routes, err := hsdb.getMachineAdvertisedRoutes(machine) | ||||
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| @ -445,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | ||||
| 	} | ||||
| 
 | ||||
| 	for _, approvedRoute := range approvedRoutes { | ||||
| 		err := hsdb.EnableRoute(uint64(approvedRoute.ID)) | ||||
| 		err := hsdb.enableRoute(uint64(approvedRoute.ID)) | ||||
| 		if err != nil { | ||||
| 			log.Err(err). | ||||
| 				Str("approvedRoute", approvedRoute.String()). | ||||
|  | ||||
| @ -18,6 +18,9 @@ var ( | ||||
| // CreateUser creates a new User. Returns error if could not be created | ||||
| // or another user already exists. | ||||
| func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	err := util.CheckForFQDNRules(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -42,12 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | ||||
| // DestroyUser destroys a User. Returns error if the User does | ||||
| // not exist or if there are machines associated with it. | ||||
| func (hsdb *HSDatabase) DestroyUser(name string) error { | ||||
| 	user, err := hsdb.GetUser(name) | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	user, err := hsdb.getUser(name) | ||||
| 	if err != nil { | ||||
| 		return ErrUserNotFound | ||||
| 	} | ||||
| 
 | ||||
| 	machines, err := hsdb.ListMachinesByUser(name) | ||||
| 	machines, err := hsdb.listMachinesByUser(name) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -55,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { | ||||
| 		return ErrUserStillHasNodes | ||||
| 	} | ||||
| 
 | ||||
| 	keys, err := hsdb.ListPreAuthKeys(name) | ||||
| 	keys, err := hsdb.listPreAuthKeys(name) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, key := range keys { | ||||
| 		err = hsdb.DestroyPreAuthKey(key) | ||||
| 		err = hsdb.destroyPreAuthKey(key) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -76,8 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { | ||||
| // RenameUser renames a User. Returns error if the User does | ||||
| // not exist or if another User exists with the new name. | ||||
| func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	var err error | ||||
| 	oldUser, err := hsdb.GetUser(oldName) | ||||
| 	oldUser, err := hsdb.getUser(oldName) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -85,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = hsdb.GetUser(newName) | ||||
| 	_, err = hsdb.getUser(newName) | ||||
| 	if err == nil { | ||||
| 		return ErrUserExists | ||||
| 	} | ||||
| @ -104,6 +113,13 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||
| 
 | ||||
| // GetUser fetches a user by name. | ||||
| func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.getUser(name) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { | ||||
| 	user := types.User{} | ||||
| 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | ||||
| 		result.Error, | ||||
| @ -117,6 +133,13 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | ||||
| 
 | ||||
| // ListUsers gets all the existing users. | ||||
| func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listUsers() | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listUsers() ([]types.User, error) { | ||||
| 	users := []types.User{} | ||||
| 	if err := hsdb.db.Find(&users).Error; err != nil { | ||||
| 		return nil, err | ||||
| @ -127,11 +150,18 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | ||||
| 
 | ||||
| // ListMachinesByUser gets all the nodes in a given user. | ||||
| func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { | ||||
| 	hsdb.mu.RLock() | ||||
| 	defer hsdb.mu.RUnlock() | ||||
| 
 | ||||
| 	return hsdb.listMachinesByUser(name) | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) { | ||||
| 	err := util.CheckForFQDNRules(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	user, err := hsdb.GetUser(name) | ||||
| 	user, err := hsdb.getUser(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -144,13 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) | ||||
| 	return machines, nil | ||||
| } | ||||
| 
 | ||||
| // SetMachineUser assigns a Machine to a user. | ||||
| func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { | ||||
| // AssignMachineToUser assigns a Machine to a user. | ||||
| func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error { | ||||
| 	hsdb.mu.Lock() | ||||
| 	defer hsdb.mu.Unlock() | ||||
| 
 | ||||
| 	err := util.CheckForFQDNRules(username) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	user, err := hsdb.GetUser(username) | ||||
| 	user, err := hsdb.getUser(username) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @ -114,15 +114,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { | ||||
| 	db.db.Save(&machine) | ||||
| 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | ||||
| 	err = db.AssignMachineToUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, "non-existing-user") | ||||
| 	err = db.AssignMachineToUser(&machine, "non-existing-user") | ||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||
| 
 | ||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | ||||
| 	err = db.AssignMachineToUser(&machine, newUser.Name) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||
|  | ||||
| @ -275,8 +275,11 @@ func (api headscaleV1APIServer) ExpireMachine( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	api.h.db.ExpireMachine( | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	api.h.db.MachineSetExpiry( | ||||
| 		machine, | ||||
| 		now, | ||||
| 	) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| @ -358,7 +361,7 @@ func (api headscaleV1APIServer) MoveMachine( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	err = api.h.db.SetMachineUser(machine, request.GetUser()) | ||||
| 	err = api.h.db.AssignMachineToUser(machine, request.GetUser()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @ -523,7 +523,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 			Str("machine", machine.Hostname). | ||||
| 			Msg("machine already registered, reauthenticating") | ||||
| 
 | ||||
| 		err := h.db.RefreshMachine(machine, expiry) | ||||
| 		err := h.db.MachineSetExpiry(machine, expiry) | ||||
| 		if err != nil { | ||||
| 			util.LogErr(err, "Failed to refresh machine") | ||||
| 			http.Error( | ||||
|  | ||||
| @ -107,6 +107,7 @@ func (h *Headscale) handlePoll( | ||||
| 		machine.LastSeen = &now | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO(kradalby): Save specific stuff, not whole object. | ||||
| 	if err := h.db.MachineSave(machine); err != nil { | ||||
| 		logErr(err, "Failed to persist/update machine in the database") | ||||
| 		http.Error(writer, "", http.StatusInternalServerError) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user