chore(source/crd): add labels without looping over (#5492)

This commit is contained in:
Ivan Ka 2025-06-03 08:02:38 +01:00 committed by GitHub
parent 2819c2f05c
commit 00fde1e510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 137 additions and 15 deletions

View File

@ -234,7 +234,7 @@ func NewEndpointWithTTL(dnsName, recordType string, ttl TTL, targets ...string)
cleanTargets[idx] = strings.TrimSuffix(target, ".") cleanTargets[idx] = strings.TrimSuffix(target, ".")
} }
for _, label := range strings.Split(dnsName, ".") { for label := range strings.SplitSeq(dnsName, ".") {
if len(label) > 63 { if len(label) > 63 {
log.Errorf("label %s in %s is longer than 63 characters. Cannot create endpoint", label, dnsName) log.Errorf("label %s in %s is longer than 63 characters. Cannot create endpoint", label, dnsName)
return nil return nil
@ -301,6 +301,19 @@ func (e *Endpoint) DeleteProviderSpecificProperty(key string) {
} }
} }
// WithLabel adds or updates a label for the Endpoint.
//
// Example usage:
//
// ep.WithLabel("owner", "user123")
func (e *Endpoint) WithLabel(key, value string) *Endpoint {
if e.Labels == nil {
e.Labels = NewLabels()
}
e.Labels[key] = value
return e
}
// Key returns the EndpointKey of the Endpoint. // Key returns the EndpointKey of the Endpoint.
func (e *Endpoint) Key() EndpointKey { func (e *Endpoint) Key() EndpointKey {
return EndpointKey{ return EndpointKey{

View File

@ -469,3 +469,23 @@ func TestNewTargetsFromAddr(t *testing.T) {
}) })
} }
} }
func TestWithLabel(t *testing.T) {
e := &endpoint.Endpoint{}
// should initialize Labels and set the key
returned := e.WithLabel("foo", "bar")
assert.Equal(t, e, returned)
assert.NotNil(t, e.Labels)
assert.Equal(t, "bar", e.Labels["foo"])
// overriding an existing key
e2 := e.WithLabel("foo", "baz")
assert.Equal(t, e, e2)
assert.Equal(t, "baz", e.Labels["foo"])
// adding a new key without wiping others
e.Labels["existing"] = "orig"
e.WithLabel("new", "val")
assert.Equal(t, "orig", e.Labels["existing"])
assert.Equal(t, "val", e.Labels["new"])
}

View File

