mirror of
				https://github.com/siderolabs/talos.git
				synced 2025-11-04 10:21:13 +01:00 
			
		
		
		
	This allows to roll all nodes to use a new CA, to refresh it, or e.g. when the `talosconfig` was exposed accidentally. Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
		
			
				
	
	
		
			398 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			398 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// This Source Code Form is subject to the terms of the Mozilla Public
 | 
						|
// License, v. 2.0. If a copy of the MPL was not distributed with this
 | 
						|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | 
						|
 | 
						|
package backend_test
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"crypto/tls"
 | 
						|
	"errors"
 | 
						|
	"testing"
 | 
						|
 | 
						|
	"github.com/siderolabs/go-pointer"
 | 
						|
	"github.com/stretchr/testify/assert"
 | 
						|
	"github.com/stretchr/testify/require"
 | 
						|
	"github.com/stretchr/testify/suite"
 | 
						|
	"google.golang.org/grpc/metadata"
 | 
						|
	protobuf "google.golang.org/protobuf/proto" //nolint:depguard
 | 
						|
	"google.golang.org/protobuf/reflect/protoreflect"
 | 
						|
	"google.golang.org/protobuf/types/descriptorpb"
 | 
						|
 | 
						|
	"github.com/siderolabs/talos/internal/app/apid/pkg/backend"
 | 
						|
	"github.com/siderolabs/talos/pkg/grpc/middleware/authz"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/cluster"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/common"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/inspect"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/machine"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/security"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/storage"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/api/time"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/config"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/proto"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/role"
 | 
						|
	"github.com/siderolabs/talos/pkg/machinery/version"
 | 
						|
)
 | 
						|
 | 
						|
