// Copyright (c) 2025 Sidero Labs, Inc. // // Use of this software is governed by the Business Source License // included in the LICENSE file. // Package workloadproxy provides tests for the workload proxy (exposed services) functionality in Omni. package workloadproxy import ( "bytes" "compress/gzip" "context" "crypto/tls" _ "embed" "encoding/base64" "errors" "fmt" "io" "math/rand/v2" "net/http" "net/url" "strconv" "strings" "sync" "testing" "time" "github.com/cosi-project/runtime/pkg/resource" "github.com/cosi-project/runtime/pkg/resource/rtestutils" "github.com/cosi-project/runtime/pkg/safe" "github.com/hashicorp/go-cleanhttp" "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" "github.com/siderolabs/go-pointer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "go.uber.org/zap/zaptest" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" clientgokubernetes "k8s.io/client-go/kubernetes" "github.com/siderolabs/omni/client/pkg/client" "github.com/siderolabs/omni/client/pkg/constants" "github.com/siderolabs/omni/client/pkg/omni/resources" "github.com/siderolabs/omni/client/pkg/omni/resources/omni" "github.com/siderolabs/omni/internal/backend/services/workloadproxy" "github.com/siderolabs/omni/internal/integration/kubernetes" "github.com/siderolabs/omni/internal/pkg/clientconfig" ) type serviceContext struct { res *omni.ExposedService svc *corev1.Service deployment *deploymentContext } type deploymentContext struct { deployment *appsv1.Deployment cluster *clusterContext services []serviceContext } type clusterContext struct { kubeClient clientgokubernetes.Interface clusterID string deployments []deploymentContext } //go:embed testdata/sidero-labs-icon.svg var sideroLabsIconSVG []byte // Test tests the exposed services functionality in Omni. func Test(ctx context.Context, t *testing.T, omniClient *client.Client, clusterIDs ...string) { ctx, cancel := context.WithTimeout(ctx, 20*time.Minute) t.Cleanup(cancel) if len(clusterIDs) == 0 { require.Fail(t, "no cluster IDs provided for the test, please provide at least one cluster ID") } ctx = kubernetes.WrapContext(ctx, t) logger := zaptest.NewLogger(t) clusters := make([]clusterContext, 0, len(clusterIDs)) for _, clusterID := range clusterIDs { cluster := prepareServices(ctx, t, logger, omniClient, clusterID) clusters = append(clusters, cluster) } var ( allServices []serviceContext allExposedServices []*omni.ExposedService deploymentsToScaleDown []deploymentContext ) for _, cluster := range clusters { for i, deployment := range cluster.deployments { allServices = append(allServices, deployment.services...) for _, service := range deployment.services { allExposedServices = append(allExposedServices, service.res) } if i%2 == 0 { // scale down every second deployment deploymentsToScaleDown = append(deploymentsToScaleDown, deployment) } } } rand.Shuffle(len(allExposedServices), func(i, j int) { allExposedServices[i], allExposedServices[j] = allExposedServices[j], allExposedServices[i] }) testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) inaccessibleExposedServices := make([]*omni.ExposedService, 0, len(allExposedServices)) for _, deployment := range deploymentsToScaleDown { logger.Info("scale deployment down to zero replicas", zap.String("deployment", deployment.deployment.Name), zap.String("clusterID", deployment.cluster.clusterID)) kubernetes.ScaleDeployment(ctx, t, deployment.cluster.kubeClient, deployment.deployment.Namespace, deployment.deployment.Name, 0) for _, service := range deployment.services { inaccessibleExposedServices = append(inaccessibleExposedServices, service.res) } } testAccess(ctx, t, logger, omniClient, inaccessibleExposedServices, http.StatusBadGateway) for _, deployment := range deploymentsToScaleDown { logger.Info("scale deployment back up", zap.String("deployment", deployment.deployment.Name), zap.String("clusterID", deployment.cluster.clusterID)) kubernetes.ScaleDeployment(ctx, t, deployment.cluster.kubeClient, deployment.deployment.Namespace, deployment.deployment.Name, 1) } testAccess(ctx, t, logger, omniClient, allExposedServices, http.StatusOK) testToggleFeature(ctx, t, logger, omniClient, clusters[0]) testToggleKubernetesServiceAnnotation(ctx, t, logger, omniClient, allServices[:len(allServices)/2]) } // testToggleFeature tests toggling off/on the workload proxy feature for a cluster. func testToggleFeature(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, cluster clusterContext) { logger.Info("test turning off and on the feature for the cluster", zap.String("clusterID", cluster.clusterID)) setFeatureToggle := func(enabled bool) { _, err := safe.StateUpdateWithConflicts(ctx, omniClient.Omni().State(), omni.NewCluster(resources.DefaultNamespace, cluster.clusterID).Metadata(), func(res *omni.Cluster) error { res.TypedSpec().Value.Features.EnableWorkloadProxy = enabled return nil }) require.NoErrorf(t, err, "failed to turn off workload proxy feature for cluster %q", cluster.clusterID) } setFeatureToggle(false) var services []*omni.ExposedService for _, deployment := range cluster.deployments { for _, service := range deployment.services { services = append(services, service.res) } } if len(services) > 4 { services = services[:4] } testAccess(ctx, t, logger, omniClient, services[:4], http.StatusNotFound) setFeatureToggle(true) testAccess(ctx, t, logger, omniClient, services[:4], http.StatusOK) } func testToggleKubernetesServiceAnnotation(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, services []serviceContext) { logger.Info("test toggling Kubernetes service annotation for exposed services", zap.Int("numServices", len(services))) for _, service := range services { kubernetes.UpdateService(ctx, t, service.deployment.cluster.kubeClient, service.svc.Namespace, service.svc.Name, func(svc *corev1.Service) { delete(svc.Annotations, constants.ExposedServicePortAnnotationKey) }) } omniState := omniClient.Omni().State() for _, service := range services { rtestutils.AssertNoResource[*omni.ExposedService](ctx, t, omniState, service.res.Metadata().ID()) } exposedServices := xslices.Map(services, func(svc serviceContext) *omni.ExposedService { return svc.res }) testAccess(ctx, t, logger, omniClient, exposedServices, http.StatusNotFound) for _, service := range services { kubernetes.UpdateService(ctx, t, service.deployment.cluster.kubeClient, service.svc.Namespace, service.svc.Name, func(svc *corev1.Service) { svc.Annotations[constants.ExposedServicePortAnnotationKey] = service.svc.Annotations[constants.ExposedServicePortAnnotationKey] }) } updatedServicesMap := make(map[resource.ID]*omni.ExposedService) for _, service := range services { rtestutils.AssertResource[*omni.ExposedService](ctx, t, omniState, service.res.Metadata().ID(), func(r *omni.ExposedService, assertion *assert.Assertions) { assertion.Equal(service.res.TypedSpec().Value.Port, r.TypedSpec().Value.Port, "exposed service port should be restored after toggling the annotation back on") assertion.Equal(service.res.TypedSpec().Value.Label, r.TypedSpec().Value.Label, "exposed service label should be restored after toggling the annotation back on") assertion.Equal(service.res.TypedSpec().Value.IconBase64, r.TypedSpec().Value.IconBase64, "exposed service icon should be restored after toggling the annotation back on") assertion.Equal(service.res.TypedSpec().Value.HasExplicitAlias, r.TypedSpec().Value.HasExplicitAlias, "exposed service has explicit alias") assertion.Empty(r.TypedSpec().Value.Error, "exposed service should not have an error after toggling the annotation back on") if r.TypedSpec().Value.HasExplicitAlias { assertion.Equal(service.res.TypedSpec().Value.Url, r.TypedSpec().Value.Url, "exposed service URL should be restored after toggling the annotation back on") } updatedServicesMap[r.Metadata().ID()] = r }) } updatedServices := maps.Values(updatedServicesMap) testAccess(ctx, t, logger, omniClient, updatedServices, http.StatusOK) } func prepareServices(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, clusterID string) clusterContext { ctx, cancel := context.WithTimeout(ctx, 150*time.Second) t.Cleanup(cancel) omniState := omniClient.Omni().State() kubeClient := kubernetes.GetClient(ctx, t, omniClient.Management(), clusterID) iconBase64 := base64.StdEncoding.EncodeToString(doGzip(t, sideroLabsIconSVG)) // base64(gzip(Sidero Labs icon SVG)) expectedIconBase64 := base64.StdEncoding.EncodeToString(sideroLabsIconSVG) numWorkloads := 5 numReplicasPerWorkload := 4 numServicePerWorkload := 10 startPort := 12345 cluster := clusterContext{ clusterID: clusterID, deployments: make([]deploymentContext, 0, numWorkloads), kubeClient: kubeClient, } for i := range numWorkloads { identifier := fmt.Sprintf("%s-w%02d", clusterID, i) deployment := deploymentContext{ cluster: &cluster, } firstPort := startPort + i*numServicePerWorkload var services []*corev1.Service deployment.deployment, services = createKubernetesResources(ctx, t, logger, kubeClient, firstPort, numReplicasPerWorkload, numServicePerWorkload, identifier, iconBase64) for _, service := range services { expectedID := clusterID + "/" + service.Name + "." + service.Namespace expectedPort, err := strconv.Atoi(service.Annotations[constants.ExposedServicePortAnnotationKey]) require.NoError(t, err) expectedLabel := service.Annotations[constants.ExposedServiceLabelAnnotationKey] explicitAlias, hasExplicitAlias := service.Annotations[constants.ExposedServicePrefixAnnotationKey] var res *omni.ExposedService rtestutils.AssertResource[*omni.ExposedService](ctx, t, omniState, expectedID, func(r *omni.ExposedService, assertion *assert.Assertions) { assertion.Equal(expectedPort, int(r.TypedSpec().Value.Port)) assertion.Equal(expectedLabel, r.TypedSpec().Value.Label) assertion.Equal(expectedIconBase64, r.TypedSpec().Value.IconBase64) assertion.NotEmpty(r.TypedSpec().Value.Url) assertion.Empty(r.TypedSpec().Value.Error) assertion.Equal(hasExplicitAlias, r.TypedSpec().Value.HasExplicitAlias) if hasExplicitAlias { assertion.Contains(r.TypedSpec().Value.Url, explicitAlias+"-") } res = r }) deployment.services = append(deployment.services, serviceContext{ res: res, svc: service, deployment: &deployment, }) } cluster.deployments = append(cluster.deployments, deployment) } return cluster } func testAccess(ctx context.Context, t *testing.T, logger *zap.Logger, omniClient *client.Client, exposedServices []*omni.ExposedService, expectedStatusCode int) { keyID, keyIDSignatureBase64, err := clientconfig.RegisterKeyGetIDSignatureBase64(ctx, omniClient) require.NoError(t, err) logger.Debug("registered public key for workload proxy", zap.String("keyID", keyID), zap.String("keyIDSignatureBase64", keyIDSignatureBase64)) cookies := []*http.Cookie{ {Name: workloadproxy.PublicKeyIDCookie, Value: keyID}, {Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: keyIDSignatureBase64}, } clientTransport := cleanhttp.DefaultTransport() clientTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} httpClient := &http.Client{ CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse // disable follow redirects }, Timeout: 30 * time.Second, Transport: clientTransport, } t.Cleanup(httpClient.CloseIdleConnections) parallelTestBatchSize := 8 var wg sync.WaitGroup errs := make([]error, parallelTestBatchSize) for i, exposedService := range exposedServices { logger.Info("test exposed service", zap.String("id", exposedService.Metadata().ID()), zap.String("url", exposedService.TypedSpec().Value.Url), zap.Int("expectedStatusCode", expectedStatusCode), ) wg.Add(1) go func() { defer wg.Done() if testErr := testAccessParallel(ctx, httpClient, exposedService, expectedStatusCode, cookies...); testErr != nil { errs[i%parallelTestBatchSize] = fmt.Errorf("failed to access exposed service %q over %q [%d]: %w", exposedService.Metadata().ID(), exposedService.TypedSpec().Value.Url, i, testErr) } }() if i == len(exposedServices)-1 || ((i+1)%parallelTestBatchSize == 0) { logger.Info("wait for the batch of exposed service tests to finish", zap.Int("batchSize", (i%parallelTestBatchSize)+1)) wg.Wait() for _, testErr := range errs { assert.NoError(t, testErr) } // reset the errors for the next batch for j := range errs { errs[j] = nil } } } } func testAccessParallel(ctx context.Context, httpClient *http.Client, exposedService *omni.ExposedService, expectedStatusCode int, cookies ...*http.Cookie) error { svcURL := exposedService.TypedSpec().Value.Url // test redirect to the authentication page when cookies are not set if err := testAccessSingle(ctx, httpClient, svcURL, http.StatusSeeOther, ""); err != nil { return fmt.Errorf("failed to access exposed service %q over %q: %w", exposedService.Metadata().ID(), svcURL, err) } expectedBodyContent := "" reqPerExposedService := 128 numRetries := 9 label := exposedService.TypedSpec().Value.Label if expectedStatusCode == http.StatusOK { expectedBodyContent = label[:strings.LastIndex(label, "-")] // for the label "integration-workload-proxy-2-w03-s01", the content must be "integration-workload-proxy-2-w03" } var wg sync.WaitGroup wg.Add(reqPerExposedService) lock := sync.Mutex{} svcErrs := make(map[string]error, reqPerExposedService) for range reqPerExposedService { go func() { defer wg.Done() if err := testAccessSingleWithRetries(ctx, httpClient, svcURL, expectedStatusCode, expectedBodyContent, numRetries, cookies...); err != nil { lock.Lock() svcErrs[err.Error()] = err lock.Unlock() } }() } wg.Wait() return errors.Join(maps.Values(svcErrs)...) } func testAccessSingleWithRetries(ctx context.Context, httpClient *http.Client, svcURL string, expectedStatusCode int, expectedBodyContent string, retries int, cookies ...*http.Cookie) error { if retries <= 0 { return fmt.Errorf("retries must be greater than 0, got %d", retries) } var err error for range retries { if err = testAccessSingle(ctx, httpClient, svcURL, expectedStatusCode, expectedBodyContent, cookies...); err == nil { return nil } select { case <-ctx.Done(): return fmt.Errorf("context canceled while trying to access %q: %w", svcURL, ctx.Err()) case <-time.After(500 * time.Millisecond): } } return fmt.Errorf("failed to access %q after %d retries: %w", svcURL, retries, err) } func testAccessSingle(ctx context.Context, httpClient *http.Client, svcURL string, expectedStatusCode int, expectedBodyContent string, cookies ...*http.Cookie) error { req, err := prepareRequest(ctx, svcURL) if err != nil { return fmt.Errorf("failed to prepare request: %w", err) } for _, cookie := range cookies { req.AddCookie(cookie) } resp, err := httpClient.Do(req) if err != nil { return fmt.Errorf("failed to do request: %w", err) } defer resp.Body.Close() //nolint:errcheck body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response body: %w", err) } bodyStr := string(body) if resp.StatusCode != expectedStatusCode { return fmt.Errorf("unexpected status code: expected %d, got %d, body: %q", expectedStatusCode, resp.StatusCode, bodyStr) } if expectedBodyContent == "" { return nil } if !strings.Contains(string(body), expectedBodyContent) { return fmt.Errorf("content %q not found in body: %q", expectedBodyContent, bodyStr) } return nil } // prepareRequest prepares a request to the workload proxy. // It uses the Omni base URL to access Omni with proper name resolution, but set the Host header to the original URL to test the workload proxy logic // // sample svcURL: https://j2s7hf-local.proxy-us.localhost:8099/ func prepareRequest(ctx context.Context, svcURL string) (*http.Request, error) { parsedURL, err := url.Parse(svcURL) if err != nil { return nil, fmt.Errorf("failed to parse URL %q: %w", svcURL, err) } hostParts := strings.SplitN(parsedURL.Host, ".", 3) if len(hostParts) < 3 { return nil, fmt.Errorf("failed to parse host %q: expected at least 3 parts", parsedURL.Host) } svcHost := parsedURL.Host baseHost := hostParts[2] parsedURL.Host = baseHost baseURL := parsedURL.String() req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Host = svcHost return req, nil } func createKubernetesResources(ctx context.Context, t *testing.T, logger *zap.Logger, kubeClient clientgokubernetes.Interface, firstPort, numReplicas, numServices int, identifier, icon string, ) (*appsv1.Deployment, []*corev1.Service) { namespace := "default" _, err := kubeClient.CoreV1().ConfigMaps(namespace).Create(ctx, &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: identifier, Namespace: namespace, }, Data: map[string]string{ "index.html": fmt.Sprintf("