mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 08:01:34 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			950 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			950 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package db
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/netip"
 | |
| 	"sort"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/juanfont/headscale/hscontrol/types"
 | |
| 	"github.com/juanfont/headscale/hscontrol/util"
 | |
| 	"github.com/patrickmn/go-cache"
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"gorm.io/gorm"
 | |
| 	"tailscale.com/tailcfg"
 | |
| 	"tailscale.com/types/key"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	NodeGivenNameHashLength = 8
 | |
| 	NodeGivenNameTrimSize   = 2
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	ErrNodeNotFound                  = errors.New("node not found")
 | |
| 	ErrNodeRouteIsNotAvailable       = errors.New("route is not available on node")
 | |
| 	ErrNodeNotFoundRegistrationCache = errors.New(
 | |
| 		"node not found in registration cache",
 | |
| 	)
 | |
| 	ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
 | |
| 	ErrDifferentRegisteredUser      = errors.New(
 | |
| 		"node was previously registered with a different user",
 | |
| 	)
 | |
| )
 | |
| 
 | |
| // ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
 | |
| func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.listPeers(node)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
 | |
| 	log.Trace().
 | |
| 		Caller().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Msg("Finding direct peers")
 | |
| 
 | |
| 	nodes := types.Nodes{}
 | |
| 	if err := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		Where("node_key <> ?",
 | |
| 			node.NodeKey.String()).Find(&nodes).Error; err != nil {
 | |
| 		return types.Nodes{}, err
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID })
 | |
| 
 | |
| 	return nodes, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.listNodes()
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
 | |
| 	nodes := []types.Node{}
 | |
| 	if err := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		Find(&nodes).Error; err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return nodes, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.listNodesByGivenName(givenName)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
 | |
| 	nodes := types.Nodes{}
 | |
| 	if err := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		Where("given_name = ?", givenName).Find(&nodes).Error; err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return nodes, nil
 | |
| }
 | |
| 
 | |
| // GetNode finds a Node by name and user and returns the Node struct.
 | |
| func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	nodes, err := hsdb.ListNodesByUser(user)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	for _, m := range nodes {
 | |
| 		if m.Hostname == name {
 | |
| 			return m, nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil, ErrNodeNotFound
 | |
| }
 | |
| 
 | |
| // GetNodeByGivenName finds a Node by given name and user and returns the Node struct.
 | |
| func (hsdb *HSDatabase) GetNodeByGivenName(
 | |
| 	user string,
 | |
| 	givenName string,
 | |
| ) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	node := types.Node{}
 | |
| 	if err := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		Where("given_name = ?", givenName).First(&node).Error; err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return nil, ErrNodeNotFound
 | |
| }
 | |
| 
 | |
| // GetNodeByID finds a Node by ID and returns the Node struct.
 | |
| func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	mach := types.Node{}
 | |
| 	if result := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		Find(&types.Node{ID: id}).First(&mach); result.Error != nil {
 | |
| 		return nil, result.Error
 | |
| 	}
 | |
| 
 | |
| 	return &mach, nil
 | |
| }
 | |
| 
 | |
| // GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
 | |
| func (hsdb *HSDatabase) GetNodeByMachineKey(
 | |
| 	machineKey key.MachinePublic,
 | |
| ) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.getNodeByMachineKey(machineKey)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) getNodeByMachineKey(
 | |
| 	machineKey key.MachinePublic,
 | |
| ) (*types.Node, error) {
 | |
| 	mach := types.Node{}
 | |
| 	if result := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
 | |
| 		return nil, result.Error
 | |
| 	}
 | |
| 
 | |
| 	return &mach, nil
 | |
| }
 | |
| 
 | |
| // GetNodeByNodeKey finds a Node by its current NodeKey.
 | |
| func (hsdb *HSDatabase) GetNodeByNodeKey(
 | |
| 	nodeKey key.NodePublic,
 | |
| ) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	node := types.Node{}
 | |
| 	if result := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		First(&node, "node_key = ?",
 | |
| 			nodeKey.String()); result.Error != nil {
 | |
| 		return nil, result.Error
 | |
| 	}
 | |
| 
 | |
| 	return &node, nil
 | |
| }
 | |
