talos/internal/app/apid/pkg/backend/apid_test.go
Andrey Smirnov 7f0300f108
feat: update dependencies, Kubernetes 1.34.0-alpha.2
Bump all dependencies, many small changes due to new golangci-lint
version.

Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
2025-06-30 19:05:22 +04:00

404 lines
12 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package backend_test
import (
"context"
"crypto/tls"
"errors"
"testing"
"github.com/siderolabs/go-pointer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc/metadata"
protobuf "google.golang.org/protobuf/proto" //nolint:depguard
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"github.com/siderolabs/talos/internal/app/apid/pkg/backend"
"github.com/siderolabs/talos/pkg/grpc/middleware/authz"
"github.com/siderolabs/talos/pkg/machinery/api/cluster"
"github.com/siderolabs/talos/pkg/machinery/api/common"
"github.com/siderolabs/talos/pkg/machinery/api/inspect"
"github.com/siderolabs/talos/pkg/machinery/api/machine"
"github.com/siderolabs/talos/pkg/machinery/api/security"
"github.com/siderolabs/talos/pkg/machinery/api/storage"
"github.com/siderolabs/talos/pkg/machinery/api/time"
"github.com/siderolabs/talos/pkg/machinery/config"
"github.com/siderolabs/talos/pkg/machinery/proto"
"github.com/siderolabs/talos/pkg/machinery/role"
"github.com/siderolabs/talos/pkg/machinery/version"
)
type APIDSuite struct {
suite.Suite
b *backend.APID
}
func (suite *APIDSuite) SetupSuite() {
tlsConfigProvider := func() (*tls.Config, error) {
return &tls.Config{}, nil
}
var err error
suite.b, err = backend.NewAPID("127.0.0.1", tlsConfigProvider)
suite.Require().NoError(err)
}
func (suite *APIDSuite) TestGetConnection() {
md1 := metadata.New(nil)
md1.Set(":authority", "127.0.0.2")
md1.Set("nodes", "127.0.0.1")
md1.Set("key", "value1", "value2")
ctx1 := metadata.NewIncomingContext(authz.ContextWithRoles(context.Background(), role.MakeSet(role.Admin)), md1)
outCtx1, conn1, err1 := suite.b.GetConnection(ctx1, "")
suite.Require().NoError(err1)
suite.Assert().NotNil(conn1)
suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx1))
mdOut1, ok1 := metadata.FromOutgoingContext(outCtx1)
suite.Require().True(ok1)
suite.Assert().Equal([]string{"value1", "value2"}, mdOut1.Get("key"))
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut1.Get("proxyfrom"))
suite.Assert().Equal([]string{"os:admin"}, mdOut1.Get("talos-role"))
suite.Run(
"Same context", func() {
ctx2 := ctx1
outCtx2, conn2, err2 := suite.b.GetConnection(ctx2, "")
suite.Require().NoError(err2)
suite.Assert().Equal(conn1, conn2) // connection is cached
suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx2))
mdOut2, ok2 := metadata.FromOutgoingContext(outCtx2)
suite.Require().True(ok2)
suite.Assert().Equal([]string{"value1", "value2"}, mdOut2.Get("key"))
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut2.Get("proxyfrom"))
suite.Assert().Equal([]string{"os:admin"}, mdOut2.Get("talos-role"))
},
)
suite.Run(
"Other context", func() {
md3 := metadata.New(nil)
md3.Set(":authority", "127.0.0.2")
md3.Set("nodes", "127.0.0.1")
md3.Set("key", "value3", "value4")
ctx3 := metadata.NewIncomingContext(
authz.ContextWithRoles(context.Background(), role.MakeSet(role.Reader)),
md3,
)
outCtx3, conn3, err3 := suite.b.GetConnection(ctx3, "")
suite.Require().NoError(err3)
suite.Assert().Equal(conn1, conn3) // connection is cached
suite.Assert().Equal(role.MakeSet(role.Reader), authz.GetRoles(outCtx3))
mdOut3, ok3 := metadata.FromOutgoingContext(outCtx3)
suite.Require().True(ok3)
suite.Assert().Equal([]string{"value3", "value4"}, mdOut3.Get("key"))
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut3.Get("proxyfrom"))
suite.Assert().Equal([]string{"os:reader"}, mdOut3.Get("talos-role"))
},
)
}
func (suite *APIDSuite) TestAppendInfoUnary() {
reply := &common.DataResponse{
Messages: []*common.Data{
{
Bytes: []byte("foobar"),
},
},
}
resp, err := proto.Marshal(reply)
suite.Require().NoError(err)
newResp, err := suite.b.AppendInfo(false, resp)
suite.Require().NoError(err)
var newReply common.DataResponse
err = proto.Unmarshal(newResp, &newReply)
suite.Require().NoError(err)
suite.Assert().EqualValues([]byte("foobar"), newReply.Messages[0].Bytes)
suite.Assert().Equal(suite.b.String(), newReply.Messages[0].Metadata.Hostname)
suite.Assert().Empty(newReply.Messages[0].Metadata.Error)
}
func (suite *APIDSuite) TestAppendInfoStreaming() {
response := &common.Data{
Bytes: []byte("foobar"),
}
resp, err := proto.Marshal(response)
suite.Require().NoError(err)
newResp, err := suite.b.AppendInfo(true, resp)
suite.Require().NoError(err)
var newResponse common.Data
err = proto.Unmarshal(newResp, &newResponse)
suite.Require().NoError(err)
suite.Assert().EqualValues([]byte("foobar"), newResponse.Bytes)
suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
suite.Assert().Empty(newResponse.Metadata.Error)
}
func (suite *APIDSuite) TestAppendInfoStreamingMetadata() {
// this tests the case when metadata field is appended twice
// to the message, but protobuf merges definitions
response := &common.Data{
Metadata: &common.Metadata{
Error: "something went wrong",
},
}
resp, err := proto.Marshal(response)
suite.Require().NoError(err)
newResp, err := suite.b.AppendInfo(true, resp)
suite.Require().NoError(err)
var newResponse common.Data
err = proto.Unmarshal(newResp, &newResponse)
suite.Require().NoError(err)
suite.Assert().Nil(newResponse.Bytes)
suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
suite.Assert().Equal("something went wrong", newResponse.Metadata.Error)
}
func (suite *APIDSuite) TestBuildErrorUnary() {
resp, err := suite.b.BuildError(false, errors.New("some error"))
suite.Require().NoError(err)
var reply common.DataResponse
err = proto.Unmarshal(resp, &reply)
suite.Require().NoError(err)
suite.Assert().Nil(reply.Messages[0].Bytes)
suite.Assert().Equal(suite.b.String(), reply.Messages[0].Metadata.Hostname)
suite.Assert().Equal("some error", reply.Messages[0].Metadata.Error)
}
func (suite *APIDSuite) TestBuildErrorStreaming() {
resp, err := suite.b.BuildError(true, errors.New("some error"))
suite.Require().NoError(err)
var response common.Data
err = proto.Unmarshal(resp, &response)
suite.Require().NoError(err)
suite.Assert().Nil(response.Bytes)
suite.Assert().Equal(suite.b.String(), response.Metadata.Hostname)
suite.Assert().Equal("some error", response.Metadata.Error)
}
func TestAPIDSuite(t *testing.T) {
suite.Run(t, new(APIDSuite))
}
func TestAPIIdiosyncrasies(t *testing.T) {
for _, services := range []protoreflect.ServiceDescriptors{
common.File_common_common_proto.Services(),
cluster.File_cluster_cluster_proto.Services(),
inspect.File_inspect_inspect_proto.Services(),
machine.File_machine_machine_proto.Services(),
// security.File_security_security_proto.Services() is different
storage.File_storage_storage_proto.Services(),
time.File_time_time_proto.Services(),
} {
for i := range services.Len() {
service := services.Get(i)
methods := service.Methods()
for j := range methods.Len() {
method := methods.Get(j)
t.Run(
string(method.FullName()), func(t *testing.T) {
response := method.Output()
responseFields := response.Fields()
if method.IsStreamingServer() {
metadata := responseFields.Get(0)
assert.Equal(t, "metadata", metadata.TextName())
assert.Equal(t, 1, int(metadata.Number()))
} else {
require.Equal(t, 1, responseFields.Len(), "unary responses should have exactly one field")
messages := responseFields.Get(0)
assert.Equal(t, "messages", messages.TextName())
assert.Equal(t, 1, int(messages.Number()))
reply := messages.Message()
replyFields := reply.Fields()
require.GreaterOrEqual(
t,
replyFields.Len(),
1,
"unary replies should have at least one field",
)
metadata := replyFields.Get(0)
assert.Equal(t, "metadata", metadata.TextName())
assert.Equal(t, 1, int(metadata.Number()))
}
},
)
}
}
}
}
//nolint:nakedret,gocyclo,forcetypeassert
func getOptions(t *testing.T, descriptor protoreflect.Descriptor) (deprecated bool, version string) {
switch opts := descriptor.Options().(type) {
case *descriptorpb.EnumOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnum).(string)
}
case *descriptorpb.EnumValueOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnumValue).(string)
}
case *descriptorpb.MessageOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMessage).(string)
}
case *descriptorpb.FieldOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedField).(string)
}
case *descriptorpb.ServiceOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedService).(string)
}
case *descriptorpb.MethodOptions:
if opts != nil {
deprecated = pointer.SafeDeref(opts.Deprecated)
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMethod).(string)
}
default:
t.Fatalf("unhandled %T", opts)
}
return
}
func testDeprecated(t *testing.T, descriptor protoreflect.Descriptor, currentVersion *config.VersionContract) {
deprecated, version := getOptions(t, descriptor)
assert.Equal(
t, deprecated, version != "",
"%s: `deprecated` and `remove_deprecated_XXX_in` options should be used together", descriptor.FullName(),
)
if !deprecated || version == "" {
return
}
v, err := config.ParseContractFromVersion(version)
require.NoError(t, err, "%s", descriptor.FullName())
assert.True(t, v.Greater(currentVersion), "%s should be removed in this version", descriptor.FullName())
}
func testEnum(t *testing.T, enum protoreflect.EnumDescriptor, currentVersion *config.VersionContract) {
testDeprecated(t, enum, currentVersion)
values := enum.Values()
for i := range values.Len() {
testDeprecated(t, values.Get(i), currentVersion)
}
}
func testMessage(t *testing.T, message protoreflect.MessageDescriptor, currentVersion *config.VersionContract) {
testDeprecated(t, message, currentVersion)
fields := message.Fields()
for i := range fields.Len() {
testDeprecated(t, fields.Get(i), currentVersion)
}
oneofs := message.Oneofs()
for i := range oneofs.Len() {
testDeprecated(t, oneofs.Get(i), currentVersion)
}
enums := message.Enums()
for i := range enums.Len() {
testEnum(t, enums.Get(i), currentVersion)
}
// test nested messages
messages := message.Messages()
for i := range messages.Len() {
testMessage(t, messages.Get(i), currentVersion)
}
}
func TestDeprecatedAPIs(t *testing.T) {
currentVersion, err := config.ParseContractFromVersion(version.Tag)
require.NoError(t, err)
for _, file := range []protoreflect.FileDescriptor{
common.File_common_common_proto,
cluster.File_cluster_cluster_proto,
inspect.File_inspect_inspect_proto,
machine.File_machine_machine_proto,
security.File_security_security_proto,
storage.File_storage_storage_proto,
time.File_time_time_proto,
} {
enums := file.Enums()
for i := range enums.Len() {
testEnum(t, enums.Get(i), currentVersion)
}
messages := file.Messages()
for i := range messages.Len() {
testMessage(t, messages.Get(i), currentVersion)
}
services := file.Services()
for i := range services.Len() {
service := services.Get(i)
testDeprecated(t, service, currentVersion)
methods := service.Methods()
for j := range methods.Len() {
method := methods.Get(j)
testDeprecated(t, method, currentVersion)
message := method.Input()
testMessage(t, message, currentVersion)
message = method.Output()
testMessage(t, message, currentVersion)
}
}
}
}