mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 01:51:04 +01:00 
			
		
		
		
	Rework map session
This commit restructures the map session in to a struct holding the state of what is needed during its lifetime. For streaming sessions, the event loop is structured a bit differently not hammering the clients with updates but rather batching them over a short, configurable time which should significantly improve cpu usage, and potentially flakyness. The use of Patch updates has been dialed back a little as it does not look like its a 100% ready for prime time. Nodes are now updated with full changes, except for a few things like online status. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									dd693c444c
								
							
						
					
					
						commit
						58c94d2bd3
					
				
							
								
								
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							@ -43,7 +43,8 @@ jobs:
 | 
				
			|||||||
          - TestTaildrop
 | 
					          - TestTaildrop
 | 
				
			||||||
          - TestResolveMagicDNS
 | 
					          - TestResolveMagicDNS
 | 
				
			||||||
          - TestExpireNode
 | 
					          - TestExpireNode
 | 
				
			||||||
          - TestNodeOnlineLastSeenStatus
 | 
					          - TestNodeOnlineStatus
 | 
				
			||||||
 | 
					          - TestPingAllByIPManyUpDown
 | 
				
			||||||
          - TestEnablingRoutes
 | 
					          - TestEnablingRoutes
 | 
				
			||||||
          - TestHASubnetRouterFailover
 | 
					          - TestHASubnetRouterFailover
 | 
				
			||||||
          - TestEnableDisableAutoApprovedRoute
 | 
					          - TestEnableDisableAutoApprovedRoute
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@ -150,6 +150,7 @@ require (
 | 
				
			|||||||
	github.com/opencontainers/image-spec v1.1.0-rc6 // indirect
 | 
						github.com/opencontainers/image-spec v1.1.0-rc6 // indirect
 | 
				
			||||||
	github.com/opencontainers/runc v1.1.12 // indirect
 | 
						github.com/opencontainers/runc v1.1.12 // indirect
 | 
				
			||||||
	github.com/pelletier/go-toml/v2 v2.1.1 // indirect
 | 
						github.com/pelletier/go-toml/v2 v2.1.1 // indirect
 | 
				
			||||||
 | 
						github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
 | 
				
			||||||
	github.com/pierrec/lz4/v4 v4.1.21 // indirect
 | 
						github.com/pierrec/lz4/v4 v4.1.21 // indirect
 | 
				
			||||||
	github.com/pkg/errors v0.9.1 // indirect
 | 
						github.com/pkg/errors v0.9.1 // indirect
 | 
				
			||||||
	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
 | 
						github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
 | 
				
			||||||
@ -161,6 +162,7 @@ require (
 | 
				
			|||||||
	github.com/safchain/ethtool v0.3.0 // indirect
 | 
						github.com/safchain/ethtool v0.3.0 // indirect
 | 
				
			||||||
	github.com/sagikazarmark/locafero v0.4.0 // indirect
 | 
						github.com/sagikazarmark/locafero v0.4.0 // indirect
 | 
				
			||||||
	github.com/sagikazarmark/slog-shim v0.1.0 // indirect
 | 
						github.com/sagikazarmark/slog-shim v0.1.0 // indirect
 | 
				
			||||||
 | 
						github.com/sasha-s/go-deadlock v0.3.1 // indirect
 | 
				
			||||||
	github.com/sirupsen/logrus v1.9.3 // indirect
 | 
						github.com/sirupsen/logrus v1.9.3 // indirect
 | 
				
			||||||
	github.com/sourcegraph/conc v0.3.0 // indirect
 | 
						github.com/sourcegraph/conc v0.3.0 // indirect
 | 
				
			||||||
	github.com/spf13/afero v1.11.0 // indirect
 | 
						github.com/spf13/afero v1.11.0 // indirect
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -336,6 +336,8 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
 | 
				
			|||||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
 | 
					github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
 | 
				
			||||||
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
 | 
					github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
 | 
				
			||||||
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
 | 
					github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
 | 
				
			||||||
 | 
					github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
 | 
				
			||||||
 | 
					github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
 | 
				
			||||||
github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA=
 | 
					github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA=
 | 
				
			||||||
github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ=
 | 
					github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ=
 | 
				
			||||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
 | 
					github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
 | 
				
			||||||
@ -392,6 +394,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g
 | 
				
			|||||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
 | 
					github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
 | 
				
			||||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
 | 
					github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
 | 
				
			||||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 | 
					github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 | 
				
			||||||
 | 
					github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0=
 | 
				
			||||||
 | 
					github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM=
 | 
				
			||||||
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
 | 
					github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
 | 
				
			||||||
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
 | 
					github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
 | 
				
			||||||
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
 | 
					github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
 | 
				
			||||||
 | 
				
			|||||||
@ -28,6 +28,7 @@ import (
 | 
				
			|||||||
	"github.com/juanfont/headscale/hscontrol/db"
 | 
						"github.com/juanfont/headscale/hscontrol/db"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/derp"
 | 
						"github.com/juanfont/headscale/hscontrol/derp"
 | 
				
			||||||
	derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
 | 
						derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/mapper"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/notifier"
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/policy"
 | 
						"github.com/juanfont/headscale/hscontrol/policy"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
@ -38,6 +39,7 @@ import (
 | 
				
			|||||||
	"github.com/prometheus/client_golang/prometheus/promhttp"
 | 
						"github.com/prometheus/client_golang/prometheus/promhttp"
 | 
				
			||||||
	zl "github.com/rs/zerolog"
 | 
						zl "github.com/rs/zerolog"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
 | 
						"github.com/sasha-s/go-deadlock"
 | 
				
			||||||
	"golang.org/x/crypto/acme"
 | 
						"golang.org/x/crypto/acme"
 | 
				
			||||||
	"golang.org/x/crypto/acme/autocert"
 | 
						"golang.org/x/crypto/acme/autocert"
 | 
				
			||||||
	"golang.org/x/oauth2"
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
@ -77,6 +79,11 @@ const (
 | 
				
			|||||||
	registerCacheCleanup    = time.Minute * 20
 | 
						registerCacheCleanup    = time.Minute * 20
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						deadlock.Opts.DeadlockTimeout = 15 * time.Second
 | 
				
			||||||
 | 
						deadlock.Opts.PrintAllCurrentGoroutines = true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Headscale represents the base app of the service.
 | 
					// Headscale represents the base app of the service.
 | 
				
			||||||
type Headscale struct {
 | 
					type Headscale struct {
 | 
				
			||||||
	cfg             *types.Config
 | 
						cfg             *types.Config
 | 
				
			||||||
@ -89,6 +96,7 @@ type Headscale struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	ACLPolicy *policy.ACLPolicy
 | 
						ACLPolicy *policy.ACLPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mapper       *mapper.Mapper
 | 
				
			||||||
	nodeNotifier *notifier.Notifier
 | 
						nodeNotifier *notifier.Notifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oidcProvider *oidc.Provider
 | 
						oidcProvider *oidc.Provider
 | 
				
			||||||
@ -96,8 +104,10 @@ type Headscale struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	registrationCache *cache.Cache
 | 
						registrationCache *cache.Cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	shutdownChan       chan struct{}
 | 
					 | 
				
			||||||
	pollNetMapStreamWG sync.WaitGroup
 | 
						pollNetMapStreamWG sync.WaitGroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mapSessions  map[types.NodeID]*mapSession
 | 
				
			||||||
 | 
						mapSessionMu deadlock.Mutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
@ -129,6 +139,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
		registrationCache:  registrationCache,
 | 
							registrationCache:  registrationCache,
 | 
				
			||||||
		pollNetMapStreamWG: sync.WaitGroup{},
 | 
							pollNetMapStreamWG: sync.WaitGroup{},
 | 
				
			||||||
		nodeNotifier:       notifier.NewNotifier(),
 | 
							nodeNotifier:       notifier.NewNotifier(),
 | 
				
			||||||
 | 
							mapSessions:        make(map[types.NodeID]*mapSession),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	app.db, err = db.NewHeadscaleDatabase(
 | 
						app.db, err = db.NewHeadscaleDatabase(
 | 
				
			||||||
@ -199,16 +210,16 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
	http.Redirect(w, req, target, http.StatusFound)
 | 
						http.Redirect(w, req, target, http.StatusFound)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// expireEphemeralNodes deletes ephemeral node records that have not been
 | 
					// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
 | 
				
			||||||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
 | 
					// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
 | 
				
			||||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
 | 
					func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
 | 
				
			||||||
	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
 | 
						ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var update types.StateUpdate
 | 
					 | 
				
			||||||
	var changed bool
 | 
					 | 
				
			||||||
	for range ticker.C {
 | 
						for range ticker.C {
 | 
				
			||||||
 | 
							var removed []types.NodeID
 | 
				
			||||||
 | 
							var changed []types.NodeID
 | 
				
			||||||
		if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
 | 
							if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
 | 
				
			||||||
			update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
 | 
								removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		}); err != nil {
 | 
							}); err != nil {
 | 
				
			||||||
@ -216,9 +227,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if changed && update.Valid() {
 | 
							if removed != nil {
 | 
				
			||||||
			ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
 | 
								ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
 | 
				
			||||||
			h.nodeNotifier.NotifyAll(ctx, update)
 | 
								h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
									Type:    types.StatePeerRemoved,
 | 
				
			||||||
 | 
									Removed: removed,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if changed != nil {
 | 
				
			||||||
 | 
								ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
 | 
				
			||||||
 | 
								h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
									Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
									ChangeNodes: changed,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -243,8 +265,9 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
 | 
							if changed {
 | 
				
			||||||
		if changed && update.Valid() {
 | 
								log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
 | 
								ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
 | 
				
			||||||
			h.nodeNotifier.NotifyAll(ctx, update)
 | 
								h.nodeNotifier.NotifyAll(ctx, update)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -272,14 +295,11 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
 | 
				
			|||||||
				h.DERPMap.Regions[region.RegionID] = ®ion
 | 
									h.DERPMap.Regions[region.RegionID] = ®ion
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			stateUpdate := types.StateUpdate{
 | 
								ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
 | 
				
			||||||
 | 
								h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
				Type:    types.StateDERPUpdated,
 | 
									Type:    types.StateDERPUpdated,
 | 
				
			||||||
				DERPMap: h.DERPMap,
 | 
									DERPMap: h.DERPMap,
 | 
				
			||||||
			}
 | 
								})
 | 
				
			||||||
			if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
				ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
 | 
					 | 
				
			||||||
				h.nodeNotifier.NotifyAll(ctx, stateUpdate)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -502,6 +522,7 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Fetch an initial DERP Map before we start serving
 | 
						// Fetch an initial DERP Map before we start serving
 | 
				
			||||||
	h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
 | 
						h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
 | 
				
			||||||
 | 
						h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.cfg.DERP.ServerEnabled {
 | 
						if h.cfg.DERP.ServerEnabled {
 | 
				
			||||||
		// When embedded DERP is enabled we always need a STUN server
 | 
							// When embedded DERP is enabled we always need a STUN server
 | 
				
			||||||
@ -533,7 +554,7 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): These should have cancel channels and be cleaned
 | 
						// TODO(kradalby): These should have cancel channels and be cleaned
 | 
				
			||||||
	// up on shutdown.
 | 
						// up on shutdown.
 | 
				
			||||||
	go h.expireEphemeralNodes(updateInterval)
 | 
						go h.deleteExpireEphemeralNodes(updateInterval)
 | 
				
			||||||
	go h.expireExpiredMachines(updateInterval)
 | 
						go h.expireExpiredMachines(updateInterval)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if zl.GlobalLevel() == zl.TraceLevel {
 | 
						if zl.GlobalLevel() == zl.TraceLevel {
 | 
				
			||||||
@ -686,6 +707,9 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
		// no good way to handle streaming timeouts, therefore we need to
 | 
							// no good way to handle streaming timeouts, therefore we need to
 | 
				
			||||||
		// keep this at unlimited and be careful to clean up connections
 | 
							// keep this at unlimited and be careful to clean up connections
 | 
				
			||||||
		// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
 | 
							// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
 | 
				
			||||||
 | 
							// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
 | 
				
			||||||
 | 
							// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
 | 
				
			||||||
 | 
							// replace this so only the longpoller has no timeout.
 | 
				
			||||||
		WriteTimeout: 0,
 | 
							WriteTimeout: 0,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -742,7 +766,6 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Handle common process-killing signals so we can gracefully shut down:
 | 
						// Handle common process-killing signals so we can gracefully shut down:
 | 
				
			||||||
	h.shutdownChan = make(chan struct{})
 | 
					 | 
				
			||||||
	sigc := make(chan os.Signal, 1)
 | 
						sigc := make(chan os.Signal, 1)
 | 
				
			||||||
	signal.Notify(sigc,
 | 
						signal.Notify(sigc,
 | 
				
			||||||
		syscall.SIGHUP,
 | 
							syscall.SIGHUP,
 | 
				
			||||||
@ -785,8 +808,6 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
					Str("signal", sig.String()).
 | 
										Str("signal", sig.String()).
 | 
				
			||||||
					Msg("Received signal to stop, shutting down gracefully")
 | 
										Msg("Received signal to stop, shutting down gracefully")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				close(h.shutdownChan)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				h.pollNetMapStreamWG.Wait()
 | 
									h.pollNetMapStreamWG.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// Gracefully shut down servers
 | 
									// Gracefully shut down servers
 | 
				
			||||||
 | 
				
			|||||||
@ -352,13 +352,8 @@ func (h *Headscale) handleAuthKey(
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		mkey := node.MachineKey
 | 
					 | 
				
			||||||
		update := types.StateUpdateExpire(node.ID, registerRequest.Expiry)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if update.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
 | 
							ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
 | 
				
			||||||
			h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
 | 
							h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		now := time.Now().UTC()
 | 
							now := time.Now().UTC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -538,11 +533,8 @@ func (h *Headscale) handleNodeLogOut(
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stateUpdate := types.StateUpdateExpire(node.ID, now)
 | 
					 | 
				
			||||||
	if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
	ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
 | 
						ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
 | 
				
			||||||
		h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
 | 
						h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.AuthURL = ""
 | 
						resp.AuthURL = ""
 | 
				
			||||||
	resp.MachineAuthorized = false
 | 
						resp.MachineAuthorized = false
 | 
				
			||||||
@ -572,7 +564,7 @@ func (h *Headscale) handleNodeLogOut(
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if node.IsEphemeral() {
 | 
						if node.IsEphemeral() {
 | 
				
			||||||
		err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
 | 
							changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Err(err).
 | 
									Err(err).
 | 
				
			||||||
@ -580,13 +572,16 @@ func (h *Headscale) handleNodeLogOut(
 | 
				
			|||||||
				Msg("Cannot delete ephemeral node from the database")
 | 
									Msg("Cannot delete ephemeral node from the database")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		stateUpdate := types.StateUpdate{
 | 
					 | 
				
			||||||
			Type:    types.StatePeerRemoved,
 | 
					 | 
				
			||||||
			Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
 | 
							ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
 | 
				
			||||||
			h.nodeNotifier.NotifyAll(ctx, stateUpdate)
 | 
							h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type:    types.StatePeerRemoved,
 | 
				
			||||||
 | 
								Removed: []types.NodeID{node.ID},
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							if changedNodes != nil {
 | 
				
			||||||
 | 
								h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
									Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
									ChangeNodes: changedNodes,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
 | 
				
			|||||||
@ -34,27 +34,22 @@ var (
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
 | 
					func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
 | 
				
			||||||
	return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
 | 
						return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
 | 
				
			||||||
		return ListPeers(rx, node)
 | 
							return ListPeers(rx, nodeID)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
 | 
					// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
 | 
				
			||||||
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
 | 
					func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Caller().
 | 
					 | 
				
			||||||
		Str("node", node.Hostname).
 | 
					 | 
				
			||||||
		Msg("Finding direct peers")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	nodes := types.Nodes{}
 | 
						nodes := types.Nodes{}
 | 
				
			||||||
	if err := tx.
 | 
						if err := tx.
 | 
				
			||||||
		Preload("AuthKey").
 | 
							Preload("AuthKey").
 | 
				
			||||||
		Preload("AuthKey.User").
 | 
							Preload("AuthKey.User").
 | 
				
			||||||
		Preload("User").
 | 
							Preload("User").
 | 
				
			||||||
		Preload("Routes").
 | 
							Preload("Routes").
 | 
				
			||||||
		Where("node_key <> ?",
 | 
							Where("id <> ?",
 | 
				
			||||||
			node.NodeKey.String()).Find(&nodes).Error; err != nil {
 | 
								nodeID).Find(&nodes).Error; err != nil {
 | 
				
			||||||
		return types.Nodes{}, err
 | 
							return types.Nodes{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -119,14 +114,14 @@ func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
 | 
				
			|||||||
	return nil, ErrNodeNotFound
 | 
						return nil, ErrNodeNotFound
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
 | 
					func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
 | 
				
			||||||
	return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
 | 
						return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
 | 
				
			||||||
		return GetNodeByID(rx, id)
 | 
							return GetNodeByID(rx, id)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetNodeByID finds a Node by ID and returns the Node struct.
 | 
					// GetNodeByID finds a Node by ID and returns the Node struct.
 | 
				
			||||||
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
 | 
					func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
 | 
				
			||||||
	mach := types.Node{}
 | 
						mach := types.Node{}
 | 
				
			||||||
	if result := tx.
 | 
						if result := tx.
 | 
				
			||||||
		Preload("AuthKey").
 | 
							Preload("AuthKey").
 | 
				
			||||||
@ -197,7 +192,7 @@ func GetNodeByAnyKey(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) SetTags(
 | 
					func (hsdb *HSDatabase) SetTags(
 | 
				
			||||||
	nodeID uint64,
 | 
						nodeID types.NodeID,
 | 
				
			||||||
	tags []string,
 | 
						tags []string,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	return hsdb.Write(func(tx *gorm.DB) error {
 | 
						return hsdb.Write(func(tx *gorm.DB) error {
 | 
				
			||||||
@ -208,7 +203,7 @@ func (hsdb *HSDatabase) SetTags(
 | 
				
			|||||||
// SetTags takes a Node struct pointer and update the forced tags.
 | 
					// SetTags takes a Node struct pointer and update the forced tags.
 | 
				
			||||||
func SetTags(
 | 
					func SetTags(
 | 
				
			||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	nodeID uint64,
 | 
						nodeID types.NodeID,
 | 
				
			||||||
	tags []string,
 | 
						tags []string,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	if len(tags) == 0 {
 | 
						if len(tags) == 0 {
 | 
				
			||||||
@ -256,7 +251,7 @@ func RenameNode(tx *gorm.DB,
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
 | 
					func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
 | 
				
			||||||
	return hsdb.Write(func(tx *gorm.DB) error {
 | 
						return hsdb.Write(func(tx *gorm.DB) error {
 | 
				
			||||||
		return NodeSetExpiry(tx, nodeID, expiry)
 | 
							return NodeSetExpiry(tx, nodeID, expiry)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -264,13 +259,13 @@ func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// NodeSetExpiry takes a Node struct and  a new expiry time.
 | 
					// NodeSetExpiry takes a Node struct and  a new expiry time.
 | 
				
			||||||
func NodeSetExpiry(tx *gorm.DB,
 | 
					func NodeSetExpiry(tx *gorm.DB,
 | 
				
			||||||
	nodeID uint64, expiry time.Time,
 | 
						nodeID types.NodeID, expiry time.Time,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
 | 
						return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
 | 
					func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
 | 
				
			||||||
	return hsdb.Write(func(tx *gorm.DB) error {
 | 
						return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | 
				
			||||||
		return DeleteNode(tx, node, isConnected)
 | 
							return DeleteNode(tx, node, isConnected)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -279,24 +274,24 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.Machine
 | 
				
			|||||||
// Caller is responsible for notifying all of change.
 | 
					// Caller is responsible for notifying all of change.
 | 
				
			||||||
func DeleteNode(tx *gorm.DB,
 | 
					func DeleteNode(tx *gorm.DB,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
) error {
 | 
					) ([]types.NodeID, error) {
 | 
				
			||||||
	err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
 | 
						changed, err := deleteNodeRoutes(tx, node, isConnected)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return changed, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Unscoped causes the node to be fully removed from the database.
 | 
						// Unscoped causes the node to be fully removed from the database.
 | 
				
			||||||
	if err := tx.Unscoped().Delete(&node).Error; err != nil {
 | 
						if err := tx.Unscoped().Delete(&node).Error; err != nil {
 | 
				
			||||||
		return err
 | 
							return changed, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return changed, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UpdateLastSeen sets a node's last seen field indicating that we
 | 
					// SetLastSeen sets a node's last seen field indicating that we
 | 
				
			||||||
// have recently communicating with this node.
 | 
					// have recently communicating with this node.
 | 
				
			||||||
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
 | 
					func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
 | 
				
			||||||
	return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
 | 
						return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -606,7 +601,7 @@ func enableRoutes(tx *gorm.DB,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return &types.StateUpdate{
 | 
						return &types.StateUpdate{
 | 
				
			||||||
		Type:        types.StatePeerChanged,
 | 
							Type:        types.StatePeerChanged,
 | 
				
			||||||
		ChangeNodes: types.Nodes{node},
 | 
							ChangeNodes: []types.NodeID{node.ID},
 | 
				
			||||||
		Message:     "created in db.enableRoutes",
 | 
							Message:     "created in db.enableRoutes",
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -681,17 +676,18 @@ func GenerateGivenName(
 | 
				
			|||||||
	return givenName, nil
 | 
						return givenName, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ExpireEphemeralNodes(tx *gorm.DB,
 | 
					func DeleteExpiredEphemeralNodes(tx *gorm.DB,
 | 
				
			||||||
	inactivityThreshhold time.Duration,
 | 
						inactivityThreshhold time.Duration,
 | 
				
			||||||
) (types.StateUpdate, bool) {
 | 
					) ([]types.NodeID, []types.NodeID) {
 | 
				
			||||||
	users, err := ListUsers(tx)
 | 
						users, err := ListUsers(tx)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Error listing users")
 | 
							log.Error().Err(err).Msg("Error listing users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return types.StateUpdate{}, false
 | 
							return nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	expired := make([]tailcfg.NodeID, 0)
 | 
						var expired []types.NodeID
 | 
				
			||||||
 | 
						var changedNodes []types.NodeID
 | 
				
			||||||
	for _, user := range users {
 | 
						for _, user := range users {
 | 
				
			||||||
		nodes, err := ListNodesByUser(tx, user.Name)
 | 
							nodes, err := ListNodesByUser(tx, user.Name)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@ -700,40 +696,36 @@ func ExpireEphemeralNodes(tx *gorm.DB,
 | 
				
			|||||||
				Str("user", user.Name).
 | 
									Str("user", user.Name).
 | 
				
			||||||
				Msg("Error listing nodes in user")
 | 
									Msg("Error listing nodes in user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return types.StateUpdate{}, false
 | 
								return nil, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for idx, node := range nodes {
 | 
							for idx, node := range nodes {
 | 
				
			||||||
			if node.IsEphemeral() && node.LastSeen != nil &&
 | 
								if node.IsEphemeral() && node.LastSeen != nil &&
 | 
				
			||||||
				time.Now().
 | 
									time.Now().
 | 
				
			||||||
					After(node.LastSeen.Add(inactivityThreshhold)) {
 | 
										After(node.LastSeen.Add(inactivityThreshhold)) {
 | 
				
			||||||
				expired = append(expired, tailcfg.NodeID(node.ID))
 | 
									expired = append(expired, node.ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				log.Info().
 | 
									log.Info().
 | 
				
			||||||
					Str("node", node.Hostname).
 | 
										Str("node", node.Hostname).
 | 
				
			||||||
					Msg("Ephemeral client removed from database")
 | 
										Msg("Ephemeral client removed from database")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					// empty isConnected map as ephemeral nodes are not routes
 | 
										// empty isConnected map as ephemeral nodes are not routes
 | 
				
			||||||
				err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
 | 
									changed, err := DeleteNode(tx, nodes[idx], nil)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					log.Error().
 | 
										log.Error().
 | 
				
			||||||
						Err(err).
 | 
											Err(err).
 | 
				
			||||||
						Str("node", node.Hostname).
 | 
											Str("node", node.Hostname).
 | 
				
			||||||
						Msg("🤮 Cannot delete ephemeral node from the database")
 | 
											Msg("🤮 Cannot delete ephemeral node from the database")
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									changedNodes = append(changedNodes, changed...)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// TODO(kradalby): needs to be moved out of transaction
 | 
							// TODO(kradalby): needs to be moved out of transaction
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(expired) > 0 {
 | 
					 | 
				
			||||||
		return types.StateUpdate{
 | 
					 | 
				
			||||||
			Type:    types.StatePeerRemoved,
 | 
					 | 
				
			||||||
			Removed: expired,
 | 
					 | 
				
			||||||
		}, true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return types.StateUpdate{}, false
 | 
						return expired, changedNodes
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ExpireExpiredNodes(tx *gorm.DB,
 | 
					func ExpireExpiredNodes(tx *gorm.DB,
 | 
				
			||||||
@ -754,35 +746,12 @@ func ExpireExpiredNodes(tx *gorm.DB,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		return time.Unix(0, 0), types.StateUpdate{}, false
 | 
							return time.Unix(0, 0), types.StateUpdate{}, false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for index, node := range nodes {
 | 
						for _, node := range nodes {
 | 
				
			||||||
		if node.IsExpired() &&
 | 
							if node.IsExpired() && node.Expiry.After(lastCheck) {
 | 
				
			||||||
			// TODO(kradalby): Replace this, it is very spammy
 | 
					 | 
				
			||||||
			// It will notify about all nodes that has been expired.
 | 
					 | 
				
			||||||
			// It should only notify about expired nodes since _last check_.
 | 
					 | 
				
			||||||
			node.Expiry.After(lastCheck) {
 | 
					 | 
				
			||||||
			expired = append(expired, &tailcfg.PeerChange{
 | 
								expired = append(expired, &tailcfg.PeerChange{
 | 
				
			||||||
				NodeID:    tailcfg.NodeID(node.ID),
 | 
									NodeID:    tailcfg.NodeID(node.ID),
 | 
				
			||||||
				KeyExpiry: node.Expiry,
 | 
									KeyExpiry: node.Expiry,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 | 
					 | 
				
			||||||
			now := time.Now()
 | 
					 | 
				
			||||||
			// Do not use setNodeExpiry as that has a notifier hook, which
 | 
					 | 
				
			||||||
			// can cause a deadlock, we are updating all changed nodes later
 | 
					 | 
				
			||||||
			// and there is no point in notifiying twice.
 | 
					 | 
				
			||||||
			if err := tx.Model(&nodes[index]).Updates(types.Node{
 | 
					 | 
				
			||||||
				Expiry: &now,
 | 
					 | 
				
			||||||
			}).Error; err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Str("node", node.Hostname).
 | 
					 | 
				
			||||||
					Str("name", node.GivenName).
 | 
					 | 
				
			||||||
					Msg("🤮 Cannot expire node")
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				log.Info().
 | 
					 | 
				
			||||||
					Str("node", node.Hostname).
 | 
					 | 
				
			||||||
					Str("name", node.GivenName).
 | 
					 | 
				
			||||||
					Msg("Node successfully expired")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -120,7 +120,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	db.DB.Save(&node)
 | 
						db.DB.Save(&node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
 | 
						_, err = db.DeleteNode(&node, types.NodeConnectedMap{})
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = db.getNode(user.Name, "testnode3")
 | 
						_, err = db.getNode(user.Name, "testnode3")
 | 
				
			||||||
@ -142,7 +142,7 @@ func (s *Suite) TestListPeers(c *check.C) {
 | 
				
			|||||||
		machineKey := key.NewMachine()
 | 
							machineKey := key.NewMachine()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		node := types.Node{
 | 
							node := types.Node{
 | 
				
			||||||
			ID:             uint64(index),
 | 
								ID:             types.NodeID(index),
 | 
				
			||||||
			MachineKey:     machineKey.Public(),
 | 
								MachineKey:     machineKey.Public(),
 | 
				
			||||||
			NodeKey:        nodeKey.Public(),
 | 
								NodeKey:        nodeKey.Public(),
 | 
				
			||||||
			Hostname:       "testnode" + strconv.Itoa(index),
 | 
								Hostname:       "testnode" + strconv.Itoa(index),
 | 
				
			||||||
@ -156,7 +156,7 @@ func (s *Suite) TestListPeers(c *check.C) {
 | 
				
			|||||||
	node0ByID, err := db.GetNodeByID(0)
 | 
						node0ByID, err := db.GetNodeByID(0)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	peersOfNode0, err := db.ListPeers(node0ByID)
 | 
						peersOfNode0, err := db.ListPeers(node0ByID.ID)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.Assert(len(peersOfNode0), check.Equals, 9)
 | 
						c.Assert(len(peersOfNode0), check.Equals, 9)
 | 
				
			||||||
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
 | 
				
			|||||||
		machineKey := key.NewMachine()
 | 
							machineKey := key.NewMachine()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		node := types.Node{
 | 
							node := types.Node{
 | 
				
			||||||
			ID:         uint64(index),
 | 
								ID:         types.NodeID(index),
 | 
				
			||||||
			MachineKey: machineKey.Public(),
 | 
								MachineKey: machineKey.Public(),
 | 
				
			||||||
			NodeKey:    nodeKey.Public(),
 | 
								NodeKey:    nodeKey.Public(),
 | 
				
			||||||
			IPAddresses: types.NodeAddresses{
 | 
								IPAddresses: types.NodeAddresses{
 | 
				
			||||||
@ -232,16 +232,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
 | 
				
			|||||||
	c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
 | 
						c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	adminPeers, err := db.ListPeers(adminNode)
 | 
						adminPeers, err := db.ListPeers(adminNode.ID)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testPeers, err := db.ListPeers(testNode)
 | 
						testPeers, err := db.ListPeers(testNode.ID)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers)
 | 
						adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers)
 | 
						testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
 | 
						peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
 | 
				
			||||||
@ -586,7 +586,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
 | 
				
			|||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): Check state update
 | 
						// TODO(kradalby): Check state update
 | 
				
			||||||
	_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
 | 
						err = db.EnableAutoApprovedRoutes(pol, node0ByID)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
 | 
						enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
 | 
				
			||||||
 | 
				
			|||||||
@ -92,10 +92,6 @@ func CreatePreAuthKey(
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &key, nil
 | 
						return &key, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -148,7 +148,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
 | 
				
			|||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db.DB.Transaction(func(tx *gorm.DB) error {
 | 
						db.DB.Transaction(func(tx *gorm.DB) error {
 | 
				
			||||||
		ExpireEphemeralNodes(tx, time.Second*20)
 | 
							DeleteExpiredEphemeralNodes(tx, time.Second*20)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -182,7 +182,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
 | 
				
			|||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db.DB.Transaction(func(tx *gorm.DB) error {
 | 
						db.DB.Transaction(func(tx *gorm.DB) error {
 | 
				
			||||||
		ExpireEphemeralNodes(tx, time.Second*20)
 | 
							DeleteExpiredEphemeralNodes(tx, time.Second*20)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,6 @@ import (
 | 
				
			|||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"tailscale.com/types/key"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
 | 
					var ErrRouteIsNotAvailable = errors.New("route is not available")
 | 
				
			||||||
@ -124,8 +123,8 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func DisableRoute(tx *gorm.DB,
 | 
					func DisableRoute(tx *gorm.DB,
 | 
				
			||||||
	id uint64,
 | 
						id uint64,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) ([]types.NodeID, error) {
 | 
				
			||||||
	route, err := GetRoute(tx, id)
 | 
						route, err := GetRoute(tx, id)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -137,16 +136,15 @@ func DisableRoute(tx *gorm.DB,
 | 
				
			|||||||
	// Tailscale requires both IPv4 and IPv6 exit routes to
 | 
						// Tailscale requires both IPv4 and IPv6 exit routes to
 | 
				
			||||||
	// be enabled at the same time, as per
 | 
						// be enabled at the same time, as per
 | 
				
			||||||
	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
 | 
						// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
 | 
				
			||||||
	var update *types.StateUpdate
 | 
						var update []types.NodeID
 | 
				
			||||||
	if !route.IsExitRoute() {
 | 
						if !route.IsExitRoute() {
 | 
				
			||||||
		update, err = failoverRouteReturnUpdate(tx, isConnected, route)
 | 
							route.Enabled = false
 | 
				
			||||||
 | 
							err = tx.Save(route).Error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		route.Enabled = false
 | 
							update, err = failoverRouteTx(tx, isConnected, route)
 | 
				
			||||||
		route.IsPrimary = false
 | 
					 | 
				
			||||||
		err = tx.Save(route).Error
 | 
					 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -160,6 +158,7 @@ func DisableRoute(tx *gorm.DB,
 | 
				
			|||||||
			if routes[i].IsExitRoute() {
 | 
								if routes[i].IsExitRoute() {
 | 
				
			||||||
				routes[i].Enabled = false
 | 
									routes[i].Enabled = false
 | 
				
			||||||
				routes[i].IsPrimary = false
 | 
									routes[i].IsPrimary = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				err = tx.Save(&routes[i]).Error
 | 
									err = tx.Save(&routes[i]).Error
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					return nil, err
 | 
										return nil, err
 | 
				
			||||||
@ -168,26 +167,11 @@ func DisableRoute(tx *gorm.DB,
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if routes == nil {
 | 
					 | 
				
			||||||
		routes, err = GetNodeRoutes(tx, &node)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	node.Routes = routes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If update is empty, it means that one was not created
 | 
						// If update is empty, it means that one was not created
 | 
				
			||||||
	// by failover (as a failover was not necessary), create
 | 
						// by failover (as a failover was not necessary), create
 | 
				
			||||||
	// one and return to the caller.
 | 
						// one and return to the caller.
 | 
				
			||||||
	if update == nil {
 | 
						if update == nil {
 | 
				
			||||||
		update = &types.StateUpdate{
 | 
							update = []types.NodeID{node.ID}
 | 
				
			||||||
			Type: types.StatePeerChanged,
 | 
					 | 
				
			||||||
			ChangeNodes: types.Nodes{
 | 
					 | 
				
			||||||
				&node,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			Message: "called from db.DisableRoute",
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return update, nil
 | 
						return update, nil
 | 
				
			||||||
@ -195,9 +179,9 @@ func DisableRoute(tx *gorm.DB,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) DeleteRoute(
 | 
					func (hsdb *HSDatabase) DeleteRoute(
 | 
				
			||||||
	id uint64,
 | 
						id uint64,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) ([]types.NodeID, error) {
 | 
				
			||||||
	return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
						return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | 
				
			||||||
		return DeleteRoute(tx, id, isConnected)
 | 
							return DeleteRoute(tx, id, isConnected)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -205,8 +189,8 @@ func (hsdb *HSDatabase) DeleteRoute(
 | 
				
			|||||||
func DeleteRoute(
 | 
					func DeleteRoute(
 | 
				
			||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	id uint64,
 | 
						id uint64,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) ([]types.NodeID, error) {
 | 
				
			||||||
	route, err := GetRoute(tx, id)
 | 
						route, err := GetRoute(tx, id)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -218,9 +202,9 @@ func DeleteRoute(
 | 
				
			|||||||
	// Tailscale requires both IPv4 and IPv6 exit routes to
 | 
						// Tailscale requires both IPv4 and IPv6 exit routes to
 | 
				
			||||||
	// be enabled at the same time, as per
 | 
						// be enabled at the same time, as per
 | 
				
			||||||
	// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
 | 
						// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
 | 
				
			||||||
	var update *types.StateUpdate
 | 
						var update []types.NodeID
 | 
				
			||||||
	if !route.IsExitRoute() {
 | 
						if !route.IsExitRoute() {
 | 
				
			||||||
		update, err = failoverRouteReturnUpdate(tx, isConnected, route)
 | 
							update, err = failoverRouteTx(tx, isConnected, route)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, nil
 | 
								return nil, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -229,7 +213,7 @@ func DeleteRoute(
 | 
				
			|||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		routes, err := GetNodeRoutes(tx, &node)
 | 
							routes, err = GetNodeRoutes(tx, &node)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -259,35 +243,37 @@ func DeleteRoute(
 | 
				
			|||||||
	node.Routes = routes
 | 
						node.Routes = routes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if update == nil {
 | 
						if update == nil {
 | 
				
			||||||
		update = &types.StateUpdate{
 | 
							update = []types.NodeID{node.ID}
 | 
				
			||||||
			Type: types.StatePeerChanged,
 | 
					 | 
				
			||||||
			ChangeNodes: types.Nodes{
 | 
					 | 
				
			||||||
				&node,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			Message: "called from db.DeleteRoute",
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return update, nil
 | 
						return update, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
 | 
					func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
 | 
				
			||||||
	routes, err := GetNodeRoutes(tx, node)
 | 
						routes, err := GetNodeRoutes(tx, node)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var changed []types.NodeID
 | 
				
			||||||
	for i := range routes {
 | 
						for i := range routes {
 | 
				
			||||||
		if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
 | 
							if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
 | 
				
			||||||
			return err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// TODO(kradalby): This is a bit too aggressive, we could probably
 | 
							// TODO(kradalby): This is a bit too aggressive, we could probably
 | 
				
			||||||
		// figure out which routes needs to be failed over rather than all.
 | 
							// figure out which routes needs to be failed over rather than all.
 | 
				
			||||||
		failoverRouteReturnUpdate(tx, isConnected, &routes[i])
 | 
							chn, err := failoverRouteTx(tx, isConnected, &routes[i])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return changed, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
							if chn != nil {
 | 
				
			||||||
 | 
								changed = append(changed, chn...)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return changed, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// isUniquePrefix returns if there is another node providing the same route already.
 | 
					// isUniquePrefix returns if there is another node providing the same route already.
 | 
				
			||||||
@ -400,7 +386,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
 | 
				
			|||||||
	for prefix, exists := range advertisedRoutes {
 | 
						for prefix, exists := range advertisedRoutes {
 | 
				
			||||||
		if !exists {
 | 
							if !exists {
 | 
				
			||||||
			route := types.Route{
 | 
								route := types.Route{
 | 
				
			||||||
				NodeID:     node.ID,
 | 
									NodeID:     node.ID.Uint64(),
 | 
				
			||||||
				Prefix:     types.IPPrefix(prefix),
 | 
									Prefix:     types.IPPrefix(prefix),
 | 
				
			||||||
				Advertised: true,
 | 
									Advertised: true,
 | 
				
			||||||
				Enabled:    false,
 | 
									Enabled:    false,
 | 
				
			||||||
@ -415,19 +401,23 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
 | 
				
			|||||||
	return sendUpdate, nil
 | 
						return sendUpdate, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
 | 
					// FailoverRouteIfAvailable takes a node and checks if the node's route
 | 
				
			||||||
// currently have a functioning host that exposes the network.
 | 
					// currently have a functioning host that exposes the network.
 | 
				
			||||||
func EnsureFailoverRouteIsAvailable(
 | 
					// If it does not, it is failed over to another suitable route if there
 | 
				
			||||||
 | 
					// is one.
 | 
				
			||||||
 | 
					func FailoverRouteIfAvailable(
 | 
				
			||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) (*types.StateUpdate, error) {
 | 
				
			||||||
 | 
						log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Msgf("ROUTE DEBUG ENTERED FAILOVER")
 | 
				
			||||||
	nodeRoutes, err := GetNodeRoutes(tx, node)
 | 
						nodeRoutes, err := GetNodeRoutes(tx, node)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("nodeRoutes", nodeRoutes).Msgf("ROUTE DEBUG NO ROUTES")
 | 
				
			||||||
		return nil, nil
 | 
							return nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var changedNodes types.Nodes
 | 
						var changedNodes []types.NodeID
 | 
				
			||||||
	for _, nodeRoute := range nodeRoutes {
 | 
						for _, nodeRoute := range nodeRoutes {
 | 
				
			||||||
		routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
 | 
							routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@ -438,71 +428,39 @@ func EnsureFailoverRouteIsAvailable(
 | 
				
			|||||||
			if route.IsPrimary {
 | 
								if route.IsPrimary {
 | 
				
			||||||
				// if we have a primary route, and the node is connected
 | 
									// if we have a primary route, and the node is connected
 | 
				
			||||||
				// nothing needs to be done.
 | 
									// nothing needs to be done.
 | 
				
			||||||
				if isConnected[route.Node.MachineKey] {
 | 
									log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG CHECKING IF ONLINE")
 | 
				
			||||||
					continue
 | 
									if isConnected[route.Node.ID] {
 | 
				
			||||||
 | 
										log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG IS ONLINE")
 | 
				
			||||||
 | 
										return nil, nil
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG NOT ONLINE, FAILING OVER")
 | 
				
			||||||
				// if not, we need to failover the route
 | 
									// if not, we need to failover the route
 | 
				
			||||||
				update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
 | 
									changedIDs, err := failoverRouteTx(tx, isConnected, &route)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					return nil, err
 | 
										return nil, err
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if update != nil {
 | 
									if changedIDs != nil {
 | 
				
			||||||
					changedNodes = append(changedNodes, update.ChangeNodes...)
 | 
										changedNodes = append(changedNodes, changedIDs...)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("changedNodes", changedNodes).Msgf("ROUTE DEBUG")
 | 
				
			||||||
	if len(changedNodes) != 0 {
 | 
						if len(changedNodes) != 0 {
 | 
				
			||||||
		return &types.StateUpdate{
 | 
							return &types.StateUpdate{
 | 
				
			||||||
			Type:        types.StatePeerChanged,
 | 
								Type:        types.StatePeerChanged,
 | 
				
			||||||
			ChangeNodes: changedNodes,
 | 
								ChangeNodes: changedNodes,
 | 
				
			||||||
			Message:     "called from db.EnsureFailoverRouteIsAvailable",
 | 
								Message:     "called from db.FailoverRouteIfAvailable",
 | 
				
			||||||
		}, nil
 | 
							}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil, nil
 | 
						return nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func failoverRouteReturnUpdate(
 | 
					// failoverRouteTx takes a route that is no longer available,
 | 
				
			||||||
	tx *gorm.DB,
 | 
					 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
					 | 
				
			||||||
	r *types.Route,
 | 
					 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					 | 
				
			||||||
	changedKeys, err := failoverRoute(tx, isConnected, r)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Interface("isConnected", isConnected).
 | 
					 | 
				
			||||||
		Interface("changedKeys", changedKeys).
 | 
					 | 
				
			||||||
		Msg("building route failover")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(changedKeys) == 0 {
 | 
					 | 
				
			||||||
		return nil, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var nodes types.Nodes
 | 
					 | 
				
			||||||
	for _, key := range changedKeys {
 | 
					 | 
				
			||||||
		node, err := GetNodeByMachineKey(tx, key)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		nodes = append(nodes, node)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &types.StateUpdate{
 | 
					 | 
				
			||||||
		Type:        types.StatePeerChanged,
 | 
					 | 
				
			||||||
		ChangeNodes: nodes,
 | 
					 | 
				
			||||||
		Message:     "called from db.failoverRouteReturnUpdate",
 | 
					 | 
				
			||||||
	}, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// failoverRoute takes a route that is no longer available,
 | 
					 | 
				
			||||||
// this can be either from:
 | 
					// this can be either from:
 | 
				
			||||||
// - being disabled
 | 
					// - being disabled
 | 
				
			||||||
// - being deleted
 | 
					// - being deleted
 | 
				
			||||||
@ -510,11 +468,11 @@ func failoverRouteReturnUpdate(
 | 
				
			|||||||
//
 | 
					//
 | 
				
			||||||
// and tries to find a new route to take over its place.
 | 
					// and tries to find a new route to take over its place.
 | 
				
			||||||
// If the given route was not primary, it returns early.
 | 
					// If the given route was not primary, it returns early.
 | 
				
			||||||
func failoverRoute(
 | 
					func failoverRouteTx(
 | 
				
			||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	isConnected map[key.MachinePublic]bool,
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
	r *types.Route,
 | 
						r *types.Route,
 | 
				
			||||||
) ([]key.MachinePublic, error) {
 | 
					) ([]types.NodeID, error) {
 | 
				
			||||||
	if r == nil {
 | 
						if r == nil {
 | 
				
			||||||
		return nil, nil
 | 
							return nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -535,11 +493,64 @@ func failoverRoute(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fo := failoverRoute(isConnected, r, routes)
 | 
				
			||||||
 | 
						if fo == nil {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = tx.Save(fo.old).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().Err(err).Msg("disabling old primary route")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = tx.Save(fo.new).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().Err(err).Msg("saving new primary route")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("hostname", fo.new.Node.Hostname).
 | 
				
			||||||
 | 
							Msgf("set primary to new route, was: id(%d), host(%s), now: id(%d), host(%s)", fo.old.ID, fo.old.Node.Hostname, fo.new.ID, fo.new.Node.Hostname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return a list of the machinekeys of the changed nodes.
 | 
				
			||||||
 | 
						return []types.NodeID{fo.old.Node.ID, fo.new.Node.ID}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type failover struct {
 | 
				
			||||||
 | 
						old *types.Route
 | 
				
			||||||
 | 
						new *types.Route
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func failoverRoute(
 | 
				
			||||||
 | 
						isConnected types.NodeConnectedMap,
 | 
				
			||||||
 | 
						routeToReplace *types.Route,
 | 
				
			||||||
 | 
						altRoutes types.Routes,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					) *failover {
 | 
				
			||||||
 | 
						if routeToReplace == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// This route is not a primary route, and it is not
 | 
				
			||||||
 | 
						// being served to nodes.
 | 
				
			||||||
 | 
						if !routeToReplace.IsPrimary {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// We do not have to failover exit nodes
 | 
				
			||||||
 | 
						if routeToReplace.IsExitRoute() {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var newPrimary *types.Route
 | 
						var newPrimary *types.Route
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Find a new suitable route
 | 
						// Find a new suitable route
 | 
				
			||||||
	for idx, route := range routes {
 | 
						for idx, route := range altRoutes {
 | 
				
			||||||
		if r.ID == route.ID {
 | 
							if routeToReplace.ID == route.ID {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -547,8 +558,8 @@ func failoverRoute(
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if isConnected[route.Node.MachineKey] {
 | 
							if isConnected != nil && isConnected[route.Node.ID] {
 | 
				
			||||||
			newPrimary = &routes[idx]
 | 
								newPrimary = &altRoutes[idx]
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -559,48 +570,23 @@ func failoverRoute(
 | 
				
			|||||||
	// the one currently marked as primary is the
 | 
						// the one currently marked as primary is the
 | 
				
			||||||
	// best we got.
 | 
						// best we got.
 | 
				
			||||||
	if newPrimary == nil {
 | 
						if newPrimary == nil {
 | 
				
			||||||
		return nil, nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						routeToReplace.IsPrimary = false
 | 
				
			||||||
		Str("hostname", newPrimary.Node.Hostname).
 | 
					 | 
				
			||||||
		Msg("found new primary, updating db")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Remove primary from the old route
 | 
					 | 
				
			||||||
	r.IsPrimary = false
 | 
					 | 
				
			||||||
	err = tx.Save(&r).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().Err(err).Msg("error disabling new primary route")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("hostname", newPrimary.Node.Hostname).
 | 
					 | 
				
			||||||
		Msg("removed primary from old route")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Set primary for the new primary
 | 
					 | 
				
			||||||
	newPrimary.IsPrimary = true
 | 
						newPrimary.IsPrimary = true
 | 
				
			||||||
	err = tx.Save(&newPrimary).Error
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().Err(err).Msg("error enabling new primary route")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, err
 | 
						return &failover{
 | 
				
			||||||
 | 
							old: routeToReplace,
 | 
				
			||||||
 | 
							new: newPrimary,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("hostname", newPrimary.Node.Hostname).
 | 
					 | 
				
			||||||
		Msg("set primary to new route")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Return a list of the machinekeys of the changed nodes.
 | 
					 | 
				
			||||||
	return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
 | 
					func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
 | 
				
			||||||
	aclPolicy *policy.ACLPolicy,
 | 
						aclPolicy *policy.ACLPolicy,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) error {
 | 
				
			||||||
	return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
						return hsdb.Write(func(tx *gorm.DB) error {
 | 
				
			||||||
		return EnableAutoApprovedRoutes(tx, aclPolicy, node)
 | 
							return EnableAutoApprovedRoutes(tx, aclPolicy, node)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -610,9 +596,9 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	aclPolicy *policy.ACLPolicy,
 | 
						aclPolicy *policy.ACLPolicy,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
) (*types.StateUpdate, error) {
 | 
					) error {
 | 
				
			||||||
	if len(node.IPAddresses) == 0 {
 | 
						if len(node.IPAddresses) == 0 {
 | 
				
			||||||
		return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
 | 
							return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	routes, err := GetNodeAdvertisedRoutes(tx, node)
 | 
						routes, err := GetNodeAdvertisedRoutes(tx, node)
 | 
				
			||||||
@ -623,7 +609,7 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
			Str("node", node.Hostname).
 | 
								Str("node", node.Hostname).
 | 
				
			||||||
			Msg("Could not get advertised routes for node")
 | 
								Msg("Could not get advertised routes for node")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
 | 
						log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
 | 
				
			||||||
@ -641,10 +627,10 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Err(err).
 | 
								log.Err(err).
 | 
				
			||||||
				Str("advertisedRoute", advertisedRoute.String()).
 | 
									Str("advertisedRoute", advertisedRoute.String()).
 | 
				
			||||||
				Uint64("nodeId", node.ID).
 | 
									Uint64("nodeId", node.ID.Uint64()).
 | 
				
			||||||
				Msg("Failed to resolve autoApprovers for advertised route")
 | 
									Msg("Failed to resolve autoApprovers for advertised route")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil, err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Trace().
 | 
							log.Trace().
 | 
				
			||||||
@ -665,7 +651,7 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
						Str("alias", approvedAlias).
 | 
											Str("alias", approvedAlias).
 | 
				
			||||||
						Msg("Failed to expand alias when processing autoApprovers policy")
 | 
											Msg("Failed to expand alias when processing autoApprovers policy")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					return nil, err
 | 
										return err
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
 | 
									// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
 | 
				
			||||||
@ -676,25 +662,17 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	update := &types.StateUpdate{
 | 
					 | 
				
			||||||
		Type:        types.StatePeerChanged,
 | 
					 | 
				
			||||||
		ChangeNodes: types.Nodes{},
 | 
					 | 
				
			||||||
		Message:     "created in db.EnableAutoApprovedRoutes",
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, approvedRoute := range approvedRoutes {
 | 
						for _, approvedRoute := range approvedRoutes {
 | 
				
			||||||
		perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
 | 
							_, err := EnableRoute(tx, uint64(approvedRoute.ID))
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Err(err).
 | 
								log.Err(err).
 | 
				
			||||||
				Str("approvedRoute", approvedRoute.String()).
 | 
									Str("approvedRoute", approvedRoute.String()).
 | 
				
			||||||
				Uint64("nodeId", node.ID).
 | 
									Uint64("nodeId", node.ID.Uint64()).
 | 
				
			||||||
				Msg("Failed to enable approved route")
 | 
									Msg("Failed to enable approved route")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil, err
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
 | 
						return nil
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return update, nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,6 @@ import (
 | 
				
			|||||||
	"gopkg.in/check.v1"
 | 
						"gopkg.in/check.v1"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
	"tailscale.com/types/key"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *Suite) TestGetRoutes(c *check.C) {
 | 
					func (s *Suite) TestGetRoutes(c *check.C) {
 | 
				
			||||||
@ -262,7 +261,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
 | 
				
			|||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): check stateupdate
 | 
						// TODO(kradalby): check stateupdate
 | 
				
			||||||
	_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
 | 
						_, err = db.DeleteRoute(uint64(routes[0].ID), nil)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	enabledRoutes1, err := db.GetEnabledRoutes(&node1)
 | 
						enabledRoutes1, err := db.GetEnabledRoutes(&node1)
 | 
				
			||||||
@ -272,20 +271,13 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
 | 
					var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestFailoverRoute(t *testing.T) {
 | 
					func TestFailoverRouteTx(t *testing.T) {
 | 
				
			||||||
	machineKeys := []key.MachinePublic{
 | 
					 | 
				
			||||||
		key.NewMachine().Public(),
 | 
					 | 
				
			||||||
		key.NewMachine().Public(),
 | 
					 | 
				
			||||||
		key.NewMachine().Public(),
 | 
					 | 
				
			||||||
		key.NewMachine().Public(),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
						tests := []struct {
 | 
				
			||||||
		name         string
 | 
							name         string
 | 
				
			||||||
		failingRoute types.Route
 | 
							failingRoute types.Route
 | 
				
			||||||
		routes       types.Routes
 | 
							routes       types.Routes
 | 
				
			||||||
		isConnected  map[key.MachinePublic]bool
 | 
							isConnected  types.NodeConnectedMap
 | 
				
			||||||
		want         []key.MachinePublic
 | 
							want         []types.NodeID
 | 
				
			||||||
		wantErr      bool
 | 
							wantErr      bool
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
@ -302,9 +294,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					ID: 1,
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix:    ipp("10.0.0.0/24"),
 | 
									Prefix:    ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node:      types.Node{},
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				IsPrimary: false,
 | 
									IsPrimary: false,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			routes:  types.Routes{},
 | 
								routes:  types.Routes{},
 | 
				
			||||||
@ -318,9 +308,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					ID: 1,
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix:    ipp("0.0.0.0/0"),
 | 
									Prefix:    ipp("0.0.0.0/0"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node:      types.Node{},
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			routes:  types.Routes{},
 | 
								routes:  types.Routes{},
 | 
				
			||||||
@ -335,7 +323,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@ -346,7 +334,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
@ -362,7 +350,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -374,7 +362,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -385,19 +373,19 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[1],
 | 
											ID: 2,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			isConnected: map[key.MachinePublic]bool{
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
				machineKeys[0]: false,
 | 
									1: false,
 | 
				
			||||||
				machineKeys[1]: true,
 | 
									2: true,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []key.MachinePublic{
 | 
								want: []types.NodeID{
 | 
				
			||||||
				machineKeys[0],
 | 
									1,
 | 
				
			||||||
				machineKeys[1],
 | 
									2,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			wantErr: false,
 | 
								wantErr: false,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
@ -409,7 +397,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: false,
 | 
									IsPrimary: false,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -421,7 +409,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -432,7 +420,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[1],
 | 
											ID: 2,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -449,7 +437,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[1],
 | 
										ID: 2,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -461,7 +449,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -472,7 +460,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[1],
 | 
											ID: 2,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -483,20 +471,19 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[2],
 | 
											ID: 3,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			isConnected: map[key.MachinePublic]bool{
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
				machineKeys[0]: true,
 | 
									1: true,
 | 
				
			||||||
				machineKeys[1]: true,
 | 
									2: true,
 | 
				
			||||||
				machineKeys[2]: true,
 | 
									3: true,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []key.MachinePublic{
 | 
								want: []types.NodeID{
 | 
				
			||||||
				machineKeys[1],
 | 
									2, 1,
 | 
				
			||||||
				machineKeys[0],
 | 
					 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			wantErr: false,
 | 
								wantErr: false,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
@ -508,7 +495,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -520,7 +507,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -532,15 +519,15 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[3],
 | 
											ID: 4,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			isConnected: map[key.MachinePublic]bool{
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
				machineKeys[0]: true,
 | 
									1: true,
 | 
				
			||||||
				machineKeys[3]: false,
 | 
									4: false,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want:    nil,
 | 
								want:    nil,
 | 
				
			||||||
			wantErr: false,
 | 
								wantErr: false,
 | 
				
			||||||
@ -553,7 +540,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -565,7 +552,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -577,7 +564,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[3],
 | 
											ID: 4,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -588,20 +575,20 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[1],
 | 
											ID: 2,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			isConnected: map[key.MachinePublic]bool{
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
				machineKeys[0]: false,
 | 
									1: false,
 | 
				
			||||||
				machineKeys[1]: true,
 | 
									2: true,
 | 
				
			||||||
				machineKeys[3]: false,
 | 
									4: false,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []key.MachinePublic{
 | 
								want: []types.NodeID{
 | 
				
			||||||
				machineKeys[0],
 | 
									1,
 | 
				
			||||||
				machineKeys[1],
 | 
									2,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			wantErr: false,
 | 
								wantErr: false,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
@ -613,7 +600,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
				Prefix: ipp("10.0.0.0/24"),
 | 
									Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
				Node: types.Node{
 | 
									Node: types.Node{
 | 
				
			||||||
					MachineKey: machineKeys[0],
 | 
										ID: 1,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				IsPrimary: true,
 | 
									IsPrimary: true,
 | 
				
			||||||
				Enabled:   true,
 | 
									Enabled:   true,
 | 
				
			||||||
@ -625,7 +612,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[0],
 | 
											ID: 1,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: true,
 | 
										IsPrimary: true,
 | 
				
			||||||
					Enabled:   true,
 | 
										Enabled:   true,
 | 
				
			||||||
@ -637,7 +624,7 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Prefix: ipp("10.0.0.0/24"),
 | 
										Prefix: ipp("10.0.0.0/24"),
 | 
				
			||||||
					Node: types.Node{
 | 
										Node: types.Node{
 | 
				
			||||||
						MachineKey: machineKeys[1],
 | 
											ID: 2,
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					IsPrimary: false,
 | 
										IsPrimary: false,
 | 
				
			||||||
					Enabled:   false,
 | 
										Enabled:   false,
 | 
				
			||||||
@ -670,8 +657,8 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
 | 
								got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | 
				
			||||||
				return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
 | 
									return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute)
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if (err != nil) != tt.wantErr {
 | 
								if (err != nil) != tt.wantErr {
 | 
				
			||||||
@ -687,230 +674,177 @@ func TestFailoverRoute(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// func TestDisableRouteFailover(t *testing.T) {
 | 
					func TestFailoverRoute(t *testing.T) {
 | 
				
			||||||
// 	machineKeys := []key.MachinePublic{
 | 
						r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
 | 
				
			||||||
// 		key.NewMachine().Public(),
 | 
							return types.Route{
 | 
				
			||||||
// 		key.NewMachine().Public(),
 | 
								Model: gorm.Model{
 | 
				
			||||||
// 		key.NewMachine().Public(),
 | 
									ID: id,
 | 
				
			||||||
// 		key.NewMachine().Public(),
 | 
								},
 | 
				
			||||||
// 	}
 | 
								Node: types.Node{
 | 
				
			||||||
 | 
									ID: nid,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								Prefix:    prefix,
 | 
				
			||||||
 | 
								Enabled:   enabled,
 | 
				
			||||||
 | 
								IsPrimary: primary,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
 | 
				
			||||||
 | 
							ro := r(id, nid, prefix, enabled, primary)
 | 
				
			||||||
 | 
							return &ro
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name         string
 | 
				
			||||||
 | 
							failingRoute types.Route
 | 
				
			||||||
 | 
							routes       types.Routes
 | 
				
			||||||
 | 
							isConnected  types.NodeConnectedMap
 | 
				
			||||||
 | 
							want         *failover
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "no-route",
 | 
				
			||||||
 | 
								failingRoute: types.Route{},
 | 
				
			||||||
 | 
								routes:       types.Routes{},
 | 
				
			||||||
 | 
								want:         nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "no-prime",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, false),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	tests := []struct {
 | 
								routes: types.Routes{},
 | 
				
			||||||
// 		name  string
 | 
								want:   nil,
 | 
				
			||||||
// 		nodes types.Nodes
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "exit-node",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("0.0.0.0/0"), false, true),
 | 
				
			||||||
 | 
								routes:       types.Routes{},
 | 
				
			||||||
 | 
								want:         nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "no-failover-single-route",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), false, true),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-primary",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
									r(2, 2, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
 | 
									1: false,
 | 
				
			||||||
 | 
									2: true,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: &failover{
 | 
				
			||||||
 | 
									old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									new: rp(2, 2, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-none-primary",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
									r(2, 2, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-primary-multi-route",
 | 
				
			||||||
 | 
								failingRoute: r(2, 2, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									r(2, 2, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
									r(3, 3, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
 | 
									1: true,
 | 
				
			||||||
 | 
									2: true,
 | 
				
			||||||
 | 
									3: true,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: &failover{
 | 
				
			||||||
 | 
									old: rp(2, 2, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									new: rp(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-primary-no-online",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
									r(2, 4, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
 | 
									1: true,
 | 
				
			||||||
 | 
									4: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-primary-one-not-online",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
									r(2, 4, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									r(3, 2, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								isConnected: types.NodeConnectedMap{
 | 
				
			||||||
 | 
									1: false,
 | 
				
			||||||
 | 
									2: true,
 | 
				
			||||||
 | 
									4: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: &failover{
 | 
				
			||||||
 | 
									old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									new: rp(3, 2, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:         "failover-primary-none-enabled",
 | 
				
			||||||
 | 
								failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
 | 
				
			||||||
 | 
								routes: types.Routes{
 | 
				
			||||||
 | 
									r(1, 1, ipp("10.0.0.0/24"), true, false),
 | 
				
			||||||
 | 
									r(2, 2, ipp("10.0.0.0/24"), false, true),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 		routeID     uint64
 | 
						cmps := append(
 | 
				
			||||||
// 		isConnected map[key.MachinePublic]bool
 | 
							util.Comparers,
 | 
				
			||||||
 | 
							cmp.Comparer(func(x, y types.IPPrefix) bool {
 | 
				
			||||||
 | 
								return netip.Prefix(x) == netip.Prefix(y)
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 		wantMachineKey key.MachinePublic
 | 
						for _, tt := range tests {
 | 
				
			||||||
// 		wantErr        string
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
// 	}{
 | 
								gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes)
 | 
				
			||||||
// 		{
 | 
					 | 
				
			||||||
// 			name: "single-route",
 | 
					 | 
				
			||||||
// 			nodes: types.Nodes{
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         0,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 1,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix: ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							Node: types.Node{
 | 
					 | 
				
			||||||
// 								MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							IsPrimary: true,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			routeID:        1,
 | 
					 | 
				
			||||||
// 			wantMachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 		},
 | 
					 | 
				
			||||||
// 		{
 | 
					 | 
				
			||||||
// 			name: "failover-simple",
 | 
					 | 
				
			||||||
// 			nodes: types.Nodes{
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         0,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 1,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: true,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         1,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 2,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: false,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			routeID:        1,
 | 
					 | 
				
			||||||
// 			wantMachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 		},
 | 
					 | 
				
			||||||
// 		{
 | 
					 | 
				
			||||||
// 			name: "no-failover-offline",
 | 
					 | 
				
			||||||
// 			nodes: types.Nodes{
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         0,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 1,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: true,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         1,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 2,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: false,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			isConnected: map[key.MachinePublic]bool{
 | 
					 | 
				
			||||||
// 				machineKeys[0]: true,
 | 
					 | 
				
			||||||
// 				machineKeys[1]: false,
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			routeID:        1,
 | 
					 | 
				
			||||||
// 			wantMachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 		},
 | 
					 | 
				
			||||||
// 		{
 | 
					 | 
				
			||||||
// 			name: "failover-to-online",
 | 
					 | 
				
			||||||
// 			nodes: types.Nodes{
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         0,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[0],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 1,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: true,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 				&types.Node{
 | 
					 | 
				
			||||||
// 					ID:         1,
 | 
					 | 
				
			||||||
// 					MachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 					Routes: []types.Route{
 | 
					 | 
				
			||||||
// 						{
 | 
					 | 
				
			||||||
// 							Model: gorm.Model{
 | 
					 | 
				
			||||||
// 								ID: 2,
 | 
					 | 
				
			||||||
// 							},
 | 
					 | 
				
			||||||
// 							Prefix:    ipp("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 							IsPrimary: false,
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 					Hostinfo: &tailcfg.Hostinfo{
 | 
					 | 
				
			||||||
// 						RoutableIPs: []netip.Prefix{
 | 
					 | 
				
			||||||
// 							netip.MustParsePrefix("10.0.0.0/24"),
 | 
					 | 
				
			||||||
// 						},
 | 
					 | 
				
			||||||
// 					},
 | 
					 | 
				
			||||||
// 				},
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			isConnected: map[key.MachinePublic]bool{
 | 
					 | 
				
			||||||
// 				machineKeys[0]: true,
 | 
					 | 
				
			||||||
// 				machineKeys[1]: true,
 | 
					 | 
				
			||||||
// 			},
 | 
					 | 
				
			||||||
// 			routeID:        1,
 | 
					 | 
				
			||||||
// 			wantMachineKey: machineKeys[1],
 | 
					 | 
				
			||||||
// 		},
 | 
					 | 
				
			||||||
// 	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	for _, tt := range tests {
 | 
								if tt.want == nil && gotf != nil {
 | 
				
			||||||
// 		t.Run(tt.name, func(t *testing.T) {
 | 
									t.Fatalf("expected nil, got %+v", gotf)
 | 
				
			||||||
// 			datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
 | 
								}
 | 
				
			||||||
// 			assert.NoError(t, err)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 			// bootstrap db
 | 
								if gotf == nil && tt.want != nil {
 | 
				
			||||||
// 			datab.DB.Transaction(func(tx *gorm.DB) error {
 | 
									t.Fatalf("expected %+v, got nil", tt.want)
 | 
				
			||||||
// 				for _, node := range tt.nodes {
 | 
								}
 | 
				
			||||||
// 					err := tx.Save(node).Error
 | 
					 | 
				
			||||||
// 					if err != nil {
 | 
					 | 
				
			||||||
// 						return err
 | 
					 | 
				
			||||||
// 					}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 					_, err = SaveNodeRoutes(tx, node)
 | 
								if tt.want != nil && gotf != nil {
 | 
				
			||||||
// 					if err != nil {
 | 
									want := map[string]*types.Route{
 | 
				
			||||||
// 						return err
 | 
										"new": tt.want.new,
 | 
				
			||||||
// 					}
 | 
										"old": tt.want.old,
 | 
				
			||||||
// 				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 				return nil
 | 
									got := map[string]*types.Route{
 | 
				
			||||||
// 			})
 | 
										"new": gotf.new,
 | 
				
			||||||
 | 
										"old": gotf.old,
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 			got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
									if diff := cmp.Diff(want, got, cmps...); diff != "" {
 | 
				
			||||||
// 				return DisableRoute(tx, tt.routeID, tt.isConnected)
 | 
										t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
 | 
				
			||||||
// 			})
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
// 			// if (err.Error() != "") != tt.wantErr {
 | 
							})
 | 
				
			||||||
// 			// 	t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
// 			// 	return
 | 
					 | 
				
			||||||
// 			// }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 			if len(got.ChangeNodes) != 1 {
 | 
					 | 
				
			||||||
// 				t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
 | 
					 | 
				
			||||||
// 			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 			if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
 | 
					 | 
				
			||||||
// 				t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
 | 
					 | 
				
			||||||
// 			}
 | 
					 | 
				
			||||||
// 		})
 | 
					 | 
				
			||||||
// 	}
 | 
					 | 
				
			||||||
// }
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -222,7 +222,7 @@ func (api headscaleV1APIServer) GetNode(
 | 
				
			|||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	request *v1.GetNodeRequest,
 | 
						request *v1.GetNodeRequest,
 | 
				
			||||||
) (*v1.GetNodeResponse, error) {
 | 
					) (*v1.GetNodeResponse, error) {
 | 
				
			||||||
	node, err := api.h.db.GetNodeByID(request.GetNodeId())
 | 
						node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -231,7 +231,7 @@ func (api headscaleV1APIServer) GetNode(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Populate the online field based on
 | 
						// Populate the online field based on
 | 
				
			||||||
	// currently connected nodes.
 | 
						// currently connected nodes.
 | 
				
			||||||
	resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
 | 
						resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.GetNodeResponse{Node: resp}, nil
 | 
						return &v1.GetNodeResponse{Node: resp}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -248,12 +248,12 @@ func (api headscaleV1APIServer) SetTags(
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | 
						node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | 
				
			||||||
		err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
 | 
							err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return db.GetNodeByID(tx, request.GetNodeId())
 | 
							return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return &v1.SetTagsResponse{
 | 
							return &v1.SetTagsResponse{
 | 
				
			||||||
@ -261,15 +261,12 @@ func (api headscaleV1APIServer) SetTags(
 | 
				
			|||||||
		}, status.Error(codes.InvalidArgument, err.Error())
 | 
							}, status.Error(codes.InvalidArgument, err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stateUpdate := types.StateUpdate{
 | 
						ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
 | 
				
			||||||
 | 
						api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
 | 
				
			||||||
		Type:        types.StatePeerChanged,
 | 
							Type:        types.StatePeerChanged,
 | 
				
			||||||
		ChangeNodes: types.Nodes{node},
 | 
							ChangeNodes: []types.NodeID{node.ID},
 | 
				
			||||||
		Message:     "called from api.SetTags",
 | 
							Message:     "called from api.SetTags",
 | 
				
			||||||
	}
 | 
						}, node.ID)
 | 
				
			||||||
	if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
 | 
					 | 
				
			||||||
		api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("node", node.Hostname).
 | 
							Str("node", node.Hostname).
 | 
				
			||||||
@ -296,12 +293,12 @@ func (api headscaleV1APIServer) DeleteNode(
 | 
				
			|||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	request *v1.DeleteNodeRequest,
 | 
						request *v1.DeleteNodeRequest,
 | 
				
			||||||
) (*v1.DeleteNodeResponse, error) {
 | 
					) (*v1.DeleteNodeResponse, error) {
 | 
				
			||||||
	node, err := api.h.db.GetNodeByID(request.GetNodeId())
 | 
						node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = api.h.db.DeleteNode(
 | 
						changedNodes, err := api.h.db.DeleteNode(
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		api.h.nodeNotifier.ConnectedMap(),
 | 
							api.h.nodeNotifier.ConnectedMap(),
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
@ -309,13 +306,17 @@ func (api headscaleV1APIServer) DeleteNode(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stateUpdate := types.StateUpdate{
 | 
						ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
 | 
				
			||||||
 | 
						api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
		Type:    types.StatePeerRemoved,
 | 
							Type:    types.StatePeerRemoved,
 | 
				
			||||||
		Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
 | 
							Removed: []types.NodeID{node.ID},
 | 
				
			||||||
	}
 | 
						})
 | 
				
			||||||
	if stateUpdate.Valid() {
 | 
					
 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
 | 
						if changedNodes != nil {
 | 
				
			||||||
		api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
 | 
							api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
								ChangeNodes: changedNodes,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.DeleteNodeResponse{}, nil
 | 
						return &v1.DeleteNodeResponse{}, nil
 | 
				
			||||||
@ -330,33 +331,27 @@ func (api headscaleV1APIServer) ExpireNode(
 | 
				
			|||||||
	node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | 
						node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | 
				
			||||||
		db.NodeSetExpiry(
 | 
							db.NodeSetExpiry(
 | 
				
			||||||
			tx,
 | 
								tx,
 | 
				
			||||||
			request.GetNodeId(),
 | 
								types.NodeID(request.GetNodeId()),
 | 
				
			||||||
			now,
 | 
								now,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return db.GetNodeByID(tx, request.GetNodeId())
 | 
							return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	selfUpdate := types.StateUpdate{
 | 
						ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
 | 
				
			||||||
		Type:        types.StateSelfUpdate,
 | 
					 | 
				
			||||||
		ChangeNodes: types.Nodes{node},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if selfUpdate.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
 | 
					 | 
				
			||||||
	api.h.nodeNotifier.NotifyByMachineKey(
 | 
						api.h.nodeNotifier.NotifyByMachineKey(
 | 
				
			||||||
		ctx,
 | 
							ctx,
 | 
				
			||||||
			selfUpdate,
 | 
							types.StateUpdate{
 | 
				
			||||||
			node.MachineKey)
 | 
								Type:        types.StateSelfUpdate,
 | 
				
			||||||
	}
 | 
								ChangeNodes: []types.NodeID{node.ID},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							node.ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stateUpdate := types.StateUpdateExpire(node.ID, now)
 | 
						ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
 | 
				
			||||||
	if stateUpdate.Valid() {
 | 
						api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
 | 
					 | 
				
			||||||
		api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("node", node.Hostname).
 | 
							Str("node", node.Hostname).
 | 
				
			||||||
@ -380,21 +375,18 @@ func (api headscaleV1APIServer) RenameNode(
 | 
				
			|||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return db.GetNodeByID(tx, request.GetNodeId())
 | 
							return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stateUpdate := types.StateUpdate{
 | 
						ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
 | 
				
			||||||
 | 
						api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
 | 
				
			||||||
		Type:        types.StatePeerChanged,
 | 
							Type:        types.StatePeerChanged,
 | 
				
			||||||
		ChangeNodes: types.Nodes{node},
 | 
							ChangeNodes: []types.NodeID{node.ID},
 | 
				
			||||||
		Message:     "called from api.RenameNode",
 | 
							Message:     "called from api.RenameNode",
 | 
				
			||||||
	}
 | 
						}, node.ID)
 | 
				
			||||||
	if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
 | 
					 | 
				
			||||||
		api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("node", node.Hostname).
 | 
							Str("node", node.Hostname).
 | 
				
			||||||
@ -423,7 +415,7 @@ func (api headscaleV1APIServer) ListNodes(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// Populate the online field based on
 | 
								// Populate the online field based on
 | 
				
			||||||
			// currently connected nodes.
 | 
								// currently connected nodes.
 | 
				
			||||||
			resp.Online = isConnected[node.MachineKey]
 | 
								resp.Online = isConnected[node.ID]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			response[index] = resp
 | 
								response[index] = resp
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -446,7 +438,7 @@ func (api headscaleV1APIServer) ListNodes(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		// Populate the online field based on
 | 
							// Populate the online field based on
 | 
				
			||||||
		// currently connected nodes.
 | 
							// currently connected nodes.
 | 
				
			||||||
		resp.Online = isConnected[node.MachineKey]
 | 
							resp.Online = isConnected[node.ID]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
 | 
							validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
 | 
				
			||||||
			node,
 | 
								node,
 | 
				
			||||||
@ -463,7 +455,7 @@ func (api headscaleV1APIServer) MoveNode(
 | 
				
			|||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	request *v1.MoveNodeRequest,
 | 
						request *v1.MoveNodeRequest,
 | 
				
			||||||
) (*v1.MoveNodeResponse, error) {
 | 
					) (*v1.MoveNodeResponse, error) {
 | 
				
			||||||
	node, err := api.h.db.GetNodeByID(request.GetNodeId())
 | 
						node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -503,7 +495,7 @@ func (api headscaleV1APIServer) EnableRoute(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if update != nil && update.Valid() {
 | 
						if update != nil {
 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
 | 
							ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
 | 
				
			||||||
		api.h.nodeNotifier.NotifyAll(
 | 
							api.h.nodeNotifier.NotifyAll(
 | 
				
			||||||
			ctx, *update)
 | 
								ctx, *update)
 | 
				
			||||||
@ -516,17 +508,19 @@ func (api headscaleV1APIServer) DisableRoute(
 | 
				
			|||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	request *v1.DisableRouteRequest,
 | 
						request *v1.DisableRouteRequest,
 | 
				
			||||||
) (*v1.DisableRouteResponse, error) {
 | 
					) (*v1.DisableRouteResponse, error) {
 | 
				
			||||||
	isConnected := api.h.nodeNotifier.ConnectedMap()
 | 
						update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | 
				
			||||||
	update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
							return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap())
 | 
				
			||||||
		return db.DisableRoute(tx, request.GetRouteId(), isConnected)
 | 
					 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if update != nil && update.Valid() {
 | 
						if update != nil {
 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
 | 
							ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
 | 
				
			||||||
		api.h.nodeNotifier.NotifyAll(ctx, *update)
 | 
							api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
								ChangeNodes: update,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.DisableRouteResponse{}, nil
 | 
						return &v1.DisableRouteResponse{}, nil
 | 
				
			||||||
@ -536,7 +530,7 @@ func (api headscaleV1APIServer) GetNodeRoutes(
 | 
				
			|||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	request *v1.GetNodeRoutesRequest,
 | 
						request *v1.GetNodeRoutesRequest,
 | 
				
			||||||
) (*v1.GetNodeRoutesResponse, error) {
 | 
					) (*v1.GetNodeRoutesResponse, error) {
 | 
				
			||||||
	node, err := api.h.db.GetNodeByID(request.GetNodeId())
 | 
						node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -556,16 +550,19 @@ func (api headscaleV1APIServer) DeleteRoute(
 | 
				
			|||||||
	request *v1.DeleteRouteRequest,
 | 
						request *v1.DeleteRouteRequest,
 | 
				
			||||||
) (*v1.DeleteRouteResponse, error) {
 | 
					) (*v1.DeleteRouteResponse, error) {
 | 
				
			||||||
	isConnected := api.h.nodeNotifier.ConnectedMap()
 | 
						isConnected := api.h.nodeNotifier.ConnectedMap()
 | 
				
			||||||
	update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
						update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | 
				
			||||||
		return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
 | 
							return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if update != nil && update.Valid() {
 | 
						if update != nil {
 | 
				
			||||||
		ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
 | 
							ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
 | 
				
			||||||
		api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
 | 
							api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
								ChangeNodes: update,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.DeleteRouteResponse{}, nil
 | 
						return &v1.DeleteRouteResponse{}, nil
 | 
				
			||||||
 | 
				
			|||||||
@ -68,12 +68,6 @@ func (h *Headscale) KeyHandler(
 | 
				
			|||||||
			Msg("could not get capability version")
 | 
								Msg("could not get capability version")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusInternalServerError)
 | 
							writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Error().
 | 
					 | 
				
			||||||
				Caller().
 | 
					 | 
				
			||||||
				Err(err).
 | 
					 | 
				
			||||||
				Msg("Failed to write response")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -82,19 +76,6 @@ func (h *Headscale) KeyHandler(
 | 
				
			|||||||
		Str("handler", "/key").
 | 
							Str("handler", "/key").
 | 
				
			||||||
		Int("cap_ver", int(capVer)).
 | 
							Int("cap_ver", int(capVer)).
 | 
				
			||||||
		Msg("New noise client")
 | 
							Msg("New noise client")
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
					 | 
				
			||||||
		writer.WriteHeader(http.StatusBadRequest)
 | 
					 | 
				
			||||||
		_, err := writer.Write([]byte("Wrong params"))
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Error().
 | 
					 | 
				
			||||||
				Caller().
 | 
					 | 
				
			||||||
				Err(err).
 | 
					 | 
				
			||||||
				Msg("Failed to write response")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TS2021 (Tailscale v2 protocol) requires to have a different key
 | 
						// TS2021 (Tailscale v2 protocol) requires to have a different key
 | 
				
			||||||
	if capVer >= NoiseCapabilityVersion {
 | 
						if capVer >= NoiseCapabilityVersion {
 | 
				
			||||||
 | 
				
			|||||||
@ -16,12 +16,12 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mapset "github.com/deckarep/golang-set/v2"
 | 
						mapset "github.com/deckarep/golang-set/v2"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/db"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/policy"
 | 
						"github.com/juanfont/headscale/hscontrol/policy"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/klauspost/compress/zstd"
 | 
						"github.com/klauspost/compress/zstd"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"golang.org/x/exp/maps"
 | 
					 | 
				
			||||||
	"tailscale.com/envknob"
 | 
						"tailscale.com/envknob"
 | 
				
			||||||
	"tailscale.com/smallzstd"
 | 
						"tailscale.com/smallzstd"
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
@ -51,21 +51,14 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
 | 
				
			|||||||
type Mapper struct {
 | 
					type Mapper struct {
 | 
				
			||||||
	// Configuration
 | 
						// Configuration
 | 
				
			||||||
	// TODO(kradalby): figure out if this is the format we want this in
 | 
						// TODO(kradalby): figure out if this is the format we want this in
 | 
				
			||||||
 | 
						db                *db.HSDatabase
 | 
				
			||||||
 | 
						cfg               *types.Config
 | 
				
			||||||
	derpMap           *tailcfg.DERPMap
 | 
						derpMap           *tailcfg.DERPMap
 | 
				
			||||||
	baseDomain       string
 | 
						isLikelyConnected types.NodeConnectedMap
 | 
				
			||||||
	dnsCfg           *tailcfg.DNSConfig
 | 
					 | 
				
			||||||
	logtail          bool
 | 
					 | 
				
			||||||
	randomClientPort bool
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	uid     string
 | 
						uid     string
 | 
				
			||||||
	created time.Time
 | 
						created time.Time
 | 
				
			||||||
	seq     uint64
 | 
						seq     uint64
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Map isnt concurrency safe, so we need to ensure
 | 
					 | 
				
			||||||
	// only one func is accessing it over time.
 | 
					 | 
				
			||||||
	mu      sync.Mutex
 | 
					 | 
				
			||||||
	peers   map[uint64]*types.Node
 | 
					 | 
				
			||||||
	patches map[uint64][]patch
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type patch struct {
 | 
					type patch struct {
 | 
				
			||||||
@ -74,35 +67,22 @@ type patch struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewMapper(
 | 
					func NewMapper(
 | 
				
			||||||
	node *types.Node,
 | 
						db *db.HSDatabase,
 | 
				
			||||||
	peers types.Nodes,
 | 
						cfg *types.Config,
 | 
				
			||||||
	derpMap *tailcfg.DERPMap,
 | 
						derpMap *tailcfg.DERPMap,
 | 
				
			||||||
	baseDomain string,
 | 
						isLikelyConnected types.NodeConnectedMap,
 | 
				
			||||||
	dnsCfg *tailcfg.DNSConfig,
 | 
					 | 
				
			||||||
	logtail bool,
 | 
					 | 
				
			||||||
	randomClientPort bool,
 | 
					 | 
				
			||||||
) *Mapper {
 | 
					) *Mapper {
 | 
				
			||||||
	log.Debug().
 | 
					 | 
				
			||||||
		Caller().
 | 
					 | 
				
			||||||
		Str("node", node.Hostname).
 | 
					 | 
				
			||||||
		Msg("creating new mapper")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
 | 
						uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &Mapper{
 | 
						return &Mapper{
 | 
				
			||||||
 | 
							db:                db,
 | 
				
			||||||
 | 
							cfg:               cfg,
 | 
				
			||||||
		derpMap:           derpMap,
 | 
							derpMap:           derpMap,
 | 
				
			||||||
		baseDomain:       baseDomain,
 | 
							isLikelyConnected: isLikelyConnected,
 | 
				
			||||||
		dnsCfg:           dnsCfg,
 | 
					 | 
				
			||||||
		logtail:          logtail,
 | 
					 | 
				
			||||||
		randomClientPort: randomClientPort,
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		uid:     uid,
 | 
							uid:     uid,
 | 
				
			||||||
		created: time.Now(),
 | 
							created: time.Now(),
 | 
				
			||||||
		seq:     0,
 | 
							seq:     0,
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// TODO: populate
 | 
					 | 
				
			||||||
		peers:   peers.IDMap(),
 | 
					 | 
				
			||||||
		patches: make(map[uint64][]patch),
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -207,11 +187,10 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
 | 
				
			|||||||
// It is a separate function to make testing easier.
 | 
					// It is a separate function to make testing easier.
 | 
				
			||||||
func (m *Mapper) fullMapResponse(
 | 
					func (m *Mapper) fullMapResponse(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
 | 
						peers types.Nodes,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
) (*tailcfg.MapResponse, error) {
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
	peers := nodeMapToList(m.peers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
 | 
						resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -219,14 +198,13 @@ func (m *Mapper) fullMapResponse(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	err = appendPeerChanges(
 | 
						err = appendPeerChanges(
 | 
				
			||||||
		resp,
 | 
							resp,
 | 
				
			||||||
 | 
							true, // full change
 | 
				
			||||||
		pol,
 | 
							pol,
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		capVer,
 | 
							capVer,
 | 
				
			||||||
		peers,
 | 
							peers,
 | 
				
			||||||
		peers,
 | 
							peers,
 | 
				
			||||||
		m.baseDomain,
 | 
							m.cfg,
 | 
				
			||||||
		m.dnsCfg,
 | 
					 | 
				
			||||||
		m.randomClientPort,
 | 
					 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -240,35 +218,25 @@ func (m *Mapper) FullMapResponse(
 | 
				
			|||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
 | 
						messages ...string,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	m.mu.Lock()
 | 
						peers, err := m.ListPeers(node.ID)
 | 
				
			||||||
	defer m.mu.Unlock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	peers := maps.Keys(m.peers)
 | 
					 | 
				
			||||||
	peersWithPatches := maps.Keys(m.patches)
 | 
					 | 
				
			||||||
	slices.Sort(peers)
 | 
					 | 
				
			||||||
	slices.Sort(peersWithPatches)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(peersWithPatches) > 0 {
 | 
					 | 
				
			||||||
		log.Debug().
 | 
					 | 
				
			||||||
			Str("node", node.Hostname).
 | 
					 | 
				
			||||||
			Uints64("peers", peers).
 | 
					 | 
				
			||||||
			Uints64("pending_patches", peersWithPatches).
 | 
					 | 
				
			||||||
			Msgf("node requested full map response, but has pending patches")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
 | 
						resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// LiteMapResponse returns a MapResponse for the given node.
 | 
						return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ReadOnlyResponse returns a MapResponse for the given node.
 | 
				
			||||||
// Lite means that the peers has been omitted, this is intended
 | 
					// Lite means that the peers has been omitted, this is intended
 | 
				
			||||||
// to be used to answer MapRequests with OmitPeers set to true.
 | 
					// to be used to answer MapRequests with OmitPeers set to true.
 | 
				
			||||||
func (m *Mapper) LiteMapResponse(
 | 
					func (m *Mapper) ReadOnlyMapResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
@ -279,18 +247,6 @@ func (m *Mapper) LiteMapResponse(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
					 | 
				
			||||||
		pol,
 | 
					 | 
				
			||||||
		node,
 | 
					 | 
				
			||||||
		nodeMapToList(m.peers),
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp.PacketFilter = policy.ReduceFilterRules(node, rules)
 | 
					 | 
				
			||||||
	resp.SSHPolicy = sshPolicy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
 | 
						return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -320,50 +276,74 @@ func (m *Mapper) DERPMapResponse(
 | 
				
			|||||||
func (m *Mapper) PeerChangedResponse(
 | 
					func (m *Mapper) PeerChangedResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	changed types.Nodes,
 | 
						changed map[types.NodeID]bool,
 | 
				
			||||||
 | 
						patches []*tailcfg.PeerChange,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
	messages ...string,
 | 
						messages ...string,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	m.mu.Lock()
 | 
					 | 
				
			||||||
	defer m.mu.Unlock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Update our internal map.
 | 
					 | 
				
			||||||
	for _, node := range changed {
 | 
					 | 
				
			||||||
		if patches, ok := m.patches[node.ID]; ok {
 | 
					 | 
				
			||||||
			// preserve online status in case the patch has an outdated one
 | 
					 | 
				
			||||||
			online := node.IsOnline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for _, p := range patches {
 | 
					 | 
				
			||||||
				// TODO(kradalby): Figure if this needs to be sorted by timestamp
 | 
					 | 
				
			||||||
				node.ApplyPeerChange(p.change)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// Ensure the patches are not applied again later
 | 
					 | 
				
			||||||
			delete(m.patches, node.ID)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			node.IsOnline = online
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		m.peers[node.ID] = node
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := appendPeerChanges(
 | 
						peers, err := m.ListPeers(node.ID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var removedIDs []tailcfg.NodeID
 | 
				
			||||||
 | 
						var changedIDs []types.NodeID
 | 
				
			||||||
 | 
						for nodeID, nodeChanged := range changed {
 | 
				
			||||||
 | 
							if nodeChanged {
 | 
				
			||||||
 | 
								changedIDs = append(changedIDs, nodeID)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								removedIDs = append(removedIDs, nodeID.NodeID())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						changedNodes := make(types.Nodes, 0, len(changedIDs))
 | 
				
			||||||
 | 
						for _, peer := range peers {
 | 
				
			||||||
 | 
							if slices.Contains(changedIDs, peer.ID) {
 | 
				
			||||||
 | 
								changedNodes = append(changedNodes, peer)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = appendPeerChanges(
 | 
				
			||||||
		&resp,
 | 
							&resp,
 | 
				
			||||||
 | 
							false, // partial change
 | 
				
			||||||
		pol,
 | 
							pol,
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		mapRequest.Version,
 | 
							mapRequest.Version,
 | 
				
			||||||
		nodeMapToList(m.peers),
 | 
							peers,
 | 
				
			||||||
		changed,
 | 
							changedNodes,
 | 
				
			||||||
		m.baseDomain,
 | 
							m.cfg,
 | 
				
			||||||
		m.dnsCfg,
 | 
					 | 
				
			||||||
		m.randomClientPort,
 | 
					 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.PeersRemoved = removedIDs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Sending patches as a part of a PeersChanged response
 | 
				
			||||||
 | 
						// is technically not suppose to be done, but they are
 | 
				
			||||||
 | 
						// applied after the PeersChanged. The patch list
 | 
				
			||||||
 | 
						// should _only_ contain Nodes that are not in the
 | 
				
			||||||
 | 
						// PeersChanged or PeersRemoved list and the caller
 | 
				
			||||||
 | 
						// should filter them out.
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// From tailcfg docs:
 | 
				
			||||||
 | 
						// These are applied after Peers* above, but in practice the
 | 
				
			||||||
 | 
						// control server should only send these on their own, without
 | 
				
			||||||
 | 
						// the Peers* fields also set.
 | 
				
			||||||
 | 
						if patches != nil {
 | 
				
			||||||
 | 
							resp.PeersChangedPatch = patches
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Add the node itself, it might have changed, and particularly
 | 
				
			||||||
 | 
						// if there are no patches or changes, this is a self update.
 | 
				
			||||||
 | 
						tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Node = tailnode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
 | 
						return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -375,71 +355,12 @@ func (m *Mapper) PeerChangedPatchResponse(
 | 
				
			|||||||
	changed []*tailcfg.PeerChange,
 | 
						changed []*tailcfg.PeerChange,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	m.mu.Lock()
 | 
					 | 
				
			||||||
	defer m.mu.Unlock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sendUpdate := false
 | 
					 | 
				
			||||||
	// patch the internal map
 | 
					 | 
				
			||||||
	for _, change := range changed {
 | 
					 | 
				
			||||||
		if peer, ok := m.peers[uint64(change.NodeID)]; ok {
 | 
					 | 
				
			||||||
			peer.ApplyPeerChange(change)
 | 
					 | 
				
			||||||
			sendUpdate = true
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			p := patch{
 | 
					 | 
				
			||||||
				timestamp: time.Now(),
 | 
					 | 
				
			||||||
				change:    change,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if patches, ok := m.patches[uint64(change.NodeID)]; ok {
 | 
					 | 
				
			||||||
				m.patches[uint64(change.NodeID)] = append(patches, p)
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				m.patches[uint64(change.NodeID)] = []patch{p}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !sendUpdate {
 | 
					 | 
				
			||||||
		return nil, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
	resp.PeersChangedPatch = changed
 | 
						resp.PeersChangedPatch = changed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
 | 
						return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(kradalby): We need some integration tests for this.
 | 
					 | 
				
			||||||
func (m *Mapper) PeerRemovedResponse(
 | 
					 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
					 | 
				
			||||||
	node *types.Node,
 | 
					 | 
				
			||||||
	removed []tailcfg.NodeID,
 | 
					 | 
				
			||||||
) ([]byte, error) {
 | 
					 | 
				
			||||||
	m.mu.Lock()
 | 
					 | 
				
			||||||
	defer m.mu.Unlock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Some nodes might have been removed already
 | 
					 | 
				
			||||||
	// so we dont want to ask downstream to remove
 | 
					 | 
				
			||||||
	// twice, than can cause a panic in tailscaled.
 | 
					 | 
				
			||||||
	notYetRemoved := []tailcfg.NodeID{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// remove from our internal map
 | 
					 | 
				
			||||||
	for _, id := range removed {
 | 
					 | 
				
			||||||
		if _, ok := m.peers[uint64(id)]; ok {
 | 
					 | 
				
			||||||
			notYetRemoved = append(notYetRemoved, id)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		delete(m.peers, uint64(id))
 | 
					 | 
				
			||||||
		delete(m.patches, uint64(id))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
					 | 
				
			||||||
	resp.PeersRemoved = notYetRemoved
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m *Mapper) marshalMapResponse(
 | 
					func (m *Mapper) marshalMapResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	resp *tailcfg.MapResponse,
 | 
						resp *tailcfg.MapResponse,
 | 
				
			||||||
@ -469,10 +390,8 @@ func (m *Mapper) marshalMapResponse(
 | 
				
			|||||||
		switch {
 | 
							switch {
 | 
				
			||||||
		case resp.Peers != nil && len(resp.Peers) > 0:
 | 
							case resp.Peers != nil && len(resp.Peers) > 0:
 | 
				
			||||||
			responseType = "full"
 | 
								responseType = "full"
 | 
				
			||||||
		case isSelfUpdate(messages...):
 | 
							case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
 | 
				
			||||||
			responseType = "self"
 | 
								responseType = "self"
 | 
				
			||||||
		case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
 | 
					 | 
				
			||||||
			responseType = "lite"
 | 
					 | 
				
			||||||
		case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
 | 
							case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
 | 
				
			||||||
			responseType = "changed"
 | 
								responseType = "changed"
 | 
				
			||||||
		case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
 | 
							case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
 | 
				
			||||||
@ -496,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
 | 
				
			|||||||
			panic(err)
 | 
								panic(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		now := time.Now().UnixNano()
 | 
							now := time.Now().Format("2006-01-02T15-04-05.999999999")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		mapResponsePath := path.Join(
 | 
							mapResponsePath := path.Join(
 | 
				
			||||||
			mPath,
 | 
								mPath,
 | 
				
			||||||
			fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
 | 
								fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
 | 
							log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
 | 
				
			||||||
@ -574,7 +493,7 @@ func (m *Mapper) baseWithConfigMapResponse(
 | 
				
			|||||||
) (*tailcfg.MapResponse, error) {
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort)
 | 
						tailnode, err := tailNode(node, capVer, pol, m.cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -582,7 +501,7 @@ func (m *Mapper) baseWithConfigMapResponse(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	resp.DERPMap = m.derpMap
 | 
						resp.DERPMap = m.derpMap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.Domain = m.baseDomain
 | 
						resp.Domain = m.cfg.BaseDomain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Do not instruct clients to collect services we do not
 | 
						// Do not instruct clients to collect services we do not
 | 
				
			||||||
	// support or do anything with them
 | 
						// support or do anything with them
 | 
				
			||||||
@ -591,12 +510,26 @@ func (m *Mapper) baseWithConfigMapResponse(
 | 
				
			|||||||
	resp.KeepAlive = false
 | 
						resp.KeepAlive = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.Debug = &tailcfg.Debug{
 | 
						resp.Debug = &tailcfg.Debug{
 | 
				
			||||||
		DisableLogTail: !m.logtail,
 | 
							DisableLogTail: !m.cfg.LogTail.Enabled,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &resp, nil
 | 
						return &resp, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
 | 
				
			||||||
 | 
						peers, err := m.db.ListPeers(nodeID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, peer := range peers {
 | 
				
			||||||
 | 
							online := m.isLikelyConnected[peer.ID]
 | 
				
			||||||
 | 
							peer.IsOnline = &online
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return peers, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
 | 
					func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
 | 
				
			||||||
	ret := make(types.Nodes, 0)
 | 
						ret := make(types.Nodes, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -612,42 +545,41 @@ func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
 | 
				
			|||||||
func appendPeerChanges(
 | 
					func appendPeerChanges(
 | 
				
			||||||
	resp *tailcfg.MapResponse,
 | 
						resp *tailcfg.MapResponse,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fullChange bool,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	peers types.Nodes,
 | 
						peers types.Nodes,
 | 
				
			||||||
	changed types.Nodes,
 | 
						changed types.Nodes,
 | 
				
			||||||
	baseDomain string,
 | 
						cfg *types.Config,
 | 
				
			||||||
	dnsCfg *tailcfg.DNSConfig,
 | 
					 | 
				
			||||||
	randomClientPort bool,
 | 
					 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	fullChange := len(peers) == len(changed)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
						packetFilter, err := pol.CompileFilterRules(append(peers, node))
 | 
				
			||||||
		pol,
 | 
						if err != nil {
 | 
				
			||||||
		node,
 | 
							return err
 | 
				
			||||||
		peers,
 | 
						}
 | 
				
			||||||
	)
 | 
					
 | 
				
			||||||
 | 
						sshPolicy, err := pol.CompileSSHPolicy(node, peers)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If there are filter rules present, see if there are any nodes that cannot
 | 
						// If there are filter rules present, see if there are any nodes that cannot
 | 
				
			||||||
	// access eachother at all and remove them from the peers.
 | 
						// access eachother at all and remove them from the peers.
 | 
				
			||||||
	if len(rules) > 0 {
 | 
						if len(packetFilter) > 0 {
 | 
				
			||||||
		changed = policy.FilterNodesByACL(node, changed, rules)
 | 
							changed = policy.FilterNodesByACL(node, changed, packetFilter)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	profiles := generateUserProfiles(node, changed, baseDomain)
 | 
						profiles := generateUserProfiles(node, changed, cfg.BaseDomain)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dnsConfig := generateDNSConfig(
 | 
						dnsConfig := generateDNSConfig(
 | 
				
			||||||
		dnsCfg,
 | 
							cfg.DNSConfig,
 | 
				
			||||||
		baseDomain,
 | 
							cfg.BaseDomain,
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		peers,
 | 
							peers,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort)
 | 
						tailPeers, err := tailNodes(changed, capVer, pol, cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -663,19 +595,9 @@ func appendPeerChanges(
 | 
				
			|||||||
		resp.PeersChanged = tailPeers
 | 
							resp.PeersChanged = tailPeers
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resp.DNSConfig = dnsConfig
 | 
						resp.DNSConfig = dnsConfig
 | 
				
			||||||
	resp.PacketFilter = policy.ReduceFilterRules(node, rules)
 | 
						resp.PacketFilter = policy.ReduceFilterRules(node, packetFilter)
 | 
				
			||||||
	resp.UserProfiles = profiles
 | 
						resp.UserProfiles = profiles
 | 
				
			||||||
	resp.SSHPolicy = sshPolicy
 | 
						resp.SSHPolicy = sshPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func isSelfUpdate(messages ...string) bool {
 | 
					 | 
				
			||||||
	for _, message := range messages {
 | 
					 | 
				
			||||||
		if strings.Contains(message, types.SelfUpdateIdentifier) {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -331,11 +331,8 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
		node  *types.Node
 | 
							node  *types.Node
 | 
				
			||||||
		peers types.Nodes
 | 
							peers types.Nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		baseDomain       string
 | 
					 | 
				
			||||||
		dnsConfig        *tailcfg.DNSConfig
 | 
					 | 
				
			||||||
		derpMap *tailcfg.DERPMap
 | 
							derpMap *tailcfg.DERPMap
 | 
				
			||||||
		logtail          bool
 | 
							cfg     *types.Config
 | 
				
			||||||
		randomClientPort bool
 | 
					 | 
				
			||||||
		want    *tailcfg.MapResponse
 | 
							want    *tailcfg.MapResponse
 | 
				
			||||||
		wantErr bool
 | 
							wantErr bool
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
@ -353,11 +350,13 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
			pol:     &policy.ACLPolicy{},
 | 
								pol:     &policy.ACLPolicy{},
 | 
				
			||||||
			node:    mini,
 | 
								node:    mini,
 | 
				
			||||||
			peers:   types.Nodes{},
 | 
								peers:   types.Nodes{},
 | 
				
			||||||
			baseDomain:       "",
 | 
					 | 
				
			||||||
			dnsConfig:        &tailcfg.DNSConfig{},
 | 
					 | 
				
			||||||
			derpMap: &tailcfg.DERPMap{},
 | 
								derpMap: &tailcfg.DERPMap{},
 | 
				
			||||||
			logtail:          false,
 | 
								cfg: &types.Config{
 | 
				
			||||||
			randomClientPort: false,
 | 
									BaseDomain:          "",
 | 
				
			||||||
 | 
									DNSConfig:           &tailcfg.DNSConfig{},
 | 
				
			||||||
 | 
									LogTail:             types.LogTailConfig{Enabled: false},
 | 
				
			||||||
 | 
									RandomizeClientPort: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
			want: &tailcfg.MapResponse{
 | 
								want: &tailcfg.MapResponse{
 | 
				
			||||||
				Node:            tailMini,
 | 
									Node:            tailMini,
 | 
				
			||||||
				KeepAlive:       false,
 | 
									KeepAlive:       false,
 | 
				
			||||||
@ -383,11 +382,13 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
			peers: types.Nodes{
 | 
								peers: types.Nodes{
 | 
				
			||||||
				peer1,
 | 
									peer1,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			baseDomain:       "",
 | 
					 | 
				
			||||||
			dnsConfig:        &tailcfg.DNSConfig{},
 | 
					 | 
				
			||||||
			derpMap: &tailcfg.DERPMap{},
 | 
								derpMap: &tailcfg.DERPMap{},
 | 
				
			||||||
			logtail:          false,
 | 
								cfg: &types.Config{
 | 
				
			||||||
			randomClientPort: false,
 | 
									BaseDomain:          "",
 | 
				
			||||||
 | 
									DNSConfig:           &tailcfg.DNSConfig{},
 | 
				
			||||||
 | 
									LogTail:             types.LogTailConfig{Enabled: false},
 | 
				
			||||||
 | 
									RandomizeClientPort: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
			want: &tailcfg.MapResponse{
 | 
								want: &tailcfg.MapResponse{
 | 
				
			||||||
				KeepAlive: false,
 | 
									KeepAlive: false,
 | 
				
			||||||
				Node:      tailMini,
 | 
									Node:      tailMini,
 | 
				
			||||||
@ -424,11 +425,13 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
				peer1,
 | 
									peer1,
 | 
				
			||||||
				peer2,
 | 
									peer2,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			baseDomain:       "",
 | 
					 | 
				
			||||||
			dnsConfig:        &tailcfg.DNSConfig{},
 | 
					 | 
				
			||||||
			derpMap: &tailcfg.DERPMap{},
 | 
								derpMap: &tailcfg.DERPMap{},
 | 
				
			||||||
			logtail:          false,
 | 
								cfg: &types.Config{
 | 
				
			||||||
			randomClientPort: false,
 | 
									BaseDomain:          "",
 | 
				
			||||||
 | 
									DNSConfig:           &tailcfg.DNSConfig{},
 | 
				
			||||||
 | 
									LogTail:             types.LogTailConfig{Enabled: false},
 | 
				
			||||||
 | 
									RandomizeClientPort: false,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
			want: &tailcfg.MapResponse{
 | 
								want: &tailcfg.MapResponse{
 | 
				
			||||||
				KeepAlive: false,
 | 
									KeepAlive: false,
 | 
				
			||||||
				Node:      tailMini,
 | 
									Node:      tailMini,
 | 
				
			||||||
@ -463,17 +466,15 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			mappy := NewMapper(
 | 
								mappy := NewMapper(
 | 
				
			||||||
				tt.node,
 | 
									nil,
 | 
				
			||||||
				tt.peers,
 | 
									tt.cfg,
 | 
				
			||||||
				tt.derpMap,
 | 
									tt.derpMap,
 | 
				
			||||||
				tt.baseDomain,
 | 
									nil,
 | 
				
			||||||
				tt.dnsConfig,
 | 
					 | 
				
			||||||
				tt.logtail,
 | 
					 | 
				
			||||||
				tt.randomClientPort,
 | 
					 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			got, err := mappy.fullMapResponse(
 | 
								got, err := mappy.fullMapResponse(
 | 
				
			||||||
				tt.node,
 | 
									tt.node,
 | 
				
			||||||
 | 
									tt.peers,
 | 
				
			||||||
				tt.pol,
 | 
									tt.pol,
 | 
				
			||||||
				0,
 | 
									0,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
				
			|||||||
@ -3,12 +3,10 @@ package mapper
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/policy"
 | 
						"github.com/juanfont/headscale/hscontrol/policy"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
					 | 
				
			||||||
	"github.com/samber/lo"
 | 
						"github.com/samber/lo"
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -17,9 +15,7 @@ func tailNodes(
 | 
				
			|||||||
	nodes types.Nodes,
 | 
						nodes types.Nodes,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
	dnsConfig *tailcfg.DNSConfig,
 | 
						cfg *types.Config,
 | 
				
			||||||
	baseDomain string,
 | 
					 | 
				
			||||||
	randomClientPort bool,
 | 
					 | 
				
			||||||
) ([]*tailcfg.Node, error) {
 | 
					) ([]*tailcfg.Node, error) {
 | 
				
			||||||
	tNodes := make([]*tailcfg.Node, len(nodes))
 | 
						tNodes := make([]*tailcfg.Node, len(nodes))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,9 +24,7 @@ func tailNodes(
 | 
				
			|||||||
			node,
 | 
								node,
 | 
				
			||||||
			capVer,
 | 
								capVer,
 | 
				
			||||||
			pol,
 | 
								pol,
 | 
				
			||||||
			dnsConfig,
 | 
								cfg,
 | 
				
			||||||
			baseDomain,
 | 
					 | 
				
			||||||
			randomClientPort,
 | 
					 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
@ -48,9 +42,7 @@ func tailNode(
 | 
				
			|||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						pol *policy.ACLPolicy,
 | 
				
			||||||
	dnsConfig *tailcfg.DNSConfig,
 | 
						cfg *types.Config,
 | 
				
			||||||
	baseDomain string,
 | 
					 | 
				
			||||||
	randomClientPort bool,
 | 
					 | 
				
			||||||
) (*tailcfg.Node, error) {
 | 
					) (*tailcfg.Node, error) {
 | 
				
			||||||
	addrs := node.IPAddresses.Prefixes()
 | 
						addrs := node.IPAddresses.Prefixes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,7 +77,7 @@ func tailNode(
 | 
				
			|||||||
		keyExpiry = time.Time{}
 | 
							keyExpiry = time.Time{}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hostname, err := node.GetFQDN(dnsConfig, baseDomain)
 | 
						hostname, err := node.GetFQDN(cfg.DNSConfig, cfg.BaseDomain)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
 | 
							return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -95,9 +87,7 @@ func tailNode(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	tNode := tailcfg.Node{
 | 
						tNode := tailcfg.Node{
 | 
				
			||||||
		ID:       tailcfg.NodeID(node.ID), // this is the actual ID
 | 
							ID:       tailcfg.NodeID(node.ID), // this is the actual ID
 | 
				
			||||||
		StableID: tailcfg.StableNodeID(
 | 
							StableID: node.ID.StableID(),
 | 
				
			||||||
			strconv.FormatUint(node.ID, util.Base10),
 | 
					 | 
				
			||||||
		), // in headscale, unlike tailcontrol server, IDs are permanent
 | 
					 | 
				
			||||||
		Name:     hostname,
 | 
							Name:     hostname,
 | 
				
			||||||
		Cap:      capVer,
 | 
							Cap:      capVer,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -133,7 +123,7 @@ func tailNode(
 | 
				
			|||||||
			tailcfg.CapabilitySSH:         []tailcfg.RawMessage{},
 | 
								tailcfg.CapabilitySSH:         []tailcfg.RawMessage{},
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if randomClientPort {
 | 
							if cfg.RandomizeClientPort {
 | 
				
			||||||
			tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
 | 
								tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@ -143,7 +133,7 @@ func tailNode(
 | 
				
			|||||||
			tailcfg.CapabilitySSH,
 | 
								tailcfg.CapabilitySSH,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if randomClientPort {
 | 
							if cfg.RandomizeClientPort {
 | 
				
			||||||
			tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
 | 
								tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -182,13 +182,16 @@ func TestTailNode(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								cfg := &types.Config{
 | 
				
			||||||
 | 
									BaseDomain:          tt.baseDomain,
 | 
				
			||||||
 | 
									DNSConfig:           tt.dnsConfig,
 | 
				
			||||||
 | 
									RandomizeClientPort: false,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			got, err := tailNode(
 | 
								got, err := tailNode(
 | 
				
			||||||
				tt.node,
 | 
									tt.node,
 | 
				
			||||||
				0,
 | 
									0,
 | 
				
			||||||
				tt.pol,
 | 
									tt.pol,
 | 
				
			||||||
				tt.dnsConfig,
 | 
									cfg,
 | 
				
			||||||
				tt.baseDomain,
 | 
					 | 
				
			||||||
				false,
 | 
					 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if (err != nil) != tt.wantErr {
 | 
								if (err != nil) != tt.wantErr {
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@ package hscontrol
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -11,6 +12,7 @@ import (
 | 
				
			|||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"golang.org/x/net/http2"
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
	"golang.org/x/net/http2/h2c"
 | 
						"golang.org/x/net/http2/h2c"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"tailscale.com/control/controlbase"
 | 
						"tailscale.com/control/controlbase"
 | 
				
			||||||
	"tailscale.com/control/controlhttp"
 | 
						"tailscale.com/control/controlhttp"
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
@ -163,3 +165,135 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						MinimumCapVersion tailcfg.CapabilityVersion = 58
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This is the busiest endpoint, as it keeps the HTTP long poll that updates
 | 
				
			||||||
 | 
					// the clients when something in the network changes.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The clients POST stuff like HostInfo and their Endpoints here, but
 | 
				
			||||||
 | 
					// only after their first request (marked with the ReadOnly field).
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// At this moment the updates are sent in a quite horrendous way, but they kinda work.
 | 
				
			||||||
 | 
					func (ns *noiseServer) NoisePollNetMapHandler(
 | 
				
			||||||
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
 | 
						req *http.Request,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "NoisePollNetMap").
 | 
				
			||||||
 | 
							Msg("PollNetMapHandler called")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Any("headers", req.Header).
 | 
				
			||||||
 | 
							Caller().
 | 
				
			||||||
 | 
							Msg("Headers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, _ := io.ReadAll(req.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mapRequest := tailcfg.MapRequest{}
 | 
				
			||||||
 | 
						if err := json.Unmarshal(body, &mapRequest); err != nil {
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Err(err).
 | 
				
			||||||
 | 
								Msg("Cannot parse MapRequest")
 | 
				
			||||||
 | 
							http.Error(writer, "Internal error", http.StatusInternalServerError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Reject unsupported versions
 | 
				
			||||||
 | 
						if mapRequest.Version < MinimumCapVersion {
 | 
				
			||||||
 | 
							log.Info().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Int("min_version", int(MinimumCapVersion)).
 | 
				
			||||||
 | 
								Int("client_version", int(mapRequest.Version)).
 | 
				
			||||||
 | 
								Msg("unsupported client connected")
 | 
				
			||||||
 | 
							http.Error(writer, "Internal error", http.StatusBadRequest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ns.nodeKey = mapRequest.NodeKey
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						node, err := ns.headscale.db.GetNodeByAnyKey(
 | 
				
			||||||
 | 
							ns.conn.Peer(),
 | 
				
			||||||
 | 
							mapRequest.NodeKey,
 | 
				
			||||||
 | 
							key.NodePublic{},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
				
			||||||
 | 
								log.Warn().
 | 
				
			||||||
 | 
									Str("handler", "NoisePollNetMap").
 | 
				
			||||||
 | 
									Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
									Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
 | 
				
			||||||
 | 
								http.Error(writer, "Internal error", http.StatusNotFound)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Str("handler", "NoisePollNetMap").
 | 
				
			||||||
 | 
								Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
								Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
 | 
				
			||||||
 | 
							http.Error(writer, "Internal error", http.StatusInternalServerError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						log.Debug().
 | 
				
			||||||
 | 
							Str("handler", "NoisePollNetMap").
 | 
				
			||||||
 | 
							Str("node", node.Hostname).
 | 
				
			||||||
 | 
							Int("cap_ver", int(mapRequest.Version)).
 | 
				
			||||||
 | 
							Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
							Msg("A node sending a MapRequest with Noise protocol")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If a streaming mapSession exists for this node, close it
 | 
				
			||||||
 | 
						// and start a new one.
 | 
				
			||||||
 | 
						if session.isStreaming() {
 | 
				
			||||||
 | 
							log.Debug().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
								Int("cap_ver", int(mapRequest.Version)).
 | 
				
			||||||
 | 
								Msg("Aquiring lock to check stream")
 | 
				
			||||||
 | 
							ns.headscale.mapSessionMu.Lock()
 | 
				
			||||||
 | 
							if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok {
 | 
				
			||||||
 | 
								log.Info().
 | 
				
			||||||
 | 
									Caller().
 | 
				
			||||||
 | 
									Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
									Msg("Node has an open streaming session, replacing")
 | 
				
			||||||
 | 
								oldSession.close()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							ns.headscale.mapSessions[node.ID] = session
 | 
				
			||||||
 | 
							ns.headscale.mapSessionMu.Unlock()
 | 
				
			||||||
 | 
							log.Debug().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
								Int("cap_ver", int(mapRequest.Version)).
 | 
				
			||||||
 | 
								Msg("Releasing lock to check stream")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session.serve()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if session.isStreaming() {
 | 
				
			||||||
 | 
							log.Debug().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
								Int("cap_ver", int(mapRequest.Version)).
 | 
				
			||||||
 | 
								Msg("Aquiring lock to remove stream")
 | 
				
			||||||
 | 
							ns.headscale.mapSessionMu.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							delete(ns.headscale.mapSessions, node.ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							ns.headscale.mapSessionMu.Unlock()
 | 
				
			||||||
 | 
							log.Debug().
 | 
				
			||||||
 | 
								Caller().
 | 
				
			||||||
 | 
								Uint64("node.id", node.ID.Uint64()).
 | 
				
			||||||
 | 
								Int("cap_ver", int(mapRequest.Version)).
 | 
				
			||||||
 | 
								Msg("Releasing lock to remove stream")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -3,52 +3,51 @@ package notifier
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"slices"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
					 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"tailscale.com/types/key"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Notifier struct {
 | 
					type Notifier struct {
 | 
				
			||||||
	l         sync.RWMutex
 | 
						l         sync.RWMutex
 | 
				
			||||||
	nodes     map[string]chan<- types.StateUpdate
 | 
						nodes     map[types.NodeID]chan<- types.StateUpdate
 | 
				
			||||||
	connected map[key.MachinePublic]bool
 | 
						connected types.NodeConnectedMap
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewNotifier() *Notifier {
 | 
					func NewNotifier() *Notifier {
 | 
				
			||||||
	return &Notifier{
 | 
						return &Notifier{
 | 
				
			||||||
		nodes:     make(map[string]chan<- types.StateUpdate),
 | 
							nodes:     make(map[types.NodeID]chan<- types.StateUpdate),
 | 
				
			||||||
		connected: make(map[key.MachinePublic]bool),
 | 
							connected: make(types.NodeConnectedMap),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
 | 
					func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
 | 
				
			||||||
	log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
 | 
						log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
 | 
				
			||||||
	defer log.Trace().
 | 
						defer log.Trace().
 | 
				
			||||||
		Caller().
 | 
							Caller().
 | 
				
			||||||
		Str("key", machineKey.ShortString()).
 | 
							Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
		Msg("releasing lock to add node")
 | 
							Msg("releasing lock to add node")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n.l.Lock()
 | 
						n.l.Lock()
 | 
				
			||||||
	defer n.l.Unlock()
 | 
						defer n.l.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n.nodes[machineKey.String()] = c
 | 
						n.nodes[nodeID] = c
 | 
				
			||||||
	n.connected[machineKey] = true
 | 
						n.connected[nodeID] = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("machine_key", machineKey.ShortString()).
 | 
							Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
		Int("open_chans", len(n.nodes)).
 | 
							Int("open_chans", len(n.nodes)).
 | 
				
			||||||
		Msg("Added new channel")
 | 
							Msg("Added new channel")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
 | 
					func (n *Notifier) RemoveNode(nodeID types.NodeID) {
 | 
				
			||||||
	log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
 | 
						log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node")
 | 
				
			||||||
	defer log.Trace().
 | 
						defer log.Trace().
 | 
				
			||||||
		Caller().
 | 
							Caller().
 | 
				
			||||||
		Str("key", machineKey.ShortString()).
 | 
							Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
		Msg("releasing lock to remove node")
 | 
							Msg("releasing lock to remove node")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n.l.Lock()
 | 
						n.l.Lock()
 | 
				
			||||||
@ -58,26 +57,32 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	delete(n.nodes, machineKey.String())
 | 
						delete(n.nodes, nodeID)
 | 
				
			||||||
	n.connected[machineKey] = false
 | 
						n.connected[nodeID] = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("machine_key", machineKey.ShortString()).
 | 
							Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
		Int("open_chans", len(n.nodes)).
 | 
							Int("open_chans", len(n.nodes)).
 | 
				
			||||||
		Msg("Removed channel")
 | 
							Msg("Removed channel")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IsConnected reports if a node is connected to headscale and has a
 | 
					// IsConnected reports if a node is connected to headscale and has a
 | 
				
			||||||
// poll session open.
 | 
					// poll session open.
 | 
				
			||||||
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
 | 
					func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
 | 
				
			||||||
	n.l.RLock()
 | 
						n.l.RLock()
 | 
				
			||||||
	defer n.l.RUnlock()
 | 
						defer n.l.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return n.connected[machineKey]
 | 
						return n.connected[nodeID]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsLikelyConnected reports if a node is connected to headscale and has a
 | 
				
			||||||
 | 
					// poll session open, but doesnt lock, so might be wrong.
 | 
				
			||||||
 | 
					func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
 | 
				
			||||||
 | 
						return n.connected[nodeID]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(kradalby): This returns a pointer and can be dangerous.
 | 
					// TODO(kradalby): This returns a pointer and can be dangerous.
 | 
				
			||||||
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
 | 
					func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
 | 
				
			||||||
	return n.connected
 | 
						return n.connected
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -88,19 +93,23 @@ func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
 | 
				
			|||||||
func (n *Notifier) NotifyWithIgnore(
 | 
					func (n *Notifier) NotifyWithIgnore(
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	update types.StateUpdate,
 | 
						update types.StateUpdate,
 | 
				
			||||||
	ignore ...string,
 | 
						ignoreNodeIDs ...types.NodeID,
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
	log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
 | 
						log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
 | 
				
			||||||
	defer log.Trace().
 | 
						defer log.Trace().
 | 
				
			||||||
		Caller().
 | 
							Caller().
 | 
				
			||||||
		Interface("type", update.Type).
 | 
							Str("type", update.Type.String()).
 | 
				
			||||||
		Msg("releasing lock, finished notifying")
 | 
							Msg("releasing lock, finished notifying")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n.l.RLock()
 | 
						n.l.RLock()
 | 
				
			||||||
	defer n.l.RUnlock()
 | 
						defer n.l.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for key, c := range n.nodes {
 | 
						if update.Type == types.StatePeerChangedPatch {
 | 
				
			||||||
		if util.IsStringInSlice(ignore, key) {
 | 
							log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for nodeID, c := range n.nodes {
 | 
				
			||||||
 | 
							if slices.Contains(ignoreNodeIDs, nodeID) {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -108,17 +117,17 @@ func (n *Notifier) NotifyWithIgnore(
 | 
				
			|||||||
		case <-ctx.Done():
 | 
							case <-ctx.Done():
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Err(ctx.Err()).
 | 
									Err(ctx.Err()).
 | 
				
			||||||
				Str("mkey", key).
 | 
									Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", ctx.Value("origin")).
 | 
				
			||||||
				Any("hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", ctx.Value("hostname")).
 | 
				
			||||||
				Msgf("update not sent, context cancelled")
 | 
									Msgf("update not sent, context cancelled")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		case c <- update:
 | 
							case c <- update:
 | 
				
			||||||
			log.Trace().
 | 
								log.Trace().
 | 
				
			||||||
				Str("mkey", key).
 | 
									Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", ctx.Value("origin")).
 | 
				
			||||||
				Any("hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", ctx.Value("hostname")).
 | 
				
			||||||
				Msgf("update successfully sent on chan")
 | 
									Msgf("update successfully sent on chan")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -127,33 +136,33 @@ func (n *Notifier) NotifyWithIgnore(
 | 
				
			|||||||
func (n *Notifier) NotifyByMachineKey(
 | 
					func (n *Notifier) NotifyByMachineKey(
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	update types.StateUpdate,
 | 
						update types.StateUpdate,
 | 
				
			||||||
	mKey key.MachinePublic,
 | 
						nodeID types.NodeID,
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
	log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
 | 
						log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
 | 
				
			||||||
	defer log.Trace().
 | 
						defer log.Trace().
 | 
				
			||||||
		Caller().
 | 
							Caller().
 | 
				
			||||||
		Interface("type", update.Type).
 | 
							Str("type", update.Type.String()).
 | 
				
			||||||
		Msg("releasing lock, finished notifying")
 | 
							Msg("releasing lock, finished notifying")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n.l.RLock()
 | 
						n.l.RLock()
 | 
				
			||||||
	defer n.l.RUnlock()
 | 
						defer n.l.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c, ok := n.nodes[mKey.String()]; ok {
 | 
						if c, ok := n.nodes[nodeID]; ok {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case <-ctx.Done():
 | 
							case <-ctx.Done():
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Err(ctx.Err()).
 | 
									Err(ctx.Err()).
 | 
				
			||||||
				Str("mkey", mKey.String()).
 | 
									Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", ctx.Value("origin")).
 | 
				
			||||||
				Any("hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", ctx.Value("hostname")).
 | 
				
			||||||
				Msgf("update not sent, context cancelled")
 | 
									Msgf("update not sent, context cancelled")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		case c <- update:
 | 
							case c <- update:
 | 
				
			||||||
			log.Trace().
 | 
								log.Trace().
 | 
				
			||||||
				Str("mkey", mKey.String()).
 | 
									Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", ctx.Value("origin")).
 | 
				
			||||||
				Any("hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", ctx.Value("hostname")).
 | 
				
			||||||
				Msgf("update successfully sent on chan")
 | 
									Msgf("update successfully sent on chan")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -166,7 +175,7 @@ func (n *Notifier) String() string {
 | 
				
			|||||||
	str := []string{"Notifier, in map:\n"}
 | 
						str := []string{"Notifier, in map:\n"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for k, v := range n.nodes {
 | 
						for k, v := range n.nodes {
 | 
				
			||||||
		str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
 | 
							str = append(str, fmt.Sprintf("\t%d: %v\n", k, v))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return strings.Join(str, "")
 | 
						return strings.Join(str, "")
 | 
				
			||||||
 | 
				
			|||||||
@ -537,11 +537,8 @@ func (h *Headscale) validateNodeForOIDCCallback(
 | 
				
			|||||||
			util.LogErr(err, "Failed to write response")
 | 
								util.LogErr(err, "Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		stateUpdate := types.StateUpdateExpire(node.ID, expiry)
 | 
					 | 
				
			||||||
		if stateUpdate.Valid() {
 | 
					 | 
				
			||||||
		ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
 | 
							ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
 | 
				
			||||||
			h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
 | 
							h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, true, nil
 | 
							return nil, true, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -114,7 +114,7 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
 | 
				
			|||||||
	return &policy, nil
 | 
						return &policy, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GenerateFilterAndSSHRules(
 | 
					func GenerateFilterAndSSHRulesForTests(
 | 
				
			||||||
	policy *ACLPolicy,
 | 
						policy *ACLPolicy,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	peers types.Nodes,
 | 
						peers types.Nodes,
 | 
				
			||||||
@ -124,40 +124,31 @@ func GenerateFilterAndSSHRules(
 | 
				
			|||||||
		return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
 | 
							return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules, err := policy.generateFilterRules(node, peers)
 | 
						rules, err := policy.CompileFilterRules(append(peers, node))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
 | 
							return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
 | 
						log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var sshPolicy *tailcfg.SSHPolicy
 | 
						sshPolicy, err := policy.CompileSSHPolicy(node, peers)
 | 
				
			||||||
	sshRules, err := policy.generateSSHRules(node, peers)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
 | 
							return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Interface("SSH", sshRules).
 | 
					 | 
				
			||||||
		Str("node", node.GivenName).
 | 
					 | 
				
			||||||
		Msg("SSH rules")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if sshPolicy == nil {
 | 
					 | 
				
			||||||
		sshPolicy = &tailcfg.SSHPolicy{}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sshPolicy.Rules = sshRules
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return rules, sshPolicy, nil
 | 
						return rules, sshPolicy, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// generateFilterRules takes a set of nodes and an ACLPolicy and generates a
 | 
					// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
 | 
				
			||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
 | 
					// set of Tailscale compatible FilterRules used to allow traffic on clients.
 | 
				
			||||||
func (pol *ACLPolicy) generateFilterRules(
 | 
					func (pol *ACLPolicy) CompileFilterRules(
 | 
				
			||||||
	node *types.Node,
 | 
						nodes types.Nodes,
 | 
				
			||||||
	peers types.Nodes,
 | 
					 | 
				
			||||||
) ([]tailcfg.FilterRule, error) {
 | 
					) ([]tailcfg.FilterRule, error) {
 | 
				
			||||||
 | 
						if pol == nil {
 | 
				
			||||||
 | 
							return tailcfg.FilterAllowAll, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules := []tailcfg.FilterRule{}
 | 
						rules := []tailcfg.FilterRule{}
 | 
				
			||||||
	nodes := append(peers, node)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for index, acl := range pol.ACLs {
 | 
						for index, acl := range pol.ACLs {
 | 
				
			||||||
		if acl.Action != "accept" {
 | 
							if acl.Action != "accept" {
 | 
				
			||||||
@ -279,10 +270,14 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
 | 
				
			|||||||
	return ret
 | 
						return ret
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pol *ACLPolicy) generateSSHRules(
 | 
					func (pol *ACLPolicy) CompileSSHPolicy(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	peers types.Nodes,
 | 
						peers types.Nodes,
 | 
				
			||||||
) ([]*tailcfg.SSHRule, error) {
 | 
					) (*tailcfg.SSHPolicy, error) {
 | 
				
			||||||
 | 
						if pol == nil {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules := []*tailcfg.SSHRule{}
 | 
						rules := []*tailcfg.SSHRule{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	acceptAction := tailcfg.SSHAction{
 | 
						acceptAction := tailcfg.SSHAction{
 | 
				
			||||||
@ -393,7 +388,9 @@ func (pol *ACLPolicy) generateSSHRules(
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return rules, nil
 | 
						return &tailcfg.SSHPolicy{
 | 
				
			||||||
 | 
							Rules: rules,
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
 | 
					func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
 | 
				
			||||||
 | 
				
			|||||||
@ -385,11 +385,12 @@ acls:
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			rules, err := pol.generateFilterRules(&types.Node{
 | 
								rules, err := pol.CompileFilterRules(types.Nodes{
 | 
				
			||||||
 | 
									&types.Node{
 | 
				
			||||||
					IPAddresses: types.NodeAddresses{
 | 
										IPAddresses: types.NodeAddresses{
 | 
				
			||||||
						netip.MustParseAddr("100.100.100.100"),
 | 
											netip.MustParseAddr("100.100.100.100"),
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
			}, types.Nodes{
 | 
									},
 | 
				
			||||||
				&types.Node{
 | 
									&types.Node{
 | 
				
			||||||
					IPAddresses: types.NodeAddresses{
 | 
										IPAddresses: types.NodeAddresses{
 | 
				
			||||||
						netip.MustParseAddr("200.200.200.200"),
 | 
											netip.MustParseAddr("200.200.200.200"),
 | 
				
			||||||
@ -546,7 +547,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
 | 
				
			|||||||
	c.Assert(pol.ACLs, check.HasLen, 6)
 | 
						c.Assert(pol.ACLs, check.HasLen, 6)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rules, err := pol.generateFilterRules(&types.Node{}, types.Nodes{})
 | 
						rules, err := pol.CompileFilterRules(types.Nodes{})
 | 
				
			||||||
	c.Assert(err, check.NotNil)
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
	c.Assert(rules, check.IsNil)
 | 
						c.Assert(rules, check.IsNil)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -562,7 +563,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
 | 
				
			|||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
 | 
						_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
 | 
				
			||||||
	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
 | 
						c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -581,7 +582,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
 | 
				
			|||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
 | 
						_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
 | 
				
			||||||
	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
 | 
						c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -597,7 +598,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
 | 
						_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
 | 
				
			||||||
	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
 | 
						c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1724,8 +1725,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
 | 
				
			|||||||
		pol ACLPolicy
 | 
							pol ACLPolicy
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	type args struct {
 | 
						type args struct {
 | 
				
			||||||
		node  *types.Node
 | 
							nodes types.Nodes
 | 
				
			||||||
		peers types.Nodes
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	tests := []struct {
 | 
						tests := []struct {
 | 
				
			||||||
		name    string
 | 
							name    string
 | 
				
			||||||
@ -1755,13 +1755,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			args: args{
 | 
								args: args{
 | 
				
			||||||
				node: &types.Node{
 | 
									nodes: types.Nodes{
 | 
				
			||||||
 | 
										&types.Node{
 | 
				
			||||||
						IPAddresses: types.NodeAddresses{
 | 
											IPAddresses: types.NodeAddresses{
 | 
				
			||||||
							netip.MustParseAddr("100.64.0.1"),
 | 
												netip.MustParseAddr("100.64.0.1"),
 | 
				
			||||||
							netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
 | 
												netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
 | 
				
			||||||
						},
 | 
											},
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				peers: types.Nodes{},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []tailcfg.FilterRule{
 | 
								want: []tailcfg.FilterRule{
 | 
				
			||||||
				{
 | 
									{
 | 
				
			||||||
@ -1800,14 +1801,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			args: args{
 | 
								args: args{
 | 
				
			||||||
				node: &types.Node{
 | 
									nodes: types.Nodes{
 | 
				
			||||||
 | 
										&types.Node{
 | 
				
			||||||
						IPAddresses: types.NodeAddresses{
 | 
											IPAddresses: types.NodeAddresses{
 | 
				
			||||||
							netip.MustParseAddr("100.64.0.1"),
 | 
												netip.MustParseAddr("100.64.0.1"),
 | 
				
			||||||
							netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
 | 
												netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
 | 
				
			||||||
						},
 | 
											},
 | 
				
			||||||
						User: types.User{Name: "mickael"},
 | 
											User: types.User{Name: "mickael"},
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				peers: types.Nodes{
 | 
					 | 
				
			||||||
					&types.Node{
 | 
										&types.Node{
 | 
				
			||||||
						IPAddresses: types.NodeAddresses{
 | 
											IPAddresses: types.NodeAddresses{
 | 
				
			||||||
							netip.MustParseAddr("100.64.0.2"),
 | 
												netip.MustParseAddr("100.64.0.2"),
 | 
				
			||||||
@ -1846,9 +1847,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			got, err := tt.field.pol.generateFilterRules(
 | 
								got, err := tt.field.pol.CompileFilterRules(
 | 
				
			||||||
				tt.args.node,
 | 
									tt.args.nodes,
 | 
				
			||||||
				tt.args.peers,
 | 
					 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
			if (err != nil) != tt.wantErr {
 | 
								if (err != nil) != tt.wantErr {
 | 
				
			||||||
				t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
 | 
									t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
 | 
				
			||||||
@ -1980,9 +1980,8 @@ func TestReduceFilterRules(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			rules, _ := tt.pol.generateFilterRules(
 | 
								rules, _ := tt.pol.CompileFilterRules(
 | 
				
			||||||
				tt.node,
 | 
									append(tt.peers, tt.node),
 | 
				
			||||||
				tt.peers,
 | 
					 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			got := ReduceFilterRules(tt.node, rules)
 | 
								got := ReduceFilterRules(tt.node, rules)
 | 
				
			||||||
@ -2883,7 +2882,7 @@ func TestSSHRules(t *testing.T) {
 | 
				
			|||||||
		node  types.Node
 | 
							node  types.Node
 | 
				
			||||||
		peers types.Nodes
 | 
							peers types.Nodes
 | 
				
			||||||
		pol   ACLPolicy
 | 
							pol   ACLPolicy
 | 
				
			||||||
		want  []*tailcfg.SSHRule
 | 
							want  *tailcfg.SSHPolicy
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "peers-can-connect",
 | 
								name: "peers-can-connect",
 | 
				
			||||||
@ -2946,7 +2945,7 @@ func TestSSHRules(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []*tailcfg.SSHRule{
 | 
								want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
 | 
				
			||||||
				{
 | 
									{
 | 
				
			||||||
					Principals: []*tailcfg.SSHPrincipal{
 | 
										Principals: []*tailcfg.SSHPrincipal{
 | 
				
			||||||
						{
 | 
											{
 | 
				
			||||||
@ -2991,7 +2990,7 @@ func TestSSHRules(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
					Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
 | 
										Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "peers-cannot-connect",
 | 
								name: "peers-cannot-connect",
 | 
				
			||||||
@ -3042,13 +3041,13 @@ func TestSSHRules(t *testing.T) {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			want: []*tailcfg.SSHRule{},
 | 
								want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			got, err := tt.pol.generateSSHRules(&tt.node, tt.peers)
 | 
								got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if diff := cmp.Diff(tt.want, got); diff != "" {
 | 
								if diff := cmp.Diff(tt.want, got); diff != "" {
 | 
				
			||||||
@ -3155,7 +3154,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
 | 
						got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	want := []tailcfg.FilterRule{
 | 
						want := []tailcfg.FilterRule{
 | 
				
			||||||
@ -3206,7 +3205,7 @@ func TestInvalidTagValidUser(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
 | 
						got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	want := []tailcfg.FilterRule{
 | 
						want := []tailcfg.FilterRule{
 | 
				
			||||||
@ -3265,7 +3264,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
 | 
				
			|||||||
	// c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
						// c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
				
			||||||
	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
 | 
						// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
 | 
						got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	want := []tailcfg.FilterRule{
 | 
						want := []tailcfg.FilterRule{
 | 
				
			||||||
@ -3335,7 +3334,7 @@ func TestValidTagInvalidUser(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{nodes2})
 | 
						got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	want := []tailcfg.FilterRule{
 | 
						want := []tailcfg.FilterRule{
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1066
									
								
								hscontrol/poll.go
									
									
									
									
									
								
							
							
						
						
									
										1066
									
								
								hscontrol/poll.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,96 +0,0 @@
 | 
				
			|||||||
package hscontrol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
					 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
					 | 
				
			||||||
	"tailscale.com/types/key"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	MinimumCapVersion tailcfg.CapabilityVersion = 58
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
 | 
					 | 
				
			||||||
// the clients when something in the network changes.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// The clients POST stuff like HostInfo and their Endpoints here, but
 | 
					 | 
				
			||||||
// only after their first request (marked with the ReadOnly field).
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
 | 
					 | 
				
			||||||
func (ns *noiseServer) NoisePollNetMapHandler(
 | 
					 | 
				
			||||||
	writer http.ResponseWriter,
 | 
					 | 
				
			||||||
	req *http.Request,
 | 
					 | 
				
			||||||
) {
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "NoisePollNetMap").
 | 
					 | 
				
			||||||
		Msg("PollNetMapHandler called")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Any("headers", req.Header).
 | 
					 | 
				
			||||||
		Caller().
 | 
					 | 
				
			||||||
		Msg("Headers")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	body, _ := io.ReadAll(req.Body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	mapRequest := tailcfg.MapRequest{}
 | 
					 | 
				
			||||||
	if err := json.Unmarshal(body, &mapRequest); err != nil {
 | 
					 | 
				
			||||||
		log.Error().
 | 
					 | 
				
			||||||
			Caller().
 | 
					 | 
				
			||||||
			Err(err).
 | 
					 | 
				
			||||||
			Msg("Cannot parse MapRequest")
 | 
					 | 
				
			||||||
		http.Error(writer, "Internal error", http.StatusInternalServerError)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Reject unsupported versions
 | 
					 | 
				
			||||||
	if mapRequest.Version < MinimumCapVersion {
 | 
					 | 
				
			||||||
		log.Info().
 | 
					 | 
				
			||||||
			Caller().
 | 
					 | 
				
			||||||
			Int("min_version", int(MinimumCapVersion)).
 | 
					 | 
				
			||||||
			Int("client_version", int(mapRequest.Version)).
 | 
					 | 
				
			||||||
			Msg("unsupported client connected")
 | 
					 | 
				
			||||||
		http.Error(writer, "Internal error", http.StatusBadRequest)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ns.nodeKey = mapRequest.NodeKey
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	node, err := ns.headscale.db.GetNodeByAnyKey(
 | 
					 | 
				
			||||||
		ns.conn.Peer(),
 | 
					 | 
				
			||||||
		mapRequest.NodeKey,
 | 
					 | 
				
			||||||
		key.NodePublic{},
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
					 | 
				
			||||||
			log.Warn().
 | 
					 | 
				
			||||||
				Str("handler", "NoisePollNetMap").
 | 
					 | 
				
			||||||
				Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
 | 
					 | 
				
			||||||
			http.Error(writer, "Internal error", http.StatusNotFound)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		log.Error().
 | 
					 | 
				
			||||||
			Str("handler", "NoisePollNetMap").
 | 
					 | 
				
			||||||
			Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
 | 
					 | 
				
			||||||
		http.Error(writer, "Internal error", http.StatusInternalServerError)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	log.Debug().
 | 
					 | 
				
			||||||
		Str("handler", "NoisePollNetMap").
 | 
					 | 
				
			||||||
		Str("node", node.Hostname).
 | 
					 | 
				
			||||||
		Int("cap_ver", int(mapRequest.Version)).
 | 
					 | 
				
			||||||
		Msg("A node sending a MapRequest with Noise protocol")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ns.headscale.handlePoll(writer, req.Context(), node, mapRequest)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -90,6 +90,25 @@ func (i StringList) Value() (driver.Value, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type StateUpdateType int
 | 
					type StateUpdateType int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (su StateUpdateType) String() string {
 | 
				
			||||||
 | 
						switch su {
 | 
				
			||||||
 | 
						case StateFullUpdate:
 | 
				
			||||||
 | 
							return "StateFullUpdate"
 | 
				
			||||||
 | 
						case StatePeerChanged:
 | 
				
			||||||
 | 
							return "StatePeerChanged"
 | 
				
			||||||
 | 
						case StatePeerChangedPatch:
 | 
				
			||||||
 | 
							return "StatePeerChangedPatch"
 | 
				
			||||||
 | 
						case StatePeerRemoved:
 | 
				
			||||||
 | 
							return "StatePeerRemoved"
 | 
				
			||||||
 | 
						case StateSelfUpdate:
 | 
				
			||||||
 | 
							return "StateSelfUpdate"
 | 
				
			||||||
 | 
						case StateDERPUpdated:
 | 
				
			||||||
 | 
							return "StateDERPUpdated"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return "unknown state update type"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	StateFullUpdate StateUpdateType = iota
 | 
						StateFullUpdate StateUpdateType = iota
 | 
				
			||||||
	// StatePeerChanged is used for updates that needs
 | 
						// StatePeerChanged is used for updates that needs
 | 
				
			||||||
@ -118,7 +137,7 @@ type StateUpdate struct {
 | 
				
			|||||||
	// ChangeNodes must be set when Type is StatePeerAdded
 | 
						// ChangeNodes must be set when Type is StatePeerAdded
 | 
				
			||||||
	// and StatePeerChanged and contains the full node
 | 
						// and StatePeerChanged and contains the full node
 | 
				
			||||||
	// object for added nodes.
 | 
						// object for added nodes.
 | 
				
			||||||
	ChangeNodes Nodes
 | 
						ChangeNodes []NodeID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ChangePatches must be set when Type is StatePeerChangedPatch
 | 
						// ChangePatches must be set when Type is StatePeerChangedPatch
 | 
				
			||||||
	// and contains a populated PeerChange object.
 | 
						// and contains a populated PeerChange object.
 | 
				
			||||||
@ -127,7 +146,7 @@ type StateUpdate struct {
 | 
				
			|||||||
	// Removed must be set when Type is StatePeerRemoved and
 | 
						// Removed must be set when Type is StatePeerRemoved and
 | 
				
			||||||
	// contain a list of the nodes that has been removed from
 | 
						// contain a list of the nodes that has been removed from
 | 
				
			||||||
	// the network.
 | 
						// the network.
 | 
				
			||||||
	Removed []tailcfg.NodeID
 | 
						Removed []NodeID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// DERPMap must be set when Type is StateDERPUpdated and
 | 
						// DERPMap must be set when Type is StateDERPUpdated and
 | 
				
			||||||
	// contain the new DERP Map.
 | 
						// contain the new DERP Map.
 | 
				
			||||||
@ -138,39 +157,6 @@ type StateUpdate struct {
 | 
				
			|||||||
	Message string
 | 
						Message string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Valid reports if a StateUpdate is correctly filled and
 | 
					 | 
				
			||||||
// panics if the mandatory fields for a type is not
 | 
					 | 
				
			||||||
// filled.
 | 
					 | 
				
			||||||
// Reports true if valid.
 | 
					 | 
				
			||||||
func (su *StateUpdate) Valid() bool {
 | 
					 | 
				
			||||||
	switch su.Type {
 | 
					 | 
				
			||||||
	case StatePeerChanged:
 | 
					 | 
				
			||||||
		if su.ChangeNodes == nil {
 | 
					 | 
				
			||||||
			panic("Mandatory field ChangeNodes is not set on StatePeerChanged update")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case StatePeerChangedPatch:
 | 
					 | 
				
			||||||
		if su.ChangePatches == nil {
 | 
					 | 
				
			||||||
			panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case StatePeerRemoved:
 | 
					 | 
				
			||||||
		if su.Removed == nil {
 | 
					 | 
				
			||||||
			panic("Mandatory field Removed is not set on StatePeerRemove update")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case StateSelfUpdate:
 | 
					 | 
				
			||||||
		if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
 | 
					 | 
				
			||||||
			panic(
 | 
					 | 
				
			||||||
				"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case StateDERPUpdated:
 | 
					 | 
				
			||||||
		if su.DERPMap == nil {
 | 
					 | 
				
			||||||
			panic("Mandatory field DERPMap is not set on StateDERPUpdated update")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Empty reports if there are any updates in the StateUpdate.
 | 
					// Empty reports if there are any updates in the StateUpdate.
 | 
				
			||||||
func (su *StateUpdate) Empty() bool {
 | 
					func (su *StateUpdate) Empty() bool {
 | 
				
			||||||
	switch su.Type {
 | 
						switch su.Type {
 | 
				
			||||||
@ -185,12 +171,12 @@ func (su *StateUpdate) Empty() bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
 | 
					func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
 | 
				
			||||||
	return StateUpdate{
 | 
						return StateUpdate{
 | 
				
			||||||
		Type: StatePeerChangedPatch,
 | 
							Type: StatePeerChangedPatch,
 | 
				
			||||||
		ChangePatches: []*tailcfg.PeerChange{
 | 
							ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				NodeID:    tailcfg.NodeID(nodeID),
 | 
									NodeID:    nodeID.NodeID(),
 | 
				
			||||||
				KeyExpiry: &expiry,
 | 
									KeyExpiry: &expiry,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
				
			|||||||
@ -69,6 +69,8 @@ type Config struct {
 | 
				
			|||||||
	CLI CLIConfig
 | 
						CLI CLIConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ACL ACLConfig
 | 
						ACL ACLConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Tuning Tuning
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type SqliteConfig struct {
 | 
					type SqliteConfig struct {
 | 
				
			||||||
@ -161,6 +163,11 @@ type LogConfig struct {
 | 
				
			|||||||
	Level  zerolog.Level
 | 
						Level  zerolog.Level
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Tuning struct {
 | 
				
			||||||
 | 
						BatchChangeDelay               time.Duration
 | 
				
			||||||
 | 
						NodeMapSessionBufferedChanSize int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func LoadConfig(path string, isFile bool) error {
 | 
					func LoadConfig(path string, isFile bool) error {
 | 
				
			||||||
	if isFile {
 | 
						if isFile {
 | 
				
			||||||
		viper.SetConfigFile(path)
 | 
							viper.SetConfigFile(path)
 | 
				
			||||||
@ -220,6 +227,9 @@ func LoadConfig(path string, isFile bool) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("node_update_check_interval", "10s")
 | 
						viper.SetDefault("node_update_check_interval", "10s")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						viper.SetDefault("tuning.batch_change_delay", "800ms")
 | 
				
			||||||
 | 
						viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if IsCLIConfigured() {
 | 
						if IsCLIConfigured() {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -719,6 +729,12 @@ func GetHeadscaleConfig() (*Config, error) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Log: GetLogConfig(),
 | 
							Log: GetLogConfig(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// TODO(kradalby): Document these settings when more stable
 | 
				
			||||||
 | 
							Tuning: Tuning{
 | 
				
			||||||
 | 
								BatchChangeDelay:               viper.GetDuration("tuning.batch_change_delay"),
 | 
				
			||||||
 | 
								NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -7,11 +7,13 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
						v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"go4.org/netipx"
 | 
						"go4.org/netipx"
 | 
				
			||||||
	"google.golang.org/protobuf/types/known/timestamppb"
 | 
						"google.golang.org/protobuf/types/known/timestamppb"
 | 
				
			||||||
@ -27,9 +29,24 @@ var (
 | 
				
			|||||||
	ErrNodeUserHasNoName    = errors.New("node user has no name")
 | 
						ErrNodeUserHasNoName    = errors.New("node user has no name")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type NodeID uint64
 | 
				
			||||||
 | 
					type NodeConnectedMap map[NodeID]bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (id NodeID) StableID() tailcfg.StableNodeID {
 | 
				
			||||||
 | 
						return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (id NodeID) NodeID() tailcfg.NodeID {
 | 
				
			||||||
 | 
						return tailcfg.NodeID(id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (id NodeID) Uint64() uint64 {
 | 
				
			||||||
 | 
						return uint64(id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Node is a Headscale client.
 | 
					// Node is a Headscale client.
 | 
				
			||||||
type Node struct {
 | 
					type Node struct {
 | 
				
			||||||
	ID uint64 `gorm:"primary_key"`
 | 
						ID NodeID `gorm:"primary_key"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// MachineKeyDatabaseField is the string representation of MachineKey
 | 
						// MachineKeyDatabaseField is the string representation of MachineKey
 | 
				
			||||||
	// it is _only_ used for reading and writing the key to the
 | 
						// it is _only_ used for reading and writing the key to the
 | 
				
			||||||
@ -198,7 +215,7 @@ func (node Node) IsExpired() bool {
 | 
				
			|||||||
		return false
 | 
							return false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return time.Now().UTC().After(*node.Expiry)
 | 
						return time.Since(*node.Expiry) > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IsEphemeral returns if the node is registered as an Ephemeral node.
 | 
					// IsEphemeral returns if the node is registered as an Ephemeral node.
 | 
				
			||||||
@ -319,7 +336,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (node *Node) Proto() *v1.Node {
 | 
					func (node *Node) Proto() *v1.Node {
 | 
				
			||||||
	nodeProto := &v1.Node{
 | 
						nodeProto := &v1.Node{
 | 
				
			||||||
		Id:         node.ID,
 | 
							Id:         uint64(node.ID),
 | 
				
			||||||
		MachineKey: node.MachineKey.String(),
 | 
							MachineKey: node.MachineKey.String(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		NodeKey:     node.NodeKey.String(),
 | 
							NodeKey:     node.NodeKey.String(),
 | 
				
			||||||
@ -486,8 +503,8 @@ func (nodes Nodes) String() string {
 | 
				
			|||||||
	return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
 | 
						return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (nodes Nodes) IDMap() map[uint64]*Node {
 | 
					func (nodes Nodes) IDMap() map[NodeID]*Node {
 | 
				
			||||||
	ret := map[uint64]*Node{}
 | 
						ret := map[NodeID]*Node{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, node := range nodes {
 | 
						for _, node := range nodes {
 | 
				
			||||||
		ret[node.ID] = node
 | 
							ret[node.ID] = node
 | 
				
			||||||
 | 
				
			|||||||
@ -83,7 +83,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -142,7 +142,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
 | 
				
			|||||||
@ -53,7 +53,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -92,7 +92,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,7 @@ func TestPingAllByIP(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -103,7 +103,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -135,7 +135,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clientIPs := make(map[TailscaleClient][]netip.Addr)
 | 
						clientIPs := make(map[TailscaleClient][]netip.Addr)
 | 
				
			||||||
	for _, client := range allClients {
 | 
						for _, client := range allClients {
 | 
				
			||||||
@ -176,7 +176,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allClients, err = scenario.ListTailscaleClients()
 | 
						allClients, err = scenario.ListTailscaleClients()
 | 
				
			||||||
	assertNoErrListClients(t, err)
 | 
						assertNoErrListClients(t, err)
 | 
				
			||||||
@ -329,7 +329,7 @@ func TestPingAllByHostname(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allHostnames, err := scenario.ListTailscaleClientsFQDNs()
 | 
						allHostnames, err := scenario.ListTailscaleClientsFQDNs()
 | 
				
			||||||
	assertNoErrListFQDN(t, err)
 | 
						assertNoErrListFQDN(t, err)
 | 
				
			||||||
@ -539,7 +539,7 @@ func TestResolveMagicDNS(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Poor mans cache
 | 
						// Poor mans cache
 | 
				
			||||||
	_, err = scenario.ListTailscaleClientsFQDNs()
 | 
						_, err = scenario.ListTailscaleClientsFQDNs()
 | 
				
			||||||
@ -609,7 +609,7 @@ func TestExpireNode(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -711,7 +711,7 @@ func TestExpireNode(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
					func TestNodeOnlineStatus(t *testing.T) {
 | 
				
			||||||
	IntegrationSkip(t)
 | 
						IntegrationSkip(t)
 | 
				
			||||||
	t.Parallel()
 | 
						t.Parallel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -723,7 +723,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
		"user1": len(MustTestVersions),
 | 
							"user1": len(MustTestVersions),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen"))
 | 
						err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online"))
 | 
				
			||||||
	assertNoErrHeadscaleEnv(t, err)
 | 
						assertNoErrHeadscaleEnv(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allClients, err := scenario.ListTailscaleClients()
 | 
						allClients, err := scenario.ListTailscaleClients()
 | 
				
			||||||
@ -735,7 +735,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
	err = scenario.WaitForTailscaleSync()
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
	assertNoErrSync(t, err)
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertClientsState(t, allClients)
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
		return x.String()
 | 
							return x.String()
 | 
				
			||||||
@ -755,8 +755,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
	headscale, err := scenario.Headscale()
 | 
						headscale, err := scenario.Headscale()
 | 
				
			||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	keepAliveInterval := 60 * time.Second
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Duration is chosen arbitrarily, 10m is reported in #1561
 | 
						// Duration is chosen arbitrarily, 10m is reported in #1561
 | 
				
			||||||
	testDuration := 12 * time.Minute
 | 
						testDuration := 12 * time.Minute
 | 
				
			||||||
	start := time.Now()
 | 
						start := time.Now()
 | 
				
			||||||
@ -780,11 +778,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
		err = json.Unmarshal([]byte(result), &nodes)
 | 
							err = json.Unmarshal([]byte(result), &nodes)
 | 
				
			||||||
		assertNoErr(t, err)
 | 
							assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		now := time.Now()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Threshold with some leeway
 | 
					 | 
				
			||||||
		lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Verify that headscale reports the nodes as online
 | 
							// Verify that headscale reports the nodes as online
 | 
				
			||||||
		for _, node := range nodes {
 | 
							for _, node := range nodes {
 | 
				
			||||||
			// All nodes should be online
 | 
								// All nodes should be online
 | 
				
			||||||
@ -795,18 +788,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
				node.GetName(),
 | 
									node.GetName(),
 | 
				
			||||||
				time.Since(start),
 | 
									time.Since(start),
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					 | 
				
			||||||
			lastSeen := node.GetLastSeen().AsTime()
 | 
					 | 
				
			||||||
			// All nodes should have been last seen between now and the keepAliveInterval
 | 
					 | 
				
			||||||
			assert.Truef(
 | 
					 | 
				
			||||||
				t,
 | 
					 | 
				
			||||||
				lastSeen.After(lastSeenThreshold),
 | 
					 | 
				
			||||||
				"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
 | 
					 | 
				
			||||||
				node.GetName(),
 | 
					 | 
				
			||||||
				lastSeen,
 | 
					 | 
				
			||||||
				keepAliveInterval,
 | 
					 | 
				
			||||||
				lastSeenThreshold,
 | 
					 | 
				
			||||||
			)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Verify that all nodes report all nodes to be online
 | 
							// Verify that all nodes report all nodes to be online
 | 
				
			||||||
@ -834,15 +815,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
					client.Hostname(),
 | 
										client.Hostname(),
 | 
				
			||||||
					time.Since(start),
 | 
										time.Since(start),
 | 
				
			||||||
				)
 | 
									)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				// from docs: last seen to tailcontrol; only present if offline
 | 
					 | 
				
			||||||
				// assert.Nilf(
 | 
					 | 
				
			||||||
				// 	t,
 | 
					 | 
				
			||||||
				// 	peerStatus.LastSeen,
 | 
					 | 
				
			||||||
				// 	"expected node %s to not have LastSeen set, got %s",
 | 
					 | 
				
			||||||
				// 	peerStatus.HostName,
 | 
					 | 
				
			||||||
				// 	peerStatus.LastSeen,
 | 
					 | 
				
			||||||
				// )
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -850,3 +822,87 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
 | 
				
			|||||||
		time.Sleep(time.Second)
 | 
							time.Sleep(time.Second)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TestPingAllByIPManyUpDown is a variant of the PingAll
 | 
				
			||||||
 | 
					// test which will take the tailscale node up and down
 | 
				
			||||||
 | 
					// five times ensuring they are able to restablish connectivity.
 | 
				
			||||||
 | 
					func TestPingAllByIPManyUpDown(t *testing.T) {
 | 
				
			||||||
 | 
						IntegrationSkip(t)
 | 
				
			||||||
 | 
						t.Parallel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scenario, err := NewScenario()
 | 
				
			||||||
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
						defer scenario.Shutdown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO(kradalby): it does not look like the user thing works, only second
 | 
				
			||||||
 | 
						// get created? maybe only when many?
 | 
				
			||||||
 | 
						spec := map[string]int{
 | 
				
			||||||
 | 
							"user1": len(MustTestVersions),
 | 
				
			||||||
 | 
							"user2": len(MustTestVersions),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						headscaleConfig := map[string]string{
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_URLS":                    "",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_ENABLED":          "true",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_REGION_ID":        "999",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_REGION_CODE":      "headscale",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_REGION_NAME":      "Headscale Embedded DERP",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
 | 
				
			||||||
 | 
							"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Envknob for enabling DERP debug logs
 | 
				
			||||||
 | 
							"DERP_DEBUG_LOGS":        "true",
 | 
				
			||||||
 | 
							"DERP_PROBER_DEBUG_LOGS": "true",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = scenario.CreateHeadscaleEnv(spec,
 | 
				
			||||||
 | 
							[]tsic.Option{},
 | 
				
			||||||
 | 
							hsic.WithTestName("pingallbyip"),
 | 
				
			||||||
 | 
							hsic.WithConfigEnv(headscaleConfig),
 | 
				
			||||||
 | 
							hsic.WithTLS(),
 | 
				
			||||||
 | 
							hsic.WithHostnameAsServerURL(),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						assertNoErrHeadscaleEnv(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						allClients, err := scenario.ListTailscaleClients()
 | 
				
			||||||
 | 
						assertNoErrListClients(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						allIps, err := scenario.ListTailscaleClientsIPs()
 | 
				
			||||||
 | 
						assertNoErrListClientIPs(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
 | 
						assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// assertClientsState(t, allClients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
 | 
				
			||||||
 | 
							return x.String()
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						success := pingAllHelper(t, allClients, allAddrs)
 | 
				
			||||||
 | 
						t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for run := range 3 {
 | 
				
			||||||
 | 
							t.Logf("Starting DownUpPing run %d", run+1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, client := range allClients {
 | 
				
			||||||
 | 
								t.Logf("taking down %q", client.Hostname())
 | 
				
			||||||
 | 
								client.Down()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							time.Sleep(5 * time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, client := range allClients {
 | 
				
			||||||
 | 
								t.Logf("bringing up %q", client.Hostname())
 | 
				
			||||||
 | 
								client.Up()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							time.Sleep(5 * time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
 | 
							assertNoErrSync(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							success := pingAllHelper(t, allClients, allAddrs)
 | 
				
			||||||
 | 
							t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -212,7 +212,11 @@ func TestEnablingRoutes(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		if route.GetId() == routeToBeDisabled.GetId() {
 | 
							if route.GetId() == routeToBeDisabled.GetId() {
 | 
				
			||||||
			assert.Equal(t, false, route.GetEnabled())
 | 
								assert.Equal(t, false, route.GetEnabled())
 | 
				
			||||||
			assert.Equal(t, false, route.GetIsPrimary())
 | 
					
 | 
				
			||||||
 | 
								// since this is the only route of this cidr,
 | 
				
			||||||
 | 
								// it will not failover, and remain Primary
 | 
				
			||||||
 | 
								// until something can replace it.
 | 
				
			||||||
 | 
								assert.Equal(t, true, route.GetIsPrimary())
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			assert.Equal(t, true, route.GetEnabled())
 | 
								assert.Equal(t, true, route.GetEnabled())
 | 
				
			||||||
			assert.Equal(t, true, route.GetIsPrimary())
 | 
								assert.Equal(t, true, route.GetIsPrimary())
 | 
				
			||||||
@ -291,6 +295,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	client := allClients[2]
 | 
						client := allClients[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Logf("Advertise route from r1 (%s) and r2 (%s), making it HA, n1 is primary", subRouter1.Hostname(), subRouter2.Hostname())
 | 
				
			||||||
	// advertise HA route on node 1 and 2
 | 
						// advertise HA route on node 1 and 2
 | 
				
			||||||
	// ID 1 will be primary
 | 
						// ID 1 will be primary
 | 
				
			||||||
	// ID 2 will be secondary
 | 
						// ID 2 will be secondary
 | 
				
			||||||
@ -384,12 +389,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	// Node 1 is primary
 | 
						// Node 1 is primary
 | 
				
			||||||
	assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
 | 
						assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, enablingRoutes[0].GetEnabled())
 | 
						assert.Equal(t, true, enablingRoutes[0].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
 | 
						assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 2 is not primary
 | 
						// Node 2 is not primary
 | 
				
			||||||
	assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
 | 
						assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, enablingRoutes[1].GetEnabled())
 | 
						assert.Equal(t, true, enablingRoutes[1].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, false, enablingRoutes[1].GetIsPrimary())
 | 
						assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Verify that the client has routes from the primary machine
 | 
						// Verify that the client has routes from the primary machine
 | 
				
			||||||
	srs1, err := subRouter1.Status()
 | 
						srs1, err := subRouter1.Status()
 | 
				
			||||||
@ -401,6 +406,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
 | 
						srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
 | 
				
			||||||
	srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
 | 
						srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
 | 
				
			||||||
 | 
						assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
						assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
				
			||||||
	assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
						assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -411,7 +419,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Take down the current primary
 | 
						// Take down the current primary
 | 
				
			||||||
	t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname())
 | 
						t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
 | 
				
			||||||
 | 
						t.Logf("expecting r2 (%s) to take over as primary", subRouter2.Hostname())
 | 
				
			||||||
	err = subRouter1.Down()
 | 
						err = subRouter1.Down()
 | 
				
			||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -435,15 +444,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	// Node 1 is not primary
 | 
						// Node 1 is not primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterMove[0].GetEnabled())
 | 
						assert.Equal(t, true, routesAfterMove[0].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, false, routesAfterMove[0].GetIsPrimary())
 | 
						assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 2 is primary
 | 
						// Node 2 is primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterMove[1].GetEnabled())
 | 
						assert.Equal(t, true, routesAfterMove[1].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterMove[1].GetIsPrimary())
 | 
						assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// TODO(kradalby): Check client status
 | 
					 | 
				
			||||||
	// Route is expected to be on SR2
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	srs2, err = subRouter2.Status()
 | 
						srs2, err = subRouter2.Status()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -453,6 +459,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
						srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
				
			||||||
	srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
						srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
 | 
				
			||||||
 | 
						assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
 | 
						assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
 | 
				
			||||||
	assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
 | 
						assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -465,7 +474,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Take down subnet router 2, leaving none available
 | 
						// Take down subnet router 2, leaving none available
 | 
				
			||||||
	t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname())
 | 
						t.Logf("taking down subnet router r2 (%s)", subRouter2.Hostname())
 | 
				
			||||||
 | 
						t.Logf("expecting r2 (%s) to remain primary, no other available", subRouter2.Hostname())
 | 
				
			||||||
	err = subRouter2.Down()
 | 
						err = subRouter2.Down()
 | 
				
			||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -489,14 +499,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	// Node 1 is not primary
 | 
						// Node 1 is not primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
 | 
						assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary())
 | 
						assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 2 is primary
 | 
						// Node 2 is primary
 | 
				
			||||||
	// if the node goes down, but no other suitable route is
 | 
						// if the node goes down, but no other suitable route is
 | 
				
			||||||
	// available, keep the last known good route.
 | 
						// available, keep the last known good route.
 | 
				
			||||||
	assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
 | 
						assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary())
 | 
						assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): Check client status
 | 
						// TODO(kradalby): Check client status
 | 
				
			||||||
	// Both are expected to be down
 | 
						// Both are expected to be down
 | 
				
			||||||
@ -508,6 +518,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
						srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
				
			||||||
	srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
						srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
 | 
				
			||||||
 | 
						assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
 | 
						assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
 | 
				
			||||||
	assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
 | 
						assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -520,7 +533,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Bring up subnet router 1, making the route available from there.
 | 
						// Bring up subnet router 1, making the route available from there.
 | 
				
			||||||
	t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname())
 | 
						t.Logf("bringing up subnet router r1 (%s)", subRouter1.Hostname())
 | 
				
			||||||
 | 
						t.Logf("expecting r1 (%s) to take over as primary (only one online)", subRouter1.Hostname())
 | 
				
			||||||
	err = subRouter1.Up()
 | 
						err = subRouter1.Up()
 | 
				
			||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -544,12 +558,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	// Node 1 is primary
 | 
						// Node 1 is primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
 | 
						assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary())
 | 
						assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 2 is not primary
 | 
						// Node 2 is not primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
 | 
						assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary())
 | 
						assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Verify that the route is announced from subnet router 1
 | 
						// Verify that the route is announced from subnet router 1
 | 
				
			||||||
	clientStatus, err = client.Status()
 | 
						clientStatus, err = client.Status()
 | 
				
			||||||
@ -558,6 +572,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
						srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
				
			||||||
	srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
						srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
 | 
				
			||||||
 | 
						assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
						assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
				
			||||||
	assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
						assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -570,7 +587,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Bring up subnet router 2, should result in no change.
 | 
						// Bring up subnet router 2, should result in no change.
 | 
				
			||||||
	t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname())
 | 
						t.Logf("bringing up subnet router r2 (%s)", subRouter2.Hostname())
 | 
				
			||||||
 | 
						t.Logf("both online, expecting r1 (%s) to still be primary (no flapping)", subRouter1.Hostname())
 | 
				
			||||||
	err = subRouter2.Up()
 | 
						err = subRouter2.Up()
 | 
				
			||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -594,12 +612,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	// Node 1 is not primary
 | 
						// Node 1 is not primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
 | 
						assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary())
 | 
						assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 2 is primary
 | 
						// Node 2 is primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
 | 
				
			||||||
	assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
 | 
						assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
 | 
				
			||||||
	assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary())
 | 
						assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Verify that the route is announced from subnet router 1
 | 
						// Verify that the route is announced from subnet router 1
 | 
				
			||||||
	clientStatus, err = client.Status()
 | 
						clientStatus, err = client.Status()
 | 
				
			||||||
@ -608,6 +626,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
						srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
 | 
				
			||||||
	srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
						srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
 | 
				
			||||||
 | 
						assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
						assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
 | 
				
			||||||
	assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
						assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -620,7 +641,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Disable the route of subnet router 1, making it failover to 2
 | 
						// Disable the route of subnet router 1, making it failover to 2
 | 
				
			||||||
	t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname())
 | 
						t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
 | 
				
			||||||
 | 
						t.Logf("expecting route to failover to r2 (%s), which is still available", subRouter2.Hostname())
 | 
				
			||||||
	_, err = headscale.Execute(
 | 
						_, err = headscale.Execute(
 | 
				
			||||||
		[]string{
 | 
							[]string{
 | 
				
			||||||
			"headscale",
 | 
								"headscale",
 | 
				
			||||||
@ -648,7 +670,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
	assert.Len(t, routesAfterDisabling1, 2)
 | 
						assert.Len(t, routesAfterDisabling1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
 | 
						t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 1 is not primary
 | 
						// Node 1 is not primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
 | 
				
			||||||
@ -680,6 +702,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// enable the route of subnet router 1, no change expected
 | 
						// enable the route of subnet router 1, no change expected
 | 
				
			||||||
	t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
 | 
						t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
 | 
				
			||||||
 | 
						t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
 | 
				
			||||||
	_, err = headscale.Execute(
 | 
						_, err = headscale.Execute(
 | 
				
			||||||
		[]string{
 | 
							[]string{
 | 
				
			||||||
			"headscale",
 | 
								"headscale",
 | 
				
			||||||
@ -736,7 +759,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// delete the route of subnet router 2, failover to one expected
 | 
						// delete the route of subnet router 2, failover to one expected
 | 
				
			||||||
	t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname())
 | 
						t.Logf("deleting route in subnet router r2 (%s)", subRouter2.Hostname())
 | 
				
			||||||
 | 
						t.Logf("expecting route to failover to r1 (%s)", subRouter1.Hostname())
 | 
				
			||||||
	_, err = headscale.Execute(
 | 
						_, err = headscale.Execute(
 | 
				
			||||||
		[]string{
 | 
							[]string{
 | 
				
			||||||
			"headscale",
 | 
								"headscale",
 | 
				
			||||||
@ -764,7 +788,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
	assertNoErr(t, err)
 | 
						assertNoErr(t, err)
 | 
				
			||||||
	assert.Len(t, routesAfterDeleting2, 1)
 | 
						assert.Len(t, routesAfterDeleting2, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Logf("routes after deleting2 %#v", routesAfterDeleting2)
 | 
						t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Node 1 is primary
 | 
						// Node 1 is primary
 | 
				
			||||||
	assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
 | 
						assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
 | 
				
			||||||
 | 
				
			|||||||
@ -50,6 +50,8 @@ var (
 | 
				
			|||||||
	tailscaleVersions2021 = map[string]bool{
 | 
						tailscaleVersions2021 = map[string]bool{
 | 
				
			||||||
		"head":     true,
 | 
							"head":     true,
 | 
				
			||||||
		"unstable": true,
 | 
							"unstable": true,
 | 
				
			||||||
 | 
							"1.60":     true,  // CapVer: 82
 | 
				
			||||||
 | 
							"1.58":     true,  // CapVer: 82
 | 
				
			||||||
		"1.56":     true,  // CapVer: 82
 | 
							"1.56":     true,  // CapVer: 82
 | 
				
			||||||
		"1.54":     true,  // CapVer: 79
 | 
							"1.54":     true,  // CapVer: 79
 | 
				
			||||||
		"1.52":     true,  // CapVer: 79
 | 
							"1.52":     true,  // CapVer: 79
 | 
				
			||||||
 | 
				
			|||||||
@ -27,7 +27,7 @@ type TailscaleClient interface {
 | 
				
			|||||||
	Down() error
 | 
						Down() error
 | 
				
			||||||
	IPs() ([]netip.Addr, error)
 | 
						IPs() ([]netip.Addr, error)
 | 
				
			||||||
	FQDN() (string, error)
 | 
						FQDN() (string, error)
 | 
				
			||||||
	Status() (*ipnstate.Status, error)
 | 
						Status(...bool) (*ipnstate.Status, error)
 | 
				
			||||||
	Netmap() (*netmap.NetworkMap, error)
 | 
						Netmap() (*netmap.NetworkMap, error)
 | 
				
			||||||
	Netcheck() (*netcheck.Report, error)
 | 
						Netcheck() (*netcheck.Report, error)
 | 
				
			||||||
	WaitForNeedsLogin() error
 | 
						WaitForNeedsLogin() error
 | 
				
			||||||
 | 
				
			|||||||
@ -9,6 +9,7 @@ import (
 | 
				
			|||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -503,7 +504,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Status returns the ipnstate.Status of the Tailscale instance.
 | 
					// Status returns the ipnstate.Status of the Tailscale instance.
 | 
				
			||||||
func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
 | 
					func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
 | 
				
			||||||
	command := []string{
 | 
						command := []string{
 | 
				
			||||||
		"tailscale",
 | 
							"tailscale",
 | 
				
			||||||
		"status",
 | 
							"status",
 | 
				
			||||||
@ -521,60 +522,70 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
 | 
				
			|||||||
		return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
 | 
							return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("status netmap to /tmp/control: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &status, err
 | 
						return &status, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
 | 
					// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
 | 
				
			||||||
// Only works with Tailscale 1.56 and newer.
 | 
					// Only works with Tailscale 1.56 and newer.
 | 
				
			||||||
// Panics if version is lower then minimum.
 | 
					// Panics if version is lower then minimum.
 | 
				
			||||||
// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
 | 
					func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
 | 
				
			||||||
// 	if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
 | 
						if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
 | 
				
			||||||
// 		panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
 | 
							panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
 | 
				
			||||||
// 	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	command := []string{
 | 
						command := []string{
 | 
				
			||||||
// 		"tailscale",
 | 
							"tailscale",
 | 
				
			||||||
// 		"debug",
 | 
							"debug",
 | 
				
			||||||
// 		"netmap",
 | 
							"netmap",
 | 
				
			||||||
// 	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	result, stderr, err := t.Execute(command)
 | 
						result, stderr, err := t.Execute(command)
 | 
				
			||||||
// 	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
// 		fmt.Printf("stderr: %s\n", stderr)
 | 
							fmt.Printf("stderr: %s\n", stderr)
 | 
				
			||||||
// 		return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
 | 
							return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
 | 
				
			||||||
// 	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	var nm netmap.NetworkMap
 | 
						var nm netmap.NetworkMap
 | 
				
			||||||
// 	err = json.Unmarshal([]byte(result), &nm)
 | 
						err = json.Unmarshal([]byte(result), &nm)
 | 
				
			||||||
// 	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
// 		return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
 | 
							return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
 | 
				
			||||||
// 	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 	return &nm, err
 | 
						err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755)
 | 
				
			||||||
// }
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("saving netmap to /tmp/control: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &nm, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
 | 
					// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
 | 
				
			||||||
// This implementation is based on getting the netmap from `tailscale debug watch-ipn`
 | 
					// This implementation is based on getting the netmap from `tailscale debug watch-ipn`
 | 
				
			||||||
// as there seem to be some weirdness omitting endpoint and DERP info if we use
 | 
					// as there seem to be some weirdness omitting endpoint and DERP info if we use
 | 
				
			||||||
// Patch updates.
 | 
					// Patch updates.
 | 
				
			||||||
// This implementation works on all supported versions.
 | 
					// This implementation works on all supported versions.
 | 
				
			||||||
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
 | 
					// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
 | 
				
			||||||
	// watch-ipn will only give an update if something is happening,
 | 
					// 	// watch-ipn will only give an update if something is happening,
 | 
				
			||||||
	// since we send keep alives, the worst case for this should be
 | 
					// 	// since we send keep alives, the worst case for this should be
 | 
				
			||||||
	// 1 minute, but set a slightly more conservative time.
 | 
					// 	// 1 minute, but set a slightly more conservative time.
 | 
				
			||||||
	ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute)
 | 
					// 	ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	notify, err := t.watchIPN(ctx)
 | 
					// 	notify, err := t.watchIPN(ctx)
 | 
				
			||||||
	if err != nil {
 | 
					// 	if err != nil {
 | 
				
			||||||
		return nil, err
 | 
					// 		return nil, err
 | 
				
			||||||
	}
 | 
					// 	}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if notify.NetMap == nil {
 | 
					// 	if notify.NetMap == nil {
 | 
				
			||||||
		return nil, fmt.Errorf("no netmap present in ipn.Notify")
 | 
					// 		return nil, fmt.Errorf("no netmap present in ipn.Notify")
 | 
				
			||||||
	}
 | 
					// 	}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return notify.NetMap, nil
 | 
					// 	return notify.NetMap, nil
 | 
				
			||||||
}
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until
 | 
					// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until
 | 
				
			||||||
// it gets one that has a netmap.NetworkMap.
 | 
					// it gets one that has a netmap.NetworkMap.
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,7 @@ import (
 | 
				
			|||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/juanfont/headscale/integration/tsic"
 | 
						"github.com/juanfont/headscale/integration/tsic"
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -154,11 +155,11 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) {
 | 
				
			|||||||
func assertValidNetmap(t *testing.T, client TailscaleClient) {
 | 
					func assertValidNetmap(t *testing.T, client TailscaleClient) {
 | 
				
			||||||
	t.Helper()
 | 
						t.Helper()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
 | 
						if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
 | 
				
			||||||
	// 	t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
 | 
							t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 	return
 | 
							return
 | 
				
			||||||
	// }
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Logf("Checking netmap of %q", client.Hostname())
 | 
						t.Logf("Checking netmap of %q", client.Hostname())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -175,7 +176,11 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
 | 
				
			|||||||
	assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
 | 
						assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
 | 
				
			||||||
	assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
 | 
						assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if netmap.SelfNode.Online() != nil {
 | 
				
			||||||
		assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
 | 
							assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							t.Errorf("Online should not be nil for %s", client.Hostname())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
 | 
						assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
 | 
				
			||||||
	assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
 | 
						assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
 | 
				
			||||||
@ -213,7 +218,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
 | 
				
			|||||||
// This test is not suitable for ACL/partial connection tests.
 | 
					// This test is not suitable for ACL/partial connection tests.
 | 
				
			||||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
 | 
					func assertValidStatus(t *testing.T, client TailscaleClient) {
 | 
				
			||||||
	t.Helper()
 | 
						t.Helper()
 | 
				
			||||||
	status, err := client.Status()
 | 
						status, err := client.Status(true)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("getting status for %q: %s", client.Hostname(), err)
 | 
							t.Fatalf("getting status for %q: %s", client.Hostname(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user