mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 10:01:05 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			293 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			293 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package hscontrol
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/juanfont/headscale/hscontrol/db"
 | 
						|
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						|
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						|
	"gorm.io/gorm"
 | 
						|
	"tailscale.com/tailcfg"
 | 
						|
	"tailscale.com/types/key"
 | 
						|
	"tailscale.com/types/ptr"
 | 
						|
)
 | 
						|
 | 
						|
type AuthProvider interface {
 | 
						|
	RegisterHandler(http.ResponseWriter, *http.Request)
 | 
						|
	AuthURL(types.RegistrationID) string
 | 
						|
}
 | 
						|
 | 
						|
func (h *Headscale) handleRegister(
 | 
						|
	ctx context.Context,
 | 
						|
	regReq tailcfg.RegisterRequest,
 | 
						|
	machineKey key.MachinePublic,
 | 
						|
) (*tailcfg.RegisterResponse, error) {
 | 
						|
	node, err := h.db.GetNodeByNodeKey(regReq.NodeKey)
 | 
						|
	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						|
		return nil, fmt.Errorf("looking up node in database: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if node != nil {
 | 
						|
		resp, err := h.handleExistingNode(node, regReq, machineKey)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("handling existing node: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		return resp, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if regReq.Followup != "" {
 | 
						|
		return h.waitForFollowup(ctx, regReq)
 | 
						|
	}
 | 
						|
 | 
						|
	if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
 | 
						|
		resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("handling register with auth key: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		return resp, nil
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err := h.handleRegisterInteractive(regReq, machineKey)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("handling register interactive: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return resp, nil
 | 
						|
}
 | 
						|
 | 
						|
func (h *Headscale) handleExistingNode(
 | 
						|
	node *types.Node,
 | 
						|
	regReq tailcfg.RegisterRequest,
 | 
						|
	machineKey key.MachinePublic,
 | 
						|
) (*tailcfg.RegisterResponse, error) {
 | 
						|
	if node.MachineKey != machineKey {
 | 
						|
		return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
 | 
						|
	}
 | 
						|
 | 
						|
	expired := node.IsExpired()
 | 
						|
	if !expired && !regReq.Expiry.IsZero() {
 | 
						|
		requestExpiry := regReq.Expiry
 | 
						|
 | 
						|
		// The client is trying to extend their key, this is not allowed.
 | 
						|
		if requestExpiry.After(time.Now()) {
 | 
						|
			return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
 | 
						|
		}
 | 
						|
 | 
						|
		// If the request expiry is in the past, we consider it a logout.
 | 
						|
		if requestExpiry.Before(time.Now()) {
 | 
						|
			if node.IsEphemeral() {
 | 
						|
				err := h.db.DeleteNode(node)
 | 
						|
				if err != nil {
 | 
						|
					return nil, fmt.Errorf("deleting ephemeral node: %w", err)
 | 
						|
				}
 | 
						|
 | 
						|
				ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
 | 
						|
				h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
 | 
						|
			}
 | 
						|
 | 
						|
			expired = true
 | 
						|
		}
 | 
						|
 | 
						|
		err := h.db.NodeSetExpiry(node.ID, requestExpiry)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("setting node expiry: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
 | 
						|
		h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
 | 
						|
	}
 | 
						|
 | 
						|
	return nodeToRegisterResponse(node), nil
 | 
						|
}
 | 
						|
 | 
						|
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
 | 
						|
	return &tailcfg.RegisterResponse{
 | 
						|
		// TODO(kradalby): Only send for user-owned nodes
 | 
						|
		// and not tagged nodes when tags is working.
 | 
						|
		User:           *node.User.TailscaleUser(),
 | 
						|
		Login:          *node.User.TailscaleLogin(),
 | 
						|
		NodeKeyExpired: node.IsExpired(),
 | 
						|
 | 
						|
		// Headscale does not implement the concept of machine authorization
 | 
						|
		// so we always return true here.
 | 
						|
		// Revisit this if #2176 gets implemented.
 | 
						|
		MachineAuthorized: true,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (h *Headscale) waitForFollowup(
 | 
						|
	ctx context.Context,
 | 
						|
	regReq tailcfg.RegisterRequest,
 | 
						|
) (*tailcfg.RegisterResponse, error) {
 | 
						|
	fu, err := url.Parse(regReq.Followup)
 | 
						|
	if err != nil {
 | 
						|
		return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
 | 
						|
	}
 | 
						|
 | 
						|
	followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
 | 
						|
	if err != nil {
 | 
						|
		return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if reg, ok := h.registrationCache.Get(followupReg); ok {
 | 
						|
		select {
 | 
						|
		case <-ctx.Done():
 | 
						|
			return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
 | 
						|
		case node := <-reg.Registered:
 | 
						|
			if node == nil {
 | 
						|
				return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
 | 
						|
			}
 | 
						|
			return nodeToRegisterResponse(node), nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
 | 
						|
}
 | 
						|
 | 
						|
// canUsePreAuthKey checks if a pre auth key can be used.
 | 
						|
func canUsePreAuthKey(pak *types.PreAuthKey) error {
 | 
						|
	if pak == nil {
 | 
						|
		return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
 | 
						|
	}
 | 
						|
	if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
 | 
						|
		return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
 | 
						|
	}
 | 
						|
 | 
						|
	// we don't need to check if has been used before
 | 
						|
	if pak.Reusable {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	if pak.Used {
 | 
						|
		return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (h *Headscale) handleRegisterWithAuthKey(
 | 
						|
	regReq tailcfg.RegisterRequest,
 | 
						|
	machineKey key.MachinePublic,
 | 
						|
) (*tailcfg.RegisterResponse, error) {
 | 
						|
	pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
 | 
						|
	if err != nil {
 | 
						|
		if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						|
			return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
 | 
						|
		}
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	err = canUsePreAuthKey(pak)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	nodeToRegister := types.Node{
 | 
						|
		Hostname:       regReq.Hostinfo.Hostname,
 | 
						|
		UserID:         pak.User.ID,
 | 
						|
		User:           pak.User,
 | 
						|
		MachineKey:     machineKey,
 | 
						|
		NodeKey:        regReq.NodeKey,
 | 
						|
		Hostinfo:       regReq.Hostinfo,
 | 
						|
		LastSeen:       ptr.To(time.Now()),
 | 
						|
		RegisterMethod: util.RegisterMethodAuthKey,
 | 
						|
 | 
						|
		// TODO(kradalby): This should not be set on the node,
 | 
						|
		// they should be looked up through the key, which is
 | 
						|
		// attached to the node.
 | 
						|
		ForcedTags: pak.Proto().GetAclTags(),
 | 
						|
		AuthKey:    pak,
 | 
						|
		AuthKeyID:  &pak.ID,
 | 
						|
	}
 | 
						|
 | 
						|
	if !regReq.Expiry.IsZero() {
 | 
						|
		nodeToRegister.Expiry = ®Req.Expiry
 | 
						|
	}
 | 
						|
 | 
						|
	ipv4, ipv6, err := h.ipAlloc.Next()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("allocating IPs: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | 
						|
		node, err := db.RegisterNode(tx,
 | 
						|
			nodeToRegister,
 | 
						|
			ipv4, ipv6,
 | 
						|
		)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("registering node: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		if !pak.Reusable {
 | 
						|
			err = db.UsePreAuthKey(tx, pak)
 | 
						|
			if err != nil {
 | 
						|
				return nil, fmt.Errorf("using pre auth key: %w", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		return node, nil
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("nodes changed hook: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if !updateSent {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
 | 
						|
		h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
 | 
						|
	}
 | 
						|
 | 
						|
	return &tailcfg.RegisterResponse{
 | 
						|
		MachineAuthorized: true,
 | 
						|
		NodeKeyExpired:    node.IsExpired(),
 | 
						|
		User:              *pak.User.TailscaleUser(),
 | 
						|
		Login:             *pak.User.TailscaleLogin(),
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (h *Headscale) handleRegisterInteractive(
 | 
						|
	regReq tailcfg.RegisterRequest,
 | 
						|
	machineKey key.MachinePublic,
 | 
						|
) (*tailcfg.RegisterResponse, error) {
 | 
						|
	registrationId, err := types.NewRegistrationID()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("generating registration ID: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	newNode := types.RegisterNode{
 | 
						|
		Node: types.Node{
 | 
						|
			Hostname:   regReq.Hostinfo.Hostname,
 | 
						|
			MachineKey: machineKey,
 | 
						|
			NodeKey:    regReq.NodeKey,
 | 
						|
			Hostinfo:   regReq.Hostinfo,
 | 
						|
			LastSeen:   ptr.To(time.Now()),
 | 
						|
		},
 | 
						|
		Registered: make(chan *types.Node),
 | 
						|
	}
 | 
						|
 | 
						|
	if !regReq.Expiry.IsZero() {
 | 
						|
		newNode.Node.Expiry = ®Req.Expiry
 | 
						|
	}
 | 
						|
 | 
						|
	h.registrationCache.Set(
 | 
						|
		registrationId,
 | 
						|
		newNode,
 | 
						|
	)
 | 
						|
 | 
						|
	return &tailcfg.RegisterResponse{
 | 
						|
		AuthURL: h.authProvider.AuthURL(registrationId),
 | 
						|
	}, nil
 | 
						|
}
 |