mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-25 14:11:06 +02:00 
			
		
		
		
	Rework getAvailableIp
This commit reworks getAvailableIp with a "simpler" version that will look for the first available IP address in our IP Prefix. There is a couple of ideas behind this: * Make the host IPs reasonably predictable and in within similar subnets, which should simplify ACLs for subnets * The code is not random, but deterministic so we can have tests * The code is a bit more understandable (no bit shift magic)
This commit is contained in:
		
							parent
							
								
									309f868a21
								
							
						
					
					
						commit
						b5841c8a8b
					
				| @ -38,7 +38,7 @@ func (s *Suite) ResetDB(c *check.C) { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 	cfg := Config{ | ||||
| 		IPPrefix: netaddr.MustParseIPPrefix("127.0.0.1/32"), | ||||
| 		IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), | ||||
| 	} | ||||
| 
 | ||||
| 	h = Headscale{ | ||||
|  | ||||
| @ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) { | ||||
| 		DiscoKey:    "faa", | ||||
| 		Name:        "testmachine", | ||||
| 		NamespaceID: n.ID, | ||||
| 		IPAddress:   "10.0.0.1", | ||||
| 	} | ||||
| 	h.db.Save(&m) | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										107
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										107
									
								
								utils.go
									
									
									
									
									
								
							| @ -7,18 +7,11 @@ package headscale | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/binary" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"time" | ||||
| 
 | ||||
| 	mathrand "math/rand" | ||||
| 
 | ||||
