From 95c2c72d22f71815e6689d176992b5c1ca6ccaf2 Mon Sep 17 00:00:00 2001 From: Ivan Ka <5395690+ivankatliarchuk@users.noreply.github.com> Date: Thu, 29 May 2025 21:04:18 +0100 Subject: [PATCH] fix(provider): aws-sd provider null pointer (#5404) Signed-off-by: ivan katliarchuk --- endpoint/labels.go | 4 +- provider/awssd/aws_sd.go | 130 ++++---- provider/awssd/aws_sd_test.go | 488 ++++++++++++++--------------- provider/awssd/fixtures_test.go | 275 ++++++++++++++++ source/informers/informers_test.go | 6 +- 5 files changed, 585 insertions(+), 318 deletions(-) create mode 100644 provider/awssd/fixtures_test.go diff --git a/endpoint/labels.go b/endpoint/labels.go index 7ee9cf8d8..f5e9ee33d 100644 --- a/endpoint/labels.go +++ b/endpoint/labels.go @@ -90,8 +90,8 @@ func NewLabelsFromStringPlain(labelText string) (Labels, error) { func NewLabelsFromString(labelText string, aesKey []byte) (Labels, error) { if len(aesKey) != 0 { decryptedText, encryptionNonce, err := DecryptText(strings.Trim(labelText, "\""), aesKey) - // in case if we have decryption error, just try process original text - // decryption errors should be ignored here, because we can already have plain-text labels in registry + // in case if we have a decryption error, try process original text + // decryption errors should be ignored here, because we can already have plain-text labels in the registry if err == nil { labels, err := NewLabelsFromStringPlain(decryptedText) if err == nil { diff --git a/provider/awssd/aws_sd.go b/provider/awssd/aws_sd.go index 44511f0b9..7fb9f7cd8 100644 --- a/provider/awssd/aws_sd.go +++ b/provider/awssd/aws_sd.go @@ -37,6 +37,9 @@ import ( const ( defaultTTL = 300 + // https://github.com/aws/aws-sdk-go-v2/blob/cf8509382340d6afdc93612550d56d685181bbb3/service/servicediscovery/api_op_ListServices.go#L42 + maxResults = 100 + sdNamespaceTypePublic = "public" sdNamespaceTypePrivate = "private" @@ -117,7 +120,7 @@ func newSdNamespaceFilter(namespaceTypeConfig string) sdtypes.NamespaceFilter { } } -// awsTags converts user supplied tags to AWS format +// awsTags converts user-supplied tags to AWS format func awsTags(tags map[string]string) []sdtypes.Tag { awsTags := make([]sdtypes.Tag, 0, len(tags)) for k, v := range tags { @@ -155,6 +158,11 @@ func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp continue } + if srv.Description == nil { + log.Warnf("Skipping service %q as owner id not configured", *srv.Name) + continue + } + endpoints = append(endpoints, p.instancesToEndpoint(ns, srv, resp.Instances)) } } @@ -167,6 +175,7 @@ func (p *AWSSDProvider) instancesToEndpoint(ns *sdtypes.NamespaceSummary, srv *s recordName := *srv.Name + "." + *ns.Name labels := endpoint.NewLabels() + labels[endpoint.AWSSDDescriptionLabel] = *srv.Description newEndpoint := &endpoint.Endpoint{ @@ -288,7 +297,7 @@ func (p *AWSSDProvider) submitCreates(ctx context.Context, namespaces []*sdtypes if err != nil { return err } - // update local list of services + // update a local list of services services[*srv.Name] = srv } else if ch.RecordTTL.IsConfigured() && *srv.DnsConfig.DnsRecords[0].TTL != int64(ch.RecordTTL) { // update service when TTL differ @@ -360,7 +369,7 @@ func (p *AWSSDProvider) ListNamespaces(ctx context.Context) ([]*sdtypes.Namespac return namespaces, nil } -// ListServicesByNamespaceID returns list of services in given namespace. +// ListServicesByNamespaceID returns a list of services in a given namespace. func (p *AWSSDProvider) ListServicesByNamespaceID(ctx context.Context, namespaceID *string) (map[string]*sdtypes.Service, error) { services := make([]sdtypes.ServiceSummary, 0) @@ -369,7 +378,7 @@ func (p *AWSSDProvider) ListServicesByNamespaceID(ctx context.Context, namespace Name: sdtypes.ServiceFilterNameNamespaceId, Values: []string{*namespaceID}, }}, - MaxResults: aws.Int32(100), + MaxResults: aws.Int32(maxResults), }) for paginator.HasMorePages() { resp, err := paginator.NextPage(ctx) @@ -412,32 +421,32 @@ func (p *AWSSDProvider) CreateService(ctx context.Context, namespaceID *string, ttl = int64(ep.RecordTTL) } - if !p.dryRun { - out, err := p.client.CreateService(ctx, &sd.CreateServiceInput{ - Name: srvName, - Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]), - DnsConfig: &sdtypes.DnsConfig{ - RoutingPolicy: routingPolicy, - DnsRecords: []sdtypes.DnsRecord{{ - Type: srvType, - TTL: aws.Int64(ttl), - }}, - }, - NamespaceId: namespaceID, - Tags: p.tags, - }) - if err != nil { - return nil, err - } - - return out.Service, nil + if p.dryRun { + // return a mock service summary in case of a dry run + return &sdtypes.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil } - // return mock service summary in case of dry run - return &sdtypes.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil + out, err := p.client.CreateService(ctx, &sd.CreateServiceInput{ + Name: srvName, + Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]), + DnsConfig: &sdtypes.DnsConfig{ + RoutingPolicy: routingPolicy, + DnsRecords: []sdtypes.DnsRecord{{ + Type: srvType, + TTL: aws.Int64(ttl), + }}, + }, + NamespaceId: namespaceID, + Tags: p.tags, + }) + if err != nil { + return nil, err + } + + return out.Service, nil } -// UpdateService updates the specified service with information from provided endpoint. +// UpdateService updates the specified service with information from the provided endpoint. func (p *AWSSDProvider) UpdateService(ctx context.Context, service *sdtypes.Service, ep *endpoint.Endpoint) error { log.Infof("Updating service \"%s\"", *service.Name) @@ -448,45 +457,52 @@ func (p *AWSSDProvider) UpdateService(ctx context.Context, service *sdtypes.Serv ttl = int64(ep.RecordTTL) } - if !p.dryRun { - _, err := p.client.UpdateService(ctx, &sd.UpdateServiceInput{ - Id: service.Id, - Service: &sdtypes.ServiceChange{ - Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]), - DnsConfig: &sdtypes.DnsConfigChange{ - DnsRecords: []sdtypes.DnsRecord{{ - Type: srvType, - TTL: aws.Int64(ttl), - }}, - }, - }, - }) - if err != nil { - return err - } + if p.dryRun { + return nil } - return nil + _, err := p.client.UpdateService(ctx, &sd.UpdateServiceInput{ + Id: service.Id, + Service: &sdtypes.ServiceChange{ + Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]), + DnsConfig: &sdtypes.DnsConfigChange{ + DnsRecords: []sdtypes.DnsRecord{{ + Type: srvType, + TTL: aws.Int64(ttl), + }}, + }, + }, + }) + return err } // DeleteService deletes empty Service from AWS API if its owner id match func (p *AWSSDProvider) DeleteService(ctx context.Context, service *sdtypes.Service) error { log.Debugf("Check if service \"%s\" owner id match and it can be deleted", *service.Name) - if !p.dryRun && p.cleanEmptyService { - // convert ownerID string to service description format - label := endpoint.NewLabels() - label[endpoint.OwnerLabelKey] = p.ownerID - label[endpoint.AWSSDDescriptionLabel] = label.SerializePlain(false) - if strings.HasPrefix(*service.Description, label[endpoint.AWSSDDescriptionLabel]) { - log.Infof("Deleting service \"%s\"", *service.Name) - _, err := p.client.DeleteService(ctx, &sd.DeleteServiceInput{ - Id: aws.String(*service.Id), - }) - return err - } - log.Debugf("Skipping service removal %s because owner id does not match, found: \"%s\", required: \"%s\"", *service.Name, *service.Description, label[endpoint.AWSSDDescriptionLabel]) + if p.dryRun || !p.cleanEmptyService { + return nil } + + // convert ownerID string to the service description format + label := endpoint.NewLabels() + label[endpoint.OwnerLabelKey] = p.ownerID + label[endpoint.AWSSDDescriptionLabel] = label.SerializePlain(false) + + if service.Description == nil { + log.Debugf("Skipping service removal %q because owner id (service.Description) not set, when should be %q", *service.Name, label[endpoint.AWSSDDescriptionLabel]) + return nil + } + + if strings.HasPrefix(*service.Description, label[endpoint.AWSSDDescriptionLabel]) { + log.Infof("Deleting service \"%s\"", *service.Name) + _, err := p.client.DeleteService(ctx, &sd.DeleteServiceInput{ + Id: aws.String(*service.Id), + }) + return err + } + log.Debugf("Skipping service removal %q because owner id does not match, found: %q, required: %q", *service.Name, *service.Description, label[endpoint.AWSSDDescriptionLabel]) + return nil } @@ -619,7 +635,7 @@ func (p *AWSSDProvider) routingPolicyFromEndpoint(ep *endpoint.Endpoint) sdtypes return sdtypes.RoutingPolicyWeighted } -// determine service type (A, AAAA, CNAME) from given endpoint +// determine the service type (A, AAAA, CNAME) from a given endpoint func (p *AWSSDProvider) serviceTypeFromEndpoint(ep *endpoint.Endpoint) sdtypes.RecordType { switch ep.RecordType { case endpoint.RecordTypeCNAME: diff --git a/provider/awssd/aws_sd_test.go b/provider/awssd/aws_sd_test.go index d49760415..f3cacb8c9 100644 --- a/provider/awssd/aws_sd_test.go +++ b/provider/awssd/aws_sd_test.go @@ -18,16 +18,12 @@ package awssd import ( "context" - "errors" - "math/rand" "reflect" - "strconv" "testing" - "time" "github.com/aws/aws-sdk-go-v2/aws" - sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery" sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,231 +32,6 @@ import ( "sigs.k8s.io/external-dns/plan" ) -// Compile time checks for interface conformance -var _ AWSSDClient = &AWSSDClientStub{} - -var ( - ErrNamespaceNotFound = errors.New("Namespace not found") -) - -type AWSSDClientStub struct { - // map[namespace_id]namespace - namespaces map[string]*sdtypes.Namespace - - // map[namespace_id] => map[service_id]instance - services map[string]map[string]*sdtypes.Service - - // map[service_id] => map[inst_id]instance - instances map[string]map[string]*sdtypes.Instance - - // []inst_id - deregistered []string -} - -func (s *AWSSDClientStub) CreateService(ctx context.Context, input *sd.CreateServiceInput, optFns ...func(*sd.Options)) (*sd.CreateServiceOutput, error) { - srv := &sdtypes.Service{ - Id: aws.String(strconv.Itoa(rand.Intn(10000))), - DnsConfig: input.DnsConfig, - Name: input.Name, - Description: input.Description, - CreateDate: aws.Time(time.Now()), - CreatorRequestId: input.CreatorRequestId, - } - - nsServices, ok := s.services[*input.NamespaceId] - if !ok { - nsServices = make(map[string]*sdtypes.Service) - s.services[*input.NamespaceId] = nsServices - } - nsServices[*srv.Id] = srv - - return &sd.CreateServiceOutput{ - Service: srv, - }, nil -} - -func (s *AWSSDClientStub) DeregisterInstance(ctx context.Context, input *sd.DeregisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.DeregisterInstanceOutput, error) { - serviceInstances := s.instances[*input.ServiceId] - delete(serviceInstances, *input.InstanceId) - s.deregistered = append(s.deregistered, *input.InstanceId) - - return &sd.DeregisterInstanceOutput{}, nil -} - -func (s *AWSSDClientStub) GetService(ctx context.Context, input *sd.GetServiceInput, optFns ...func(options *sd.Options)) (*sd.GetServiceOutput, error) { - for _, entry := range s.services { - srv, ok := entry[*input.Id] - if ok { - return &sd.GetServiceOutput{ - Service: srv, - }, nil - } - } - - return nil, errors.New("service not found") -} - -func (s *AWSSDClientStub) DiscoverInstances(ctx context.Context, input *sd.DiscoverInstancesInput, opts ...func(options *sd.Options)) (*sd.DiscoverInstancesOutput, error) { - instances := make([]sdtypes.HttpInstanceSummary, 0) - - var foundNs bool - for _, ns := range s.namespaces { - if *ns.Name == *input.NamespaceName { - foundNs = true - - for _, srv := range s.services[*ns.Id] { - if *srv.Name == *input.ServiceName { - for _, inst := range s.instances[*srv.Id] { - instances = append(instances, *instanceToHTTPInstanceSummary(inst)) - } - } - } - } - } - - if !foundNs { - return nil, ErrNamespaceNotFound - } - - return &sd.DiscoverInstancesOutput{ - Instances: instances, - }, nil -} - -func (s *AWSSDClientStub) ListNamespaces(ctx context.Context, input *sd.ListNamespacesInput, optFns ...func(options *sd.Options)) (*sd.ListNamespacesOutput, error) { - namespaces := make([]sdtypes.NamespaceSummary, 0) - - for _, ns := range s.namespaces { - if len(input.Filters) > 0 && input.Filters[0].Name == sdtypes.NamespaceFilterNameType { - if ns.Type != sdtypes.NamespaceType(input.Filters[0].Values[0]) { - // skip namespaces not matching filter - continue - } - } - namespaces = append(namespaces, *namespaceToNamespaceSummary(ns)) - } - - return &sd.ListNamespacesOutput{ - Namespaces: namespaces, - }, nil -} - -func (s *AWSSDClientStub) ListServices(ctx context.Context, input *sd.ListServicesInput, optFns ...func(options *sd.Options)) (*sd.ListServicesOutput, error) { - services := make([]sdtypes.ServiceSummary, 0) - - // get namespace filter - if len(input.Filters) == 0 || input.Filters[0].Name != sdtypes.ServiceFilterNameNamespaceId { - return nil, errors.New("missing namespace filter") - } - nsID := input.Filters[0].Values[0] - - for _, srv := range s.services[nsID] { - services = append(services, *serviceToServiceSummary(srv)) - } - - return &sd.ListServicesOutput{ - Services: services, - }, nil -} - -func (s *AWSSDClientStub) RegisterInstance(ctx context.Context, input *sd.RegisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.RegisterInstanceOutput, error) { - srvInstances, ok := s.instances[*input.ServiceId] - if !ok { - srvInstances = make(map[string]*sdtypes.Instance) - s.instances[*input.ServiceId] = srvInstances - } - - srvInstances[*input.InstanceId] = &sdtypes.Instance{ - Id: input.InstanceId, - Attributes: input.Attributes, - CreatorRequestId: input.CreatorRequestId, - } - - return &sd.RegisterInstanceOutput{}, nil -} - -func (s *AWSSDClientStub) UpdateService(ctx context.Context, input *sd.UpdateServiceInput, optFns ...func(options *sd.Options)) (*sd.UpdateServiceOutput, error) { - out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id}) - if err != nil { - return nil, err - } - - origSrv := out.Service - updateSrv := input.Service - - origSrv.Description = updateSrv.Description - origSrv.DnsConfig.DnsRecords = updateSrv.DnsConfig.DnsRecords - - return &sd.UpdateServiceOutput{}, nil -} - -func (s *AWSSDClientStub) DeleteService(ctx context.Context, input *sd.DeleteServiceInput, optFns ...func(options *sd.Options)) (*sd.DeleteServiceOutput, error) { - out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id}) - if err != nil { - return nil, err - } - - service := out.Service - namespace := s.services[*service.NamespaceId] - delete(namespace, *input.Id) - - return &sd.DeleteServiceOutput{}, nil -} - -func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, namespaceTypeFilter, ownerID string) *AWSSDProvider { - return &AWSSDProvider{ - client: api, - dryRun: false, - namespaceFilter: domainFilter, - namespaceTypeFilter: newSdNamespaceFilter(namespaceTypeFilter), - cleanEmptyService: true, - ownerID: ownerID, - } -} - -func instanceToHTTPInstanceSummary(instance *sdtypes.Instance) *sdtypes.HttpInstanceSummary { - if instance == nil { - return nil - } - - return &sdtypes.HttpInstanceSummary{ - InstanceId: instance.Id, - Attributes: instance.Attributes, - } -} - -func namespaceToNamespaceSummary(namespace *sdtypes.Namespace) *sdtypes.NamespaceSummary { - if namespace == nil { - return nil - } - - return &sdtypes.NamespaceSummary{ - Id: namespace.Id, - Type: namespace.Type, - Name: namespace.Name, - Arn: namespace.Arn, - } -} - -func serviceToServiceSummary(service *sdtypes.Service) *sdtypes.ServiceSummary { - if service == nil { - return nil - } - - return &sdtypes.ServiceSummary{ - Arn: service.Arn, - CreateDate: service.CreateDate, - Description: service.Description, - DnsConfig: service.DnsConfig, - HealthCheckConfig: service.HealthCheckConfig, - HealthCheckCustomConfig: service.HealthCheckCustomConfig, - Id: service.Id, - InstanceCount: service.InstanceCount, - Name: service.Name, - Type: service.Type, - } -} - func TestAWSSDProvider_Records(t *testing.T) { namespaces := map[string]*sdtypes.Namespace{ "private": { @@ -324,6 +95,19 @@ func TestAWSSDProvider_Records(t *testing.T) { }}, }, }, + "aaaa-srv-not-managed-without-owner-id": { + Id: aws.String("aaaa-srv"), + Name: aws.String("service5"), + Description: nil, + DnsConfig: &sdtypes.DnsConfig{ + NamespaceId: aws.String("private"), + RoutingPolicy: sdtypes.RoutingPolicyWeighted, + DnsRecords: []sdtypes.DnsRecord{{ + Type: sdtypes.RecordTypeAaaa, + TTL: aws.Int64(100), + }}, + }, + }, }, } @@ -414,9 +198,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) { ctx := context.Background() // apply creates - provider.ApplyChanges(ctx, &plan.Changes{ + err := provider.ApplyChanges(ctx, &plan.Changes{ Create: expectedEndpoints, }) + assert.NoError(t, err) // make sure services were created assert.Len(t, api.services["private"], 3) @@ -431,9 +216,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) { ctx = context.Background() // apply deletes - provider.ApplyChanges(ctx, &plan.Changes{ + err = provider.ApplyChanges(ctx, &plan.Changes{ Delete: expectedEndpoints, }) + assert.NoError(t, err) // make sure all instances are gone endpoints, _ = provider.Records(ctx) @@ -616,7 +402,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "") // A type - provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{ + _, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{ Labels: map[string]string{ endpoint.AWSSDDescriptionLabel: "A-srv", }, @@ -624,6 +410,8 @@ func TestAWSSDProvider_CreateService(t *testing.T) { RecordTTL: 60, Targets: endpoint.Targets{"1.2.3.4"}, }) + assert.NoError(t, err) + expectedServices["A-srv"] = &sdtypes.Service{ Name: aws.String("A-srv"), Description: aws.String("A-srv"), @@ -638,7 +426,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { } // AAAA type - provider.CreateService(context.Background(), aws.String("private"), aws.String("AAAA-srv"), &endpoint.Endpoint{ + _, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("AAAA-srv"), &endpoint.Endpoint{ Labels: map[string]string{ endpoint.AWSSDDescriptionLabel: "AAAA-srv", }, @@ -646,6 +434,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { RecordTTL: 60, Targets: endpoint.Targets{"::1234:5678:"}, }) + assert.NoError(t, err) expectedServices["AAAA-srv"] = &sdtypes.Service{ Name: aws.String("AAAA-srv"), Description: aws.String("AAAA-srv"), @@ -660,7 +449,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { } // CNAME type - provider.CreateService(context.Background(), aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{ + _, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{ Labels: map[string]string{ endpoint.AWSSDDescriptionLabel: "CNAME-srv", }, @@ -668,6 +457,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { RecordTTL: 80, Targets: endpoint.Targets{"cname.target.com"}, }) + assert.NoError(t, err) expectedServices["CNAME-srv"] = &sdtypes.Service{ Name: aws.String("CNAME-srv"), Description: aws.String("CNAME-srv"), @@ -682,7 +472,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { } // ALIAS type - provider.CreateService(context.Background(), aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{ + _, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{ Labels: map[string]string{ endpoint.AWSSDDescriptionLabel: "ALIAS-srv", }, @@ -690,6 +480,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) { RecordTTL: 100, Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com"}, }) + assert.NoError(t, err) expectedServices["ALIAS-srv"] = &sdtypes.Service{ Name: aws.String("ALIAS-srv"), Description: aws.String("ALIAS-srv"), @@ -703,21 +494,68 @@ func TestAWSSDProvider_CreateService(t *testing.T) { NamespaceId: aws.String("private"), } - validateAWSSDServicesMapsEqual(t, expectedServices, api.services["private"]) + testHelperAWSSDServicesMapsEqual(t, expectedServices, api.services["private"]) } -func validateAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sdtypes.Service, services map[string]*sdtypes.Service) { - require.Len(t, services, len(expected)) - - for _, srv := range services { - validateAWSSDServicesEqual(t, expected[*srv.Name], srv) +func TestAWSSDProvider_CreateServiceDryRun(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "private": { + Id: aws.String("private"), + Name: aws.String("private.com"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, } + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: make(map[string]map[string]*sdtypes.Service), + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "") + provider.dryRun = true + + service, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{ + Labels: map[string]string{ + endpoint.AWSSDDescriptionLabel: "A-srv", + }, + RecordType: endpoint.RecordTypeA, + RecordTTL: 60, + Targets: endpoint.Targets{"1.2.3.4"}, + }) + assert.NoError(t, err) + + assert.NotNil(t, service) + assert.Equal(t, "dry-run-service", *service.Name) } -func validateAWSSDServicesEqual(t *testing.T, expected *sdtypes.Service, srv *sdtypes.Service) { - assert.Equal(t, *expected.Description, *srv.Description) - assert.Equal(t, *expected.Name, *srv.Name) - assert.True(t, reflect.DeepEqual(*expected.DnsConfig, *srv.DnsConfig)) +func TestAWSSDProvider_CreateService_LabelNotSet(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "private": { + Id: aws.String("private"), + Name: aws.String("private.com"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, + } + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: make(map[string]map[string]*sdtypes.Service), + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-123") + + service, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{ + Labels: map[string]string{ + "wrong-unsupported-label": "A-srv", + }, + RecordType: endpoint.RecordTypeA, + RecordTTL: 60, + Targets: endpoint.Targets{"1.2.3.4"}, + }) + + assert.NoError(t, err) + assert.NotNil(t, service) + assert.Empty(t, *service.Description) } func TestAWSSDProvider_UpdateService(t *testing.T) { @@ -754,14 +592,63 @@ func TestAWSSDProvider_UpdateService(t *testing.T) { provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "") // update service with different TTL - provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{ + err := provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{ RecordType: endpoint.RecordTypeA, RecordTTL: 100, }) + assert.NoError(t, err) + assert.Len(t, api.services["private"], 1) assert.Equal(t, int64(100), *api.services["private"]["srv1"].DnsConfig.DnsRecords[0].TTL) } +func TestAWSSDProvider_UpdateService_DryRun(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "private": { + Id: aws.String("private"), + Name: aws.String("private.com"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, + } + + services := map[string]map[string]*sdtypes.Service{ + "private": { + "srv1": { + Id: aws.String("srv1"), + Name: aws.String("service1"), + NamespaceId: aws.String("private"), + DnsConfig: &sdtypes.DnsConfig{ + RoutingPolicy: sdtypes.RoutingPolicyMultivalue, + DnsRecords: []sdtypes.DnsRecord{{ + Type: sdtypes.RecordTypeA, + TTL: aws.Int64(60), + }}, + }, + }, + }, + } + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: services, + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "") + provider.dryRun = true + + // update service with different TTL + err := provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{ + RecordType: endpoint.RecordTypeAAAA, + RecordTTL: 100, + }) + + assert.NoError(t, err) + assert.Len(t, api.services["private"], 1) + // records should not be updated + assert.NotEqual(t, 100, api.services["private"]["srv1"].DnsConfig.DnsRecords[0].TTL) + assert.NotEqual(t, endpoint.RecordTypeAAAA, api.services["private"]["srv1"].DnsConfig.DnsRecords[0].Type) +} + func TestAWSSDProvider_DeleteService(t *testing.T) { namespaces := map[string]*sdtypes.Namespace{ "private": { @@ -791,6 +678,12 @@ func TestAWSSDProvider_DeleteService(t *testing.T) { Name: aws.String("service3"), NamespaceId: aws.String("private"), }, + "srv4": { + Id: aws.String("srv4"), + Description: nil, + Name: aws.String("service4"), + NamespaceId: aws.String("private"), + }, }, } @@ -803,24 +696,105 @@ func TestAWSSDProvider_DeleteService(t *testing.T) { // delete first service err := provider.DeleteService(context.Background(), services["private"]["srv1"]) - require.NoError(t, err) - assert.Len(t, api.services["private"], 2) + assert.NoError(t, err) + assert.Len(t, api.services["private"], 3) // delete third service - err1 := provider.DeleteService(context.Background(), services["private"]["srv3"]) - require.NoError(t, err1) - assert.Len(t, api.services["private"], 1) + err = provider.DeleteService(context.Background(), services["private"]["srv3"]) + assert.NoError(t, err) + assert.Len(t, api.services["private"], 2) - expectedServices := map[string]*sdtypes.Service{ + // delete service with no description + err = provider.DeleteService(context.Background(), services["private"]["srv4"]) + assert.NoError(t, err) + + expected := map[string]*sdtypes.Service{ "srv2": { Id: aws.String("srv2"), Description: aws.String("heritage=external-dns,external-dns/owner=owner-id"), Name: aws.String("service2"), NamespaceId: aws.String("private"), }, + "srv4": { + Id: aws.String("srv4"), + Description: nil, + Name: aws.String("service4"), + NamespaceId: aws.String("private"), + }, } - assert.Equal(t, expectedServices, api.services["private"]) + assert.Equal(t, expected, api.services["private"]) +} + +func TestAWSSDProvider_DeleteServiceEmptyDescription_Logging(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "private": { + Id: aws.String("private"), + Name: aws.String("private.com"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, + } + + services := map[string]map[string]*sdtypes.Service{ + "private": { + "srv1": { + Id: aws.String("srv1"), + Description: nil, + Name: aws.String("service1"), + NamespaceId: aws.String("private"), + }, + }, + } + + logs := testutils.LogsUnderTestWithLogLevel(log.DebugLevel, t) + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: services, + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-id") + + // delete service + err := provider.DeleteService(context.Background(), services["private"]["srv1"]) + assert.NoError(t, err) + assert.Len(t, api.services["private"], 1) + + testutils.TestHelperLogContainsWithLogLevel("Skipping service removal \"service1\" because owner id (service.Description) not set, when should be", log.DebugLevel, logs, t) +} + +func TestAWSSDProvider_DeleteServiceDryRun(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "private": { + Id: aws.String("private"), + Name: aws.String("private.com"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, + } + + services := map[string]map[string]*sdtypes.Service{ + "private": { + "srv1": { + Id: aws.String("srv1"), + Description: aws.String("heritage=external-dns,external-dns/owner=owner-id"), + Name: aws.String("service1"), + NamespaceId: aws.String("private"), + }, + }, + } + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: services, + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-id") + provider.dryRun = true + + // delete first service + err := provider.DeleteService(context.Background(), services["private"]["srv1"]) + assert.NoError(t, err) + assert.Len(t, api.services["private"], 1) } func TestAWSSDProvider_RegisterInstance(t *testing.T) { @@ -897,12 +871,13 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) { expectedInstances := make(map[string]*sdtypes.Instance) // IPv4-based instance - provider.RegisterInstance(context.Background(), services["private"]["a-srv"], &endpoint.Endpoint{ + err := provider.RegisterInstance(context.Background(), services["private"]["a-srv"], &endpoint.Endpoint{ RecordType: endpoint.RecordTypeA, DNSName: "service1.private.com.", RecordTTL: 300, Targets: endpoint.Targets{"1.2.3.4", "1.2.3.5"}, }) + assert.NoError(t, err) expectedInstances["1.2.3.4"] = &sdtypes.Instance{ Id: aws.String("1.2.3.4"), Attributes: map[string]string{ @@ -917,12 +892,13 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) { } // AWS ELB instance (ALIAS) - provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{ + err = provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{ RecordType: endpoint.RecordTypeCNAME, DNSName: "service1.private.com.", RecordTTL: 300, Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com", "load-balancer.us-west-2.elb.amazonaws.com"}, }) + assert.NoError(t, err) expectedInstances["load-balancer.us-east-1.elb.amazonaws.com"] = &sdtypes.Instance{ Id: aws.String("load-balancer.us-east-1.elb.amazonaws.com"), Attributes: map[string]string{ diff --git a/provider/awssd/fixtures_test.go b/provider/awssd/fixtures_test.go new file mode 100644 index 000000000..8d949632f --- /dev/null +++ b/provider/awssd/fixtures_test.go @@ -0,0 +1,275 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package awssd + +import ( + "context" + "errors" + "math/rand" + "reflect" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sigs.k8s.io/external-dns/endpoint" + + sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" +) + +var ( + // Compile time checks for interface conformance + _ AWSSDClient = &AWSSDClientStub{} + ErrNamespaceNotFound = errors.New("namespace not found") +) + +type AWSSDClientStub struct { + // map[namespace_id]namespace + namespaces map[string]*types.Namespace + + // map[namespace_id] => map[service_id]instance + services map[string]map[string]*types.Service + + // map[service_id] => map[inst_id]instance + instances map[string]map[string]*types.Instance + + // []inst_id + deregistered []string +} + +func (s *AWSSDClientStub) CreateService(_ context.Context, input *servicediscovery.CreateServiceInput, _ ...func(*servicediscovery.Options)) (*servicediscovery.CreateServiceOutput, error) { + srv := &types.Service{ + Id: aws.String(strconv.Itoa(rand.Intn(10000))), + DnsConfig: input.DnsConfig, + Name: input.Name, + Description: input.Description, + CreateDate: aws.Time(time.Now()), + CreatorRequestId: input.CreatorRequestId, + } + + nsServices, ok := s.services[*input.NamespaceId] + if !ok { + nsServices = make(map[string]*types.Service) + s.services[*input.NamespaceId] = nsServices + } + nsServices[*srv.Id] = srv + + return &servicediscovery.CreateServiceOutput{ + Service: srv, + }, nil +} + +func (s *AWSSDClientStub) DeregisterInstance(_ context.Context, input *servicediscovery.DeregisterInstanceInput, _ ...func(options *servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { + serviceInstances := s.instances[*input.ServiceId] + delete(serviceInstances, *input.InstanceId) + s.deregistered = append(s.deregistered, *input.InstanceId) + + return &servicediscovery.DeregisterInstanceOutput{}, nil +} + +func (s *AWSSDClientStub) GetService(_ context.Context, input *servicediscovery.GetServiceInput, _ ...func(options *servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { + for _, entry := range s.services { + srv, ok := entry[*input.Id] + if ok { + return &servicediscovery.GetServiceOutput{ + Service: srv, + }, nil + } + } + + return nil, errors.New("service not found") +} + +func (s *AWSSDClientStub) DiscoverInstances(_ context.Context, input *sd.DiscoverInstancesInput, _ ...func(options *sd.Options)) (*sd.DiscoverInstancesOutput, error) { + instances := make([]sdtypes.HttpInstanceSummary, 0) + + var foundNs bool + for _, ns := range s.namespaces { + if *ns.Name == *input.NamespaceName { + foundNs = true + + for _, srv := range s.services[*ns.Id] { + if *srv.Name == *input.ServiceName { + for _, inst := range s.instances[*srv.Id] { + instances = append(instances, *instanceToHTTPInstanceSummary(inst)) + } + } + } + } + } + + if !foundNs { + return nil, ErrNamespaceNotFound + } + + return &sd.DiscoverInstancesOutput{ + Instances: instances, + }, nil +} + +func (s *AWSSDClientStub) ListNamespaces(_ context.Context, input *sd.ListNamespacesInput, _ ...func(options *sd.Options)) (*sd.ListNamespacesOutput, error) { + namespaces := make([]sdtypes.NamespaceSummary, 0) + + for _, ns := range s.namespaces { + if len(input.Filters) > 0 && input.Filters[0].Name == sdtypes.NamespaceFilterNameType { + if ns.Type != sdtypes.NamespaceType(input.Filters[0].Values[0]) { + // skip namespaces not matching filter + continue + } + } + namespaces = append(namespaces, *namespaceToNamespaceSummary(ns)) + } + + return &sd.ListNamespacesOutput{ + Namespaces: namespaces, + }, nil +} + +func (s *AWSSDClientStub) ListServices(_ context.Context, input *sd.ListServicesInput, _ ...func(options *sd.Options)) (*sd.ListServicesOutput, error) { + services := make([]sdtypes.ServiceSummary, 0) + + // get namespace filter + if len(input.Filters) == 0 || input.Filters[0].Name != sdtypes.ServiceFilterNameNamespaceId { + return nil, errors.New("missing namespace filter") + } + nsID := input.Filters[0].Values[0] + + for _, srv := range s.services[nsID] { + services = append(services, *serviceToServiceSummary(srv)) + } + + return &sd.ListServicesOutput{ + Services: services, + }, nil +} + +func (s *AWSSDClientStub) RegisterInstance(ctx context.Context, input *sd.RegisterInstanceInput, _ ...func(options *sd.Options)) (*sd.RegisterInstanceOutput, error) { + srvInstances, ok := s.instances[*input.ServiceId] + if !ok { + srvInstances = make(map[string]*sdtypes.Instance) + s.instances[*input.ServiceId] = srvInstances + } + + srvInstances[*input.InstanceId] = &sdtypes.Instance{ + Id: input.InstanceId, + Attributes: input.Attributes, + CreatorRequestId: input.CreatorRequestId, + } + + return &sd.RegisterInstanceOutput{}, nil +} + +func (s *AWSSDClientStub) UpdateService(ctx context.Context, input *sd.UpdateServiceInput, _ ...func(options *sd.Options)) (*sd.UpdateServiceOutput, error) { + out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id}) + if err != nil { + return nil, err + } + + origSrv := out.Service + updateSrv := input.Service + + origSrv.Description = updateSrv.Description + origSrv.DnsConfig.DnsRecords = updateSrv.DnsConfig.DnsRecords + + return &sd.UpdateServiceOutput{}, nil +} + +func (s *AWSSDClientStub) DeleteService(ctx context.Context, input *sd.DeleteServiceInput, _ ...func(options *sd.Options)) (*sd.DeleteServiceOutput, error) { + out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id}) + if err != nil { + return nil, err + } + + service := out.Service + namespace := s.services[*service.NamespaceId] + delete(namespace, *input.Id) + + return &sd.DeleteServiceOutput{}, nil +} + +func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, namespaceTypeFilter, ownerID string) *AWSSDProvider { + return &AWSSDProvider{ + client: api, + dryRun: false, + namespaceFilter: domainFilter, + namespaceTypeFilter: newSdNamespaceFilter(namespaceTypeFilter), + cleanEmptyService: true, + ownerID: ownerID, + } +} + +func instanceToHTTPInstanceSummary(instance *sdtypes.Instance) *sdtypes.HttpInstanceSummary { + if instance == nil { + return nil + } + + return &sdtypes.HttpInstanceSummary{ + InstanceId: instance.Id, + Attributes: instance.Attributes, + } +} + +func namespaceToNamespaceSummary(namespace *sdtypes.Namespace) *sdtypes.NamespaceSummary { + if namespace == nil { + return nil + } + + return &sdtypes.NamespaceSummary{ + Id: namespace.Id, + Type: namespace.Type, + Name: namespace.Name, + Arn: namespace.Arn, + } +} + +func serviceToServiceSummary(service *sdtypes.Service) *sdtypes.ServiceSummary { + if service == nil { + return nil + } + + return &sdtypes.ServiceSummary{ + Arn: service.Arn, + CreateDate: service.CreateDate, + Description: service.Description, + DnsConfig: service.DnsConfig, + HealthCheckConfig: service.HealthCheckConfig, + HealthCheckCustomConfig: service.HealthCheckCustomConfig, + Id: service.Id, + InstanceCount: service.InstanceCount, + Name: service.Name, + Type: service.Type, + } +} + +func testHelperAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sdtypes.Service, services map[string]*sdtypes.Service) { + require.Len(t, services, len(expected)) + + for _, srv := range services { + testHelperAWSSDServicesEqual(t, expected[*srv.Name], srv) + } +} + +func testHelperAWSSDServicesEqual(t *testing.T, expected *sdtypes.Service, srv *sdtypes.Service) { + assert.Equal(t, *expected.Description, *srv.Description) + assert.Equal(t, *expected.Name, *srv.Name) + assert.True(t, reflect.DeepEqual(*expected.DnsConfig, *srv.DnsConfig)) +} diff --git a/source/informers/informers_test.go b/source/informers/informers_test.go index 2268efb38..8d5fd05b0 100644 --- a/source/informers/informers_test.go +++ b/source/informers/informers_test.go @@ -92,17 +92,17 @@ func TestWaitForDynamicCacheSync(t *testing.T) { }{ { name: "all caches synced", - syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: true}, + syncResults: map[schema.GroupVersionResource]bool{{}: true}, }, { name: "some caches not synced", - syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: false}, + syncResults: map[schema.GroupVersionResource]bool{{}: false}, expectError: true, errorMsg: "failed to sync string with timeout 1m0s", }, { name: "context timeout", - syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: false}, + syncResults: map[schema.GroupVersionResource]bool{{}: false}, expectError: true, errorMsg: "failed to sync string with timeout 1m0s", },