mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	Only read relevant nodes from database in PeerChangedResponse (#2509)
* Only read relevant nodes from database in PeerChangedResponse * Rework to ensure transactional consistency in PeerChangedResponse again * An empty nodeIDs list should return an empty nodes list * Add test to ListNodesSubset * Link PR in CHANGELOG.md * combine ListNodes and ListNodesSubset into one function * query for all nodes in ListNodes if no parameter is given * also add optional filtering for relevant nodes to ListPeers
This commit is contained in:
		
							parent
							
								
									d2a6356d89
								
							
						
					
					
						commit
						0d3134720b
					
				| @ -87,6 +87,7 @@ The new policy can be used by setting the environment variable | ||||
|   [#2493](https://github.com/juanfont/headscale/pull/2493) | ||||
|   - If a OIDC provider doesn't include the `email_verified` claim in its ID | ||||
|     tokens, Headscale will attempt to get it from the UserInfo endpoint. | ||||
| - Improve performance by only querying relevant nodes from the database for node updates [#2509](https://github.com/juanfont/headscale/pull/2509) | ||||
| 
 | ||||
| ## 0.25.1 (2025-02-25) | ||||
| 
 | ||||
|  | ||||
| @ -35,21 +35,26 @@ var ( | ||||
| 	) | ||||
| ) | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | ||||
| // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||
| // If no peer IDs are given, all peers are returned. | ||||
| // If at least one peer ID is given, only these peer nodes will be returned. | ||||
| func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | ||||
| 		return ListPeers(rx, nodeID) | ||||
| 		return ListPeers(rx, nodeID, peerIDs...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // ListPeers returns all peers of node, regardless of any Policy or if the node is expired. | ||||
| func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { | ||||
| // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||
| // If no peer IDs are given, all peers are returned. | ||||
| // If at least one peer ID is given, only these peer nodes will be returned. | ||||
| func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	nodes := types.Nodes{} | ||||
| 	if err := tx. | ||||
| 		Preload("AuthKey"). | ||||
| 		Preload("AuthKey.User"). | ||||
| 		Preload("User"). | ||||
| 		Where("id <> ?", | ||||
| 			nodeID).Find(&nodes).Error; err != nil { | ||||
| 		Where("id <> ?", nodeID). | ||||
| 		Where(peerIDs).Find(&nodes).Error; err != nil { | ||||
| 		return types.Nodes{}, err | ||||
| 	} | ||||
| 
 | ||||
| @ -58,19 +63,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { | ||||
| 	return nodes, nil | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { | ||||
| // ListNodes queries the database for either all nodes if no parameters are given | ||||
| // or for the given nodes if at least one node ID is given as parameter | ||||
| func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | ||||
| 		return ListNodes(rx) | ||||
| 		return ListNodes(rx, nodeIDs...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func ListNodes(tx *gorm.DB) (types.Nodes, error) { | ||||
| // ListNodes queries the database for either all nodes if no parameters are given | ||||
| // or for the given nodes if at least one node ID is given as parameter | ||||
| func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	nodes := types.Nodes{} | ||||
| 	if err := tx. | ||||
| 		Preload("AuthKey"). | ||||
| 		Preload("AuthKey.User"). | ||||
| 		Preload("User"). | ||||
| 		Find(&nodes).Error; err != nil { | ||||
| 		Where(nodeIDs).Find(&nodes).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -747,3 +747,174 @@ func TestRenameNode(t *testing.T) { | ||||
| 	}) | ||||
| 	assert.ErrorContains(t, err, "name is not unique") | ||||
| } | ||||
| 
 | ||||
| func TestListPeers(t *testing.T) { | ||||
| 	// Setup test database | ||||
| 	db, err := newSQLiteTestDB() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("creating db: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := db.CreateUser(types.User{Name: "test"}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	user2, err := db.CreateUser(types.User{Name: "user2"}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	node1 := types.Node{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     key.NewMachine().Public(), | ||||
| 		NodeKey:        key.NewNode().Public(), | ||||
| 		Hostname:       "test1", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||
| 	} | ||||
| 
 | ||||
| 	node2 := types.Node{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     key.NewMachine().Public(), | ||||
| 		NodeKey:        key.NewNode().Public(), | ||||
| 		Hostname:       "test2", | ||||
| 		UserID:         user2.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.DB.Save(&node1).Error | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	err = db.DB.Save(&node2).Error | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	err = db.DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		_, err := RegisterNode(tx, node1, nil, nil) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		_, err = RegisterNode(tx, node2, nil, nil) | ||||
| 		return err | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	nodes, err := db.ListNodes() | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	assert.Len(t, nodes, 2) | ||||
| 
 | ||||
| 	// No parameter means no filter, should return all peers | ||||
| 	nodes, err = db.ListPeers(1) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 1) | ||||
| 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||
| 
 | ||||
| 	// Empty node list should return all peers | ||||
| 	nodes, err = db.ListPeers(1, types.NodeIDs{}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 1) | ||||
| 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||
| 
 | ||||
| 	// No match in IDs should return empty list and no error | ||||
| 	nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 0) | ||||
| 
 | ||||
| 	// Partial match in IDs | ||||
| 	nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 1) | ||||
| 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||
| 
 | ||||
| 	// Several matched IDs, but node ID is still filtered out | ||||
| 	nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 1) | ||||
| 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||
| } | ||||
| 
 | ||||
| func TestListNodes(t *testing.T) { | ||||
| 	// Setup test database | ||||
| 	db, err := newSQLiteTestDB() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("creating db: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user, err := db.CreateUser(types.User{Name: "test"}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	user2, err := db.CreateUser(types.User{Name: "user2"}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	node1 := types.Node{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     key.NewMachine().Public(), | ||||
| 		NodeKey:        key.NewNode().Public(), | ||||
| 		Hostname:       "test1", | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||
| 	} | ||||
| 
 | ||||
| 	node2 := types.Node{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     key.NewMachine().Public(), | ||||
| 		NodeKey:        key.NewNode().Public(), | ||||
| 		Hostname:       "test2", | ||||
| 		UserID:         user2.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.DB.Save(&node1).Error | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	err = db.DB.Save(&node2).Error | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	err = db.DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		_, err := RegisterNode(tx, node1, nil, nil) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		_, err = RegisterNode(tx, node2, nil, nil) | ||||
| 		return err | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	nodes, err := db.ListNodes() | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	assert.Len(t, nodes, 2) | ||||
| 
 | ||||
| 	// No parameter means no filter, should return all nodes | ||||
| 	nodes, err = db.ListNodes() | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 2) | ||||
| 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||
| 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||
| 
 | ||||
| 	// Empty node list should return all nodes | ||||
| 	nodes, err = db.ListNodes(types.NodeIDs{}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 2) | ||||
| 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||
| 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||
| 
 | ||||
| 	// No match in IDs should return empty list and no error | ||||
| 	nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 0) | ||||
| 
 | ||||
| 	// Partial match in IDs | ||||
| 	nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 1) | ||||
| 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||
| 
 | ||||
| 	// Several matched IDs | ||||
| 	nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, len(nodes), 2) | ||||
| 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||
| 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||
| } | ||||
|  | ||||
| @ -255,27 +255,25 @@ func (m *Mapper) PeerChangedResponse( | ||||
| 	patches []*tailcfg.PeerChange, | ||||
| 	messages ...string, | ||||
| ) ([]byte, error) { | ||||
| 	var err error | ||||
| 	resp := m.baseMapResponse() | ||||
| 
 | ||||
| 	peers, err := m.ListPeers(node.ID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	var removedIDs []tailcfg.NodeID | ||||
| 	var changedIDs []types.NodeID | ||||
| 	for nodeID, nodeChanged := range changed { | ||||
| 		if nodeChanged { | ||||
| 			changedIDs = append(changedIDs, nodeID) | ||||
| 			if nodeID != node.ID { | ||||
| 				changedIDs = append(changedIDs, nodeID) | ||||
| 			} | ||||
| 		} else { | ||||
| 			removedIDs = append(removedIDs, nodeID.NodeID()) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	changedNodes := make(types.Nodes, 0, len(changedIDs)) | ||||
| 	for _, peer := range peers { | ||||
| 		if slices.Contains(changedIDs, peer.ID) { | ||||
| 			changedNodes = append(changedNodes, peer) | ||||
| 	changedNodes := types.Nodes{} | ||||
| 	if len(changedIDs) > 0 { | ||||
| 		changedNodes, err = m.ListNodes(changedIDs...) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @ -482,8 +480,11 @@ func (m *Mapper) baseWithConfigMapResponse( | ||||
| 	return &resp, nil | ||||
| } | ||||
| 
 | ||||
| func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | ||||
| 	peers, err := m.db.ListPeers(nodeID) | ||||
| // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||
| // If no peer IDs are given, all peers are returned. | ||||
| // If at least one peer ID is given, only these peer nodes will be returned. | ||||
| func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	peers, err := m.db.ListPeers(nodeID, peerIDs...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -496,6 +497,22 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | ||||
| 	return peers, nil | ||||
| } | ||||
| 
 | ||||
| // ListNodes queries the database for either all nodes if no parameters are given | ||||
| // or for the given nodes if at least one node ID is given as parameter | ||||
| func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||
| 	nodes, err := m.db.ListNodes(nodeIDs...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, node := range nodes { | ||||
| 		online := m.notif.IsLikelyConnected(node.ID) | ||||
| 		node.IsOnline = &online | ||||
| 	} | ||||
| 
 | ||||
| 	return nodes, nil | ||||
| } | ||||
| 
 | ||||
| func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { | ||||
| 	ret := make(types.Nodes, 0) | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user