mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 01:51:04 +01:00 
			
		
		
		
	wrap policy in policy manager interface (#2255)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									2c1ad6d11a
								
							
						
					
					
						commit
						f7b0cbbbea
					
				
							
								
								
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							@ -31,8 +31,7 @@ jobs:
 | 
				
			|||||||
          - TestPreAuthKeyCorrectUserLoggedInCommand
 | 
					          - TestPreAuthKeyCorrectUserLoggedInCommand
 | 
				
			||||||
          - TestApiKeyCommand
 | 
					          - TestApiKeyCommand
 | 
				
			||||||
          - TestNodeTagCommand
 | 
					          - TestNodeTagCommand
 | 
				
			||||||
          - TestNodeAdvertiseTagNoACLCommand
 | 
					          - TestNodeAdvertiseTagCommand
 | 
				
			||||||
          - TestNodeAdvertiseTagWithACLCommand
 | 
					 | 
				
			||||||
          - TestNodeCommand
 | 
					          - TestNodeCommand
 | 
				
			||||||
          - TestNodeExpireCommand
 | 
					          - TestNodeExpireCommand
 | 
				
			||||||
          - TestNodeRenameCommand
 | 
					          - TestNodeRenameCommand
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										163
									
								
								hscontrol/app.go
									
									
									
									
									
								
							
							
						
						
									
										163
									
								
								hscontrol/app.go
									
									
									
									
									
								
							@ -88,7 +88,8 @@ type Headscale struct {
 | 
				
			|||||||
	DERPMap    *tailcfg.DERPMap
 | 
						DERPMap    *tailcfg.DERPMap
 | 
				
			||||||
	DERPServer *derpServer.DERPServer
 | 
						DERPServer *derpServer.DERPServer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ACLPolicy *policy.ACLPolicy
 | 
						polManOnce sync.Once
 | 
				
			||||||
 | 
						polMan     policy.PolicyManager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mapper       *mapper.Mapper
 | 
						mapper       *mapper.Mapper
 | 
				
			||||||
	nodeNotifier *notifier.Notifier
 | 
						nodeNotifier *notifier.Notifier
 | 
				
			||||||
@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = app.loadPolicyManager(); err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("failed to load ACL policy: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var authProvider AuthProvider
 | 
						var authProvider AuthProvider
 | 
				
			||||||
	authProvider = NewAuthProviderWeb(cfg.ServerURL)
 | 
						authProvider = NewAuthProviderWeb(cfg.ServerURL)
 | 
				
			||||||
	if cfg.OIDC.Issuer != "" {
 | 
						if cfg.OIDC.Issuer != "" {
 | 
				
			||||||
@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
			app.db,
 | 
								app.db,
 | 
				
			||||||
			app.nodeNotifier,
 | 
								app.nodeNotifier,
 | 
				
			||||||
			app.ipAlloc,
 | 
								app.ipAlloc,
 | 
				
			||||||
 | 
								app.polMan,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
 | 
								if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
 | 
				
			||||||
@ -475,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
 | 
				
			|||||||
	return router
 | 
						return router
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
 | 
				
			||||||
 | 
					// Maybe we should attempt a new in memory state and not go via the DB?
 | 
				
			||||||
 | 
					func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
 | 
				
			||||||
 | 
						users, err := db.ListUsers()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						changed, err := polMan.SetUsers(users)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if changed {
 | 
				
			||||||
 | 
							ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
 | 
				
			||||||
 | 
							notif.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type: types.StateFullUpdate,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
 | 
				
			||||||
 | 
					// Maybe we should attempt a new in memory state and not go via the DB?
 | 
				
			||||||
 | 
					func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
 | 
				
			||||||
 | 
						nodes, err := db.ListNodes()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						changed, err := polMan.SetNodes(nodes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if changed {
 | 
				
			||||||
 | 
							ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
 | 
				
			||||||
 | 
							notif.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
 | 
								Type: types.StateFullUpdate,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
 | 
					// Serve launches the HTTP and gRPC server service Headscale and the API.
 | 
				
			||||||
func (h *Headscale) Serve() error {
 | 
					func (h *Headscale) Serve() error {
 | 
				
			||||||
	if profilingEnabled {
 | 
						if profilingEnabled {
 | 
				
			||||||
@ -490,19 +542,13 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var err error
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err = h.loadACLPolicy(); err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed to load ACL policy: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if dumpConfig {
 | 
						if dumpConfig {
 | 
				
			||||||
		spew.Dump(h.cfg)
 | 
							spew.Dump(h.cfg)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Fetch an initial DERP Map before we start serving
 | 
						// Fetch an initial DERP Map before we start serving
 | 
				
			||||||
	h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
 | 
						h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
 | 
				
			||||||
	h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
 | 
						h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.cfg.DERP.ServerEnabled {
 | 
						if h.cfg.DERP.ServerEnabled {
 | 
				
			||||||
		// When embedded DERP is enabled we always need a STUN server
 | 
							// When embedded DERP is enabled we always need a STUN server
 | 
				
			||||||
@ -772,12 +818,21 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
					Str("signal", sig.String()).
 | 
										Str("signal", sig.String()).
 | 
				
			||||||
					Msg("Received SIGHUP, reloading ACL and Config")
 | 
										Msg("Received SIGHUP, reloading ACL and Config")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// TODO(kradalby): Reload config on SIGHUP
 | 
									if err := h.loadPolicyManager(); err != nil {
 | 
				
			||||||
				if err := h.loadACLPolicy(); err != nil {
 | 
										log.Error().Err(err).Msg("failed to reload Policy")
 | 
				
			||||||
					log.Error().Err(err).Msg("failed to reload ACL policy")
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if h.ACLPolicy != nil {
 | 
									pol, err := h.policyBytes()
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										log.Error().Err(err).Msg("failed to get policy blob")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									changed, err := h.polMan.SetPolicy(pol)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										log.Error().Err(err).Msg("failed to set new policy")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if changed {
 | 
				
			||||||
					log.Info().
 | 
										log.Info().
 | 
				
			||||||
						Msg("ACL policy successfully reloaded, notifying nodes of change")
 | 
											Msg("ACL policy successfully reloaded, notifying nodes of change")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -996,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
 | 
				
			|||||||
	return &machineKey, nil
 | 
						return &machineKey, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) loadACLPolicy() error {
 | 
					// policyBytes returns the appropriate policy for the
 | 
				
			||||||
	var (
 | 
					// current configuration as a []byte array.
 | 
				
			||||||
		pol *policy.ACLPolicy
 | 
					func (h *Headscale) policyBytes() ([]byte, error) {
 | 
				
			||||||
		err error
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	switch h.cfg.Policy.Mode {
 | 
						switch h.cfg.Policy.Mode {
 | 
				
			||||||
	case types.PolicyModeFile:
 | 
						case types.PolicyModeFile:
 | 
				
			||||||
		path := h.cfg.Policy.Path
 | 
							path := h.cfg.Policy.Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// It is fine to start headscale without a policy file.
 | 
							// It is fine to start headscale without a policy file.
 | 
				
			||||||
		if len(path) == 0 {
 | 
							if len(path) == 0 {
 | 
				
			||||||
			return nil
 | 
								return nil, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		absPath := util.AbsolutePathFromConfigPath(path)
 | 
							absPath := util.AbsolutePathFromConfigPath(path)
 | 
				
			||||||
		pol, err = policy.LoadACLPolicyFromPath(absPath)
 | 
							policyFile, err := os.Open(absPath)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("failed to load ACL policy from file: %w", err)
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer policyFile.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return io.ReadAll(policyFile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case types.PolicyModeDB:
 | 
				
			||||||
 | 
							p, err := h.db.GetPolicy()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								if errors.Is(err, types.ErrPolicyNotFound) {
 | 
				
			||||||
 | 
									return nil, nil
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return []byte(p.Data), err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) loadPolicyManager() error {
 | 
				
			||||||
 | 
						var errOut error
 | 
				
			||||||
 | 
						h.polManOnce.Do(func() {
 | 
				
			||||||
		// Validate and reject configuration that would error when applied
 | 
							// Validate and reject configuration that would error when applied
 | 
				
			||||||
		// when creating a map response. This requires nodes, so there is still
 | 
							// when creating a map response. This requires nodes, so there is still
 | 
				
			||||||
		// a scenario where they might be allowed if the server has no nodes
 | 
							// a scenario where they might be allowed if the server has no nodes
 | 
				
			||||||
@ -1027,46 +1101,35 @@ func (h *Headscale) loadACLPolicy() error {
 | 
				
			|||||||
		// allowed to be written to the database.
 | 
							// allowed to be written to the database.
 | 
				
			||||||
		nodes, err := h.db.ListNodes()
 | 
							nodes, err := h.db.ListNodes()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | 
								errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		users, err := h.db.ListUsers()
 | 
							users, err := h.db.ListUsers()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("loading users from database to validate policy: %w", err)
 | 
								errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		_, err = pol.CompileFilterRules(users, nodes)
 | 
							pol, err := h.policyBytes()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("verifying policy rules: %w", err)
 | 
								errOut = fmt.Errorf("loading policy bytes: %w", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								errOut = fmt.Errorf("creating policy manager: %w", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(nodes) > 0 {
 | 
							if len(nodes) > 0 {
 | 
				
			||||||
			_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
 | 
								_, err = h.polMan.SSHPolicy(nodes[0])
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return fmt.Errorf("verifying SSH rules: %w", err)
 | 
									errOut = fmt.Errorf("verifying SSH rules: %w", err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	case types.PolicyModeDB:
 | 
						return errOut
 | 
				
			||||||
		p, err := h.db.GetPolicy()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			if errors.Is(err, types.ErrPolicyNotFound) {
 | 
					 | 
				
			||||||
				return nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return fmt.Errorf("failed to get policy from database: %w", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return fmt.Errorf("failed to parse policy: %w", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
		log.Fatal().
 | 
					 | 
				
			||||||
			Str("mode", string(h.cfg.Policy.Mode)).
 | 
					 | 
				
			||||||
			Msg("Unknown ACL policy mode")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	h.ACLPolicy = pol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								http.Error(writer, "Internal server error", http.StatusInternalServerError)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = h.db.Write(func(tx *gorm.DB) error {
 | 
						err = h.db.Write(func(tx *gorm.DB) error {
 | 
				
			||||||
 | 
				
			|||||||
@ -563,7 +563,7 @@ func TestAutoApproveRoutes(t *testing.T) {
 | 
				
			|||||||
			pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
 | 
								pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			require.NoError(t, err)
 | 
								require.NoError(t, err)
 | 
				
			||||||
			assert.NotNil(t, pol)
 | 
								require.NotNil(t, pol)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			user, err := adb.CreateUser("test")
 | 
								user, err := adb.CreateUser("test")
 | 
				
			||||||
			require.NoError(t, err)
 | 
								require.NoError(t, err)
 | 
				
			||||||
@ -600,8 +600,17 @@ func TestAutoApproveRoutes(t *testing.T) {
 | 
				
			|||||||
			node0ByID, err := adb.GetNodeByID(0)
 | 
								node0ByID, err := adb.GetNodeByID(0)
 | 
				
			||||||
			require.NoError(t, err)
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								users, err := adb.ListUsers()
 | 
				
			||||||
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								nodes, err := adb.ListNodes()
 | 
				
			||||||
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
 | 
				
			||||||
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// TODO(kradalby): Check state update
 | 
								// TODO(kradalby): Check state update
 | 
				
			||||||
			err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
 | 
								err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
 | 
				
			||||||
			require.NoError(t, err)
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
 | 
								enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
 | 
				
			||||||
 | 
				
			|||||||
@ -598,18 +598,18 @@ func failoverRoute(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
 | 
					func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
 | 
				
			||||||
	aclPolicy *policy.ACLPolicy,
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	return hsdb.Write(func(tx *gorm.DB) error {
 | 
						return hsdb.Write(func(tx *gorm.DB) error {
 | 
				
			||||||
		return EnableAutoApprovedRoutes(tx, aclPolicy, node)
 | 
							return EnableAutoApprovedRoutes(tx, polMan, node)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
 | 
					// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
 | 
				
			||||||
func EnableAutoApprovedRoutes(
 | 
					func EnableAutoApprovedRoutes(
 | 
				
			||||||
	tx *gorm.DB,
 | 
						tx *gorm.DB,
 | 
				
			||||||
	aclPolicy *policy.ACLPolicy,
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	if node.IPv4 == nil && node.IPv6 == nil {
 | 
						if node.IPv4 == nil && node.IPv6 == nil {
 | 
				
			||||||
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
 | 
							routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
 | 
				
			||||||
			netip.Prefix(advertisedRoute.Prefix),
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Trace().
 | 
							log.Trace().
 | 
				
			||||||
			Str("node", node.Hostname).
 | 
								Str("node", node.Hostname).
 | 
				
			||||||
@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes(
 | 
				
			|||||||
			if approvedAlias == node.User.Username() {
 | 
								if approvedAlias == node.User.Username() {
 | 
				
			||||||
				approvedRoutes = append(approvedRoutes, advertisedRoute)
 | 
									approvedRoutes = append(approvedRoutes, advertisedRoute)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				users, err := ListUsers(tx)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					return fmt.Errorf("looking up users to expand route alias: %w", err)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// TODO(kradalby): figure out how to get this to depend on less stuff
 | 
									// TODO(kradalby): figure out how to get this to depend on less stuff
 | 
				
			||||||
				approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
 | 
									approvedIps, err := polMan.ExpandAlias(approvedAlias)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
 | 
										return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,6 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
						v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/db"
 | 
						"github.com/juanfont/headscale/hscontrol/db"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/policy"
 | 
					 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("updating resources using user: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.CreateUserResponse{User: user.Proto()}, nil
 | 
						return &v1.CreateUserResponse{User: user.Proto()}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -97,6 +101,11 @@ func (api headscaleV1APIServer) DeleteUser(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("updating resources using user: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.DeleteUserResponse{}, nil
 | 
						return &v1.DeleteUserResponse{}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -241,6 +250,11 @@ func (api headscaleV1APIServer) RegisterNode(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("updating resources using node: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
 | 
						return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -480,10 +494,7 @@ func (api headscaleV1APIServer) ListNodes(
 | 
				
			|||||||
			resp.Online = true
 | 
								resp.Online = true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
 | 
							validTags := api.h.polMan.Tags(node)
 | 
				
			||||||
			node,
 | 
					 | 
				
			||||||
		)
 | 
					 | 
				
			||||||
		resp.InvalidTags = invalidTags
 | 
					 | 
				
			||||||
		resp.ValidTags = validTags
 | 
							resp.ValidTags = validTags
 | 
				
			||||||
		response[index] = resp
 | 
							response[index] = resp
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -759,11 +770,6 @@ func (api headscaleV1APIServer) SetPolicy(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	p := request.GetPolicy()
 | 
						p := request.GetPolicy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("loading ACL policy file: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Validate and reject configuration that would error when applied
 | 
						// Validate and reject configuration that would error when applied
 | 
				
			||||||
	// when creating a map response. This requires nodes, so there is still
 | 
						// when creating a map response. This requires nodes, so there is still
 | 
				
			||||||
	// a scenario where they might be allowed if the server has no nodes
 | 
						// a scenario where they might be allowed if the server has no nodes
 | 
				
			||||||
@ -773,18 +779,13 @@ func (api headscaleV1APIServer) SetPolicy(
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | 
							return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	users, err := api.h.db.ListUsers()
 | 
						changed, err := api.h.polMan.SetPolicy([]byte(p))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
 | 
							return nil, fmt.Errorf("setting policy: %w", err)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = pol.CompileFilterRules(users, nodes)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("verifying policy rules: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(nodes) > 0 {
 | 
						if len(nodes) > 0 {
 | 
				
			||||||
		_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
 | 
							_, err = api.h.polMan.SSHPolicy(nodes[0])
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, fmt.Errorf("verifying SSH rules: %w", err)
 | 
								return nil, fmt.Errorf("verifying SSH rules: %w", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -795,12 +796,13 @@ func (api headscaleV1APIServer) SetPolicy(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	api.h.ACLPolicy = pol
 | 
						// Only send update if the packet filter has changed.
 | 
				
			||||||
 | 
						if changed {
 | 
				
			||||||
		ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
 | 
							ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
 | 
				
			||||||
		api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
							api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | 
				
			||||||
			Type: types.StateFullUpdate,
 | 
								Type: types.StateFullUpdate,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	response := &v1.SetPolicyResponse{
 | 
						response := &v1.SetPolicyResponse{
 | 
				
			||||||
		Policy:    updated.Data,
 | 
							Policy:    updated.Data,
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,7 @@ type Mapper struct {
 | 
				
			|||||||
	cfg     *types.Config
 | 
						cfg     *types.Config
 | 
				
			||||||
	derpMap *tailcfg.DERPMap
 | 
						derpMap *tailcfg.DERPMap
 | 
				
			||||||
	notif   *notifier.Notifier
 | 
						notif   *notifier.Notifier
 | 
				
			||||||
 | 
						polMan  policy.PolicyManager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	uid     string
 | 
						uid     string
 | 
				
			||||||
	created time.Time
 | 
						created time.Time
 | 
				
			||||||
@ -71,6 +72,7 @@ func NewMapper(
 | 
				
			|||||||
	cfg *types.Config,
 | 
						cfg *types.Config,
 | 
				
			||||||
	derpMap *tailcfg.DERPMap,
 | 
						derpMap *tailcfg.DERPMap,
 | 
				
			||||||
	notif *notifier.Notifier,
 | 
						notif *notifier.Notifier,
 | 
				
			||||||
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
) *Mapper {
 | 
					) *Mapper {
 | 
				
			||||||
	uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
 | 
						uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -79,6 +81,7 @@ func NewMapper(
 | 
				
			|||||||
		cfg:     cfg,
 | 
							cfg:     cfg,
 | 
				
			||||||
		derpMap: derpMap,
 | 
							derpMap: derpMap,
 | 
				
			||||||
		notif:   notif,
 | 
							notif:   notif,
 | 
				
			||||||
 | 
							polMan:  polMan,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		uid:     uid,
 | 
							uid:     uid,
 | 
				
			||||||
		created: time.Now(),
 | 
							created: time.Now(),
 | 
				
			||||||
@ -153,11 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
 | 
				
			|||||||
func (m *Mapper) fullMapResponse(
 | 
					func (m *Mapper) fullMapResponse(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	peers types.Nodes,
 | 
						peers types.Nodes,
 | 
				
			||||||
	users []types.User,
 | 
					 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
) (*tailcfg.MapResponse, error) {
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
	resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
 | 
						resp, err := m.baseWithConfigMapResponse(node, capVer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -165,11 +166,9 @@ func (m *Mapper) fullMapResponse(
 | 
				
			|||||||
	err = appendPeerChanges(
 | 
						err = appendPeerChanges(
 | 
				
			||||||
		resp,
 | 
							resp,
 | 
				
			||||||
		true, // full change
 | 
							true, // full change
 | 
				
			||||||
		pol,
 | 
							m.polMan,
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		capVer,
 | 
							capVer,
 | 
				
			||||||
		users,
 | 
					 | 
				
			||||||
		peers,
 | 
					 | 
				
			||||||
		peers,
 | 
							peers,
 | 
				
			||||||
		m.cfg,
 | 
							m.cfg,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
@ -184,19 +183,14 @@ func (m *Mapper) fullMapResponse(
 | 
				
			|||||||
func (m *Mapper) FullMapResponse(
 | 
					func (m *Mapper) FullMapResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	messages ...string,
 | 
						messages ...string,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	peers, err := m.ListPeers(node.ID)
 | 
						peers, err := m.ListPeers(node.ID)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	users, err := m.db.ListUsers()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
 | 
						resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -210,10 +204,9 @@ func (m *Mapper) FullMapResponse(
 | 
				
			|||||||
func (m *Mapper) ReadOnlyMapResponse(
 | 
					func (m *Mapper) ReadOnlyMapResponse(
 | 
				
			||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	messages ...string,
 | 
						messages ...string,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
 | 
						resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -249,7 +242,6 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	changed map[types.NodeID]bool,
 | 
						changed map[types.NodeID]bool,
 | 
				
			||||||
	patches []*tailcfg.PeerChange,
 | 
						patches []*tailcfg.PeerChange,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	messages ...string,
 | 
						messages ...string,
 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
@ -259,11 +251,6 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	users, err := m.db.ListUsers()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("listing users for map response: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var removedIDs []tailcfg.NodeID
 | 
						var removedIDs []tailcfg.NodeID
 | 
				
			||||||
	var changedIDs []types.NodeID
 | 
						var changedIDs []types.NodeID
 | 
				
			||||||
	for nodeID, nodeChanged := range changed {
 | 
						for nodeID, nodeChanged := range changed {
 | 
				
			||||||
@ -284,11 +271,9 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
	err = appendPeerChanges(
 | 
						err = appendPeerChanges(
 | 
				
			||||||
		&resp,
 | 
							&resp,
 | 
				
			||||||
		false, // partial change
 | 
							false, // partial change
 | 
				
			||||||
		pol,
 | 
							m.polMan,
 | 
				
			||||||
		node,
 | 
							node,
 | 
				
			||||||
		mapRequest.Version,
 | 
							mapRequest.Version,
 | 
				
			||||||
		users,
 | 
					 | 
				
			||||||
		peers,
 | 
					 | 
				
			||||||
		changedNodes,
 | 
							changedNodes,
 | 
				
			||||||
		m.cfg,
 | 
							m.cfg,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
@ -315,7 +300,7 @@ func (m *Mapper) PeerChangedResponse(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Add the node itself, it might have changed, and particularly
 | 
						// Add the node itself, it might have changed, and particularly
 | 
				
			||||||
	// if there are no patches or changes, this is a self update.
 | 
						// if there are no patches or changes, this is a self update.
 | 
				
			||||||
	tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
 | 
						tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -330,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse(
 | 
				
			|||||||
	mapRequest tailcfg.MapRequest,
 | 
						mapRequest tailcfg.MapRequest,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	changed []*tailcfg.PeerChange,
 | 
						changed []*tailcfg.PeerChange,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
) ([]byte, error) {
 | 
					) ([]byte, error) {
 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
	resp.PeersChangedPatch = changed
 | 
						resp.PeersChangedPatch = changed
 | 
				
			||||||
@ -459,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
 | 
				
			|||||||
// incremental.
 | 
					// incremental.
 | 
				
			||||||
func (m *Mapper) baseWithConfigMapResponse(
 | 
					func (m *Mapper) baseWithConfigMapResponse(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
					 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
) (*tailcfg.MapResponse, error) {
 | 
					) (*tailcfg.MapResponse, error) {
 | 
				
			||||||
	resp := m.baseMapResponse()
 | 
						resp := m.baseMapResponse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tailnode, err := tailNode(node, capVer, pol, m.cfg)
 | 
						tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -517,35 +500,30 @@ func appendPeerChanges(
 | 
				
			|||||||
	resp *tailcfg.MapResponse,
 | 
						resp *tailcfg.MapResponse,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fullChange bool,
 | 
						fullChange bool,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	users []types.User,
 | 
					 | 
				
			||||||
	peers types.Nodes,
 | 
					 | 
				
			||||||
	changed types.Nodes,
 | 
						changed types.Nodes,
 | 
				
			||||||
	cfg *types.Config,
 | 
						cfg *types.Config,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
 | 
						filter := polMan.Filter()
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
 | 
						sshPolicy, err := polMan.SSHPolicy(node)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If there are filter rules present, see if there are any nodes that cannot
 | 
						// If there are filter rules present, see if there are any nodes that cannot
 | 
				
			||||||
	// access each-other at all and remove them from the peers.
 | 
						// access each-other at all and remove them from the peers.
 | 
				
			||||||
	if len(packetFilter) > 0 {
 | 
						if len(filter) > 0 {
 | 
				
			||||||
		changed = policy.FilterNodesByACL(node, changed, packetFilter)
 | 
							changed = policy.FilterNodesByACL(node, changed, filter)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	profiles := generateUserProfiles(node, changed)
 | 
						profiles := generateUserProfiles(node, changed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dnsConfig := generateDNSConfig(cfg, node)
 | 
						dnsConfig := generateDNSConfig(cfg, node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tailPeers, err := tailNodes(changed, capVer, pol, cfg)
 | 
						tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -570,7 +548,7 @@ func appendPeerChanges(
 | 
				
			|||||||
		// new PacketFilters field and "base" allows us to send a full update when we
 | 
							// new PacketFilters field and "base" allows us to send a full update when we
 | 
				
			||||||
		// have to send an empty list, avoiding the hack in the else block.
 | 
							// have to send an empty list, avoiding the hack in the else block.
 | 
				
			||||||
		resp.PacketFilters = map[string][]tailcfg.FilterRule{
 | 
							resp.PacketFilters = map[string][]tailcfg.FilterRule{
 | 
				
			||||||
			"base": policy.ReduceFilterRules(node, packetFilter),
 | 
								"base": policy.ReduceFilterRules(node, filter),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		// This is a hack to avoid sending an empty list of packet filters.
 | 
							// This is a hack to avoid sending an empty list of packet filters.
 | 
				
			||||||
@ -578,11 +556,11 @@ func appendPeerChanges(
 | 
				
			|||||||
		// be omitted, causing the client to consider it unchanged, keeping the
 | 
							// be omitted, causing the client to consider it unchanged, keeping the
 | 
				
			||||||
		// previous packet filter. Worst case, this can cause a node that previously
 | 
							// previous packet filter. Worst case, this can cause a node that previously
 | 
				
			||||||
		// has access to a node to _not_ loose access if an empty (allow none) is sent.
 | 
							// has access to a node to _not_ loose access if an empty (allow none) is sent.
 | 
				
			||||||
		reduced := policy.ReduceFilterRules(node, packetFilter)
 | 
							reduced := policy.ReduceFilterRules(node, filter)
 | 
				
			||||||
		if len(reduced) > 0 {
 | 
							if len(reduced) > 0 {
 | 
				
			||||||
			resp.PacketFilter = reduced
 | 
								resp.PacketFilter = reduced
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			resp.PacketFilter = packetFilter
 | 
								resp.PacketFilter = filter
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -461,18 +461,19 @@ func Test_fullMapResponse(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			mappy := NewMapper(
 | 
								mappy := NewMapper(
 | 
				
			||||||
				nil,
 | 
									nil,
 | 
				
			||||||
				tt.cfg,
 | 
									tt.cfg,
 | 
				
			||||||
				tt.derpMap,
 | 
									tt.derpMap,
 | 
				
			||||||
				nil,
 | 
									nil,
 | 
				
			||||||
 | 
									polMan,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			got, err := mappy.fullMapResponse(
 | 
								got, err := mappy.fullMapResponse(
 | 
				
			||||||
				tt.node,
 | 
									tt.node,
 | 
				
			||||||
				tt.peers,
 | 
									tt.peers,
 | 
				
			||||||
				[]types.User{user1, user2},
 | 
					 | 
				
			||||||
				tt.pol,
 | 
					 | 
				
			||||||
				0,
 | 
									0,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -14,7 +14,7 @@ import (
 | 
				
			|||||||
func tailNodes(
 | 
					func tailNodes(
 | 
				
			||||||
	nodes types.Nodes,
 | 
						nodes types.Nodes,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
	cfg *types.Config,
 | 
						cfg *types.Config,
 | 
				
			||||||
) ([]*tailcfg.Node, error) {
 | 
					) ([]*tailcfg.Node, error) {
 | 
				
			||||||
	tNodes := make([]*tailcfg.Node, len(nodes))
 | 
						tNodes := make([]*tailcfg.Node, len(nodes))
 | 
				
			||||||
@ -23,7 +23,7 @@ func tailNodes(
 | 
				
			|||||||
		node, err := tailNode(
 | 
							node, err := tailNode(
 | 
				
			||||||
			node,
 | 
								node,
 | 
				
			||||||
			capVer,
 | 
								capVer,
 | 
				
			||||||
			pol,
 | 
								polMan,
 | 
				
			||||||
			cfg,
 | 
								cfg,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@ -40,7 +40,7 @@ func tailNodes(
 | 
				
			|||||||
func tailNode(
 | 
					func tailNode(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	capVer tailcfg.CapabilityVersion,
 | 
						capVer tailcfg.CapabilityVersion,
 | 
				
			||||||
	pol *policy.ACLPolicy,
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
	cfg *types.Config,
 | 
						cfg *types.Config,
 | 
				
			||||||
) (*tailcfg.Node, error) {
 | 
					) (*tailcfg.Node, error) {
 | 
				
			||||||
	addrs := node.Prefixes()
 | 
						addrs := node.Prefixes()
 | 
				
			||||||
@ -81,7 +81,7 @@ func tailNode(
 | 
				
			|||||||
		return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
 | 
							return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tags, _ := pol.TagsOfNode(node)
 | 
						tags := polMan.Tags(node)
 | 
				
			||||||
	tags = lo.Uniq(append(tags, node.ForcedTags...))
 | 
						tags = lo.Uniq(append(tags, node.ForcedTags...))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tNode := tailcfg.Node{
 | 
						tNode := tailcfg.Node{
 | 
				
			||||||
 | 
				
			|||||||
@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
 | 
				
			||||||
			cfg := &types.Config{
 | 
								cfg := &types.Config{
 | 
				
			||||||
				BaseDomain:          tt.baseDomain,
 | 
									BaseDomain:          tt.baseDomain,
 | 
				
			||||||
				DNSConfig:           tt.dnsConfig,
 | 
									DNSConfig:           tt.dnsConfig,
 | 
				
			||||||
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
 | 
				
			|||||||
			got, err := tailNode(
 | 
								got, err := tailNode(
 | 
				
			||||||
				tt.node,
 | 
									tt.node,
 | 
				
			||||||
				0,
 | 
									0,
 | 
				
			||||||
				tt.pol,
 | 
									polMan,
 | 
				
			||||||
				cfg,
 | 
									cfg,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
 | 
				
			|||||||
			tn, err := tailNode(
 | 
								tn, err := tailNode(
 | 
				
			||||||
				node,
 | 
									node,
 | 
				
			||||||
				0,
 | 
									0,
 | 
				
			||||||
				&policy.ACLPolicy{},
 | 
									&policy.PolicyManagerV1{},
 | 
				
			||||||
				&types.Config{},
 | 
									&types.Config{},
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -18,6 +18,7 @@ import (
 | 
				
			|||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/db"
 | 
						"github.com/juanfont/headscale/hscontrol/db"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/notifier"
 | 
						"github.com/juanfont/headscale/hscontrol/notifier"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
 | 
				
			|||||||
	registrationCache *zcache.Cache[string, key.MachinePublic]
 | 
						registrationCache *zcache.Cache[string, key.MachinePublic]
 | 
				
			||||||
	notifier          *notifier.Notifier
 | 
						notifier          *notifier.Notifier
 | 
				
			||||||
	ipAlloc           *db.IPAllocator
 | 
						ipAlloc           *db.IPAllocator
 | 
				
			||||||
 | 
						polMan            policy.PolicyManager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oidcProvider *oidc.Provider
 | 
						oidcProvider *oidc.Provider
 | 
				
			||||||
	oauth2Config *oauth2.Config
 | 
						oauth2Config *oauth2.Config
 | 
				
			||||||
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
 | 
				
			|||||||
	db *db.HSDatabase,
 | 
						db *db.HSDatabase,
 | 
				
			||||||
	notif *notifier.Notifier,
 | 
						notif *notifier.Notifier,
 | 
				
			||||||
	ipAlloc *db.IPAllocator,
 | 
						ipAlloc *db.IPAllocator,
 | 
				
			||||||
 | 
						polMan policy.PolicyManager,
 | 
				
			||||||
) (*AuthProviderOIDC, error) {
 | 
					) (*AuthProviderOIDC, error) {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	// grab oidc config if it hasn't been already
 | 
						// grab oidc config if it hasn't been already
 | 
				
			||||||
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
 | 
				
			|||||||
		registrationCache: registrationCache,
 | 
							registrationCache: registrationCache,
 | 
				
			||||||
		notifier:          notif,
 | 
							notifier:          notif,
 | 
				
			||||||
		ipAlloc:           ipAlloc,
 | 
							ipAlloc:           ipAlloc,
 | 
				
			||||||
 | 
							polMan:            polMan,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		oidcProvider: oidcProvider,
 | 
							oidcProvider: oidcProvider,
 | 
				
			||||||
		oauth2Config: oauth2Config,
 | 
							oauth2Config: oauth2Config,
 | 
				
			||||||
@ -478,6 +482,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
 | 
				
			|||||||
		return nil, fmt.Errorf("creating or updating user: %w", err)
 | 
							return nil, fmt.Errorf("creating or updating user: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = usersChangedHook(a.db, a.polMan, a.notifier)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("updating resources using user: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return user, nil
 | 
						return user, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -501,6 +510,11 @@ func (a *AuthProviderOIDC) registerNode(
 | 
				
			|||||||
		return fmt.Errorf("could not register node: %w", err)
 | 
							return fmt.Errorf("could not register node: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = nodesChangedHook(a.db, a.polMan, a.notifier)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("updating resources using node: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										181
									
								
								hscontrol/policy/pm.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								hscontrol/policy/pm.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
				
			|||||||
 | 
					package policy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/netip"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
 | 
						"go4.org/netipx"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
						"tailscale.com/util/deephash"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PolicyManager interface {
 | 
				
			||||||
 | 
						Filter() []tailcfg.FilterRule
 | 
				
			||||||
 | 
						SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
 | 
				
			||||||
 | 
						Tags(*types.Node) []string
 | 
				
			||||||
 | 
						ApproversForRoute(netip.Prefix) []string
 | 
				
			||||||
 | 
						ExpandAlias(string) (*netipx.IPSet, error)
 | 
				
			||||||
 | 
						SetPolicy([]byte) (bool, error)
 | 
				
			||||||
 | 
						SetUsers(users []types.User) (bool, error)
 | 
				
			||||||
 | 
						SetNodes(nodes types.Nodes) (bool, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
 | 
				
			||||||
 | 
						policyFile, err := os.Open(path)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer policyFile.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						policyBytes, err := io.ReadAll(policyFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return NewPolicyManager(policyBytes, users, nodes)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
 | 
				
			||||||
 | 
						var pol *ACLPolicy
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						if polB != nil && len(polB) > 0 {
 | 
				
			||||||
 | 
							pol, err = LoadACLPolicyFromBytes(polB)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, fmt.Errorf("parsing policy: %w", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pm := PolicyManagerV1{
 | 
				
			||||||
 | 
							pol:   pol,
 | 
				
			||||||
 | 
							users: users,
 | 
				
			||||||
 | 
							nodes: nodes,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = pm.updateLocked()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &pm, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
 | 
				
			||||||
 | 
						pm := PolicyManagerV1{
 | 
				
			||||||
 | 
							pol:   pol,
 | 
				
			||||||
 | 
							users: users,
 | 
				
			||||||
 | 
							nodes: nodes,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err := pm.updateLocked()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &pm, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PolicyManagerV1 struct {
 | 
				
			||||||
 | 
						mu  sync.Mutex
 | 
				
			||||||
 | 
						pol *ACLPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						users []types.User
 | 
				
			||||||
 | 
						nodes types.Nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						filterHash deephash.Sum
 | 
				
			||||||
 | 
						filter     []tailcfg.FilterRule
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// updateLocked updates the filter rules based on the current policy and nodes.
 | 
				
			||||||
 | 
					// It must be called with the lock held.
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) updateLocked() (bool, error) {
 | 
				
			||||||
 | 
						filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, fmt.Errorf("compiling filter rules: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						filterHash := deephash.Hash(&filter)
 | 
				
			||||||
 | 
						if filterHash == pm.filterHash {
 | 
				
			||||||
 | 
							return false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pm.filter = filter
 | 
				
			||||||
 | 
						pm.filterHash = filterHash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
 | 
				
			||||||
 | 
						pm.mu.Lock()
 | 
				
			||||||
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
 | 
						return pm.filter
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
 | 
				
			||||||
 | 
						pm.mu.Lock()
 | 
				
			||||||
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
 | 
				
			||||||
 | 
						pol, err := LoadACLPolicyFromBytes(polB)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, fmt.Errorf("parsing policy: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pm.mu.Lock()
 | 
				
			||||||
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pm.pol = pol
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return pm.updateLocked()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetUsers updates the users in the policy manager and updates the filter rules.
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
 | 
				
			||||||
 | 
						pm.mu.Lock()
 | 
				
			||||||
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pm.users = users
 | 
				
			||||||
 | 
						return pm.updateLocked()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetNodes updates the nodes in the policy manager and updates the filter rules.
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
 | 
				
			||||||
 | 
						pm.mu.Lock()
 | 
				
			||||||
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
 | 
						pm.nodes = nodes
 | 
				
			||||||
 | 
						return pm.updateLocked()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
 | 
				
			||||||
 | 
						if pm == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tags, _ := pm.pol.TagsOfNode(node)
 | 
				
			||||||
 | 
						return tags
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
 | 
				
			||||||
 | 
						// TODO(kradalby): This can be a parse error of the address in the policy,
 | 
				
			||||||
 | 
						// in the new policy this will be typed and not a problem, in this policy
 | 
				
			||||||
 | 
						// we will just return empty list
 | 
				
			||||||
 | 
						if pm.pol == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
 | 
				
			||||||
 | 
						return approvers
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
 | 
				
			||||||
 | 
						ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ips, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										158
									
								
								hscontrol/policy/pm_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								hscontrol/policy/pm_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,158 @@
 | 
				
			|||||||
 | 
					package policy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPolicySetChange(t *testing.T) {
 | 
				
			||||||
 | 
						users := []types.User{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Model: gorm.Model{ID: 1},
 | 
				
			||||||
 | 
								Name:  "testuser",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name             string
 | 
				
			||||||
 | 
							users            []types.User
 | 
				
			||||||
 | 
							nodes            types.Nodes
 | 
				
			||||||
 | 
							policy           []byte
 | 
				
			||||||
 | 
							wantUsersChange  bool
 | 
				
			||||||
 | 
							wantNodesChange  bool
 | 
				
			||||||
 | 
							wantPolicyChange bool
 | 
				
			||||||
 | 
							wantFilter       []tailcfg.FilterRule
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "set-nodes",
 | 
				
			||||||
 | 
								nodes: types.Nodes{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										IPv4: iap("100.64.0.2"),
 | 
				
			||||||
 | 
										User: users[0],
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantNodesChange: false,
 | 
				
			||||||
 | 
								wantFilter: []tailcfg.FilterRule{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:            "set-users",
 | 
				
			||||||
 | 
								users:           users,
 | 
				
			||||||
 | 
								wantUsersChange: false,
 | 
				
			||||||
 | 
								wantFilter: []tailcfg.FilterRule{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "set-users-and-node",
 | 
				
			||||||
 | 
								users: users,
 | 
				
			||||||
 | 
								nodes: types.Nodes{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										IPv4: iap("100.64.0.2"),
 | 
				
			||||||
 | 
										User: users[0],
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantUsersChange: false,
 | 
				
			||||||
 | 
								wantNodesChange: true,
 | 
				
			||||||
 | 
								wantFilter: []tailcfg.FilterRule{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										SrcIPs:   []string{"100.64.0.2/32"},
 | 
				
			||||||
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "set-policy",
 | 
				
			||||||
 | 
								policy: []byte(`
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					"acls": [
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"action": "accept",
 | 
				
			||||||
 | 
								"src": [
 | 
				
			||||||
 | 
									"100.64.0.61",
 | 
				
			||||||
 | 
								],
 | 
				
			||||||
 | 
								"dst": [
 | 
				
			||||||
 | 
									"100.64.0.62:*",
 | 
				
			||||||
 | 
								],
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							],
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
									`),
 | 
				
			||||||
 | 
								wantPolicyChange: true,
 | 
				
			||||||
 | 
								wantFilter: []tailcfg.FilterRule{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										SrcIPs:   []string{"100.64.0.61/32"},
 | 
				
			||||||
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								pol := `
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						"groups": {
 | 
				
			||||||
 | 
							"group:example": [
 | 
				
			||||||
 | 
								"testuser",
 | 
				
			||||||
 | 
							],
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"hosts": {
 | 
				
			||||||
 | 
							"host-1": "100.64.0.1",
 | 
				
			||||||
 | 
							"subnet-1": "100.100.101.100/24",
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"acls": [
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"action": "accept",
 | 
				
			||||||
 | 
								"src": [
 | 
				
			||||||
 | 
									"group:example",
 | 
				
			||||||
 | 
								],
 | 
				
			||||||
 | 
								"dst": [
 | 
				
			||||||
 | 
									"host-1:*",
 | 
				
			||||||
 | 
								],
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						],
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					`
 | 
				
			||||||
 | 
								pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if tt.policy != nil {
 | 
				
			||||||
 | 
									change, err := pm.SetPolicy(tt.policy)
 | 
				
			||||||
 | 
									require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									assert.Equal(t, tt.wantPolicyChange, change)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if tt.users != nil {
 | 
				
			||||||
 | 
									change, err := pm.SetUsers(tt.users)
 | 
				
			||||||
 | 
									require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									assert.Equal(t, tt.wantUsersChange, change)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if tt.nodes != nil {
 | 
				
			||||||
 | 
									change, err := pm.SetNodes(tt.nodes)
 | 
				
			||||||
 | 
									require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									assert.Equal(t, tt.wantNodesChange, change)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
 | 
				
			||||||
 | 
									t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
 | 
				
			|||||||
			switch update.Type {
 | 
								switch update.Type {
 | 
				
			||||||
			case types.StateFullUpdate:
 | 
								case types.StateFullUpdate:
 | 
				
			||||||
				m.tracef("Sending Full MapResponse")
 | 
									m.tracef("Sending Full MapResponse")
 | 
				
			||||||
				data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
 | 
									data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
 | 
				
			||||||
			case types.StatePeerChanged:
 | 
								case types.StatePeerChanged:
 | 
				
			||||||
				changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
 | 
									changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
				lastMessage = update.Message
 | 
									lastMessage = update.Message
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
				data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
 | 
				
			||||||
				updateType = "change"
 | 
									updateType = "change"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			case types.StatePeerChangedPatch:
 | 
								case types.StatePeerChangedPatch:
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
 | 
				
			||||||
				data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
 | 
									data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
 | 
				
			||||||
				updateType = "patch"
 | 
									updateType = "patch"
 | 
				
			||||||
			case types.StatePeerRemoved:
 | 
								case types.StatePeerRemoved:
 | 
				
			||||||
				changed := make(map[types.NodeID]bool, len(update.Removed))
 | 
									changed := make(map[types.NodeID]bool, len(update.Removed))
 | 
				
			||||||
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
 | 
				
			|||||||
					changed[nodeID] = false
 | 
										changed[nodeID] = false
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
				data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
 | 
				
			||||||
				updateType = "remove"
 | 
									updateType = "remove"
 | 
				
			||||||
			case types.StateSelfUpdate:
 | 
								case types.StateSelfUpdate:
 | 
				
			||||||
				lastMessage = update.Message
 | 
									lastMessage = update.Message
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
				// create the map so an empty (self) update is sent
 | 
									// create the map so an empty (self) update is sent
 | 
				
			||||||
				data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
 | 
				
			||||||
				updateType = "remove"
 | 
									updateType = "remove"
 | 
				
			||||||
			case types.StateDERPUpdated:
 | 
								case types.StateDERPUpdated:
 | 
				
			||||||
				m.tracef("Sending DERPUpdate MapResponse")
 | 
									m.tracef("Sending DERPUpdate MapResponse")
 | 
				
			||||||
@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if m.h.ACLPolicy != nil {
 | 
							// TODO(kradalby): Only update the node that has actually changed
 | 
				
			||||||
 | 
							nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if m.h.polMan != nil {
 | 
				
			||||||
			// update routes with peer information
 | 
								// update routes with peer information
 | 
				
			||||||
			err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
 | 
								err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				m.errf(err, "Error running auto approved routes")
 | 
									m.errf(err, "Error running auto approved routes")
 | 
				
			||||||
				mapResponseEndpointUpdates.WithLabelValues("error").Inc()
 | 
									mapResponseEndpointUpdates.WithLabelValues("error").Inc()
 | 
				
			||||||
@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() {
 | 
				
			|||||||
func (m *mapSession) handleReadOnlyRequest() {
 | 
					func (m *mapSession) handleReadOnlyRequest() {
 | 
				
			||||||
	m.tracef("Client asked for a lite update, responding without peers")
 | 
						m.tracef("Client asked for a lite update, responding without peers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy)
 | 
						mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		m.errf(err, "Failed to create MapResponse")
 | 
							m.errf(err, "Failed to create MapResponse")
 | 
				
			||||||
		http.Error(m.w, "", http.StatusInternalServerError)
 | 
							http.Error(m.w, "", http.StatusInternalServerError)
 | 
				
			||||||
 | 
				
			|||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -79,6 +79,10 @@ type Option = func(c *HeadscaleInContainer)
 | 
				
			|||||||
// HeadscaleInContainer instance.
 | 
					// HeadscaleInContainer instance.
 | 
				
			||||||
func WithACLPolicy(acl *policy.ACLPolicy) Option {
 | 
					func WithACLPolicy(acl *policy.ACLPolicy) Option {
 | 
				
			||||||
	return func(hsic *HeadscaleInContainer) {
 | 
						return func(hsic *HeadscaleInContainer) {
 | 
				
			||||||
 | 
							if acl == nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// TODO(kradalby): Move somewhere appropriate
 | 
							// TODO(kradalby): Move somewhere appropriate
 | 
				
			||||||
		hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
 | 
							hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user