mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	Consolidate machine related lookups
This commit moves the routes lookup functions to be subcommands of Machine, making them a lot simpler and more specific/composable. It also moves the register command from cli.go into machine, so we can clear out the extra file. Finally a toProto function has been added to convert between the machine database model and the proto/rpc model.
This commit is contained in:
		
							parent
							
								
									67adea5cab
								
							
						
					
					
						commit
						787814ea89
					
				
							
								
								
									
										43
									
								
								cli.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								cli.go
									
									
									
									
									
								
							| @ -1,43 +0,0 @@ | ||||
| package headscale | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/types/wgkey" | ||||
| ) | ||||
| 
 | ||||
| // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey | ||||
| func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { | ||||
| 	ns, err := h.GetNamespace(namespace) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	mKey, err := wgkey.ParseHex(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	m := Machine{} | ||||
| 	if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { | ||||
| 		return nil, errors.New("Machine not found") | ||||
| 	} | ||||
| 
 | ||||
| 	h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered | ||||
| 
 | ||||
| 	if m.isAlreadyRegistered() { | ||||
| 		return nil, errors.New("Machine already registered") | ||||
| 	} | ||||
| 
 | ||||
| 	ip, err := h.getAvailableIP() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	m.IPAddress = ip.String() | ||||
| 	m.NamespaceID = ns.ID | ||||
| 	m.Registered = true | ||||
| 	m.RegisterMethod = "cli" | ||||
| 	h.db.Save(&m) | ||||
| 
 | ||||
| 	return &m, nil | ||||
| } | ||||
							
								
								
									
										249
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										249
									
								
								machine.go
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ package headscale | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| @ -10,8 +11,11 @@ import ( | ||||
| 
 | ||||
| 	"github.com/fatih/set" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"gorm.io/datatypes" | ||||
| 	"gorm.io/gorm" | ||||
| 	"inet.af/netaddr" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/wgkey" | ||||
| @ -91,7 +95,7 @@ func (h *Headscale) updateMachineExpiry(m *Machine) { | ||||
| 
 | ||||
| func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getDirectPeers"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msg("Finding direct peers") | ||||
| 
 | ||||
| @ -105,7 +109,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { | ||||
| 	sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getDirectmachines"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msgf("Found direct machines: %s", machines.String()) | ||||
| 	return machines, nil | ||||
| @ -114,7 +118,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { | ||||
| // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for | ||||
| func (h *Headscale) getShared(m *Machine) (Machines, error) { | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getShared"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msg("Finding shared peers") | ||||
| 
 | ||||
| @ -132,7 +136,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { | ||||
| 	sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getShared"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msgf("Found shared peers: %s", peers.String()) | ||||
| 	return peers, nil | ||||
| @ -141,7 +145,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { | ||||
| // getSharedTo fetches the machines of the namespaces this machine is shared in | ||||
| func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getSharedTo"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msg("Finding peers in namespaces this machine is shared with") | ||||
| 
 | ||||
| @ -157,13 +161,13 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { | ||||
| 		if err != nil { | ||||
| 			return Machines{}, err | ||||
| 		} | ||||
| 		peers = append(peers, *namespaceMachines...) | ||||
| 		peers = append(peers, namespaceMachines...) | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getSharedTo"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msgf("Found peers we are shared with: %s", peers.String()) | ||||
| 	return peers, nil | ||||
| @ -173,7 +177,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { | ||||
| 	direct, err := h.getDirectPeers(m) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Str("func", "getPeers"). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot fetch peers") | ||||
| 		return Machines{}, err | ||||
| @ -182,7 +186,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { | ||||
| 	shared, err := h.getShared(m) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Str("func", "getShared"). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot fetch peers") | ||||
| 		return Machines{}, err | ||||
| @ -191,7 +195,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { | ||||
| 	sharedTo, err := h.getSharedTo(m) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Str("func", "sharedTo"). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot fetch peers") | ||||
| 		return Machines{}, err | ||||
| @ -203,13 +207,21 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { | ||||
| 	sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Str("func", "getShared"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msgf("Found total peers: %s", peers.String()) | ||||
| 
 | ||||
| 	return peers, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) ListMachines() ([]Machine, error) { | ||||
| 	machines := []Machine{} | ||||
| 	if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return machines, nil | ||||
| } | ||||
| 
 | ||||
| // GetMachine finds a Machine by name and namespace and returns the Machine struct | ||||
| func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { | ||||
| 	machines, err := h.ListMachinesInNamespace(namespace) | ||||
| @ -217,7 +229,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, m := range *machines { | ||||
| 	for _, m := range machines { | ||||
| 		if m.Name == name { | ||||
| 			return &m, nil | ||||
| 		} | ||||
| @ -326,7 +338,7 @@ func (h *Headscale) isOutdated(m *Machine) bool { | ||||
| 
 | ||||
| 	lastChange := h.getLastStateChange(namespaces...) | ||||
| 	log.Trace(). | ||||
| 		Str("func", "keepAlive"). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Time("last_successful_update", *m.LastSuccessfulUpdate). | ||||
| 		Time("last_state_change", lastChange). | ||||
| @ -405,7 +417,7 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include | ||||
| 	ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) | ||||
| 	if err != nil { | ||||
| 		log.Trace(). | ||||
| 			Str("func", "toNode"). | ||||
| 			Caller(). | ||||
| 			Str("ip", m.IPAddress). | ||||
| 			Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) | ||||
| 		return nil, err | ||||
| @ -508,3 +520,212 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include | ||||
| 	} | ||||
| 	return &n, nil | ||||
| } | ||||
| 
 | ||||
| func (m *Machine) toProto() *v1.Machine { | ||||
| 	machine := &v1.Machine{ | ||||
| 		Id:         m.ID, | ||||
| 		MachineKey: m.MachineKey, | ||||
| 
 | ||||
| 		NodeKey:   m.NodeKey, | ||||
| 		DiscoKey:  m.DiscoKey, | ||||
| 		IpAddress: m.IPAddress, | ||||
| 		Name:      m.Name, | ||||
| 		Namespace: m.Namespace.toProto(), | ||||
| 
 | ||||
| 		Registered: m.Registered, | ||||
| 
 | ||||
| 		// TODO(kradalby): Implement register method enum converter | ||||
| 		// RegisterMethod: , | ||||
| 
 | ||||
| 		CreatedAt: timestamppb.New(m.CreatedAt), | ||||
| 	} | ||||
| 
 | ||||
| 	if m.AuthKey != nil { | ||||
| 		machine.PreAuthKey = m.AuthKey.toProto() | ||||
| 	} | ||||
| 
 | ||||
| 	if m.LastSeen != nil { | ||||
| 		machine.LastSeen = timestamppb.New(*m.LastSeen) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.LastSuccessfulUpdate != nil { | ||||
| 		machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.Expiry != nil { | ||||
| 		machine.Expiry = timestamppb.New(*m.Expiry) | ||||
| 	} | ||||
| 
 | ||||
| 	return machine | ||||
| } | ||||
| 
 | ||||
| // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey | ||||
| func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { | ||||
| 	ns, err := h.GetNamespace(namespace) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	mKey, err := wgkey.ParseHex(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	m := Machine{} | ||||
| 	if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { | ||||
| 		return nil, errors.New("Machine not found") | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Msg("Attempting to register machine") | ||||
| 
 | ||||
| 	if m.isAlreadyRegistered() { | ||||
| 		err := errors.New("Machine already registered") | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Str("machine", m.Name). | ||||
| 			Msg("Attempting to register machine") | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	ip, err := h.getAvailableIP() | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Str("machine", m.Name). | ||||
| 			Msg("Could not find IP for the new machine") | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Str("ip", ip.String()). | ||||
| 		Msg("Found IP for host") | ||||
| 
 | ||||
| 	m.IPAddress = ip.String() | ||||
| 	m.NamespaceID = ns.ID | ||||
| 	m.Registered = true | ||||
| 	m.RegisterMethod = "cli" | ||||
| 	h.db.Save(&m) | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", m.Name). | ||||
| 		Str("ip", ip.String()). | ||||
| 		Msg("Machine registered with the database") | ||||
| 
 | ||||
| 	return &m, nil | ||||
| } | ||||
| 
 | ||||
| func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { | ||||
| 	hostInfo, err := m.GetHostInfo() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return hostInfo.RoutableIPs, nil | ||||
| } | ||||
| 
 | ||||
| func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { | ||||
| 	data, err := m.EnabledRoutes.MarshalJSON() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	routesStr := []string{} | ||||
| 	err = json.Unmarshal(data, &routesStr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	routes := make([]netaddr.IPPrefix, len(routesStr)) | ||||
| 	for index, routeStr := range routesStr { | ||||
| 		route, err := netaddr.ParseIPPrefix(routeStr) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		routes[index] = route | ||||
| 	} | ||||
| 
 | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| func (m *Machine) IsRoutesEnabled(routeStr string) bool { | ||||
| 	route, err := netaddr.ParseIPPrefix(routeStr) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	enabledRoutes, err := m.GetEnabledRoutes() | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	for _, enabledRoute := range enabledRoutes { | ||||
| 		if route == enabledRoute { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the | ||||
| // previous list of routes. | ||||
| func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { | ||||
| 	newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) | ||||
| 	for index, routeStr := range routeStrs { | ||||
| 		route, err := netaddr.ParseIPPrefix(routeStr) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		newRoutes[index] = route | ||||
| 	} | ||||
| 
 | ||||
| 	availableRoutes, err := m.GetAdvertisedRoutes() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, newRoute := range newRoutes { | ||||
| 		if !containsIpPrefix(availableRoutes, newRoute) { | ||||
| 			return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	routes, err := json.Marshal(newRoutes) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	m.EnabledRoutes = datatypes.JSON(routes) | ||||
| 	h.db.Save(&m) | ||||
| 
 | ||||
| 	err = h.RequestMapUpdates(m.NamespaceID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (m *Machine) RoutesToProto() (*v1.Routes, error) { | ||||
| 	availableRoutes, err := m.GetAdvertisedRoutes() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	enabledRoutes, err := m.GetEnabledRoutes() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &v1.Routes{ | ||||
| 		AdvertisedRoutes: ipPrefixToString(availableRoutes), | ||||
| 		EnabledRoutes:    ipPrefixToString(enabledRoutes), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
							
								
								
									
										18
									
								
								routes.go
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								routes.go
									
									
									
									
									
								
							| @ -3,13 +3,12 @@ package headscale | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 
 | ||||
| 	"github.com/pterm/pterm" | ||||
| 	"gorm.io/datatypes" | ||||
| 	"inet.af/netaddr" | ||||
| ) | ||||
| 
 | ||||
| // Deprecated: use machine function instead | ||||
| // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by | ||||
| // namespace and node name) | ||||
| func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { | ||||
| @ -25,6 +24,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) ( | ||||
| 	return &hostInfo.RoutableIPs, nil | ||||
| } | ||||
| 
 | ||||
| // Deprecated: use machine function instead | ||||
| // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by | ||||
| // namespace and node name) | ||||
| func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) { | ||||
| @ -56,6 +56,7 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| // Deprecated: use machine function instead | ||||
| // IsNodeRouteEnabled checks if a certain route has been enabled | ||||
| func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool { | ||||
| 	route, err := netaddr.ParseIPPrefix(routeStr) | ||||
| @ -76,6 +77,7 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // Deprecated: use EnableRoute in machine.go | ||||
| // EnableNodeRoute enables a subnet route advertised by a node (identified by | ||||
| // namespace and node name) | ||||
| func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { | ||||
| @ -129,15 +131,3 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RoutesToPtables converts the list of routes to a nice table | ||||
| func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData { | ||||
| 	d := pterm.TableData{{"Route", "Enabled"}} | ||||
| 
 | ||||
| 	for _, route := range availableRoutes { | ||||
| 		enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String()) | ||||
| 
 | ||||
| 		d = append(d, []string{route.String(), strconv.FormatBool(enabled)}) | ||||
| 	} | ||||
| 	return d | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user