From ef3e3bc1cc4e3fab31ec5fa852e5d1b61e2d22a2 Mon Sep 17 00:00:00 2001 From: Utku Ozdemir Date: Wed, 11 Feb 2026 13:12:01 +0100 Subject: [PATCH] test: use automation sa directly in integration tests Instead of doing the fake user auth flow in the integration tests via the `clientconfig` package, use the automation service account directly. Remove all other usages of that package as well, and drop it completely. The package predates the initial service account token feature of Omni, its purpose was to authenticate to the Omni API in the integration tests. We have the automation key now, so we don't need that anymore. Signed-off-by: Utku Ozdemir --- cmd/make-cookies/main.go | 34 +- internal/integration/auth_test.go | 142 ++++++-- internal/integration/common_test.go | 9 +- internal/integration/image_test.go | 20 +- internal/integration/integration_test.go | 68 +++- internal/integration/suites_test.go | 24 +- .../workloadproxy/workloadproxy.go | 40 ++- internal/pkg/clientconfig/clientconfig.go | 302 ------------------ .../pkg/clientconfig/register_key_debug.go | 43 --- .../pkg/clientconfig/register_key_no_debug.go | 20 -- 10 files changed, 235 insertions(+), 467 deletions(-) delete mode 100644 internal/pkg/clientconfig/clientconfig.go delete mode 100644 internal/pkg/clientconfig/register_key_debug.go delete mode 100644 internal/pkg/clientconfig/register_key_no_debug.go diff --git a/cmd/make-cookies/main.go b/cmd/make-cookies/main.go index 866dcfee..389214f6 100644 --- a/cmd/make-cookies/main.go +++ b/cmd/make-cookies/main.go @@ -7,49 +7,43 @@ package main import ( - "context" + "encoding/base64" "fmt" + "log" "net/http" - "os" + + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/siderolabs/omni/internal/backend/services/workloadproxy" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) func main() { if err := app(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + log.Fatalf("failed to create cookies: %v", err) } } func app() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // don't forget to build this with the -tags=sidero.debug - if len(os.Args) != 2 { - return fmt.Errorf("usage: %s ", os.Args[0]) + _, saKey := serviceaccount.GetFromEnv() + if saKey == "" { + return fmt.Errorf("no service account key found in environment variables") } - cfg := clientconfig.New(os.Args[1], os.Getenv("OMNI_SERVICE_ACCOUNT_KEY")) - defer cfg.Close() //nolint:errcheck - - client, err := cfg.GetClient(ctx) + sa, err := serviceaccount.Decode(saKey) if err != nil { - return fmt.Errorf("error getting client: %w", err) + return fmt.Errorf("error decoding service account key: %w", err) } - defer client.Close() //nolint:errcheck + keyID := sa.Key.Fingerprint() - keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, client) + signedIDBytes, err := sa.Key.Sign([]byte(keyID)) if err != nil { - return fmt.Errorf("error registering key: %w", err) + return fmt.Errorf("error signing key ID: %w", err) } cookies := []*http.Cookie{ {Name: workloadproxy.PublicKeyIDCookie, Value: keyID}, - {Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: keyIDSignatureBase64}, + {Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: base64.StdEncoding.EncodeToString(signedIDBytes)}, } for _, cookie := range cookies { diff --git a/internal/integration/auth_test.go b/internal/integration/auth_test.go index 921243d0..7998f34e 100644 --- a/internal/integration/auth_test.go +++ b/internal/integration/auth_test.go @@ -10,6 +10,7 @@ package integration_test import ( "bytes" "context" + "crypto/md5" _ "embed" "encoding/base64" "encoding/json" @@ -23,6 +24,7 @@ import ( "slices" "strconv" "strings" + "sync" "testing" "time" @@ -36,7 +38,6 @@ import ( "github.com/google/uuid" "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" - authcli "github.com/siderolabs/go-api-signature/pkg/client/auth" "github.com/siderolabs/go-api-signature/pkg/client/interceptor" "github.com/siderolabs/go-api-signature/pkg/message" "github.com/siderolabs/go-api-signature/pkg/pgp" @@ -73,10 +74,114 @@ import ( "github.com/siderolabs/omni/internal/backend/runtime/omni/validated" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" - "github.com/siderolabs/omni/internal/pkg/clientconfig" "github.com/siderolabs/omni/internal/pkg/grpcutil" ) +// testClientFactory creates test clients with specific roles for authorization testing. +// It uses the root client (automation SA) to create new service accounts with the specified role, +// then returns a client authenticated as that SA. Clients are cached by role. +type testClientFactory struct { + endpoint string + serviceAccountKey string + rootCli *client.Client + + mu sync.Mutex + clients map[role.Role]*client.Client +} + +func newTestClientFactory(endpoint string, rootCli *client.Client) *testClientFactory { + return &testClientFactory{ + endpoint: endpoint, + rootCli: rootCli, + clients: make(map[role.Role]*client.Client), + } +} + +func (f *testClientFactory) getClient(ctx context.Context, r role.Role) (*client.Client, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if cli, ok := f.clients[r]; ok { + return cli, nil + } + + cli, err := f.createClientForRole(ctx, r) + if err != nil { + return nil, err + } + + f.clients[r] = cli + + return cli, nil +} + +func (f *testClientFactory) createClientForRole(ctx context.Context, r role.Role) (*client.Client, error) { + name := fmt.Sprintf("%x", md5.Sum([]byte(string(r)))) + + comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) + + suffix := access.ServiceAccountNameSuffix + if r == role.InfraProvider { + suffix = access.InfraProviderServiceAccountNameSuffix + } + + serviceAccountEmail := name + suffix + + key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) + if err != nil { + return nil, err + } + + if r == role.InfraProvider { + name = access.InfraProviderServiceAccountPrefix + name + } + + armoredPublicKey, err := key.ArmorPublic() + if err != nil { + return nil, err + } + + serviceAccounts, err := f.rootCli.Management().ListServiceAccounts(ctx) + if err != nil { + return nil, err + } + + if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool { + return account.Name == name + }) != -1 { + if err = f.rootCli.Management().DestroyServiceAccount(ctx, name); err != nil { + return nil, err + } + } + + _, err = f.rootCli.Management().CreateServiceAccount(ctx, name, armoredPublicKey, string(r), false) + if err != nil { + return nil, err + } + + encodedKey, err := serviceaccount.Encode(name, key) + if err != nil { + return nil, err + } + + return client.New(f.endpoint, client.WithServiceAccount(encodedKey)) +} + +func (f *testClientFactory) close() error { + f.mu.Lock() + defer f.mu.Unlock() + + var errs []error + + for _, cli := range f.clients { + if err := cli.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + // AssertAnonymousAuthentication tests the authentication without any credentials. func AssertAnonymousAuthentication(testCtx context.Context, client *client.Client) TestFunc { return func(t *testing.T) { @@ -336,7 +441,7 @@ type apiAuthzTestCase struct { // AssertAPIAuthz tests the authorization checks of the API endpoints. // //nolint:gocognit,gocyclo,cyclop,maintidx -func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, clusterName string) TestFunc { +func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory, clusterName string) TestFunc { rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String())) assertSuccess := func(t *testing.T, err error) { @@ -555,13 +660,10 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi for _, tc := range testCases { // test each test case without signature t.Run(fmt.Sprintf("%s-no-signature", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient(rootCtx) - require.NoError(t, testErr) - // skip signing the request ctx := context.WithValue(rootCtx, interceptor.SkipInterceptorContextKey{}, struct{}{}) - testErr = tc.fn(ctx, scopedClient) + testErr := tc.fn(ctx, rootCli) // public resources will either succeed or fail with a permission denied if they are read-only resources if tc.isPublic { @@ -586,11 +688,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi // test with the role which should succeed t.Run(fmt.Sprintf("%s-success", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(tc.requiredRole)), - authcli.WithSkipUserRole(true), - ) + scopedClient, testErr := clientFactory.getClient(rootCtx, tc.requiredRole) require.NoError(t, testErr) assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), tc.requiredRole) @@ -613,10 +711,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi require.NoError(t, err) t.Run(fmt.Sprintf("%s-failure", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(failureRole)), - authcli.WithSkipUserRole(true)) + scopedClient, testErr := clientFactory.getClient(rootCtx, failureRole) require.NoError(t, testErr) assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), failureRole) @@ -641,7 +736,7 @@ type resourceAuthzTestCase struct { // AssertResourceAuthz tests the authorization checks of the resources (state). // //nolint:gocognit,gocyclo,cyclop,maintidx -func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig) TestFunc { +func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory) TestFunc { rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String())) return func(t *testing.T) { @@ -1247,11 +1342,7 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client delete(untestedResourceTypes, tc.resource.Metadata().Type()) t.Run(name, func(t *testing.T) { - scopedCli, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(testRole)), - authcli.WithSkipUserRole(true), - ) + scopedCli, testErr := clientFactory.getClient(rootCtx, testRole) require.NoError(t, testErr) // ensure that scopedCli is operating with the correct role @@ -1377,12 +1468,13 @@ var ( const grpcMetadataPrefix = "Grpc-Metadata-" -func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, httpEndpoint, clusterName string) TestFunc { +func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, serviceAccountKey, httpEndpoint, clusterName string) TestFunc { return func(t *testing.T) { - key, err := clientConfig.GetKey(ctx) + sa, err := serviceaccount.Decode(serviceAccountKey) require.NoError(t, err) - email := clientconfig.DefaultServiceAccount + key := sa.Key + email := sa.Name + access.ServiceAccountNameSuffix // do the same flow for the signature as in the JS code signRequest := func(request *http.Request) error { diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index e9cbb006..7864da26 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "net" - "net/http" "strings" "testing" "time" @@ -29,7 +28,6 @@ import ( "github.com/siderolabs/omni/client/pkg/client" "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/runtime/talos" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) func resourceDetails(res resource.Resource) string { @@ -266,9 +264,6 @@ type WipeAMachineFunc func(ctx context.Context, uuid string) error // FreezeAMachineFunc is a function to freeze a machine by UUID. type FreezeAMachineFunc func(ctx context.Context, uuid string) error -// HTTPRequestSignerFunc is function to sign the HTTP request. -type HTTPRequestSignerFunc func(ctx context.Context, req *http.Request) error - // Options for the test runner. // //nolint:govet @@ -340,8 +335,8 @@ type MachineProviderConfig struct { // TestOptions constains all common data that might be required to run the tests. type TestOptions struct { Options - omniClient *client.Client - clientConfig *clientconfig.ClientConfig + omniClient *client.Client + serviceAccountKey string machineSemaphore *semaphore.Weighted } diff --git a/internal/integration/image_test.go b/internal/integration/image_test.go index 77f3695a..b3a99dab 100644 --- a/internal/integration/image_test.go +++ b/internal/integration/image_test.go @@ -16,21 +16,26 @@ import ( "time" "github.com/cosi-project/runtime/pkg/safe" + "github.com/siderolabs/go-api-signature/pkg/message" + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/stretchr/testify/require" "github.com/siderolabs/omni/client/api/omni/management" - "github.com/siderolabs/omni/client/pkg/client" + "github.com/siderolabs/omni/client/pkg/access" clientconsts "github.com/siderolabs/omni/client/pkg/constants" "github.com/siderolabs/omni/client/pkg/omni/resources/omni" ) // AssertSomeImagesAreDownloadable verifies generated image download. -func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Client, signer HTTPRequestSignerFunc, httpEndpoint string) TestFunc { - st := client.Omni().State() +func AssertSomeImagesAreDownloadable(testCtx context.Context, options *TestOptions) TestFunc { + st := options.omniClient.Omni().State() return func(t *testing.T) { t.Parallel() + sa, err := serviceaccount.Decode(options.serviceAccountKey) + require.NoError(t, err) + media, err := safe.StateListAll[*omni.InstallationMedia](testCtx, st) require.NoError(t, err) @@ -60,10 +65,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli ctx, cancel := context.WithTimeout(testCtx, time.Minute*5) defer cancel() - u, err := url.Parse(httpEndpoint) + u, err := url.Parse(options.HTTPEndpoint) require.NoError(t, err) - schematic, err := client.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{ + schematic, err := options.omniClient.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{ MediaId: image.Metadata().ID(), TalosVersion: clientconsts.DefaultTalosVersion, }) @@ -75,7 +80,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) require.NoError(t, err) - require.NoError(t, signer(ctx, req)) + msg, err := message.NewHTTP(req) + require.NoError(t, err) + + require.NoError(t, msg.Sign(sa.Name+access.ServiceAccountNameSuffix, sa.Key)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index b305d009..9fdb2c53 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -16,14 +16,16 @@ import ( "net/url" "os" "os/exec" + "runtime" "testing" "time" "github.com/cosi-project/runtime/pkg/resource/rtestutils" "github.com/cosi-project/runtime/pkg/safe" - "github.com/cosi-project/runtime/pkg/state" + cosistate "github.com/cosi-project/runtime/pkg/state" "github.com/mattn/go-shellwords" "github.com/prometheus/client_golang/prometheus" + "github.com/siderolabs/go-api-signature/pkg/pgp" "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,14 +36,19 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" + "github.com/siderolabs/omni/client/pkg/access" + "github.com/siderolabs/omni/client/pkg/client" clientconsts "github.com/siderolabs/omni/client/pkg/constants" + authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/client/pkg/omni/resources/siderolink" _ "github.com/siderolabs/omni/cmd/acompat" // this package should always be imported first for init->set env to work "github.com/siderolabs/omni/cmd/omni/pkg/app" "github.com/siderolabs/omni/internal/backend/runtime/omni" + "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" - "github.com/siderolabs/omni/internal/pkg/clientconfig" + "github.com/siderolabs/omni/internal/pkg/auth/role" + omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount" "github.com/siderolabs/omni/internal/pkg/config" "github.com/siderolabs/omni/internal/pkg/constants" ) @@ -197,13 +204,7 @@ func TestIntegration(t *testing.T) { // Talos API calls try to use user auth if the service account var is not set os.Setenv(serviceaccount.OmniServiceAccountKeyEnvVar, serviceAccount) - clientConfig := clientconfig.New(omniEndpoint, serviceAccount) - - t.Cleanup(func() { - clientConfig.Close() //nolint:errcheck - }) - - rootClient, err := clientConfig.GetClient(t.Context()) + rootClient, err := client.New(omniEndpoint, client.WithServiceAccount(serviceAccount)) require.NoError(t, err) t.Cleanup(func() { @@ -211,10 +212,10 @@ func TestIntegration(t *testing.T) { }) testOptions := &TestOptions{ - omniClient: rootClient, - Options: options, - machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)), - clientConfig: clientConfig, + omniClient: rootClient, + Options: options, + machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)), + serviceAccountKey: serviceAccount, } preRunHooks(t, testOptions) @@ -390,7 +391,7 @@ func postRunHooks(t *testing.T, options *TestOptions) { } } -func cleanupLinksFunc(ctx context.Context, st state.State) error { +func cleanupLinksFunc(ctx context.Context, st cosistate.State) error { links, err := safe.ReaderListAll[*siderolink.Link](ctx, st) if err != nil { return err @@ -403,7 +404,7 @@ func cleanupLinksFunc(ctx context.Context, st state.State) error { return links.ForEachErr(func(r *siderolink.Link) error { err := st.TeardownAndDestroy(ctx, r.Metadata()) - if err != nil && !state.IsNotFoundError(err) { + if err != nil && !cosistate.IsNotFoundError(err) { return err } @@ -503,7 +504,7 @@ func runOmni(t *testing.T) (string, error) { rtestutils.AssertResources(ctx, t, state.Default(), []string{talosVersion}, func(*omnires.TalosVersion, *assert.Assertions) {}) - sa, err := clientconfig.CreateServiceAccount(omniCtx, "root", state.Default()) + sa, err := createBootstrapServiceAccount(omniCtx, "root", state.Default()) if err != nil { return "", err } @@ -513,3 +514,38 @@ func runOmni(t *testing.T) (string, error) { return sa, nil } + +func createBootstrapServiceAccount(ctx context.Context, name string, st cosistate.State) (string, error) { + comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) + + serviceAccountEmail := name + access.ServiceAccountNameSuffix + + key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) + if err != nil { + return "", err + } + + armoredPublicKey, err := key.ArmorPublic() + if err != nil { + return "", err + } + + identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail) + if err != nil && !cosistate.IsNotFoundError(err) { + return "", err + } + + if identity != nil { + err = omnisa.Destroy(ctx, st, name) + if err != nil { + return "", err + } + } + + _, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey)) + if err != nil { + return "", err + } + + return serviceaccount.Encode(name, key) +} diff --git a/internal/integration/suites_test.go b/internal/integration/suites_test.go index 6a234367..f2f4d7b2 100644 --- a/internal/integration/suites_test.go +++ b/internal/integration/suites_test.go @@ -8,8 +8,6 @@ package integration_test import ( - "context" - "net/http" "testing" "time" @@ -19,7 +17,6 @@ import ( "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/extensions" "github.com/siderolabs/omni/internal/integration/workloadproxy" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) type assertClusterReadyOptions struct { @@ -128,9 +125,7 @@ Generate various Talos images with Omni and try to download them.`) t.Run( "TalosImagesShouldBeDownloadable", - AssertSomeImagesAreDownloadable(t.Context(), options.omniClient, func(ctx context.Context, req *http.Request) error { - return clientconfig.SignHTTPRequest(ctx, options.omniClient, req) - }, options.HTTPEndpoint), + AssertSomeImagesAreDownloadable(t.Context(), options), ) } } @@ -1159,9 +1154,14 @@ Test authorization on accessing Omni API, some tests run without a cluster, some AssertServiceAccountAPIFlow(t.Context(), options.omniClient), ) + clientFactory := newTestClientFactory(omniEndpoint, options.omniClient) + t.Cleanup(func() { + clientFactory.close() //nolint:errcheck + }) + t.Run( "ResourceAuthzShouldWork", - AssertResourceAuthz(t.Context(), options.omniClient, options.clientConfig), + AssertResourceAuthz(t.Context(), options.omniClient, clientFactory), ) t.Run( @@ -1193,12 +1193,12 @@ Test authorization on accessing Omni API, some tests run without a cluster, some t.Run( "APIAuthorizationShouldBeTested", - AssertAPIAuthz(t.Context(), options.omniClient, options.clientConfig, clusterName), + AssertAPIAuthz(t.Context(), options.omniClient, clientFactory, clusterName), ) t.Run( "FrontendAPIShouldBeTested", - AssertFrontendResourceAPI(t.Context(), options.omniClient, options.clientConfig, options.HTTPEndpoint, clusterName), + AssertFrontendResourceAPI(t.Context(), options.omniClient, options.serviceAccountKey, options.HTTPEndpoint, clusterName), ) t.Run( @@ -1278,7 +1278,7 @@ Test workload service proxying feature`) parentCtx := t.Context() t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, cluster1, cluster2) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, cluster1, cluster2) }) t.Run("ClusterShouldBeDestroyed-"+cluster1, AssertDestroyCluster(t.Context(), options.omniClient.Omni().State(), cluster1, false, false)) @@ -1457,7 +1457,7 @@ Test Omni upgrades, the first half that runs on the previous Omni version parentCtx := t.Context() t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, clusterName) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName) }) t.Run("SaveClusterSnapshot", SaveClusterSnapshot(t.Context(), omniClient, clusterName)) @@ -1494,7 +1494,7 @@ Test Omni upgrades, the second half that runs on the current Omni version t.Run("AssertMachinesNotRebootedConfigUnchanged", AssertClusterSnapshot(t.Context(), omniClient, clusterName)) t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, clusterName) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName) }) t.Run( diff --git a/internal/integration/workloadproxy/workloadproxy.go b/internal/integration/workloadproxy/workloadproxy.go index 5c9ffb66..fd9b2706 100644 --- a/internal/integration/workloadproxy/workloadproxy.go +++ b/internal/integration/workloadproxy/workloadproxy.go @@ -31,6 +31,8 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" + "github.com/siderolabs/go-api-signature/pkg/pgp" + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/siderolabs/go-pointer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,7 +50,6 @@ import ( "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/services/workloadproxy" "github.com/siderolabs/omni/internal/integration/kubernetes" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) type serviceContext struct { @@ -75,7 +76,7 @@ var sideroLabsIconSVG []byte // Test tests the exposed services functionality in Omni. // //nolint:prealloc -func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterIDs ...string) { +func Test(ctx context.Context, t *testing.T, omniClient *client.Client, serviceAccountKey string, clusterIDs ...string) { ctx, cancel := context.WithTimeout(ctx, 20*time.Minute) t.Cleanup(cancel) @@ -83,6 +84,9 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI require.Fail(t, "no cluster IDs provided for the test, please provide at least one cluster ID") } + sa, err := serviceaccount.Decode(serviceAccountKey) + require.NoError(t, err) + ctx = kubernetes.WrapContext(ctx, t) logger := zaptest.NewLogger(t) @@ -118,7 +122,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI allExposedServices[i], allExposedServices[j] = allExposedServices[j], allExposedServices[i] }) - testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) + testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK) inaccessibleExposedServices := make([]*omni.ExposedService, 0, len(allExposedServices)) @@ -132,7 +136,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI } } - testAccess(ctx, t, logger, omniClient, inaccessibleExposedServices, http.StatusBadGateway) + testAccess(ctx, t, logger, sa.Key, inaccessibleExposedServices, http.StatusBadGateway) for _, deployment := range deploymentsToScaleDown { logger.Info("scale deployment back up", zap.String("deployment", deployment.deployment.Name), zap.String("clusterID", deployment.cluster.clusterID)) @@ -140,13 +144,13 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI kubernetes.ScaleDeployment(ctx, t, deployment.cluster.kubeClient, deployment.deployment.Namespace, deployment.deployment.Name, 1) } - testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) - testToggleFeature(ctx, t, logger, omniClient, clusters[0]) - testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, allServices[:len(allServices)/2]) + testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK) + testToggleFeature(ctx, t, logger, omniClient, sa.Key, clusters[0]) + testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, sa.Key, allServices[:len(allServices)/2]) } // testToggleFeature tests toggling off/on the workload proxy feature for a cluster. -func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, cluster clusterContext) { +func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, cluster clusterContext) { logger.Info("test turning off and on the feature for the cluster", zap.String("clusterID", cluster.clusterID)) setFeatureToggle := func(enabled bool) { @@ -172,14 +176,14 @@ func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, om services = services[:4] } - testAccess(ctx, t, logger, omniClient, services[:4], http.StatusNotFound) + testAccess(ctx, t, logger, saKey, services[:4], http.StatusNotFound) setFeatureToggle(true) - testAccess(ctx, t, logger, omniClient, services[:4], http.StatusOK) + testAccess(ctx, t, logger, saKey, services[:4], http.StatusOK) } -func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, services []serviceContext) { +func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, services []serviceContext) { logger.Info("test toggling Kubernetes service annotation for exposed services", zap.Int("numServices", len(services))) for _, service := range services { @@ -195,7 +199,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo exposedServices := xslices.Map(services, func(svc serviceContext) *omni.ExposedService { return svc.res }) - testAccess(ctx, t, logger, omniClient, exposedServices, http.StatusNotFound) + testAccess(ctx, t, logger, saKey, exposedServices, http.StatusNotFound) for _, service := range services { kubernetes.UpdateService(ctx, t, service.deployment.cluster.kubeClient, service.svc.Namespace, service.svc.Name, func(svc *corev1.Service) { @@ -223,7 +227,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo updatedServices := maps.Values(updatedServicesMap) - testAccess(ctx, t, logger, omniClient, updatedServices, http.StatusOK) + testAccess(ctx, t, logger, saKey, updatedServices, http.StatusOK) } func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, clusterID string) clusterContext { @@ -299,11 +303,15 @@ func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omni return cluster } -func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, exposedServices []*omni.ExposedService, expectedStatusCode int) { - keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, omniClient) +func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, saKey *pgp.Key, exposedServices []*omni.ExposedService, expectedStatusCode int) { + keyID := saKey.Fingerprint() + + signedIDBytes, err := saKey.Sign([]byte(keyID)) require.NoError(t, err) - logger.Debug("registered public key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64)) + keyIDSignatureBase64 := base64.StdEncoding.EncodeToString(signedIDBytes) + + logger.Debug("using SA key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64)) cookies := []*http.Cookie{ {Name: workloadproxy.PublicKeyIDCookie, Value: keyID}, diff --git a/internal/pkg/clientconfig/clientconfig.go b/internal/pkg/clientconfig/clientconfig.go deleted file mode 100644 index 57b7dec6..00000000 --- a/internal/pkg/clientconfig/clientconfig.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -// Package clientconfig holds the configuration for the test client for Omni API. -package clientconfig - -import ( - "context" - "crypto/md5" - "encoding/base64" - "fmt" - "net/http" - "runtime" - "slices" - "time" - - "github.com/cosi-project/runtime/pkg/safe" - "github.com/cosi-project/runtime/pkg/state" - "github.com/hashicorp/go-multierror" - "github.com/siderolabs/gen/containers" - authpb "github.com/siderolabs/go-api-signature/api/auth" - authcli "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/message" - "github.com/siderolabs/go-api-signature/pkg/pgp" - "github.com/siderolabs/go-api-signature/pkg/serviceaccount" - - "github.com/siderolabs/omni/client/api/omni/management" - "github.com/siderolabs/omni/client/pkg/access" - "github.com/siderolabs/omni/client/pkg/client" - authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" - "github.com/siderolabs/omni/internal/pkg/auth" - "github.com/siderolabs/omni/internal/pkg/auth/role" - omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount" -) - -const ( - DefaultServiceAccount = "integration" + access.ServiceAccountNameSuffix - defaultEmail = "test-user@siderolabs.com" -) - -type clientCacheKey struct { - role string - email string - skipUserRole bool -} - -type clientCacheValue struct { - client *client.Client - key *pgp.Key - err error -} - -// ClientConfig is a test client. -type ClientConfig struct { - endpoint string - serviceAccountKey string - clientCache containers.ConcurrentMap[clientCacheKey, clientCacheValue] -} - -// New creates a new test client config. -func New(endpoint, serviceAccountKey string) *ClientConfig { - return &ClientConfig{ - endpoint: endpoint, - serviceAccountKey: serviceAccountKey, - } -} - -// GetClient returns a test client for the default test email. -// -// Clients are cached by their configuration, so if a client with the -// given configuration was created before, the cached one will be returned. -func (t *ClientConfig) GetClient(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) { - return t.GetClientForEmail(ctx, DefaultServiceAccount, publicKeyOpts...) -} - -// GetClientForEmail returns a test client for the given email. -// -// Clients are cached by their configuration, so if a client with the -// given configuration was created before, the cached one will be returned. -func (t *ClientConfig) GetClientForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) { - cacheKey := t.buildCacheKey(email, publicKeyOpts) - - // The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close]. - cliValue, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue { - cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey) - - return clientCacheValue{ - client: cli, - key: key, - err: err, - } - }) - - return cliValue.client, cliValue.err -} - -// GetKey fetches service account key for the default email. -func (t *ClientConfig) GetKey(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) { - return t.GetKeyForEmail(ctx, DefaultServiceAccount, publicKeyOpts...) -} - -// GetKeyForEmail fetches service account key for the specified email. -func (t *ClientConfig) GetKeyForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) { - cacheKey := t.buildCacheKey(email, publicKeyOpts) - - // The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close]. - cliOrErr, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue { - cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey) - - return clientCacheValue{ - client: cli, - key: key, - err: err, - } - }) - - return cliOrErr.key, cliOrErr.err -} - -// Close closes all the clients created by this config. -func (t *ClientConfig) Close() error { - var multiErr error - - t.clientCache.ForEach(func(_ clientCacheKey, cliOrErr clientCacheValue) { - if cliOrErr.client != nil { - if err := cliOrErr.client.Close(); err != nil { - multiErr = multierror.Append(multiErr, err) - } - } - }) - - return multiErr -} - -func (t *ClientConfig) buildCacheKey(email string, publicKeyOpts []authcli.RegisterPGPPublicKeyOption) clientCacheKey { - var req authpb.RegisterPublicKeyRequest - - for _, o := range publicKeyOpts { - o(&req) - } - - return clientCacheKey{ - role: req.Role, - email: email, - skipUserRole: req.SkipUserRole, - } -} - -// SignHTTPRequest signs the regular HTTP request using the default test email. -func SignHTTPRequest(ctx context.Context, client *client.Client, req *http.Request) error { - return SignHTTPRequestWithEmail(ctx, client, req, defaultEmail) -} - -// SignHTTPRequestWithEmail signs the regular HTTP request using the given email. -func SignHTTPRequestWithEmail(ctx context.Context, client *client.Client, req *http.Request, email string) error { - newKey, err := pgp.GenerateKey("", "", email, 4*time.Hour) - if err != nil { - return err - } - - err = registerKey(ctx, client.Auth(), newKey, email) - if err != nil { - return err - } - - msg, err := message.NewHTTP(req) - if err != nil { - return err - } - - return msg.Sign(email, newKey) -} - -// RegisterKeyGetIDSignatureBase64 registers a new public key with the default test email and returns its ID and the base-64 encoded signature of the same ID. -func RegisterKeyGetIDSignatureBase64(ctx context.Context, client *client.Client) (id, idSignatureBase66 string, err error) { - newKey, err := pgp.GenerateKey("", "", defaultEmail, 4*time.Hour) - if err != nil { - return "", "", err - } - - err = registerKey(ctx, client.Auth(), newKey, defaultEmail) - if err != nil { - return "", "", err - } - - id = newKey.Fingerprint() - - signedIDBytes, err := newKey.Sign([]byte(id)) - if err != nil { - return "", "", err - } - - idSignatureBase66 = base64.StdEncoding.EncodeToString(signedIDBytes) - - return id, idSignatureBase66, nil -} - -// CreateServiceAccount using the direct access to the Omni state. -func CreateServiceAccount(ctx context.Context, name string, st state.State) (string, error) { - // generate a new PGP key with long lifetime - comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) - - serviceAccountEmail := name + access.ServiceAccountNameSuffix - - key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) - if err != nil { - return "", err - } - - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return "", err - } - - identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail) - if err != nil && !state.IsNotFoundError(err) { - return "", err - } - - if identity != nil { - err = omnisa.Destroy(ctx, st, name) - if err != nil { - return "", err - } - } - - _, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey)) - if err != nil { - return "", err - } - - return serviceaccount.Encode(name, key) -} - -func createServiceAccountClient(ctx context.Context, endpoint, serviceAccountKey string, cacheKey clientCacheKey) (*client.Client, *pgp.Key, error) { - rootClient, err := client.New(endpoint, client.WithServiceAccount(serviceAccountKey)) - if err != nil { - return nil, nil, err - } - - defer rootClient.Close() //nolint:errcheck - - name := fmt.Sprintf("%x", md5.Sum([]byte(cacheKey.email+cacheKey.role))) - - // generate a new PGP key with long lifetime - comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) - - suffix := access.ServiceAccountNameSuffix - - if cacheKey.role == string(role.InfraProvider) { - suffix = access.InfraProviderServiceAccountNameSuffix - } - - serviceAccountEmail := name + suffix - - key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) - if err != nil { - return nil, nil, err - } - - if cacheKey.role == string(role.InfraProvider) { - name = access.InfraProviderServiceAccountPrefix + name - } - - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return nil, nil, err - } - - serviceAccounts, err := rootClient.Management().ListServiceAccounts(ctx) - if err != nil { - return nil, nil, err - } - - if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool { - return account.Name == name - }) != -1 { - if err = rootClient.Management().DestroyServiceAccount(ctx, name); err != nil { - return nil, nil, err - } - } - - // create service account with the generated key - _, err = rootClient.Management().CreateServiceAccount(ctx, name, armoredPublicKey, cacheKey.role, cacheKey.role == "") - if err != nil { - return nil, nil, err - } - - encodedKey, err := serviceaccount.Encode(name, key) - if err != nil { - return nil, nil, err - } - - cli, err := client.New(endpoint, client.WithServiceAccount(encodedKey)) - if err != nil { - return nil, nil, err - } - - return cli, key, nil -} diff --git a/internal/pkg/clientconfig/register_key_debug.go b/internal/pkg/clientconfig/register_key_debug.go deleted file mode 100644 index 7583c741..00000000 --- a/internal/pkg/clientconfig/register_key_debug.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -//go:build sidero.debug - -package clientconfig - -import ( - "context" - "time" - - "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/pgp" - "google.golang.org/grpc/metadata" - - grpcomni "github.com/siderolabs/omni/internal/backend/grpc" -) - -func registerKey(ctx context.Context, cli *auth.Client, key *pgp.Key, email string, opts ...auth.RegisterPGPPublicKeyOption) error { - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return err - } - - _, err = cli.RegisterPGPPublicKey(ctx, email, []byte(armoredPublicKey), opts...) - if err != nil { - return err - } - - debugCtx := metadata.AppendToOutgoingContext(ctx, grpcomni.DebugVerifiedEmailHeaderKey, email) - - err = cli.ConfirmPublicKey(debugCtx, key.Fingerprint()) - if err != nil { - return err - } - - timeoutCtx, timeoutCtxCancel := context.WithTimeout(ctx, 10*time.Second) - defer timeoutCtxCancel() - - return cli.AwaitPublicKeyConfirmation(timeoutCtx, key.Fingerprint()) -} diff --git a/internal/pkg/clientconfig/register_key_no_debug.go b/internal/pkg/clientconfig/register_key_no_debug.go deleted file mode 100644 index c55266e5..00000000 --- a/internal/pkg/clientconfig/register_key_no_debug.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -//go:build !sidero.debug - -package clientconfig - -import ( - "context" - "errors" - - "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/pgp" -) - -func registerKey(context.Context, *auth.Client, *pgp.Key, string, ...auth.RegisterPGPPublicKeyOption) error { - return errors.New("registerKey is not implemented in non-debug builds") -}