mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	Only read relevant nodes from database in PeerChangedResponse (#2509)
* Only read relevant nodes from database in PeerChangedResponse * Rework to ensure transactional consistency in PeerChangedResponse again * An empty nodeIDs list should return an empty nodes list * Add test to ListNodesSubset * Link PR in CHANGELOG.md * combine ListNodes and ListNodesSubset into one function * query for all nodes in ListNodes if no parameter is given * also add optional filtering for relevant nodes to ListPeers
This commit is contained in:
		
							parent
							
								
									d2a6356d89
								
							
						
					
					
						commit
						0d3134720b
					
				| @ -87,6 +87,7 @@ The new policy can be used by setting the environment variable | |||||||
|   [#2493](https://github.com/juanfont/headscale/pull/2493) |   [#2493](https://github.com/juanfont/headscale/pull/2493) | ||||||
|   - If a OIDC provider doesn't include the `email_verified` claim in its ID |   - If a OIDC provider doesn't include the `email_verified` claim in its ID | ||||||
|     tokens, Headscale will attempt to get it from the UserInfo endpoint. |     tokens, Headscale will attempt to get it from the UserInfo endpoint. | ||||||
|  | - Improve performance by only querying relevant nodes from the database for node updates [#2509](https://github.com/juanfont/headscale/pull/2509) | ||||||
| 
 | 
 | ||||||
| ## 0.25.1 (2025-02-25) | ## 0.25.1 (2025-02-25) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -35,21 +35,26 @@ var ( | |||||||
| 	) | 	) | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||||
|  | // If no peer IDs are given, all peers are returned. | ||||||
|  | // If at least one peer ID is given, only these peer nodes will be returned. | ||||||
|  | func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||||
| 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | ||||||
| 		return ListPeers(rx, nodeID) | 		return ListPeers(rx, nodeID, peerIDs...) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListPeers returns all peers of node, regardless of any Policy or if the node is expired. | // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||||
| func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { | // If no peer IDs are given, all peers are returned. | ||||||
|  | // If at least one peer ID is given, only these peer nodes will be returned. | ||||||
|  | func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||||
| 	nodes := types.Nodes{} | 	nodes := types.Nodes{} | ||||||
| 	if err := tx. | 	if err := tx. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| 		Preload("AuthKey.User"). | 		Preload("AuthKey.User"). | ||||||
| 		Preload("User"). | 		Preload("User"). | ||||||
| 		Where("id <> ?", | 		Where("id <> ?", nodeID). | ||||||
| 			nodeID).Find(&nodes).Error; err != nil { | 		Where(peerIDs).Find(&nodes).Error; err != nil { | ||||||
| 		return types.Nodes{}, err | 		return types.Nodes{}, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -58,19 +63,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { | |||||||
| 	return nodes, nil | 	return nodes, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { | // ListNodes queries the database for either all nodes if no parameters are given | ||||||
|  | // or for the given nodes if at least one node ID is given as parameter | ||||||
|  | func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||||
| 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | 	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { | ||||||
| 		return ListNodes(rx) | 		return ListNodes(rx, nodeIDs...) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func ListNodes(tx *gorm.DB) (types.Nodes, error) { | // ListNodes queries the database for either all nodes if no parameters are given | ||||||
|  | // or for the given nodes if at least one node ID is given as parameter | ||||||
|  | func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||||
| 	nodes := types.Nodes{} | 	nodes := types.Nodes{} | ||||||
| 	if err := tx. | 	if err := tx. | ||||||
| 		Preload("AuthKey"). | 		Preload("AuthKey"). | ||||||
| 		Preload("AuthKey.User"). | 		Preload("AuthKey.User"). | ||||||
| 		Preload("User"). | 		Preload("User"). | ||||||
| 		Find(&nodes).Error; err != nil { | 		Where(nodeIDs).Find(&nodes).Error; err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -747,3 +747,174 @@ func TestRenameNode(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	assert.ErrorContains(t, err, "name is not unique") | 	assert.ErrorContains(t, err, "name is not unique") | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestListPeers(t *testing.T) { | ||||||
|  | 	// Setup test database | ||||||
|  | 	db, err := newSQLiteTestDB() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("creating db: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user, err := db.CreateUser(types.User{Name: "test"}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	user2, err := db.CreateUser(types.User{Name: "user2"}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	node1 := types.Node{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     key.NewMachine().Public(), | ||||||
|  | 		NodeKey:        key.NewNode().Public(), | ||||||
|  | 		Hostname:       "test1", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	node2 := types.Node{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     key.NewMachine().Public(), | ||||||
|  | 		NodeKey:        key.NewNode().Public(), | ||||||
|  | 		Hostname:       "test2", | ||||||
|  | 		UserID:         user2.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Save(&node1).Error | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Save(&node2).Error | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Transaction(func(tx *gorm.DB) error { | ||||||
|  | 		_, err := RegisterNode(tx, node1, nil, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		_, err = RegisterNode(tx, node2, nil, nil) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	nodes, err := db.ListNodes() | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	assert.Len(t, nodes, 2) | ||||||
|  | 
 | ||||||
|  | 	// No parameter means no filter, should return all peers | ||||||
|  | 	nodes, err = db.ListPeers(1) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 1) | ||||||
|  | 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// Empty node list should return all peers | ||||||
|  | 	nodes, err = db.ListPeers(1, types.NodeIDs{}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 1) | ||||||
|  | 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// No match in IDs should return empty list and no error | ||||||
|  | 	nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 0) | ||||||
|  | 
 | ||||||
|  | 	// Partial match in IDs | ||||||
|  | 	nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 1) | ||||||
|  | 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// Several matched IDs, but node ID is still filtered out | ||||||
|  | 	nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 1) | ||||||
|  | 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestListNodes(t *testing.T) { | ||||||
|  | 	// Setup test database | ||||||
|  | 	db, err := newSQLiteTestDB() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("creating db: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user, err := db.CreateUser(types.User{Name: "test"}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	user2, err := db.CreateUser(types.User{Name: "user2"}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	node1 := types.Node{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     key.NewMachine().Public(), | ||||||
|  | 		NodeKey:        key.NewNode().Public(), | ||||||
|  | 		Hostname:       "test1", | ||||||
|  | 		UserID:         user.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	node2 := types.Node{ | ||||||
|  | 		ID:             0, | ||||||
|  | 		MachineKey:     key.NewMachine().Public(), | ||||||
|  | 		NodeKey:        key.NewNode().Public(), | ||||||
|  | 		Hostname:       "test2", | ||||||
|  | 		UserID:         user2.ID, | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  | 		Hostinfo:       &tailcfg.Hostinfo{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Save(&node1).Error | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Save(&node2).Error | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	err = db.DB.Transaction(func(tx *gorm.DB) error { | ||||||
|  | 		_, err := RegisterNode(tx, node1, nil, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		_, err = RegisterNode(tx, node2, nil, nil) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	nodes, err := db.ListNodes() | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	assert.Len(t, nodes, 2) | ||||||
|  | 
 | ||||||
|  | 	// No parameter means no filter, should return all nodes | ||||||
|  | 	nodes, err = db.ListNodes() | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 2) | ||||||
|  | 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||||
|  | 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// Empty node list should return all nodes | ||||||
|  | 	nodes, err = db.ListNodes(types.NodeIDs{}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 2) | ||||||
|  | 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||||
|  | 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// No match in IDs should return empty list and no error | ||||||
|  | 	nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 0) | ||||||
|  | 
 | ||||||
|  | 	// Partial match in IDs | ||||||
|  | 	nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 1) | ||||||
|  | 	assert.Equal(t, "test2", nodes[0].Hostname) | ||||||
|  | 
 | ||||||
|  | 	// Several matched IDs | ||||||
|  | 	nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, len(nodes), 2) | ||||||
|  | 	assert.Equal(t, "test1", nodes[0].Hostname) | ||||||
|  | 	assert.Equal(t, "test2", nodes[1].Hostname) | ||||||
|  | } | ||||||
|  | |||||||
| @ -255,27 +255,25 @@ func (m *Mapper) PeerChangedResponse( | |||||||
| 	patches []*tailcfg.PeerChange, | 	patches []*tailcfg.PeerChange, | ||||||
| 	messages ...string, | 	messages ...string, | ||||||
| ) ([]byte, error) { | ) ([]byte, error) { | ||||||
|  | 	var err error | ||||||
| 	resp := m.baseMapResponse() | 	resp := m.baseMapResponse() | ||||||
| 
 | 
 | ||||||
| 	peers, err := m.ListPeers(node.ID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var removedIDs []tailcfg.NodeID | 	var removedIDs []tailcfg.NodeID | ||||||
| 	var changedIDs []types.NodeID | 	var changedIDs []types.NodeID | ||||||
| 	for nodeID, nodeChanged := range changed { | 	for nodeID, nodeChanged := range changed { | ||||||
| 		if nodeChanged { | 		if nodeChanged { | ||||||
| 			changedIDs = append(changedIDs, nodeID) | 			if nodeID != node.ID { | ||||||
|  | 				changedIDs = append(changedIDs, nodeID) | ||||||
|  | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			removedIDs = append(removedIDs, nodeID.NodeID()) | 			removedIDs = append(removedIDs, nodeID.NodeID()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 	changedNodes := types.Nodes{} | ||||||
| 	changedNodes := make(types.Nodes, 0, len(changedIDs)) | 	if len(changedIDs) > 0 { | ||||||
| 	for _, peer := range peers { | 		changedNodes, err = m.ListNodes(changedIDs...) | ||||||
| 		if slices.Contains(changedIDs, peer.ID) { | 		if err != nil { | ||||||
| 			changedNodes = append(changedNodes, peer) | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -482,8 +480,11 @@ func (m *Mapper) baseWithConfigMapResponse( | |||||||
| 	return &resp, nil | 	return &resp, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | // ListPeers returns peers of node, regardless of any Policy or if the node is expired. | ||||||
| 	peers, err := m.db.ListPeers(nodeID) | // If no peer IDs are given, all peers are returned. | ||||||
|  | // If at least one peer ID is given, only these peer nodes will be returned. | ||||||
|  | func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { | ||||||
|  | 	peers, err := m.db.ListPeers(nodeID, peerIDs...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -496,6 +497,22 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { | |||||||
| 	return peers, nil | 	return peers, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ListNodes queries the database for either all nodes if no parameters are given | ||||||
|  | // or for the given nodes if at least one node ID is given as parameter | ||||||
|  | func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { | ||||||
|  | 	nodes, err := m.db.ListNodes(nodeIDs...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, node := range nodes { | ||||||
|  | 		online := m.notif.IsLikelyConnected(node.ID) | ||||||
|  | 		node.IsOnline = &online | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nodes, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { | func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { | ||||||
| 	ret := make(types.Nodes, 0) | 	ret := make(types.Nodes, 0) | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user