talos/pkg/machinery/client/connection.go
Andrey Smirnov 96aa9638f7
chore: rename talos-systems/talos to siderolabs/talos
There's a cyclic dependency on siderolink library which imports talos
machinery back. We will fix that after we get talos pushed under a new
name.

Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
2022-11-03 16:50:32 +04:00

212 lines
5.6 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package client
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
"net/url"
"strings"
"github.com/siderolabs/gen/slices"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
clientconfig "github.com/siderolabs/talos/pkg/machinery/client/config"
"github.com/siderolabs/talos/pkg/machinery/client/resolver"
"github.com/siderolabs/talos/pkg/machinery/constants"
)
// Conn returns underlying client connection.
func (c *Client) Conn() *grpc.ClientConn {
return c.conn.ClientConn
}
// getConn creates new gRPC connection.
func (c *Client) getConn(ctx context.Context, opts ...grpc.DialOption) (*grpcConnectionWrapper, error) {
endpoints := c.GetEndpoints()
target := c.getTarget(
resolver.EnsureEndpointsHavePorts(
reduceURLsToAddresses(endpoints),
constants.ApidPort),
)
dialOpts := []grpc.DialOption(nil)
dialOpts = append(dialOpts, c.options.grpcDialOptions...)
dialOpts = append(dialOpts, opts...)
if c.options.unixSocketPath != "" {
conn, err := grpc.DialContext(ctx, target, dialOpts...)
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
}
tlsConfig := c.options.tlsConfig
if tlsConfig != nil {
return c.makeConnection(ctx, target, credentials.NewTLS(tlsConfig), dialOpts)
}
if err := c.resolveConfigContext(); err != nil {
return nil, fmt.Errorf("failed to resolve configuration context: %w", err)
}
basicAuth := c.options.configContext.Auth.Basic
if basicAuth != nil {
dialOpts = append(dialOpts, WithGRPCBasicAuth(basicAuth.Username, basicAuth.Password))
}
sideroV1 := c.options.configContext.Auth.SideroV1
if sideroV1 != nil {
var contextName string
if c.options.config != nil {
contextName = c.options.config.Context
}
if c.options.contextOverrideSet {
contextName = c.options.contextOverride
}
authInterceptor := newAuthInterceptorConfig(contextName, sideroV1.Identity)
dialOpts = append(dialOpts,
grpc.WithUnaryInterceptor(authInterceptor.Interceptor().Unary()),
grpc.WithStreamInterceptor(authInterceptor.Interceptor().Stream()),
)
}
creds, err := buildCredentials(c.options.configContext, endpoints)
if err != nil {
return nil, err
}
return c.makeConnection(ctx, target, creds, dialOpts)
}
func buildTLSConfig(configContext *clientconfig.Context) (*tls.Config, error) {
tlsConfig := &tls.Config{}
caBytes, err := getCA(configContext)
if err != nil {
return nil, fmt.Errorf("failed to get CA: %w", err)
}
if len(caBytes) > 0 {
tlsConfig.RootCAs = x509.NewCertPool()
if ok := tlsConfig.RootCAs.AppendCertsFromPEM(caBytes); !ok {
return nil, fmt.Errorf("failed to append CA certificate to RootCAs pool")
}
}
crt, err := CertificateFromConfigContext(configContext)
if err != nil {
return nil, fmt.Errorf("failed to acquire credentials: %w", err)
}
if crt != nil {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.Certificates = append(tlsConfig.Certificates, *crt)
}
return tlsConfig, nil
}
func (c *Client) makeConnection(ctx context.Context, target string, creds credentials.TransportCredentials, dialOpts []grpc.DialOption) (*grpcConnectionWrapper, error) {
dialOpts = append(dialOpts,
grpc.WithTransportCredentials(creds),
grpc.WithInitialWindowSize(65535*32),
grpc.WithInitialConnWindowSize(65535*16))
conn, err := grpc.DialContext(ctx, target, dialOpts...)
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
}
func (c *Client) getTarget(endpoints []string) string {
switch {
case c.options.unixSocketPath != "":
return fmt.Sprintf("unix:///%s", c.options.unixSocketPath)
case len(endpoints) > 1:
return fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ","))
default:
// NB: we use the `dns` scheme here in order to handle fancier situations
// when there is a single endpoint.
// Such possibilities include SRV records, multiple IPs from A and/or AAAA
// records, and descriptive TXT records which include things like load
// balancer specs.
return fmt.Sprintf("dns:///%s", endpoints[0])
}
}
func getCA(context *clientconfig.Context) ([]byte, error) {
if context.CA == "" {
return nil, nil
}
caBytes, err := base64.StdEncoding.DecodeString(context.CA)
if err != nil {
return nil, fmt.Errorf("error decoding CA: %w", err)
}
return caBytes, err
}
// CertificateFromConfigContext constructs the client Credentials from the given configuration Context.
func CertificateFromConfigContext(context *clientconfig.Context) (*tls.Certificate, error) {
if context.Crt == "" && context.Key == "" {
return nil, nil
}
crtBytes, err := base64.StdEncoding.DecodeString(context.Crt)
if err != nil {
return nil, fmt.Errorf("error decoding certificate: %w", err)
}
keyBytes, err := base64.StdEncoding.DecodeString(context.Key)
if err != nil {
return nil, fmt.Errorf("error decoding key: %w", err)
}
crt, err := tls.X509KeyPair(crtBytes, keyBytes)
if err != nil {
return nil, fmt.Errorf("could not load client key pair: %s", err)
}
return &crt, nil
}
func reduceURLsToAddresses(endpoints []string) []string {
return slices.Map(endpoints, func(endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
return endpoint
}
if u.Scheme == "https" && u.Port() == "" {
return net.JoinHostPort(u.Hostname(), "443")
}
if u.Scheme != "" {
if u.Port() != "" {
return net.JoinHostPort(u.Hostname(), u.Port())
}
if u.Opaque == "" {
return u.Host
}
}
return endpoint
})
}