mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			169 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			169 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package db
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net/netip"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/google/go-cmp/cmp"
 | |
| 	"github.com/google/go-cmp/cmp/cmpopts"
 | |
| 	"github.com/juanfont/headscale/hscontrol/types"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"gorm.io/gorm"
 | |
| )
 | |
| 
 | |
| func TestMigrations(t *testing.T) {
 | |
| 	ipp := func(p string) types.IPPrefix {
 | |
| 		return types.IPPrefix(netip.MustParsePrefix(p))
 | |
| 	}
 | |
| 	r := func(id uint64, p string, a, e, i bool) types.Route {
 | |
| 		return types.Route{
 | |
| 			NodeID:     id,
 | |
| 			Prefix:     ipp(p),
 | |
| 			Advertised: a,
 | |
| 			Enabled:    e,
 | |
| 			IsPrimary:  i,
 | |
| 		}
 | |
| 	}
 | |
| 	tests := []struct {
 | |
| 		dbPath   string
 | |
| 		wantFunc func(*testing.T, *HSDatabase)
 | |
| 		wantErr  string
 | |
| 	}{
 | |
| 		{
 | |
| 			dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
 | |
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | |
| 				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | |
| 					return GetRoutes(rx)
 | |
| 				})
 | |
| 				assert.NoError(t, err)
 | |
| 
 | |
| 				assert.Len(t, routes, 10)
 | |
| 				want := types.Routes{
 | |
| 					r(1, "0.0.0.0/0", true, true, false),
 | |
| 					r(1, "::/0", true, true, false),
 | |
| 					r(1, "10.9.110.0/24", true, true, true),
 | |
| 					r(26, "172.100.100.0/24", true, true, true),
 | |
| 					r(26, "172.100.100.0/24", true, false, false),
 | |
| 					r(31, "0.0.0.0/0", true, true, false),
 | |
| 					r(31, "0.0.0.0/0", true, false, false),
 | |
| 					r(31, "::/0", true, true, false),
 | |
| 					r(31, "::/0", true, false, false),
 | |
| 					r(32, "192.168.0.24/32", true, true, true),
 | |
| 				}
 | |
| 				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
 | |
| 					return x == y
 | |
| 				})); diff != "" {
 | |
| 					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
 | |
| 			wantFunc: func(t *testing.T, h *HSDatabase) {
 | |
| 				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | |
| 					return GetRoutes(rx)
 | |
| 				})
 | |
| 				assert.NoError(t, err)
 | |
| 
 | |
| 				assert.Len(t, routes, 4)
 | |
| 				want := types.Routes{
 | |
| 					// These routes exists, but have no nodes associated with them
 | |
| 					// when the migration starts.
 | |
| 					// r(1, "0.0.0.0/0", true, true, false),
 | |
| 					// r(1, "::/0", true, true, false),
 | |
| 					// r(3, "0.0.0.0/0", true, true, false),
 | |
| 					// r(3, "::/0", true, true, false),
 | |
| 					// r(5, "0.0.0.0/0", true, true, false),
 | |
| 					// r(5, "::/0", true, true, false),
 | |
| 					// r(6, "0.0.0.0/0", true, true, false),
 | |
| 					// r(6, "::/0", true, true, false),
 | |
| 					// r(6, "10.0.0.0/8", true, false, false),
 | |
| 					// r(7, "0.0.0.0/0", true, true, false),
 | |
| 					// r(7, "::/0", true, true, false),
 | |
| 					// r(7, "10.0.0.0/8", true, false, false),
 | |
| 					// r(9, "0.0.0.0/0", true, true, false),
 | |
| 					// r(9, "::/0", true, true, false),
 | |
| 					// r(9, "10.0.0.0/8", true, true, false),
 | |
| 					// r(11, "0.0.0.0/0", true, true, false),
 | |
| 					// r(11, "::/0", true, true, false),
 | |
| 					// r(11, "10.0.0.0/8", true, true, true),
 | |
| 					// r(12, "0.0.0.0/0", true, true, false),
 | |
| 					// r(12, "::/0", true, true, false),
 | |
| 					// r(12, "10.0.0.0/8", true, false, false),
 | |
| 					//
 | |
| 					// These nodes exists, so routes should be kept.
 | |
| 					r(13, "10.0.0.0/8", true, false, false),
 | |
| 					r(13, "0.0.0.0/0", true, true, false),
 | |
| 					r(13, "::/0", true, true, false),
 | |
| 					r(13, "10.18.80.2/32", true, true, true),
 | |
| 				}
 | |
| 				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
 | |
| 					return x == y
 | |
| 				})); diff != "" {
 | |
| 					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.dbPath, func(t *testing.T) {
 | |
| 			dbPath, err := testCopyOfDatabase(tt.dbPath)
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("copying db for test: %s", err)
 | |
| 			}
 | |
| 
 | |
| 			hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{
 | |
| 				Type: "sqlite3",
 | |
| 				Sqlite: types.SqliteConfig{
 | |
| 					Path: dbPath,
 | |
| 				},
 | |
| 			}, "")
 | |
| 			if err != nil && tt.wantErr != err.Error() {
 | |
| 				t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
 | |
| 			}
 | |
| 
 | |
| 			if tt.wantFunc != nil {
 | |
| 				tt.wantFunc(t, hsdb)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func testCopyOfDatabase(src string) (string, error) {
 | |
| 	sourceFileStat, err := os.Stat(src)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	if !sourceFileStat.Mode().IsRegular() {
 | |
| 		return "", fmt.Errorf("%s is not a regular file", src)
 | |
| 	}
 | |
| 
 | |
| 	source, err := os.Open(src)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	defer source.Close()
 | |
| 
 | |
| 	tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	fn := filepath.Base(src)
 | |
| 	dst := filepath.Join(tmpDir, fn)
 | |
| 
 | |
| 	destination, err := os.Create(dst)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	defer destination.Close()
 | |
| 	_, err = io.Copy(destination, source)
 | |
| 	return dst, err
 | |
| }
 |