diff --git a/client/pkg/client/management/management.go b/client/pkg/client/management/management.go index 3191f45f..57d14bee 100644 --- a/client/pkg/client/management/management.go +++ b/client/pkg/client/management/management.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/siderolabs/omni/client/api/omni/management" ) @@ -400,3 +401,21 @@ func (client *ClusterClient) KubernetesSyncManifests(ctx context.Context, dryRun } } } + +// CreateJoinToken creates a join token and returns it's ID. +func (client *Client) CreateJoinToken(ctx context.Context, name string, ttl time.Duration) (string, error) { + var expirationTime *timestamppb.Timestamp + if ttl > 0 { + expirationTime = timestamppb.New(time.Now().Add(ttl)) + } + + resp, err := client.conn.CreateJoinToken(ctx, &management.CreateJoinTokenRequest{ + Name: name, + ExpirationTime: expirationTime, + }) + if err != nil { + return "", err + } + + return resp.Id, nil +} diff --git a/client/pkg/omnictl/jointoken.go b/client/pkg/omnictl/jointoken.go new file mode 100644 index 00000000..b6e6fa5b --- /dev/null +++ b/client/pkg/omnictl/jointoken.go @@ -0,0 +1,274 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package omnictl + +import ( + "context" + "fmt" + "os" + "text/tabwriter" + "time" + + "github.com/cosi-project/runtime/pkg/safe" + "github.com/spf13/cobra" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/siderolabs/omni/client/pkg/client" + "github.com/siderolabs/omni/client/pkg/omni/resources" + "github.com/siderolabs/omni/client/pkg/omni/resources/siderolink" + "github.com/siderolabs/omni/client/pkg/omnictl/internal/access" +) + +var ( + joinTokenCreateFlags struct { + role string + + useUserRole bool + ttl time.Duration + } + + joinTokenRenewFlags struct { + ttl time.Duration + } + + // joinTokenCmd represents the jointoken command. + joinTokenCmd = &cobra.Command{ + Use: "jointoken", + Aliases: []string{"jt"}, + Short: "Manage join tokens", + } + + joinTokenCreateCmd = &cobra.Command{ + Use: "create ", + Aliases: []string{"c"}, + Short: "Create a join token", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + name := args[0] + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + token, err := client.Management().CreateJoinToken(ctx, name, joinTokenCreateFlags.ttl) + if err != nil { + return err + } + + fmt.Println(token) + + return nil + }) + }, + } + + joinTokenRevokeCmd = &cobra.Command{ + Use: "revoke ", + Aliases: []string{"r"}, + Short: "Revoke a join token", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + id := args[0] + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + _, err := safe.StateUpdateWithConflicts( + ctx, + client.Omni().State(), + siderolink.NewJoinToken(resources.DefaultNamespace, id).Metadata(), + func(res *siderolink.JoinToken) error { + res.TypedSpec().Value.Revoked = true + + return nil + }, + ) + if err != nil { + return err + } + + fmt.Printf("token %q was revoked\n", id) + + return nil + }) + }, + } + + joinTokenUnrevokeCmd = &cobra.Command{ + Use: "unrevoke ", + Aliases: []string{"ur"}, + Short: "Unrevoke a join token", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + id := args[0] + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + _, err := safe.StateUpdateWithConflicts( + ctx, + client.Omni().State(), + siderolink.NewJoinToken(resources.DefaultNamespace, id).Metadata(), + func(res *siderolink.JoinToken) error { + res.TypedSpec().Value.Revoked = false + + return nil + }, + ) + if err != nil { + return err + } + + fmt.Printf("token %q was unrevoked\n", id) + + return nil + }) + }, + } + + joinTokenMakeDefaultCmd = &cobra.Command{ + Use: "make-default ", + Aliases: []string{"md"}, + Short: "Make the token default one", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + id := args[0] + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + _, err := safe.StateUpdateWithConflicts( + ctx, + client.Omni().State(), + siderolink.NewDefaultJoinToken().Metadata(), + func(res *siderolink.DefaultJoinToken) error { + res.TypedSpec().Value.TokenId = id + + return nil + }, + ) + if err != nil { + return err + } + + fmt.Printf("token %q is now default\n", id) + + return nil + }) + }, + } + + joinTokenRenewCmd = &cobra.Command{ + Use: "renew ", + Aliases: []string{"r"}, + Short: "Renew a join token", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + id := args[0] + + if joinTokenRenewFlags.ttl == 0 { + return fmt.Errorf("ttl should be greater than 0") + } + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + _, err := safe.StateUpdateWithConflicts( + ctx, + client.Omni().State(), + siderolink.NewJoinToken(resources.DefaultNamespace, id).Metadata(), + func(res *siderolink.JoinToken) error { + res.TypedSpec().Value.ExpirationTime = timestamppb.New(time.Now().Add(joinTokenRenewFlags.ttl)) + + return nil + }, + ) + if err != nil { + return err + } + + fmt.Printf("token %q was renewed, new ttl is %s\n", id, joinTokenRenewFlags.ttl) + + return nil + }) + }, + } + + joinTokenListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"l"}, + Short: "List join tokens", + Args: cobra.NoArgs, + RunE: func(*cobra.Command, []string) error { + return access.WithClient(func(ctx context.Context, client *client.Client) error { + joinTokens, err := safe.ReaderListAll[*siderolink.JoinTokenStatus](ctx, client.Omni().State()) + if err != nil { + return err + } + + writer := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + + fmt.Fprintf(writer, "ID\tNAME\tSTATE\tEXPIRATION\tUSE COUNT\tDEFAULT\n") //nolint:errcheck + + for token := range joinTokens.All() { + var isDefault string + + if token.TypedSpec().Value.IsDefault { + isDefault = "*" + } + + expirationTime := "never" + + if token.TypedSpec().Value.ExpirationTime != nil { + expirationTime = token.TypedSpec().Value.ExpirationTime.AsTime().String() + } + + if _, err = fmt.Fprintf( + writer, + "%s\t%s\t%s\t%s\t%d\t%s\n", + token.Metadata().ID(), + token.TypedSpec().Value.Name, + token.TypedSpec().Value.State.String(), + expirationTime, + token.TypedSpec().Value.UseCount, + isDefault, + ); err != nil { + return err + } + } + + return writer.Flush() + }) + }, + } + + joinTokenDeleteCmd = &cobra.Command{ + Use: "delete ", + Aliases: []string{"d"}, + Short: "Delete a join token", + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + id := args[0] + + return access.WithClient(func(ctx context.Context, client *client.Client) error { + err := client.Omni().State().TeardownAndDestroy(ctx, siderolink.NewJoinToken(resources.DefaultNamespace, id).Metadata()) + if err != nil { + return fmt.Errorf("failed to delete a join token: %w", err) + } + + fmt.Printf("deleted join token: %s\n", id) + + return nil + }) + }, + } +) + +func init() { + RootCmd.AddCommand(joinTokenCmd) + + joinTokenCmd.AddCommand(joinTokenCreateCmd) + joinTokenCmd.AddCommand(joinTokenListCmd) + joinTokenCmd.AddCommand(joinTokenDeleteCmd) + joinTokenCmd.AddCommand(joinTokenRevokeCmd) + joinTokenCmd.AddCommand(joinTokenMakeDefaultCmd) + joinTokenCmd.AddCommand(joinTokenUnrevokeCmd) + joinTokenCmd.AddCommand(joinTokenRenewCmd) + + joinTokenCreateCmd.Flags().DurationVarP(&joinTokenCreateFlags.ttl, "ttl", "t", 0, "TTL for the join token") + + joinTokenRenewCmd.Flags().DurationVarP(&joinTokenRenewFlags.ttl, "ttl", "t", 0, "TTL for the join token") + + joinTokenRenewCmd.MarkFlagRequired("ttl") //nolint:errcheck +} diff --git a/internal/backend/runtime/omni/export_test.go b/internal/backend/runtime/omni/export_test.go index 9fb402b5..ee2d904d 100644 --- a/internal/backend/runtime/omni/export_test.go +++ b/internal/backend/runtime/omni/export_test.go @@ -94,8 +94,8 @@ func JoinTokenValidationOptions(st state.State) []validated.StateOption { return joinTokenValidationOptions(st) } -func DefaultJoinTokenValidationOptions() []validated.StateOption { - return defaultJoinTokenValidationOptions() +func DefaultJoinTokenValidationOptions(st state.State) []validated.StateOption { + return defaultJoinTokenValidationOptions(st) } func ImportedClusterSecretValidationOptions(st state.State, clusterImportEnabled bool) []validated.StateOption { diff --git a/internal/backend/runtime/omni/omni.go b/internal/backend/runtime/omni/omni.go index 799638e6..e68318f7 100644 --- a/internal/backend/runtime/omni/omni.go +++ b/internal/backend/runtime/omni/omni.go @@ -400,7 +400,7 @@ func NewRuntime(talosClientFactory *talos.ClientFactory, dnsService *dns.Service infraMachineConfigValidationOptions(cachedState), nodeForceDestroyRequestValidationOptions(cachedState), joinTokenValidationOptions(cachedState), - defaultJoinTokenValidationOptions(), + defaultJoinTokenValidationOptions(cachedState), importedClusterSecretValidationOptions(cachedState, config.Config.Features.EnableClusterImport), ) diff --git a/internal/backend/runtime/omni/state_validation.go b/internal/backend/runtime/omni/state_validation.go index 263e2e0f..6e39f984 100644 --- a/internal/backend/runtime/omni/state_validation.go +++ b/internal/backend/runtime/omni/state_validation.go @@ -1176,9 +1176,26 @@ func joinTokenValidationOptions(st state.State) []validated.StateOption { } } -func defaultJoinTokenValidationOptions() []validated.StateOption { +func defaultJoinTokenValidationOptions(st state.State) []validated.StateOption { + validateToken := func(ctx context.Context, id string) error { + _, err := safe.ReaderGetByID[*siderolink.JoinToken](ctx, st, id) + if err != nil { + if state.IsNotFoundError(err) { + return fmt.Errorf("no token with id %q exists", id) + } + + return err + } + + return nil + } + return []validated.StateOption{ - validated.WithUpdateValidations(validated.NewUpdateValidationForType(func(_ context.Context, _, res *siderolink.DefaultJoinToken, _ ...state.UpdateOption) error { + validated.WithUpdateValidations(validated.NewUpdateValidationForType(func(ctx context.Context, _, res *siderolink.DefaultJoinToken, _ ...state.UpdateOption) error { + if err := validateToken(ctx, res.TypedSpec().Value.TokenId); err != nil { + return err + } + if res.Metadata().Phase() == resource.PhaseTearingDown { if res.Metadata().ID() != siderolink.DefaultJoinTokenID { return nil @@ -1190,7 +1207,11 @@ func defaultJoinTokenValidationOptions() []validated.StateOption { return nil })), validated.WithDestroyValidations(validated.NewDestroyValidationForType( - func(_ context.Context, _ resource.Pointer, res *siderolink.DefaultJoinToken, _ ...state.DestroyOption) error { + func(ctx context.Context, _ resource.Pointer, res *siderolink.DefaultJoinToken, _ ...state.DestroyOption) error { + if err := validateToken(ctx, res.TypedSpec().Value.TokenId); err != nil { + return err + } + if res.Metadata().ID() != siderolink.DefaultJoinTokenID { return nil } diff --git a/internal/backend/runtime/omni/state_validation_test.go b/internal/backend/runtime/omni/state_validation_test.go index c076e05a..277203a6 100644 --- a/internal/backend/runtime/omni/state_validation_test.go +++ b/internal/backend/runtime/omni/state_validation_test.go @@ -1474,11 +1474,19 @@ func TestDefaultJoinTokenValidation(t *testing.T) { t.Cleanup(cancel) innerSt := state.WrapCore(namespaced.NewState(inmem.Build)) - st := validated.NewState(innerSt, omni.DefaultJoinTokenValidationOptions()...) + st := validated.NewState(innerSt, omni.DefaultJoinTokenValidationOptions(innerSt)...) wrappedState := state.WrapCore(st) defaultToken := siderolink.NewDefaultJoinToken() + joinToken := siderolink.NewJoinToken(resources.DefaultNamespace, "mm") + + require.NoError(t, st.Create(ctx, joinToken)) + + joinToken = siderolink.NewJoinToken(resources.DefaultNamespace, "mmmm") + + require.NoError(t, st.Create(ctx, joinToken)) + defaultToken.TypedSpec().Value.TokenId = "mm" require.NoError(t, wrappedState.Create(ctx, defaultToken)) @@ -1491,6 +1499,14 @@ func TestDefaultJoinTokenValidation(t *testing.T) { assert.NoError(t, err) + _, err = safe.StateUpdateWithConflicts(ctx, wrappedState, defaultToken.Metadata(), func(token *siderolink.DefaultJoinToken) error { + token.TypedSpec().Value.TokenId = "mmmmmm" + + return nil + }) + + assert.Error(t, err) + _, err = wrappedState.Teardown(ctx, defaultToken.Metadata()) assert.ErrorContains(t, err, "destroying") diff --git a/internal/integration/auth_test.go b/internal/integration/auth_test.go index 2cdb30fb..f37dd83a 100644 --- a/internal/integration/auth_test.go +++ b/internal/integration/auth_test.go @@ -638,8 +638,9 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client joinToken := siderolink.NewJoinToken(resources.DefaultNamespace, uuid.New().String()) - defaultJoinToken := siderolink.NewDefaultJoinToken() - *defaultJoinToken.Metadata() = resource.NewMetadata(resources.DefaultNamespace, siderolink.DefaultJoinTokenType, uuid.New().String(), resource.VersionUndefined) + defaultJoinToken, err := safe.StateGetByID[*siderolink.DefaultJoinToken](rootCtx, rootCli.Omni().State(), siderolink.DefaultJoinTokenID) + + require.NoError(t, err) importedClusterSecret := omni.NewImportedClusterSecrets(resources.DefaultNamespace, cluster.Metadata().ID()) @@ -1222,11 +1223,12 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client default: if accessErr != nil { toleratedErrors := map[string]string{ - "NotFoundError": "doesn't exist", - "ValidationError": "failed to validate", - "UnsupportedError": "unsupported resource type", - "AlreadyExists(AccessPolicy)": "resource AccessPolicies.omni.sidero.dev(default/access-policy@undefined) already exists", - "VersionConflict(AccessPolicy)": "failed to update: resource AccessPolicies.omni.sidero.dev(default/access-policy@1) update conflict: expected version", + "NotFoundError": "doesn't exist", + "ValidationError": "failed to validate", + "UnsupportedError": "unsupported resource type", + "AlreadyExists(DefaultJoinToken)": "resource DefaultJoinTokens.omni.sidero.dev(default/default@1) already exists", + "AlreadyExists(AccessPolicy)": "resource AccessPolicies.omni.sidero.dev(default/access-policy@undefined) already exists", + "VersionConflict(AccessPolicy)": "failed to update: resource AccessPolicies.omni.sidero.dev(default/access-policy@1) update conflict: expected version", } isExpectedError := false