test: use automation sa directly in integration tests

Instead of doing the fake user auth flow in the integration tests via the `clientconfig` package, use the automation service account directly. Remove all other usages of that package as well, and drop it completely.

The package predates the initial service account token feature of Omni, its purpose was to authenticate to the Omni API in the integration tests. We have the automation key now, so we don't need that anymore.

Signed-off-by: Utku Ozdemir <utku.ozdemir@siderolabs.com>
This commit is contained in:
Utku Ozdemir 2026-02-11 13:12:01 +01:00
parent 6102db4e1d
commit ef3e3bc1cc
No known key found for this signature in database
GPG Key ID: DBD13117B0A14E93
10 changed files with 235 additions and 467 deletions

View File

@ -7,49 +7,43 @@
package main
import (
"context"
"encoding/base64"
"fmt"
"log"
"net/http"
"os"
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"github.com/siderolabs/omni/internal/backend/services/workloadproxy"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
)
func main() {
if err := app(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
log.Fatalf("failed to create cookies: %v", err)
}
}
func app() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// don't forget to build this with the -tags=sidero.debug
if len(os.Args) != 2 {
return fmt.Errorf("usage: %s <endpoint>", os.Args[0])
_, saKey := serviceaccount.GetFromEnv()
if saKey == "" {
return fmt.Errorf("no service account key found in environment variables")
}
cfg := clientconfig.New(os.Args[1], os.Getenv("OMNI_SERVICE_ACCOUNT_KEY"))
defer cfg.Close() //nolint:errcheck
client, err := cfg.GetClient(ctx)
sa, err := serviceaccount.Decode(saKey)
if err != nil {
return fmt.Errorf("error getting client: %w", err)
return fmt.Errorf("error decoding service account key: %w", err)
}
defer client.Close() //nolint:errcheck
keyID := sa.Key.Fingerprint()
keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, client)
signedIDBytes, err := sa.Key.Sign([]byte(keyID))
if err != nil {
return fmt.Errorf("error registering key: %w", err)
return fmt.Errorf("error signing key ID: %w", err)
}
cookies := []*http.Cookie{
{Name: workloadproxy.PublicKeyIDCookie, Value: keyID},
{Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: keyIDSignatureBase64},
{Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: base64.StdEncoding.EncodeToString(signedIDBytes)},
}
for _, cookie := range cookies {

View File

@ -10,6 +10,7 @@ package integration_test
import (
"bytes"
"context"
"crypto/md5"
_ "embed"
"encoding/base64"
"encoding/json"
@ -23,6 +24,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -36,7 +38,6 @@ import (
"github.com/google/uuid"
"github.com/siderolabs/gen/maps"
"github.com/siderolabs/gen/xslices"
authcli "github.com/siderolabs/go-api-signature/pkg/client/auth"
"github.com/siderolabs/go-api-signature/pkg/client/interceptor"
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/siderolabs/go-api-signature/pkg/pgp"
@ -73,10 +74,114 @@ import (
"github.com/siderolabs/omni/internal/backend/runtime/omni/validated"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
"github.com/siderolabs/omni/internal/pkg/grpcutil"
)
// testClientFactory creates test clients with specific roles for authorization testing.
// It uses the root client (automation SA) to create new service accounts with the specified role,
// then returns a client authenticated as that SA. Clients are cached by role.
type testClientFactory struct {
endpoint string
serviceAccountKey string
rootCli *client.Client
mu sync.Mutex
clients map[role.Role]*client.Client
}
func newTestClientFactory(endpoint string, rootCli *client.Client) *testClientFactory {
return &testClientFactory{
endpoint: endpoint,
rootCli: rootCli,
clients: make(map[role.Role]*client.Client),
}
}
func (f *testClientFactory) getClient(ctx context.Context, r role.Role) (*client.Client, error) {
f.mu.Lock()
defer f.mu.Unlock()
if cli, ok := f.clients[r]; ok {
return cli, nil
}
cli, err := f.createClientForRole(ctx, r)
if err != nil {
return nil, err
}
f.clients[r] = cli
return cli, nil
}
func (f *testClientFactory) createClientForRole(ctx context.Context, r role.Role) (*client.Client, error) {
name := fmt.Sprintf("%x", md5.Sum([]byte(string(r))))
comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
suffix := access.ServiceAccountNameSuffix
if r == role.InfraProvider {
suffix = access.InfraProviderServiceAccountNameSuffix
}
serviceAccountEmail := name + suffix
key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime)
if err != nil {
return nil, err
}
if r == role.InfraProvider {
name = access.InfraProviderServiceAccountPrefix + name
}
armoredPublicKey, err := key.ArmorPublic()
if err != nil {
return nil, err
}
serviceAccounts, err := f.rootCli.Management().ListServiceAccounts(ctx)
if err != nil {
return nil, err
}
if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool {
return account.Name == name
}) != -1 {
if err = f.rootCli.Management().DestroyServiceAccount(ctx, name); err != nil {
return nil, err
}
}
_, err = f.rootCli.Management().CreateServiceAccount(ctx, name, armoredPublicKey, string(r), false)
if err != nil {
return nil, err
}
encodedKey, err := serviceaccount.Encode(name, key)
if err != nil {
return nil, err
}
return client.New(f.endpoint, client.WithServiceAccount(encodedKey))
}
func (f *testClientFactory) close() error {
f.mu.Lock()
defer f.mu.Unlock()
var errs []error
for _, cli := range f.clients {
if err := cli.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// AssertAnonymousAuthentication tests the authentication without any credentials.
func AssertAnonymousAuthentication(testCtx context.Context, client *client.Client) TestFunc {
return func(t *testing.T) {
@ -336,7 +441,7 @@ type apiAuthzTestCase struct {
// AssertAPIAuthz tests the authorization checks of the API endpoints.
//
//nolint:gocognit,gocyclo,cyclop,maintidx
func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, clusterName string) TestFunc {
func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory, clusterName string) TestFunc {
rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String()))
assertSuccess := func(t *testing.T, err error) {
@ -555,13 +660,10 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi
for _, tc := range testCases {
// test each test case without signature
t.Run(fmt.Sprintf("%s-no-signature", tc.namePrefix), func(t *testing.T) {
scopedClient, testErr := clientConfig.GetClient(rootCtx)
require.NoError(t, testErr)
// skip signing the request
ctx := context.WithValue(rootCtx, interceptor.SkipInterceptorContextKey{}, struct{}{})
testErr = tc.fn(ctx, scopedClient)
testErr := tc.fn(ctx, rootCli)
// public resources will either succeed or fail with a permission denied if they are read-only resources
if tc.isPublic {
@ -586,11 +688,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi
// test with the role which should succeed
t.Run(fmt.Sprintf("%s-success", tc.namePrefix), func(t *testing.T) {
scopedClient, testErr := clientConfig.GetClient(
rootCtx,
authcli.WithRole(string(tc.requiredRole)),
authcli.WithSkipUserRole(true),
)
scopedClient, testErr := clientFactory.getClient(rootCtx, tc.requiredRole)
require.NoError(t, testErr)
assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), tc.requiredRole)
@ -613,10 +711,7 @@ func AssertAPIAuthz(rootCtx context.Context, rootCli *client.Client, clientConfi
require.NoError(t, err)
t.Run(fmt.Sprintf("%s-failure", tc.namePrefix), func(t *testing.T) {
scopedClient, testErr := clientConfig.GetClient(
rootCtx,
authcli.WithRole(string(failureRole)),
authcli.WithSkipUserRole(true))
scopedClient, testErr := clientFactory.getClient(rootCtx, failureRole)
require.NoError(t, testErr)
assertCurrentUserRole(rootCtx, t, scopedClient.Omni().State(), failureRole)
@ -641,7 +736,7 @@ type resourceAuthzTestCase struct {
// AssertResourceAuthz tests the authorization checks of the resources (state).
//
//nolint:gocognit,gocyclo,cyclop,maintidx
func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig) TestFunc {
func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, clientFactory *testClientFactory) TestFunc {
rootCtx = metadata.NewOutgoingContext(rootCtx, metadata.Pairs(grpcutil.LogLevelOverrideMetadataKey, zapcore.PanicLevel.String()))
return func(t *testing.T) {
@ -1247,11 +1342,7 @@ func AssertResourceAuthz(rootCtx context.Context, rootCli *client.Client, client
delete(untestedResourceTypes, tc.resource.Metadata().Type())
t.Run(name, func(t *testing.T) {
scopedCli, testErr := clientConfig.GetClient(
rootCtx,
authcli.WithRole(string(testRole)),
authcli.WithSkipUserRole(true),
)
scopedCli, testErr := clientFactory.getClient(rootCtx, testRole)
require.NoError(t, testErr)
// ensure that scopedCli is operating with the correct role
@ -1377,12 +1468,13 @@ var (
const grpcMetadataPrefix = "Grpc-Metadata-"
func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, clientConfig *clientconfig.ClientConfig, httpEndpoint, clusterName string) TestFunc {
func AssertFrontendResourceAPI(ctx context.Context, rootCli *client.Client, serviceAccountKey, httpEndpoint, clusterName string) TestFunc {
return func(t *testing.T) {
key, err := clientConfig.GetKey(ctx)
sa, err := serviceaccount.Decode(serviceAccountKey)
require.NoError(t, err)
email := clientconfig.DefaultServiceAccount
key := sa.Key
email := sa.Name + access.ServiceAccountNameSuffix
// do the same flow for the signature as in the JS code
signRequest := func(request *http.Request) error {

View File

@ -13,7 +13,6 @@ import (
"errors"
"fmt"
"net"
"net/http"
"strings"
"testing"
"time"
@ -29,7 +28,6 @@ import (
"github.com/siderolabs/omni/client/pkg/client"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/runtime/talos"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
)
func resourceDetails(res resource.Resource) string {
@ -266,9 +264,6 @@ type WipeAMachineFunc func(ctx context.Context, uuid string) error
// FreezeAMachineFunc is a function to freeze a machine by UUID.
type FreezeAMachineFunc func(ctx context.Context, uuid string) error
// HTTPRequestSignerFunc is function to sign the HTTP request.
type HTTPRequestSignerFunc func(ctx context.Context, req *http.Request) error
// Options for the test runner.
//
//nolint:govet
@ -340,8 +335,8 @@ type MachineProviderConfig struct {
// TestOptions constains all common data that might be required to run the tests.
type TestOptions struct {
Options
omniClient *client.Client
clientConfig *clientconfig.ClientConfig
omniClient *client.Client
serviceAccountKey string
machineSemaphore *semaphore.Weighted
}

View File

@ -16,21 +16,26 @@ import (
"time"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"github.com/stretchr/testify/require"
"github.com/siderolabs/omni/client/api/omni/management"
"github.com/siderolabs/omni/client/pkg/client"
"github.com/siderolabs/omni/client/pkg/access"
clientconsts "github.com/siderolabs/omni/client/pkg/constants"
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
)
// AssertSomeImagesAreDownloadable verifies generated image download.
func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Client, signer HTTPRequestSignerFunc, httpEndpoint string) TestFunc {
st := client.Omni().State()
func AssertSomeImagesAreDownloadable(testCtx context.Context, options *TestOptions) TestFunc {
st := options.omniClient.Omni().State()
return func(t *testing.T) {
t.Parallel()
sa, err := serviceaccount.Decode(options.serviceAccountKey)
require.NoError(t, err)
media, err := safe.StateListAll[*omni.InstallationMedia](testCtx, st)
require.NoError(t, err)
@ -60,10 +65,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli
ctx, cancel := context.WithTimeout(testCtx, time.Minute*5)
defer cancel()
u, err := url.Parse(httpEndpoint)
u, err := url.Parse(options.HTTPEndpoint)
require.NoError(t, err)
schematic, err := client.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{
schematic, err := options.omniClient.Management().CreateSchematic(ctx, &management.CreateSchematicRequest{
MediaId: image.Metadata().ID(),
TalosVersion: clientconsts.DefaultTalosVersion,
})
@ -75,7 +80,10 @@ func AssertSomeImagesAreDownloadable(testCtx context.Context, client *client.Cli
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
require.NoError(t, err)
require.NoError(t, signer(ctx, req))
msg, err := message.NewHTTP(req)
require.NoError(t, err)
require.NoError(t, msg.Sign(sa.Name+access.ServiceAccountNameSuffix, sa.Key))
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)

View File

@ -16,14 +16,16 @@ import (
"net/url"
"os"
"os/exec"
"runtime"
"testing"
"time"
"github.com/cosi-project/runtime/pkg/resource/rtestutils"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
cosistate "github.com/cosi-project/runtime/pkg/state"
"github.com/mattn/go-shellwords"
"github.com/prometheus/client_golang/prometheus"
"github.com/siderolabs/go-api-signature/pkg/pgp"
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -34,14 +36,19 @@ import (
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"github.com/siderolabs/omni/client/pkg/access"
"github.com/siderolabs/omni/client/pkg/client"
clientconsts "github.com/siderolabs/omni/client/pkg/constants"
authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth"
omnires "github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/client/pkg/omni/resources/siderolink"
_ "github.com/siderolabs/omni/cmd/acompat" // this package should always be imported first for init->set env to work
"github.com/siderolabs/omni/cmd/omni/pkg/app"
"github.com/siderolabs/omni/internal/backend/runtime/omni"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
"github.com/siderolabs/omni/internal/pkg/auth/role"
omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/constants"
)
@ -197,13 +204,7 @@ func TestIntegration(t *testing.T) {
// Talos API calls try to use user auth if the service account var is not set
os.Setenv(serviceaccount.OmniServiceAccountKeyEnvVar, serviceAccount)
clientConfig := clientconfig.New(omniEndpoint, serviceAccount)
t.Cleanup(func() {
clientConfig.Close() //nolint:errcheck
})
rootClient, err := clientConfig.GetClient(t.Context())
rootClient, err := client.New(omniEndpoint, client.WithServiceAccount(serviceAccount))
require.NoError(t, err)
t.Cleanup(func() {
@ -211,10 +212,10 @@ func TestIntegration(t *testing.T) {
})
testOptions := &TestOptions{
omniClient: rootClient,
Options: options,
machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)),
clientConfig: clientConfig,
omniClient: rootClient,
Options: options,
machineSemaphore: semaphore.NewWeighted(int64(options.ExpectedMachines)),
serviceAccountKey: serviceAccount,
}
preRunHooks(t, testOptions)
@ -390,7 +391,7 @@ func postRunHooks(t *testing.T, options *TestOptions) {
}
}
func cleanupLinksFunc(ctx context.Context, st state.State) error {
func cleanupLinksFunc(ctx context.Context, st cosistate.State) error {
links, err := safe.ReaderListAll[*siderolink.Link](ctx, st)
if err != nil {
return err
@ -403,7 +404,7 @@ func cleanupLinksFunc(ctx context.Context, st state.State) error {
return links.ForEachErr(func(r *siderolink.Link) error {
err := st.TeardownAndDestroy(ctx, r.Metadata())
if err != nil && !state.IsNotFoundError(err) {
if err != nil && !cosistate.IsNotFoundError(err) {
return err
}
@ -503,7 +504,7 @@ func runOmni(t *testing.T) (string, error) {
rtestutils.AssertResources(ctx, t, state.Default(), []string{talosVersion}, func(*omnires.TalosVersion, *assert.Assertions) {})
sa, err := clientconfig.CreateServiceAccount(omniCtx, "root", state.Default())
sa, err := createBootstrapServiceAccount(omniCtx, "root", state.Default())
if err != nil {
return "", err
}
@ -513,3 +514,38 @@ func runOmni(t *testing.T) (string, error) {
return sa, nil
}
func createBootstrapServiceAccount(ctx context.Context, name string, st cosistate.State) (string, error) {
comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
serviceAccountEmail := name + access.ServiceAccountNameSuffix
key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime)
if err != nil {
return "", err
}
armoredPublicKey, err := key.ArmorPublic()
if err != nil {
return "", err
}
identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail)
if err != nil && !cosistate.IsNotFoundError(err) {
return "", err
}
if identity != nil {
err = omnisa.Destroy(ctx, st, name)
if err != nil {
return "", err
}
}
_, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey))
if err != nil {
return "", err
}
return serviceaccount.Encode(name, key)
}

View File

@ -8,8 +8,6 @@
package integration_test
import (
"context"
"net/http"
"testing"
"time"
@ -19,7 +17,6 @@ import (
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/extensions"
"github.com/siderolabs/omni/internal/integration/workloadproxy"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
)
type assertClusterReadyOptions struct {
@ -128,9 +125,7 @@ Generate various Talos images with Omni and try to download them.`)
t.Run(
"TalosImagesShouldBeDownloadable",
AssertSomeImagesAreDownloadable(t.Context(), options.omniClient, func(ctx context.Context, req *http.Request) error {
return clientconfig.SignHTTPRequest(ctx, options.omniClient, req)
}, options.HTTPEndpoint),
AssertSomeImagesAreDownloadable(t.Context(), options),
)
}
}
@ -1159,9 +1154,14 @@ Test authorization on accessing Omni API, some tests run without a cluster, some
AssertServiceAccountAPIFlow(t.Context(), options.omniClient),
)
clientFactory := newTestClientFactory(omniEndpoint, options.omniClient)
t.Cleanup(func() {
clientFactory.close() //nolint:errcheck
})
t.Run(
"ResourceAuthzShouldWork",
AssertResourceAuthz(t.Context(), options.omniClient, options.clientConfig),
AssertResourceAuthz(t.Context(), options.omniClient, clientFactory),
)
t.Run(
@ -1193,12 +1193,12 @@ Test authorization on accessing Omni API, some tests run without a cluster, some
t.Run(
"APIAuthorizationShouldBeTested",
AssertAPIAuthz(t.Context(), options.omniClient, options.clientConfig, clusterName),
AssertAPIAuthz(t.Context(), options.omniClient, clientFactory, clusterName),
)
t.Run(
"FrontendAPIShouldBeTested",
AssertFrontendResourceAPI(t.Context(), options.omniClient, options.clientConfig, options.HTTPEndpoint, clusterName),
AssertFrontendResourceAPI(t.Context(), options.omniClient, options.serviceAccountKey, options.HTTPEndpoint, clusterName),
)
t.Run(
@ -1278,7 +1278,7 @@ Test workload service proxying feature`)
parentCtx := t.Context()
t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) {
workloadproxy.Test(parentCtx, t, omniClient, cluster1, cluster2)
workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, cluster1, cluster2)
})
t.Run("ClusterShouldBeDestroyed-"+cluster1, AssertDestroyCluster(t.Context(), options.omniClient.Omni().State(), cluster1, false, false))
@ -1457,7 +1457,7 @@ Test Omni upgrades, the first half that runs on the previous Omni version
parentCtx := t.Context()
t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) {
workloadproxy.Test(parentCtx, t, omniClient, clusterName)
workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName)
})
t.Run("SaveClusterSnapshot", SaveClusterSnapshot(t.Context(), omniClient, clusterName))
@ -1494,7 +1494,7 @@ Test Omni upgrades, the second half that runs on the current Omni version
t.Run("AssertMachinesNotRebootedConfigUnchanged", AssertClusterSnapshot(t.Context(), omniClient, clusterName))
t.Run("WorkloadProxyShouldBeTested", func(t *testing.T) {
workloadproxy.Test(parentCtx, t, omniClient, clusterName)
workloadproxy.Test(parentCtx, t, omniClient, options.serviceAccountKey, clusterName)
})
t.Run(

View File

@ -31,6 +31,8 @@ import (
"github.com/hashicorp/go-cleanhttp"
"github.com/siderolabs/gen/maps"
"github.com/siderolabs/gen/xslices"
"github.com/siderolabs/go-api-signature/pkg/pgp"
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"github.com/siderolabs/go-pointer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -48,7 +50,6 @@ import (
"github.com/siderolabs/omni/client/pkg/omni/resources/omni"
"github.com/siderolabs/omni/internal/backend/services/workloadproxy"
"github.com/siderolabs/omni/internal/integration/kubernetes"
"github.com/siderolabs/omni/internal/pkg/clientconfig"
)
type serviceContext struct {
@ -75,7 +76,7 @@ var sideroLabsIconSVG []byte
// Test tests the exposed services functionality in Omni.
//
//nolint:prealloc
func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterIDs ...string) {
func Test(ctx context.Context, t *testing.T, omniClient *client.Client, serviceAccountKey string, clusterIDs ...string) {
ctx, cancel := context.WithTimeout(ctx, 20*time.Minute)
t.Cleanup(cancel)
@ -83,6 +84,9 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI
require.Fail(t, "no cluster IDs provided for the test, please provide at least one cluster ID")
}
sa, err := serviceaccount.Decode(serviceAccountKey)
require.NoError(t, err)
ctx = kubernetes.WrapContext(ctx, t)
logger := zaptest.NewLogger(t)
@ -118,7 +122,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI
allExposedServices[i], allExposedServices[j] = allExposedServices[j], allExposedServices[i]
})
testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK)
testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK)
inaccessibleExposedServices := make([]*omni.ExposedService, 0, len(allExposedServices))
@ -132,7 +136,7 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI
}
}
testAccess(ctx, t, logger, omniClient, inaccessibleExposedServices, http.StatusBadGateway)
testAccess(ctx, t, logger, sa.Key, inaccessibleExposedServices, http.StatusBadGateway)
for _, deployment := range deploymentsToScaleDown {
logger.Info("scale deployment back up", zap.String("deployment", deployment.deployment.Name), zap.String("clusterID", deployment.cluster.clusterID))
@ -140,13 +144,13 @@ func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterI
kubernetes.ScaleDeployment(ctx, t, deployment.cluster.kubeClient, deployment.deployment.Namespace, deployment.deployment.Name, 1)
}
testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK)
testToggleFeature(ctx, t, logger, omniClient, clusters[0])
testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, allServices[:len(allServices)/2])
testAccess(ctx, t, logger, sa.Key, allExposedServices, http.StatusOK)
testToggleFeature(ctx, t, logger, omniClient, sa.Key, clusters[0])
testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, sa.Key, allServices[:len(allServices)/2])
}
// testToggleFeature tests toggling off/on the workload proxy feature for a cluster.
func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, cluster clusterContext) {
func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, cluster clusterContext) {
logger.Info("test turning off and on the feature for the cluster", zap.String("clusterID", cluster.clusterID))
setFeatureToggle := func(enabled bool) {
@ -172,14 +176,14 @@ func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, om
services = services[:4]
}
testAccess(ctx, t, logger, omniClient, services[:4], http.StatusNotFound)
testAccess(ctx, t, logger, saKey, services[:4], http.StatusNotFound)
setFeatureToggle(true)
testAccess(ctx, t, logger, omniClient, services[:4], http.StatusOK)
testAccess(ctx, t, logger, saKey, services[:4], http.StatusOK)
}
func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, services []serviceContext) {
func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, saKey *pgp.Key, services []serviceContext) {
logger.Info("test toggling Kubernetes service annotation for exposed services", zap.Int("numServices", len(services)))
for _, service := range services {
@ -195,7 +199,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo
exposedServices := xslices.Map(services, func(svc serviceContext) *omni.ExposedService { return svc.res })
testAccess(ctx, t, logger, omniClient, exposedServices, http.StatusNotFound)
testAccess(ctx, t, logger, saKey, exposedServices, http.StatusNotFound)
for _, service := range services {
kubernetes.UpdateService(ctx, t, service.deployment.cluster.kubeClient, service.svc.Namespace, service.svc.Name, func(svc *corev1.Service) {
@ -223,7 +227,7 @@ func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, lo
updatedServices := maps.Values(updatedServicesMap)
testAccess(ctx, t, logger, omniClient, updatedServices, http.StatusOK)
testAccess(ctx, t, logger, saKey, updatedServices, http.StatusOK)
}
func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, clusterID string) clusterContext {
@ -299,11 +303,15 @@ func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omni
return cluster
}
func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, exposedServices []*omni.ExposedService, expectedStatusCode int) {
keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, omniClient)
func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, saKey *pgp.Key, exposedServices []*omni.ExposedService, expectedStatusCode int) {
keyID := saKey.Fingerprint()
signedIDBytes, err := saKey.Sign([]byte(keyID))
require.NoError(t, err)
logger.Debug("registered public key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64))
keyIDSignatureBase64 := base64.StdEncoding.EncodeToString(signedIDBytes)
logger.Debug("using SA key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64))
cookies := []*http.Cookie{
{Name: workloadproxy.PublicKeyIDCookie, Value: keyID},

View File

@ -1,302 +0,0 @@
// Copyright (c) 2026 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package clientconfig holds the configuration for the test client for Omni API.
package clientconfig
import (
"context"
"crypto/md5"
"encoding/base64"
"fmt"
"net/http"
"runtime"
"slices"
"time"
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
"github.com/hashicorp/go-multierror"
"github.com/siderolabs/gen/containers"
authpb "github.com/siderolabs/go-api-signature/api/auth"
authcli "github.com/siderolabs/go-api-signature/pkg/client/auth"
"github.com/siderolabs/go-api-signature/pkg/message"
"github.com/siderolabs/go-api-signature/pkg/pgp"
"github.com/siderolabs/go-api-signature/pkg/serviceaccount"
"github.com/siderolabs/omni/client/api/omni/management"
"github.com/siderolabs/omni/client/pkg/access"
"github.com/siderolabs/omni/client/pkg/client"
authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
omnisa "github.com/siderolabs/omni/internal/pkg/auth/serviceaccount"
)
const (
DefaultServiceAccount = "integration" + access.ServiceAccountNameSuffix
defaultEmail = "test-user@siderolabs.com"
)
type clientCacheKey struct {
role string
email string
skipUserRole bool
}
type clientCacheValue struct {
client *client.Client
key *pgp.Key
err error
}
// ClientConfig is a test client.
type ClientConfig struct {
endpoint string
serviceAccountKey string
clientCache containers.ConcurrentMap[clientCacheKey, clientCacheValue]
}
// New creates a new test client config.
func New(endpoint, serviceAccountKey string) *ClientConfig {
return &ClientConfig{
endpoint: endpoint,
serviceAccountKey: serviceAccountKey,
}
}
// GetClient returns a test client for the default test email.
//
// Clients are cached by their configuration, so if a client with the
// given configuration was created before, the cached one will be returned.
func (t *ClientConfig) GetClient(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) {
return t.GetClientForEmail(ctx, DefaultServiceAccount, publicKeyOpts...)
}
// GetClientForEmail returns a test client for the given email.
//
// Clients are cached by their configuration, so if a client with the
// given configuration was created before, the cached one will be returned.
func (t *ClientConfig) GetClientForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*client.Client, error) {
cacheKey := t.buildCacheKey(email, publicKeyOpts)
// The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close].
cliValue, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue {
cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey)
return clientCacheValue{
client: cli,
key: key,
err: err,
}
})
return cliValue.client, cliValue.err
}
// GetKey fetches service account key for the default email.
func (t *ClientConfig) GetKey(ctx context.Context, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) {
return t.GetKeyForEmail(ctx, DefaultServiceAccount, publicKeyOpts...)
}
// GetKeyForEmail fetches service account key for the specified email.
func (t *ClientConfig) GetKeyForEmail(ctx context.Context, email string, publicKeyOpts ...authcli.RegisterPGPPublicKeyOption) (*pgp.Key, error) {
cacheKey := t.buildCacheKey(email, publicKeyOpts)
// The client is created by the cache callback, and will be closed by the cache on [ClientConfig.Close].
cliOrErr, _ := t.clientCache.GetOrCall(cacheKey, func() clientCacheValue {
cli, key, err := createServiceAccountClient(ctx, t.endpoint, t.serviceAccountKey, cacheKey)
return clientCacheValue{
client: cli,
key: key,
err: err,
}
})
return cliOrErr.key, cliOrErr.err
}
// Close closes all the clients created by this config.
func (t *ClientConfig) Close() error {
var multiErr error
t.clientCache.ForEach(func(_ clientCacheKey, cliOrErr clientCacheValue) {
if cliOrErr.client != nil {
if err := cliOrErr.client.Close(); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
})
return multiErr
}
func (t *ClientConfig) buildCacheKey(email string, publicKeyOpts []authcli.RegisterPGPPublicKeyOption) clientCacheKey {
var req authpb.RegisterPublicKeyRequest
for _, o := range publicKeyOpts {
o(&req)
}
return clientCacheKey{
role: req.Role,
email: email,
skipUserRole: req.SkipUserRole,
}
}
// SignHTTPRequest signs the regular HTTP request using the default test email.
func SignHTTPRequest(ctx context.Context, client *client.Client, req *http.Request) error {
return SignHTTPRequestWithEmail(ctx, client, req, defaultEmail)
}
// SignHTTPRequestWithEmail signs the regular HTTP request using the given email.
func SignHTTPRequestWithEmail(ctx context.Context, client *client.Client, req *http.Request, email string) error {
newKey, err := pgp.GenerateKey("", "", email, 4*time.Hour)
if err != nil {
return err
}
err = registerKey(ctx, client.Auth(), newKey, email)
if err != nil {
return err
}
msg, err := message.NewHTTP(req)
if err != nil {
return err
}
return msg.Sign(email, newKey)
}
// RegisterKeyGetIDSignatureBase64 registers a new public key with the default test email and returns its ID and the base-64 encoded signature of the same ID.
func RegisterKeyGetIDSignatureBase64(ctx context.Context, client *client.Client) (id, idSignatureBase66 string, err error) {
newKey, err := pgp.GenerateKey("", "", defaultEmail, 4*time.Hour)
if err != nil {
return "", "", err
}
err = registerKey(ctx, client.Auth(), newKey, defaultEmail)
if err != nil {
return "", "", err
}
id = newKey.Fingerprint()
signedIDBytes, err := newKey.Sign([]byte(id))
if err != nil {
return "", "", err
}
idSignatureBase66 = base64.StdEncoding.EncodeToString(signedIDBytes)
return id, idSignatureBase66, nil
}
// CreateServiceAccount using the direct access to the Omni state.
func CreateServiceAccount(ctx context.Context, name string, st state.State) (string, error) {
// generate a new PGP key with long lifetime
comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
serviceAccountEmail := name + access.ServiceAccountNameSuffix
key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime)
if err != nil {
return "", err
}
armoredPublicKey, err := key.ArmorPublic()
if err != nil {
return "", err
}
identity, err := safe.ReaderGetByID[*authres.Identity](ctx, st, serviceAccountEmail)
if err != nil && !state.IsNotFoundError(err) {
return "", err
}
if identity != nil {
err = omnisa.Destroy(ctx, st, name)
if err != nil {
return "", err
}
}
_, err = omnisa.Create(ctx, st, name, string(role.Admin), false, []byte(armoredPublicKey))
if err != nil {
return "", err
}
return serviceaccount.Encode(name, key)
}
func createServiceAccountClient(ctx context.Context, endpoint, serviceAccountKey string, cacheKey clientCacheKey) (*client.Client, *pgp.Key, error) {
rootClient, err := client.New(endpoint, client.WithServiceAccount(serviceAccountKey))
if err != nil {
return nil, nil, err
}
defer rootClient.Close() //nolint:errcheck
name := fmt.Sprintf("%x", md5.Sum([]byte(cacheKey.email+cacheKey.role)))
// generate a new PGP key with long lifetime
comment := fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
suffix := access.ServiceAccountNameSuffix
if cacheKey.role == string(role.InfraProvider) {
suffix = access.InfraProviderServiceAccountNameSuffix
}
serviceAccountEmail := name + suffix
key, err := pgp.GenerateKey(name, comment, serviceAccountEmail, auth.ServiceAccountMaxAllowedLifetime)
if err != nil {
return nil, nil, err
}
if cacheKey.role == string(role.InfraProvider) {
name = access.InfraProviderServiceAccountPrefix + name
}
armoredPublicKey, err := key.ArmorPublic()
if err != nil {
return nil, nil, err
}
serviceAccounts, err := rootClient.Management().ListServiceAccounts(ctx)
if err != nil {
return nil, nil, err
}
if slices.IndexFunc(serviceAccounts, func(account *management.ListServiceAccountsResponse_ServiceAccount) bool {
return account.Name == name
}) != -1 {
if err = rootClient.Management().DestroyServiceAccount(ctx, name); err != nil {
return nil, nil, err
}
}
// create service account with the generated key
_, err = rootClient.Management().CreateServiceAccount(ctx, name, armoredPublicKey, cacheKey.role, cacheKey.role == "")
if err != nil {
return nil, nil, err
}
encodedKey, err := serviceaccount.Encode(name, key)
if err != nil {
return nil, nil, err
}
cli, err := client.New(endpoint, client.WithServiceAccount(encodedKey))
if err != nil {
return nil, nil, err
}
return cli, key, nil
}

View File

@ -1,43 +0,0 @@
// Copyright (c) 2026 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//go:build sidero.debug
package clientconfig
import (
"context"
"time"
"github.com/siderolabs/go-api-signature/pkg/client/auth"
"github.com/siderolabs/go-api-signature/pkg/pgp"
"google.golang.org/grpc/metadata"
grpcomni "github.com/siderolabs/omni/internal/backend/grpc"
)
func registerKey(ctx context.Context, cli *auth.Client, key *pgp.Key, email string, opts ...auth.RegisterPGPPublicKeyOption) error {
armoredPublicKey, err := key.ArmorPublic()
if err != nil {
return err
}
_, err = cli.RegisterPGPPublicKey(ctx, email, []byte(armoredPublicKey), opts...)
if err != nil {
return err
}
debugCtx := metadata.AppendToOutgoingContext(ctx, grpcomni.DebugVerifiedEmailHeaderKey, email)
err = cli.ConfirmPublicKey(debugCtx, key.Fingerprint())
if err != nil {
return err
}
timeoutCtx, timeoutCtxCancel := context.WithTimeout(ctx, 10*time.Second)
defer timeoutCtxCancel()
return cli.AwaitPublicKeyConfirmation(timeoutCtx, key.Fingerprint())
}

View File

@ -1,20 +0,0 @@
// Copyright (c) 2026 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//go:build !sidero.debug
package clientconfig
import (
"context"
"errors"
"github.com/siderolabs/go-api-signature/pkg/client/auth"
"github.com/siderolabs/go-api-signature/pkg/pgp"
)
func registerKey(context.Context, *auth.Client, *pgp.Key, string, ...auth.RegisterPGPPublicKeyOption) error {
return errors.New("registerKey is not implemented in non-debug builds")
}