mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	introduce rw lock for db, ish...
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									a1a3ff4ba8
								
							
						
					
					
						commit
						eff529f2c5
					
				| @ -309,7 +309,7 @@ func (h *Headscale) handleAuthKey( | |||||||
| 
 | 
 | ||||||
| 		machine.NodeKey = nodeKey | 		machine.NodeKey = nodeKey | ||||||
| 		machine.AuthKeyID = uint(pak.ID) | 		machine.AuthKeyID = uint(pak.ID) | ||||||
| 		err := h.db.RefreshMachine(machine, registerRequest.Expiry) | 		err := h.db.MachineSetExpiry(machine, registerRequest.Expiry) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Caller(). | 				Caller(). | ||||||
| @ -510,7 +510,8 @@ func (h *Headscale) handleMachineLogOut( | |||||||
| 		Str("machine", machine.Hostname). | 		Str("machine", machine.Hostname). | ||||||
| 		Msg("Client requested logout") | 		Msg("Client requested logout") | ||||||
| 
 | 
 | ||||||
| 	err := h.db.ExpireMachine(&machine) | 	now := time.Now() | ||||||
|  | 	err := h.db.MachineSetExpiry(&machine, now) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -552,7 +553,7 @@ func (h *Headscale) handleMachineLogOut( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if machine.IsEphemeral() { | 	if machine.IsEphemeral() { | ||||||
| 		err = h.db.HardDeleteMachine(&machine) | 		err = h.db.DeleteMachine(&machine) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Err(err). | 				Err(err). | ||||||
|  | |||||||
| @ -22,6 +22,9 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") | |||||||
| func (hsdb *HSDatabase) CreateAPIKey( | func (hsdb *HSDatabase) CreateAPIKey( | ||||||
| 	expiration *time.Time, | 	expiration *time.Time, | ||||||
| ) (string, *types.APIKey, error) { | ) (string, *types.APIKey, error) { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | 	prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", nil, err | 		return "", nil, err | ||||||
| @ -55,6 +58,9 @@ func (hsdb *HSDatabase) CreateAPIKey( | |||||||
| 
 | 
 | ||||||
| // ListAPIKeys returns the list of ApiKeys for a user. | // ListAPIKeys returns the list of ApiKeys for a user. | ||||||
| func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	keys := []types.APIKey{} | 	keys := []types.APIKey{} | ||||||
| 	if err := hsdb.db.Find(&keys).Error; err != nil { | 	if err := hsdb.db.Find(&keys).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -65,6 +71,9 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { | |||||||
| 
 | 
 | ||||||
| // GetAPIKey returns a ApiKey for a given key. | // GetAPIKey returns a ApiKey for a given key. | ||||||
| func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	key := types.APIKey{} | 	key := types.APIKey{} | ||||||
| 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | 	if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { | ||||||
| 		return nil, result.Error | 		return nil, result.Error | ||||||
| @ -75,6 +84,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { | |||||||
| 
 | 
 | ||||||
| // GetAPIKeyByID returns a ApiKey for a given id. | // GetAPIKeyByID returns a ApiKey for a given id. | ||||||
| func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	key := types.APIKey{} | 	key := types.APIKey{} | ||||||
| 	if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { | 	if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { | ||||||
| 		return nil, result.Error | 		return nil, result.Error | ||||||
| @ -86,6 +98,9 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { | |||||||
| // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey | ||||||
| // does not exist. | // does not exist. | ||||||
| func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | 	if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { | ||||||
| 		return result.Error | 		return result.Error | ||||||
| 	} | 	} | ||||||
| @ -95,6 +110,9 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { | |||||||
| 
 | 
 | ||||||
| // ExpireAPIKey marks a ApiKey as expired. | // ExpireAPIKey marks a ApiKey as expired. | ||||||
| func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | 	if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -103,6 +121,9 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { | func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	prefix, hash, found := strings.Cut(keyStr, ".") | 	prefix, hash, found := strings.Cut(keyStr, ".") | ||||||
| 	if !found { | 	if !found { | ||||||
| 		return false, ErrAPIKeyFailedToParse | 		return false, ErrAPIKeyFailedToParse | ||||||
|  | |||||||
| @ -40,6 +40,8 @@ type HSDatabase struct { | |||||||
| 	db       *gorm.DB | 	db       *gorm.DB | ||||||
| 	notifier *notifier.Notifier | 	notifier *notifier.Notifier | ||||||
| 
 | 
 | ||||||
|  | 	mu sync.RWMutex | ||||||
|  | 
 | ||||||
| 	ipAllocationMutex sync.Mutex | 	ipAllocationMutex sync.Mutex | ||||||
| 
 | 
 | ||||||
| 	ipPrefixes []netip.Prefix | 	ipPrefixes []netip.Prefix | ||||||
|  | |||||||
| @ -36,6 +36,13 @@ var ( | |||||||
| 
 | 
 | ||||||
| // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. | // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. | ||||||
| func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { | func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listPeers(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listPeers(machine *types.Machine) (types.Machines, error) { | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Caller(). | 		Caller(). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("machine", machine.Hostname). | ||||||
| @ -63,6 +70,13 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { | func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listMachines() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listMachines() ([]types.Machine, error) { | ||||||
| 	machines := []types.Machine{} | 	machines := []types.Machine{} | ||||||
| 	if err := hsdb.db. | 	if err := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -77,6 +91,13 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { | func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listMachinesByGivenName(givenName) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listMachinesByGivenName(givenName string) (types.Machines, error) { | ||||||
| 	machines := types.Machines{} | 	machines := types.Machines{} | ||||||
| 	if err := hsdb.db. | 	if err := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -92,6 +113,9 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine | |||||||
| 
 | 
 | ||||||
| // GetMachine finds a Machine by name and user and returns the Machine struct. | // GetMachine finds a Machine by name and user and returns the Machine struct. | ||||||
| func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { | func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	machines, err := hsdb.ListMachinesByUser(user) | 	machines, err := hsdb.ListMachinesByUser(user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -111,15 +135,17 @@ func (hsdb *HSDatabase) GetMachineByGivenName( | |||||||
| 	user string, | 	user string, | ||||||
| 	givenName string, | 	givenName string, | ||||||
| ) (*types.Machine, error) { | ) (*types.Machine, error) { | ||||||
| 	machines, err := hsdb.ListMachinesByUser(user) | 	hsdb.mu.RLock() | ||||||
| 	if err != nil { | 	defer hsdb.mu.RUnlock() | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	for _, m := range machines { | 	machine := types.Machine{} | ||||||
| 		if m.GivenName == givenName { | 	if err := hsdb.db. | ||||||
| 			return &m, nil | 		Preload("AuthKey"). | ||||||
| 		} | 		Preload("AuthKey.User"). | ||||||
|  | 		Preload("User"). | ||||||
|  | 		Preload("Routes"). | ||||||
|  | 		Where("given_name = ?", givenName).First(&machine).Error; err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil, ErrMachineNotFound | 	return nil, ErrMachineNotFound | ||||||
| @ -127,6 +153,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( | |||||||
| 
 | 
 | ||||||
| // GetMachineByID finds a Machine by ID and returns the Machine struct. | // GetMachineByID finds a Machine by ID and returns the Machine struct. | ||||||
| func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { | func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	mach := types.Machine{} | 	mach := types.Machine{} | ||||||
| 	if result := hsdb.db. | 	if result := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -144,6 +173,9 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { | |||||||
| func (hsdb *HSDatabase) GetMachineByMachineKey( | func (hsdb *HSDatabase) GetMachineByMachineKey( | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| ) (*types.Machine, error) { | ) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	mach := types.Machine{} | 	mach := types.Machine{} | ||||||
| 	if result := hsdb.db. | 	if result := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -161,6 +193,9 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( | |||||||
| func (hsdb *HSDatabase) GetMachineByNodeKey( | func (hsdb *HSDatabase) GetMachineByNodeKey( | ||||||
| 	nodeKey key.NodePublic, | 	nodeKey key.NodePublic, | ||||||
| ) (*types.Machine, error) { | ) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	machine := types.Machine{} | 	machine := types.Machine{} | ||||||
| 	if result := hsdb.db. | 	if result := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -179,6 +214,9 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( | |||||||
| func (hsdb *HSDatabase) GetMachineByAnyKey( | func (hsdb *HSDatabase) GetMachineByAnyKey( | ||||||
| 	machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, | 	machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, | ||||||
| ) (*types.Machine, error) { | ) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	machine := types.Machine{} | 	machine := types.Machine{} | ||||||
| 	if result := hsdb.db. | 	if result := hsdb.db. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| @ -195,10 +233,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( | |||||||
| 	return &machine, nil | 	return &machine, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TODO(kradalby): rename this, it sounds like a mix of getting and setting to db | func (hsdb *HSDatabase) MachineReloadFromDatabase(machine *types.Machine) error { | ||||||
| // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database | 	hsdb.mu.RLock() | ||||||
| // and updates it with the latest data from the database. | 	defer hsdb.mu.RUnlock() | ||||||
| func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { | 
 | ||||||
| 	if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { | 	if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { | ||||||
| 		return result.Error | 		return result.Error | ||||||
| 	} | 	} | ||||||
| @ -211,46 +249,36 @@ func (hsdb *HSDatabase) SetTags( | |||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| 	tags []string, | 	tags []string, | ||||||
| ) error { | ) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	newTags := []string{} | 	newTags := []string{} | ||||||
| 	for _, tag := range tags { | 	for _, tag := range tags { | ||||||
| 		if !util.StringOrPrefixListContains(newTags, tag) { | 		if !util.StringOrPrefixListContains(newTags, tag) { | ||||||
| 			newTags = append(newTags, tag) | 			newTags = append(newTags, tag) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	machine.ForcedTags = newTags |  | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
| 		Type:    types.StatePeerChanged, | 		ForcedTags: newTags, | ||||||
| 		Changed: []uint64{machine.ID}, | 	}).Error; err != nil { | ||||||
| 	}, machine.MachineKey) |  | ||||||
| 
 |  | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { |  | ||||||
| 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ExpireMachine takes a Machine struct and sets the expire field to now. |  | ||||||
| func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { |  | ||||||
| 	now := time.Now() |  | ||||||
| 	machine.Expiry = &now |  | ||||||
| 
 |  | ||||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
| 		Type:    types.StatePeerChanged, | 		Type:    types.StatePeerChanged, | ||||||
| 		Changed: []uint64{machine.ID}, | 		Changed: []uint64{machine.ID}, | ||||||
| 	}, machine.MachineKey) | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { |  | ||||||
| 		return fmt.Errorf("failed to expire machine in the database: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RenameMachine takes a Machine struct and a new GivenName for the machines | // RenameMachine takes a Machine struct and a new GivenName for the machines | ||||||
| // and renames it. | // and renames it. | ||||||
| func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { | func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	err := util.CheckForFQDNRules( | 	err := util.CheckForFQDNRules( | ||||||
| 		newName, | 		newName, | ||||||
| 	) | 	) | ||||||
| @ -260,82 +288,93 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er | |||||||
| 			Str("func", "RenameMachine"). | 			Str("func", "RenameMachine"). | ||||||
| 			Str("machine", machine.Hostname). | 			Str("machine", machine.Hostname). | ||||||
| 			Str("newName", newName). | 			Str("newName", newName). | ||||||
| 			Err(err) | 			Err(err). | ||||||
|  | 			Msg("failed to rename machine") | ||||||
| 
 | 
 | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	machine.GivenName = newName | 	machine.GivenName = newName | ||||||
| 
 | 
 | ||||||
|  | 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
|  | 		GivenName: newName, | ||||||
|  | 	}).Error; err != nil { | ||||||
|  | 		return fmt.Errorf("failed to rename machine in the database: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
| 		Type:    types.StatePeerChanged, | 		Type:    types.StatePeerChanged, | ||||||
| 		Changed: []uint64{machine.ID}, | 		Changed: []uint64{machine.ID}, | ||||||
| 	}, machine.MachineKey) | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { |  | ||||||
| 		return fmt.Errorf("failed to rename machine in the database: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshMachine takes a Machine struct and  a new expiry time. | // MachineSetExpiry takes a Machine struct and  a new expiry time. | ||||||
| func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { | func (hsdb *HSDatabase) MachineSetExpiry(machine *types.Machine, expiry time.Time) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.machineSetExpiry(machine, expiry) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) machineSetExpiry(machine *types.Machine, expiry time.Time) error { | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 
 | 
 | ||||||
| 	machine.LastSuccessfulUpdate = &now | 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
| 	machine.Expiry = &expiry | 		LastSuccessfulUpdate: &now, | ||||||
| 
 | 		Expiry:               &expiry, | ||||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | 	}).Error; err != nil { | ||||||
| 		Type:    types.StatePeerChanged, |  | ||||||
| 		Changed: []uint64{machine.ID}, |  | ||||||
| 	}, machine.MachineKey) |  | ||||||
| 
 |  | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { |  | ||||||
| 		return fmt.Errorf( | 		return fmt.Errorf( | ||||||
| 			"failed to refresh machine (update expiration) in the database: %w", | 			"failed to refresh machine (update expiration) in the database: %w", | ||||||
| 			err, | 			err, | ||||||
| 		) | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
|  | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // DeleteMachine softs deletes a Machine from the database. | // DeleteMachine deletes a Machine from the database. | ||||||
| func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { | func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { | ||||||
| 	err := hsdb.DeleteMachineRoutes(machine) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.deleteMachine(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) deleteMachine(machine *types.Machine) error { | ||||||
|  | 	err := hsdb.deleteMachineRoutes(machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Delete(&machine).Error; err != nil { | 	// Unscoped causes the machine to be fully removed from the database. | ||||||
|  | 	if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerRemoved, | ||||||
|  | 		Removed: []tailcfg.NodeID{tailcfg.NodeID(machine.ID)}, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { | func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { | ||||||
| 	return hsdb.db.Updates(types.Machine{ | 	hsdb.mu.Lock() | ||||||
| 		ID:                   machine.ID, | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
| 		LastSeen:             machine.LastSeen, | 		LastSeen:             machine.LastSeen, | ||||||
| 		LastSuccessfulUpdate: machine.LastSuccessfulUpdate, | 		LastSuccessfulUpdate: machine.LastSuccessfulUpdate, | ||||||
| 	}).Error | 	}).Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // HardDeleteMachine hard deletes a Machine from the database. |  | ||||||
| func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { |  | ||||||
| 	err := hsdb.DeleteMachineRoutes(machine) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | ||||||
| 	cache *cache.Cache, | 	cache *cache.Cache, | ||||||
| 	nodeKeyStr string, | 	nodeKeyStr string, | ||||||
| @ -343,6 +382,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | |||||||
| 	machineExpiry *time.Time, | 	machineExpiry *time.Time, | ||||||
| 	registrationMethod string, | 	registrationMethod string, | ||||||
| ) (*types.Machine, error) { | ) (*types.Machine, error) { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	nodeKey := key.NodePublic{} | 	nodeKey := key.NodePublic{} | ||||||
| 	err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) | 	err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -358,7 +400,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | |||||||
| 
 | 
 | ||||||
| 	if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { | 	if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { | ||||||
| 		if registrationMachine, ok := machineInterface.(types.Machine); ok { | 		if registrationMachine, ok := machineInterface.(types.Machine); ok { | ||||||
| 			user, err := hsdb.GetUser(userName) | 			user, err := hsdb.getUser(userName) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, fmt.Errorf( | 				return nil, fmt.Errorf( | ||||||
| 					"failed to find user in register machine from auth callback, %w", | 					"failed to find user in register machine from auth callback, %w", | ||||||
| @ -379,7 +421,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | |||||||
| 				registrationMachine.Expiry = machineExpiry | 				registrationMachine.Expiry = machineExpiry | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			machine, err := hsdb.RegisterMachine( | 			machine, err := hsdb.registerMachine( | ||||||
| 				registrationMachine, | 				registrationMachine, | ||||||
| 			) | 			) | ||||||
| 
 | 
 | ||||||
| @ -397,8 +439,14 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. | // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. | ||||||
| func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, | func (hsdb *HSDatabase) RegisterMachine(machine types.Machine) (*types.Machine, error) { | ||||||
| ) (*types.Machine, error) { | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.registerMachine(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) registerMachine(machine types.Machine) (*types.Machine, error) { | ||||||
| 	log.Debug(). | 	log.Debug(). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("machine", machine.Hostname). | ||||||
| 		Str("machine_key", machine.MachineKey). | 		Str("machine_key", machine.MachineKey). | ||||||
| @ -456,9 +504,12 @@ func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, | |||||||
| 
 | 
 | ||||||
| // MachineSetNodeKey sets the node key of a machine and saves it to the database. | // MachineSetNodeKey sets the node key of a machine and saves it to the database. | ||||||
| func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { | func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { | ||||||
| 	machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
|  | 		NodeKey: util.NodePublicKeyStripPrefix(nodeKey), | ||||||
|  | 	}).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -468,11 +519,14 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No | |||||||
| // MachineSetMachineKey sets the machine key of a machine and saves it to the database. | // MachineSetMachineKey sets the machine key of a machine and saves it to the database. | ||||||
| func (hsdb *HSDatabase) MachineSetMachineKey( | func (hsdb *HSDatabase) MachineSetMachineKey( | ||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| 	nodeKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| ) error { | ) error { | ||||||
| 	machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Model(machine).Updates(types.Machine{ | ||||||
|  | 		MachineKey: util.MachinePublicKeyStripPrefix(machineKey), | ||||||
|  | 	}).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -482,6 +536,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( | |||||||
| // MachineSave saves a machine object to the database, prefer to use a specific save method rather | // MachineSave saves a machine object to the database, prefer to use a specific save method rather | ||||||
| // than this. It is intended to be used when we are changing or. | // than this. It is intended to be used when we are changing or. | ||||||
| func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { | func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -491,6 +548,13 @@ func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { | |||||||
| 
 | 
 | ||||||
| // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. | // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. | ||||||
| func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { | func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getAdvertisedRoutes(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||||
| 	routes := types.Routes{} | 	routes := types.Routes{} | ||||||
| 
 | 
 | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| @ -516,6 +580,13 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre | |||||||
| 
 | 
 | ||||||
| // GetEnabledRoutes returns the routes that are enabled for the machine. | // GetEnabledRoutes returns the routes that are enabled for the machine. | ||||||
| func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { | func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getEnabledRoutes(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { | ||||||
| 	routes := types.Routes{} | 	routes := types.Routes{} | ||||||
| 
 | 
 | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| @ -541,12 +612,15 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { | func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	route, err := netip.ParsePrefix(routeStr) | 	route, err := netip.ParsePrefix(routeStr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	enabledRoutes, err := hsdb.GetEnabledRoutes(machine) | 	enabledRoutes, err := hsdb.getEnabledRoutes(machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Could not get enabled routes") | 		log.Error().Err(err).Msg("Could not get enabled routes") | ||||||
| 
 | 
 | ||||||
| @ -575,7 +649,10 @@ func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { | |||||||
| func (hsdb *HSDatabase) ListOnlineMachines( | func (hsdb *HSDatabase) ListOnlineMachines( | ||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| ) (map[tailcfg.NodeID]bool, error) { | ) (map[tailcfg.NodeID]bool, error) { | ||||||
| 	peers, err := hsdb.ListPeers(machine) | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	peers, err := hsdb.listPeers(machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -595,7 +672,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | |||||||
| 		newRoutes[index] = route | 		newRoutes[index] = route | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) | 	advertisedRoutes, err := hsdb.getAdvertisedRoutes(machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -642,7 +719,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { | func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { | ||||||
| 	normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( | 	normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( | ||||||
| 		suppliedName, | 		suppliedName, | ||||||
| 	) | 	) | ||||||
| @ -669,20 +746,23 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { | func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { | ||||||
| 	givenName, err := hsdb.generateGivenName(suppliedName, false) | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	givenName, err := generateGivenName(suppliedName, false) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ | 	// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ | ||||||
| 	machines, err := hsdb.ListMachinesByGivenName(givenName) | 	machines, err := hsdb.listMachinesByGivenName(givenName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, machine := range machines { | 	for _, machine := range machines { | ||||||
| 		if machine.MachineKey != machineKey && machine.GivenName == givenName { | 		if machine.MachineKey != machineKey && machine.GivenName == givenName { | ||||||
| 			postfixedName, err := hsdb.generateGivenName(suppliedName, true) | 			postfixedName, err := generateGivenName(suppliedName, true) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return "", err | 				return "", err | ||||||
| 			} | 			} | ||||||
| @ -695,7 +775,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { | func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { | ||||||
| 	users, err := hsdb.ListUsers() | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	users, err := hsdb.listUsers() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Error listing users") | 		log.Error().Err(err).Msg("Error listing users") | ||||||
| 
 | 
 | ||||||
| @ -703,7 +786,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, user := range users { | 	for _, user := range users { | ||||||
| 		machines, err := hsdb.ListMachinesByUser(user.Name) | 		machines, err := hsdb.listMachinesByUser(user.Name) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Err(err). | 				Err(err). | ||||||
| @ -724,7 +807,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | |||||||
| 					Str("machine", machine.Hostname). | 					Str("machine", machine.Hostname). | ||||||
| 					Msg("Ephemeral client removed from database") | 					Msg("Ephemeral client removed from database") | ||||||
| 
 | 
 | ||||||
| 				err = hsdb.HardDeleteMachine(&machines[idx]) | 				err = hsdb.deleteMachine(&machines[idx]) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error(). | 					log.Error(). | ||||||
| 						Err(err). | 						Err(err). | ||||||
| @ -744,12 +827,15 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	// use the time of the start of the function to ensure we | 	// use the time of the start of the function to ensure we | ||||||
| 	// dont miss some machines by returning it _after_ we have | 	// dont miss some machines by returning it _after_ we have | ||||||
| 	// checked everything. | 	// checked everything. | ||||||
| 	started := time.Now() | 	started := time.Now() | ||||||
| 
 | 
 | ||||||
| 	users, err := hsdb.ListUsers() | 	users, err := hsdb.listUsers() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Error listing users") | 		log.Error().Err(err).Msg("Error listing users") | ||||||
| 
 | 
 | ||||||
| @ -757,7 +843,7 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, user := range users { | 	for _, user := range users { | ||||||
| 		machines, err := hsdb.ListMachinesByUser(user.Name) | 		machines, err := hsdb.listMachinesByUser(user.Name) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Err(err). | 				Err(err). | ||||||
| @ -773,7 +859,8 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | |||||||
| 				machine.Expiry.After(lastCheck) { | 				machine.Expiry.After(lastCheck) { | ||||||
| 				expired = append(expired, tailcfg.NodeID(machine.ID)) | 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||||
| 
 | 
 | ||||||
| 				err := hsdb.ExpireMachine(&machines[index]) | 				now := time.Now() | ||||||
|  | 				err := hsdb.machineSetExpiry(&machines[index], now) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error(). | 					log.Error(). | ||||||
| 						Err(err). | 						Err(err). | ||||||
|  | |||||||
| @ -127,28 +127,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { | |||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Suite) TestDeleteMachine(c *check.C) { |  | ||||||
| 	user, err := db.CreateUser("test") |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 	machine := types.Machine{ |  | ||||||
| 		ID:             0, |  | ||||||
| 		MachineKey:     "foo", |  | ||||||
| 		NodeKey:        "bar", |  | ||||||
| 		DiscoKey:       "faa", |  | ||||||
| 		Hostname:       "testmachine", |  | ||||||
| 		UserID:         user.ID, |  | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, |  | ||||||
| 		AuthKeyID:      uint(1), |  | ||||||
| 	} |  | ||||||
| 	db.db.Save(&machine) |  | ||||||
| 
 |  | ||||||
| 	err = db.DeleteMachine(&machine) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| 
 |  | ||||||
| 	_, err = db.GetMachine(user.Name, "testmachine") |  | ||||||
| 	c.Assert(err, check.NotNil) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Suite) TestHardDeleteMachine(c *check.C) { | func (s *Suite) TestHardDeleteMachine(c *check.C) { | ||||||
| 	user, err := db.CreateUser("test") | 	user, err := db.CreateUser("test") | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| @ -164,7 +142,7 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { | |||||||
| 	} | 	} | ||||||
| 	db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 
 | 
 | ||||||
| 	err = db.HardDeleteMachine(&machine) | 	err = db.DeleteMachine(&machine) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	_, err = db.GetMachine(user.Name, "testmachine3") | 	_, err = db.GetMachine(user.Name, "testmachine3") | ||||||
| @ -329,7 +307,8 @@ func (s *Suite) TestExpireMachine(c *check.C) { | |||||||
| 
 | 
 | ||||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, false) | 	c.Assert(machineFromDB.IsExpired(), check.Equals, false) | ||||||
| 
 | 
 | ||||||
| 	err = db.ExpireMachine(machineFromDB) | 	now := time.Now() | ||||||
|  | 	err = db.MachineSetExpiry(machineFromDB, now) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 
 | 
 | ||||||
| 	c.Assert(machineFromDB.IsExpired(), check.Equals, true) | 	c.Assert(machineFromDB.IsExpired(), check.Equals, true) | ||||||
| @ -450,14 +429,12 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		name    string | 		name    string | ||||||
| 		db      *HSDatabase |  | ||||||
| 		args    args | 		args    args | ||||||
| 		want    *regexp.Regexp | 		want    *regexp.Regexp | ||||||
| 		wantErr bool | 		wantErr bool | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			name: "simple machine name generation", | 			name: "simple machine name generation", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "testmachine", | 				suppliedName: "testmachine", | ||||||
| 				randomSuffix: false, | 				randomSuffix: false, | ||||||
| @ -467,7 +444,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with 53 chars", | 			name: "machine name with 53 chars", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | 				suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", | ||||||
| 				randomSuffix: false, | 				randomSuffix: false, | ||||||
| @ -477,7 +453,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with 63 chars", | 			name: "machine name with 63 chars", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||||
| 				randomSuffix: false, | 				randomSuffix: false, | ||||||
| @ -487,7 +462,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with 64 chars", | 			name: "machine name with 64 chars", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", | ||||||
| 				randomSuffix: false, | 				randomSuffix: false, | ||||||
| @ -497,7 +471,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with 73 chars", | 			name: "machine name with 73 chars", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | 				suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", | ||||||
| 				randomSuffix: false, | 				randomSuffix: false, | ||||||
| @ -507,7 +480,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with random suffix", | 			name: "machine name with random suffix", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "test", | 				suppliedName: "test", | ||||||
| 				randomSuffix: true, | 				randomSuffix: true, | ||||||
| @ -517,7 +489,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "machine name with 63 chars with random suffix", | 			name: "machine name with 63 chars with random suffix", | ||||||
| 			db:   &HSDatabase{}, |  | ||||||
| 			args: args{ | 			args: args{ | ||||||
| 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | 				suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", | ||||||
| 				randomSuffix: true, | 				randomSuffix: true, | ||||||
| @ -528,7 +499,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | 			got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) | ||||||
| 			if (err != nil) != tt.wantErr { | 			if (err != nil) != tt.wantErr { | ||||||
| 				t.Errorf( | 				t.Errorf( | ||||||
| 					"Headscale.GenerateGivenName() error = %v, wantErr %v", | 					"Headscale.GenerateGivenName() error = %v, wantErr %v", | ||||||
|  | |||||||
| @ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| 	expiration *time.Time, | 	expiration *time.Time, | ||||||
| 	aclTags []string, | 	aclTags []string, | ||||||
| ) (*types.PreAuthKey, error) { | ) (*types.PreAuthKey, error) { | ||||||
|  | 	// TODO(kradalby): figure out this lock | ||||||
|  | 	// hsdb.mu.Lock() | ||||||
|  | 	// defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	user, err := hsdb.GetUser(userName) | 	user, err := hsdb.GetUser(userName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -92,7 +96,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( | |||||||
| 
 | 
 | ||||||
| // ListPreAuthKeys returns the list of PreAuthKeys for a user. | // ListPreAuthKeys returns the list of PreAuthKeys for a user. | ||||||
| func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||||
| 	user, err := hsdb.GetUser(userName) | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listPreAuthKeys(userName) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { | ||||||
|  | 	user, err := hsdb.getUser(userName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -107,6 +118,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er | |||||||
| 
 | 
 | ||||||
| // GetPreAuthKey returns a PreAuthKey for a given key. | // GetPreAuthKey returns a PreAuthKey for a given key. | ||||||
| func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { | func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	pak, err := hsdb.ValidatePreAuthKey(key) | 	pak, err := hsdb.ValidatePreAuthKey(key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -122,6 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe | |||||||
| // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey | ||||||
| // does not exist. | // does not exist. | ||||||
| func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.destroyPreAuthKey(pak) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { | ||||||
| 	return hsdb.db.Transaction(func(db *gorm.DB) error { | 	return hsdb.db.Transaction(func(db *gorm.DB) error { | ||||||
| 		if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { | 		if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { | ||||||
| 			return result.Error | 			return result.Error | ||||||
| @ -137,6 +158,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { | |||||||
| 
 | 
 | ||||||
| // MarkExpirePreAuthKey marks a PreAuthKey as expired. | // MarkExpirePreAuthKey marks a PreAuthKey as expired. | ||||||
| func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | 	if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -146,6 +170,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { | |||||||
| 
 | 
 | ||||||
| // UsePreAuthKey marks a PreAuthKey as used. | // UsePreAuthKey marks a PreAuthKey as used. | ||||||
| func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	k.Used = true | 	k.Used = true | ||||||
| 	if err := hsdb.db.Save(k).Error; err != nil { | 	if err := hsdb.db.Save(k).Error; err != nil { | ||||||
| 		return fmt.Errorf("failed to update key used status in the database: %w", err) | 		return fmt.Errorf("failed to update key used status in the database: %w", err) | ||||||
| @ -157,6 +184,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { | |||||||
| // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node | // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node | ||||||
| // If returns no error and a PreAuthKey, it can be used. | // If returns no error and a PreAuthKey, it can be used. | ||||||
| func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { | func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	pak := types.PreAuthKey{} | 	pak := types.PreAuthKey{} | ||||||
| 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | 	if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( | ||||||
| 		result.Error, | 		result.Error, | ||||||
| @ -174,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machines := types.Machines{} | 	machines := types.Machines{} | ||||||
| 	if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | 	if err := hsdb.db. | ||||||
|  | 		Preload("AuthKey"). | ||||||
|  | 		Where(&types.Machine{AuthKeyID: uint(pak.ID)}). | ||||||
|  | 		Find(&machines).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -13,6 +13,13 @@ import ( | |||||||
| var ErrRouteIsNotAvailable = errors.New("route is not available") | var ErrRouteIsNotAvailable = errors.New("route is not available") | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getRoutes() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { | ||||||
| 	var routes types.Routes | 	var routes types.Routes | ||||||
| 	err := hsdb.db.Preload("Machine").Find(&routes).Error | 	err := hsdb.db.Preload("Machine").Find(&routes).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getMachineAdvertisedRoutes(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { | ||||||
| 	var routes types.Routes | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| @ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { | func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getMachineRoutes(m) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) { | ||||||
| 	var routes types.Routes | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| @ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getRoute(id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { | ||||||
| 	var route types.Route | 	var route types.Route | ||||||
| 	err := hsdb.db.Preload("Machine").First(&route, id).Error | 	err := hsdb.db.Preload("Machine").First(&route, id).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) EnableRoute(id uint64) error { | func (hsdb *HSDatabase) EnableRoute(id uint64) error { | ||||||
| 	route, err := hsdb.GetRoute(id) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.enableRoute(id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) enableRoute(id uint64) error { | ||||||
|  | 	route, err := hsdb.getRoute(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) DisableRoute(id uint64) error { | func (hsdb *HSDatabase) DisableRoute(id uint64) error { | ||||||
| 	route, err := hsdb.GetRoute(id) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	route, err := hsdb.getRoute(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return hsdb.HandlePrimarySubnetFailover() | 		return hsdb.handlePrimarySubnetFailover() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | 	routes, err := hsdb.getMachineRoutes(&route.Machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.HandlePrimarySubnetFailover() | 	return hsdb.handlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | ||||||
| 	route, err := hsdb.GetRoute(id) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	route, err := hsdb.getRoute(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return hsdb.HandlePrimarySubnetFailover() | 		return hsdb.handlePrimarySubnetFailover() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routes, err := hsdb.GetMachineRoutes(&route.Machine) | 	routes, err := hsdb.getMachineRoutes(&route.Machine) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.HandlePrimarySubnetFailover() | 	return hsdb.handlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error { | ||||||
| 	routes, err := hsdb.GetMachineRoutes(m) | 	routes, err := hsdb.getMachineRoutes(m) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return hsdb.HandlePrimarySubnetFailover() | 	return hsdb.handlePrimarySubnetFailover() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // isUniquePrefix returns if there is another machine providing the same route already. | // isUniquePrefix returns if there is another machine providing the same route already. | ||||||
| @ -201,6 +242,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro | |||||||
| // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) | ||||||
| // Exit nodes are not considered for this, as they are never marked as Primary. | // Exit nodes are not considered for this, as they are never marked as Primary. | ||||||
| func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { | func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
| 	var routes types.Routes | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| 		Preload("Machine"). | 		Preload("Machine"). | ||||||
| @ -214,6 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.processMachineRoutes(machine) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error { | ||||||
| 	currentRoutes := types.Routes{} | 	currentRoutes := types.Routes{} | ||||||
| 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | 	err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -264,6 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.handlePrimarySubnetFailover() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { | ||||||
| 	// first, get all the enabled routes | 	// first, get all the enabled routes | ||||||
| 	var routes types.Routes | 	var routes types.Routes | ||||||
| 	err := hsdb.db. | 	err := hsdb.db. | ||||||
| @ -388,11 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | |||||||
| 	aclPolicy *policy.ACLPolicy, | 	aclPolicy *policy.ACLPolicy, | ||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| ) error { | ) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	if len(machine.IPAddresses) == 0 { | 	if len(machine.IPAddresses) == 0 { | ||||||
| 		return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs | 		return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routes, err := hsdb.GetMachineAdvertisedRoutes(machine) | 	routes, err := hsdb.getMachineAdvertisedRoutes(machine) | ||||||
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| @ -445,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, approvedRoute := range approvedRoutes { | 	for _, approvedRoute := range approvedRoutes { | ||||||
| 		err := hsdb.EnableRoute(uint64(approvedRoute.ID)) | 		err := hsdb.enableRoute(uint64(approvedRoute.ID)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Err(err). | 			log.Err(err). | ||||||
| 				Str("approvedRoute", approvedRoute.String()). | 				Str("approvedRoute", approvedRoute.String()). | ||||||
|  | |||||||
| @ -18,6 +18,9 @@ var ( | |||||||
| // CreateUser creates a new User. Returns error if could not be created | // CreateUser creates a new User. Returns error if could not be created | ||||||
| // or another user already exists. | // or another user already exists. | ||||||
| func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	err := util.CheckForFQDNRules(name) | 	err := util.CheckForFQDNRules(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -42,12 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { | |||||||
| // DestroyUser destroys a User. Returns error if the User does | // DestroyUser destroys a User. Returns error if the User does | ||||||
| // not exist or if there are machines associated with it. | // not exist or if there are machines associated with it. | ||||||
| func (hsdb *HSDatabase) DestroyUser(name string) error { | func (hsdb *HSDatabase) DestroyUser(name string) error { | ||||||
| 	user, err := hsdb.GetUser(name) | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	user, err := hsdb.getUser(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return ErrUserNotFound | 		return ErrUserNotFound | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machines, err := hsdb.ListMachinesByUser(name) | 	machines, err := hsdb.listMachinesByUser(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -55,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { | |||||||
| 		return ErrUserStillHasNodes | 		return ErrUserStillHasNodes | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	keys, err := hsdb.ListPreAuthKeys(name) | 	keys, err := hsdb.listPreAuthKeys(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
| 		err = hsdb.DestroyPreAuthKey(key) | 		err = hsdb.destroyPreAuthKey(key) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -76,8 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { | |||||||
| // RenameUser renames a User. Returns error if the User does | // RenameUser renames a User. Returns error if the User does | ||||||
| // not exist or if another User exists with the new name. | // not exist or if another User exists with the new name. | ||||||
| func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	var err error | 	var err error | ||||||
| 	oldUser, err := hsdb.GetUser(oldName) | 	oldUser, err := hsdb.getUser(oldName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -85,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	_, err = hsdb.GetUser(newName) | 	_, err = hsdb.getUser(newName) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		return ErrUserExists | 		return ErrUserExists | ||||||
| 	} | 	} | ||||||
| @ -104,6 +113,13 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { | |||||||
| 
 | 
 | ||||||
| // GetUser fetches a user by name. | // GetUser fetches a user by name. | ||||||
| func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.getUser(name) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { | ||||||
| 	user := types.User{} | 	user := types.User{} | ||||||
| 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | 	if result := hsdb.db.First(&user, "name = ?", name); errors.Is( | ||||||
| 		result.Error, | 		result.Error, | ||||||
| @ -117,6 +133,13 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { | |||||||
| 
 | 
 | ||||||
| // ListUsers gets all the existing users. | // ListUsers gets all the existing users. | ||||||
| func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listUsers() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listUsers() ([]types.User, error) { | ||||||
| 	users := []types.User{} | 	users := []types.User{} | ||||||
| 	if err := hsdb.db.Find(&users).Error; err != nil { | 	if err := hsdb.db.Find(&users).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -127,11 +150,18 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { | |||||||
| 
 | 
 | ||||||
| // ListMachinesByUser gets all the nodes in a given user. | // ListMachinesByUser gets all the nodes in a given user. | ||||||
| func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { | func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { | ||||||
|  | 	hsdb.mu.RLock() | ||||||
|  | 	defer hsdb.mu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	return hsdb.listMachinesByUser(name) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) { | ||||||
| 	err := util.CheckForFQDNRules(name) | 	err := util.CheckForFQDNRules(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	user, err := hsdb.GetUser(name) | 	user, err := hsdb.getUser(name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -144,13 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) | |||||||
| 	return machines, nil | 	return machines, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetMachineUser assigns a Machine to a user. | // AssignMachineToUser assigns a Machine to a user. | ||||||
| func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { | func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error { | ||||||
|  | 	hsdb.mu.Lock() | ||||||
|  | 	defer hsdb.mu.Unlock() | ||||||
|  | 
 | ||||||
| 	err := util.CheckForFQDNRules(username) | 	err := util.CheckForFQDNRules(username) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	user, err := hsdb.GetUser(username) | 	user, err := hsdb.getUser(username) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -114,15 +114,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { | |||||||
| 	db.db.Save(&machine) | 	db.db.Save(&machine) | ||||||
| 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | 	c.Assert(machine.UserID, check.Equals, oldUser.ID) | ||||||
| 
 | 
 | ||||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | 	err = db.AssignMachineToUser(&machine, newUser.Name) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||||
| 
 | 
 | ||||||
| 	err = db.SetMachineUser(&machine, "non-existing-user") | 	err = db.AssignMachineToUser(&machine, "non-existing-user") | ||||||
| 	c.Assert(err, check.Equals, ErrUserNotFound) | 	c.Assert(err, check.Equals, ErrUserNotFound) | ||||||
| 
 | 
 | ||||||
| 	err = db.SetMachineUser(&machine, newUser.Name) | 	err = db.AssignMachineToUser(&machine, newUser.Name) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| 	c.Assert(machine.UserID, check.Equals, newUser.ID) | 	c.Assert(machine.UserID, check.Equals, newUser.ID) | ||||||
| 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | 	c.Assert(machine.User.Name, check.Equals, newUser.Name) | ||||||
|  | |||||||
| @ -275,8 +275,11 @@ func (api headscaleV1APIServer) ExpireMachine( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	api.h.db.ExpireMachine( | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	api.h.db.MachineSetExpiry( | ||||||
| 		machine, | 		machine, | ||||||
|  | 		now, | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| @ -358,7 +361,7 @@ func (api headscaleV1APIServer) MoveMachine( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = api.h.db.SetMachineUser(machine, request.GetUser()) | 	err = api.h.db.AssignMachineToUser(machine, request.GetUser()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -523,7 +523,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 			Str("machine", machine.Hostname). | 			Str("machine", machine.Hostname). | ||||||
| 			Msg("machine already registered, reauthenticating") | 			Msg("machine already registered, reauthenticating") | ||||||
| 
 | 
 | ||||||
| 		err := h.db.RefreshMachine(machine, expiry) | 		err := h.db.MachineSetExpiry(machine, expiry) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			util.LogErr(err, "Failed to refresh machine") | 			util.LogErr(err, "Failed to refresh machine") | ||||||
| 			http.Error( | 			http.Error( | ||||||
|  | |||||||
| @ -107,6 +107,7 @@ func (h *Headscale) handlePoll( | |||||||
| 		machine.LastSeen = &now | 		machine.LastSeen = &now | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// TODO(kradalby): Save specific stuff, not whole object. | ||||||
| 	if err := h.db.MachineSave(machine); err != nil { | 	if err := h.db.MachineSave(machine); err != nil { | ||||||
| 		logErr(err, "Failed to persist/update machine in the database") | 		logErr(err, "Failed to persist/update machine in the database") | ||||||
| 		http.Error(writer, "", http.StatusInternalServerError) | 		http.Error(writer, "", http.StatusInternalServerError) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user