mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 10:01:05 +01:00 
			
		
		
		
	move MapResponse peer logic into function and reuse
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									387aa03adb
								
							
						
					
					
						commit
						432e975a7f
					
				@ -92,6 +92,8 @@ type Headscale struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	shutdownChan       chan struct{}
 | 
						shutdownChan       chan struct{}
 | 
				
			||||||
	pollNetMapStreamWG sync.WaitGroup
 | 
						pollNetMapStreamWG sync.WaitGroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pollStreamOpenMu sync.Mutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
					func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			||||||
 | 
				
			|||||||
@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							machine := &route.Machine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !route.IsPrimary {
 | 
							if !route.IsPrimary {
 | 
				
			||||||
			_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
 | 
								_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
 | 
				
			||||||
			if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
 | 
								if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
 | 
				
			||||||
@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
				
			|||||||
					return err
 | 
										return err
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				changedMachines = append(changedMachines, &route.Machine)
 | 
									changedMachines = append(changedMachines, machine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
				
			|||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			changedMachines = append(changedMachines, &route.Machine)
 | 
								changedMachines = append(changedMachines, machine)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -38,6 +38,16 @@ const (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
 | 
					var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Optimise
 | 
				
			||||||
 | 
					// As this work continues, the idea is that there will be one Mapper instance
 | 
				
			||||||
 | 
					// per node, attached to the open stream between the control and client.
 | 
				
			||||||
 | 
					// This means that this can hold a state per machine and we can use that to
 | 
				
			||||||
 | 
					// improve the mapresponses sent.
 | 
				
			||||||
 | 
					// We could:
 | 
				
			||||||
 | 
					// - Keep information about the previous mapresponse so we can send a diff
 | 
				
			||||||
 | 
					// - Store hashes
 | 
				
			||||||
 | 
					// - Create a "minifier" that removes info not needed for the node
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Mapper struct {
 | 
					type Mapper struct {
 | 
				
			||||||
	privateKey2019 *key.MachinePrivate
 | 
						privateKey2019 *key.MachinePrivate
 | 
				
			||||||
	isNoise        bool
 | 
						isNoise        bool
 | 
				
			||||||
@ -102,105 +112,6 @@ func (m *Mapper) String() string {
 | 
				
			|||||||
	return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
 | 
						return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO: Optimise
 | 
					 | 
				
			||||||
// As this work continues, the idea is that there will be one Mapper instance
 | 
					 | 
				
			||||||
// per node, attached to the open stream between the control and client.
 | 
					 | 
				
			||||||
// This means that this can hold a state per machine and we can use that to
 | 
					 | 
				
			||||||
// improve the mapresponses sent.
 | 
					 | 
				
			||||||
// We could:
 | 
					 | 
				
			||||||
// - Keep information about the previous mapresponse so we can send a diff
 | 
					 | 
				
			||||||
// - Store hashes
 | 
					 | 
				
			||||||
// - Create a "minifier" that removes info not needed for the node
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// fullMapResponse is the internal function for generating a MapResponse
 | 
					 | 
				
			||||||
// for a machine.
 | 
					 | 
				
			||||||
func fullMapResponse(
 | 
					 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	machine *types.Machine,
 | 
					 | 
				
			||||||
	peers types.Machines,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	baseDomain string,
 | 
					 | 
				
			||||||
	dnsCfg *tailcfg.DNSConfig,
 | 
					 | 
				
			||||||
	derpMap *tailcfg.DERPMap,
 | 
					 | 
				
			||||||
	logtail bool,
 | 
					 | 
				
			||||||
	randomClientPort bool,
 | 
					 | 
				
			||||||
) (*tailcfg.MapResponse, error) {
 | 
					 | 
				
			||||||
	tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	now := time.Now()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp := tailcfg.MapResponse{
 | 
					 | 
				
			||||||
		Node: tailnode,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		DERPMap: derpMap,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		Domain: baseDomain,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Do not instruct clients to collect services we do not
 | 
					 | 
				
			||||||
		// support or do anything with them
 | 
					 | 
				
			||||||
		CollectServices: "false",
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		ControlTime:  &now,
 | 
					 | 
				
			||||||
		KeepAlive:    false,
 | 
					 | 
				
			||||||
		OnlineChange: db.OnlineMachineMap(peers),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		Debug: &tailcfg.Debug{
 | 
					 | 
				
			||||||
			DisableLogTail:      !logtail,
 | 
					 | 
				
			||||||
			RandomizeClientPort: randomClientPort,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if peers != nil || len(peers) > 0 {
 | 
					 | 
				
			||||||
		rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
					 | 
				
			||||||
			pol,
 | 
					 | 
				
			||||||
			machine,
 | 
					 | 
				
			||||||
			peers,
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Filter out peers that have expired.
 | 
					 | 
				
			||||||
		peers = filterExpiredAndNotReady(peers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// If there are filter rules present, see if there are any machines that cannot
 | 
					 | 
				
			||||||
		// access eachother at all and remove them from the peers.
 | 
					 | 
				
			||||||
		if len(rules) > 0 {
 | 
					 | 
				
			||||||
			peers = policy.FilterMachinesByACL(machine, peers, rules)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		profiles := generateUserProfiles(machine, peers, baseDomain)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		dnsConfig := generateDNSConfig(
 | 
					 | 
				
			||||||
			dnsCfg,
 | 
					 | 
				
			||||||
			baseDomain,
 | 
					 | 
				
			||||||
			machine,
 | 
					 | 
				
			||||||
			peers,
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Peers is always returned sorted by Node.ID.
 | 
					 | 
				
			||||||
		sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
					 | 
				
			||||||
			return tailPeers[x].ID < tailPeers[y].ID
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		resp.Peers = tailPeers
 | 
					 | 
				
			||||||
		resp.DNSConfig = dnsConfig
 | 
					 | 
				
			||||||
		resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
					 | 
				
			||||||
		resp.UserProfiles = profiles
 | 
					 | 
				
			||||||
		resp.SSHPolicy = sshPolicy
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &resp, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func generateUserProfiles(
 | 
					func generateUserProfiles(
 | 
				
			||||||
	machine *types.Machine,
 | 
						machine *types.Machine,
 | 
				
			||||||
	peers types.Machines,
 | 
						peers types.Machines,
 | 
				
			||||||
@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// fullMapResponse creates a complete MapResponse for a node.
 | 
				
			||||||
 | 
					// It is a separate function to make testing easier.
 | 
				
			||||||
 | 
					func (m *Mapper) fullMapResponse(
 | 
				
			||||||
 | 
						machine *types.Machine,
 | 
				
			||||||
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
 | 
						peers := machineMapToList(m.peers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := m.baseWithConfigMapResponse(machine, pol)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO(kradalby): Move this into appendPeerChanges?
 | 
				
			||||||
 | 
						resp.OnlineChange = db.OnlineMachineMap(peers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = appendPeerChanges(
 | 
				
			||||||
 | 
							resp,
 | 
				
			||||||
 | 
							pol,
 | 
				
			||||||
 | 
							machine,
 | 
				
			||||||
 | 
							peers,
 | 
				
			||||||
 | 
							peers,
 | 
				
			||||||
 | 
							m.baseDomain,
 | 
				
			||||||
 | 
							m.dnsCfg,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return resp, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// FullMapResponse returns a MapResponse for the given machine.
 | 
					// FullMapResponse returns a MapResponse for the given machine.
 | 
				
			||||||
func (m *Mapper) FullMapResponse(
 | 
					func (m *Mapper) FullMapResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse(
 | 
				
			|||||||
	m.mu.Lock()
 | 
						m.mu.Lock()
 | 
				
			||||||
	defer m.mu.Unlock()
 | 
						defer m.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mapResponse, err := fullMapResponse(
 | 
						resp, err := m.fullMapResponse(machine, pol)
 | 
				
			||||||
		pol,
 | 
					 | 
				
			||||||
		machine,
 | 
					 | 
				
			||||||
		machineMapToList(m.peers),
 | 
					 | 
				
			||||||
		m.baseDomain,
 | 
					 | 
				
			||||||
		m.dnsCfg,
 | 
					 | 
				
			||||||
		m.derpMap,
 | 
					 | 
				
			||||||
		m.logtail,
 | 
					 | 
				
			||||||
		m.randomClientPort,
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if m.isNoise {
 | 
						if m.isNoise {
 | 
				
			||||||
		return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
							return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// LiteMapResponse returns a MapResponse for the given machine.
 | 
					// LiteMapResponse returns a MapResponse for the given machine.
 | 
				
			||||||
@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse(
 | 
				
			|||||||
	machine *types.Machine,
 | 
						machine *types.Machine,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	mapResponse, err := fullMapResponse(
 | 
						resp, err := m.baseWithConfigMapResponse(machine, pol)
 | 
				
			||||||
		pol,
 | 
					 | 
				
			||||||
		machine,
 | 
					 | 
				
			||||||
		nil,
 | 
					 | 
				
			||||||
		m.baseDomain,
 | 
					 | 
				
			||||||
		m.dnsCfg,
 | 
					 | 
				
			||||||
		m.derpMap,
 | 
					 | 
				
			||||||
		m.logtail,
 | 
					 | 
				
			||||||
		m.randomClientPort,
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if m.isNoise {
 | 
						if m.isNoise {
 | 
				
			||||||
		return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
							return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *Mapper) KeepAliveResponse(
 | 
					func (m *Mapper) KeepAliveResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	machine *types.Machine,
 | 
						machine *types.Machine,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	resp := m.baseMapResponse(machine)
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
	resp.KeepAlive = true
 | 
						resp.KeepAlive = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
				
			||||||
@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse(
 | 
				
			|||||||
	machine *types.Machine,
 | 
						machine *types.Machine,
 | 
				
			||||||
	derpMap tailcfg.DERPMap,
 | 
						derpMap tailcfg.DERPMap,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	resp := m.baseMapResponse(machine)
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
	resp.DERPMap = &derpMap
 | 
						resp.DERPMap = &derpMap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
				
			||||||
@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
	m.mu.Lock()
 | 
						m.mu.Lock()
 | 
				
			||||||
	defer m.mu.Unlock()
 | 
						defer m.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var err error
 | 
					 | 
				
			||||||
	lastSeen := make(map[tailcfg.NodeID]bool)
 | 
						lastSeen := make(map[tailcfg.NodeID]bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Update our internal map.
 | 
						// Update our internal map.
 | 
				
			||||||
@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
		lastSeen[tailcfg.NodeID(machine.ID)] = true
 | 
							lastSeen[tailcfg.NodeID(machine.ID)] = true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := appendPeerChanges(
 | 
				
			||||||
 | 
							&resp,
 | 
				
			||||||
		pol,
 | 
							pol,
 | 
				
			||||||
		machine,
 | 
							machine,
 | 
				
			||||||
		machineMapToList(m.peers),
 | 
							machineMapToList(m.peers),
 | 
				
			||||||
 | 
							changed,
 | 
				
			||||||
 | 
							m.baseDomain,
 | 
				
			||||||
 | 
							m.dnsCfg,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	changed = filterExpiredAndNotReady(changed)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If there are filter rules present, see if there are any machines that cannot
 | 
					 | 
				
			||||||
	// access eachother at all and remove them from the changed.
 | 
					 | 
				
			||||||
	if len(rules) > 0 {
 | 
					 | 
				
			||||||
		changed = policy.FilterMachinesByACL(machine, changed, rules)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Peers is always returned sorted by Node.ID.
 | 
					 | 
				
			||||||
	sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
					 | 
				
			||||||
		return tailPeers[x].ID < tailPeers[y].ID
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp := m.baseMapResponse(machine)
 | 
					 | 
				
			||||||
	resp.PeersChanged = tailPeers
 | 
					 | 
				
			||||||
	resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
					 | 
				
			||||||
	resp.SSHPolicy = sshPolicy
 | 
					 | 
				
			||||||
	// resp.PeerSeenChange = lastSeen
 | 
						// resp.PeerSeenChange = lastSeen
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
				
			||||||
@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse(
 | 
				
			|||||||
		delete(m.peers, uint64(id))
 | 
							delete(m.peers, uint64(id))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp := m.baseMapResponse(machine)
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
	resp.PeersRemoved = removed
 | 
						resp.PeersRemoved = removed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
				
			||||||
@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse(
 | 
				
			|||||||
			panic(err)
 | 
								panic(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		now := time.Now().Unix()
 | 
							now := time.Now().UnixNano()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		mapResponsePath := path.Join(
 | 
							mapResponsePath := path.Join(
 | 
				
			||||||
			mPath,
 | 
								mPath,
 | 
				
			||||||
@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{
 | 
				
			|||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
 | 
					// baseMapResponse returns a tailcfg.MapResponse with
 | 
				
			||||||
 | 
					// KeepAlive false and ControlTime set to now.
 | 
				
			||||||
 | 
					func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
 | 
				
			||||||
	now := time.Now()
 | 
						now := time.Now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp := tailcfg.MapResponse{
 | 
						resp := tailcfg.MapResponse{
 | 
				
			||||||
@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
 | 
				
			|||||||
		ControlTime: &now,
 | 
							ControlTime: &now,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// online, err := m.db.ListOnlineMachines(machine)
 | 
					 | 
				
			||||||
	// if err == nil {
 | 
					 | 
				
			||||||
	// 	resp.OnlineChange = online
 | 
					 | 
				
			||||||
	// }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return resp
 | 
						return resp
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
 | 
				
			||||||
 | 
					// with the basic configuration from headscale set.
 | 
				
			||||||
 | 
					// It is used in for bigger updates, such as full and lite, not
 | 
				
			||||||
 | 
					// incremental.
 | 
				
			||||||
 | 
					func (m *Mapper) baseWithConfigMapResponse(
 | 
				
			||||||
 | 
						machine *types.Machine,
 | 
				
			||||||
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Node = tailnode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.DERPMap = m.derpMap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.Domain = m.baseDomain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Do not instruct clients to collect services we do not
 | 
				
			||||||
 | 
						// support or do anything with them
 | 
				
			||||||
 | 
						resp.CollectServices = "false"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.KeepAlive = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.Debug = &tailcfg.Debug{
 | 
				
			||||||
 | 
							DisableLogTail:      !m.logtail,
 | 
				
			||||||
 | 
							RandomizeClientPort: m.randomClientPort,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &resp, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
 | 
					func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
 | 
				
			||||||
	ret := make(types.Machines, 0)
 | 
						ret := make(types.Machines, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines {
 | 
				
			|||||||
		return !item.IsExpired() || len(item.Endpoints) > 0
 | 
							return !item.IsExpired() || len(item.Endpoints) > 0
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// appendPeerChanges mutates a tailcfg.MapResponse with all the
 | 
				
			||||||
 | 
					// necessary changes when peers have changed.
 | 
				
			||||||
 | 
					func appendPeerChanges(
 | 
				
			||||||
 | 
						resp *tailcfg.MapResponse,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
 | 
						machine *types.Machine,
 | 
				
			||||||
 | 
						peers types.Machines,
 | 
				
			||||||
 | 
						changed types.Machines,
 | 
				
			||||||
 | 
						baseDomain string,
 | 
				
			||||||
 | 
						dnsCfg *tailcfg.DNSConfig,
 | 
				
			||||||
 | 
					) error {
 | 
				
			||||||
 | 
						fullChange := len(peers) == len(changed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
				
			||||||
 | 
							pol,
 | 
				
			||||||
 | 
							machine,
 | 
				
			||||||
 | 
							peers,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Filter out peers that have expired.
 | 
				
			||||||
 | 
						changed = filterExpiredAndNotReady(changed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If there are filter rules present, see if there are any machines that cannot
 | 
				
			||||||
 | 
						// access eachother at all and remove them from the peers.
 | 
				
			||||||
 | 
						if len(rules) > 0 {
 | 
				
			||||||
 | 
							changed = policy.FilterMachinesByACL(machine, changed, rules)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						profiles := generateUserProfiles(machine, changed, baseDomain)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dnsConfig := generateDNSConfig(
 | 
				
			||||||
 | 
							dnsCfg,
 | 
				
			||||||
 | 
							baseDomain,
 | 
				
			||||||
 | 
							machine,
 | 
				
			||||||
 | 
							peers,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Peers is always returned sorted by Node.ID.
 | 
				
			||||||
 | 
						sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
				
			||||||
 | 
							return tailPeers[x].ID < tailPeers[y].ID
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if fullChange {
 | 
				
			||||||
 | 
							resp.Peers = tailPeers
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							resp.PeersChanged = tailPeers
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.DNSConfig = dnsConfig
 | 
				
			||||||
 | 
						resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
				
			||||||
 | 
						resp.UserProfiles = profiles
 | 
				
			||||||
 | 
						resp.SSHPolicy = sshPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -441,7 +441,9 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
						},
 | 
											},
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
 | 
									UserProfiles: []tailcfg.UserProfile{
 | 
				
			||||||
 | 
										{LoginName: "mini", DisplayName: "mini"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
				SSHPolicy:   &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
 | 
									SSHPolicy:   &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
 | 
				
			||||||
				ControlTime: &time.Time{},
 | 
									ControlTime: &time.Time{},
 | 
				
			||||||
				Debug: &tailcfg.Debug{
 | 
									Debug: &tailcfg.Debug{
 | 
				
			||||||
@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			got, err := fullMapResponse(
 | 
								mappy := NewMapper(
 | 
				
			||||||
				tt.pol,
 | 
					 | 
				
			||||||
				tt.machine,
 | 
									tt.machine,
 | 
				
			||||||
				tt.peers,
 | 
									tt.peers,
 | 
				
			||||||
 | 
									nil,
 | 
				
			||||||
 | 
									false,
 | 
				
			||||||
 | 
									tt.derpMap,
 | 
				
			||||||
				tt.baseDomain,
 | 
									tt.baseDomain,
 | 
				
			||||||
				tt.dnsConfig,
 | 
									tt.dnsConfig,
 | 
				
			||||||
				tt.derpMap,
 | 
					 | 
				
			||||||
				tt.logtail,
 | 
									tt.logtail,
 | 
				
			||||||
				tt.randomClientPort,
 | 
									tt.randomClientPort,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								got, err := mappy.fullMapResponse(
 | 
				
			||||||
 | 
									tt.machine,
 | 
				
			||||||
 | 
									tt.pol,
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if (err != nil) != tt.wantErr {
 | 
								if (err != nil) != tt.wantErr {
 | 
				
			||||||
				t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
 | 
									t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,8 @@ func logPollFunc(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// handlePoll is the common code for the legacy and Noise protocols to
 | 
					// handlePoll is the common code for the legacy and Noise protocols to
 | 
				
			||||||
// managed the poll loop.
 | 
					// managed the poll loop.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//nolint:gocyclo
 | 
				
			||||||
func (h *Headscale) handlePoll(
 | 
					func (h *Headscale) handlePoll(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
@ -67,6 +69,7 @@ func (h *Headscale) handlePoll(
 | 
				
			|||||||
	// following updates missing
 | 
						// following updates missing
 | 
				
			||||||
	var updateChan chan types.StateUpdate
 | 
						var updateChan chan types.StateUpdate
 | 
				
			||||||
	if mapRequest.Stream {
 | 
						if mapRequest.Stream {
 | 
				
			||||||
 | 
							h.pollStreamOpenMu.Lock()
 | 
				
			||||||
		h.pollNetMapStreamWG.Add(1)
 | 
							h.pollNetMapStreamWG.Add(1)
 | 
				
			||||||
		defer h.pollNetMapStreamWG.Done()
 | 
							defer h.pollNetMapStreamWG.Done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -251,6 +254,8 @@ func (h *Headscale) handlePoll(
 | 
				
			|||||||
	ctx, cancel := context.WithCancel(ctx)
 | 
						ctx, cancel := context.WithCancel(ctx)
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						h.pollStreamOpenMu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		logInfo("Waiting for update on stream channel")
 | 
							logInfo("Waiting for update on stream channel")
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
 | 
				
			|||||||
@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) {
 | 
				
			|||||||
	defer scenario.Shutdown()
 | 
						defer scenario.Shutdown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	spec := map[string]int{
 | 
						spec := map[string]int{
 | 
				
			||||||
		// Omit 1.16.2 (-1) because it does not have the FQDN field
 | 
							"magicdns1": len(MustTestVersions),
 | 
				
			||||||
		"magicdns1": len(MustTestVersions) - 1,
 | 
							"magicdns2": len(MustTestVersions),
 | 
				
			||||||
		"magicdns2": len(MustTestVersions) - 1,
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
 | 
						err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
 | 
				
			||||||
 | 
				
			|||||||
@ -21,6 +21,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	tsicHashLength     = 6
 | 
						tsicHashLength     = 6
 | 
				
			||||||
 | 
						defaultPingTimeout = 300 * time.Millisecond
 | 
				
			||||||
	defaultPingCount   = 10
 | 
						defaultPingCount   = 10
 | 
				
			||||||
	dockerContextPath  = "../."
 | 
						dockerContextPath  = "../."
 | 
				
			||||||
	headscaleCertPath  = "/usr/local/share/ca-certificates/headscale.crt"
 | 
						headscaleCertPath  = "/usr/local/share/ca-certificates/headscale.crt"
 | 
				
			||||||
@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption {
 | 
				
			|||||||
// TODO(kradalby): Make multiping, go routine magic.
 | 
					// TODO(kradalby): Make multiping, go routine magic.
 | 
				
			||||||
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
 | 
					func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
 | 
				
			||||||
	args := pingArgs{
 | 
						args := pingArgs{
 | 
				
			||||||
		timeout: 300 * time.Millisecond,
 | 
							timeout: defaultPingTimeout,
 | 
				
			||||||
		count:   defaultPingCount,
 | 
							count:   defaultPingCount,
 | 
				
			||||||
		direct:  true,
 | 
							direct:  true,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user