mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-30 23:51:03 +01:00 
			
		
		
		
	This commit fixes the constraint syntax so it is both valid for sqlite and postgres. To validate this, I've added a new postgres testing library and a helper that will spin up local postgres, setup a db and use it in the constraints tests. This should also help testing db stuff in the future. postgres has been added to the nix dev shell and is now required for running the unit tests. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
		
			
				
	
	
		
			379 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			379 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package db
 | ||
| 
 | ||
| import (
 | ||
| 	"database/sql"
 | ||
| 	"fmt"
 | ||
| 	"io"
 | ||
| 	"net/netip"
 | ||
| 	"os"
 | ||
| 	"path/filepath"
 | ||
| 	"slices"
 | ||
| 	"sort"
 | ||
| 	"strings"
 | ||
| 	"testing"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"github.com/google/go-cmp/cmp"
 | ||
| 	"github.com/google/go-cmp/cmp/cmpopts"
 | ||
| 	"github.com/juanfont/headscale/hscontrol/types"
 | ||
| 	"github.com/juanfont/headscale/hscontrol/util"
 | ||
| 	"github.com/stretchr/testify/assert"
 | ||
| 	"github.com/stretchr/testify/require"
 | ||
| 	"gorm.io/gorm"
 | ||
| 	"zgo.at/zcache/v2"
 | ||
| )
 | ||
| 
 | ||
| func TestMigrations(t *testing.T) {
 | ||
| 	ipp := func(p string) netip.Prefix {
 | ||
| 		return netip.MustParsePrefix(p)
 | ||
| 	}
 | ||
| 	r := func(id uint64, p string, a, e, i bool) types.Route {
 | ||
| 		return types.Route{
 | ||
| 			NodeID:     id,
 | ||
| 			Prefix:     ipp(p),
 | ||
| 			Advertised: a,
 | ||
| 			Enabled:    e,
 | ||
| 			IsPrimary:  i,
 | ||
| 		}
 | ||
| 	}
 | ||
| 	tests := []struct {
 | ||
| 		dbPath   string
 | ||
| 		wantFunc func(*testing.T, *HSDatabase)
 | ||
| 		wantErr  string
 | ||
| 	}{
 | ||
| 		{
 | ||
| 			dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
 | ||
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | ||
| 				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | ||
| 					return GetRoutes(rx)
 | ||
| 				})
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				assert.Len(t, routes, 10)
 | ||
| 				want := types.Routes{
 | ||
| 					r(1, "0.0.0.0/0", true, true, false),
 | ||
| 					r(1, "::/0", true, true, false),
 | ||
| 					r(1, "10.9.110.0/24", true, true, true),
 | ||
| 					r(26, "172.100.100.0/24", true, true, true),
 | ||
| 					r(26, "172.100.100.0/24", true, false, false),
 | ||
| 					r(31, "0.0.0.0/0", true, true, false),
 | ||
| 					r(31, "0.0.0.0/0", true, false, false),
 | ||
| 					r(31, "::/0", true, true, false),
 | ||
| 					r(31, "::/0", true, false, false),
 | ||
| 					r(32, "192.168.0.24/32", true, true, true),
 | ||
| 				}
 | ||
| 				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
 | ||
| 					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | ||
| 				}
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
 | ||
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | ||
| 				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | ||
| 					return GetRoutes(rx)
 | ||
| 				})
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				assert.Len(t, routes, 4)
 | ||
| 				want := types.Routes{
 | ||
| 					// These routes exists, but have no nodes associated with them
 | ||
| 					// when the migration starts.
 | ||
| 					// r(1, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(1, "::/0", true, true, false),
 | ||
| 					// r(3, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(3, "::/0", true, true, false),
 | ||
| 					// r(5, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(5, "::/0", true, true, false),
 | ||
| 					// r(6, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(6, "::/0", true, true, false),
 | ||
| 					// r(6, "10.0.0.0/8", true, false, false),
 | ||
| 					// r(7, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(7, "::/0", true, true, false),
 | ||
| 					// r(7, "10.0.0.0/8", true, false, false),
 | ||
| 					// r(9, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(9, "::/0", true, true, false),
 | ||
| 					// r(9, "10.0.0.0/8", true, true, false),
 | ||
| 					// r(11, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(11, "::/0", true, true, false),
 | ||
| 					// r(11, "10.0.0.0/8", true, true, true),
 | ||
| 					// r(12, "0.0.0.0/0", true, true, false),
 | ||
| 					// r(12, "::/0", true, true, false),
 | ||
| 					// r(12, "10.0.0.0/8", true, false, false),
 | ||
| 					//
 | ||
| 					// These nodes exists, so routes should be kept.
 | ||
| 					r(13, "10.0.0.0/8", true, false, false),
 | ||
| 					r(13, "0.0.0.0/0", true, true, false),
 | ||
| 					r(13, "::/0", true, true, false),
 | ||
| 					r(13, "10.18.80.2/32", true, true, true),
 | ||
| 				}
 | ||
| 				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
 | ||
| 					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | ||
| 				}
 | ||
| 			},
 | ||
| 		},
 | ||
