From 00fde1e510f1615a0cc3ec4736d4155246535a08 Mon Sep 17 00:00:00 2001 From: Ivan Ka <5395690+ivankatliarchuk@users.noreply.github.com> Date: Tue, 3 Jun 2025 08:02:38 +0100 Subject: [PATCH] chore(source/crd): add labels without looping over (#5492) --- endpoint/endpoint.go | 15 ++++- internal/testutils/endpoint_test.go | 20 ++++++ source/crd.go | 13 +--- source/crd_test.go | 100 +++++++++++++++++++++++++++- source/shared_test.go | 4 +- 5 files changed, 137 insertions(+), 15 deletions(-) diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 86034fed6..e6e1cc41a 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -234,7 +234,7 @@ func NewEndpointWithTTL(dnsName, recordType string, ttl TTL, targets ...string) cleanTargets[idx] = strings.TrimSuffix(target, ".") } - for _, label := range strings.Split(dnsName, ".") { + for label := range strings.SplitSeq(dnsName, ".") { if len(label) > 63 { log.Errorf("label %s in %s is longer than 63 characters. Cannot create endpoint", label, dnsName) return nil @@ -301,6 +301,19 @@ func (e *Endpoint) DeleteProviderSpecificProperty(key string) { } } +// WithLabel adds or updates a label for the Endpoint. +// +// Example usage: +// +// ep.WithLabel("owner", "user123") +func (e *Endpoint) WithLabel(key, value string) *Endpoint { + if e.Labels == nil { + e.Labels = NewLabels() + } + e.Labels[key] = value + return e +} + // Key returns the EndpointKey of the Endpoint. func (e *Endpoint) Key() EndpointKey { return EndpointKey{ diff --git a/internal/testutils/endpoint_test.go b/internal/testutils/endpoint_test.go index 8ddb1f75d..6f109def0 100644 --- a/internal/testutils/endpoint_test.go +++ b/internal/testutils/endpoint_test.go @@ -469,3 +469,23 @@ func TestNewTargetsFromAddr(t *testing.T) { }) } } + +func TestWithLabel(t *testing.T) { + e := &endpoint.Endpoint{} + // should initialize Labels and set the key + returned := e.WithLabel("foo", "bar") + assert.Equal(t, e, returned) + assert.NotNil(t, e.Labels) + assert.Equal(t, "bar", e.Labels["foo"]) + + // overriding an existing key + e2 := e.WithLabel("foo", "baz") + assert.Equal(t, e, e2) + assert.Equal(t, "baz", e.Labels["foo"]) + + // adding a new key without wiping others + e.Labels["existing"] = "orig" + e.WithLabel("new", "val") + assert.Equal(t, "orig", e.Labels["existing"]) + assert.Equal(t, "val", e.Labels["new"]) +} diff --git a/source/crd.go b/source/crd.go index 187d19525..9c46b6c74 100644 --- a/source/crd.go +++ b/source/crd.go @@ -183,7 +183,7 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error for _, dnsEndpoint := range result.Items { // Make sure that all endpoints have targets for A or CNAME type - crdEndpoints := []*endpoint.Endpoint{} + var crdEndpoints []*endpoint.Endpoint for _, ep := range dnsEndpoint.Spec.Endpoints { if (ep.RecordType == endpoint.RecordTypeCNAME || ep.RecordType == endpoint.RecordTypeA || ep.RecordType == endpoint.RecordTypeAAAA) && len(ep.Targets) < 1 { log.Warnf("Endpoint %s with DNSName %s has an empty list of targets", dnsEndpoint.Name, ep.DNSName) @@ -206,14 +206,11 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error continue } - if ep.Labels == nil { - ep.Labels = endpoint.NewLabels() - } + ep.WithLabel(endpoint.ResourceLabelKey, fmt.Sprintf("crd/%s/%s", dnsEndpoint.Namespace, dnsEndpoint.Name)) crdEndpoints = append(crdEndpoints, ep) } - cs.setResourceLabel(&dnsEndpoint, crdEndpoints) endpoints = append(endpoints, crdEndpoints...) if dnsEndpoint.Status.ObservedGeneration == dnsEndpoint.Generation { @@ -231,12 +228,6 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error return endpoints, nil } -func (cs *crdSource) setResourceLabel(crd *apiv1alpha1.DNSEndpoint, endpoints []*endpoint.Endpoint) { - for _, ep := range endpoints { - ep.Labels[endpoint.ResourceLabelKey] = fmt.Sprintf("crd/%s/%s", crd.Namespace, crd.Name) - } -} - func (cs *crdSource) watch(ctx context.Context, opts *metav1.ListOptions) (watch.Interface, error) { opts.Watch = true return cs.crdClient.Get(). diff --git a/source/crd_test.go b/source/crd_test.go index 5d33851fa..f08fc01ac 100644 --- a/source/crd_test.go +++ b/source/crd_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net/http" "strings" "sync/atomic" @@ -64,7 +65,7 @@ func objBody(codec runtime.Encoder, obj runtime.Object) io.ReadCloser { func fakeRESTClient(endpoints []*endpoint.Endpoint, apiVersion, kind, namespace, name string, annotations map[string]string, labels map[string]string, _ *testing.T) rest.Interface { groupVersion, _ := schema.ParseGroupVersion(apiVersion) scheme := runtime.NewScheme() - addKnownTypes(scheme, groupVersion) + _ = addKnownTypes(scheme, groupVersion) dnsEndpointList := apiv1alpha1.DNSEndpointList{} dnsEndpoint := &apiv1alpha1.DNSEndpoint{ @@ -513,6 +514,12 @@ func testCRDSourceEndpoints(t *testing.T) { // Validate received endpoints against expected endpoints. validateEndpoints(t, receivedEndpoints, ti.endpoints) + + for _, e := range receivedEndpoints { + // TODO: at the moment not all sources apply ResourceLabelKey + require.GreaterOrEqual(t, len(e.Labels), 1, "endpoint must have at least one label") + require.Contains(t, e.Labels, endpoint.ResourceLabelKey, "endpoint must include the ResourceLabelKey label") + } }) } } @@ -659,6 +666,61 @@ func validateCRDResource(t *testing.T, src Source, expectError bool) { } } +func TestDNSEndpointsWithSetResourceLabels(t *testing.T) { + + typeCounts := map[string]int{ + endpoint.RecordTypeA: 3, + endpoint.RecordTypeCNAME: 2, + endpoint.RecordTypeNS: 7, + endpoint.RecordTypeNAPTR: 1, + } + + crds := generateTestFixtureDNSEndpointsByType("test-ns", typeCounts) + + for _, crd := range crds.Items { + for _, ep := range crd.Spec.Endpoints { + require.Empty(t, ep.Labels, "endpoint not have labels set") + require.NotContains(t, ep.Labels, endpoint.ResourceLabelKey, "endpoint must not include the ResourceLabelKey label") + } + } + + scheme := runtime.NewScheme() + err := apiv1alpha1.AddToScheme(scheme) + require.NoError(t, err) + + codecFactory := serializer.WithoutConversionCodecFactory{ + CodecFactory: serializer.NewCodecFactory(scheme), + } + + client := &fake.RESTClient{ + GroupVersion: apiv1alpha1.GroupVersion, + VersionedAPIPath: fmt.Sprintf("/apis/%s", apiv1alpha1.GroupVersion.String()), + NegotiatedSerializer: codecFactory, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: objBody(codecFactory.LegacyCodec(apiv1alpha1.GroupVersion), &crds), + }, nil + }), + } + + cs := &crdSource{ + crdClient: client, + namespace: "test-ns", + crdResource: "dnsendpoints", + codec: runtime.NewParameterCodec(scheme), + labelSelector: labels.Everything(), + } + + res, err := cs.Endpoints(t.Context()) + require.NoError(t, err) + + for _, ep := range res { + require.Contains(t, ep.Labels, endpoint.ResourceLabelKey) + } +} + func helperCreateWatcherWithInformer(t *testing.T) (*cachetesting.FakeControllerSource, crdSource) { t.Helper() ctx := t.Context() @@ -679,3 +741,39 @@ func helperCreateWatcherWithInformer(t *testing.T) (*cachetesting.FakeController return watcher, *cs } + +// generateTestFixtureDNSEndpointsByType generates DNSEndpoint CRDs according to the provided counts per RecordType. +func generateTestFixtureDNSEndpointsByType(namespace string, typeCounts map[string]int) apiv1alpha1.DNSEndpointList { + var result []apiv1alpha1.DNSEndpoint + idx := 0 + for rt, count := range typeCounts { + for i := 0; i < count; i++ { + result = append(result, apiv1alpha1.DNSEndpoint{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("dnsendpoint-%s-%d", rt, idx), + Namespace: namespace, + }, + Spec: apiv1alpha1.DNSEndpointSpec{ + Endpoints: []*endpoint.Endpoint{ + { + DNSName: strings.ToLower(fmt.Sprintf("%s-%d.example.com", rt, idx)), + RecordType: rt, + Targets: endpoint.Targets{fmt.Sprintf("192.0.2.%d", idx)}, + RecordTTL: 300, + }, + }, + }, + }) + idx++ + } + } + // Shuffle the result to ensure randomness in the order. + rand.New(rand.NewSource(time.Now().UnixNano())) + rand.Shuffle(len(result), func(i, j int) { + result[i], result[j] = result[j], result[i] + }) + + return apiv1alpha1.DNSEndpointList{ + Items: result, + } +} diff --git a/source/shared_test.go b/source/shared_test.go index 11828dbe2..62280ec77 100644 --- a/source/shared_test.go +++ b/source/shared_test.go @@ -81,12 +81,12 @@ func validateEndpoint(t *testing.T, endpoint, expected *endpoint.Endpoint) { t.Errorf("RecordTTL expected %v, got %v", expected.RecordTTL, endpoint.RecordTTL) } - // if non-empty record type is expected, check that it matches. + // if a non-empty record type is expected, check that it matches. if endpoint.RecordType != expected.RecordType { t.Errorf("RecordType expected %q, got %q", expected.RecordType, endpoint.RecordType) } - // if non-empty labels are expected, check that they matches. + // if non-empty labels are expected, check that they match. if expected.Labels != nil && !reflect.DeepEqual(endpoint.Labels, expected.Labels) { t.Errorf("Labels expected %s, got %s", expected.Labels, endpoint.Labels) }