@ -183,7 +183,7 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error
for _, dnsEndpoint := range result.Items { for _, dnsEndpoint := range result.Items {
// Make sure that all endpoints have targets for A or CNAME type // Make sure that all endpoints have targets for A or CNAME type
crdEndpoints := []*endpoint.Endpoint{} var crdEndpoints []*endpoint.Endpoint
for _, ep := range dnsEndpoint.Spec.Endpoints { for _, ep := range dnsEndpoint.Spec.Endpoints {
if (ep.RecordType == endpoint.RecordTypeCNAME || ep.RecordType == endpoint.RecordTypeA || ep.RecordType == endpoint.RecordTypeAAAA) && len(ep.Targets) < 1 { if (ep.RecordType == endpoint.RecordTypeCNAME || ep.RecordType == endpoint.RecordTypeA || ep.RecordType == endpoint.RecordTypeAAAA) && len(ep.Targets) < 1 {
log.Warnf("Endpoint %s with DNSName %s has an empty list of targets", dnsEndpoint.Name, ep.DNSName) log.Warnf("Endpoint %s with DNSName %s has an empty list of targets", dnsEndpoint.Name, ep.DNSName)
@ -206,14 +206,11 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error
continue continue
} }
if ep.Labels == nil { ep.WithLabel(endpoint.ResourceLabelKey, fmt.Sprintf("crd/%s/%s", dnsEndpoint.Namespace, dnsEndpoint.Name))
ep.Labels = endpoint.NewLabels()
}
crdEndpoints = append(crdEndpoints, ep) crdEndpoints = append(crdEndpoints, ep)
} }
cs.setResourceLabel(&dnsEndpoint, crdEndpoints)
endpoints = append(endpoints, crdEndpoints...) endpoints = append(endpoints, crdEndpoints...)
if dnsEndpoint.Status.ObservedGeneration == dnsEndpoint.Generation { if dnsEndpoint.Status.ObservedGeneration == dnsEndpoint.Generation {
@ -231,12 +228,6 @@ func (cs *crdSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error
return endpoints, nil return endpoints, nil
} }
func (cs *crdSource) setResourceLabel(crd *apiv1alpha1.DNSEndpoint, endpoints []*endpoint.Endpoint) {
for _, ep := range endpoints {
ep.Labels[endpoint.ResourceLabelKey] = fmt.Sprintf("crd/%s/%s", crd.Namespace, crd.Name)
}
}
func (cs *crdSource) watch(ctx context.Context, opts *metav1.ListOptions) (watch.Interface, error) { func (cs *crdSource) watch(ctx context.Context, opts *metav1.ListOptions) (watch.Interface, error) {
opts.Watch = true opts.Watch = true
return cs.crdClient.Get(). return cs.crdClient.Get().

View File

@ -22,6 +22,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -64,7 +65,7 @@ func objBody(codec runtime.Encoder, obj runtime.Object) io.ReadCloser {
func fakeRESTClient(endpoints []*endpoint.Endpoint, apiVersion, kind, namespace, name string, annotations map[string]string, labels map[string]string, _ *testing.T) rest.Interface { func fakeRESTClient(endpoints []*endpoint.Endpoint, apiVersion, kind, namespace, name string, annotations map[string]string, labels map[string]string, _ *testing.T) rest.Interface {
groupVersion, _ := schema.ParseGroupVersion(apiVersion) groupVersion, _ := schema.ParseGroupVersion(apiVersion)
scheme := runtime.NewScheme() scheme := runtime.NewScheme()
addKnownTypes(scheme, groupVersion) _ = addKnownTypes(scheme, groupVersion)
dnsEndpointList := apiv1alpha1.DNSEndpointList{} dnsEndpointList := apiv1alpha1.DNSEndpointList{}
dnsEndpoint := &apiv1alpha1.DNSEndpoint{ dnsEndpoint := &apiv1alpha1.DNSEndpoint{
@ -513,6 +514,12 @@ func testCRDSourceEndpoints(t *testing.T) {
// Validate received endpoints against expected endpoints. // Validate received endpoints against expected endpoints.
validateEndpoints(t, receivedEndpoints, ti.endpoints) validateEndpoints(t, receivedEndpoints, ti.endpoints)
for _, e := range receivedEndpoints {
// TODO: at the moment not all sources apply ResourceLabelKey
require.GreaterOrEqual(t, len(e.Labels), 1, "endpoint must have at least one label")
require.Contains(t, e.Labels, endpoint.ResourceLabelKey, "endpoint must include the ResourceLabelKey label")
}
}) })
} }
} }
@ -659,6 +666,61 @@ func validateCRDResource(t *testing.T, src Source, expectError bool) {
} }
} }
func TestDNSEndpointsWithSetResourceLabels(t *testing.T) {
typeCounts := map[string]int{
endpoint.RecordTypeA: 3,
endpoint.RecordTypeCNAME: 2,
endpoint.RecordTypeNS: 7,
endpoint.RecordTypeNAPTR: 1,
}
crds := generateTestFixtureDNSEndpointsByType("test-ns", typeCounts)
for _, crd := range crds.Items {
for _, ep := range crd.Spec.Endpoints {
require.Empty(t, ep.Labels, "endpoint not have labels set")
require.NotContains(t, ep.Labels, endpoint.ResourceLabelKey, "endpoint must not include the ResourceLabelKey label")
}
}
scheme := runtime.NewScheme()
err := apiv1alpha1.AddToScheme(scheme)
require.NoError(t, err)
codecFactory := serializer.WithoutConversionCodecFactory{
CodecFactory: serializer.NewCodecFactory(scheme),
}
client := &fake.RESTClient{
GroupVersion: apiv1alpha1.GroupVersion,
VersionedAPIPath: fmt.Sprintf("/apis/%s", apiv1alpha1.GroupVersion.String()),
NegotiatedSerializer: codecFactory,
Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: objBody(codecFactory.LegacyCodec(apiv1alpha1.GroupVersion), &crds),
}, nil
}),
}
cs := &crdSource{
crdClient: client,
namespace: "test-ns",
crdResource: "dnsendpoints",
codec: runtime.NewParameterCodec(scheme),
labelSelector: labels.Everything(),
}
res, err := cs.Endpoints(t.Context())
require.NoError(t, err)
for _, ep := range res {
require.Contains(t, ep.Labels, endpoint.ResourceLabelKey)
}
}
func helperCreateWatcherWithInformer(t *testing.T) (*cachetesting.FakeControllerSource, crdSource) { func helperCreateWatcherWithInformer(t *testing.T) (*cachetesting.FakeControllerSource, crdSource) {
t.Helper() t.Helper()
ctx := t.Context() ctx := t.Context()
@ -679,3 +741,39 @@ func helperCreateWatcherWithInformer(t *testing.T) (*cachetesting.FakeController
return watcher, *cs return watcher, *cs
} }
// generateTestFixtureDNSEndpointsByType generates DNSEndpoint CRDs according to the provided counts per RecordType.
func generateTestFixtureDNSEndpointsByType(namespace string, typeCounts map[string]int) apiv1alpha1.DNSEndpointList {
var result []apiv1alpha1.DNSEndpoint
idx := 0
for rt, count := range typeCounts {
for i := 0; i < count; i++ {
result = append(result, apiv1alpha1.DNSEndpoint{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("dnsendpoint-%s-%d", rt, idx),
Namespace: namespace,
},
Spec: apiv1alpha1.DNSEndpointSpec{
Endpoints: []*endpoint.Endpoint{
{
DNSName: strings.ToLower(fmt.Sprintf("%s-%d.example.com", rt, idx)),
RecordType: rt,
Targets: endpoint.Targets{fmt.Sprintf("192.0.2.%d", idx)},
RecordTTL: 300,
},
},
},
})
idx++
}
}
// Shuffle the result to ensure randomness in the order.
rand.New(rand.NewSource(time.Now().UnixNano()))
rand.Shuffle(len(result), func(i, j int) {
result[i], result[j] = result[j], result[i]
})
return apiv1alpha1.DNSEndpointList{
Items: result,
}
}

View File

@ -81,12 +81,12 @@ func validateEndpoint(t *testing.T, endpoint, expected *endpoint.Endpoint) {
t.Errorf("RecordTTL expected %v, got %v", expected.RecordTTL, endpoint.RecordTTL) t.Errorf("RecordTTL expected %v, got %v", expected.RecordTTL, endpoint.RecordTTL)
} }
// if non-empty record type is expected, check that it matches. // if a non-empty record type is expected, check that it matches.
if endpoint.RecordType != expected.RecordType { if endpoint.RecordType != expected.RecordType {
t.Errorf("RecordType expected %q, got %q", expected.RecordType, endpoint.RecordType) t.Errorf("RecordType expected %q, got %q", expected.RecordType, endpoint.RecordType)
} }
// if non-empty labels are expected, check that they matches. // if non-empty labels are expected, check that they match.
if expected.Labels != nil && !reflect.DeepEqual(endpoint.Labels, expected.Labels) { if expected.Labels != nil && !reflect.DeepEqual(endpoint.Labels, expected.Labels) {
t.Errorf("Labels expected %s, got %s", expected.Labels, endpoint.Labels) t.Errorf("Labels expected %s, got %s", expected.Labels, endpoint.Labels)
} }