diff --git a/cmd/make-cookies/main.go b/cmd/make-cookies/main.go index 866dcfee..389214f6 100644 --- a/cmd/make-cookies/main.go +++ b/cmd/make-cookies/main.go @@ -7,49 +7,43 @@ package main import ( - "context" + "encoding/base64" "fmt" + "log" "net/http" - "os" + + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/siderolabs/omni/internal/backend/services/workloadproxy" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) func main() { if err := app(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + log.Fatalf("failed to create cookies: %v", err) } } func app() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // don't forget to build this with the -tags=sidero.debug - if len(os.Args) != 2 { - return fmt.Errorf("usage: %s ", os.Args[0]) + _, saKey := serviceaccount.GetFromEnv() + if saKey == "" { + return fmt.Errorf("no service account key found in environment variables") } - cfg := clientconfig.New(os.Args[1], os.Getenv("OMNI_SERVICE_ACCOUNT_KEY")) - defer cfg.Close() //nolint:errcheck - - client, err := cfg.GetClient(ctx) + sa, err := serviceaccount.Decode(saKey) if err != nil { - return fmt.Errorf("error getting client: %w", err) + return fmt.Errorf("error decoding service account key: %w", err) } - defer client.Close() //nolint:errcheck + keyID := sa.Key.Fingerprint() - keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, client) + signedIDBytes, err := sa.Key.Sign([]byte(keyID)) if err != nil { - return fmt.Errorf("error registering key: %w", err) + return fmt.Errorf("error signing key ID: %w", err) } cookies := []*http.Cookie{ {Name: workloadproxy.PublicKeyIDCookie, Value: keyID}, - {Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: keyIDSignatureBase64}, + {Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: base64.StdEncoding.EncodeToString(signedIDBytes)}, } for _, cookie := range cookies { diff --git a/internal/integration/auth_test.go b/internal/integration/auth_test.go index 921243d0..7998f34e 100644 --- a/internal/integration/auth_test.go +++ b/internal/integration/auth_test.go @@ -10,6 +10,7 @@ package integration_test import ( "bytes" "context" + "crypto/md5" _ "embed" "encoding/base64" "encoding/json" @@ -23,6 +24,7 @@ import ( "slices" "strconv" "strings" + "sync" "testing" "time" @@ -36,7 +38,6 @@ import ( "github.com/google/uuid" "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" - authcli "github.com/siderolabs/go-api-signature/pkg/client/auth" "github.com/siderolabs/go-api-signature/pkg/client/interceptor" "github.com/siderolabs/go-api-signature/pkg/message" "github.com/siderolabs/go-api-signature/pkg/pgp" @@ -73,10 +74,114 @@ import ( "github.com/siderolabs/omni/internal/backend/runtime/omni/validated" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/role" - "github.com/siderolabs/omni/internal/pkg/clientconfig" "github.com/siderolabs/omni/internal/pkg/grpcutil" ) +// testClientFactory creates test clients with specific roles for authorization testing. +// It uses the root client (automation SA) to create new service accounts with the specified role, +// then returns a client authenticated as that SA. Clients are cached by role. +type testClientFactory struct { + endpoint string + serviceAccountKey string + rootCli *client.Client + + mu sync.Mutex + clients map[role.Role]*client.Client +} + +func newTestClientFactory(endpoint string, rootCli *client.Client) *testClientFactory { + return &testClientFactory{ + endpoint: endpoint, + rootCli: rootCli, + clients: make(map[role.Role]*client.Client), + } +} + +func (f *testClientFactory) getClient(ctx context.Context, r role.Role) (*client.Client, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if cli, ok := f.clients[r]; ok { + return cli, nil + } + + cli, err := f.createClientForRole(ctx, r) + if err != nil { + return nil, err + } + + f.clients[r] = cli + + return cli, nil +} + +func (f *testClientFactory) createClientForRole(ctx context.Context, r role.Role) (*client.Client, error) { + name := fmt.Sprintf("%x", md5.Sum([]byte(string(r)))) + + comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) + + suffix := access.ServiceAccountNameSuffix + if r == role.InfraProvider { + suffix = access.InfraProviderServiceAccountNameSuffix + } + + serviceAccountEmail := name + suffix + + key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) + if err != nil { + return nil, err + } + + if r == role.InfraProvider { + name = access.InfraProviderServiceAccountPrefix + name + } + + armoredPublicKey, err := key.ArmorPublic() + if err != nil { + return nil, err + } + + serviceAccounts, err := f.rootCli.Management().ListServiceAccounts(ctx) + if err != nil { + return nil, err + } + + if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool { + return account.Name == name + }) != -1 { + if err = f.rootCli.Management().DestroyServiceAccount(ctx, name); err != nil { + return nil, err + } + } + + _, err = f.rootCli.Management().CreateServiceAccount(ctx, name, armoredPublicKey, string(r), false) + if err != nil { + return nil, err + } + + encodedKey, err := serviceaccount.Encode(name, key) + if err != nil { + return nil, err + } + + return client.New(f.endpoint, client.WithServiceAccount(encodedKey)) +} + +func (f *testClientFactory) close() error { + f.mu.Lock() + defer f.mu.Unlock() + + var errs []error + + for _, cli := range f.clients { + if err := cli.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + // AssertAnonymousAuthentication tests the authentication without any credentials. func AssertAnonymousAuthentication(testCtx context.Context, client *client.Client) TestFunc { return func(t *testing.T) { @@ -336,7 +441,7 @@ type apiAuthzTestCase struct { // AssertAPIAuthz tests the authorization checks of the API endpoints. // //nolint:gocognit,gocyclo,cyclop,maintidx -func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, clusterName string) TestFunc { +func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory, clusterName string) TestFunc { rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String())) assertSuccess := func(t *testing.T, err error) { @@ -555,13 +660,10 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi for _, tc := range testCases { // test each test case without signature t.Run(fmt.Sprintf("%s-no-signature", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient(rootCtx) - require.NoError(t, testErr) - // skip signing the request ctx := context.WithValue(rootCtx, interceptor.SkipInterceptorContextKey{}, struct{}{}) - testErr = tc.fn(ctx, scopedClient) + testErr := tc.fn(ctx, rootCli) // public resources will either succeed or fail with a permission denied if they are read-only resources if tc.isPublic { @@ -586,11 +688,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi // test with the role which should succeed t.Run(fmt.Sprintf("%s-success", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(tc.requiredRole)), - authcli.WithSkipUserRole(true), - ) + scopedClient, testErr := clientFactory.getClient(rootCtx, tc.requiredRole) require.NoError(t, testErr) assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), tc.requiredRole) @@ -613,10 +711,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi require.NoError(t, err) t.Run(fmt.Sprintf("%s-failure", tc.namePrefix), func(t *testing.T) { - scopedClient, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(failureRole)), - authcli.WithSkipUserRole(true)) + scopedClient, testErr := clientFactory.getClient(rootCtx, failureRole) require.NoError(t, testErr) assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), failureRole) @@ -641,7 +736,7 @@ type resourceAuthzTestCase struct { // AssertResourceAuthz tests the authorization checks of the resources (state). // //nolint:gocognit,gocyclo,cyclop,maintidx -func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig) TestFunc { +func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory) TestFunc { rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String())) return func(t *testing.T) { @@ -1247,11 +1342,7 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client delete(untestedResourceTypes, tc.resource.Metadata().Type()) t.Run(name, func(t *testing.T) { - scopedCli, testErr := clientConfig.GetClient( - rootCtx, - authcli.WithRole(string(testRole)), - authcli.WithSkipUserRole(true), - ) + scopedCli, testErr := clientFactory.getClient(rootCtx, testRole) require.NoError(t, testErr) // ensure that scopedCli is operating with the correct role @@ -1377,12 +1468,13 @@ var ( const grpcMetadataPrefix = "Grpc-Metadata-" -func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, httpEndpoint, clusterName string) TestFunc { +func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, serviceAccountKey, httpEndpoint, clusterName string) TestFunc { return func(t *testing.T) { - key, err := clientConfig.GetKey(ctx) + sa, err := serviceaccount.Decode(serviceAccountKey) require.NoError(t, err) - email := clientconfig.DefaultServiceAccount + key := sa.Key + email := sa.Name + access.ServiceAccountNameSuffix // do the same flow for the signature as in the JS code signRequest := func(request *http.Request) error { diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index e9cbb006..7864da26 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "net" - "net/http" "strings" "testing" "time" @@ -29,7 +28,6 @@ import ( "github.com/siderolabs/omni/client/pkg/client" "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/runtime/talos" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) func resourceDetails(res resource.Resource) string { @@ -266,9 +264,6 @@ type WipeAMachineFunc func(ctx context.Context, uuid string) error // FreezeAMachineFunc is a function to freeze a machine by UUID. type FreezeAMachineFunc func(ctx context.Context, uuid string) error -// HTTPRequestSignerFunc is function to sign the HTTP request. -type HTTPRequestSignerFunc func(ctx context.Context, req *http.Request) error - // Options for the test runner. // //nolint:govet @@ -340,8 +335,8 @@ type MachineProviderConfig struct { // TestOptions constains all common data that might be required to run the tests. type TestOptions struct { Options - omniClient *client.Client - clientConfig *clientconfig.ClientConfig + omniClient *client.Client + serviceAccountKey string machineSemaphore *semaphore.Weighted } diff --git a/internal/integration/image_test.go b/internal/integration/image_test.go index 77f3695a..b3a99dab 100644 --- a/internal/integration/image_test.go +++ b/internal/integration/image_test.go @@ -16,21 +16,26 @@ import ( "time" "github.com/cosi-project/runtime/pkg/safe" + "github.com/siderolabs/go-api-signature/pkg/message" + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/stretchr/testify/require" "github.com/siderolabs/omni/client/api/omni/management" - "github.com/siderolabs/omni/client/pkg/client" + "github.com/siderolabs/omni/client/pkg/access" clientconsts "github.com/siderolabs/omni/client/pkg/constants" "github.com/siderolabs/omni/client/pkg/omni/resources/omni" ) // AssertSomeImagesAreDownloadable verifies generated image download. -func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Client, signer HTTPRequestSignerFunc, httpEndpoint string) TestFunc { - st := client.Omni().State() +func AssertSomeImagesAreDownloadable(testCtx context.Context, options *TestOptions) TestFunc { + st := options.omniClient.Omni().State() return func(t *testing.T) { t.Parallel() + sa, err := serviceaccount.Decode(options.serviceAccountKey) + require.NoError(t, err) + media, err := safe.StateListAll[*omni.InstallationMedia](testCtx, st) require.NoError(t, err) @@ -60,10 +65,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli ctx, cancel := context.WithTimeout(testCtx, time.Minute*5) defer cancel() - u, err := url.Parse(httpEndpoint) + u, err := url.Parse(options.HTTPEndpoint) require.NoError(t, err) - schematic, err := client.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{ + schematic, err := options.omniClient.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{ MediaId: image.Metadata().ID(), TalosVersion: clientconsts.DefaultTalosVersion, }) @@ -75,7 +80,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) require.NoError(t, err) - require.NoError(t, signer(ctx, req)) + msg, err := message.NewHTTP(req) + require.NoError(t, err) + + require.NoError(t, msg.Sign(sa.Name+access.ServiceAccountNameSuffix, sa.Key)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index b305d009..9fdb2c53 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -16,14 +16,16 @@ import ( "net/url" "os" "os/exec" + "runtime" "testing" "time" "github.com/cosi-project/runtime/pkg/resource/rtestutils" "github.com/cosi-project/runtime/pkg/safe" - "github.com/cosi-project/runtime/pkg/state" + cosistate "github.com/cosi-project/runtime/pkg/state" "github.com/mattn/go-shellwords" "github.com/prometheus/client_golang/prometheus" + "github.com/siderolabs/go-api-signature/pkg/pgp" "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,14 +36,19 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" + "github.com/siderolabs/omni/client/pkg/access" + "github.com/siderolabs/omni/client/pkg/client" clientconsts "github.com/siderolabs/omni/client/pkg/constants" + authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/client/pkg/omni/resources/siderolink" _ "github.com/siderolabs/omni/cmd/acompat" // this package should always be imported first for init->set env to work "github.com/siderolabs/omni/cmd/omni/pkg/app" "github.com/siderolabs/omni/internal/backend/runtime/omni" + "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" - "github.com/siderolabs/omni/internal/pkg/clientconfig" + "github.com/siderolabs/omni/internal/pkg/auth/role" + omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount" "github.com/siderolabs/omni/internal/pkg/config" "github.com/siderolabs/omni/internal/pkg/constants" ) @@ -197,13 +204,7 @@ func TestIntegration(t *testing.T) { // Talos API calls try to use user auth if the service account var is not set os.Setenv(serviceaccount.OmniServiceAccountKeyEnvVar, serviceAccount) - clientConfig := clientconfig.New(omniEndpoint, serviceAccount) - - t.Cleanup(func() { - clientConfig.Close() //nolint:errcheck - }) - - rootClient, err := clientConfig.GetClient(t.Context()) + rootClient, err := client.New(omniEndpoint, client.WithServiceAccount(serviceAccount)) require.NoError(t, err) t.Cleanup(func() { @@ -211,10 +212,10 @@ func TestIntegration(t *testing.T) { }) testOptions := &TestOptions{ - omniClient: rootClient, - Options: options, - machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)), - clientConfig: clientConfig, + omniClient: rootClient, + Options: options, + machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)), + serviceAccountKey: serviceAccount, } preRunHooks(t, testOptions) @@ -390,7 +391,7 @@ func postRunHooks(t *testing.T, options *TestOptions) { } } -func cleanupLinksFunc(ctx context.Context, st state.State) error { +func cleanupLinksFunc(ctx context.Context, st cosistate.State) error { links, err := safe.ReaderListAll[*siderolink.Link](ctx, st) if err != nil { return err @@ -403,7 +404,7 @@ func cleanupLinksFunc(ctx context.Context, st state.State) error { return links.ForEachErr(func(r *siderolink.Link) error { err := st.TeardownAndDestroy(ctx, r.Metadata()) - if err != nil && !state.IsNotFoundError(err) { + if err != nil && !cosistate.IsNotFoundError(err) { return err } @@ -503,7 +504,7 @@ func runOmni(t *testing.T) (string, error) { rtestutils.AssertResources(ctx, t, state.Default(), []string{talosVersion}, func(*omnires.TalosVersion, *assert.Assertions) {}) - sa, err := clientconfig.CreateServiceAccount(omniCtx, "root", state.Default()) + sa, err := createBootstrapServiceAccount(omniCtx, "root", state.Default()) if err != nil { return "", err } @@ -513,3 +514,38 @@ func runOmni(t *testing.T) (string, error) { return sa, nil } + +func createBootstrapServiceAccount(ctx context.Context, name string, st cosistate.State) (string, error) { + comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) + + serviceAccountEmail := name + access.ServiceAccountNameSuffix + + key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) + if err != nil { + return "", err + } + + armoredPublicKey, err := key.ArmorPublic() + if err != nil { + return "", err + } + + identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail) + if err != nil && !cosistate.IsNotFoundError(err) { + return "", err + } + + if identity != nil { + err = omnisa.Destroy(ctx, st, name) + if err != nil { + return "", err + } + } + + _, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey)) + if err != nil { + return "", err + } + + return serviceaccount.Encode(name, key) +} diff --git a/internal/integration/suites_test.go b/internal/integration/suites_test.go index 6a234367..f2f4d7b2 100644 --- a/internal/integration/suites_test.go +++ b/internal/integration/suites_test.go @@ -8,8 +8,6 @@ package integration_test import ( - "context" - "net/http" "testing" "time" @@ -19,7 +17,6 @@ import ( "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/extensions" "github.com/siderolabs/omni/internal/integration/workloadproxy" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) type assertClusterReadyOptions struct { @@ -128,9 +125,7 @@ Generate various Talos images with Omni and try to download them.`) t.Run( "TalosImagesShouldBeDownloadable", - AssertSomeImagesAreDownloadable(t.Context(), options.omniClient, func(ctx context.Context, req *http.Request) error { - return clientconfig.SignHTTPRequest(ctx, options.omniClient, req) - }, options.HTTPEndpoint), + AssertSomeImagesAreDownloadable(t.Context(), options), ) } } @@ -1159,9 +1154,14 @@ Test authorization on accessing Omni API, some tests run without a cluster, some AssertServiceAccountAPIFlow(t.Context(), options.omniClient), ) + clientFactory := newTestClientFactory(omniEndpoint, options.omniClient) + t.Cleanup(func() { + clientFactory.close() //nolint:errcheck + }) + t.Run( "ResourceAuthzShouldWork", - AssertResourceAuthz(t.Context(), options.omniClient, options.clientConfig), + AssertResourceAuthz(t.Context(), options.omniClient, clientFactory), ) t.Run( @@ -1193,12 +1193,12 @@ Test authorization on accessing Omni API, some tests run without a cluster, some t.Run( "APIAuthorizationShouldBeTested", - AssertAPIAuthz(t.Context(), options.omniClient, options.clientConfig, clusterName), + AssertAPIAuthz(t.Context(), options.omniClient, clientFactory, clusterName), ) t.Run( "FrontendAPIShouldBeTested", - AssertFrontendResourceAPI(t.Context(), options.omniClient, options.clientConfig, options.HTTPEndpoint, clusterName), + AssertFrontendResourceAPI(t.Context(), options.omniClient, options.serviceAccountKey, options.HTTPEndpoint, clusterName), ) t.Run( @@ -1278,7 +1278,7 @@ Test workload service proxying feature`) parentCtx := t.Context() t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, cluster1, cluster2) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, cluster1, cluster2) }) t.Run("ClusterShouldBeDestroyed-"+cluster1, AssertDestroyCluster(t.Context(), options.omniClient.Omni().State(), cluster1, false, false)) @@ -1457,7 +1457,7 @@ Test Omni upgrades, the first half that runs on the previous Omni version parentCtx := t.Context() t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, clusterName) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName) }) t.Run("SaveClusterSnapshot", SaveClusterSnapshot(t.Context(), omniClient, clusterName)) @@ -1494,7 +1494,7 @@ Test Omni upgrades, the second half that runs on the current Omni version t.Run("AssertMachinesNotRebootedConfigUnchanged", AssertClusterSnapshot(t.Context(), omniClient, clusterName)) t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) { - workloadproxy.Test(parentCtx, t, omniClient, clusterName) + workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName) }) t.Run( diff --git a/internal/integration/workloadproxy/workloadproxy.go b/internal/integration/workloadproxy/workloadproxy.go index 5c9ffb66..fd9b2706 100644 --- a/internal/integration/workloadproxy/workloadproxy.go +++ b/internal/integration/workloadproxy/workloadproxy.go @@ -31,6 +31,8 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" + "github.com/siderolabs/go-api-signature/pkg/pgp" + "github.com/siderolabs/go-api-signature/pkg/serviceaccount" "github.com/siderolabs/go-pointer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,7 +50,6 @@ import ( "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/services/workloadproxy" "github.com/siderolabs/omni/internal/integration/kubernetes" - "github.com/siderolabs/omni/internal/pkg/clientconfig" ) type serviceContext struct { @@ -75,7 +76,7 @@ var sideroLabsIconSVG []byte // Test tests the exposed services functionality in Omni. // //nolint:prealloc -func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterIDs ...string) { +func Test(ctx context.Context, t *testing.T, omniClient *client.Client, serviceAccountKey string, clusterIDs ...string) { ctx, cancel := context.WithTimeout(ctx, 20*time.Minute) t.Cleanup(cancel) @@ -83,6 +84,9 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI require.Fail(t, "no cluster IDs provided for the test, please provide at least one cluster ID") } + sa, err := serviceaccount.Decode(serviceAccountKey) + require.NoError(t, err) + ctx = kubernetes.WrapContext(ctx, t) logger := zaptest.NewLogger(t) @@ -118,7 +122,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI allExposedServices[i], allExposedServices[j] = allExposedServices[j], allExposedServices[i] }) - testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) + testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK) inaccessibleExposedServices := make([]*omni.ExposedService, 0, len(allExposedServices)) @@ -132,7 +136,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI } } - testAccess(ctx, t, logger, omniClient, inaccessibleExposedServices, http.StatusBadGateway) + testAccess(ctx, t, logger, sa.Key, inaccessibleExposedServices, http.StatusBadGateway) for _, deployment := range deploymentsToScaleDown { logger.Info("scale deployment back up", zap.String("deployment", deployment.deployment.Name), zap.String("clusterID", deployment.cluster.clusterID)) @@ -140,13 +144,13 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI kubernetes.ScaleDeployment(ctx, t, deployment.cluster.kubeClient, deployment.deployment.Namespace, deployment.deployment.Name, 1) } - testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) - testToggleFeature(ctx, t, logger, omniClient, clusters[0]) - testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, allServices[:len(allServices)/2]) + testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK) + testToggleFeature(ctx, t, logger, omniClient, sa.Key, clusters[0]) + testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, sa.Key, allServices[:len(allServices)/2]) } // testToggleFeature tests toggling off/on the workload proxy feature for a cluster. -func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, cluster clusterContext) { +func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, cluster clusterContext) { logger.Info("test turning off and on the feature for the cluster", zap.String("clusterID", cluster.clusterID)) setFeatureToggle := func(enabled bool) { @@ -172,14 +176,14 @@ func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, om services = services[:4] } - testAccess(ctx, t, logger, omniClient, services[:4], http.StatusNotFound) + testAccess(ctx, t, logger, saKey, services[:4], http.StatusNotFound) setFeatureToggle(true) - testAccess(ctx, t, logger, omniClient, services[:4], http.StatusOK) + testAccess(ctx, t, logger, saKey, services[:4], http.StatusOK) } -func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, services []serviceContext) { +func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, services []serviceContext) { logger.Info("test toggling Kubernetes service annotation for exposed services", zap.Int("numServices", len(services))) for _, service := range services { @@ -195,7 +199,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo exposedServices := xslices.Map(services, func(svc serviceContext) *omni.ExposedService { return svc.res }) - testAccess(ctx, t, logger, omniClient, exposedServices, http.StatusNotFound) + testAccess(ctx, t, logger, saKey, exposedServices, http.StatusNotFound) for _, service := range services { kubernetes.UpdateService(ctx, t, service.deployment.cluster.kubeClient, service.svc.Namespace, service.svc.Name, func(svc *corev1.Service) { @@ -223,7 +227,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo updatedServices := maps.Values(updatedServicesMap) - testAccess(ctx, t, logger, omniClient, updatedServices, http.StatusOK) + testAccess(ctx, t, logger, saKey, updatedServices, http.StatusOK) } func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, clusterID string) clusterContext { @@ -299,11 +303,15 @@ func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omni return cluster } -func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, exposedServices []*omni.ExposedService, expectedStatusCode int) { - keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, omniClient) +func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, saKey *pgp.Key, exposedServices []*omni.ExposedService, expectedStatusCode int) { + keyID := saKey.Fingerprint() + + signedIDBytes, err := saKey.Sign([]byte(keyID)) require.NoError(t, err) - logger.Debug("registered public key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64)) + keyIDSignatureBase64 := base64.StdEncoding.EncodeToString(signedIDBytes) + + logger.Debug("using SA key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64)) cookies := []*http.Cookie{ {Name: workloadproxy.PublicKeyIDCookie, Value: keyID}, diff --git a/internal/pkg/clientconfig/clientconfig.go b/internal/pkg/clientconfig/clientconfig.go deleted file mode 100644 index 57b7dec6..00000000 --- a/internal/pkg/clientconfig/clientconfig.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -// Package clientconfig holds the configuration for the test client for Omni API. -package clientconfig - -import ( - "context" - "crypto/md5" - "encoding/base64" - "fmt" - "net/http" - "runtime" - "slices" - "time" - - "github.com/cosi-project/runtime/pkg/safe" - "github.com/cosi-project/runtime/pkg/state" - "github.com/hashicorp/go-multierror" - "github.com/siderolabs/gen/containers" - authpb "github.com/siderolabs/go-api-signature/api/auth" - authcli "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/message" - "github.com/siderolabs/go-api-signature/pkg/pgp" - "github.com/siderolabs/go-api-signature/pkg/serviceaccount" - - "github.com/siderolabs/omni/client/api/omni/management" - "github.com/siderolabs/omni/client/pkg/access" - "github.com/siderolabs/omni/client/pkg/client" - authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" - "github.com/siderolabs/omni/internal/pkg/auth" - "github.com/siderolabs/omni/internal/pkg/auth/role" - omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount" -) - -const ( - DefaultServiceAccount = "integration" + access.ServiceAccountNameSuffix - defaultEmail = "test-user@siderolabs.com" -) - -type clientCacheKey struct { - role string - email string - skipUserRole bool -} - -type clientCacheValue struct { - client *client.Client - key *pgp.Key - err error -} - -// ClientConfig is a test client. -type ClientConfig struct { - endpoint string - serviceAccountKey string - clientCache containers.ConcurrentMap[clientCacheKey, clientCacheValue] -} - -// New creates a new test client config. -func New(endpoint, serviceAccountKey string) *ClientConfig { - return &ClientConfig{ - endpoint: endpoint, - serviceAccountKey: serviceAccountKey, - } -} - -// GetClient returns a test client for the default test email. -// -// Clients are cached by their configuration, so if a client with the -// given configuration was created before, the cached one will be returned. -func (t *ClientConfig) GetClient(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) { - return t.GetClientForEmail(ctx, DefaultServiceAccount, publicKeyOpts...) -} - -// GetClientForEmail returns a test client for the given email. -// -// Clients are cached by their configuration, so if a client with the -// given configuration was created before, the cached one will be returned. -func (t *ClientConfig) GetClientForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) { - cacheKey := t.buildCacheKey(email, publicKeyOpts) - - // The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close]. - cliValue, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue { - cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey) - - return clientCacheValue{ - client: cli, - key: key, - err: err, - } - }) - - return cliValue.client, cliValue.err -} - -// GetKey fetches service account key for the default email. -func (t *ClientConfig) GetKey(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) { - return t.GetKeyForEmail(ctx, DefaultServiceAccount, publicKeyOpts...) -} - -// GetKeyForEmail fetches service account key for the specified email. -func (t *ClientConfig) GetKeyForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) { - cacheKey := t.buildCacheKey(email, publicKeyOpts) - - // The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close]. - cliOrErr, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue { - cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey) - - return clientCacheValue{ - client: cli, - key: key, - err: err, - } - }) - - return cliOrErr.key, cliOrErr.err -} - -// Close closes all the clients created by this config. -func (t *ClientConfig) Close() error { - var multiErr error - - t.clientCache.ForEach(func(_ clientCacheKey, cliOrErr clientCacheValue) { - if cliOrErr.client != nil { - if err := cliOrErr.client.Close(); err != nil { - multiErr = multierror.Append(multiErr, err) - } - } - }) - - return multiErr -} - -func (t *ClientConfig) buildCacheKey(email string, publicKeyOpts []authcli.RegisterPGPPublicKeyOption) clientCacheKey { - var req authpb.RegisterPublicKeyRequest - - for _, o := range publicKeyOpts { - o(&req) - } - - return clientCacheKey{ - role: req.Role, - email: email, - skipUserRole: req.SkipUserRole, - } -} - -// SignHTTPRequest signs the regular HTTP request using the default test email. -func SignHTTPRequest(ctx context.Context, client *client.Client, req *http.Request) error { - return SignHTTPRequestWithEmail(ctx, client, req, defaultEmail) -} - -// SignHTTPRequestWithEmail signs the regular HTTP request using the given email. -func SignHTTPRequestWithEmail(ctx context.Context, client *client.Client, req *http.Request, email string) error { - newKey, err := pgp.GenerateKey("", "", email, 4*time.Hour) - if err != nil { - return err - } - - err = registerKey(ctx, client.Auth(), newKey, email) - if err != nil { - return err - } - - msg, err := message.NewHTTP(req) - if err != nil { - return err - } - - return msg.Sign(email, newKey) -} - -// RegisterKeyGetIDSignatureBase64 registers a new public key with the default test email and returns its ID and the base-64 encoded signature of the same ID. -func RegisterKeyGetIDSignatureBase64(ctx context.Context, client *client.Client) (id, idSignatureBase66 string, err error) { - newKey, err := pgp.GenerateKey("", "", defaultEmail, 4*time.Hour) - if err != nil { - return "", "", err - } - - err = registerKey(ctx, client.Auth(), newKey, defaultEmail) - if err != nil { - return "", "", err - } - - id = newKey.Fingerprint() - - signedIDBytes, err := newKey.Sign([]byte(id)) - if err != nil { - return "", "", err - } - - idSignatureBase66 = base64.StdEncoding.EncodeToString(signedIDBytes) - - return id, idSignatureBase66, nil -} - -// CreateServiceAccount using the direct access to the Omni state. -func CreateServiceAccount(ctx context.Context, name string, st state.State) (string, error) { - // generate a new PGP key with long lifetime - comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) - - serviceAccountEmail := name + access.ServiceAccountNameSuffix - - key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) - if err != nil { - return "", err - } - - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return "", err - } - - identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail) - if err != nil && !state.IsNotFoundError(err) { - return "", err - } - - if identity != nil { - err = omnisa.Destroy(ctx, st, name) - if err != nil { - return "", err - } - } - - _, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey)) - if err != nil { - return "", err - } - - return serviceaccount.Encode(name, key) -} - -func createServiceAccountClient(ctx context.Context, endpoint, serviceAccountKey string, cacheKey clientCacheKey) (*client.Client, *pgp.Key, error) { - rootClient, err := client.New(endpoint, client.WithServiceAccount(serviceAccountKey)) - if err != nil { - return nil, nil, err - } - - defer rootClient.Close() //nolint:errcheck - - name := fmt.Sprintf("%x", md5.Sum([]byte(cacheKey.email+cacheKey.role))) - - // generate a new PGP key with long lifetime - comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) - - suffix := access.ServiceAccountNameSuffix - - if cacheKey.role == string(role.InfraProvider) { - suffix = access.InfraProviderServiceAccountNameSuffix - } - - serviceAccountEmail := name + suffix - - key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime) - if err != nil { - return nil, nil, err - } - - if cacheKey.role == string(role.InfraProvider) { - name = access.InfraProviderServiceAccountPrefix + name - } - - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return nil, nil, err - } - - serviceAccounts, err := rootClient.Management().ListServiceAccounts(ctx) - if err != nil { - return nil, nil, err - } - - if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool { - return account.Name == name - }) != -1 { - if err = rootClient.Management().DestroyServiceAccount(ctx, name); err != nil { - return nil, nil, err - } - } - - // create service account with the generated key - _, err = rootClient.Management().CreateServiceAccount(ctx, name, armoredPublicKey, cacheKey.role, cacheKey.role == "") - if err != nil { - return nil, nil, err - } - - encodedKey, err := serviceaccount.Encode(name, key) - if err != nil { - return nil, nil, err - } - - cli, err := client.New(endpoint, client.WithServiceAccount(encodedKey)) - if err != nil { - return nil, nil, err - } - - return cli, key, nil -} diff --git a/internal/pkg/clientconfig/register_key_debug.go b/internal/pkg/clientconfig/register_key_debug.go deleted file mode 100644 index 7583c741..00000000 --- a/internal/pkg/clientconfig/register_key_debug.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -//go:build sidero.debug - -package clientconfig - -import ( - "context" - "time" - - "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/pgp" - "google.golang.org/grpc/metadata" - - grpcomni "github.com/siderolabs/omni/internal/backend/grpc" -) - -func registerKey(ctx context.Context, cli *auth.Client, key *pgp.Key, email string, opts ...auth.RegisterPGPPublicKeyOption) error { - armoredPublicKey, err := key.ArmorPublic() - if err != nil { - return err - } - - _, err = cli.RegisterPGPPublicKey(ctx, email, []byte(armoredPublicKey), opts...) - if err != nil { - return err - } - - debugCtx := metadata.AppendToOutgoingContext(ctx, grpcomni.DebugVerifiedEmailHeaderKey, email) - - err = cli.ConfirmPublicKey(debugCtx, key.Fingerprint()) - if err != nil { - return err - } - - timeoutCtx, timeoutCtxCancel := context.WithTimeout(ctx, 10*time.Second) - defer timeoutCtxCancel() - - return cli.AwaitPublicKeyConfirmation(timeoutCtx, key.Fingerprint()) -} diff --git a/internal/pkg/clientconfig/register_key_no_debug.go b/internal/pkg/clientconfig/register_key_no_debug.go deleted file mode 100644 index c55266e5..00000000 --- a/internal/pkg/clientconfig/register_key_no_debug.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2026 Sidero Labs, Inc. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. - -//go:build !sidero.debug - -package clientconfig - -import ( - "context" - "errors" - - "github.com/siderolabs/go-api-signature/pkg/client/auth" - "github.com/siderolabs/go-api-signature/pkg/pgp" -) - -func registerKey(context.Context, *auth.Client, *pgp.Key, string, ...auth.RegisterPGPPublicKeyOption) error { - return errors.New("registerKey is not implemented in non-debug builds") -}