type APIDSuite struct {
 | 
						|
	suite.Suite
 | 
						|
 | 
						|
	b *backend.APID
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) SetupSuite() {
 | 
						|
	tlsConfigProvider := func() (*tls.Config, error) {
 | 
						|
		return &tls.Config{}, nil
 | 
						|
	}
 | 
						|
 | 
						|
	var err error
 | 
						|
	suite.b, err = backend.NewAPID("127.0.0.1", tlsConfigProvider)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestGetConnection() {
 | 
						|
	md1 := metadata.New(nil)
 | 
						|
	md1.Set(":authority", "127.0.0.2")
 | 
						|
	md1.Set("nodes", "127.0.0.1")
 | 
						|
	md1.Set("key", "value1", "value2")
 | 
						|
	ctx1 := metadata.NewIncomingContext(authz.ContextWithRoles(context.Background(), role.MakeSet(role.Admin)), md1)
 | 
						|
 | 
						|
	outCtx1, conn1, err1 := suite.b.GetConnection(ctx1, "")
 | 
						|
	suite.Require().NoError(err1)
 | 
						|
	suite.Assert().NotNil(conn1)
 | 
						|
	suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx1))
 | 
						|
 | 
						|
	mdOut1, ok1 := metadata.FromOutgoingContext(outCtx1)
 | 
						|
	suite.Require().True(ok1)
 | 
						|
	suite.Assert().Equal([]string{"value1", "value2"}, mdOut1.Get("key"))
 | 
						|
	suite.Assert().Equal([]string{"127.0.0.2"}, mdOut1.Get("proxyfrom"))
 | 
						|
	suite.Assert().Equal([]string{"os:admin"}, mdOut1.Get("talos-role"))
 | 
						|
 | 
						|
	suite.Run(
 | 
						|
		"Same context", func() {
 | 
						|
			ctx2 := ctx1
 | 
						|
			outCtx2, conn2, err2 := suite.b.GetConnection(ctx2, "")
 | 
						|
			suite.Require().NoError(err2)
 | 
						|
			suite.Assert().Equal(conn1, conn2) // connection is cached
 | 
						|
			suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx2))
 | 
						|
 | 
						|
			mdOut2, ok2 := metadata.FromOutgoingContext(outCtx2)
 | 
						|
			suite.Require().True(ok2)
 | 
						|
			suite.Assert().Equal([]string{"value1", "value2"}, mdOut2.Get("key"))
 | 
						|
			suite.Assert().Equal([]string{"127.0.0.2"}, mdOut2.Get("proxyfrom"))
 | 
						|
			suite.Assert().Equal([]string{"os:admin"}, mdOut2.Get("talos-role"))
 | 
						|
		},
 | 
						|
	)
 | 
						|
 | 
						|
	suite.Run(
 | 
						|
		"Other context", func() {
 | 
						|
			md3 := metadata.New(nil)
 | 
						|
			md3.Set(":authority", "127.0.0.2")
 | 
						|
			md3.Set("nodes", "127.0.0.1")
 | 
						|
			md3.Set("key", "value3", "value4")
 | 
						|
			ctx3 := metadata.NewIncomingContext(
 | 
						|
				authz.ContextWithRoles(context.Background(), role.MakeSet(role.Reader)),
 | 
						|
				md3,
 | 
						|
			)
 | 
						|
 | 
						|
			outCtx3, conn3, err3 := suite.b.GetConnection(ctx3, "")
 | 
						|
			suite.Require().NoError(err3)
 | 
						|
			suite.Assert().Equal(conn1, conn3) // connection is cached
 | 
						|
			suite.Assert().Equal(role.MakeSet(role.Reader), authz.GetRoles(outCtx3))
 | 
						|
 | 
						|
			mdOut3, ok3 := metadata.FromOutgoingContext(outCtx3)
 | 
						|
			suite.Require().True(ok3)
 | 
						|
			suite.Assert().Equal([]string{"value3", "value4"}, mdOut3.Get("key"))
 | 
						|
			suite.Assert().Equal([]string{"127.0.0.2"}, mdOut3.Get("proxyfrom"))
 | 
						|
			suite.Assert().Equal([]string{"os:reader"}, mdOut3.Get("talos-role"))
 | 
						|
		},
 | 
						|
	)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestAppendInfoUnary() {
 | 
						|
	reply := &common.DataResponse{
 | 
						|
		Messages: []*common.Data{
 | 
						|
			{
 | 
						|
				Bytes: []byte("foobar"),
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err := proto.Marshal(reply)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	newResp, err := suite.b.AppendInfo(false, resp)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	var newReply common.DataResponse
 | 
						|
	err = proto.Unmarshal(newResp, &newReply)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	suite.Assert().EqualValues([]byte("foobar"), newReply.Messages[0].Bytes)
 | 
						|
	suite.Assert().Equal(suite.b.String(), newReply.Messages[0].Metadata.Hostname)
 | 
						|
	suite.Assert().Empty(newReply.Messages[0].Metadata.Error)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestAppendInfoStreaming() {
 | 
						|
	response := &common.Data{
 | 
						|
		Bytes: []byte("foobar"),
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err := proto.Marshal(response)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	newResp, err := suite.b.AppendInfo(true, resp)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	var newResponse common.Data
 | 
						|
	err = proto.Unmarshal(newResp, &newResponse)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	suite.Assert().EqualValues([]byte("foobar"), newResponse.Bytes)
 | 
						|
	suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
 | 
						|
	suite.Assert().Empty(newResponse.Metadata.Error)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestAppendInfoStreamingMetadata() {
 | 
						|
	// this tests the case when metadata field is appended twice
 | 
						|
	// to the message, but protobuf merges definitions
 | 
						|
	response := &common.Data{
 | 
						|
		Metadata: &common.Metadata{
 | 
						|
			Error: "something went wrong",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	resp, err := proto.Marshal(response)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	newResp, err := suite.b.AppendInfo(true, resp)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	var newResponse common.Data
 | 
						|
	err = proto.Unmarshal(newResp, &newResponse)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	suite.Assert().Nil(newResponse.Bytes)
 | 
						|
	suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
 | 
						|
	suite.Assert().Equal("something went wrong", newResponse.Metadata.Error)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestBuildErrorUnary() {
 | 
						|
	resp, err := suite.b.BuildError(false, errors.New("some error"))
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	var reply common.DataResponse
 | 
						|
	err = proto.Unmarshal(resp, &reply)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	suite.Assert().Nil(reply.Messages[0].Bytes)
 | 
						|
	suite.Assert().Equal(suite.b.String(), reply.Messages[0].Metadata.Hostname)
 | 
						|
	suite.Assert().Equal("some error", reply.Messages[0].Metadata.Error)
 | 
						|
}
 | 
						|
 | 
						|
func (suite *APIDSuite) TestBuildErrorStreaming() {
 | 
						|
	resp, err := suite.b.BuildError(true, errors.New("some error"))
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	var response common.Data
 | 
						|
	err = proto.Unmarshal(resp, &response)
 | 
						|
	suite.Require().NoError(err)
 | 
						|
 | 
						|
	suite.Assert().Nil(response.Bytes)
 | 
						|
	suite.Assert().Equal(suite.b.String(), response.Metadata.Hostname)
 | 
						|
	suite.Assert().Equal("some error", response.Metadata.Error)
 | 
						|
}
 | 
						|
 | 
						|
func TestAPIDSuite(t *testing.T) {
 | 
						|
	suite.Run(t, new(APIDSuite))
 | 
						|
}
 | 
						|
 | 
						|
func TestAPIIdiosyncrasies(t *testing.T) {
 | 
						|
	for _, services := range []protoreflect.ServiceDescriptors{
 | 
						|
		common.File_common_common_proto.Services(),
 | 
						|
		cluster.File_cluster_cluster_proto.Services(),
 | 
						|
		inspect.File_inspect_inspect_proto.Services(),
 | 
						|
		machine.File_machine_machine_proto.Services(),
 | 
						|
		// security.File_security_security_proto.Services() is different
 | 
						|
		storage.File_storage_storage_proto.Services(),
 | 
						|
		time.File_time_time_proto.Services(),
 | 
						|
	} {
 | 
						|
		for i := range services.Len() {
 | 
						|
			service := services.Get(i)
 | 
						|
			methods := service.Methods()
 | 
						|
 | 
						|
			for j := range methods.Len() {
 | 
						|
				method := methods.Get(j)
 | 
						|
 | 
						|
				t.Run(
 | 
						|
					string(method.FullName()), func(t *testing.T) {
 | 
						|
						response := method.Output()
 | 
						|
						responseFields := response.Fields()
 | 
						|
 | 
						|
						if method.IsStreamingServer() {
 | 
						|
							metadata := responseFields.Get(0)
 | 
						|
							assert.Equal(t, "metadata", metadata.TextName())
 | 
						|
							assert.Equal(t, 1, int(metadata.Number()))
 | 
						|
						} else {
 | 
						|
							require.Equal(t, 1, responseFields.Len(), "unary responses should have exactly one field")
 | 
						|
 | 
						|
							messages := responseFields.Get(0)
 | 
						|
							assert.Equal(t, "messages", messages.TextName())
 | 
						|
							assert.Equal(t, 1, int(messages.Number()))
 | 
						|
 | 
						|
							reply := messages.Message()
 | 
						|
							replyFields := reply.Fields()
 | 
						|
							require.GreaterOrEqual(
 | 
						|
								t,
 | 
						|
								replyFields.Len(),
 | 
						|
								1,
 | 
						|
								"unary replies should have at least one field",
 | 
						|
							)
 | 
						|
 | 
						|
							metadata := replyFields.Get(0)
 | 
						|
							assert.Equal(t, "metadata", metadata.TextName())
 | 
						|
							assert.Equal(t, 1, int(metadata.Number()))
 | 
						|
						}
 | 
						|
					},
 | 
						|
				)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
//nolint:nakedret,gocyclo,errcheck,forcetypeassert
 | 
						|
func getOptions(t *testing.T, descriptor protoreflect.Descriptor) (deprecated bool, version string) {
 | 
						|
	switch opts := descriptor.Options().(type) {
 | 
						|
	case *descriptorpb.EnumOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnum).(string)
 | 
						|
		}
 | 
						|
	case *descriptorpb.EnumValueOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnumValue).(string)
 | 
						|
		}
 | 
						|
	case *descriptorpb.MessageOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMessage).(string)
 | 
						|
		}
 | 
						|
	case *descriptorpb.FieldOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedField).(string)
 | 
						|
		}
 | 
						|
	case *descriptorpb.ServiceOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedService).(string)
 | 
						|
		}
 | 
						|
	case *descriptorpb.MethodOptions:
 | 
						|
		if opts != nil {
 | 
						|
			deprecated = pointer.SafeDeref(opts.Deprecated)
 | 
						|
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMethod).(string)
 | 
						|
		}
 | 
						|
 | 
						|
	default:
 | 
						|
		t.Fatalf("unhandled %T", opts)
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func testDeprecated(t *testing.T, descriptor protoreflect.Descriptor, currentVersion *config.VersionContract) {
 | 
						|
	deprecated, version := getOptions(t, descriptor)
 | 
						|
 | 
						|
	assert.Equal(
 | 
						|
		t, deprecated, version != "",
 | 
						|
		"%s: `deprecated` and `remove_deprecated_XXX_in` options should be used together", descriptor.FullName(),
 | 
						|
	)
 | 
						|
 | 
						|
	if !deprecated || version == "" {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	v, err := config.ParseContractFromVersion(version)
 | 
						|
	require.NoError(t, err, "%s", descriptor.FullName())
 | 
						|
 | 
						|
	assert.True(t, v.Greater(currentVersion), "%s should be removed in this version", descriptor.FullName())
 | 
						|
}
 | 
						|
 | 
						|
func testEnum(t *testing.T, enum protoreflect.EnumDescriptor, currentVersion *config.VersionContract) {
 | 
						|
	testDeprecated(t, enum, currentVersion)
 | 
						|
 | 
						|
	values := enum.Values()
 | 
						|
	for i := range values.Len() {
 | 
						|
		testDeprecated(t, values.Get(i), currentVersion)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func testMessage(t *testing.T, message protoreflect.MessageDescriptor, currentVersion *config.VersionContract) {
 | 
						|
	testDeprecated(t, message, currentVersion)
 | 
						|
 | 
						|
	fields := message.Fields()
 | 
						|
	for i := range fields.Len() {
 | 
						|
		testDeprecated(t, fields.Get(i), currentVersion)
 | 
						|
	}
 | 
						|
 | 
						|
	oneofs := message.Oneofs()
 | 
						|
	for i := range oneofs.Len() {
 | 
						|
		testDeprecated(t, oneofs.Get(i), currentVersion)
 | 
						|
	}
 | 
						|
 | 
						|
	enums := message.Enums()
 | 
						|
	for i := range enums.Len() {
 | 
						|
		testEnum(t, enums.Get(i), currentVersion)
 | 
						|
	}
 | 
						|
 | 
						|
	// test nested messages
 | 
						|
	messages := message.Messages()
 | 
						|
	for i := range messages.Len() {
 | 
						|
		testMessage(t, messages.Get(i), currentVersion)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDeprecatedAPIs(t *testing.T) {
 | 
						|
	currentVersion, err := config.ParseContractFromVersion(version.Tag)
 | 
						|
	require.NoError(t, err)
 | 
						|
 | 
						|
	for _, file := range []protoreflect.FileDescriptor{
 | 
						|
		common.File_common_common_proto,
 | 
						|
		cluster.File_cluster_cluster_proto,
 | 
						|
		inspect.File_inspect_inspect_proto,
 | 
						|
		machine.File_machine_machine_proto,
 | 
						|
		security.File_security_security_proto,
 | 
						|
		storage.File_storage_storage_proto,
 | 
						|
		time.File_time_time_proto,
 | 
						|
	} {
 | 
						|
		enums := file.Enums()
 | 
						|
		for i := range enums.Len() {
 | 
						|
			testEnum(t, enums.Get(i), currentVersion)
 | 
						|
		}
 | 
						|
 | 
						|
		messages := file.Messages()
 | 
						|
		for i := range messages.Len() {
 | 
						|
			testMessage(t, messages.Get(i), currentVersion)
 | 
						|
		}
 | 
						|
 | 
						|
		services := file.Services()
 | 
						|
		for i := range services.Len() {
 | 
						|
			service := services.Get(i)
 | 
						|
			testDeprecated(t, service, currentVersion)
 | 
						|
 | 
						|
			methods := service.Methods()
 | 
						|
			for j := range methods.Len() {
 | 
						|
				method := methods.Get(j)
 | 
						|
				testDeprecated(t, method, currentVersion)
 | 
						|
 | 
						|
				message := method.Input()
 | 
						|
				testMessage(t, message, currentVersion)
 | 
						|
 | 
						|
				message = method.Output()
 | 
						|
				testMessage(t, message, currentVersion)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |