omni/internal/pkg/siderolink/siderolink_test.go
Dmitriy Matrenichev e2f5795789
chore: allow multiple IP's for siderolink-wireguard-advertised-addr flag
The code is already there: Talos will simply fail to connect and will try again by rotating the IP.
We simply add support for specifying multiple IP's in the `siderolink-wireguard-advertised-addr` flag separated by a comma.

Fixes #495

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
2024-08-28 20:41:29 +03:00

425 lines
13 KiB
Go

// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package siderolink_test
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/cosi-project/runtime/pkg/resource"
"github.com/cosi-project/runtime/pkg/resource/rtestutils"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
"github.com/cosi-project/runtime/pkg/state/impl/inmem"
"github.com/cosi-project/runtime/pkg/state/impl/namespaced"
"github.com/siderolabs/gen/xtesting/must"
"github.com/siderolabs/go-pointer"
"github.com/siderolabs/go-retry/retry"
pb "github.com/siderolabs/siderolink/api/siderolink"
"github.com/siderolabs/siderolink/pkg/wireguard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/proto"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/client/pkg/omni/resources/siderolink"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/errgroup"
"github.com/siderolabs/omni/internal/pkg/grpcutil"
"github.com/siderolabs/omni/internal/pkg/machinestatus"
sideromanager "github.com/siderolabs/omni/internal/pkg/siderolink"
)
type fakeWireguardHandler struct {
logger *zap.Logger
loggerMu sync.Mutex
}
func (h *fakeWireguardHandler) SetupDevice(wireguard.DeviceConfig) error {
return nil
}
func (h *fakeWireguardHandler) Run(ctx context.Context, logger *zap.Logger) error {
h.loggerMu.Lock()
unlock := sync.OnceFunc(h.loggerMu.Unlock)
defer unlock()
h.logger = logger
unlock()
<-ctx.Done()
return nil
}
func (h *fakeWireguardHandler) Shutdown() error {
return nil
}
func (h *fakeWireguardHandler) PeerEvent(_ context.Context, spec *specs.SiderolinkSpec, deleted bool) error {
h.loggerMu.Lock()
defer h.loggerMu.Unlock()
msg := "updated peer"
if deleted {
msg = "removed peer"
}
h.logger.Info(msg, zap.String("public_key", spec.NodePublicKey), zap.String("address", spec.NodeSubnet))
return nil
}
func (h *fakeWireguardHandler) Peers() ([]wgtypes.Peer, error) {
return []wgtypes.Peer{}, nil
}
type SiderolinkSuite struct {
suite.Suite
ctx context.Context //nolint:containedctx
ctxCancel context.CancelFunc
state state.State
manager *sideromanager.Manager
address string
wg sync.WaitGroup
}
func (suite *SiderolinkSuite) SetupTest() {
suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute)
suite.state = state.WrapCore(namespaced.NewState(inmem.Build))
params := sideromanager.Params{
WireguardEndpoint: "127.0.0.1:0",
AdvertisedEndpoint: config.Config.SiderolinkWireguardAdvertisedAddress + "," + TestIP,
APIEndpoint: "127.0.0.1:0",
}
var err error
machineStatusHandler := machinestatus.NewHandler(suite.state, zaptest.NewLogger(suite.T()), make(chan *omni.MachineStatusSnapshot))
suite.manager, err = sideromanager.NewManager(suite.ctx, suite.state, &fakeWireguardHandler{}, params, zaptest.NewLogger(suite.T()), nil, machineStatusHandler, nil)
suite.Require().NoError(err)
suite.startManager(params)
}
func (suite *SiderolinkSuite) startManager(params sideromanager.Params) {
suite.wg.Add(1)
lis, err := params.NewListener()
suite.Require().NoError(err)
suite.address = lis.Addr().String()
go func() {
defer suite.wg.Done()
eg, groupCtx := errgroup.WithContext(suite.ctx)
server := grpc.NewServer()
suite.manager.Register(
server,
)
eg.Go(func() error {
return suite.manager.Run(
groupCtx,
"127.0.0.1",
"0",
"0",
"",
)
})
grpcutil.RunServer(groupCtx, server, lis, eg, zaptest.NewLogger(suite.T()))
suite.Require().NoError(eg.Wait())
}()
}
func (suite *SiderolinkSuite) TestNodes() {
var spec *specs.ConnectionParamsSpec
ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2)
defer cancel()
rtestutils.AssertResources[*siderolink.Config](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.Config, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.PrivateKey)
assertion.NotEmpty(r.TypedSpec().Value.PublicKey)
})
rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.Args)
assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint)
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint)
spec = r.TypedSpec().Value
})
conn, err := grpc.NewClient(suite.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
suite.Require().NoError(err)
client := pb.NewProvisionServiceClient(conn)
privateKey, err := wgtypes.GeneratePrivateKey()
suite.Require().NoError(err)
resp, err := client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})
suite.Require().NoError(err)
suite.Assert().NoError(
retry.Constant(time.Second * 2).Retry(func() error {
list, err := suite.state.List(suite.ctx, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "", resource.VersionUndefined)) //nolint:govet
if err != nil {
return err
}
if len(list.Items) == 0 {
return retry.ExpectedErrorf("no links established yet")
}
for _, item := range list.Items {
if item.Metadata().ID() == "" {
return errors.New("empty id in the resource list")
}
}
return nil
}),
)
reprovision, err := client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})
suite.Assert().NoError(err)
suite.Require().True(proto.Equal(resp, reprovision))
privateKey, err = wgtypes.GeneratePrivateKey()
suite.Assert().NoError(err)
reprovision, err = client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})
suite.Assert().NoError(err)
suite.Require().True(proto.Equal(resp, reprovision))
res, err := safe.StateGet[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "testnode", resource.VersionUndefined))
suite.Assert().NoError(err)
suite.Require().Equal(privateKey.PublicKey().String(), res.TypedSpec().Value.NodePublicKey)
}
func (suite *SiderolinkSuite) TestNodeWithSeveralAdvertisedIPs() {
var spec *specs.ConnectionParamsSpec
ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2)
defer cancel()
rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.Args)
assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint)
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint)
spec = r.TypedSpec().Value
})
conn := must.Value(grpc.NewClient(suite.address, grpc.WithTransportCredentials(insecure.NewCredentials())))(suite.T())
client := pb.NewProvisionServiceClient(conn)
privateKey := must.Value(wgtypes.GeneratePrivateKey())(suite.T())
resp := must.Value(client.Provision(
suite.ctx,
&pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
},
))(suite.T())
require.Equal(suite.T(), []string{config.Config.SiderolinkWireguardAdvertisedAddress, TestIP}, resp.GetEndpoints())
}
func (suite *SiderolinkSuite) TestVirtualNodes() {
var spec *specs.ConnectionParamsSpec
ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2)
defer cancel()
rtestutils.AssertResources[*siderolink.Config](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.Config, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.PrivateKey)
assertion.NotEmpty(r.TypedSpec().Value.PublicKey)
})
rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.Args)
assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint)
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint)
spec = r.TypedSpec().Value
})
conn, err := grpc.NewClient(suite.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
suite.Require().NoError(err)
client := pb.NewProvisionServiceClient(conn)
privateKey, err := wgtypes.GeneratePrivateKey()
suite.Require().NoError(err)
resp, err := client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
WireguardOverGrpc: pointer.To(true),
})
suite.Require().NoError(err)
suite.Assert().NoError(
retry.Constant(time.Second * 2).Retry(func() error {
list, err := safe.ReaderList[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "", resource.VersionUndefined)) //nolint:govet
if err != nil {
return err
}
if list.Len() == 0 {
return retry.ExpectedErrorf("no links established yet")
}
for link := range list.All() {
if link.Metadata().ID() == "" {
return errors.New("empty id in the resource list")
}
if link.TypedSpec().Value.VirtualAddrport == "" {
return errors.New("empty virtual address in the resource list")
}
}
return nil
}),
)
reprovision, err := client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})
expectedResp := resp.CloneVT()
expectedResp.GrpcPeerAddrPort = ""
expectedResp.ServerEndpoint = pb.MakeEndpoints(config.Config.SiderolinkWireguardAdvertisedAddress, TestIP)
suite.Assert().NoError(err)
suite.Require().Equal(expectedResp.String(), reprovision.String())
privateKey, err = wgtypes.GeneratePrivateKey()
suite.Assert().NoError(err)
reprovision, err = client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})
suite.Assert().NoError(err)
suite.Require().Equal(expectedResp.String(), reprovision.String())
res, err := safe.StateGet[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "testnode", resource.VersionUndefined))
suite.Assert().NoError(err)
suite.Require().Equal(privateKey.PublicKey().String(), res.TypedSpec().Value.NodePublicKey)
suite.Require().Zero(res.TypedSpec().Value.VirtualAddrport)
reprovision, err = client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: privateKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
WireguardOverGrpc: pointer.To(true),
})
resp.GrpcPeerAddrPort = reprovision.GrpcPeerAddrPort
resp.ServerEndpoint = reprovision.ServerEndpoint
suite.Assert().NoError(err)
suite.Require().Equal(resp.String(), reprovision.String())
res, err = safe.StateGet[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "testnode", resource.VersionUndefined))
suite.Assert().NoError(err)
suite.Require().Equal(privateKey.PublicKey().String(), res.TypedSpec().Value.NodePublicKey)
suite.Require().NotZero(res.TypedSpec().Value.VirtualAddrport)
suite.Require().Equal(reprovision.GrpcPeerAddrPort, res.TypedSpec().Value.VirtualAddrport)
}
func (suite *SiderolinkSuite) TestGenerateJoinToken() {
token, err := sideromanager.GenerateJoinToken()
suite.Assert().NoError(err)
tokenLen := len(token)
suite.Assert().Less(tokenLen, 52)
suite.Assert().Greater(tokenLen, 42)
}
func (suite *SiderolinkSuite) TearDownTest() {
suite.T().Log("tear down")
suite.ctxCancel()
suite.wg.Wait()
}
func TestSiderolinkSuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(SiderolinkSuite))
}
// TestIP from TEST-NET-1 network which can never be used.
const TestIP = "192.2.0.2"