omni/internal/backend/grpc/router/router.go
Artem Chernyshev 3810ccb03f
fix: properly clean up stale Talos gRPC backends
Fixes: https://github.com/siderolabs/omni/issues/432

Signed-off-by: Artem Chernyshev <artem.chernyshev@talos-systems.com>
2024-07-01 17:09:36 +03:00

439 lines
13 KiB
Go

// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package router defines gRPC proxy helpers.
package router
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"math"
"net"
"runtime"
"slices"
"strconv"
"strings"
"time"
"github.com/cosi-project/runtime/pkg/resource"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/prometheus/client_golang/prometheus"
"github.com/siderolabs/gen/xslices"
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/siderolabs/grpc-proxy/proxy"
"github.com/siderolabs/talos/pkg/machinery/client/resolver"
talosconstants "github.com/siderolabs/talos/pkg/machinery/constants"
"github.com/siderolabs/talos/pkg/machinery/role"
"go.uber.org/zap"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/siderolabs/omni/client/api/common"
"github.com/siderolabs/omni/client/pkg/constants"
"github.com/siderolabs/omni/client/pkg/omni/resources"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/dns"
"github.com/siderolabs/omni/internal/memconn"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/certs"
)
const (
talosBackendLRUSize = 32
talosBackendTTL = time.Hour
)
// Router wraps grpc-proxy StreamDirector.
type Router struct {
talosBackends *expirable.LRU[string, proxy.Backend]
sf singleflight.Group
metricCacheSize, metricActiveClients prometheus.Gauge
metricCacheHits, metricCacheMisses prometheus.Counter
omniBackend proxy.Backend
nodeResolver NodeResolver
verifier grpc.UnaryServerInterceptor
cosiState state.State
authEnabled bool
}
// NewRouter builds new Router.
func NewRouter(
transport *memconn.Transport,
cosiState state.State,
nodeResolver NodeResolver,
authEnabled bool,
verifier grpc.UnaryServerInterceptor,
) (*Router, error) {
omniConn, err := grpc.NewClient(transport.Address(),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return transport.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithCodec(proxy.Codec()), //nolint:staticcheck
// we are proxying requests to ourselves, so we don't need to impose a limit
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
)
if err != nil {
return nil, fmt.Errorf("failed to dial omni backend: %w", err)
}
r := &Router{
talosBackends: expirable.NewLRU[string, proxy.Backend](talosBackendLRUSize, nil, talosBackendTTL),
omniBackend: NewOmniBackend("omni", nodeResolver, omniConn),
cosiState: cosiState,
nodeResolver: nodeResolver,
authEnabled: authEnabled,
verifier: verifier,
metricCacheSize: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "omni_grpc_proxy_talos_backend_cache_size",
Help: "Number of Talos clients in the cache of gRPC Proxy.",
}),
metricActiveClients: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "omni_grpc_proxy_talos_backend_active_clients",
Help: "Number of active Talos clients created by gRPC Proxy.",
}),
metricCacheHits: prometheus.NewCounter(prometheus.CounterOpts{
Name: "omni_grpc_proxy_talos_backend_cache_hits_total",
Help: "Number of gRPC Proxy Talos client cache hits.",
}),
metricCacheMisses: prometheus.NewCounter(prometheus.CounterOpts{
Name: "omni_grpc_proxy_talos_backend_cache_misses_total",
Help: "Number of gRPC Proxy Talos client cache misses.",
}),
}
return r, nil
}
// removeBackend clears cached client for a cluster.
func (r *Router) removeBackend(id string) {
r.talosBackends.Remove(id)
}
// Director implements proxy.StreamDirector function.
func (r *Router) Director(ctx context.Context, fullMethodName string) (proxy.Mode, []proxy.Backend, error) {
fullMethodName = strings.TrimLeft(fullMethodName, "/")
// Proxy explicitly local APIs to the local backend.
switch {
case strings.HasPrefix(fullMethodName, "auth."),
strings.HasPrefix(fullMethodName, "config."),
strings.HasPrefix(fullMethodName, "management."),
strings.HasPrefix(fullMethodName, "oidc."),
strings.HasPrefix(fullMethodName, "omni."):
return proxy.One2One, []proxy.Backend{r.omniBackend}, nil
default:
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return proxy.One2One, []proxy.Backend{r.omniBackend}, nil
}
if runtime := md.Get(message.RuntimeHeaderHey); runtime != nil && runtime[0] == common.Runtime_Talos.String() {
backends, err := r.getTalosBackend(ctx, md)
if err != nil {
return proxy.One2One, nil, err
}
return proxy.One2One, backends, nil
}
return proxy.One2One, []proxy.Backend{r.omniBackend}, nil
}
func (r *Router) getTalosBackend(ctx context.Context, md metadata.MD) ([]proxy.Backend, error) {
clusterName := getClusterName(md)
id := fmt.Sprintf("cluster-%s", clusterName)
if clusterName == "" {
id = fmt.Sprintf("machine-%s", getNodeID(md))
}
if backend, ok := r.talosBackends.Get(id); ok {
r.metricCacheHits.Inc()
return []proxy.Backend{backend}, nil
}
ch := r.sf.DoChan(id, func() (any, error) {
ctx = actor.MarkContextAsInternalActor(ctx)
r.metricCacheMisses.Inc()
conn, err := r.getConn(ctx, clusterName)
if err != nil {
return nil, err
}
r.metricActiveClients.Inc()
backend := NewTalosBackend(id, clusterName, r.nodeResolver, conn, r.authEnabled, r.verifier)
r.talosBackends.Add(id, backend)
runtime.SetFinalizer(backend, func(backend *TalosBackend) {
r.metricActiveClients.Dec()
backend.conn.Close() //nolint:errcheck
})
return backend, nil
})
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-ch:
if res.Err != nil {
return nil, res.Err
}
backend := res.Val.(proxy.Backend) //nolint:errcheck,forcetypeassert
return []proxy.Backend{backend}, nil
}
}
func (r *Router) getTransportCredentials(ctx context.Context, contextName string) (credentials.TransportCredentials, []string, error) {
if contextName == "" {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, nil, fmt.Errorf("failed to get node ip from the request")
}
var endpoints []dns.Info
info := resolveNodes(r.nodeResolver, md)
endpoints = info.nodes
if info.nodeOk {
endpoints = []dns.Info{info.node}
}
return credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
}), xslices.Map(endpoints, func(info dns.Info) string {
return net.JoinHostPort(info.GetAddress(), strconv.FormatInt(talosconstants.ApidPort, 10))
}), nil
}
clusterCredentials, err := r.getClusterCredentials(ctx, contextName)
if err != nil {
return nil, nil, err
}
tlsConfig := &tls.Config{}
tlsConfig.RootCAs = x509.NewCertPool()
if ok := tlsConfig.RootCAs.AppendCertsFromPEM(clusterCredentials.CAPEM); !ok {
return nil, nil, errors.New("failed to append CA certificate to RootCAs pool")
}
clientCert, err := tls.X509KeyPair(clusterCredentials.CertPEM, clusterCredentials.KeyPEM)
if err != nil {
return nil, nil, fmt.Errorf("failed to create TLS client certificate: %w", err)
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.Certificates = append(tlsConfig.Certificates, clientCert)
return credentials.NewTLS(tlsConfig), xslices.Map(clusterCredentials.Endpoints, func(endpoint string) string {
return net.JoinHostPort(endpoint, strconv.FormatInt(talosconstants.ApidPort, 10))
}), nil
}
func (r *Router) getConn(ctx context.Context, contextName string) (*grpc.ClientConn, error) {
creds, endpoints, err := r.getTransportCredentials(ctx, contextName)
if err != nil {
return nil, err
}
backoffConfig := backoff.DefaultConfig
backoffConfig.MaxDelay = 15 * time.Second
endpoint := fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ","))
opts := []grpc.DialOption{
grpc.WithInitialWindowSize(65535 * 32),
grpc.WithInitialConnWindowSize(65535 * 16),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoffConfig,
MinConnectTimeout: 20 * time.Second,
}),
grpc.WithTransportCredentials(creds),
grpc.WithCodec(proxy.Codec()), //nolint:staticcheck
grpc.WithSharedWriteBuffer(true),
}
return grpc.NewClient(
endpoint,
opts...,
)
}
type talosClusterCredentials struct {
CAPEM []byte
CertPEM []byte
KeyPEM []byte
Endpoints []string
}
func (r *Router) getClusterCredentials(ctx context.Context, clusterName string) (*talosClusterCredentials, error) {
if clusterName == "" {
return nil, status.Errorf(codes.InvalidArgument, "cluster name is not set")
}
secrets, err := safe.StateGet[*omni.ClusterSecrets](ctx, r.cosiState, omni.NewClusterSecrets(resources.DefaultNamespace, clusterName).Metadata())
if err != nil {
if state.IsNotFoundError(err) {
return nil, status.Errorf(codes.NotFound, "cluster %q is not registered", clusterName)
}
return nil, err
}
// use the `os:impersonator` role here, set the required role directly in router.TalosBackend.GetConnection.
clientCert, CA, err := certs.TalosAPIClientCertificateFromSecrets(secrets, constants.CertificateValidityTime, role.MakeSet(role.Impersonator))
if err != nil {
return nil, err
}
clusterEndpoint, err := safe.StateGet[*omni.ClusterEndpoint](ctx, r.cosiState, omni.NewClusterEndpoint(resources.DefaultNamespace, clusterName).Metadata())
if err != nil {
return nil, err
}
endpoints := clusterEndpoint.TypedSpec().Value.ManagementAddresses
return &talosClusterCredentials{
Endpoints: endpoints,
CAPEM: CA,
CertPEM: clientCert.Crt,
KeyPEM: clientCert.Key,
}, nil
}
// ResourceWatcher watches the resource state and removes cached Talos API connections.
func (r *Router) ResourceWatcher(ctx context.Context, s state.State, logger *zap.Logger) error {
events := make(chan state.Event)
if err := s.WatchKind(ctx, resource.NewMetadata(resources.DefaultNamespace, omni.ClusterType, "", resource.VersionUndefined), events); err != nil {
return err
}
if err := s.WatchKind(ctx, resource.NewMetadata(resources.DefaultNamespace, omni.ClusterSecretsType, "", resource.VersionUndefined), events); err != nil {
return err
}
if err := s.WatchKind(ctx, resource.NewMetadata(resources.DefaultNamespace, omni.ClusterEndpointType, "", resource.VersionUndefined), events); err != nil {
return err
}
if err := s.WatchKind(ctx, resource.NewMetadata(resources.DefaultNamespace, omni.MachineType, "", resource.VersionUndefined), events); err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case e := <-events:
switch e.Type {
case state.Errored:
return fmt.Errorf("talos backend cluster watch failed: %w", e.Error)
case state.Bootstrapped:
// ignore
case state.Created, state.Updated, state.Destroyed:
if e.Type == state.Destroyed && e.Resource.Metadata().Type() == omni.MachineType {
id := fmt.Sprintf("machine-%s", e.Resource.Metadata().ID())
r.removeBackend(id)
logger.Info("remove machine talos backend", zap.String("id", id))
continue
}
id := fmt.Sprintf("cluster-%s", e.Resource.Metadata().ID())
// all resources have cluster name as the ID, drop the backend to make sure we have new connection established
r.removeBackend(id)
logger.Info("remove cluster talos backend", zap.String("id", id))
}
}
}
}
// ExtractContext reads cluster context from the supplied metadata.
func ExtractContext(ctx context.Context) *common.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil
}
return &common.Context{Name: getClusterName(md)}
}
func getClusterName(md metadata.MD) string {
get := func(key string) string {
vals := md.Get(key)
if vals == nil {
return ""
}
return vals[0]
}
if clusterName := get(message.ClusterHeaderKey); clusterName != "" {
return clusterName
}
return get(message.ContextHeaderKey)
}
func getNodeID(md metadata.MD) string {
if nodes := md.Get(nodesHeaderKey); len(nodes) != 0 {
slices.Sort(nodes)
return strings.Join(nodes, ",")
}
return strings.Join(md.Get(nodeHeaderKey), ",")
}
// Describe implements prom.Collector interface.
func (r *Router) Describe(ch chan<- *prometheus.Desc) {
prometheus.DescribeByCollect(r, ch)
}
// Collect implements prom.Collector interface.
func (r *Router) Collect(ch chan<- prometheus.Metric) {
r.metricActiveClients.Collect(ch)
r.metricCacheSize.Set(float64(r.talosBackends.Len()))
r.metricCacheSize.Collect(ch)
r.metricCacheHits.Collect(ch)
r.metricCacheMisses.Collect(ch)
}
var _ prometheus.Collector = &Router{}