omni/internal/backend/server.go
Utku Ozdemir 176f9d9f57
feat: compute schematic id only from the extensions
When determining the schematic ID of a machine, instead of relying the ID on the schematic ID meta-extension, compute the ID by gathering the extensions on the machine. This way, the extension ID will not contain the META values, labels or the kernel args.

This ID is actually the ID we need, as when we compare the desired schematic with the actual one during a Talos upgrade, we are only interested in the changes in the list of extensions.

This does not cause the kernel args, labels, etc. to disappear, as they are used at installation time and preserved afterward (e.g., during upgrades).

Additionally:
- Remove the list of extensions from the `Schematic` resource, as it relied upon the schematics always being created through Omni. This is not always the case - i.e., when a partial join config is used. Therefore, instead of relying on it, we store the list of extensions by directly reading them from the machine and storing them on the `MachineStatus` resource.
- Skip setting the schematic META section at all if there are no labels set on Download Installation Media screen.

Closes siderolabs/omni#55.

Signed-off-by: Utku Ozdemir <utku.ozdemir@siderolabs.com>
2024-03-22 14:58:19 +03:00

1109 lines
32 KiB
Go

// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package backend contains all internal backend code.
package backend
import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"strconv"
"strings"
"time"
pgpcrypto "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/cosi-project/runtime/api/v1alpha1"
"github.com/cosi-project/runtime/pkg/resource"
"github.com/cosi-project/runtime/pkg/resource/meta"
"github.com/cosi-project/runtime/pkg/resource/protobuf"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
protobufserver "github.com/cosi-project/runtime/pkg/state/protobuf/server"
"github.com/crewjam/saml/samlsp"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/siderolabs/gen/value"
"github.com/siderolabs/go-api-signature/pkg/pgp"
talosconstants "github.com/siderolabs/talos/pkg/machinery/constants"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
resapi "github.com/siderolabs/omni/client/api/omni/resources"
"github.com/siderolabs/omni/client/pkg/constants"
"github.com/siderolabs/omni/client/pkg/omni/resources"
authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth"
omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/debug"
"github.com/siderolabs/omni/internal/backend/dns"
"github.com/siderolabs/omni/internal/backend/factory"
grpcomni "github.com/siderolabs/omni/internal/backend/grpc"
"github.com/siderolabs/omni/internal/backend/grpc/router"
"github.com/siderolabs/omni/internal/backend/health"
"github.com/siderolabs/omni/internal/backend/imagefactory"
"github.com/siderolabs/omni/internal/backend/k8sproxy"
"github.com/siderolabs/omni/internal/backend/logging"
"github.com/siderolabs/omni/internal/backend/monitoring"
"github.com/siderolabs/omni/internal/backend/oidc"
"github.com/siderolabs/omni/internal/backend/runtime"
"github.com/siderolabs/omni/internal/backend/runtime/kubernetes"
"github.com/siderolabs/omni/internal/backend/runtime/omni"
"github.com/siderolabs/omni/internal/backend/runtime/talos"
"github.com/siderolabs/omni/internal/backend/saml"
"github.com/siderolabs/omni/internal/backend/workloadproxy"
"github.com/siderolabs/omni/internal/frontend"
"github.com/siderolabs/omni/internal/memconn"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/auth0"
"github.com/siderolabs/omni/internal/pkg/auth/handler"
"github.com/siderolabs/omni/internal/pkg/auth/interceptor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/cache"
"github.com/siderolabs/omni/internal/pkg/compress"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/errgroup"
"github.com/siderolabs/omni/internal/pkg/grpcutil"
"github.com/siderolabs/omni/internal/pkg/kms"
"github.com/siderolabs/omni/internal/pkg/siderolink"
)
// Server is main backend entrypoint that starts REST API, WebSocket and Serves static contents.
type Server struct {
omniRuntime *omni.Runtime
logger *zap.Logger
logHandler *siderolink.LogHandler
authConfig *authres.Config
dnsService *dns.Service
workloadProxyServiceRegistry *workloadproxy.ServiceRegistry
imageFactoryClient *imagefactory.Client
linkCounterDeltaCh chan<- siderolink.LinkCounterDeltas
proxyServer Proxy
bindAddress string
metricsBindAddress string
pprofBindAddress string
k8sProxyBindAddress string
keyFile string
certFile string
}
// NewServer creates new HTTP server.
func NewServer(
bindAddress, metricsBindAddress, k8sProxyBindAddress, pprofBindAddress string,
dnsService *dns.Service,
workloadProxyServiceRegistry *workloadproxy.ServiceRegistry,
imageFactoryClient *imagefactory.Client,
linkCounterDeltaCh chan<- siderolink.LinkCounterDeltas,
omniRuntime *omni.Runtime,
talosRuntime *talos.Runtime,
logHandler *siderolink.LogHandler,
authConfig *authres.Config,
keyFile, certFile string,
proxyServer Proxy,
logger *zap.Logger,
) (*Server, error) {
s := &Server{
omniRuntime: omniRuntime,
logger: logger.With(logging.Component("server")),
logHandler: logHandler,
authConfig: authConfig,
dnsService: dnsService,
workloadProxyServiceRegistry: workloadProxyServiceRegistry,
imageFactoryClient: imageFactoryClient,
linkCounterDeltaCh: linkCounterDeltaCh,
proxyServer: proxyServer,
bindAddress: bindAddress,
metricsBindAddress: metricsBindAddress,
k8sProxyBindAddress: k8sProxyBindAddress,
pprofBindAddress: pprofBindAddress,
keyFile: keyFile,
certFile: certFile,
}
k8sruntime, err := kubernetes.New(omniRuntime.State())
if err != nil {
return nil, err
}
prometheus.MustRegister(k8sruntime)
runtime.Install(kubernetes.Name, k8sruntime)
runtime.Install(talos.Name, talosRuntime)
runtime.Install(omni.Name, s.omniRuntime)
return s, nil
}
// RegisterRuntime adds a runtime.
func (s *Server) RegisterRuntime(name string, r runtime.Runtime) {
runtime.Install(name, r)
}
// Run runs HTTP server.
func (s *Server) Run(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
s.omniRuntime.Run(ctx, eg)
runtimeState := s.omniRuntime.State()
oidcStorage := oidc.NewStorage(runtimeState, s.logger)
oidcProvider, err := oidc.NewProvider(ctx, oidcStorage)
if err != nil {
return err
}
imageFactoryHandler := handler.NewAuthConfig(
handler.NewSignature(
&factory.Handler{
State: runtimeState,
Logger: s.logger.With(logging.Component("factory_proxy")),
},
s.authenticatorFunc(),
s.logger,
),
authres.Enabled(s.authConfig),
s.logger,
)
var samlHandler *samlsp.Middleware
if s.authConfig.TypedSpec().Value.Saml.Enabled {
samlHandler, err = saml.NewHandler(s.omniRuntime.State(), s.authConfig.TypedSpec().Value.Saml, s.logger) //nolint:contextcheck
if err != nil {
return err
}
}
mux, err := makeMux(imageFactoryHandler, oidcProvider.HttpHandler(), samlHandler, s.omniRuntime, s.logger)
if err != nil {
return fmt.Errorf("failed to create mux: %w", err)
}
serverOptions, err := s.buildServerOptions() //nolint:contextcheck
if err != nil {
return err
}
serviceServers, err := grpcomni.MakeServiceServers(runtimeState, s.logHandler, oidcProvider, oidcStorage, s.dnsService, s.imageFactoryClient, s.logger)
if err != nil {
return err
}
gatewayTransport := &memconn.Transport{Address: "gateway-conn"}
grpcServer, err := grpcomni.New(ctx, mux, serviceServers, gatewayTransport, s.logger, serverOptions...)
if err != nil {
return err
}
grpcTransport := &memconn.Transport{Address: "grpc-conn"}
rtr, err := router.NewRouter(
grpcTransport,
runtimeState,
s.dnsService,
authres.Enabled(s.authConfig),
interceptor.NewSignature(s.authenticatorFunc(), s.logger).Unary(),
)
if err != nil {
return err
}
prometheus.MustRegister(rtr)
eg.Go(func() error { return rtr.ClusterWatcher(ctx, runtimeState) })
grpcProxyServer := router.NewServer(rtr,
router.Interceptors(s.logger),
grpc.MaxRecvMsgSize(constants.GRPCMaxMessageSize),
)
crtData := certData{certFile: s.certFile, keyFile: s.keyFile}
workloadProxyHandler, err := s.workloadProxyHandler(mux)
if err != nil {
return fmt.Errorf("failed to create workload proxy handler: %w", err)
}
unifiedHandler := unifyHandler(workloadProxyHandler, grpcProxyServer, crtData)
fns := []func() error{
func() error { return runGRPCServer(ctx, grpcProxyServer, gatewayTransport, s.logger) },
func() error { return runAPIServer(ctx, unifiedHandler, s.bindAddress, crtData, s.logger) },
func() error { return runGRPCServer(ctx, grpcServer, grpcTransport, s.logger) },
func() error { return runMetricsServer(ctx, s.metricsBindAddress, s.logger) },
func() error {
return runK8sProxyServer(ctx, s.k8sProxyBindAddress, oidcStorage, crtData, runtimeState, s.logger)
},
func() error { return s.proxyServer.Run(ctx, unifiedHandler, s.logger) },
func() error { return s.logHandler.Start(ctx) },
func() error { return s.runMachineAPI(ctx) },
}
if s.pprofBindAddress != "" {
fns = append(fns, func() error { return runPprofServer(ctx, s.pprofBindAddress, s.logger) })
}
for _, fn := range fns {
eg.Go(fn)
}
if err = runLocalResourceServer(ctx, runtimeState, serverOptions, eg, s.logger); err != nil {
return fmt.Errorf("failed to run local resource server: %w", err)
}
return eg.Wait()
}
// buildServerOptions builds the gRPC server options.
//
// Recovery is installed as the first middleware in the chain to handle panics (via defer and recover()) in all subsequent middlewares.
//
// Logging is installed as the first middleware (even before recovery middleware) in the chain
// so that request in the form it was received and status sent on the wire is logged (error/success).
// It also tracks the whole duration of the request, including other middleware overhead.
func (s *Server) buildServerOptions() ([]grpc.ServerOption, error) {
recoveryOpt := grpc_recovery.WithRecoveryHandler(recoveryHandler(s.logger))
messageProducer := grpcutil.LogLevelOverridingMessageProducer(grpc_zap.DefaultMessageProducer)
logLevelOverrideUnaryInterceptor, logLevelOverrideStreamInterceptor := grpcutil.LogLevelInterceptors()
grpc_prometheus.EnableHandlingTimeHistogram(grpc_prometheus.WithHistogramBuckets([]float64{0.001, 0.01, 0.1, 1, 10, 30, 60, 120, 300, 600}))
unaryInterceptors := []grpc.UnaryServerInterceptor{
grpc_ctxtags.UnaryServerInterceptor(),
logLevelOverrideUnaryInterceptor,
grpc_zap.UnaryServerInterceptor(s.logger, grpc_zap.WithMessageProducer(messageProducer)),
grpcutil.SetUserAgent(),
grpcutil.SetRealPeerAddress(),
grpcutil.InterceptBodyToTags(
grpcutil.NewHook(
grpcutil.NewRewriter(resourceServerCreate),
grpcutil.NewRewriter(resourceServerUpdate),
grpcutil.NewRewriter(cosiResourceServerCreate),
grpcutil.NewRewriter(cosiResourceServerUpdate),
),
1024,
),
grpc_prometheus.UnaryServerInterceptor,
grpc_recovery.UnaryServerInterceptor(recoveryOpt),
}
streamInterceptors := []grpc.StreamServerInterceptor{
grpc_ctxtags.StreamServerInterceptor(),
logLevelOverrideStreamInterceptor,
grpc_zap.StreamServerInterceptor(s.logger, grpc_zap.WithMessageProducer(messageProducer)),
grpcutil.StreamSetUserAgent(),
grpcutil.StreamSetRealPeerAddress(),
grpcutil.StreamIntercept(
grpcutil.StreamHooks{
RecvMsg: grpcutil.StreamInterceptRequestBodyToTags(
grpcutil.NewHook(
grpcutil.NewRewriter(resourceServerCreate),
grpcutil.NewRewriter(resourceServerUpdate),
grpcutil.NewRewriter(cosiResourceServerCreate),
grpcutil.NewRewriter(cosiResourceServerUpdate),
),
1024,
),
},
),
grpc_prometheus.StreamServerInterceptor,
grpc_recovery.StreamServerInterceptor(recoveryOpt),
}
unaryAuthInterceptors, streamAuthInterceptors, err := s.getAuthInterceptors()
if err != nil {
return nil, err
}
unaryInterceptors = append(unaryInterceptors, unaryAuthInterceptors...)
streamInterceptors = append(streamInterceptors, streamAuthInterceptors...)
return []grpc.ServerOption{
grpc.MaxRecvMsgSize(constants.GRPCMaxMessageSize),
grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...),
grpc.SharedWriteBuffer(true),
}, nil
}
func (s *Server) getAuthInterceptors() ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor, error) {
authEnabled := authres.Enabled(s.authConfig)
authConfigInterceptor := interceptor.NewAuthConfig(authEnabled, s.logger)
unaryInterceptors := []grpc.UnaryServerInterceptor{
authConfigInterceptor.Unary(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
authConfigInterceptor.Stream(),
}
if !authEnabled {
return unaryInterceptors, streamInterceptors, nil
}
// auth is enabled, add signature and jwt interceptors
signatureInterceptor := interceptor.NewSignature(s.authenticatorFunc(), s.logger)
unaryInterceptors = append(unaryInterceptors, signatureInterceptor.Unary())
streamInterceptors = append(streamInterceptors, signatureInterceptor.Stream())
switch {
case s.authConfig.TypedSpec().Value.Auth0.Enabled:
verifier, err := auth0.NewIDTokenVerifier(s.authConfig.TypedSpec().Value.GetAuth0().Domain)
if err != nil {
return nil, nil, err
}
jwtInterceptor := interceptor.NewJWT(verifier, s.logger)
unaryInterceptors = append(unaryInterceptors, jwtInterceptor.Unary())
streamInterceptors = append(streamInterceptors, jwtInterceptor.Stream())
case s.authConfig.TypedSpec().Value.Saml.Enabled:
samlInterceptor := interceptor.NewSAML(s.omniRuntime.State(), s.logger)
unaryInterceptors = append(unaryInterceptors, samlInterceptor.Unary())
streamInterceptors = append(streamInterceptors, samlInterceptor.Stream())
}
return unaryInterceptors, streamInterceptors, nil
}
func (s *Server) authenticatorFunc() auth.AuthenticatorFunc {
return func(ctx context.Context, fingerprint string) (*auth.Authenticator, error) {
ctx = actor.MarkContextAsInternalActor(ctx)
ptr := authres.NewPublicKey(resources.DefaultNamespace, fingerprint).Metadata()
pubKey, err := safe.StateGet[*authres.PublicKey](ctx, s.omniRuntime.State(), ptr)
if err != nil {
return nil, err
}
if pubKey.TypedSpec().Value.Expiration.AsTime().Before(time.Now()) {
return nil, errors.New("public key expired")
}
if !pubKey.TypedSpec().Value.Confirmed {
return nil, errors.New("public key not confirmed")
}
userID, labelExists := pubKey.Metadata().Labels().Get(authres.LabelPublicKeyUserID)
if !labelExists {
return nil, errors.New("public key has no user ID label")
}
key, err := pgpcrypto.NewKeyFromArmored(string(pubKey.TypedSpec().Value.GetPublicKey()))
if err != nil {
return nil, err
}
verifier, err := pgp.NewKey(key)
if err != nil {
return nil, err
}
user, err := safe.StateGet[*authres.User](ctx, s.omniRuntime.State(), resource.NewMetadata(resources.DefaultNamespace, authres.UserType, userID, resource.VersionUndefined))
if err != nil {
return nil, err
}
finalRole, err := role.Min(role.Role(user.TypedSpec().Value.GetRole()), role.Role(pubKey.TypedSpec().Value.GetRole()))
if err != nil {
return nil, err
}
if config.Config.Auth.Suspended {
finalRole = role.Reader
}
return &auth.Authenticator{
UserID: userID,
Identity: pubKey.TypedSpec().Value.GetIdentity().GetEmail(),
Role: finalRole,
Verifier: verifier,
}, nil
}
}
func (s *Server) runMachineAPI(ctx context.Context) error {
wgAddress := config.Config.SiderolinkWireguardBindAddress
params := siderolink.Params{
WireguardEndpoint: wgAddress,
AdvertisedEndpoint: config.Config.SiderolinkWireguardAdvertisedAddress,
APIEndpoint: config.Config.MachineAPIBindAddress,
Cert: config.Config.MachineAPICertFile,
Key: config.Config.MachineAPIKeyFile,
EventSinkPort: strconv.Itoa(config.Config.EventSinkPort),
}
slink, err := siderolink.NewManager(
ctx,
s.omniRuntime.State(),
siderolink.DefaultWireguardHandler,
params,
s.logger.With(logging.Component("siderolink")).WithOptions(
zap.AddStacktrace(zapcore.ErrorLevel), // prevent warn level from printing stack traces
),
s.logHandler,
s.linkCounterDeltaCh,
)
if err != nil {
return err
}
kms := kms.NewManager(
s.omniRuntime.State(),
s.logger.With(logging.Component("kms")).WithOptions(
zap.AddStacktrace(zapcore.ErrorLevel), // prevent warn level from printing stack traces
),
)
prometheus.MustRegister(slink)
// start API listener
lis, err := params.NewListener()
if err != nil {
return fmt.Errorf("error listening for Siderolink gRPC API: %w", err)
}
eg, groupCtx := errgroup.WithContext(ctx)
server := grpc.NewServer(
grpc.SharedWriteBuffer(true),
)
slink.Register(server)
kms.Register(server)
eg.Go(func() error {
return slink.Run(groupCtx,
"",
strconv.Itoa(config.Config.EventSinkPort),
strconv.Itoa(talosconstants.TrustdPort),
strconv.Itoa(config.Config.LogServerPort),
)
})
grpcutil.RunServer(groupCtx, server, lis, eg)
return eg.Wait()
}
func (s *Server) workloadProxyHandler(next http.Handler) (http.Handler, error) {
roleProvider, err := workloadproxy.NewAccessPolicyRoleProvider(s.omniRuntime.State())
if err != nil {
return nil, fmt.Errorf("failed to create access policy role provider: %w", err)
}
pgpSignatureValidator, err := workloadproxy.NewPGPAccessValidator(s.omniRuntime.State(), roleProvider,
s.logger.With(logging.Component("pgp_access_validator")))
if err != nil {
return nil, fmt.Errorf("failed to create pgp signature validator: %w", err)
}
mainURL, err := url.Parse(config.Config.APIURL)
if err != nil {
return nil, fmt.Errorf("failed to parse API URL: %w", err)
}
return workloadproxy.NewHTTPHandler(
next,
s.workloadProxyServiceRegistry,
pgpSignatureValidator,
mainURL,
s.logger.With(logging.Component("workload_proxy_handler")),
)
}
func recoveryHandler(logger *zap.Logger) grpc_recovery.RecoveryHandlerFunc {
return func(p any) error {
if logger != nil {
logger.Error("grpc panic", zap.Any("panic", p), zap.Stack("stack"))
}
return status.Errorf(codes.Internal, "%v", p)
}
}
func cosiResourceServerCreate(req *v1alpha1.CreateRequest) (*v1alpha1.CreateRequest, bool) {
if isSensitiveResource(req.Resource) {
req.Resource.Spec = nil
return req, true
}
return nil, false
}
func cosiResourceServerUpdate(req *v1alpha1.UpdateRequest) (*v1alpha1.UpdateRequest, bool) {
if isSensitiveResource(req.NewResource) {
req.NewResource.Spec = nil
return req, true
}
return nil, false
}
func resourceServerCreate(resCopy *resapi.CreateRequest) (*resapi.CreateRequest, bool) {
if isSensitiveSpec(resCopy.Resource) {
resCopy.Resource.Spec = ""
return resCopy, true
}
return nil, false
}
func resourceServerUpdate(resCopy *resapi.UpdateRequest) (*resapi.UpdateRequest, bool) {
if isSensitiveSpec(resCopy.Resource) {
resCopy.Resource.Spec = ""
return resCopy, true
}
return nil, false
}
func isSensitiveResource(res *v1alpha1.Resource) bool {
protoR, err := protobuf.Unmarshal(res)
if err != nil {
return false
}
properResource, err := protobuf.UnmarshalResource(protoR)
if err != nil {
return false
}
resDef, ok := properResource.(meta.ResourceDefinitionProvider)
if !ok || resDef.ResourceDefinition().Sensitivity == meta.Sensitive {
// If we have !ok we do not know if this resource have Sensitive field, so we will mask it anyway.
return true
}
return false
}
func isSensitiveSpec(resource *resapi.Resource) bool {
res, err := grpcomni.CreateResource(resource)
if err != nil {
return false
}
resDef, ok := res.(meta.ResourceDefinitionProvider)
if !ok || resDef.ResourceDefinition().Sensitivity == meta.Sensitive {
// If we have !ok we do not know if this resource have Sensitive field, so we will mask it anyway.
return true
}
return false
}
func makeMux(imageHandler, oidcHandler http.Handler, samlHandler *samlsp.Middleware, omniRuntime *omni.Runtime, logger *zap.Logger) (*http.ServeMux, error) {
mux := http.NewServeMux()
muxHandle := func(route string, handler http.Handler, value string) {
mux.Handle(route, monitoring.NewHandler(
logging.NewHandler(handler, logger.With(zap.String("handler", value))),
prometheus.Labels{"handler": value},
))
}
muxHandle(
"/",
compress.Handler(
frontend.NewStaticHandler(7200),
gzip.BestCompression,
),
"static",
)
if samlHandler != nil {
saml.RegisterHandlers(samlHandler, mux, logger)
}
muxHandle("/image/", imageHandler, "image")
omnictlHndlr, err := getOmnictlDownloads("./omnictl/")
if err != nil {
return nil, err
}
talosctlHandler, err := makeTalosctlHandler(omniRuntime.State(), logger)
if err != nil {
return nil, err
}
muxHandle("/omnictl/", http.StripPrefix("/omnictl/", omnictlHndlr), "files")
muxHandle("/talosctl/downloads", talosctlHandler, "talosctl-downloads")
// actually enabled only in debug build
muxHandle("/debug/", debug.NewHandler(omniRuntime.GetCOSIRuntime(), omniRuntime.State()), "debug")
// OIDC Provider
mux.Handle("/oidc/",
http.StripPrefix("/oidc",
monitoring.NewHandler(
logging.NewHandler(
oidcHandler,
logger.With(zap.String("handler", "debug")),
),
prometheus.Labels{"handler": "debug"},
),
),
)
// Health checks
muxHandle("/healthz", health.NewHandler(omniRuntime.State(), logger), "health")
return mux, nil
}
func getOmnictlDownloads(dir string) (http.Handler, error) {
readDir, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("failed to read directory %q: %w", dir, err)
}
for _, entry := range readDir {
name := entry.Name()
if !entry.Type().IsRegular() {
return nil, fmt.Errorf("entry %q is not a regular file in %q", name, dir)
}
}
return http.FileServer(http.Dir(dir)), nil
}
func runMetricsServer(ctx context.Context, bindAddress string, logger *zap.Logger) error {
var metricsMux http.ServeMux
metricsMux.Handle("/metrics", promhttp.Handler())
metricsServer := &http.Server{
Addr: bindAddress,
Handler: &metricsMux,
}
logger = logger.With(zap.String("server", bindAddress), zap.String("server_type", "metrics"))
return runServer(ctx, &server{
server: metricsServer,
}, logger)
}
type oidcStore interface {
GetPublicKeyByID(keyID string) (any, error)
}
func runK8sProxyServer(ctx context.Context, bindAddress string, oidcStorage oidcStore, data certData,
runtimeState state.State, logger *zap.Logger,
) error {
keyFunc := func(_ context.Context, keyID string) (any, error) {
return oidcStorage.GetPublicKeyByID(keyID)
}
clusterUUIDResolver := func(ctx context.Context, clusterID string) (resource.ID, error) {
ctx = actor.MarkContextAsInternalActor(ctx)
uuid, resolveErr := safe.StateGetByID[*omnires.ClusterUUID](ctx, runtimeState, clusterID)
if resolveErr != nil {
return "", fmt.Errorf("failed to resolve cluster ID to UUID: %w", resolveErr)
}
return uuid.TypedSpec().Value.Uuid, nil
}
k8sProxyHandler, err := k8sproxy.NewHandler(keyFunc, clusterUUIDResolver, logger)
if err != nil {
return err
}
prometheus.MustRegister(k8sProxyHandler)
k8sProxy := monitoring.NewHandler(
logging.NewHandler(
k8sProxyHandler,
logger.With(zap.String("handler", "k8s_proxy")),
),
prometheus.Labels{"handler": "k8s-proxy"},
)
k8sProxyServer := &http.Server{
Addr: bindAddress,
Handler: k8sProxy,
}
logger = logger.With(zap.String("server", bindAddress), zap.String("server_type", "k8s_proxy"))
return runServer(ctx, &server{
server: k8sProxyServer,
certData: data,
}, logger)
}
func runAPIServer(ctx context.Context, handler http.Handler, bindAddress string, data certData, logger *zap.Logger) error {
srv := &http.Server{
Addr: bindAddress,
Handler: handler,
}
logger = logger.With(zap.String("server", bindAddress), zap.String("server_type", "api"))
return runServer(ctx, &server{
server: srv,
certData: data,
}, logger)
}
// setRealIPRequest extracts ip from the request and sets it to the X-Real-IP header if there is neither X-Real-IP nore
// X-Forwarded-For.
func setRealIPRequest(req *http.Request) *http.Request {
if req.Header.Get("X-Real-IP") != "" || req.Header.Get("X-Forwarded-For") != "" {
return req
}
actualIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return req
}
newReq := req.Clone(req.Context())
newReq.Header.Set("X-Real-IP", actualIP)
return newReq
}
type server struct {
server *http.Server
certData
}
type certData struct {
certFile string
keyFile string
}
func (s *server) ListenAndServe() error {
if s.certFile != "" || s.keyFile != "" {
return s.server.ListenAndServeTLS(s.certFile, s.keyFile)
}
return s.server.ListenAndServe()
}
func (s *server) Shutdown(ctx context.Context) error {
err := s.server.Shutdown(ctx)
if errors.Is(ctx.Err(), err) {
closeErr := s.server.Close()
if closeErr != nil {
return fmt.Errorf("failed to close server: %w", closeErr)
}
}
return err
}
func runServer(ctx context.Context, srv *server, logger *zap.Logger) error {
logger.Info("server starting")
defer logger.Info("server stopped")
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
select {
case err := <-errCh:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to serve: %w", err)
}
return nil
case <-ctx.Done():
logger.Info("server stopping")
}
shutdownCtx, shutdownCtxCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCtxCancel()
//nolint:contextcheck
err := srv.Shutdown(shutdownCtx)
if err != nil {
logger.Error("failed to gracefully stop server", zap.Error(err))
}
return err
}
func runLocalResourceServer(ctx context.Context, st state.CoreState, serverOptions []grpc.ServerOption, eg *errgroup.Group, logger *zap.Logger) error {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.Config.LocalResourceServerPort))
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
unaryInterceptor := grpc.UnaryServerInterceptor(func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
return handler(actor.MarkContextAsInternalActor(ctx), req)
})
streamInterceptor := grpc.StreamServerInterceptor(func(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, &grpc_middleware.WrappedServerStream{
ServerStream: ss,
WrappedContext: actor.MarkContextAsInternalActor(ss.Context()),
})
})
serverOptions = append([]grpc.ServerOption{
grpc.ChainUnaryInterceptor(unaryInterceptor),
grpc.ChainStreamInterceptor(streamInterceptor),
grpc.SharedWriteBuffer(true),
}, serverOptions...)
grpcServer := grpc.NewServer(serverOptions...)
readOnlyState := state.WrapCore(state.Filter(st, func(_ context.Context, access state.Access) error {
if !access.Verb.Readonly() {
return status.Error(codes.PermissionDenied, "only read-only access is permitted")
}
return nil
}))
v1alpha1.RegisterStateServer(grpcServer, protobufserver.NewState(readOnlyState))
logger.Info("starting local resource server")
grpcutil.RunServer(ctx, grpcServer, listener, eg)
return nil
}
func runGRPCServer(ctx context.Context, server *grpc.Server, transport *memconn.Transport, logger *zap.Logger) error {
grpcListener, err := transport.Listener()
if err != nil {
return fmt.Errorf("failed to create listener: %w", err)
}
logger.Info("internal API server starting", zap.String("address", grpcListener.Addr().String()))
defer logger.Info("internal API server stopped")
errCh := make(chan error, 1)
go func() { errCh <- server.Serve(grpcListener) }()
select {
case err := <-errCh:
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
return fmt.Errorf("failed to serve: %w", err)
}
return nil
case <-ctx.Done():
logger.Info("grpc server stopping")
}
// Since we use a memconn transport and ServeHTTP, we can't use the graceful shutdown
server.Stop()
return nil
}
func unifyHandler(handler http.Handler, grpcServer *grpc.Server, data certData) http.Handler {
h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.ProtoMajor == 2 && strings.HasPrefix(
req.Header.Get("Content-Type"), "application/grpc") {
// grpcProxyServer provides top-level gRPC proxy handler.
grpcServer.ServeHTTP(w, setRealIPRequest(req))
return
}
// handler contains "regular" HTTP handlers
handler.ServeHTTP(w, req)
}))
if value.IsZero(data) {
// If we don't have TLS data, wrap the handler in http2.Server
h = h2c.NewHandler(h, &http2.Server{})
}
return h
}
func runPprofServer(ctx context.Context, bindAddress string, l *zap.Logger) error {
mux := &http.ServeMux{}
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
srv := &http.Server{
Addr: bindAddress,
Handler: mux,
}
l = l.With(zap.String("server", bindAddress), zap.String("server_type", "pprof"))
return runServer(ctx, &server{server: srv}, l)
}
//nolint:unparam
func makeTalosctlHandler(state state.State, logger *zap.Logger) (http.Handler, error) {
// The list of versions does not update very often, so we can cache it.
cacher := cache.Value[releaseData]{Duration: time.Hour}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type result struct {
ReleaseData *releaseData `json:"release_data,omitempty"`
Status string `json:"status"`
}
writeResult := func(a any, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(a); err != nil {
logger.Error("failed to encode result", zap.Error(err))
}
}
ctx := actor.MarkContextAsInternalActor(r.Context())
data, err := cacher.GetOrUpdate(func() (releaseData, error) { return getReleaseData(ctx, state) })
if err != nil {
logger.Error("failed to get latest talosctl release", zap.Error(err))
writeResult(result{Status: "failed to get latest talosctl release"}, http.StatusInternalServerError)
return
}
writeResult(result{
ReleaseData: &data,
Status: "ok",
}, http.StatusOK)
}), nil
}
func getReleaseData(ctx context.Context, state state.State) (releaseData, error) {
all, err := safe.StateListAll[*omnires.TalosVersion](ctx, state)
if err != nil {
return releaseData{}, fmt.Errorf("failed to list all talos versions: %w", err)
}
if all.Len() == 0 {
return releaseData{}, errors.New("no talos versions found")
}
versionNames := make([]string, 0, all.Len())
for it := all.Iterator(); it.Next(); {
version := it.Value().TypedSpec().Value.Version
if !strings.HasPrefix(version, "v") {
version = "v" + version
}
versionNames = append(versionNames, version)
}
releases, err := getGithubReleases(versionNames...)
if err != nil {
return releaseData{}, err
}
return releases, nil
}
func getGithubReleases(tags ...string) (releaseData, error) {
if len(tags) == 0 {
return releaseData{}, errors.New("no tags provided")
}
versions := make(map[string][]talosctlAsset, len(tags))
for _, tag := range tags {
assets := make([]talosctlAsset, 0, len(assetsData))
for _, asset := range assetsData {
assets = append(assets, talosctlAsset{
Name: asset.name,
URL: fmt.Sprintf("https://github.com/siderolabs/talos/releases/download/%s/%s", tag, asset.urlPart),
})
}
versions[tag] = assets
}
return releaseData{
AvailableVersions: versions,
DefaultVersion: tags[len(tags)-1],
}, nil
}
type releaseData struct {
AvailableVersions map[string][]talosctlAsset `json:"available_versions"`
DefaultVersion string `json:"default_version,omitempty"`
}
type talosctlAsset struct {
Name string `json:"name"`
URL string `json:"url"`
}
var assetsData = []struct {
name string
urlPart string
}{
{
"Apple",
"talosctl-darwin-amd64",
},
{
"Apple Silicon",
"talosctl-darwin-arm64",
},
{
"Linux",
"talosctl-linux-amd64",
},
{
"Linux ARM",
"talosctl-linux-armv7",
},
{
"Linux ARM64",
"talosctl-linux-arm64",
},
{
"Windows",
"talosctl-windows-amd64.exe",
},
}