mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 01:51:04 +01:00 
			
		
		
		
	Restructure database config (#1700)
This commit is contained in:
		
							parent
							
								
									00e7550e76
								
							
						
					
					
						commit
						94b30abf56
					
				
							
								
								
									
										22
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								CHANGELOG.md
									
									
									
									
									
								
							@ -34,16 +34,18 @@ after improving the test harness as part of adopting [#1460](https://github.com/
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
### Changes
 | 
					### Changes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
 | 
					- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
 | 
				
			||||||
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
 | 
					- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
 | 
				
			||||||
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
 | 
					- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
 | 
				
			||||||
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
 | 
					- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
 | 
				
			||||||
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
 | 
					- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
 | 
				
			||||||
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
 | 
					- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
 | 
				
			||||||
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
 | 
					- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
 | 
				
			||||||
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
 | 
					- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
 | 
				
			||||||
Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
 | 
					- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
 | 
				
			||||||
Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
 | 
					- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
 | 
				
			||||||
 | 
					- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700)
 | 
				
			||||||
 | 
					  - Old structure is now considered deprecated and will be removed in the future.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 0.22.3 (2023-05-12)
 | 
					## 0.22.3 (2023-05-12)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
 | 
				
			|||||||
	c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
 | 
						c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
 | 
				
			||||||
	c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
 | 
						c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
 | 
				
			||||||
	c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
 | 
						c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
 | 
				
			||||||
	c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
 | 
						c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
 | 
				
			||||||
	c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
 | 
						c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
 | 
				
			||||||
 | 
						c.Assert(viper.GetString("database.type"), check.Equals, "sqlite")
 | 
				
			||||||
 | 
						c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
 | 
				
			||||||
@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
 | 
				
			|||||||
	c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
 | 
						c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
 | 
				
			||||||
	c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
 | 
						c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
 | 
				
			||||||
	c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
 | 
						c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
 | 
				
			||||||
	c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
 | 
						c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
 | 
				
			||||||
	c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
 | 
						c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
				
			||||||
 | 
				
			|||||||
@ -138,24 +138,25 @@ ephemeral_node_inactivity_timeout: 30m
 | 
				
			|||||||
# In case of doubts, do not touch the default 10s.
 | 
					# In case of doubts, do not touch the default 10s.
 | 
				
			||||||
node_update_check_interval: 10s
 | 
					node_update_check_interval: 10s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# SQLite config
 | 
					database:
 | 
				
			||||||
db_type: sqlite3
 | 
					  type: sqlite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# For production:
 | 
					  # SQLite config
 | 
				
			||||||
db_path: /var/lib/headscale/db.sqlite
 | 
					  sqlite:
 | 
				
			||||||
 | 
					    path: /var/lib/headscale/db.sqlite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# # Postgres config
 | 
					  # # Postgres config
 | 
				
			||||||
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
 | 
					  # postgres:
 | 
				
			||||||
# db_type: postgres
 | 
					  #   # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
 | 
				
			||||||
# db_host: localhost
 | 
					  #   host: localhost
 | 
				
			||||||
# db_port: 5432
 | 
					  #   port: 5432
 | 
				
			||||||
# db_name: headscale
 | 
					  #   name: headscale
 | 
				
			||||||
# db_user: foo
 | 
					  #   user: foo
 | 
				
			||||||
# db_pass: bar
 | 
					  #   pass: bar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
 | 
					  #   # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
 | 
				
			||||||
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
 | 
					  #   # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
 | 
				
			||||||
# db_ssl: false
 | 
					  #   ssl: false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### TLS configuration
 | 
					### TLS configuration
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,6 @@ import (
 | 
				
			|||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"os/signal"
 | 
						"os/signal"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"syscall"
 | 
						"syscall"
 | 
				
			||||||
@ -118,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
		return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
 | 
							return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var dbString string
 | 
					 | 
				
			||||||
	switch cfg.DBtype {
 | 
					 | 
				
			||||||
	case db.Postgres:
 | 
					 | 
				
			||||||
		dbString = fmt.Sprintf(
 | 
					 | 
				
			||||||
			"host=%s dbname=%s user=%s",
 | 
					 | 
				
			||||||
			cfg.DBhost,
 | 
					 | 
				
			||||||
			cfg.DBname,
 | 
					 | 
				
			||||||
			cfg.DBuser,
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
 | 
					 | 
				
			||||||
			if !sslEnabled {
 | 
					 | 
				
			||||||
				dbString += " sslmode=disable"
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if cfg.DBport != 0 {
 | 
					 | 
				
			||||||
			dbString += fmt.Sprintf(" port=%d", cfg.DBport)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if cfg.DBpass != "" {
 | 
					 | 
				
			||||||
			dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case db.Sqlite:
 | 
					 | 
				
			||||||
		dbString = cfg.DBpath
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		return nil, errUnsupportedDatabase
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	registrationCache := cache.New(
 | 
						registrationCache := cache.New(
 | 
				
			||||||
		registerCacheExpiration,
 | 
							registerCacheExpiration,
 | 
				
			||||||
		registerCacheCleanup,
 | 
							registerCacheCleanup,
 | 
				
			||||||
@ -156,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	app := Headscale{
 | 
						app := Headscale{
 | 
				
			||||||
		cfg:                cfg,
 | 
							cfg:                cfg,
 | 
				
			||||||
		dbType:             cfg.DBtype,
 | 
					 | 
				
			||||||
		dbString:           dbString,
 | 
					 | 
				
			||||||
		noisePrivateKey:    noisePrivateKey,
 | 
							noisePrivateKey:    noisePrivateKey,
 | 
				
			||||||
		registrationCache:  registrationCache,
 | 
							registrationCache:  registrationCache,
 | 
				
			||||||
		pollNetMapStreamWG: sync.WaitGroup{},
 | 
							pollNetMapStreamWG: sync.WaitGroup{},
 | 
				
			||||||
@ -165,9 +131,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	database, err := db.NewHeadscaleDatabase(
 | 
						database, err := db.NewHeadscaleDatabase(
 | 
				
			||||||
		cfg.DBtype,
 | 
							cfg.Database,
 | 
				
			||||||
		dbString,
 | 
							app.nodeNotifier,
 | 
				
			||||||
		app.dbDebug,
 | 
					 | 
				
			||||||
		cfg.IPPrefixes,
 | 
							cfg.IPPrefixes,
 | 
				
			||||||
		cfg.BaseDomain)
 | 
							cfg.BaseDomain)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -755,14 +720,16 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var tailsqlContext context.Context
 | 
						var tailsqlContext context.Context
 | 
				
			||||||
	if tailsqlEnabled {
 | 
						if tailsqlEnabled {
 | 
				
			||||||
		if h.cfg.DBtype != db.Sqlite {
 | 
							if h.cfg.Database.Type != types.DatabaseSqlite {
 | 
				
			||||||
			log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite)
 | 
								log.Fatal().
 | 
				
			||||||
 | 
									Str("type", h.cfg.Database.Type).
 | 
				
			||||||
 | 
									Msgf("tailsql only support %q", types.DatabaseSqlite)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if tailsqlTSKey == "" {
 | 
							if tailsqlTSKey == "" {
 | 
				
			||||||
			log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
 | 
								log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		tailsqlContext = context.Background()
 | 
							tailsqlContext = context.Background()
 | 
				
			||||||
		go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath)
 | 
							go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Handle common process-killing signals so we can gracefully shut down:
 | 
						// Handle common process-killing signals so we can gracefully shut down:
 | 
				
			||||||
 | 
				
			|||||||
@ -6,11 +6,13 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/glebarez/sqlite"
 | 
						"github.com/glebarez/sqlite"
 | 
				
			||||||
	"github.com/go-gormigrate/gormigrate/v2"
 | 
						"github.com/go-gormigrate/gormigrate/v2"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
@ -19,11 +21,6 @@ import (
 | 
				
			|||||||
	"gorm.io/gorm/logger"
 | 
						"gorm.io/gorm/logger"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	Postgres = "postgres"
 | 
					 | 
				
			||||||
	Sqlite   = "sqlite3"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var errDatabaseNotSupported = errors.New("database type not supported")
 | 
					var errDatabaseNotSupported = errors.New("database type not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// KV is a key-value store in a psql table. For future use...
 | 
					// KV is a key-value store in a psql table. For future use...
 | 
				
			||||||
@ -43,12 +40,12 @@ type HSDatabase struct {
 | 
				
			|||||||
// TODO(kradalby): assemble this struct from toptions or something typed
 | 
					// TODO(kradalby): assemble this struct from toptions or something typed
 | 
				
			||||||
// rather than arguments.
 | 
					// rather than arguments.
 | 
				
			||||||
func NewHeadscaleDatabase(
 | 
					func NewHeadscaleDatabase(
 | 
				
			||||||
	dbType, connectionAddr string,
 | 
						cfg types.DatabaseConfig,
 | 
				
			||||||
	debug bool,
 | 
						notifier *notifier.Notifier,
 | 
				
			||||||
	ipPrefixes []netip.Prefix,
 | 
						ipPrefixes []netip.Prefix,
 | 
				
			||||||
	baseDomain string,
 | 
						baseDomain string,
 | 
				
			||||||
) (*HSDatabase, error) {
 | 
					) (*HSDatabase, error) {
 | 
				
			||||||
	dbConn, err := openDB(dbType, connectionAddr, debug)
 | 
						dbConn, err := openDB(cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -62,7 +59,7 @@ func NewHeadscaleDatabase(
 | 
				
			|||||||
		{
 | 
							{
 | 
				
			||||||
			ID: "202312101416",
 | 
								ID: "202312101416",
 | 
				
			||||||
			Migrate: func(tx *gorm.DB) error {
 | 
								Migrate: func(tx *gorm.DB) error {
 | 
				
			||||||
				if dbType == Postgres {
 | 
									if cfg.Type == types.DatabasePostgres {
 | 
				
			||||||
					tx.Exec(`create extension if not exists "uuid-ossp";`)
 | 
										tx.Exec(`create extension if not exists "uuid-ossp";`)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -321,20 +318,20 @@ func NewHeadscaleDatabase(
 | 
				
			|||||||
	return &db, err
 | 
						return &db, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
 | 
					func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
 | 
				
			||||||
	log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO(kradalby): Integrate this with zerolog
 | 
				
			||||||
	var dbLogger logger.Interface
 | 
						var dbLogger logger.Interface
 | 
				
			||||||
	if debug {
 | 
						if cfg.Debug {
 | 
				
			||||||
		dbLogger = logger.Default
 | 
							dbLogger = logger.Default
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		dbLogger = logger.Default.LogMode(logger.Silent)
 | 
							dbLogger = logger.Default.LogMode(logger.Silent)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch dbType {
 | 
						switch cfg.Type {
 | 
				
			||||||
	case Sqlite:
 | 
						case types.DatabaseSqlite:
 | 
				
			||||||
		db, err := gorm.Open(
 | 
							db, err := gorm.Open(
 | 
				
			||||||
			sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
 | 
								sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
 | 
				
			||||||
			&gorm.Config{
 | 
								&gorm.Config{
 | 
				
			||||||
				DisableForeignKeyConstraintWhenMigrating: true,
 | 
									DisableForeignKeyConstraintWhenMigrating: true,
 | 
				
			||||||
				Logger:                                   dbLogger,
 | 
									Logger:                                   dbLogger,
 | 
				
			||||||
@ -353,8 +350,31 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		return db, err
 | 
							return db, err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	case Postgres:
 | 
						case types.DatabasePostgres:
 | 
				
			||||||
		return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
 | 
							dbString := fmt.Sprintf(
 | 
				
			||||||
 | 
								"host=%s dbname=%s user=%s",
 | 
				
			||||||
 | 
								cfg.Postgres.Host,
 | 
				
			||||||
 | 
								cfg.Postgres.Name,
 | 
				
			||||||
 | 
								cfg.Postgres.User,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
 | 
				
			||||||
 | 
								if !sslEnabled {
 | 
				
			||||||
 | 
									dbString += " sslmode=disable"
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if cfg.Postgres.Port != 0 {
 | 
				
			||||||
 | 
								dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if cfg.Postgres.Pass != "" {
 | 
				
			||||||
 | 
								dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return gorm.Open(postgres.Open(dbString), &gorm.Config{
 | 
				
			||||||
			DisableForeignKeyConstraintWhenMigrating: true,
 | 
								DisableForeignKeyConstraintWhenMigrating: true,
 | 
				
			||||||
			Logger:                                   dbLogger,
 | 
								Logger:                                   dbLogger,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
@ -362,7 +382,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return nil, fmt.Errorf(
 | 
						return nil, fmt.Errorf(
 | 
				
			||||||
		"database of type %s is not supported: %w",
 | 
							"database of type %s is not supported: %w",
 | 
				
			||||||
		dbType,
 | 
							cfg.Type,
 | 
				
			||||||
		errDatabaseNotSupported,
 | 
							errDatabaseNotSupported,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,7 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/google/go-cmp/cmp"
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
@ -654,9 +655,13 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			db, err = NewHeadscaleDatabase(
 | 
								db, err = NewHeadscaleDatabase(
 | 
				
			||||||
				"sqlite3",
 | 
									types.DatabaseConfig{
 | 
				
			||||||
				tmpDir+"/headscale_test.db",
 | 
										Type: "sqlite3",
 | 
				
			||||||
				false,
 | 
										Sqlite: types.SqliteConfig{
 | 
				
			||||||
 | 
											Path: tmpDir + "/headscale_test.db",
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									notifier.NewNotifier(),
 | 
				
			||||||
				[]netip.Prefix{
 | 
									[]netip.Prefix{
 | 
				
			||||||
					netip.MustParsePrefix("10.27.0.0/23"),
 | 
										netip.MustParsePrefix("10.27.0.0/23"),
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
 | 
				
			|||||||
@ -6,6 +6,8 @@ import (
 | 
				
			|||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"gopkg.in/check.v1"
 | 
						"gopkg.in/check.v1"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -44,9 +46,13 @@ func (s *Suite) ResetDB(c *check.C) {
 | 
				
			|||||||
	log.Printf("database path: %s", tmpDir+"/headscale_test.db")
 | 
						log.Printf("database path: %s", tmpDir+"/headscale_test.db")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db, err = NewHeadscaleDatabase(
 | 
						db, err = NewHeadscaleDatabase(
 | 
				
			||||||
		"sqlite3",
 | 
							types.DatabaseConfig{
 | 
				
			||||||
		tmpDir+"/headscale_test.db",
 | 
								Type: "sqlite3",
 | 
				
			||||||
		false,
 | 
								Sqlite: types.SqliteConfig{
 | 
				
			||||||
 | 
									Path: tmpDir + "/headscale_test.db",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							notifier.NewNotifier(),
 | 
				
			||||||
		[]netip.Prefix{
 | 
							[]netip.Prefix{
 | 
				
			||||||
			netip.MustParsePrefix("10.27.0.0/23"),
 | 
								netip.MustParsePrefix("10.27.0.0/23"),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
				
			|||||||
@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	cfg := types.Config{
 | 
						cfg := types.Config{
 | 
				
			||||||
		NoisePrivateKeyPath: tmpDir + "/noise_private.key",
 | 
							NoisePrivateKeyPath: tmpDir + "/noise_private.key",
 | 
				
			||||||
		DBtype:              "sqlite3",
 | 
							Database: types.DatabaseConfig{
 | 
				
			||||||
		DBpath:              tmpDir + "/headscale_test.db",
 | 
								Type: "sqlite3",
 | 
				
			||||||
 | 
								Sqlite: types.SqliteConfig{
 | 
				
			||||||
 | 
									Path: tmpDir + "/headscale_test.db",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		IPPrefixes: []netip.Prefix{
 | 
							IPPrefixes: []netip.Prefix{
 | 
				
			||||||
			netip.MustParsePrefix("10.27.0.0/23"),
 | 
								netip.MustParsePrefix("10.27.0.0/23"),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,11 @@ import (
 | 
				
			|||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const SelfUpdateIdentifier = "self-update"
 | 
					const (
 | 
				
			||||||
 | 
						SelfUpdateIdentifier = "self-update"
 | 
				
			||||||
 | 
						DatabasePostgres     = "postgres"
 | 
				
			||||||
 | 
						DatabaseSqlite       = "sqlite3"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
 | 
					var ErrCannotParsePrefix = errors.New("cannot parse prefix")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -154,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case StateSelfUpdate:
 | 
						case StateSelfUpdate:
 | 
				
			||||||
		if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
 | 
							if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
 | 
				
			||||||
			panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node")
 | 
								panic(
 | 
				
			||||||
 | 
									"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case StateDERPUpdated:
 | 
						case StateDERPUpdated:
 | 
				
			||||||
		if su.DERPMap == nil {
 | 
							if su.DERPMap == nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -46,16 +46,9 @@ type Config struct {
 | 
				
			|||||||
	Log                            LogConfig
 | 
						Log                            LogConfig
 | 
				
			||||||
	DisableUpdateCheck             bool
 | 
						DisableUpdateCheck             bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DERP DERPConfig
 | 
						Database DatabaseConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DBtype string
 | 
						DERP DERPConfig
 | 
				
			||||||
	DBpath string
 | 
					 | 
				
			||||||
	DBhost string
 | 
					 | 
				
			||||||
	DBport int
 | 
					 | 
				
			||||||
	DBname string
 | 
					 | 
				
			||||||
	DBuser string
 | 
					 | 
				
			||||||
	DBpass string
 | 
					 | 
				
			||||||
	DBssl  string
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	TLS TLSConfig
 | 
						TLS TLSConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -77,6 +70,28 @@ type Config struct {
 | 
				
			|||||||
	ACL ACLConfig
 | 
						ACL ACLConfig
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SqliteConfig struct {
 | 
				
			||||||
 | 
						Path string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PostgresConfig struct {
 | 
				
			||||||
 | 
						Host string
 | 
				
			||||||
 | 
						Port int
 | 
				
			||||||
 | 
						Name string
 | 
				
			||||||
 | 
						User string
 | 
				
			||||||
 | 
						Pass string
 | 
				
			||||||
 | 
						Ssl  string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type DatabaseConfig struct {
 | 
				
			||||||
 | 
						// Type sets the database type, either "sqlite3" or "postgres"
 | 
				
			||||||
 | 
						Type  string
 | 
				
			||||||
 | 
						Debug bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Sqlite   SqliteConfig
 | 
				
			||||||
 | 
						Postgres PostgresConfig
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TLSConfig struct {
 | 
					type TLSConfig struct {
 | 
				
			||||||
	CertPath string
 | 
						CertPath string
 | 
				
			||||||
	KeyPath  string
 | 
						KeyPath  string
 | 
				
			||||||
@ -161,6 +176,19 @@ func LoadConfig(path string, isFile bool) error {
 | 
				
			|||||||
	viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
 | 
						viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
 | 
				
			||||||
	viper.AutomaticEnv()
 | 
						viper.AutomaticEnv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_type", "database.type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// SQLite aliases
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_path", "database.sqlite.path")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Postgres aliases
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_host", "database.postgres.host")
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_port", "database.postgres.port")
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_name", "database.postgres.name")
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_user", "database.postgres.user")
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_pass", "database.postgres.pass")
 | 
				
			||||||
 | 
						viper.RegisterAlias("db_ssl", "database.postgres.ssl")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
 | 
						viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
 | 
				
			||||||
	viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
 | 
						viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -184,6 +212,7 @@ func LoadConfig(path string, isFile bool) error {
 | 
				
			|||||||
	viper.SetDefault("cli.insecure", false)
 | 
						viper.SetDefault("cli.insecure", false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("db_ssl", false)
 | 
						viper.SetDefault("db_ssl", false)
 | 
				
			||||||
 | 
						viper.SetDefault("database.postgres.ssl", false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
 | 
						viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
 | 
				
			||||||
	viper.SetDefault("oidc.strip_email_domain", true)
 | 
						viper.SetDefault("oidc.strip_email_domain", true)
 | 
				
			||||||
@ -389,6 +418,37 @@ func GetLogConfig() LogConfig {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetDatabaseConfig() DatabaseConfig {
 | 
				
			||||||
 | 
						debug := viper.GetBool("database.debug")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type_ := viper.GetString("database.type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch type_ {
 | 
				
			||||||
 | 
						case DatabaseSqlite, DatabasePostgres:
 | 
				
			||||||
 | 
							break
 | 
				
			||||||
 | 
						case "sqlite":
 | 
				
			||||||
 | 
							type_ = "sqlite3"
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							log.Fatal().Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return DatabaseConfig{
 | 
				
			||||||
 | 
							Type:  type_,
 | 
				
			||||||
 | 
							Debug: debug,
 | 
				
			||||||
 | 
							Sqlite: SqliteConfig{
 | 
				
			||||||
 | 
								Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							Postgres: PostgresConfig{
 | 
				
			||||||
 | 
								Host: viper.GetString("database.postgres.host"),
 | 
				
			||||||
 | 
								Port: viper.GetInt("database.postgres.port"),
 | 
				
			||||||
 | 
								Name: viper.GetString("database.postgres.name"),
 | 
				
			||||||
 | 
								User: viper.GetString("database.postgres.user"),
 | 
				
			||||||
 | 
								Pass: viper.GetString("database.postgres.pass"),
 | 
				
			||||||
 | 
								Ssl:  viper.GetString("database.postgres.ssl"),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
 | 
					func GetDNSConfig() (*tailcfg.DNSConfig, string) {
 | 
				
			||||||
	if viper.IsSet("dns_config") {
 | 
						if viper.IsSet("dns_config") {
 | 
				
			||||||
		dnsConfig := &tailcfg.DNSConfig{}
 | 
							dnsConfig := &tailcfg.DNSConfig{}
 | 
				
			||||||
@ -617,14 +677,7 @@ func GetHeadscaleConfig() (*Config, error) {
 | 
				
			|||||||
			"node_update_check_interval",
 | 
								"node_update_check_interval",
 | 
				
			||||||
		),
 | 
							),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		DBtype: viper.GetString("db_type"),
 | 
							Database: GetDatabaseConfig(),
 | 
				
			||||||
		DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
 | 
					 | 
				
			||||||
		DBhost: viper.GetString("db_host"),
 | 
					 | 
				
			||||||
		DBport: viper.GetInt("db_port"),
 | 
					 | 
				
			||||||
		DBname: viper.GetString("db_name"),
 | 
					 | 
				
			||||||
		DBuser: viper.GetString("db_user"),
 | 
					 | 
				
			||||||
		DBpass: viper.GetString("db_pass"),
 | 
					 | 
				
			||||||
		DBssl:  viper.GetString("db_ssl"),
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		TLS: GetTLSConfig(),
 | 
							TLS: GetTLSConfig(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string {
 | 
				
			|||||||
	return map[string]string{
 | 
						return map[string]string{
 | 
				
			||||||
		"HEADSCALE_LOG_LEVEL":                         "trace",
 | 
							"HEADSCALE_LOG_LEVEL":                         "trace",
 | 
				
			||||||
		"HEADSCALE_ACL_POLICY_PATH":                   "",
 | 
							"HEADSCALE_ACL_POLICY_PATH":                   "",
 | 
				
			||||||
		"HEADSCALE_DB_TYPE":                           "sqlite3",
 | 
							"HEADSCALE_DATABASE_TYPE":                     "sqlite",
 | 
				
			||||||
		"HEADSCALE_DB_PATH":                           "/tmp/integration_test_db.sqlite3",
 | 
							"HEADSCALE_DATABASE_SQLITE_PATH":              "/tmp/integration_test_db.sqlite3",
 | 
				
			||||||
		"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
 | 
							"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
 | 
				
			||||||
		"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL":        "10s",
 | 
							"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL":        "10s",
 | 
				
			||||||
		"HEADSCALE_IP_PREFIXES":                       "fd7a:115c:a1e0::/48 100.64.0.0/10",
 | 
							"HEADSCALE_IP_PREFIXES":                       "fd7a:115c:a1e0::/48 100.64.0.0/10",
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user