mirror of
https://github.com/siderolabs/talos.git
synced 2025-08-06 14:47:05 +02:00
Bump all dependencies, many small changes due to new golangci-lint version. Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
404 lines
12 KiB
Go
404 lines
12 KiB
Go
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package backend_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/siderolabs/go-pointer"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
"google.golang.org/grpc/metadata"
|
|
protobuf "google.golang.org/protobuf/proto" //nolint:depguard
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/types/descriptorpb"
|
|
|
|
"github.com/siderolabs/talos/internal/app/apid/pkg/backend"
|
|
"github.com/siderolabs/talos/pkg/grpc/middleware/authz"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/cluster"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/common"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/inspect"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/machine"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/security"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/storage"
|
|
"github.com/siderolabs/talos/pkg/machinery/api/time"
|
|
"github.com/siderolabs/talos/pkg/machinery/config"
|
|
"github.com/siderolabs/talos/pkg/machinery/proto"
|
|
"github.com/siderolabs/talos/pkg/machinery/role"
|
|
"github.com/siderolabs/talos/pkg/machinery/version"
|
|
)
|
|
|
|
type APIDSuite struct {
|
|
suite.Suite
|
|
|
|
b *backend.APID
|
|
}
|
|
|
|
func (suite *APIDSuite) SetupSuite() {
|
|
tlsConfigProvider := func() (*tls.Config, error) {
|
|
return &tls.Config{}, nil
|
|
}
|
|
|
|
var err error
|
|
|
|
suite.b, err = backend.NewAPID("127.0.0.1", tlsConfigProvider)
|
|
suite.Require().NoError(err)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestGetConnection() {
|
|
md1 := metadata.New(nil)
|
|
md1.Set(":authority", "127.0.0.2")
|
|
md1.Set("nodes", "127.0.0.1")
|
|
md1.Set("key", "value1", "value2")
|
|
ctx1 := metadata.NewIncomingContext(authz.ContextWithRoles(context.Background(), role.MakeSet(role.Admin)), md1)
|
|
|
|
outCtx1, conn1, err1 := suite.b.GetConnection(ctx1, "")
|
|
suite.Require().NoError(err1)
|
|
suite.Assert().NotNil(conn1)
|
|
suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx1))
|
|
|
|
mdOut1, ok1 := metadata.FromOutgoingContext(outCtx1)
|
|
suite.Require().True(ok1)
|
|
suite.Assert().Equal([]string{"value1", "value2"}, mdOut1.Get("key"))
|
|
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut1.Get("proxyfrom"))
|
|
suite.Assert().Equal([]string{"os:admin"}, mdOut1.Get("talos-role"))
|
|
|
|
suite.Run(
|
|
"Same context", func() {
|
|
ctx2 := ctx1
|
|
outCtx2, conn2, err2 := suite.b.GetConnection(ctx2, "")
|
|
suite.Require().NoError(err2)
|
|
suite.Assert().Equal(conn1, conn2) // connection is cached
|
|
suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx2))
|
|
|
|
mdOut2, ok2 := metadata.FromOutgoingContext(outCtx2)
|
|
suite.Require().True(ok2)
|
|
suite.Assert().Equal([]string{"value1", "value2"}, mdOut2.Get("key"))
|
|
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut2.Get("proxyfrom"))
|
|
suite.Assert().Equal([]string{"os:admin"}, mdOut2.Get("talos-role"))
|
|
},
|
|
)
|
|
|
|
suite.Run(
|
|
"Other context", func() {
|
|
md3 := metadata.New(nil)
|
|
md3.Set(":authority", "127.0.0.2")
|
|
md3.Set("nodes", "127.0.0.1")
|
|
md3.Set("key", "value3", "value4")
|
|
ctx3 := metadata.NewIncomingContext(
|
|
authz.ContextWithRoles(context.Background(), role.MakeSet(role.Reader)),
|
|
md3,
|
|
)
|
|
|
|
outCtx3, conn3, err3 := suite.b.GetConnection(ctx3, "")
|
|
suite.Require().NoError(err3)
|
|
suite.Assert().Equal(conn1, conn3) // connection is cached
|
|
suite.Assert().Equal(role.MakeSet(role.Reader), authz.GetRoles(outCtx3))
|
|
|
|
mdOut3, ok3 := metadata.FromOutgoingContext(outCtx3)
|
|
suite.Require().True(ok3)
|
|
suite.Assert().Equal([]string{"value3", "value4"}, mdOut3.Get("key"))
|
|
suite.Assert().Equal([]string{"127.0.0.2"}, mdOut3.Get("proxyfrom"))
|
|
suite.Assert().Equal([]string{"os:reader"}, mdOut3.Get("talos-role"))
|
|
},
|
|
)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestAppendInfoUnary() {
|
|
reply := &common.DataResponse{
|
|
Messages: []*common.Data{
|
|
{
|
|
Bytes: []byte("foobar"),
|
|
},
|
|
},
|
|
}
|
|
|
|
resp, err := proto.Marshal(reply)
|
|
suite.Require().NoError(err)
|
|
|
|
newResp, err := suite.b.AppendInfo(false, resp)
|
|
suite.Require().NoError(err)
|
|
|
|
var newReply common.DataResponse
|
|
|
|
err = proto.Unmarshal(newResp, &newReply)
|
|
suite.Require().NoError(err)
|
|
|
|
suite.Assert().EqualValues([]byte("foobar"), newReply.Messages[0].Bytes)
|
|
suite.Assert().Equal(suite.b.String(), newReply.Messages[0].Metadata.Hostname)
|
|
suite.Assert().Empty(newReply.Messages[0].Metadata.Error)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestAppendInfoStreaming() {
|
|
response := &common.Data{
|
|
Bytes: []byte("foobar"),
|
|
}
|
|
|
|
resp, err := proto.Marshal(response)
|
|
suite.Require().NoError(err)
|
|
|
|
newResp, err := suite.b.AppendInfo(true, resp)
|
|
suite.Require().NoError(err)
|
|
|
|
var newResponse common.Data
|
|
|
|
err = proto.Unmarshal(newResp, &newResponse)
|
|
suite.Require().NoError(err)
|
|
|
|
suite.Assert().EqualValues([]byte("foobar"), newResponse.Bytes)
|
|
suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
|
|
suite.Assert().Empty(newResponse.Metadata.Error)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestAppendInfoStreamingMetadata() {
|
|
// this tests the case when metadata field is appended twice
|
|
// to the message, but protobuf merges definitions
|
|
response := &common.Data{
|
|
Metadata: &common.Metadata{
|
|
Error: "something went wrong",
|
|
},
|
|
}
|
|
|
|
resp, err := proto.Marshal(response)
|
|
suite.Require().NoError(err)
|
|
|
|
newResp, err := suite.b.AppendInfo(true, resp)
|
|
suite.Require().NoError(err)
|
|
|
|
var newResponse common.Data
|
|
|
|
err = proto.Unmarshal(newResp, &newResponse)
|
|
suite.Require().NoError(err)
|
|
|
|
suite.Assert().Nil(newResponse.Bytes)
|
|
suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
|
|
suite.Assert().Equal("something went wrong", newResponse.Metadata.Error)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestBuildErrorUnary() {
|
|
resp, err := suite.b.BuildError(false, errors.New("some error"))
|
|
suite.Require().NoError(err)
|
|
|
|
var reply common.DataResponse
|
|
|
|
err = proto.Unmarshal(resp, &reply)
|
|
suite.Require().NoError(err)
|
|
|
|
suite.Assert().Nil(reply.Messages[0].Bytes)
|
|
suite.Assert().Equal(suite.b.String(), reply.Messages[0].Metadata.Hostname)
|
|
suite.Assert().Equal("some error", reply.Messages[0].Metadata.Error)
|
|
}
|
|
|
|
func (suite *APIDSuite) TestBuildErrorStreaming() {
|
|
resp, err := suite.b.BuildError(true, errors.New("some error"))
|
|
suite.Require().NoError(err)
|
|
|
|
var response common.Data
|
|
|
|
err = proto.Unmarshal(resp, &response)
|
|
suite.Require().NoError(err)
|
|
|
|
suite.Assert().Nil(response.Bytes)
|
|
suite.Assert().Equal(suite.b.String(), response.Metadata.Hostname)
|
|
suite.Assert().Equal("some error", response.Metadata.Error)
|
|
}
|
|
|
|
func TestAPIDSuite(t *testing.T) {
|
|
suite.Run(t, new(APIDSuite))
|
|
}
|
|
|
|
func TestAPIIdiosyncrasies(t *testing.T) {
|
|
for _, services := range []protoreflect.ServiceDescriptors{
|
|
common.File_common_common_proto.Services(),
|
|
cluster.File_cluster_cluster_proto.Services(),
|
|
inspect.File_inspect_inspect_proto.Services(),
|
|
machine.File_machine_machine_proto.Services(),
|
|
// security.File_security_security_proto.Services() is different
|
|
storage.File_storage_storage_proto.Services(),
|
|
time.File_time_time_proto.Services(),
|
|
} {
|
|
for i := range services.Len() {
|
|
service := services.Get(i)
|
|
methods := service.Methods()
|
|
|
|
for j := range methods.Len() {
|
|
method := methods.Get(j)
|
|
|
|
t.Run(
|
|
string(method.FullName()), func(t *testing.T) {
|
|
response := method.Output()
|
|
responseFields := response.Fields()
|
|
|
|
if method.IsStreamingServer() {
|
|
metadata := responseFields.Get(0)
|
|
assert.Equal(t, "metadata", metadata.TextName())
|
|
assert.Equal(t, 1, int(metadata.Number()))
|
|
} else {
|
|
require.Equal(t, 1, responseFields.Len(), "unary responses should have exactly one field")
|
|
|
|
messages := responseFields.Get(0)
|
|
assert.Equal(t, "messages", messages.TextName())
|
|
assert.Equal(t, 1, int(messages.Number()))
|
|
|
|
reply := messages.Message()
|
|
replyFields := reply.Fields()
|
|
require.GreaterOrEqual(
|
|
t,
|
|
replyFields.Len(),
|
|
1,
|
|
"unary replies should have at least one field",
|
|
)
|
|
|
|
metadata := replyFields.Get(0)
|
|
assert.Equal(t, "metadata", metadata.TextName())
|
|
assert.Equal(t, 1, int(metadata.Number()))
|
|
}
|
|
},
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//nolint:nakedret,gocyclo,forcetypeassert
|
|
func getOptions(t *testing.T, descriptor protoreflect.Descriptor) (deprecated bool, version string) {
|
|
switch opts := descriptor.Options().(type) {
|
|
case *descriptorpb.EnumOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnum).(string)
|
|
}
|
|
case *descriptorpb.EnumValueOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnumValue).(string)
|
|
}
|
|
case *descriptorpb.MessageOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMessage).(string)
|
|
}
|
|
case *descriptorpb.FieldOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedField).(string)
|
|
}
|
|
case *descriptorpb.ServiceOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedService).(string)
|
|
}
|
|
case *descriptorpb.MethodOptions:
|
|
if opts != nil {
|
|
deprecated = pointer.SafeDeref(opts.Deprecated)
|
|
version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMethod).(string)
|
|
}
|
|
|
|
default:
|
|
t.Fatalf("unhandled %T", opts)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func testDeprecated(t *testing.T, descriptor protoreflect.Descriptor, currentVersion *config.VersionContract) {
|
|
deprecated, version := getOptions(t, descriptor)
|
|
|
|
assert.Equal(
|
|
t, deprecated, version != "",
|
|
"%s: `deprecated` and `remove_deprecated_XXX_in` options should be used together", descriptor.FullName(),
|
|
)
|
|
|
|
if !deprecated || version == "" {
|
|
return
|
|
}
|
|
|
|
v, err := config.ParseContractFromVersion(version)
|
|
require.NoError(t, err, "%s", descriptor.FullName())
|
|
|
|
assert.True(t, v.Greater(currentVersion), "%s should be removed in this version", descriptor.FullName())
|
|
}
|
|
|
|
func testEnum(t *testing.T, enum protoreflect.EnumDescriptor, currentVersion *config.VersionContract) {
|
|
testDeprecated(t, enum, currentVersion)
|
|
|
|
values := enum.Values()
|
|
for i := range values.Len() {
|
|
testDeprecated(t, values.Get(i), currentVersion)
|
|
}
|
|
}
|
|
|
|
func testMessage(t *testing.T, message protoreflect.MessageDescriptor, currentVersion *config.VersionContract) {
|
|
testDeprecated(t, message, currentVersion)
|
|
|
|
fields := message.Fields()
|
|
for i := range fields.Len() {
|
|
testDeprecated(t, fields.Get(i), currentVersion)
|
|
}
|
|
|
|
oneofs := message.Oneofs()
|
|
for i := range oneofs.Len() {
|
|
testDeprecated(t, oneofs.Get(i), currentVersion)
|
|
}
|
|
|
|
enums := message.Enums()
|
|
for i := range enums.Len() {
|
|
testEnum(t, enums.Get(i), currentVersion)
|
|
}
|
|
|
|
// test nested messages
|
|
messages := message.Messages()
|
|
for i := range messages.Len() {
|
|
testMessage(t, messages.Get(i), currentVersion)
|
|
}
|
|
}
|
|
|
|
func TestDeprecatedAPIs(t *testing.T) {
|
|
currentVersion, err := config.ParseContractFromVersion(version.Tag)
|
|
require.NoError(t, err)
|
|
|
|
for _, file := range []protoreflect.FileDescriptor{
|
|
common.File_common_common_proto,
|
|
cluster.File_cluster_cluster_proto,
|
|
inspect.File_inspect_inspect_proto,
|
|
machine.File_machine_machine_proto,
|
|
security.File_security_security_proto,
|
|
storage.File_storage_storage_proto,
|
|
time.File_time_time_proto,
|
|
} {
|
|
enums := file.Enums()
|
|
for i := range enums.Len() {
|
|
testEnum(t, enums.Get(i), currentVersion)
|
|
}
|
|
|
|
messages := file.Messages()
|
|
for i := range messages.Len() {
|
|
testMessage(t, messages.Get(i), currentVersion)
|
|
}
|
|
|
|
services := file.Services()
|
|
for i := range services.Len() {
|
|
service := services.Get(i)
|
|
testDeprecated(t, service, currentVersion)
|
|
|
|
methods := service.Methods()
|
|
for j := range methods.Len() {
|
|
method := methods.Get(j)
|
|
testDeprecated(t, method, currentVersion)
|
|
|
|
message := method.Input()
|
|
testMessage(t, message, currentVersion)
|
|
|
|
message = method.Output()
|
|
testMessage(t, message, currentVersion)
|
|
}
|
|
}
|
|
}
|
|
}
|