| 		// at 14:15:06 ❯ go run ./cmd/headscale preauthkeys list
 | ||
| 		// ID | Key      | Reusable | Ephemeral | Used  | Expiration | Created    | Tags
 | ||
| 		// 1  | 09b28f.. | false    | false     | false | 2024-09-27 | 2024-09-27 | tag:derp
 | ||
| 		// 2  | 3112b9.. | false    | false     | false | 2024-09-27 | 2024-09-27 | tag:derp
 | ||
| 		// 3  | 7c23b9.. | false    | false     | false | 2024-09-27 | 2024-09-27 | tag:derp,tag:merp
 | ||
| 		// 4  | f20155.. | false    | false     | false | 2024-09-27 | 2024-09-27 | tag:test
 | ||
| 		// 5  | b212b9.. | false    | false     | false | 2024-09-27 | 2024-09-27 | tag:test,tag:woop,tag:dedu
 | ||
| 		{
 | ||
| 			dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
 | ||
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | ||
| 				keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
 | ||
| 					kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest
 | ||
| 					if err != nil {
 | ||
| 						return nil, err
 | ||
| 					}
 | ||
| 
 | ||
| 					testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra
 | ||
| 					if err != nil {
 | ||
| 						return nil, err
 | ||
| 					}
 | ||
| 
 | ||
| 					return append(kratest, testkra...), nil
 | ||
| 				})
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				assert.Len(t, keys, 5)
 | ||
| 				want := []types.PreAuthKey{
 | ||
| 					{
 | ||
| 						ID:   1,
 | ||
| 						Tags: []string{"tag:derp"},
 | ||
| 					},
 | ||
| 					{
 | ||
| 						ID:   2,
 | ||
| 						Tags: []string{"tag:derp"},
 | ||
| 					},
 | ||
| 					{
 | ||
| 						ID:   3,
 | ||
| 						Tags: []string{"tag:derp", "tag:merp"},
 | ||
| 					},
 | ||
| 					{
 | ||
| 						ID:   4,
 | ||
| 						Tags: []string{"tag:test"},
 | ||
| 					},
 | ||
| 					{
 | ||
| 						ID:   5,
 | ||
| 						Tags: []string{"tag:test", "tag:woop", "tag:dedu"},
 | ||
| 					},
 | ||
| 				}
 | ||
| 
 | ||
| 				if diff := cmp.Diff(want, keys, cmp.Comparer(func(a, b []string) bool {
 | ||
| 					sort.Sort(sort.StringSlice(a))
 | ||
| 					sort.Sort(sort.StringSlice(b))
 | ||
| 					return slices.Equal(a, b)
 | ||
| 				}), cmpopts.IgnoreFields(types.PreAuthKey{}, "Key", "UserID", "User", "CreatedAt", "Expiration")); diff != "" {
 | ||
| 					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | ||
| 				}
 | ||
| 
 | ||
| 				if h.DB.Migrator().HasTable("pre_auth_key_acl_tags") {
 | ||
| 					t.Errorf("TestMigrations() table pre_auth_key_acl_tags should not exist")
 | ||
| 				}
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
 | ||
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | ||
| 				nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
 | ||
| 					return ListNodes(rx)
 | ||
| 				})
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				for _, node := range nodes {
 | ||
| 					assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
 | ||
| 					assert.Contains(t, node.MachineKey.String(), "mkey:")
 | ||
| 					assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
 | ||
| 					assert.Contains(t, node.NodeKey.String(), "nodekey:")
 | ||
| 					assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
 | ||
| 					assert.Contains(t, node.DiscoKey.String(), "discokey:")
 | ||
| 					assert.NotNil(t, node.IPv4)
 | ||
| 					assert.NotNil(t, node.IPv4)
 | ||
| 					assert.Len(t, node.Endpoints, 1)
 | ||
| 					assert.NotNil(t, node.Hostinfo)
 | ||
| 					assert.NotNil(t, node.MachineKey)
 | ||
| 				}
 | ||
| 			},
 | ||
| 		},
 | ||
| 	}
 | ||
| 
 | ||
| 	for _, tt := range tests {
 | ||
| 		t.Run(tt.dbPath, func(t *testing.T) {
 | ||
| 			dbPath, err := testCopyOfDatabase(tt.dbPath)
 | ||
| 			if err != nil {
 | ||
| 				t.Fatalf("copying db for test: %s", err)
 | ||
| 			}
 | ||
| 
 | ||
| 			hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{
 | ||
| 				Type: "sqlite3",
 | ||
| 				Sqlite: types.SqliteConfig{
 | ||
| 					Path: dbPath,
 | ||
| 				},
 | ||
| 			}, "", emptyCache())
 | ||
| 			if err != nil && tt.wantErr != err.Error() {
 | ||
| 				t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
 | ||
| 			}
 | ||
| 
 | ||
| 			if tt.wantFunc != nil {
 | ||
| 				tt.wantFunc(t, hsdb)
 | ||
| 			}
 | ||
| 		})
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| func testCopyOfDatabase(src string) (string, error) {
 | ||
| 	sourceFileStat, err := os.Stat(src)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 
 | ||
| 	if !sourceFileStat.Mode().IsRegular() {
 | ||
| 		return "", fmt.Errorf("%s is not a regular file", src)
 | ||
| 	}
 | ||
| 
 | ||
| 	source, err := os.Open(src)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	defer source.Close()
 | ||
| 
 | ||
| 	tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 
 | ||
| 	fn := filepath.Base(src)
 | ||
| 	dst := filepath.Join(tmpDir, fn)
 | ||
| 
 | ||
| 	destination, err := os.Create(dst)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	defer destination.Close()
 | ||
| 	_, err = io.Copy(destination, source)
 | ||
| 	return dst, err
 | ||
| }
 | ||
