mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 01:51:04 +01:00 
			
		
		
		
	add versioned migrations (#1644)
This commit is contained in:
		
							parent
							
								
									ac910fd44c
								
							
						
					
					
						commit
						6049ec758c
					
				@ -34,6 +34,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
### Changes
 | 
					### Changes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
 | 
				
			||||||
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
 | 
					Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
 | 
				
			||||||
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
 | 
					SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
 | 
				
			||||||
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
 | 
					State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
          # When updating go.mod or go.sum, a new sha will need to be calculated,
 | 
					          # When updating go.mod or go.sum, a new sha will need to be calculated,
 | 
				
			||||||
          # update this if you have a mismatch after doing a change to thos files.
 | 
					          # update this if you have a mismatch after doing a change to thos files.
 | 
				
			||||||
          vendorHash = "sha256-7yqJbF0GkKa3wjiGWJ8BZSJyckrpwmCiX77/aoPGmRc=";
 | 
					          vendorHash = "sha256-u9AmJguQ5dnJpfhOeLN43apvMHuraOrJhvlEIp9RoIc=";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
 | 
					          ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@ -75,6 +75,7 @@ require (
 | 
				
			|||||||
	github.com/fsnotify/fsnotify v1.7.0 // indirect
 | 
						github.com/fsnotify/fsnotify v1.7.0 // indirect
 | 
				
			||||||
	github.com/fxamacker/cbor/v2 v2.5.0 // indirect
 | 
						github.com/fxamacker/cbor/v2 v2.5.0 // indirect
 | 
				
			||||||
	github.com/glebarez/go-sqlite v1.21.2 // indirect
 | 
						github.com/glebarez/go-sqlite v1.21.2 // indirect
 | 
				
			||||||
 | 
						github.com/go-gormigrate/gormigrate/v2 v2.1.1 // indirect
 | 
				
			||||||
	github.com/go-jose/go-jose/v3 v3.0.1 // indirect
 | 
						github.com/go-jose/go-jose/v3 v3.0.1 // indirect
 | 
				
			||||||
	github.com/gogo/protobuf v1.3.2 // indirect
 | 
						github.com/gogo/protobuf v1.3.2 // indirect
 | 
				
			||||||
	github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
 | 
						github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@ -101,6 +101,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
 | 
				
			|||||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
 | 
					github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
 | 
				
			||||||
github.com/glebarez/sqlite v1.10.0 h1:u4gt8y7OND/cCei/NMHmfbLxF6xP2wgKcT/BJf2pYkc=
 | 
					github.com/glebarez/sqlite v1.10.0 h1:u4gt8y7OND/cCei/NMHmfbLxF6xP2wgKcT/BJf2pYkc=
 | 
				
			||||||
github.com/glebarez/sqlite v1.10.0/go.mod h1:IJ+lfSOmiekhQsFTJRx/lHtGYmCdtAiTaf5wI9u5uHA=
 | 
					github.com/glebarez/sqlite v1.10.0/go.mod h1:IJ+lfSOmiekhQsFTJRx/lHtGYmCdtAiTaf5wI9u5uHA=
 | 
				
			||||||
 | 
					github.com/go-gormigrate/gormigrate/v2 v2.1.1 h1:eGS0WTFRV30r103lU8JNXY27KbviRnqqIDobW3EV3iY=
 | 
				
			||||||
 | 
					github.com/go-gormigrate/gormigrate/v2 v2.1.1/go.mod h1:L7nJ620PFDKei9QOhJzqA8kRCk+E3UbV2f5gv+1ndLc=
 | 
				
			||||||
github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
 | 
					github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
 | 
				
			||||||
github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
 | 
					github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
 | 
				
			||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
 | 
					github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
 | 
				
			||||||
 | 
				
			|||||||
@ -11,6 +11,7 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/glebarez/sqlite"
 | 
						"github.com/glebarez/sqlite"
 | 
				
			||||||
 | 
						"github.com/go-gormigrate/gormigrate/v2"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/notifier"
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
@ -21,15 +22,11 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	dbVersion = "1"
 | 
						Postgres = "postgres"
 | 
				
			||||||
	Postgres  = "postgres"
 | 
						Sqlite   = "sqlite3"
 | 
				
			||||||
	Sqlite    = "sqlite3"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var errDatabaseNotSupported = errors.New("database type not supported")
 | 
				
			||||||
	errValueNotFound        = errors.New("not found")
 | 
					 | 
				
			||||||
	errDatabaseNotSupported = errors.New("database type not supported")
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// KV is a key-value store in a psql table. For future use...
 | 
					// KV is a key-value store in a psql table. For future use...
 | 
				
			||||||
// TODO(kradalby): Is this used for anything?
 | 
					// TODO(kradalby): Is this used for anything?
 | 
				
			||||||
@ -64,6 +61,261 @@ func NewHeadscaleDatabase(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{
 | 
				
			||||||
 | 
							// New migrations should be added as transactions at the end of this list.
 | 
				
			||||||
 | 
							// The initial commit here is quite messy, completely out of order and
 | 
				
			||||||
 | 
							// has no versioning and is the tech debt of not having versioned migrations
 | 
				
			||||||
 | 
							// prior to this point. This first migration is all DB changes to bring a DB
 | 
				
			||||||
 | 
							// up to 0.23.0.
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								ID: "202312101416",
 | 
				
			||||||
 | 
								Migrate: func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
									if dbType == Postgres {
 | 
				
			||||||
 | 
										tx.Exec(`create extension if not exists "uuid-ossp";`)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameTable("namespaces", "users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// the big rename from Machine to Node
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameTable("machines", "nodes")
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(types.User{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id")
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// GivenName is used as the primary source of DNS names, make sure
 | 
				
			||||||
 | 
									// the field is populated and normalized if it was not when the
 | 
				
			||||||
 | 
									// node was registered.
 | 
				
			||||||
 | 
									_ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// If the Node table has a column for registered,
 | 
				
			||||||
 | 
									// find all occourences of "false" and drop them. Then
 | 
				
			||||||
 | 
									// remove the column.
 | 
				
			||||||
 | 
									if tx.Migrator().HasColumn(&types.Node{}, "registered") {
 | 
				
			||||||
 | 
										log.Info().
 | 
				
			||||||
 | 
											Msg(`Database has legacy "registered" column in node, removing...`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										nodes := types.Nodes{}
 | 
				
			||||||
 | 
										if err := tx.Not("registered").Find(&nodes).Error; err != nil {
 | 
				
			||||||
 | 
											log.Error().Err(err).Msg("Error accessing db")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										for _, node := range nodes {
 | 
				
			||||||
 | 
											log.Info().
 | 
				
			||||||
 | 
												Str("node", node.Hostname).
 | 
				
			||||||
 | 
												Str("machine_key", node.MachineKey.ShortString()).
 | 
				
			||||||
 | 
												Msg("Deleting unregistered node")
 | 
				
			||||||
 | 
											if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil {
 | 
				
			||||||
 | 
												log.Error().
 | 
				
			||||||
 | 
													Err(err).
 | 
				
			||||||
 | 
													Str("node", node.Hostname).
 | 
				
			||||||
 | 
													Str("machine_key", node.MachineKey.ShortString()).
 | 
				
			||||||
 | 
													Msg("Error deleting unregistered node")
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										err := tx.Migrator().DropColumn(&types.Node{}, "registered")
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											log.Error().Err(err).Msg("Error dropping registered column")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&types.Route{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&types.Node{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Ensure all keys have correct prefixes
 | 
				
			||||||
 | 
									// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
 | 
				
			||||||
 | 
									type result struct {
 | 
				
			||||||
 | 
										ID         uint64
 | 
				
			||||||
 | 
										MachineKey string
 | 
				
			||||||
 | 
										NodeKey    string
 | 
				
			||||||
 | 
										DiscoKey   string
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									var results []result
 | 
				
			||||||
 | 
									err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									for _, node := range results {
 | 
				
			||||||
 | 
										mKey := node.MachineKey
 | 
				
			||||||
 | 
										if !strings.HasPrefix(node.MachineKey, "mkey:") {
 | 
				
			||||||
 | 
											mKey = "mkey:" + node.MachineKey
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										nKey := node.NodeKey
 | 
				
			||||||
 | 
										if !strings.HasPrefix(node.NodeKey, "nodekey:") {
 | 
				
			||||||
 | 
											nKey = "nodekey:" + node.NodeKey
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										dKey := node.DiscoKey
 | 
				
			||||||
 | 
										if !strings.HasPrefix(node.DiscoKey, "discokey:") {
 | 
				
			||||||
 | 
											dKey = "discokey:" + node.DiscoKey
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										err := tx.Exec(
 | 
				
			||||||
 | 
											"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
 | 
				
			||||||
 | 
											sql.Named("mKey", mKey),
 | 
				
			||||||
 | 
											sql.Named("nKey", nKey),
 | 
				
			||||||
 | 
											sql.Named("dKey", dKey),
 | 
				
			||||||
 | 
											sql.Named("id", node.ID),
 | 
				
			||||||
 | 
										).Error
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											return err
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
 | 
				
			||||||
 | 
										log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										type NodeAux struct {
 | 
				
			||||||
 | 
											ID            uint64
 | 
				
			||||||
 | 
											EnabledRoutes types.IPPrefixes
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										nodesAux := []NodeAux{}
 | 
				
			||||||
 | 
										err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											log.Fatal().Err(err).Msg("Error accessing db")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										for _, node := range nodesAux {
 | 
				
			||||||
 | 
											for _, prefix := range node.EnabledRoutes {
 | 
				
			||||||
 | 
												if err != nil {
 | 
				
			||||||
 | 
													log.Error().
 | 
				
			||||||
 | 
														Err(err).
 | 
				
			||||||
 | 
														Str("enabled_route", prefix.String()).
 | 
				
			||||||
 | 
														Msg("Error parsing enabled_route")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
													continue
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
												err = tx.Preload("Node").
 | 
				
			||||||
 | 
													Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
 | 
				
			||||||
 | 
													First(&types.Route{}).
 | 
				
			||||||
 | 
													Error
 | 
				
			||||||
 | 
												if err == nil {
 | 
				
			||||||
 | 
													log.Info().
 | 
				
			||||||
 | 
														Str("enabled_route", prefix.String()).
 | 
				
			||||||
 | 
														Msg("Route already migrated to new table, skipping")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
													continue
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
												route := types.Route{
 | 
				
			||||||
 | 
													NodeID:     node.ID,
 | 
				
			||||||
 | 
													Advertised: true,
 | 
				
			||||||
 | 
													Enabled:    true,
 | 
				
			||||||
 | 
													Prefix:     types.IPPrefix(prefix),
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
												if err := tx.Create(&route).Error; err != nil {
 | 
				
			||||||
 | 
													log.Error().Err(err).Msg("Error creating route")
 | 
				
			||||||
 | 
												} else {
 | 
				
			||||||
 | 
													log.Info().
 | 
				
			||||||
 | 
														Uint64("node_id", route.NodeID).
 | 
				
			||||||
 | 
														Str("prefix", prefix.String()).
 | 
				
			||||||
 | 
														Msg("Route migrated")
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											log.Error().Err(err).Msg("Error dropping enabled_routes column")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if tx.Migrator().HasColumn(&types.Node{}, "given_name") {
 | 
				
			||||||
 | 
										nodes := types.Nodes{}
 | 
				
			||||||
 | 
										if err := tx.Find(&nodes).Error; err != nil {
 | 
				
			||||||
 | 
											log.Error().Err(err).Msg("Error accessing db")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										for item, node := range nodes {
 | 
				
			||||||
 | 
											if node.GivenName == "" {
 | 
				
			||||||
 | 
												normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
 | 
				
			||||||
 | 
													node.Hostname,
 | 
				
			||||||
 | 
												)
 | 
				
			||||||
 | 
												if err != nil {
 | 
				
			||||||
 | 
													log.Error().
 | 
				
			||||||
 | 
														Caller().
 | 
				
			||||||
 | 
														Str("hostname", node.Hostname).
 | 
				
			||||||
 | 
														Err(err).
 | 
				
			||||||
 | 
														Msg("Failed to normalize node hostname in DB migration")
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
												err = tx.Model(nodes[item]).Updates(types.Node{
 | 
				
			||||||
 | 
													GivenName: normalizedHostname,
 | 
				
			||||||
 | 
												}).Error
 | 
				
			||||||
 | 
												if err != nil {
 | 
				
			||||||
 | 
													log.Error().
 | 
				
			||||||
 | 
														Caller().
 | 
				
			||||||
 | 
														Str("hostname", node.Hostname).
 | 
				
			||||||
 | 
														Err(err).
 | 
				
			||||||
 | 
														Msg("Failed to save normalized node name in DB migration")
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&KV{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&types.PreAuthKey{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&types.PreAuthKeyACLTag{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									_ = tx.Migrator().DropTable("shared_machines")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tx.AutoMigrate(&types.APIKey{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								Rollback: func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								// drop key-value table, it is not used, and has not contained
 | 
				
			||||||
 | 
								// useful data for a long time or ever.
 | 
				
			||||||
 | 
								ID: "202312101430",
 | 
				
			||||||
 | 
								Migrate: func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
									return tx.Migrator().DropTable("kvs")
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								Rollback: func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = migrations.Migrate(); err != nil {
 | 
				
			||||||
 | 
							log.Fatal().Err(err).Msgf("Migration failed: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db := HSDatabase{
 | 
						db := HSDatabase{
 | 
				
			||||||
		db:       dbConn,
 | 
							db:       dbConn,
 | 
				
			||||||
		notifier: notifier,
 | 
							notifier: notifier,
 | 
				
			||||||
@ -72,232 +324,6 @@ func NewHeadscaleDatabase(
 | 
				
			|||||||
		baseDomain: baseDomain,
 | 
							baseDomain: baseDomain,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msgf("database %#v", dbConn)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if dbType == Postgres {
 | 
					 | 
				
			||||||
		dbConn.Exec(`create extension if not exists "uuid-ossp";`)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameTable("namespaces", "users")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// the big rename from Machine to Node
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameTable("machines", "nodes")
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(types.User{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id")
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// GivenName is used as the primary source of DNS names, make sure
 | 
					 | 
				
			||||||
	// the field is populated and normalized if it was not when the
 | 
					 | 
				
			||||||
	// node was registered.
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If the Node table has a column for registered,
 | 
					 | 
				
			||||||
	// find all occourences of "false" and drop them. Then
 | 
					 | 
				
			||||||
	// remove the column.
 | 
					 | 
				
			||||||
	if dbConn.Migrator().HasColumn(&types.Node{}, "registered") {
 | 
					 | 
				
			||||||
		log.Info().
 | 
					 | 
				
			||||||
			Msg(`Database has legacy "registered" column in node, removing...`)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		nodes := types.Nodes{}
 | 
					 | 
				
			||||||
		if err := dbConn.Not("registered").Find(&nodes).Error; err != nil {
 | 
					 | 
				
			||||||
			log.Error().Err(err).Msg("Error accessing db")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for _, node := range nodes {
 | 
					 | 
				
			||||||
			log.Info().
 | 
					 | 
				
			||||||
				Str("node", node.Hostname).
 | 
					 | 
				
			||||||
				Str("machine_key", node.MachineKey.ShortString()).
 | 
					 | 
				
			||||||
				Msg("Deleting unregistered node")
 | 
					 | 
				
			||||||
			if err := dbConn.Delete(&types.Node{}, node.ID).Error; err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Str("node", node.Hostname).
 | 
					 | 
				
			||||||
					Str("machine_key", node.MachineKey.ShortString()).
 | 
					 | 
				
			||||||
					Msg("Error deleting unregistered node")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		err := dbConn.Migrator().DropColumn(&types.Node{}, "registered")
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Error().Err(err).Msg("Error dropping registered column")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&types.Route{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&types.Node{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Ensure all keys have correct prefixes
 | 
					 | 
				
			||||||
	// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
 | 
					 | 
				
			||||||
	type result struct {
 | 
					 | 
				
			||||||
		ID         uint64
 | 
					 | 
				
			||||||
		MachineKey string
 | 
					 | 
				
			||||||
		NodeKey    string
 | 
					 | 
				
			||||||
		DiscoKey   string
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var results []result
 | 
					 | 
				
			||||||
	err = db.db.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, node := range results {
 | 
					 | 
				
			||||||
		mKey := node.MachineKey
 | 
					 | 
				
			||||||
		if !strings.HasPrefix(node.MachineKey, "mkey:") {
 | 
					 | 
				
			||||||
			mKey = "mkey:" + node.MachineKey
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		nKey := node.NodeKey
 | 
					 | 
				
			||||||
		if !strings.HasPrefix(node.NodeKey, "nodekey:") {
 | 
					 | 
				
			||||||
			nKey = "nodekey:" + node.NodeKey
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		dKey := node.DiscoKey
 | 
					 | 
				
			||||||
		if !strings.HasPrefix(node.DiscoKey, "discokey:") {
 | 
					 | 
				
			||||||
			dKey = "discokey:" + node.DiscoKey
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		err := db.db.Exec(
 | 
					 | 
				
			||||||
			"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
 | 
					 | 
				
			||||||
			sql.Named("mKey", mKey),
 | 
					 | 
				
			||||||
			sql.Named("nKey", nKey),
 | 
					 | 
				
			||||||
			sql.Named("dKey", dKey),
 | 
					 | 
				
			||||||
			sql.Named("id", node.ID),
 | 
					 | 
				
			||||||
		).Error
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if dbConn.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
 | 
					 | 
				
			||||||
		log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		type NodeAux struct {
 | 
					 | 
				
			||||||
			ID            uint64
 | 
					 | 
				
			||||||
			EnabledRoutes types.IPPrefixes
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		nodesAux := []NodeAux{}
 | 
					 | 
				
			||||||
		err := dbConn.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Fatal().Err(err).Msg("Error accessing db")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		for _, node := range nodesAux {
 | 
					 | 
				
			||||||
			for _, prefix := range node.EnabledRoutes {
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					log.Error().
 | 
					 | 
				
			||||||
						Err(err).
 | 
					 | 
				
			||||||
						Str("enabled_route", prefix.String()).
 | 
					 | 
				
			||||||
						Msg("Error parsing enabled_route")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				err = dbConn.Preload("Node").
 | 
					 | 
				
			||||||
					Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
 | 
					 | 
				
			||||||
					First(&types.Route{}).
 | 
					 | 
				
			||||||
					Error
 | 
					 | 
				
			||||||
				if err == nil {
 | 
					 | 
				
			||||||
					log.Info().
 | 
					 | 
				
			||||||
						Str("enabled_route", prefix.String()).
 | 
					 | 
				
			||||||
						Msg("Route already migrated to new table, skipping")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				route := types.Route{
 | 
					 | 
				
			||||||
					NodeID:     node.ID,
 | 
					 | 
				
			||||||
					Advertised: true,
 | 
					 | 
				
			||||||
					Enabled:    true,
 | 
					 | 
				
			||||||
					Prefix:     types.IPPrefix(prefix),
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if err := dbConn.Create(&route).Error; err != nil {
 | 
					 | 
				
			||||||
					log.Error().Err(err).Msg("Error creating route")
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					log.Info().
 | 
					 | 
				
			||||||
						Uint64("node_id", route.NodeID).
 | 
					 | 
				
			||||||
						Str("prefix", prefix.String()).
 | 
					 | 
				
			||||||
						Msg("Route migrated")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		err = dbConn.Migrator().DropColumn(&types.Node{}, "enabled_routes")
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Error().Err(err).Msg("Error dropping enabled_routes column")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if dbConn.Migrator().HasColumn(&types.Node{}, "given_name") {
 | 
					 | 
				
			||||||
		nodes := types.Nodes{}
 | 
					 | 
				
			||||||
		if err := dbConn.Find(&nodes).Error; err != nil {
 | 
					 | 
				
			||||||
			log.Error().Err(err).Msg("Error accessing db")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for item, node := range nodes {
 | 
					 | 
				
			||||||
			if node.GivenName == "" {
 | 
					 | 
				
			||||||
				normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
 | 
					 | 
				
			||||||
					node.Hostname,
 | 
					 | 
				
			||||||
				)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					log.Error().
 | 
					 | 
				
			||||||
						Caller().
 | 
					 | 
				
			||||||
						Str("hostname", node.Hostname).
 | 
					 | 
				
			||||||
						Err(err).
 | 
					 | 
				
			||||||
						Msg("Failed to normalize node hostname in DB migration")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				err = db.RenameNode(nodes[item], normalizedHostname)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					log.Error().
 | 
					 | 
				
			||||||
						Caller().
 | 
					 | 
				
			||||||
						Str("hostname", node.Hostname).
 | 
					 | 
				
			||||||
						Err(err).
 | 
					 | 
				
			||||||
						Msg("Failed to save normalized node name in DB migration")
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&KV{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&types.PreAuthKey{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_ = dbConn.Migrator().DropTable("shared_machines")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = dbConn.AutoMigrate(&types.APIKey{})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// TODO(kradalby): is this needed?
 | 
					 | 
				
			||||||
	err = db.setValue("db_version", dbVersion)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &db, err
 | 
						return &db, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -347,39 +373,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// getValue returns the value for the given key in KV.
 | 
					 | 
				
			||||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
 | 
					 | 
				
			||||||
	var row KV
 | 
					 | 
				
			||||||
	if result := hsdb.db.First(&row, "key = ?", key); errors.Is(
 | 
					 | 
				
			||||||
		result.Error,
 | 
					 | 
				
			||||||
		gorm.ErrRecordNotFound,
 | 
					 | 
				
			||||||
	) {
 | 
					 | 
				
			||||||
		return "", errValueNotFound
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return row.Value, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// setValue sets value for the given key in KV.
 | 
					 | 
				
			||||||
func (hsdb *HSDatabase) setValue(key string, value string) error {
 | 
					 | 
				
			||||||
	keyValue := KV{
 | 
					 | 
				
			||||||
		Key:   key,
 | 
					 | 
				
			||||||
		Value: value,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, err := hsdb.getValue(key); err == nil {
 | 
					 | 
				
			||||||
		hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := hsdb.db.Create(keyValue).Error; err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed to create key value pair in the database: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
 | 
					func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
 | 
				
			||||||
	ctx, cancel := context.WithTimeout(ctx, time.Second)
 | 
						ctx, cancel := context.WithTimeout(ctx, time.Second)
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user