From 6b23deddcf5cc39911f58bdedd0a70a208d64445 Mon Sep 17 00:00:00 2001 From: Philipp Sauter Date: Fri, 5 Aug 2022 22:54:16 +0200 Subject: [PATCH] feat: support custom ports for connecting to apid from talosctl Users can now add a port suffix to the endpoints used by talosctl. Either in the CLI flag or the ~/.talos/config. The default port is still 50000. Signed-off-by: Philipp Sauter --- pkg/grpc/gen/remote.go | 10 +-- pkg/machinery/client/client.go | 76 ++----------------- pkg/machinery/client/resolver.go | 16 ---- pkg/machinery/client/resolver/roundrobin.go | 52 ++++++++----- .../client/resolver/roundrobin_test.go | 41 ++++++++++ 5 files changed, 86 insertions(+), 109 deletions(-) delete mode 100644 pkg/machinery/client/resolver.go create mode 100644 pkg/machinery/client/resolver/roundrobin_test.go diff --git a/pkg/grpc/gen/remote.go b/pkg/grpc/gen/remote.go index 9685f4777..c0381e90f 100644 --- a/pkg/grpc/gen/remote.go +++ b/pkg/grpc/gen/remote.go @@ -20,12 +20,6 @@ import ( "github.com/talos-systems/talos/pkg/machinery/constants" ) -var trustdResolverScheme string - -func init() { - trustdResolverScheme = resolver.RegisterRoundRobinResolver(constants.TrustdPort) -} - // RemoteGenerator represents the OS identity generator. type RemoteGenerator struct { conn *grpc.ClientConn @@ -38,9 +32,11 @@ func NewRemoteGenerator(token string, endpoints []string, ca *x509.PEMEncodedCer return nil, fmt.Errorf("at least one root of trust endpoint is required") } + endpoints = resolver.EnsureEndpointsHavePorts(endpoints, constants.TrustdPort) + g = &RemoteGenerator{} - conn, err := basic.NewConnection(fmt.Sprintf("%s:///%s", trustdResolverScheme, strings.Join(endpoints, ",")), basic.NewTokenCredentials(token), ca) + conn, err := basic.NewConnection(fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ",")), basic.NewTokenCredentials(token), ca) if err != nil { return nil, err } diff --git a/pkg/machinery/client/client.go b/pkg/machinery/client/client.go index 66be40d4a..c0ff2c992 100644 --- a/pkg/machinery/client/client.go +++ b/pkg/machinery/client/client.go @@ -19,7 +19,6 @@ import ( "time" grpctls "github.com/talos-systems/crypto/tls" - "github.com/talos-systems/net" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -35,6 +34,7 @@ import ( storageapi "github.com/talos-systems/talos/pkg/machinery/api/storage" timeapi "github.com/talos-systems/talos/pkg/machinery/api/time" clientconfig "github.com/talos-systems/talos/pkg/machinery/client/config" + "github.com/talos-systems/talos/pkg/machinery/client/resolver" "github.com/talos-systems/talos/pkg/machinery/constants" ) @@ -147,7 +147,7 @@ func New(ctx context.Context, opts ...OptionFunc) (c *Client, err error) { return nil, errors.New("failed to determine endpoints") } - c.conn, err = c.GetConn(ctx) + c.conn, err = c.getConn(ctx) if err != nil { return nil, fmt.Errorf("failed to create client connection: %w", err) } @@ -165,9 +165,9 @@ func New(ctx context.Context, opts ...OptionFunc) (c *Client, err error) { return c, nil } -// GetConn creates new gRPC connection. -func (c *Client) GetConn(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - endpoints := c.GetEndpoints() +// getConn creates new gRPC connection. +func (c *Client) getConn(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + endpoints := resolver.EnsureEndpointsHavePorts(c.GetEndpoints(), constants.ApidPort) var target string @@ -175,14 +175,14 @@ func (c *Client) GetConn(ctx context.Context, opts ...grpc.DialOption) (*grpc.Cl case c.options.unixSocketPath != "": target = fmt.Sprintf("unix:///%s", c.options.unixSocketPath) case len(endpoints) > 1: - target = fmt.Sprintf("%s:///%s", talosListResolverScheme, strings.Join(endpoints, ",")) + target = fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ",")) default: // NB: we use the `dns` scheme here in order to handle fancier situations // when there is a single endpoint. // Such possibilities include SRV records, multiple IPs from A and/or AAAA // records, and descriptive TXT records which include things like load // balancer specs. - target = fmt.Sprintf("dns:///%s:%d", net.FormatAddress(endpoints[0]), constants.ApidPort) + target = fmt.Sprintf("dns:///%s", endpoints[0]) } dialOpts := []grpc.DialOption(nil) @@ -252,68 +252,6 @@ func CredentialsFromConfigContext(context *clientconfig.Context) (*Credentials, }, nil } -// NewClientContextAndCredentialsFromConfig initializes Credentials from config file. -// -// Deprecated: use Option-based methods for client creation. -func NewClientContextAndCredentialsFromConfig(p, ctx string) (context *clientconfig.Context, creds *Credentials, err error) { - c, err := clientconfig.Open(p) - if err != nil { - return - } - - context, creds, err = NewClientContextAndCredentialsFromParsedConfig(c, ctx) - - return -} - -// NewClientContextAndCredentialsFromParsedConfig initializes Credentials from parsed configuration. -// -// Deprecated: use Option-based methods for client creation. -func NewClientContextAndCredentialsFromParsedConfig(c *clientconfig.Config, ctx string) (context *clientconfig.Context, creds *Credentials, err error) { - if ctx != "" { - c.Context = ctx - } - - if c.Context == "" { - return nil, nil, fmt.Errorf("'context' key is not set in the config") - } - - context = c.Contexts[c.Context] - if context == nil { - return nil, nil, fmt.Errorf("context %q is not defined in 'contexts' key in config", c.Context) - } - - creds, err = CredentialsFromConfigContext(context) - if err != nil { - return nil, nil, fmt.Errorf("failed to extract credentials from context: %w", err) - } - - return context, creds, nil -} - -// NewClient initializes a Client. -// -// Deprecated: use client.NewFromConfigContext() instead. -func NewClient(cfg *tls.Config, endpoints []string, port int, opts ...grpc.DialOption) (c *Client, err error) { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(cfg))) - - cfg.ServerName = endpoints[0] - - c = &Client{} - - // TODO(smira): endpoints[0] should be replaced with proper load-balancing - c.conn, err = grpc.DialContext(context.Background(), fmt.Sprintf("%s:%d", net.FormatAddress(endpoints[0]), port), opts...) - if err != nil { - return - } - - c.MachineClient = machineapi.NewMachineServiceClient(c.conn) - c.TimeClient = timeapi.NewTimeServiceClient(c.conn) - c.ClusterClient = clusterapi.NewClusterServiceClient(c.conn) - - return c, nil -} - // Close shuts down client protocol. func (c *Client) Close() error { return c.conn.Close() diff --git a/pkg/machinery/client/resolver.go b/pkg/machinery/client/resolver.go deleted file mode 100644 index 0a3680fd6..000000000 --- a/pkg/machinery/client/resolver.go +++ /dev/null @@ -1,16 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package client - -import ( - "github.com/talos-systems/talos/pkg/machinery/client/resolver" - "github.com/talos-systems/talos/pkg/machinery/constants" -) - -var talosListResolverScheme string - -func init() { - talosListResolverScheme = resolver.RegisterRoundRobinResolver(constants.ApidPort) -} diff --git a/pkg/machinery/client/resolver/roundrobin.go b/pkg/machinery/client/resolver/roundrobin.go index d6db7696d..e3ea8d8cc 100644 --- a/pkg/machinery/client/resolver/roundrobin.go +++ b/pkg/machinery/client/resolver/roundrobin.go @@ -5,30 +5,28 @@ package resolver import ( - "fmt" "math/rand" + "net" + "strconv" "strings" - "github.com/talos-systems/net" "google.golang.org/grpc/resolver" + + "github.com/talos-systems/talos/pkg/machinery/generic/slices" ) -// RegisterRoundRobinResolver registers round-robin gRPC resolver for specified port and returns scheme to use in grpc.Dial. -func RegisterRoundRobinResolver(port int) (scheme string) { - scheme = fmt.Sprintf(roundRobinResolverScheme, port) +// RoundRobinResolverScheme is a scheme to use in grpc.Dial for the round-robin gRPC resolver. +// This resolver requires that all endpoints have a port appended. +// To ensure this, use EnsureEndpointsHavePorts before constructing a connection string. +const RoundRobinResolverScheme = "talosroundrobin" +func init() { resolver.Register(&roundRobinResolverBuilder{ - port: port, - scheme: scheme, + scheme: RoundRobinResolverScheme, }) - - return } -const roundRobinResolverScheme = "taloslist-%d" - type roundRobinResolverBuilder struct { - port int scheme string } @@ -37,7 +35,6 @@ func (b *roundRobinResolverBuilder) Build(target resolver.Target, cc resolver.Cl r := &roundRobinResolver{ target: target, cc: cc, - port: b.port, } if err := r.start(); err != nil { @@ -55,16 +52,37 @@ func (b *roundRobinResolverBuilder) Scheme() string { type roundRobinResolver struct { target resolver.Target cc resolver.ClientConn - port int +} + +// EnsureEndpointsHavePorts returns the list of endpoints with default port appended to those addresses that don't have a port. +func EnsureEndpointsHavePorts(endpoints []string, defaultPort int) []string { + return slices.Map(endpoints, func(endpoint string) string { + _, _, err := net.SplitHostPort(endpoint) + if err != nil { + return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)) + } + + return endpoint + }) } func (r *roundRobinResolver) start() error { var addrs []resolver.Address //nolint:prealloc - for _, a := range strings.Split(r.target.Endpoint, ",") { //nolint:staticcheck + //nolint:staticcheck + endpoints := strings.Split(r.target.Endpoint, ",") + + for _, addr := range endpoints { + serverName := addr + + host, _, err := net.SplitHostPort(serverName) + if err == nil { + serverName = host + } + addrs = append(addrs, resolver.Address{ - ServerName: a, - Addr: fmt.Sprintf("%s:%d", net.FormatAddress(a), r.port), + ServerName: serverName, + Addr: addr, }) } diff --git a/pkg/machinery/client/resolver/roundrobin_test.go b/pkg/machinery/client/resolver/roundrobin_test.go new file mode 100644 index 000000000..efb275544 --- /dev/null +++ b/pkg/machinery/client/resolver/roundrobin_test.go @@ -0,0 +1,41 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package resolver_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/talos-systems/talos/pkg/machinery/client/resolver" + "github.com/talos-systems/talos/pkg/machinery/constants" +) + +func TestEnsureEndpointsHavePorts(t *testing.T) { + endpoints := []string{ + "123.123.123.123", + "exammple.com:111", + "234.234.234.234:4000", + "localhost", + "localhost:890", + "2001:db8:0:0:0:ff00:42:8329", + "www.company.com", + "[2001:db8:4006:812::200e]:8080", + } + expected := []string{ + "123.123.123.123:50000", + "exammple.com:111", + "234.234.234.234:4000", + "localhost:50000", + "localhost:890", + "[2001:db8:0:0:0:ff00:42:8329]:50000", + "www.company.com:50000", + "[2001:db8:4006:812::200e]:8080", + } + + actual := resolver.EnsureEndpointsHavePorts(endpoints, constants.ApidPort) + + assert.Equal(t, expected, actual) +}