| 
 | |
| // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
 | |
| func (hsdb *HSDatabase) GetNodeByAnyKey(
 | |
| 	machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
 | |
| ) (*types.Node, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	node := types.Node{}
 | |
| 	if result := hsdb.db.
 | |
| 		Preload("AuthKey").
 | |
| 		Preload("AuthKey.User").
 | |
| 		Preload("User").
 | |
| 		Preload("Routes").
 | |
| 		First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
 | |
| 			machineKey.String(),
 | |
| 			nodeKey.String(),
 | |
| 			oldNodeKey.String()); result.Error != nil {
 | |
| 		return nil, result.Error
 | |
| 	}
 | |
| 
 | |
| 	return &node, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	if result := hsdb.db.Find(node).First(&node); result.Error != nil {
 | |
| 		return result.Error
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // SetTags takes a Node struct pointer and update the forced tags.
 | |
| func (hsdb *HSDatabase) SetTags(
 | |
| 	node *types.Node,
 | |
| 	tags []string,
 | |
| ) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	if len(tags) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	newTags := []string{}
 | |
| 	for _, tag := range tags {
 | |
| 		if !util.StringOrPrefixListContains(newTags, tag) {
 | |
| 			newTags = append(newTags, tag)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if err := hsdb.db.Model(node).Updates(types.Node{
 | |
| 		ForcedTags: newTags,
 | |
| 	}).Error; err != nil {
 | |
| 		return fmt.Errorf("failed to update tags for node in the database: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type:        types.StatePeerChanged,
 | |
| 		ChangeNodes: types.Nodes{node},
 | |
| 		Message:     "called from db.SetTags",
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // RenameNode takes a Node struct and a new GivenName for the nodes
 | |
| // and renames it.
 | |
| func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	err := util.CheckForFQDNRules(
 | |
| 		newName,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		log.Error().
 | |
| 			Caller().
 | |
| 			Str("func", "RenameNode").
 | |
| 			Str("node", node.Hostname).
 | |
| 			Str("newName", newName).
 | |
| 			Err(err).
 | |
| 			Msg("failed to rename node")
 | |
| 
 | |
| 		return err
 | |
| 	}
 | |
| 	node.GivenName = newName
 | |
| 
 | |
| 	if err := hsdb.db.Model(node).Updates(types.Node{
 | |
| 		GivenName: newName,
 | |
| 	}).Error; err != nil {
 | |
| 		return fmt.Errorf("failed to rename node in the database: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type:        types.StatePeerChanged,
 | |
| 		ChangeNodes: types.Nodes{node},
 | |
| 		Message:     "called from db.RenameNode",
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // NodeSetExpiry takes a Node struct and  a new expiry time.
 | |
| func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	return hsdb.nodeSetExpiry(node, expiry)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error {
 | |
| 	if err := hsdb.db.Model(node).Updates(types.Node{
 | |
| 		Expiry: &expiry,
 | |
| 	}).Error; err != nil {
 | |
| 		return fmt.Errorf(
 | |
| 			"failed to refresh node (update expiration) in the database: %w",
 | |
| 			err,
 | |
| 		)
 | |
| 	}
 | |
| 
 | |
| 	node.Expiry = &expiry
 | |
| 
 | |
| 	stateSelfUpdate := types.StateUpdate{
 | |
| 		Type:        types.StateSelfUpdate,
 | |
| 		ChangeNodes: types.Nodes{node},
 | |
| 	}
 | |
| 	if stateSelfUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
 | |
| 	}
 | |
| 
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type: types.StatePeerChangedPatch,
 | |
| 		ChangePatches: []*tailcfg.PeerChange{
 | |
| 			{
 | |
| 				NodeID:    tailcfg.NodeID(node.ID),
 | |
| 				KeyExpiry: &expiry,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // DeleteNode deletes a Node from the database.
 | |
| func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	return hsdb.deleteNode(node)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
 | |
| 	err := hsdb.deleteNodeRoutes(node)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Unscoped causes the node to be fully removed from the database.
 | |
| 	if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type:    types.StatePeerRemoved,
 | |
| 		Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyAll(stateUpdate)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // UpdateLastSeen sets a node's last seen field indicating that we
 | |
| // have recently communicating with this node.
 | |
| // This is mostly used to indicate if a node is online and is not
 | |
| // extremely important to make sure is fully correct and to avoid
 | |
| // holding up the hot path, does not contain any locks and isnt
 | |
| // concurrency safe. But that should be ok.
 | |
| func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
 | |
| 	return hsdb.db.Model(node).Updates(types.Node{
 | |
| 		LastSeen: node.LastSeen,
 | |
| 	}).Error
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
 | |
| 	cache *cache.Cache,
 | |
| 	mkey key.MachinePublic,
 | |
| 	userName string,
 | |
| 	nodeExpiry *time.Time,
 | |
| 	registrationMethod string,
 | |
| ) (*types.Node, error) {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	log.Debug().
 | |
| 		Str("machine_key", mkey.ShortString()).
 | |
| 		Str("userName", userName).
 | |
| 		Str("registrationMethod", registrationMethod).
 | |
| 		Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
 | |
| 		Msg("Registering node from API/CLI or auth callback")
 | |
| 
 | |
| 	if nodeInterface, ok := cache.Get(mkey.String()); ok {
 | |
| 		if registrationNode, ok := nodeInterface.(types.Node); ok {
 | |
| 			user, err := hsdb.getUser(userName)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf(
 | |
| 					"failed to find user in register node from auth callback, %w",
 | |
| 					err,
 | |
| 				)
 | |
| 			}
 | |
| 
 | |
| 			// Registration of expired node with different user
 | |
| 			if registrationNode.ID != 0 &&
 | |
| 				registrationNode.UserID != user.ID {
 | |
| 				return nil, ErrDifferentRegisteredUser
 | |
| 			}
 | |
| 
 | |
| 			registrationNode.UserID = user.ID
 | |
| 			registrationNode.RegisterMethod = registrationMethod
 | |
| 
 | |
| 			if nodeExpiry != nil {
 | |
| 				registrationNode.Expiry = nodeExpiry
 | |
| 			}
 | |
| 
 | |
| 			node, err := hsdb.registerNode(
 | |
| 				registrationNode,
 | |
| 			)
 | |
| 
 | |
| 			if err == nil {
 | |
| 				cache.Delete(mkey.String())
 | |
| 			}
 | |
| 
 | |
| 			return node, err
 | |
| 		} else {
 | |
| 			return nil, ErrCouldNotConvertNodeInterface
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil, ErrNodeNotFoundRegistrationCache
 | |
| }
 | |
| 
 | |
| // RegisterNode is executed from the CLI to register a new Node using its MachineKey.
 | |
| func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	return hsdb.registerNode(node)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
 | |
| 	log.Debug().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Str("machine_key", node.MachineKey.ShortString()).
 | |
| 		Str("node_key", node.NodeKey.ShortString()).
 | |
| 		Str("user", node.User.Name).
 | |
| 		Msg("Registering node")
 | |
| 
 | |
| 	// If the node exists and we had already IPs for it, we just save it
 | |
| 	// so we store the node.Expire and node.Nodekey that has been set when
 | |
| 	// adding it to the registrationCache
 | |
| 	if len(node.IPAddresses) > 0 {
 | |
| 		if err := hsdb.db.Save(&node).Error; err != nil {
 | |
| 			return nil, fmt.Errorf("failed register existing node in the database: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		log.Trace().
 | |
| 			Caller().
 | |
| 			Str("node", node.Hostname).
 | |
| 			Str("machine_key", node.MachineKey.ShortString()).
 | |
| 			Str("node_key", node.NodeKey.ShortString()).
 | |
| 			Str("user", node.User.Name).
 | |
| 			Msg("Node authorized again")
 | |
| 
 | |
| 		return &node, nil
 | |
| 	}
 | |
| 
 | |
| 	hsdb.ipAllocationMutex.Lock()
 | |
| 	defer hsdb.ipAllocationMutex.Unlock()
 | |
| 
 | |
| 	ips, err := hsdb.getAvailableIPs()
 | |
| 	if err != nil {
 | |
| 		log.Error().
 | |
| 			Caller().
 | |
| 			Err(err).
 | |
| 			Str("node", node.Hostname).
 | |
| 			Msg("Could not find IP for the new node")
 | |
| 
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	node.IPAddresses = ips
 | |
| 
 | |
| 	if err := hsdb.db.Save(&node).Error; err != nil {
 | |
| 		return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Caller().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Str("ip", strings.Join(ips.StringSlice(), ",")).
 | |
| 		Msg("Node registered with the database")
 | |
| 
 | |
| 	return &node, nil
 | |
| }
 | |
| 
 | |
| // NodeSetNodeKey sets the node key of a node and saves it to the database.
 | |
| func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	if err := hsdb.db.Model(node).Updates(types.Node{
 | |
| 		NodeKey: nodeKey,
 | |
| 	}).Error; err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // NodeSetMachineKey sets the node key of a node and saves it to the database.
 | |
| func (hsdb *HSDatabase) NodeSetMachineKey(
 | |
| 	node *types.Node,
 | |
| 	machineKey key.MachinePublic,
 | |
| ) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	if err := hsdb.db.Model(node).Updates(types.Node{
 | |
| 		MachineKey: machineKey,
 | |
| 	}).Error; err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // NodeSave saves a node object to the database, prefer to use a specific save method rather
 | |
| // than this. It is intended to be used when we are changing or.
 | |
| func (hsdb *HSDatabase) NodeSave(node *types.Node) error {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	if err := hsdb.db.Save(node).Error; err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // GetAdvertisedRoutes returns the routes that are be advertised by the given node.
 | |
| func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.getAdvertisedRoutes(node)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
 | |
| 	routes := types.Routes{}
 | |
| 
 | |
| 	err := hsdb.db.
 | |
| 		Preload("Node").
 | |
| 		Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
 | |
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 | |
| 		log.Error().
 | |
| 			Caller().
 | |
| 			Err(err).
 | |
| 			Str("node", node.Hostname).
 | |
| 			Msg("Could not get advertised routes for node")
 | |
| 
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	prefixes := []netip.Prefix{}
 | |
| 	for _, route := range routes {
 | |
| 		prefixes = append(prefixes, netip.Prefix(route.Prefix))
 | |
| 	}
 | |
| 
 | |
| 	return prefixes, nil
 | |
| }
 | |
| 
 | |
| // GetEnabledRoutes returns the routes that are enabled for the node.
 | |
| func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	return hsdb.getEnabledRoutes(node)
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
 | |
| 	routes := types.Routes{}
 | |
| 
 | |
| 	err := hsdb.db.
 | |
| 		Preload("Node").
 | |
| 		Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
 | |
| 		Find(&routes).Error
 | |
| 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 | |
| 		log.Error().
 | |
| 			Caller().
 | |
| 			Err(err).
 | |
| 			Str("node", node.Hostname).
 | |
| 			Msg("Could not get enabled routes for node")
 | |
| 
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	prefixes := []netip.Prefix{}
 | |
| 	for _, route := range routes {
 | |
| 		prefixes = append(prefixes, netip.Prefix(route.Prefix))
 | |
| 	}
 | |
| 
 | |
| 	return prefixes, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	route, err := netip.ParsePrefix(routeStr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	enabledRoutes, err := hsdb.getEnabledRoutes(node)
 | |
| 	if err != nil {
 | |
| 		log.Error().Err(err).Msg("Could not get enabled routes")
 | |
| 
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	for _, enabledRoute := range enabledRoutes {
 | |
| 		if route == enabledRoute {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // enableRoutes enables new routes based on a list of new routes.
 | |
| func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
 | |
| 	newRoutes := make([]netip.Prefix, len(routeStrs))
 | |
| 	for index, routeStr := range routeStrs {
 | |
| 		route, err := netip.ParsePrefix(routeStr)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		newRoutes[index] = route
 | |
| 	}
 | |
| 
 | |
| 	advertisedRoutes, err := hsdb.getAdvertisedRoutes(node)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	for _, newRoute := range newRoutes {
 | |
| 		if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
 | |
| 			return fmt.Errorf(
 | |
| 				"route (%s) is not available on node %s: %w",
 | |
| 				node.Hostname,
 | |
| 				newRoute, ErrNodeRouteIsNotAvailable,
 | |
| 			)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Separate loop so we don't leave things in a half-updated state
 | |
| 	for _, prefix := range newRoutes {
 | |
| 		route := types.Route{}
 | |
| 		err := hsdb.db.Preload("Node").
 | |
| 			Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
 | |
| 			First(&route).Error
 | |
| 		if err == nil {
 | |
| 			route.Enabled = true
 | |
| 
 | |
| 			// Mark already as primary if there is only this node offering this subnet
 | |
| 			// (and is not an exit route)
 | |
| 			if !route.IsExitRoute() {
 | |
| 				route.IsPrimary = hsdb.isUniquePrefix(route)
 | |
| 			}
 | |
| 
 | |
| 			err = hsdb.db.Save(&route).Error
 | |
| 			if err != nil {
 | |
| 				return fmt.Errorf("failed to enable route: %w", err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			return fmt.Errorf("failed to find route: %w", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Ensure the node has the latest routes when notifying the other
 | |
| 	// nodes
 | |
| 	nRoutes, err := hsdb.getNodeRoutes(node)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("failed to read back routes: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	node.Routes = nRoutes
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Caller().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Strs("routes", routeStrs).
 | |
| 		Msg("enabling routes")
 | |
| 
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type:        types.StatePeerChanged,
 | |
| 		ChangeNodes: types.Nodes{node},
 | |
| 		Message:     "called from db.enableRoutes",
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyWithIgnore(
 | |
| 			stateUpdate, node.MachineKey.String())
 | |
| 	}
 | |
| 
 | |
| 	// Send an update to the node itself with to ensure it
 | |
| 	// has an updated packetfilter allowing the new route
 | |
| 	// if it is defined in the ACL.
 | |
| 	selfUpdate := types.StateUpdate{
 | |
| 		Type:        types.StateSelfUpdate,
 | |
| 		ChangeNodes: types.Nodes{node},
 | |
| 	}
 | |
| 	if selfUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyByMachineKey(
 | |
| 			selfUpdate,
 | |
| 			node.MachineKey)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
 | |
| 	normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
 | |
| 		suppliedName,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	if randomSuffix {
 | |
| 		// Trim if a hostname will be longer than 63 chars after adding the hash.
 | |
| 		trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
 | |
| 		if len(normalizedHostname) > trimmedHostnameLength {
 | |
| 			normalizedHostname = normalizedHostname[:trimmedHostnameLength]
 | |
| 		}
 | |
| 
 | |
| 		suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
 | |
| 		if err != nil {
 | |
| 			return "", err
 | |
| 		}
 | |
| 
 | |
| 		normalizedHostname += "-" + suffix
 | |
| 	}
 | |
| 
 | |
| 	return normalizedHostname, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) GenerateGivenName(
 | |
| 	mkey key.MachinePublic,
 | |
| 	suppliedName string,
 | |
| ) (string, error) {
 | |
| 	hsdb.mu.RLock()
 | |
| 	defer hsdb.mu.RUnlock()
 | |
| 
 | |
| 	givenName, err := generateGivenName(suppliedName, false)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
 | |
| 	nodes, err := hsdb.listNodesByGivenName(givenName)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	var nodeFound *types.Node
 | |
| 	for idx, node := range nodes {
 | |
| 		if node.GivenName == givenName {
 | |
| 			nodeFound = nodes[idx]
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
 | |
| 		postfixedName, err := generateGivenName(suppliedName, true)
 | |
| 		if err != nil {
 | |
| 			return "", err
 | |
| 		}
 | |
| 
 | |
| 		givenName = postfixedName
 | |
| 	}
 | |
| 
 | |
| 	return givenName, nil
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	users, err := hsdb.listUsers()
 | |
| 	if err != nil {
 | |
| 		log.Error().Err(err).Msg("Error listing users")
 | |
| 
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for _, user := range users {
 | |
| 		nodes, err := hsdb.listNodesByUser(user.Name)
 | |
| 		if err != nil {
 | |
| 			log.Error().
 | |
| 				Err(err).
 | |
| 				Str("user", user.Name).
 | |
| 				Msg("Error listing nodes in user")
 | |
| 
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		expired := make([]tailcfg.NodeID, 0)
 | |
| 		for idx, node := range nodes {
 | |
| 			if node.IsEphemeral() && node.LastSeen != nil &&
 | |
| 				time.Now().
 | |
| 					After(node.LastSeen.Add(inactivityThreshhold)) {
 | |
| 				expired = append(expired, tailcfg.NodeID(node.ID))
 | |
| 
 | |
| 				log.Info().
 | |
| 					Str("node", node.Hostname).
 | |
| 					Msg("Ephemeral client removed from database")
 | |
| 
 | |
| 				err = hsdb.deleteNode(nodes[idx])
 | |
| 				if err != nil {
 | |
| 					log.Error().
 | |
| 						Err(err).
 | |
| 						Str("node", node.Hostname).
 | |
| 						Msg("🤮 Cannot delete ephemeral node from the database")
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if len(expired) > 0 {
 | |
| 			hsdb.notifier.NotifyAll(types.StateUpdate{
 | |
| 				Type:    types.StatePeerRemoved,
 | |
| 				Removed: expired,
 | |
| 			})
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
 | |
| 	hsdb.mu.Lock()
 | |
| 	defer hsdb.mu.Unlock()
 | |
| 
 | |
| 	// use the time of the start of the function to ensure we
 | |
| 	// dont miss some nodes by returning it _after_ we have
 | |
| 	// checked everything.
 | |
| 	started := time.Now()
 | |
| 
 | |
| 	expiredNodes := make([]*types.Node, 0)
 | |
| 
 | |
| 	nodes, err := hsdb.listNodes()
 | |
| 	if err != nil {
 | |
| 		log.Error().
 | |
| 			Err(err).
 | |
| 			Msg("Error listing nodes to find expired nodes")
 | |
| 
 | |
| 		return time.Unix(0, 0)
 | |
| 	}
 | |
| 	for index, node := range nodes {
 | |
| 		if node.IsExpired() &&
 | |
| 			// TODO(kradalby): Replace this, it is very spammy
 | |
| 			// It will notify about all nodes that has been expired.
 | |
| 			// It should only notify about expired nodes since _last check_.
 | |
| 			node.Expiry.After(lastCheck) {
 | |
| 			expiredNodes = append(expiredNodes, &nodes[index])
 | |
| 
 | |
| 			// Do not use setNodeExpiry as that has a notifier hook, which
 | |
| 			// can cause a deadlock, we are updating all changed nodes later
 | |
| 			// and there is no point in notifiying twice.
 | |
| 			if err := hsdb.db.Model(&nodes[index]).Updates(types.Node{
 | |
| 				Expiry: &started,
 | |
| 			}).Error; err != nil {
 | |
| 				log.Error().
 | |
| 					Err(err).
 | |
| 					Str("node", node.Hostname).
 | |
| 					Str("name", node.GivenName).
 | |
| 					Msg("🤮 Cannot expire node")
 | |
| 			} else {
 | |
| 				log.Info().
 | |
| 					Str("node", node.Hostname).
 | |
| 					Str("name", node.GivenName).
 | |
| 					Msg("Node successfully expired")
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	expired := make([]*tailcfg.PeerChange, len(expiredNodes))
 | |
| 	for idx, node := range expiredNodes {
 | |
| 		expired[idx] = &tailcfg.PeerChange{
 | |
| 			NodeID:    tailcfg.NodeID(node.ID),
 | |
| 			KeyExpiry: &started,
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Inform the peers of a node with a lightweight update.
 | |
| 	stateUpdate := types.StateUpdate{
 | |
| 		Type:          types.StatePeerChangedPatch,
 | |
| 		ChangePatches: expired,
 | |
| 	}
 | |
| 	if stateUpdate.Valid() {
 | |
| 		hsdb.notifier.NotifyAll(stateUpdate)
 | |
| 	}
 | |
| 
 | |
| 	// Inform the node itself that it has expired.
 | |
| 	for _, node := range expiredNodes {
 | |
| 		stateSelfUpdate := types.StateUpdate{
 | |
| 			Type:        types.StateSelfUpdate,
 | |
| 			ChangeNodes: types.Nodes{node},
 | |
| 		}
 | |
| 		if stateSelfUpdate.Valid() {
 | |
| 			hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return started
 | |
| }
 |