mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 08:01:34 +01:00 
			
		
		
		
	move to use tailscfg types over strings/custom types (#1612)
* rename database only fields Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use correct endpoint type over string list Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove HostInfo wrapper Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * wrap errors in database hooks Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									ed4e19996b
								
							
						
					
					
						commit
						b918aa03fc
					
				| @ -12,6 +12,7 @@ import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"gopkg.in/check.v1" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/key" | ||||
| ) | ||||
| 
 | ||||
| @ -593,7 +594,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo: types.HostInfo{ | ||||
| 		Hostinfo: &tailcfg.Hostinfo{ | ||||
| 			RequestTags: []string{"tag:exit"}, | ||||
| 			RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, | ||||
| 		}, | ||||
|  | ||||
| @ -274,7 +274,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { | ||||
| 	} | ||||
| 
 | ||||
| 	advertisedRoutes := map[netip.Prefix]bool{} | ||||
| 	for _, prefix := range node.HostInfo.RoutableIPs { | ||||
| 	for _, prefix := range node.Hostinfo.RoutableIPs { | ||||
| 		advertisedRoutes[prefix] = false | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 	db.db.Save(&node) | ||||
| 
 | ||||
| @ -81,7 +81,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 	db.db.Save(&node) | ||||
| 
 | ||||
| @ -152,7 +152,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		Hostinfo:       &hostInfo1, | ||||
| 	} | ||||
| 	db.db.Save(&node1) | ||||
| 
 | ||||
| @ -174,7 +174,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 		Hostinfo:       &hostInfo2, | ||||
| 	} | ||||
| 	db.db.Save(&node2) | ||||
| 
 | ||||
| @ -232,7 +232,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		Hostinfo:       &hostInfo1, | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	db.db.Save(&node1) | ||||
| @ -266,7 +266,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 		Hostinfo:       &hostInfo2, | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	db.db.Save(&node2) | ||||
| @ -313,9 +313,9 @@ func (s *Suite) TestSubnetFailover(c *check.C) { | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(len(routes), check.Equals, 1) | ||||
| 
 | ||||
| 	node2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ | ||||
| 	node2.Hostinfo = &tailcfg.Hostinfo{ | ||||
| 		RoutableIPs: []netip.Prefix{prefix, prefix2}, | ||||
| 	}) | ||||
| 	} | ||||
| 	err = db.db.Save(&node2).Error | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| @ -368,7 +368,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { | ||||
| 		UserID:         user.ID, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 		HostInfo:       types.HostInfo(hostInfo1), | ||||
| 		Hostinfo:       &hostInfo1, | ||||
| 		LastSeen:       &now, | ||||
| 	} | ||||
| 	db.db.Save(&node1) | ||||
|  | ||||
| @ -550,7 +550,7 @@ func (api headscaleV1APIServer) DebugCreateNode( | ||||
| 		Expiry:   &time.Time{}, | ||||
| 		LastSeen: &time.Time{}, | ||||
| 
 | ||||
| 		HostInfo: types.HostInfo(hostinfo), | ||||
| 		Hostinfo: &hostinfo, | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug(). | ||||
|  | ||||
| @ -195,7 +195,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { | ||||
| 		if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { | ||||
| 			attrs := url.Values{ | ||||
| 				"device_name":  []string{node.Hostname}, | ||||
| 				"device_model": []string{node.HostInfo.OS}, | ||||
| 				"device_model": []string{node.Hostinfo.OS}, | ||||
| 			} | ||||
| 
 | ||||
| 			if len(node.IPAddresses) > 0 { | ||||
|  | ||||
| @ -186,8 +186,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		AuthKey:     &types.PreAuthKey{}, | ||||
| 		LastSeen:    &lastSeen, | ||||
| 		Expiry:      &expire, | ||||
| 		HostInfo:    types.HostInfo{}, | ||||
| 		Endpoints:   []string{}, | ||||
| 		Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 		Routes: []types.Route{ | ||||
| 			{ | ||||
| 				Prefix:     types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), | ||||
| @ -267,8 +266,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		ForcedTags:  []string{}, | ||||
| 		LastSeen:    &lastSeen, | ||||
| 		Expiry:      &expire, | ||||
| 		HostInfo:    types.HostInfo{}, | ||||
| 		Endpoints:   []string{}, | ||||
| 		Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 		Routes:      []types.Route{}, | ||||
| 		CreatedAt:   created, | ||||
| 	} | ||||
| @ -324,8 +322,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		ForcedTags:  []string{}, | ||||
| 		LastSeen:    &lastSeen, | ||||
| 		Expiry:      &expire, | ||||
| 		HostInfo:    types.HostInfo{}, | ||||
| 		Endpoints:   []string{}, | ||||
| 		Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 		Routes:      []types.Route{}, | ||||
| 		CreatedAt:   created, | ||||
| 	} | ||||
|  | ||||
| @ -72,8 +72,8 @@ func tailNode( | ||||
| 	} | ||||
| 
 | ||||
| 	var derp string | ||||
| 	if node.HostInfo.NetInfo != nil { | ||||
| 		derp = fmt.Sprintf("127.3.3.40:%d", node.HostInfo.NetInfo.PreferredDERP) | ||||
| 	if node.Hostinfo.NetInfo != nil { | ||||
| 		derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) | ||||
| 	} else { | ||||
| 		derp = "127.3.3.40:0" // Zero means disconnected or unknown. | ||||
| 	} | ||||
| @ -90,18 +90,11 @@ func tailNode( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	hostInfo := node.GetHostInfo() | ||||
| 
 | ||||
| 	online := node.IsOnline() | ||||
| 
 | ||||
| 	tags, _ := pol.TagsOfNode(node) | ||||
| 	tags = lo.Uniq(append(tags, node.ForcedTags...)) | ||||
| 
 | ||||
| 	endpoints, err := node.EndpointsToAddrPort() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	tNode := tailcfg.Node{ | ||||
| 		ID: tailcfg.NodeID(node.ID), // this is the actual ID | ||||
| 		StableID: tailcfg.StableNodeID( | ||||
| @ -118,9 +111,9 @@ func tailNode( | ||||
| 		DiscoKey:   node.DiscoKey, | ||||
| 		Addresses:  addrs, | ||||
| 		AllowedIPs: allowedIPs, | ||||
| 		Endpoints:  endpoints, | ||||
| 		Endpoints:  node.Endpoints, | ||||
| 		DERP:       derp, | ||||
| 		Hostinfo:   hostInfo.View(), | ||||
| 		Hostinfo:   node.Hostinfo.View(), | ||||
| 		Created:    node.CreatedAt, | ||||
| 
 | ||||
| 		Tags: tags, | ||||
|  | ||||
| @ -54,7 +54,9 @@ func TestTailNode(t *testing.T) { | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "empty-node", | ||||
| 			node:       &types.Node{}, | ||||
| 			node: &types.Node{ | ||||
| 				Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 			}, | ||||
| 			pol:        &policy.ACLPolicy{}, | ||||
| 			dnsConfig:  &tailcfg.DNSConfig{}, | ||||
| 			baseDomain: "", | ||||
| @ -102,8 +104,7 @@ func TestTailNode(t *testing.T) { | ||||
| 				AuthKey:    &types.PreAuthKey{}, | ||||
| 				LastSeen:   &lastSeen, | ||||
| 				Expiry:     &expire, | ||||
| 				HostInfo:   types.HostInfo{}, | ||||
| 				Endpoints:  []string{}, | ||||
| 				Hostinfo:   &tailcfg.Hostinfo{}, | ||||
| 				Routes: []types.Route{ | ||||
| 					{ | ||||
| 						Prefix:     types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), | ||||
|  | ||||
| @ -596,10 +596,13 @@ func excludeCorrectlyTaggedNodes( | ||||
| 	} | ||||
| 	// for each node if tag is in tags list, don't append it. | ||||
| 	for _, node := range nodes { | ||||
| 		hi := node.GetHostInfo() | ||||
| 
 | ||||
| 		found := false | ||||
| 		for _, t := range hi.RequestTags { | ||||
| 
 | ||||
| 		if node.Hostinfo == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		for _, t := range node.Hostinfo.RequestTags { | ||||
| 			if util.StringOrPrefixListContains(tags, t) { | ||||
| 				found = true | ||||
| 
 | ||||
| @ -787,8 +790,11 @@ func (pol *ACLPolicy) expandIPsFromTag( | ||||
| 	for _, user := range owners { | ||||
| 		nodes := filterNodesByUser(nodes, user) | ||||
| 		for _, node := range nodes { | ||||
| 			hi := node.GetHostInfo() | ||||
| 			if util.StringOrPrefixListContains(hi.RequestTags, alias) { | ||||
| 			if node.Hostinfo == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) { | ||||
| 				node.IPAddresses.AppendToIPSet(&build) | ||||
| 			} | ||||
| 		} | ||||
| @ -882,7 +888,7 @@ func (pol *ACLPolicy) TagsOfNode( | ||||
| 
 | ||||
| 	validTagMap := make(map[string]bool) | ||||
| 	invalidTagMap := make(map[string]bool) | ||||
| 	for _, tag := range node.HostInfo.RequestTags { | ||||
| 	for _, tag := range node.Hostinfo.RequestTags { | ||||
| 		owners, err := expandOwnersFromTag(pol, tag) | ||||
| 		if errors.Is(err, ErrInvalidTag) { | ||||
| 			invalidTagMap[tag] = true | ||||
|  | ||||
| @ -418,6 +418,7 @@ acls: | ||||
| 					User: types.User{ | ||||
| 						Name: "testuser", | ||||
| 					}, | ||||
| 					Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}) | ||||
| 
 | ||||
| @ -1264,7 +1265,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1275,7 +1276,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1405,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1443,7 +1444,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1454,7 +1455,7 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1465,12 +1466,14 @@ func Test_expandAlias(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.3"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "marc"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPAddresses: types.NodeAddresses{ | ||||
| 							netip.MustParseAddr("100.64.0.4"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @ -1520,7 +1523,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1531,7 +1534,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1542,6 +1545,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.4"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user: "joe", | ||||
| @ -1550,6 +1554,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, | ||||
| 					User:        types.User{Name: "joe"}, | ||||
| 					Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @ -1570,7 +1575,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1581,7 +1586,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1592,6 +1597,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.4"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user: "joe", | ||||
| @ -1600,6 +1606,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, | ||||
| 					User:        types.User{Name: "joe"}, | ||||
| 					Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @ -1615,7 +1622,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "foo", | ||||
| 							RequestTags: []string{"tag:accountant-webserver"}, | ||||
| @ -1627,12 +1634,14 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						}, | ||||
| 						User:       types.User{Name: "joe"}, | ||||
| 						ForcedTags: []string{"tag:accountant-webserver"}, | ||||
| 						Hostinfo:   &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 					&types.Node{ | ||||
| 						IPAddresses: types.NodeAddresses{ | ||||
| 							netip.MustParseAddr("100.64.0.4"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user: "joe", | ||||
| @ -1641,6 +1650,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 				&types.Node{ | ||||
| 					IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, | ||||
| 					User:        types.User{Name: "joe"}, | ||||
| 					Hostinfo:    &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @ -1656,7 +1666,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.1"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "hr-web1", | ||||
| 							RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1667,7 +1677,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.2"), | ||||
| 						}, | ||||
| 						User: types.User{Name: "joe"}, | ||||
| 						HostInfo: types.HostInfo{ | ||||
| 						Hostinfo: &tailcfg.Hostinfo{ | ||||
| 							OS:          "centos", | ||||
| 							Hostname:    "hr-web2", | ||||
| 							RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1678,6 +1688,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 							netip.MustParseAddr("100.64.0.4"), | ||||
| 						}, | ||||
| 						User:     types.User{Name: "joe"}, | ||||
| 						Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				user: "joe", | ||||
| @ -1688,7 +1699,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						netip.MustParseAddr("100.64.0.1"), | ||||
| 					}, | ||||
| 					User: types.User{Name: "joe"}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						OS:          "centos", | ||||
| 						Hostname:    "hr-web1", | ||||
| 						RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1699,7 +1710,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						netip.MustParseAddr("100.64.0.2"), | ||||
| 					}, | ||||
| 					User: types.User{Name: "joe"}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						OS:          "centos", | ||||
| 						Hostname:    "hr-web2", | ||||
| 						RequestTags: []string{"tag:hr-webserver"}, | ||||
| @ -1710,6 +1721,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 						netip.MustParseAddr("100.64.0.4"), | ||||
| 					}, | ||||
| 					User:     types.User{Name: "joe"}, | ||||
| 					Hostinfo: &tailcfg.Hostinfo{}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @ -1952,7 +1964,7 @@ func Test_getTags(t *testing.T) { | ||||
| 					User: types.User{ | ||||
| 						Name: "joe", | ||||
| 					}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RequestTags: []string{"tag:valid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| @ -1972,7 +1984,7 @@ func Test_getTags(t *testing.T) { | ||||
| 					User: types.User{ | ||||
| 						Name: "joe", | ||||
| 					}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RequestTags: []string{"tag:valid", "tag:invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| @ -1992,7 +2004,7 @@ func Test_getTags(t *testing.T) { | ||||
| 					User: types.User{ | ||||
| 						Name: "joe", | ||||
| 					}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RequestTags: []string{ | ||||
| 							"tag:invalid", | ||||
| 							"tag:valid", | ||||
| @ -2016,7 +2028,7 @@ func Test_getTags(t *testing.T) { | ||||
| 					User: types.User{ | ||||
| 						Name: "joe", | ||||
| 					}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RequestTags: []string{"tag:invalid", "very-invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| @ -2032,7 +2044,7 @@ func Test_getTags(t *testing.T) { | ||||
| 					User: types.User{ | ||||
| 						Name: "joe", | ||||
| 					}, | ||||
| 					HostInfo: types.HostInfo{ | ||||
| 					Hostinfo: &tailcfg.Hostinfo{ | ||||
| 						RequestTags: []string{"tag:invalid", "very-invalid"}, | ||||
| 					}, | ||||
| 				}, | ||||
| @ -3010,7 +3022,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 
 | ||||
| 	pol := &ACLPolicy{ | ||||
| @ -3062,7 +3074,7 @@ func TestInvalidTagValidUser(t *testing.T) { | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 
 | ||||
| 	pol := &ACLPolicy{ | ||||
| @ -3113,7 +3125,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 
 | ||||
| 	pol := &ACLPolicy{ | ||||
| @ -3174,7 +3186,7 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		HostInfo:       types.HostInfo(hostInfo), | ||||
| 		Hostinfo:       &hostInfo, | ||||
| 	} | ||||
| 
 | ||||
| 	hostInfo2 := tailcfg.Hostinfo{ | ||||
| @ -3191,7 +3203,7 @@ func TestValidTagInvalidUser(t *testing.T) { | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		RegisterMethod: util.RegisterMethodAuthKey, | ||||
| 		HostInfo:       types.HostInfo(hostInfo2), | ||||
| 		Hostinfo:       &hostInfo2, | ||||
| 	} | ||||
| 
 | ||||
| 	pol := &ACLPolicy{ | ||||
|  | ||||
| @ -83,15 +83,14 @@ func (h *Headscale) handlePoll( | ||||
| 			Bool("stream", mapRequest.Stream). | ||||
| 			Str("node_key", node.NodeKey.ShortString()). | ||||
| 			Str("node", node.Hostname). | ||||
| 			Strs("endpoints", node.Endpoints). | ||||
| 			Msg("Received endpoint update") | ||||
| 
 | ||||
| 		now := time.Now().UTC() | ||||
| 		node.LastSeen = &now | ||||
| 		node.Hostname = mapRequest.Hostinfo.Hostname | ||||
| 		node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) | ||||
| 		node.Hostinfo = mapRequest.Hostinfo | ||||
| 		node.DiscoKey = mapRequest.DiscoKey | ||||
| 		node.SetEndpointsFromAddrPorts(mapRequest.Endpoints) | ||||
| 		node.Endpoints = mapRequest.Endpoints | ||||
| 
 | ||||
| 		if err := h.db.NodeSave(node); err != nil { | ||||
| 			logErr(err, "Failed to persist/update node in the database") | ||||
| @ -142,9 +141,9 @@ func (h *Headscale) handlePoll( | ||||
| 	now := time.Now().UTC() | ||||
| 	node.LastSeen = &now | ||||
| 	node.Hostname = mapRequest.Hostinfo.Hostname | ||||
| 	node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) | ||||
| 	node.Hostinfo = mapRequest.Hostinfo | ||||
| 	node.DiscoKey = mapRequest.DiscoKey | ||||
| 	node.SetEndpointsFromAddrPorts(mapRequest.Endpoints) | ||||
| 	node.Endpoints = mapRequest.Endpoints | ||||
| 
 | ||||
| 	// When a node connects to control, list the peers it has at | ||||
| 	// that given point, further updates are kept in memory in | ||||
|  | ||||
| @ -12,33 +12,6 @@ import ( | ||||
| 
 | ||||
| var ErrCannotParsePrefix = errors.New("cannot parse prefix") | ||||
| 
 | ||||
| // This is a "wrapper" type around tailscales | ||||
| // Hostinfo to allow us to add database "serialization" | ||||
| // methods. This allows us to use a typed values throughout | ||||
| // the code and not have to marshal/unmarshal and error | ||||
| // check all over the code. | ||||
| type HostInfo tailcfg.Hostinfo | ||||
| 
 | ||||
| func (hi *HostInfo) Scan(destination interface{}) error { | ||||
| 	switch value := destination.(type) { | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, hi) | ||||
| 
 | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), hi) | ||||
| 
 | ||||
| 	default: | ||||
| 		return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Value return json value, implement driver.Valuer interface. | ||||
| func (hi HostInfo) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(hi) | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type IPPrefix netip.Prefix | ||||
| 
 | ||||
| func (i *IPPrefix) Scan(destination interface{}) error { | ||||
|  | ||||
| @ -2,6 +2,7 @@ package types | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| @ -27,28 +28,41 @@ var ( | ||||
| type Node struct { | ||||
| 	ID uint64 `gorm:"primary_key"` | ||||
| 
 | ||||
| 	// MachineKeyValue is the string representation of MachineKey | ||||
| 	// MachineKeyDatabaseField is the string representation of MachineKey | ||||
| 	// it is _only_ used for reading and writing the key to the | ||||
| 	// database and should not be used. | ||||
| 	// Use MachineKey instead. | ||||
| 	MachineKeyValue string `gorm:"column:machine_key;unique_index"` | ||||
| 	MachineKeyDatabaseField string            `gorm:"column:machine_key;unique_index"` | ||||
| 	MachineKey              key.MachinePublic `gorm:"-"` | ||||
| 
 | ||||
| 	// NodeKeyValue is the string representation of NodeKey | ||||
| 	// NodeKeyDatabaseField is the string representation of NodeKey | ||||
| 	// it is _only_ used for reading and writing the key to the | ||||
| 	// database and should not be used. | ||||
| 	// Use NodeKey instead. | ||||
| 	NodeKeyValue string `gorm:"column:node_key"` | ||||
| 	NodeKeyDatabaseField string         `gorm:"column:node_key"` | ||||
| 	NodeKey              key.NodePublic `gorm:"-"` | ||||
| 
 | ||||
| 	// DiscoKeyValue is the string representation of DiscoKey | ||||
| 	// DiscoKeyDatabaseField is the string representation of DiscoKey | ||||
| 	// it is _only_ used for reading and writing the key to the | ||||
| 	// database and should not be used. | ||||
| 	// Use DiscoKey instead. | ||||
| 	DiscoKeyValue string `gorm:"column:disco_key"` | ||||
| 
 | ||||
| 	MachineKey key.MachinePublic `gorm:"-"` | ||||
| 	NodeKey    key.NodePublic    `gorm:"-"` | ||||
| 	DiscoKeyDatabaseField string          `gorm:"column:disco_key"` | ||||
| 	DiscoKey              key.DiscoPublic `gorm:"-"` | ||||
| 
 | ||||
| 	// EndpointsDatabaseField is the string list representation of Endpoints | ||||
| 	// it is _only_ used for reading and writing the key to the | ||||
| 	// database and should not be used. | ||||
| 	// Use Endpoints instead. | ||||
| 	EndpointsDatabaseField StringList       `gorm:"column:endpoints"` | ||||
| 	Endpoints              []netip.AddrPort `gorm:"-"` | ||||
| 
 | ||||
| 	// EndpointsDatabaseField is the string list representation of Endpoints | ||||
| 	// it is _only_ used for reading and writing the key to the | ||||
| 	// database and should not be used. | ||||
| 	// Use Endpoints instead. | ||||
| 	HostinfoDatabaseField string            `gorm:"column:hostinfo"` | ||||
| 	Hostinfo              *tailcfg.Hostinfo `gorm:"-"` | ||||
| 
 | ||||
| 	IPAddresses NodeAddresses | ||||
| 
 | ||||
| 	// Hostname represents the name given by the Tailscale | ||||
| @ -76,9 +90,6 @@ type Node struct { | ||||
| 	LastSeen *time.Time | ||||
| 	Expiry   *time.Time | ||||
| 
 | ||||
| 	HostInfo  HostInfo | ||||
| 	Endpoints StringList | ||||
| 
 | ||||
| 	Routes []Route | ||||
| 
 | ||||
| 	CreatedAt time.Time | ||||
| @ -195,31 +206,6 @@ func (node Node) IsExpired() bool { | ||||
| 	return time.Now().UTC().After(*node.Expiry) | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): Try to replace the types in the DB to be correct. | ||||
| func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) { | ||||
| 	var ret []netip.AddrPort | ||||
| 	for _, ep := range node.Endpoints { | ||||
| 		addrPort, err := netip.ParseAddrPort(ep) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		ret = append(ret, addrPort) | ||||
| 	} | ||||
| 
 | ||||
| 	return ret, nil | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby): Try to replace the types in the DB to be correct. | ||||
| func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) { | ||||
| 	var strs StringList | ||||
| 	for _, addrPort := range in { | ||||
| 		strs = append(strs, addrPort.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	node.Endpoints = strs | ||||
| } | ||||
| 
 | ||||
| // IsOnline returns if the node is connected to Headscale. | ||||
| // This is really a naive implementation, as we don't really see | ||||
| // if there is a working connection between the client and the server. | ||||
| @ -277,9 +263,22 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { | ||||
| // correctly in the database. | ||||
| // This currently means storing the keys as strings. | ||||
| func (n *Node) BeforeSave(tx *gorm.DB) (err error) { | ||||
| 	n.MachineKeyValue = n.MachineKey.String() | ||||
| 	n.NodeKeyValue = n.NodeKey.String() | ||||
| 	n.DiscoKeyValue = n.DiscoKey.String() | ||||
| 	n.MachineKeyDatabaseField = n.MachineKey.String() | ||||
| 	n.NodeKeyDatabaseField = n.NodeKey.String() | ||||
| 	n.DiscoKeyDatabaseField = n.DiscoKey.String() | ||||
| 
 | ||||
| 	var endpoints StringList | ||||
| 	for _, addrPort := range n.Endpoints { | ||||
| 		endpoints = append(endpoints, addrPort.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	n.EndpointsDatabaseField = endpoints | ||||
| 
 | ||||
| 	hi, err := json.Marshal(n.Hostinfo) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err) | ||||
| 	} | ||||
| 	n.HostinfoDatabaseField = string(hi) | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| @ -291,23 +290,40 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) { | ||||
| // the proper types. | ||||
| func (n *Node) AfterFind(tx *gorm.DB) (err error) { | ||||
| 	var machineKey key.MachinePublic | ||||
| 	if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil { | ||||
| 		return err | ||||
| 	if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil { | ||||
| 		return fmt.Errorf("failed to unmarshal machine key from db: %w", err) | ||||
| 	} | ||||
| 	n.MachineKey = machineKey | ||||
| 
 | ||||
| 	var nodeKey key.NodePublic | ||||
| 	if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil { | ||||
| 		return err | ||||
| 	if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil { | ||||
| 		return fmt.Errorf("failed to unmarshal node key from db: %w", err) | ||||
| 	} | ||||
| 	n.NodeKey = nodeKey | ||||
| 
 | ||||
| 	var discoKey key.DiscoPublic | ||||
| 	if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil { | ||||
| 		return err | ||||
| 	if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil { | ||||
| 		return fmt.Errorf("failed to unmarshal disco key from db: %w", err) | ||||
| 	} | ||||
| 	n.DiscoKey = discoKey | ||||
| 
 | ||||
| 	var endpoints []netip.AddrPort | ||||
| 	for _, ep := range n.EndpointsDatabaseField { | ||||
| 		addrPort, err := netip.ParseAddrPort(ep) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("failed to parse endpoint from db: %w", err) | ||||
| 		} | ||||
| 
 | ||||
| 		endpoints = append(endpoints, addrPort) | ||||
| 	} | ||||
| 	n.Endpoints = endpoints | ||||
| 
 | ||||
| 	var hi tailcfg.Hostinfo | ||||
| 	if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil { | ||||
| 		return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err) | ||||
| 	} | ||||
| 	n.Hostinfo = &hi | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -346,11 +362,6 @@ func (node *Node) Proto() *v1.Node { | ||||
| 	return nodeProto | ||||
| } | ||||
| 
 | ||||
| // GetHostInfo returns a Hostinfo struct for the node. | ||||
| func (node *Node) GetHostInfo() tailcfg.Hostinfo { | ||||
| 	return tailcfg.Hostinfo(node.HostInfo) | ||||
| } | ||||
| 
 | ||||
| func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { | ||||
| 	var hostname string | ||||
| 	if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user