| 
 | ||
| func emptyCache() *zcache.Cache[string, types.Node] {
 | ||
| 	return zcache.New[string, types.Node](time.Minute, time.Hour)
 | ||
| }
 | ||
| 
 | ||
| // requireConstraintFailed checks if the error is a constraint failure with
 | ||
| // either SQLite and PostgreSQL error messages.
 | ||
| func requireConstraintFailed(t *testing.T, err error) {
 | ||
| 	t.Helper()
 | ||
| 	require.Error(t, err)
 | ||
| 	if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
 | ||
| 		require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| func TestConstraints(t *testing.T) {
 | ||
| 	tests := []struct {
 | ||
| 		name string
 | ||
| 		run  func(*testing.T, *gorm.DB)
 | ||
| 	}{
 | ||
| 		{
 | ||
| 			name: "no-duplicate-username-if-no-oidc",
 | ||
| 			run: func(t *testing.T, db *gorm.DB) {
 | ||
| 				_, err := CreateUser(db, "user1")
 | ||
| 				require.NoError(t, err)
 | ||
| 				_, err = CreateUser(db, "user1")
 | ||
| 				requireConstraintFailed(t, err)
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			name: "no-oidc-duplicate-username-and-id",
 | ||
| 			run: func(t *testing.T, db *gorm.DB) {
 | ||
| 				user := types.User{
 | ||
| 					Model: gorm.Model{ID: 1},
 | ||
| 					Name:  "user1",
 | ||
| 				}
 | ||
| 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
 | ||
| 
 | ||
| 				err := db.Save(&user).Error
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				user = types.User{
 | ||
| 					Model: gorm.Model{ID: 2},
 | ||
| 					Name:  "user1",
 | ||
| 				}
 | ||
| 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
 | ||
| 
 | ||
| 				err = db.Save(&user).Error
 | ||
| 				requireConstraintFailed(t, err)
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			name: "no-oidc-duplicate-id",
 | ||
| 			run: func(t *testing.T, db *gorm.DB) {
 | ||
| 				user := types.User{
 | ||
| 					Model: gorm.Model{ID: 1},
 | ||
| 					Name:  "user1",
 | ||
| 				}
 | ||
| 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
 | ||
| 
 | ||
| 				err := db.Save(&user).Error
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				user = types.User{
 | ||
| 					Model: gorm.Model{ID: 2},
 | ||
| 					Name:  "user1.1",
 | ||
| 				}
 | ||
| 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
 | ||
| 
 | ||
| 				err = db.Save(&user).Error
 | ||
| 				requireConstraintFailed(t, err)
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			name: "allow-duplicate-username-cli-then-oidc",
 | ||
| 			run: func(t *testing.T, db *gorm.DB) {
 | ||
| 				_, err := CreateUser(db, "user1") // Create CLI username
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				user := types.User{
 | ||
| 					Name:               "user1",
 | ||
| 					ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
 | ||
| 				}
 | ||
| 
 | ||
| 				err = db.Save(&user).Error
 | ||
| 				require.NoError(t, err)
 | ||
| 			},
 | ||
| 		},
 | ||
| 		{
 | ||
| 			name: "allow-duplicate-username-oidc-then-cli",
 | ||
| 			run: func(t *testing.T, db *gorm.DB) {
 | ||
| 				user := types.User{
 | ||
| 					Name:               "user1",
 | ||
| 					ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
 | ||
| 				}
 | ||
| 
 | ||
| 				err := db.Save(&user).Error
 | ||
| 				require.NoError(t, err)
 | ||
| 
 | ||
| 				_, err = CreateUser(db, "user1") // Create CLI username
 | ||
| 				require.NoError(t, err)
 | ||
| 			},
 | ||
| 		},
 | ||
| 	}
 | ||
| 
 | ||
| 	for _, tt := range tests {
 | ||
| 		t.Run(tt.name+"-postgres", func(t *testing.T) {
 | ||
| 			db := newPostgresTestDB(t)
 | ||
| 			tt.run(t, db.DB.Debug())
 | ||
| 		})
 | ||
| 		t.Run(tt.name+"-sqlite", func(t *testing.T) {
 | ||
| 			db, err := newSQLiteTestDB()
 | ||
| 			if err != nil {
 | ||
| 				t.Fatalf("creating database: %s", err)
 | ||
| 			}
 | ||
| 
 | ||
| 			tt.run(t, db.DB.Debug())
 | ||
| 		})
 | ||
| 
 | ||
| 	}
 | ||
| }
 |