omni/internal/pkg/siderolink/provision.go
Utku Ozdemir 0e76483bab
Some checks failed
default / default (push) Has been cancelled
default / e2e-backups (push) Has been cancelled
default / e2e-forced-removal (push) Has been cancelled
default / e2e-omni-upgrade (push) Has been cancelled
default / e2e-scaling (push) Has been cancelled
default / e2e-short (push) Has been cancelled
default / e2e-short-secureboot (push) Has been cancelled
default / e2e-templates (push) Has been cancelled
default / e2e-upgrades (push) Has been cancelled
default / e2e-workload-proxy (push) Has been cancelled
chore: rekres, bump deps, Go, Talos and k8s versions, satisfy linters
- Bump some deps, namely cosi-runtime and Talos machinery.
- Update `auditState` to implement the new methods in COSI's `state.State`.
- Bump default Talos and Kubernetes versions to their latest.
- Rekres, which brings Go 1.24.5. Also update it in go.mod files.
- Fix linter errors coming from new linters.

Signed-off-by: Utku Ozdemir <utku.ozdemir@siderolabs.com>
2025-07-11 18:23:48 +02:00

614 lines
20 KiB
Go

// Copyright (c) 2025 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package siderolink
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/blang/semver/v4"
"github.com/cosi-project/runtime/pkg/controller/generic"
"github.com/cosi-project/runtime/pkg/resource"
"github.com/cosi-project/runtime/pkg/resource/protobuf"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
"github.com/siderolabs/go-pointer"
pb "github.com/siderolabs/siderolink/api/siderolink"
"github.com/siderolabs/siderolink/pkg/wireguard"
"go.uber.org/zap"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/client/pkg/jointoken"
"github.com/siderolabs/omni/client/pkg/omni/resources/infra"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/client/pkg/omni/resources/siderolink"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/config"
)
var minSupportedSecureTokensVersion = semver.MustParse("1.6.0")
var errUUIDConflict = fmt.Errorf("UUID conflict")
type provisionContext struct {
siderolinkConfig *siderolink.Config
link *siderolink.Link
pendingMachine *siderolink.PendingMachine
pendingMachineStatus *siderolink.PendingMachineStatus
token *jointoken.JoinToken
requestNodeUniqueToken *jointoken.NodeUniqueToken
linkNodeUniqueToken *jointoken.NodeUniqueToken
request *pb.ProvisionRequest
// flags
hasValidJoinToken bool
hasValidNodeUniqueToken bool
nodeUniqueTokensEnabled bool
forceValidNodeUniqueToken bool
supportsSecureJoinTokens bool
tokenWasWiped bool
useWireguardOverGRPC bool
}
func (pc *provisionContext) isAuthorizedLegacyJoin() bool {
// explicitly reject legacy machine if the link has node unique token set
if !pc.supportsSecureJoinTokens && pc.linkNodeUniqueToken != nil && pc.nodeUniqueTokensEnabled {
return false
}
return pc.hasValidJoinToken
}
func (pc *provisionContext) isAuthorizedSecureFlow() bool {
return pc.hasValidJoinToken || pc.hasValidNodeUniqueToken
}
// NewProvisionHandler creates a new ProvisionHandler.
func NewProvisionHandler(logger *zap.Logger, state state.State, joinTokenMode string, forceWireguardOverGRPC bool) *ProvisionHandler {
return &ProvisionHandler{
logger: logger,
state: state,
joinTokenMode: joinTokenMode,
forceWireguardOverGRPC: forceWireguardOverGRPC,
}
}
// ProvisionHandler is the gRPC service that handles provision responses coming from the Talos nodes.
type ProvisionHandler struct {
pb.UnimplementedProvisionServiceServer
logger *zap.Logger
state state.State
joinTokenMode string
forceWireguardOverGRPC bool
}
func (h *ProvisionHandler) runCleanup(ctx context.Context) error {
ticker := time.NewTicker(time.Second * 30)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
pendingMachines, err := safe.ReaderListAll[*siderolink.PendingMachine](ctx, h.state)
if err != nil {
h.logger.Error("pending machine cleanup failed", zap.Error(err))
}
for machine := range pendingMachines.All() {
if time.Since(machine.Metadata().Updated()) > time.Second*30 || machine.Metadata().Phase() == resource.PhaseTearingDown {
if err = h.removePendingMachine(ctx, machine); err != nil {
h.logger.Error("failed to remove pending machine", zap.Error(err), zap.String("id", machine.Metadata().ID()))
}
}
}
}
}
}
// Provision handles the requests from Talos nodes.
func (h *ProvisionHandler) Provision(ctx context.Context, req *pb.ProvisionRequest) (*pb.ProvisionResponse, error) {
ctx = actor.MarkContextAsInternalActor(ctx)
provisionContext, err := h.buildProvisionContext(ctx, req)
if err != nil {
return nil, err
}
if !provisionContext.supportsSecureJoinTokens && h.joinTokenMode == config.JoinTokensModeStrict {
return nil, status.Errorf(
codes.FailedPrecondition,
"Talos version %s is not supported on this Omni instance as '--join-tokens-mode' is set to 'strict'",
pointer.SafeDeref(req.TalosVersion),
)
}
resp, err := h.provision(ctx, provisionContext)
if err != nil && status.Code(err) == codes.Unknown {
h.logger.Error("failed to handle machine provision request", zap.Error(err))
return nil, status.Error(codes.Internal, "internal server error")
}
return resp, err
}
type res interface {
generic.ResourceWithRD
TypedSpec() *protobuf.ResourceSpec[specs.SiderolinkSpec, *specs.SiderolinkSpec]
}
func updateAnnotations(res resource.Resource, annotationsToAdd []string, annotationsToRemove []string) {
for _, annotation := range annotationsToAdd {
res.Metadata().Annotations().Set(annotation, "")
}
for _, annotation := range annotationsToRemove {
res.Metadata().Annotations().Delete(annotation)
}
}
// createResource creates the link resource (PendingMachine/Link) if it doesn't exist.
func createResource[T res](ctx context.Context, st state.State, provisionContext *provisionContext,
annotationsToAdd []string, annotationsToRemove []string,
) (T, error) {
var (
zero T
id string
)
rd := zero.ResourceDefinition()
if rd.Type == siderolink.LinkType {
id = provisionContext.request.NodeUuid
} else {
id = provisionContext.request.NodePublicKey
}
var (
res resource.Resource
err error
)
res, err = protobuf.CreateResource(rd.Type)
if err != nil {
return zero, err
}
*res.Metadata() = resource.NewMetadata(rd.DefaultNamespace, rd.Type, id, resource.VersionUndefined)
link, ok := res.(T)
if !ok {
return zero, fmt.Errorf("incorrect resource type")
}
if provisionContext.token != nil {
if value, ok := provisionContext.token.ExtraData[omni.LabelInfraProviderID]; ok {
link.Metadata().Annotations().Set(omni.LabelInfraProviderID, value)
}
if value, ok := provisionContext.token.ExtraData[omni.LabelMachineRequest]; ok {
link.Metadata().Labels().Set(omni.LabelMachineRequest, value)
}
}
link.TypedSpec().Value, err = generateLinkSpec(provisionContext)
if err != nil {
return zero, err
}
if link.Metadata().Type() == siderolink.PendingMachineType {
link.Metadata().Labels().Set(omni.MachineUUID, provisionContext.request.NodeUuid)
}
updateAnnotations(res, annotationsToAdd, annotationsToRemove)
return link, st.Create(ctx, res)
}
func updateResourceWithMatchingToken[T res](ctx context.Context, logger *zap.Logger,
st state.State, provisionContext *provisionContext, r T, annotationsToAdd []string, annotationsToRemove []string,
) (T, error) {
return safe.StateUpdateWithConflicts(ctx, st, r.Metadata(), func(link T) error {
s := link.TypedSpec().Value
if link.Metadata().Type() == siderolink.PendingMachineType {
link.Metadata().Annotations().Set("timestamp", time.Now().String())
}
updateSpec := func() error {
s.NodeUniqueToken = pointer.SafeDeref(provisionContext.request.NodeUniqueToken)
if provisionContext.pendingMachine != nil {
s.NodeSubnet = provisionContext.pendingMachine.TypedSpec().Value.NodeSubnet
logger.Info("updated subnet", zap.String("subnet", s.NodeSubnet))
}
var err error
s.NodePublicKey = provisionContext.request.NodePublicKey
s.VirtualAddrport, err = generateVirtualAddrPort(provisionContext.useWireguardOverGRPC)
if err != nil {
return err
}
updateAnnotations(link, annotationsToAdd, annotationsToRemove)
return nil
}
if s.NodeUniqueToken == "" {
logger.Debug("set unique node token")
return updateSpec()
}
linkNodeUniqueToken, err := jointoken.ParseNodeUniqueToken(s.NodeUniqueToken)
if err != nil {
return err
}
if linkNodeUniqueToken.Equal(provisionContext.requestNodeUniqueToken) ||
!provisionContext.forceValidNodeUniqueToken && provisionContext.requestNodeUniqueToken.IsSameFingerprint(linkNodeUniqueToken) {
logger.Debug("overwrite the existing node unique token")
return updateSpec()
}
// the token has the same fingerprint, but the random part doesn't match
// return the error
// this case might happen if there is a hardware failure, so that the node
// has lost it's META partition contents
if linkNodeUniqueToken.IsSameFingerprint(provisionContext.requestNodeUniqueToken) {
logger.Warn("machine connection rejected: the machine has the correct fingerprint, but the random token part doesn't match")
return status.Error(codes.PermissionDenied, "unauthorized")
}
return errUUIDConflict
})
}
func (h *ProvisionHandler) provision(ctx context.Context, provisionContext *provisionContext) (*pb.ProvisionResponse, error) {
logger := h.logger.With(zap.String("machine", provisionContext.request.NodeUuid))
// legacy flow, let it join unconditionally
if !provisionContext.nodeUniqueTokensEnabled || !provisionContext.supportsSecureJoinTokens {
if !provisionContext.isAuthorizedLegacyJoin() {
return nil, status.Error(codes.PermissionDenied, "unauthorized")
}
return establishLink[*siderolink.Link](ctx, h.logger, h.state, provisionContext, nil, nil)
}
if !provisionContext.isAuthorizedSecureFlow() {
return nil, status.Error(codes.PermissionDenied, "unauthorized")
}
// if the token is not generated and the node supports secure join tokens
// put the machine into the limbo state by creating the pending machine resource
// the controller will then pick it up and create a wireguard peer for it
if provisionContext.requestNodeUniqueToken == nil {
return establishLink[*siderolink.PendingMachine](ctx, h.logger, h.state, provisionContext, nil, nil)
}
annotationsToAdd := []string{}
annotationsToRemove := []string{}
switch {
// the token was wiped, reset the link annotation until Talos gets installed again
case provisionContext.tokenWasWiped:
annotationsToRemove = append(annotationsToRemove, siderolink.ForceValidNodeUniqueToken)
// if we detected that Talos installed during provision
// mark the link with the annotation to block using the link without the unique node token
case provisionContext.forceValidNodeUniqueToken:
annotationsToAdd = append(annotationsToAdd, siderolink.ForceValidNodeUniqueToken)
}
response, err := establishLink[*siderolink.Link](ctx, h.logger, h.state, provisionContext, annotationsToAdd, annotationsToRemove)
if err != nil {
if errors.Is(err, errUUIDConflict) {
logger.Info("detected UUID conflict", zap.String("peer", provisionContext.request.NodePublicKey))
// link is there, but the token doesn't match and the fingerprint differs, keep the machine in the limbo state
// mark pending machine as having the UUID conflict, PendingMachineStatus controller should inject the new UUID
// and the machine will re-join
return establishLink[*siderolink.PendingMachine](ctx, h.logger, h.state, provisionContext, []string{siderolink.PendingMachineUUIDConflict}, nil)
}
return nil, err
}
return response, nil
}
func (h *ProvisionHandler) removePendingMachine(ctx context.Context, pendingMachine *siderolink.PendingMachine) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
ready, err := h.state.Teardown(ctx, pendingMachine.Metadata())
if err != nil {
if state.IsNotFoundError(err) {
return nil
}
return nil
}
if !ready {
return nil
}
err = h.state.Destroy(ctx, pendingMachine.Metadata())
if err != nil {
if state.IsNotFoundError(err) {
return nil
}
return err
}
h.logger.Info("cleaned up the pending machine link after grace period", zap.String("id", pendingMachine.Metadata().ID()))
return nil
}
func establishLink[T res](ctx context.Context, logger *zap.Logger, st state.State, provisionContext *provisionContext,
annotationsToAdd []string, annotationsToRemove []string,
) (*pb.ProvisionResponse, error) {
link, err := createResource[T](ctx, st, provisionContext, annotationsToAdd, annotationsToRemove)
if err != nil {
if !state.IsConflictError(err) {
return nil, err
}
link, err = updateResourceWithMatchingToken[T](ctx, logger, st, provisionContext, link, annotationsToAdd, annotationsToRemove)
if err != nil {
if state.IsPhaseConflictError(err) {
return nil, status.Errorf(codes.AlreadyExists, "the machine with the same UUID is already registered in Omni and is in the tearing down phase")
}
return nil, err
}
}
return genProvisionResponse(ctx, logger, st, provisionContext, link, link.TypedSpec().Value)
}
func genProvisionResponse(ctx context.Context, logger *zap.Logger, st state.State,
provisionContext *provisionContext, link resource.Resource, spec *specs.SiderolinkSpec,
) (*pb.ProvisionResponse, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
logger.Debug("waiting for the Wireguard peer to be created", zap.String("link", link.Metadata().String()))
_, err := st.WatchFor(ctx,
siderolink.NewLinkStatus(link).Metadata(),
state.WithPhases(resource.PhaseRunning),
state.WithCondition(func(r resource.Resource) (bool, error) {
return !resource.IsTombstone(r), nil
}),
)
if err != nil {
return nil, fmt.Errorf("timed out waiting for the wireguard peer to be created on Omni side %w", err)
}
wgConfig := provisionContext.siderolinkConfig.TypedSpec().Value
endpoint := wgConfig.WireguardEndpoint
if wgConfig.AdvertisedEndpoint != "" {
endpoint = wgConfig.AdvertisedEndpoint
}
// If the virtual address is set, use it as the endpoint to prevent the client from connecting to the actual WG endpoint
if spec.VirtualAddrport != "" {
endpoint = spec.VirtualAddrport
}
endpoints := strings.Split(endpoint, ",")
logger.Debug("generated response",
zap.String("node_address", spec.NodeSubnet),
zap.String("public_key", wgConfig.PublicKey),
zap.String("grpc_addr_port", spec.VirtualAddrport),
)
return &pb.ProvisionResponse{
ServerEndpoint: pb.MakeEndpoints(endpoints...),
ServerPublicKey: wgConfig.PublicKey,
NodeAddressPrefix: spec.NodeSubnet,
ServerAddress: wgConfig.ServerAddress,
GrpcPeerAddrPort: spec.VirtualAddrport,
}, nil
}
func generateLinkSpec(provisionContext *provisionContext) (*specs.SiderolinkSpec, error) {
nodePrefix := netip.MustParsePrefix(provisionContext.siderolinkConfig.TypedSpec().Value.Subnet)
var nodeAddress string
switch {
case provisionContext.link != nil:
nodeAddress = provisionContext.link.TypedSpec().Value.NodeSubnet
case provisionContext.pendingMachine != nil:
nodeAddress = provisionContext.pendingMachine.TypedSpec().Value.NodeSubnet
default:
// generated random address for the node
addr, err := wireguard.GenerateRandomNodeAddr(nodePrefix)
if err != nil {
return nil, fmt.Errorf("error generating random node address: %w", err)
}
nodeAddress = addr.String()
}
pubKey, err := wgtypes.ParseKey(provisionContext.request.NodePublicKey)
if err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error parsing Wireguard key: %s", err))
}
virtualAddrPort, err := generateVirtualAddrPort(provisionContext.useWireguardOverGRPC)
if err != nil {
return nil, err
}
return &specs.SiderolinkSpec{
NodeSubnet: nodeAddress,
NodePublicKey: pubKey.String(),
VirtualAddrport: virtualAddrPort,
NodeUniqueToken: pointer.SafeDeref(provisionContext.request.NodeUniqueToken),
Connected: true,
}, nil
}
func generateVirtualAddrPort(generate bool) (string, error) {
if !generate {
return "", nil
}
generated, err := wireguard.GenerateRandomNodeAddr(wireguard.VirtualNetworkPrefix())
if err != nil {
return "", fmt.Errorf("error generating random virtual node address: %w", err)
}
return net.JoinHostPort(generated.Addr().String(), "50889"), nil
}
//nolint:gocyclo,cyclop
func (h *ProvisionHandler) buildProvisionContext(ctx context.Context, req *pb.ProvisionRequest) (*provisionContext, error) {
link, err := safe.StateGetByID[*siderolink.Link](ctx, h.state, req.NodeUuid)
if err != nil && !state.IsNotFoundError(err) {
return nil, err
}
// TODO: add support of several join tokens here
siderolinkConfig, err := safe.ReaderGetByID[*siderolink.Config](ctx, h.state, siderolink.ConfigID)
if err != nil {
return nil, err
}
var (
linkToken *jointoken.JoinToken
requestNodeUniqueToken *jointoken.NodeUniqueToken
linkNodeUniqueToken *jointoken.NodeUniqueToken
pendingMachineStatus *siderolink.PendingMachineStatus
forceValidUniqueToken bool
tokenWasWiped bool
)
if req.JoinToken != nil {
linkToken, err = h.getJoinToken(*req.JoinToken)
if err != nil {
return nil, err
}
}
// if the version is not set, consider the machine be below 1.6
talosVersion := semver.MustParse("1.5.0")
if pointer.SafeDeref(req.TalosVersion) != "" {
talosVersion, err = semver.ParseTolerant(*req.TalosVersion)
if err != nil {
return nil, fmt.Errorf("failed to parse Talos version %q from the provision request %w", *req.TalosVersion, err)
}
}
pendingMachine, err := safe.ReaderGetByID[*siderolink.PendingMachine](ctx, h.state, req.NodePublicKey)
if err != nil && !state.IsNotFoundError(err) {
return nil, err
}
if uniqueToken := pointer.SafeDeref(req.NodeUniqueToken); uniqueToken != "" {
requestNodeUniqueToken, err = jointoken.ParseNodeUniqueToken(uniqueToken)
if err != nil {
return nil, err
}
}
if pendingMachine != nil {
pendingMachineStatus, err = safe.StateWatchFor[*siderolink.PendingMachineStatus](ctx,
h.state,
siderolink.NewPendingMachineStatus(pendingMachine.Metadata().ID()).Metadata(),
state.WithPhases(resource.PhaseRunning),
state.WithCondition(func(r resource.Resource) (bool, error) {
return !resource.IsTombstone(r), nil
}),
)
if err != nil {
return nil, err
}
forceValidUniqueToken = pendingMachineStatus.TypedSpec().Value.TalosInstalled
}
if link != nil {
_, forceValidUniqueToken = link.Metadata().Annotations().Get(siderolink.ForceValidNodeUniqueToken)
linkNodeUniqueToken, err = jointoken.ParseNodeUniqueToken(link.TypedSpec().Value.NodeUniqueToken)
if err != nil {
return nil, err
}
var machineStatus *infra.MachineStatus
machineStatus, err = safe.ReaderGetByID[*infra.MachineStatus](ctx, h.state, req.NodeUuid)
if err != nil && !state.IsNotFoundError(err) {
return nil, err
}
if machineStatus != nil && link.TypedSpec().Value.NodeUniqueToken == machineStatus.TypedSpec().Value.WipedNodeUniqueToken {
forceValidUniqueToken = false
tokenWasWiped = true
}
}
supportsSecureJoinTokens := talosVersion.GTE(minSupportedSecureTokensVersion)
return &provisionContext{
siderolinkConfig: siderolinkConfig,
link: link,
pendingMachine: pendingMachine,
pendingMachineStatus: pendingMachineStatus,
token: linkToken,
request: req,
requestNodeUniqueToken: requestNodeUniqueToken,
linkNodeUniqueToken: linkNodeUniqueToken,
forceValidNodeUniqueToken: forceValidUniqueToken,
tokenWasWiped: tokenWasWiped,
hasValidJoinToken: linkToken != nil && linkToken.IsValid(siderolinkConfig.TypedSpec().Value.JoinToken),
hasValidNodeUniqueToken: linkNodeUniqueToken.Equal(requestNodeUniqueToken),
nodeUniqueTokensEnabled: h.joinTokenMode != config.JoinTokensModeLegacyOnly,
supportsSecureJoinTokens: supportsSecureJoinTokens,
useWireguardOverGRPC: h.forceWireguardOverGRPC || pointer.SafeDeref(req.WireguardOverGrpc),
}, nil
}
func (h *ProvisionHandler) getJoinToken(tokenString string) (*jointoken.JoinToken, error) {
var token jointoken.JoinToken
token, err := jointoken.Parse(tokenString)
if err != nil {
return nil, status.Errorf(codes.PermissionDenied, "invalid join token %s", err)
}
return &token, nil
}