| 	"golang.org/x/crypto/nacl/box" | ||||
| 	"gorm.io/gorm" | ||||
| 	"inet.af/netaddr" | ||||
| 	"tailscale.com/types/wgkey" | ||||
| ) | ||||
| @ -78,47 +71,73 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err | ||||
| 	return msg, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) getAvailableIP() (*net.IP, error) { | ||||
| 	i := 0 | ||||
| func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { | ||||
| 	ipPrefix := h.cfg.IPPrefix | ||||
| 
 | ||||
| 	usedIps, err := h.getUsedIPs() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// for _, ip := range usedIps { | ||||
| 	// 	nextIP := ip.Next() | ||||
| 
 | ||||
| 	// 	if !containsIPs(usedIps, nextIP) && ipPrefix.Contains(nextIP) { | ||||
| 	// 		return &nextIP, nil | ||||
| 	// 	} | ||||
| 	// } | ||||
| 
 | ||||
| 	// // If there are no IPs in use, we are starting fresh and | ||||
| 	// // can issue IPs from the beginning of the prefix. | ||||
| 	// ip := ipPrefix.IP() | ||||
| 	// return &ip, nil | ||||
| 
 | ||||
| 	// return nil, fmt.Errorf("failed to find any available IP in %s", ipPrefix) | ||||
| 
 | ||||
| 	// Get the first IP in our prefix | ||||
| 	ip := ipPrefix.IP() | ||||
| 
 | ||||
| 	for { | ||||
| 		ip, err := getRandomIP(h.cfg.IPPrefix) | ||||
| 		if !ipPrefix.Contains(ip) { | ||||
| 			return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix) | ||||
| 		} | ||||
| 
 | ||||
| 		if ip.IsZero() && | ||||
| 			ip.IsLoopback() { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if !containsIPs(usedIps, ip) { | ||||
| 			return &ip, nil | ||||
| 		} | ||||
| 
 | ||||
| 		ip = ip.Next() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { | ||||
| 	var addresses []string | ||||
| 	h.db.Model(&Machine{}).Pluck("ip_address", &addresses) | ||||
| 
 | ||||
| 	ips := make([]netaddr.IP, len(addresses)) | ||||
| 	for index, addr := range addresses { | ||||
| 		ip, err := netaddr.ParseIP(addr) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 			return nil, fmt.Errorf("failed to parse ip from database, %w", err) | ||||
| 		} | ||||
| 		m := Machine{} | ||||
| 		if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { | ||||
| 			return ip, nil | ||||
| 		} | ||||
| 		i++ | ||||
| 		if i == 100 { // really random number | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, errors.New(fmt.Sprintf("Could not find an available IP address in %s", h.cfg.IPPrefix.String())) | ||||
| } | ||||
| 
 | ||||
| func getRandomIP(ipPrefix netaddr.IPPrefix) (*net.IP, error) { | ||||
| 	mathrand.Seed(time.Now().Unix()) | ||||
| 	ipo, ipnet, err := net.ParseCIDR(ipPrefix.String()) | ||||
| 	if err == nil { | ||||
| 		ip := ipo.To4() | ||||
| 		// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) | ||||
| 		// fmt.Println("Final address is ", ip) | ||||
| 		// fmt.Println("Broadcast address is ", ipb) | ||||
| 		// fmt.Println("Network address is ", ipn) | ||||
| 		r := mathrand.Uint32() | ||||
| 		ipRaw := make([]byte, 4) | ||||
| 		binary.LittleEndian.PutUint32(ipRaw, r) | ||||
| 		// ipRaw[3] = 254 | ||||
| 		// fmt.Println("ipRaw is ", ipRaw) | ||||
| 		for i, v := range ipRaw { | ||||
| 			// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i]) | ||||
| 			ip[i] = ip[i] + (v &^ ipnet.Mask[i]) | ||||
| 			// fmt.Println("IP After: ", ip[i]) | ||||
| 		} | ||||
| 		// fmt.Println("FINAL IP: ", ip.String()) | ||||
| 		return &ip, nil | ||||
| 		ips[index] = ip | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, err | ||||
| 	return ips, nil | ||||
| } | ||||
| 
 | ||||
| func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { | ||||
| 	for _, v := range ips { | ||||
| 		if v == ip { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
|  | ||||
							
								
								
									
										105
									
								
								utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								utils_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | ||||
| package headscale | ||||
| 
 | ||||
| import ( | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"inet.af/netaddr" | ||||
| ) | ||||
| 
 | ||||
| func (s *Suite) TestGetAvailableIp(c *check.C) { | ||||
| 	ip, err := h.getAvailableIP() | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	expected := netaddr.MustParseIP("10.27.0.0") | ||||
| 
 | ||||
| 	c.Assert(ip.String(), check.Equals, expected.String()) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetUsedIps(c *check.C) { | ||||
| 	ip, err := h.getAvailableIP() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	n, err := h.CreateNamespace("test_ip") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	_, err = h.GetMachine("test", "testmachine") | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	m := Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Name:           "testmachine", | ||||
| 		NamespaceID:    n.ID, | ||||
| 		Registered:     true, | ||||
| 		RegisterMethod: "authKey", | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		IPAddress:      ip.String(), | ||||
| 	} | ||||
| 	h.db.Save(&m) | ||||
| 
 | ||||
| 	ips, err := h.getUsedIPs() | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	expected := netaddr.MustParseIP("10.27.0.0") | ||||
| 
 | ||||
| 	c.Assert(ips[0], check.Equals, expected) | ||||
| } | ||||
| 
 | ||||
| func (s *Suite) TestGetMultiIp(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test-ip-multi") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	for i := 1; i <= 350; i++ { | ||||
| 		ip, err := h.getAvailableIP() | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 		pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) | ||||
| 		c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 		_, err = h.GetMachine("test", "testmachine") | ||||
| 		c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 		m := Machine{ | ||||
| 			ID:             0, | ||||
| 			MachineKey:     "foo", | ||||
| 			NodeKey:        "bar", | ||||
| 			DiscoKey:       "faa", | ||||
| 			Name:           "testmachine", | ||||
| 			NamespaceID:    n.ID, | ||||
| 			Registered:     true, | ||||
| 			RegisterMethod: "authKey", | ||||
| 			AuthKeyID:      uint(pak.ID), | ||||
| 			IPAddress:      ip.String(), | ||||
| 		} | ||||
| 		h.db.Save(&m) | ||||
| 	} | ||||
| 
 | ||||
| 	ips, err := h.getUsedIPs() | ||||
| 
 | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(len(ips), check.Equals, 350) | ||||
| 
 | ||||
| 	c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.0")) | ||||
| 	c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.9")) | ||||
| 	c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.44")) | ||||
| 
 | ||||
| 	expectedNextIP := netaddr.MustParseIP("10.27.1.94") | ||||
| 	nextIP, err := h.getAvailableIP() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) | ||||
| 
 | ||||
| 	// If we call get Available again, we should receive | ||||
| 	// the same IP, as it has not been reserved. | ||||
| 	nextIP2, err := h.getAvailableIP() | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user