mirror of
https://github.com/siderolabs/talos.git
synced 2025-11-07 20:01:46 +01:00
There's a cyclic dependency on siderolink library which imports talos machinery back. We will fix that after we get talos pushed under a new name. Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
212 lines
5.6 KiB
Go
212 lines
5.6 KiB
Go
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package client
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/siderolabs/gen/slices"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
|
|
clientconfig "github.com/siderolabs/talos/pkg/machinery/client/config"
|
|
"github.com/siderolabs/talos/pkg/machinery/client/resolver"
|
|
"github.com/siderolabs/talos/pkg/machinery/constants"
|
|
)
|
|
|
|
// Conn returns underlying client connection.
|
|
func (c *Client) Conn() *grpc.ClientConn {
|
|
return c.conn.ClientConn
|
|
}
|
|
|
|
// getConn creates new gRPC connection.
|
|
func (c *Client) getConn(ctx context.Context, opts ...grpc.DialOption) (*grpcConnectionWrapper, error) {
|
|
endpoints := c.GetEndpoints()
|
|
|
|
target := c.getTarget(
|
|
resolver.EnsureEndpointsHavePorts(
|
|
reduceURLsToAddresses(endpoints),
|
|
constants.ApidPort),
|
|
)
|
|
|
|
dialOpts := []grpc.DialOption(nil)
|
|
|
|
dialOpts = append(dialOpts, c.options.grpcDialOptions...)
|
|
|
|
dialOpts = append(dialOpts, opts...)
|
|
|
|
if c.options.unixSocketPath != "" {
|
|
conn, err := grpc.DialContext(ctx, target, dialOpts...)
|
|
|
|
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
|
|
}
|
|
|
|
tlsConfig := c.options.tlsConfig
|
|
|
|
if tlsConfig != nil {
|
|
return c.makeConnection(ctx, target, credentials.NewTLS(tlsConfig), dialOpts)
|
|
}
|
|
|
|
if err := c.resolveConfigContext(); err != nil {
|
|
return nil, fmt.Errorf("failed to resolve configuration context: %w", err)
|
|
}
|
|
|
|
basicAuth := c.options.configContext.Auth.Basic
|
|
if basicAuth != nil {
|
|
dialOpts = append(dialOpts, WithGRPCBasicAuth(basicAuth.Username, basicAuth.Password))
|
|
}
|
|
|
|
sideroV1 := c.options.configContext.Auth.SideroV1
|
|
if sideroV1 != nil {
|
|
var contextName string
|
|
|
|
if c.options.config != nil {
|
|
contextName = c.options.config.Context
|
|
}
|
|
|
|
if c.options.contextOverrideSet {
|
|
contextName = c.options.contextOverride
|
|
}
|
|
|
|
authInterceptor := newAuthInterceptorConfig(contextName, sideroV1.Identity)
|
|
dialOpts = append(dialOpts,
|
|
grpc.WithUnaryInterceptor(authInterceptor.Interceptor().Unary()),
|
|
grpc.WithStreamInterceptor(authInterceptor.Interceptor().Stream()),
|
|
)
|
|
}
|
|
|
|
creds, err := buildCredentials(c.options.configContext, endpoints)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return c.makeConnection(ctx, target, creds, dialOpts)
|
|
}
|
|
|
|
func buildTLSConfig(configContext *clientconfig.Context) (*tls.Config, error) {
|
|
tlsConfig := &tls.Config{}
|
|
|
|
caBytes, err := getCA(configContext)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get CA: %w", err)
|
|
}
|
|
|
|
if len(caBytes) > 0 {
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
|
|
if ok := tlsConfig.RootCAs.AppendCertsFromPEM(caBytes); !ok {
|
|
return nil, fmt.Errorf("failed to append CA certificate to RootCAs pool")
|
|
}
|
|
}
|
|
|
|
crt, err := CertificateFromConfigContext(configContext)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to acquire credentials: %w", err)
|
|
}
|
|
|
|
if crt != nil {
|
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
tlsConfig.Certificates = append(tlsConfig.Certificates, *crt)
|
|
}
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (c *Client) makeConnection(ctx context.Context, target string, creds credentials.TransportCredentials, dialOpts []grpc.DialOption) (*grpcConnectionWrapper, error) {
|
|
dialOpts = append(dialOpts,
|
|
grpc.WithTransportCredentials(creds),
|
|
grpc.WithInitialWindowSize(65535*32),
|
|
grpc.WithInitialConnWindowSize(65535*16))
|
|
|
|
conn, err := grpc.DialContext(ctx, target, dialOpts...)
|
|
|
|
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
|
|
}
|
|
|
|
func (c *Client) getTarget(endpoints []string) string {
|
|
switch {
|
|
case c.options.unixSocketPath != "":
|
|
return fmt.Sprintf("unix:///%s", c.options.unixSocketPath)
|
|
case len(endpoints) > 1:
|
|
return fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ","))
|
|
default:
|
|
// NB: we use the `dns` scheme here in order to handle fancier situations
|
|
// when there is a single endpoint.
|
|
// Such possibilities include SRV records, multiple IPs from A and/or AAAA
|
|
// records, and descriptive TXT records which include things like load
|
|
// balancer specs.
|
|
return fmt.Sprintf("dns:///%s", endpoints[0])
|
|
}
|
|
}
|
|
|
|
func getCA(context *clientconfig.Context) ([]byte, error) {
|
|
if context.CA == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
caBytes, err := base64.StdEncoding.DecodeString(context.CA)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding CA: %w", err)
|
|
}
|
|
|
|
return caBytes, err
|
|
}
|
|
|
|
// CertificateFromConfigContext constructs the client Credentials from the given configuration Context.
|
|
func CertificateFromConfigContext(context *clientconfig.Context) (*tls.Certificate, error) {
|
|
if context.Crt == "" && context.Key == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
crtBytes, err := base64.StdEncoding.DecodeString(context.Crt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding certificate: %w", err)
|
|
}
|
|
|
|
keyBytes, err := base64.StdEncoding.DecodeString(context.Key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding key: %w", err)
|
|
}
|
|
|
|
crt, err := tls.X509KeyPair(crtBytes, keyBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load client key pair: %s", err)
|
|
}
|
|
|
|
return &crt, nil
|
|
}
|
|
|
|
func reduceURLsToAddresses(endpoints []string) []string {
|
|
return slices.Map(endpoints, func(endpoint string) string {
|
|
u, err := url.Parse(endpoint)
|
|
if err != nil {
|
|
return endpoint
|
|
}
|
|
|
|
if u.Scheme == "https" && u.Port() == "" {
|
|
return net.JoinHostPort(u.Hostname(), "443")
|
|
}
|
|
|
|
if u.Scheme != "" {
|
|
if u.Port() != "" {
|
|
return net.JoinHostPort(u.Hostname(), u.Port())
|
|
}
|
|
|
|
if u.Opaque == "" {
|
|
return u.Host
|
|
}
|
|
}
|
|
|
|
return endpoint
|
|
})
|
|
}
|