chore: support getting multiple endpoints from the Provision rpc call

The code will rotate through the endpoints, until it reaches the end, and only then it will try to do the provisioning again.

Closes #7973

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
This commit is contained in:
Dmitriy Matrenichev 2023-11-24 19:19:34 +03:00
parent dd45dd06cf
commit ba827bf8b8
No known key found for this signature in database
GPG Key ID: D3363CF894E68892
4 changed files with 161 additions and 91 deletions

2
go.mod
View File

@ -124,7 +124,7 @@ require (
github.com/siderolabs/grpc-proxy v0.4.0
github.com/siderolabs/kms-client v0.1.0
github.com/siderolabs/net v0.4.0
github.com/siderolabs/siderolink v0.3.2
github.com/siderolabs/siderolink v0.3.3
github.com/siderolabs/talos/pkg/machinery v1.6.0-alpha.2
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5

4
go.sum
View File

@ -691,8 +691,8 @@ github.com/siderolabs/net v0.4.0 h1:1bOgVay/ijPkJz4qct98nHsiB/ysLQU0KLoBC4qLm7I=
github.com/siderolabs/net v0.4.0/go.mod h1:/ibG+Hm9HU27agp5r9Q3eZicEfjquzNzQNux5uEk0kM=
github.com/siderolabs/protoenc v0.2.1 h1:BqxEmeWQeMpNP3R6WrPqDatX8sM/r4t97OP8mFmg6GA=
github.com/siderolabs/protoenc v0.2.1/go.mod h1:StTHxjet1g11GpNAWiATgc8K0HMKiFSEVVFOa/H0otc=
github.com/siderolabs/siderolink v0.3.2 h1:ULFHQAgxtVCU7Sd+GLP7bDSQBXrwTtppaI4TKl/YqZc=
github.com/siderolabs/siderolink v0.3.2/go.mod h1:juxlSF9cBzeBHsOjS7hVS3s0NDpC034i/OZunVReqmo=
github.com/siderolabs/siderolink v0.3.3 h1:rnsN4K4TPtk38Ygs/oKQsiVe8iYUi9RRS8gh4U7mbGM=
github.com/siderolabs/siderolink v0.3.3/go.mod h1:juxlSF9cBzeBHsOjS7hVS3s0NDpC034i/OZunVReqmo=
github.com/siderolabs/tcpproxy v0.1.0 h1:IbkS9vRhjMOscc1US3M5P1RnsGKFgB6U5IzUk+4WkKA=
github.com/siderolabs/tcpproxy v0.1.0/go.mod h1:onn6CPPj/w1UNqQ0U97oRPF0CqbrgEApYCw4P9IiCW8=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=

View File

@ -44,6 +44,7 @@ import (
// ManagerController interacts with SideroLink API and brings up the SideroLink Wireguard interface.
type ManagerController struct {
nodeKey wgtypes.Key
pd provisionData
}
// Name implements controller.Controller interface.
@ -146,99 +147,25 @@ func (ctrl *ManagerController) Run(ctx context.Context, r controller.Runtime, lo
case <-r.EventCh():
}
cfg, err := safe.ReaderGetByID[*siderolink.Config](ctx, r, siderolink.ConfigID)
if err != nil {
if state.IsNotFoundError(err) {
if cleanupErr := ctrl.cleanup(ctx, r, nil, nil, logger); cleanupErr != nil {
return fmt.Errorf("failed to do cleanup: %w", cleanupErr)
}
if ctrl.pd.IsEmpty() {
provision, err := ctrl.provision(ctx, r, logger)
if err != nil {
return fmt.Errorf("error provisioning: %w", err)
}
// no config
if !provision.IsPresent() {
continue
}
return fmt.Errorf("failed to get siderolink config: %w", err)
ctrl.pd = provision.ValueOrZero()
}
sysInfo, err := safe.ReaderGetByID[*hardware.SystemInformation](ctx, r, hardware.SystemInformationID)
if err != nil {
if state.IsNotFoundError(err) {
// no system information
continue
}
return fmt.Errorf("failed to get system information: %w", err)
}
nodeUUID := sysInfo.TypedSpec().UUID
stringEndpoint := cfg.TypedSpec().APIEndpoint
parsedEndpoint, err := endpoint.Parse(stringEndpoint)
if err != nil {
return fmt.Errorf("failed to parse siderolink endpoint: %w", err)
}
var transportCredentials credentials.TransportCredentials
if parsedEndpoint.Insecure {
transportCredentials = insecure.NewCredentials()
} else {
transportCredentials = credentials.NewTLS(&tls.Config{})
}
provision := func() (*pb.ProvisionResponse, error) {
connCtx, connCtxCancel := context.WithTimeout(ctx, 10*time.Second)
defer connCtxCancel()
conn, connErr := grpc.DialContext(
connCtx,
parsedEndpoint.Host,
grpc.WithTransportCredentials(transportCredentials),
grpc.WithSharedWriteBuffer(true),
)
if connErr != nil {
return nil, fmt.Errorf("error dialing SideroLink endpoint %q: %w", stringEndpoint, connErr)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
logger.Error("failed to close SideroLink provisioning GRPC connection", zap.Error(closeErr))
}
}()
uniqTokenRes, rdrErr := safe.ReaderGetByID[*runtime.UniqueMachineToken](ctx, r, runtime.UniqueMachineTokenID)
if rdrErr != nil {
return nil, fmt.Errorf("failed to get unique token: %w", rdrErr)
}
sideroLinkClient := pb.NewProvisionServiceClient(conn)
request := &pb.ProvisionRequest{
NodeUuid: nodeUUID,
NodePublicKey: ctrl.nodeKey.PublicKey().String(),
NodeUniqueToken: pointer.To(uniqTokenRes.TypedSpec().Token),
TalosVersion: pointer.To(version.Tag),
}
token := parsedEndpoint.GetParam("jointoken")
if token != "" {
request.JoinToken = pointer.To(token)
}
return sideroLinkClient.Provision(ctx, request)
}
resp, err := provision()
if err != nil {
return err
}
serverAddress, err := netip.ParseAddr(resp.ServerAddress)
serverAddress, err := netip.ParseAddr(ctrl.pd.ServerAddress)
if err != nil {
return fmt.Errorf("error parsing server address: %w", err)
}
nodeAddress, err := netip.ParsePrefix(resp.NodeAddressPrefix)
nodeAddress, err := netip.ParsePrefix(ctrl.pd.NodeAddressPrefix)
if err != nil {
return fmt.Errorf("error parsing node address: %w", err)
}
@ -246,6 +173,18 @@ func (ctrl *ManagerController) Run(ctx context.Context, r controller.Runtime, lo
linkSpec := network.NewLinkSpec(network.ConfigNamespaceName, network.LayeredID(network.ConfigOperator, network.LinkID(constants.SideroLinkName)))
addressSpec := network.NewAddressSpec(network.ConfigNamespaceName, network.LayeredID(network.ConfigOperator, network.AddressID(constants.SideroLinkName, nodeAddress)))
// Rotate through the endpoints.
ep, ok := ctrl.pd.TakeEndpoint()
if !ok {
return errors.New("host returned no endpoints")
}
logger.Info(
"configuring siderolink connection",
zap.String("peer_endpoint", ep),
zap.String("next_peer_endpoint", ctrl.pd.PeekNextEndpoint()),
)
if err := safe.WriterModify(ctx, r, linkSpec,
func(res *network.LinkSpec) error {
spec := res.TypedSpec()
@ -262,8 +201,8 @@ func (ctrl *ManagerController) Run(ctx context.Context, r controller.Runtime, lo
PrivateKey: ctrl.nodeKey.String(),
Peers: []network.WireguardPeer{
{
PublicKey: resp.ServerPublicKey,
Endpoint: resp.ServerEndpoint,
PublicKey: ctrl.pd.ServerPublicKey,
Endpoint: ep,
AllowedIPs: []netip.Prefix{
netip.PrefixFrom(serverAddress, serverAddress.BitLen()),
},
@ -310,13 +249,144 @@ func (ctrl *ManagerController) Run(ctx context.Context, r controller.Runtime, lo
logger.Info(
"siderolink connection configured",
zap.String("endpoint", stringEndpoint),
zap.String("node_uuid", nodeUUID),
zap.String("endpoint", ctrl.pd.apiEndpont),
zap.String("node_uuid", ctrl.pd.nodeUUID),
zap.String("node_address", nodeAddress.String()),
)
}
}
//nolint:gocyclo
func (ctrl *ManagerController) provision(ctx context.Context, r controller.Runtime, logger *zap.Logger) (optional.Optional[provisionData], error) {
cfg, err := safe.ReaderGetByID[*siderolink.Config](ctx, r, siderolink.ConfigID)
if err != nil {
if state.IsNotFoundError(err) {
if cleanupErr := ctrl.cleanup(ctx, r, nil, nil, logger); cleanupErr != nil {
return optional.None[provisionData](), fmt.Errorf("failed to do cleanup: %w", cleanupErr)
}
// no config
return optional.None[provisionData](), nil
}
return optional.None[provisionData](), fmt.Errorf("failed to get siderolink config: %w", err)
}
sysInfo, err := safe.ReaderGetByID[*hardware.SystemInformation](ctx, r, hardware.SystemInformationID)
if err != nil {
if state.IsNotFoundError(err) {
// no system information
return optional.None[provisionData](), nil
}
return optional.None[provisionData](), fmt.Errorf("failed to get system information: %w", err)
}
nodeUUID := sysInfo.TypedSpec().UUID
stringEndpoint := cfg.TypedSpec().APIEndpoint
parsedEndpoint, err := endpoint.Parse(stringEndpoint)
if err != nil {
return optional.None[provisionData](), fmt.Errorf("failed to parse siderolink endpoint: %w", err)
}
var transportCredentials credentials.TransportCredentials
if parsedEndpoint.Insecure {
transportCredentials = insecure.NewCredentials()
} else {
transportCredentials = credentials.NewTLS(&tls.Config{})
}
provision := func() (*pb.ProvisionResponse, error) {
connCtx, connCtxCancel := context.WithTimeout(ctx, 10*time.Second)
defer connCtxCancel()
conn, connErr := grpc.DialContext(
connCtx,
parsedEndpoint.Host,
grpc.WithTransportCredentials(transportCredentials),
grpc.WithSharedWriteBuffer(true),
)
if connErr != nil {
return nil, fmt.Errorf("error dialing SideroLink endpoint %q: %w", stringEndpoint, connErr)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
logger.Error("failed to close SideroLink provisioning GRPC connection", zap.Error(closeErr))
}
}()
uniqTokenRes, rdrErr := safe.ReaderGetByID[*runtime.UniqueMachineToken](ctx, r, runtime.UniqueMachineTokenID)
if rdrErr != nil {
return nil, fmt.Errorf("failed to get unique token: %w", rdrErr)
}
sideroLinkClient := pb.NewProvisionServiceClient(conn)
request := &pb.ProvisionRequest{
NodeUuid: nodeUUID,
NodePublicKey: ctrl.nodeKey.PublicKey().String(),
NodeUniqueToken: pointer.To(uniqTokenRes.TypedSpec().Token),
TalosVersion: pointer.To(version.Tag),
}
token := parsedEndpoint.GetParam("jointoken")
if token != "" {
request.JoinToken = pointer.To(token)
}
return sideroLinkClient.Provision(ctx, request)
}
resp, err := provision()
if err != nil {
return optional.None[provisionData](), err
}
return optional.Some(provisionData{
nodeUUID: nodeUUID,
apiEndpont: stringEndpoint,
ServerAddress: resp.ServerAddress,
ServerPublicKey: resp.ServerPublicKey,
NodeAddressPrefix: resp.NodeAddressPrefix,
endpoints: resp.GetEndpoints(),
}), nil
}
type provisionData struct {
nodeUUID string
apiEndpont string
ServerAddress string
ServerPublicKey string
NodeAddressPrefix string
endpoints []string
}
func (d *provisionData) IsEmpty() bool {
return d == nil || len(d.endpoints) == 0
}
func (d *provisionData) TakeEndpoint() (string, bool) {
if d.IsEmpty() {
return "", false
}
ep := d.endpoints[0]
d.endpoints = d.endpoints[1:]
return ep, true
}
func (d *provisionData) PeekNextEndpoint() string {
if d.IsEmpty() {
return ""
}
return d.endpoints[0]
}
func (ctrl *ManagerController) cleanup(
ctx context.Context,
r controller.Runtime,

View File

@ -72,7 +72,7 @@ const (
func (srv mockServer) Provision(_ context.Context, _ *pb.ProvisionRequest) (*pb.ProvisionResponse, error) {
return &pb.ProvisionResponse{
ServerEndpoint: mockServerEndpoint,
ServerEndpoint: pb.MakeEndpoints(mockServerEndpoint),
ServerAddress: mockServerAddress,
ServerPublicKey: mockServerPublicKey,
NodeAddressPrefix: mockNodeAddressPrefix,