fix(provider): aws-sd provider null pointer (#5404)

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>
This commit is contained in:
Ivan Ka 2025-05-29 21:04:18 +01:00 committed by GitHub
parent 060ded1520
commit 95c2c72d22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 585 additions and 318 deletions

View File

@ -90,8 +90,8 @@ func NewLabelsFromStringPlain(labelText string) (Labels, error) {
func NewLabelsFromString(labelText string, aesKey []byte) (Labels, error) {
if len(aesKey) != 0 {
decryptedText, encryptionNonce, err := DecryptText(strings.Trim(labelText, "\""), aesKey)
// in case if we have decryption error, just try process original text
// decryption errors should be ignored here, because we can already have plain-text labels in registry
// in case if we have a decryption error, try process original text
// decryption errors should be ignored here, because we can already have plain-text labels in the registry
if err == nil {
labels, err := NewLabelsFromStringPlain(decryptedText)
if err == nil {

View File

@ -37,6 +37,9 @@ import (
const (
defaultTTL = 300
// https://github.com/aws/aws-sdk-go-v2/blob/cf8509382340d6afdc93612550d56d685181bbb3/service/servicediscovery/api_op_ListServices.go#L42
maxResults = 100
sdNamespaceTypePublic = "public"
sdNamespaceTypePrivate = "private"
@ -117,7 +120,7 @@ func newSdNamespaceFilter(namespaceTypeConfig string) sdtypes.NamespaceFilter {
}
}
// awsTags converts user supplied tags to AWS format
// awsTags converts user-supplied tags to AWS format
func awsTags(tags map[string]string) []sdtypes.Tag {
awsTags := make([]sdtypes.Tag, 0, len(tags))
for k, v := range tags {
@ -155,6 +158,11 @@ func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp
continue
}
if srv.Description == nil {
log.Warnf("Skipping service %q as owner id not configured", *srv.Name)
continue
}
endpoints = append(endpoints, p.instancesToEndpoint(ns, srv, resp.Instances))
}
}
@ -167,6 +175,7 @@ func (p *AWSSDProvider) instancesToEndpoint(ns *sdtypes.NamespaceSummary, srv *s
recordName := *srv.Name + "." + *ns.Name
labels := endpoint.NewLabels()
labels[endpoint.AWSSDDescriptionLabel] = *srv.Description
newEndpoint := &endpoint.Endpoint{
@ -288,7 +297,7 @@ func (p *AWSSDProvider) submitCreates(ctx context.Context, namespaces []*sdtypes
if err != nil {
return err
}
// update local list of services
// update a local list of services
services[*srv.Name] = srv
} else if ch.RecordTTL.IsConfigured() && *srv.DnsConfig.DnsRecords[0].TTL != int64(ch.RecordTTL) {
// update service when TTL differ
@ -360,7 +369,7 @@ func (p *AWSSDProvider) ListNamespaces(ctx context.Context) ([]*sdtypes.Namespac
return namespaces, nil
}
// ListServicesByNamespaceID returns list of services in given namespace.
// ListServicesByNamespaceID returns a list of services in a given namespace.
func (p *AWSSDProvider) ListServicesByNamespaceID(ctx context.Context, namespaceID *string) (map[string]*sdtypes.Service, error) {
services := make([]sdtypes.ServiceSummary, 0)
@ -369,7 +378,7 @@ func (p *AWSSDProvider) ListServicesByNamespaceID(ctx context.Context, namespace
Name: sdtypes.ServiceFilterNameNamespaceId,
Values: []string{*namespaceID},
}},
MaxResults: aws.Int32(100),
MaxResults: aws.Int32(maxResults),
})
for paginator.HasMorePages() {
resp, err := paginator.NextPage(ctx)
@ -412,32 +421,32 @@ func (p *AWSSDProvider) CreateService(ctx context.Context, namespaceID *string,
ttl = int64(ep.RecordTTL)
}
if !p.dryRun {
out, err := p.client.CreateService(ctx, &sd.CreateServiceInput{
Name: srvName,
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: routingPolicy,
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
NamespaceId: namespaceID,
Tags: p.tags,
})
if err != nil {
return nil, err
}
return out.Service, nil
if p.dryRun {
// return a mock service summary in case of a dry run
return &sdtypes.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil
}
// return mock service summary in case of dry run
return &sdtypes.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil
out, err := p.client.CreateService(ctx, &sd.CreateServiceInput{
Name: srvName,
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: routingPolicy,
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
NamespaceId: namespaceID,
Tags: p.tags,
})
if err != nil {
return nil, err
}
return out.Service, nil
}
// UpdateService updates the specified service with information from provided endpoint.
// UpdateService updates the specified service with information from the provided endpoint.
func (p *AWSSDProvider) UpdateService(ctx context.Context, service *sdtypes.Service, ep *endpoint.Endpoint) error {
log.Infof("Updating service \"%s\"", *service.Name)
@ -448,45 +457,52 @@ func (p *AWSSDProvider) UpdateService(ctx context.Context, service *sdtypes.Serv
ttl = int64(ep.RecordTTL)
}
if !p.dryRun {
_, err := p.client.UpdateService(ctx, &sd.UpdateServiceInput{
Id: service.Id,
Service: &sdtypes.ServiceChange{
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sdtypes.DnsConfigChange{
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
},
})
if err != nil {
return err
}
if p.dryRun {
return nil
}
return nil
_, err := p.client.UpdateService(ctx, &sd.UpdateServiceInput{
Id: service.Id,
Service: &sdtypes.ServiceChange{
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sdtypes.DnsConfigChange{
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
},
})
return err
}
// DeleteService deletes empty Service from AWS API if its owner id match
func (p *AWSSDProvider) DeleteService(ctx context.Context, service *sdtypes.Service) error {
log.Debugf("Check if service \"%s\" owner id match and it can be deleted", *service.Name)
if !p.dryRun && p.cleanEmptyService {
// convert ownerID string to service description format
label := endpoint.NewLabels()
label[endpoint.OwnerLabelKey] = p.ownerID
label[endpoint.AWSSDDescriptionLabel] = label.SerializePlain(false)
if strings.HasPrefix(*service.Description, label[endpoint.AWSSDDescriptionLabel]) {
log.Infof("Deleting service \"%s\"", *service.Name)
_, err := p.client.DeleteService(ctx, &sd.DeleteServiceInput{
Id: aws.String(*service.Id),
})
return err
}
log.Debugf("Skipping service removal %s because owner id does not match, found: \"%s\", required: \"%s\"", *service.Name, *service.Description, label[endpoint.AWSSDDescriptionLabel])
if p.dryRun || !p.cleanEmptyService {
return nil
}
// convert ownerID string to the service description format
label := endpoint.NewLabels()
label[endpoint.OwnerLabelKey] = p.ownerID
label[endpoint.AWSSDDescriptionLabel] = label.SerializePlain(false)
if service.Description == nil {
log.Debugf("Skipping service removal %q because owner id (service.Description) not set, when should be %q", *service.Name, label[endpoint.AWSSDDescriptionLabel])
return nil
}
if strings.HasPrefix(*service.Description, label[endpoint.AWSSDDescriptionLabel]) {
log.Infof("Deleting service \"%s\"", *service.Name)
_, err := p.client.DeleteService(ctx, &sd.DeleteServiceInput{
Id: aws.String(*service.Id),
})
return err
}
log.Debugf("Skipping service removal %q because owner id does not match, found: %q, required: %q", *service.Name, *service.Description, label[endpoint.AWSSDDescriptionLabel])
return nil
}
@ -619,7 +635,7 @@ func (p *AWSSDProvider) routingPolicyFromEndpoint(ep *endpoint.Endpoint) sdtypes
return sdtypes.RoutingPolicyWeighted
}
// determine service type (A, AAAA, CNAME) from given endpoint
// determine the service type (A, AAAA, CNAME) from a given endpoint
func (p *AWSSDProvider) serviceTypeFromEndpoint(ep *endpoint.Endpoint) sdtypes.RecordType {
switch ep.RecordType {
case endpoint.RecordTypeCNAME:

View File

@ -18,16 +18,12 @@ package awssd
import (
"context"
"errors"
"math/rand"
"reflect"
"strconv"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -36,231 +32,6 @@ import (
"sigs.k8s.io/external-dns/plan"
)
// Compile time checks for interface conformance
var _ AWSSDClient = &AWSSDClientStub{}
var (
ErrNamespaceNotFound = errors.New("Namespace not found")
)
type AWSSDClientStub struct {
// map[namespace_id]namespace
namespaces map[string]*sdtypes.Namespace
// map[namespace_id] => map[service_id]instance
services map[string]map[string]*sdtypes.Service
// map[service_id] => map[inst_id]instance
instances map[string]map[string]*sdtypes.Instance
// []inst_id
deregistered []string
}
func (s *AWSSDClientStub) CreateService(ctx context.Context, input *sd.CreateServiceInput, optFns ...func(*sd.Options)) (*sd.CreateServiceOutput, error) {
srv := &sdtypes.Service{
Id: aws.String(strconv.Itoa(rand.Intn(10000))),
DnsConfig: input.DnsConfig,
Name: input.Name,
Description: input.Description,
CreateDate: aws.Time(time.Now()),
CreatorRequestId: input.CreatorRequestId,
}
nsServices, ok := s.services[*input.NamespaceId]
if !ok {
nsServices = make(map[string]*sdtypes.Service)
s.services[*input.NamespaceId] = nsServices
}
nsServices[*srv.Id] = srv
return &sd.CreateServiceOutput{
Service: srv,
}, nil
}
func (s *AWSSDClientStub) DeregisterInstance(ctx context.Context, input *sd.DeregisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.DeregisterInstanceOutput, error) {
serviceInstances := s.instances[*input.ServiceId]
delete(serviceInstances, *input.InstanceId)
s.deregistered = append(s.deregistered, *input.InstanceId)
return &sd.DeregisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) GetService(ctx context.Context, input *sd.GetServiceInput, optFns ...func(options *sd.Options)) (*sd.GetServiceOutput, error) {
for _, entry := range s.services {
srv, ok := entry[*input.Id]
if ok {
return &sd.GetServiceOutput{
Service: srv,
}, nil
}
}
return nil, errors.New("service not found")
}
func (s *AWSSDClientStub) DiscoverInstances(ctx context.Context, input *sd.DiscoverInstancesInput, opts ...func(options *sd.Options)) (*sd.DiscoverInstancesOutput, error) {
instances := make([]sdtypes.HttpInstanceSummary, 0)
var foundNs bool
for _, ns := range s.namespaces {
if *ns.Name == *input.NamespaceName {
foundNs = true
for _, srv := range s.services[*ns.Id] {
if *srv.Name == *input.ServiceName {
for _, inst := range s.instances[*srv.Id] {
instances = append(instances, *instanceToHTTPInstanceSummary(inst))
}
}
}
}
}
if !foundNs {
return nil, ErrNamespaceNotFound
}
return &sd.DiscoverInstancesOutput{
Instances: instances,
}, nil
}
func (s *AWSSDClientStub) ListNamespaces(ctx context.Context, input *sd.ListNamespacesInput, optFns ...func(options *sd.Options)) (*sd.ListNamespacesOutput, error) {
namespaces := make([]sdtypes.NamespaceSummary, 0)
for _, ns := range s.namespaces {
if len(input.Filters) > 0 && input.Filters[0].Name == sdtypes.NamespaceFilterNameType {
if ns.Type != sdtypes.NamespaceType(input.Filters[0].Values[0]) {
// skip namespaces not matching filter
continue
}
}
namespaces = append(namespaces, *namespaceToNamespaceSummary(ns))
}
return &sd.ListNamespacesOutput{
Namespaces: namespaces,
}, nil
}
func (s *AWSSDClientStub) ListServices(ctx context.Context, input *sd.ListServicesInput, optFns ...func(options *sd.Options)) (*sd.ListServicesOutput, error) {
services := make([]sdtypes.ServiceSummary, 0)
// get namespace filter
if len(input.Filters) == 0 || input.Filters[0].Name != sdtypes.ServiceFilterNameNamespaceId {
return nil, errors.New("missing namespace filter")
}
nsID := input.Filters[0].Values[0]
for _, srv := range s.services[nsID] {
services = append(services, *serviceToServiceSummary(srv))
}
return &sd.ListServicesOutput{
Services: services,
}, nil
}
func (s *AWSSDClientStub) RegisterInstance(ctx context.Context, input *sd.RegisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.RegisterInstanceOutput, error) {
srvInstances, ok := s.instances[*input.ServiceId]
if !ok {
srvInstances = make(map[string]*sdtypes.Instance)
s.instances[*input.ServiceId] = srvInstances
}
srvInstances[*input.InstanceId] = &sdtypes.Instance{
Id: input.InstanceId,
Attributes: input.Attributes,
CreatorRequestId: input.CreatorRequestId,
}
return &sd.RegisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) UpdateService(ctx context.Context, input *sd.UpdateServiceInput, optFns ...func(options *sd.Options)) (*sd.UpdateServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
origSrv := out.Service
updateSrv := input.Service
origSrv.Description = updateSrv.Description
origSrv.DnsConfig.DnsRecords = updateSrv.DnsConfig.DnsRecords
return &sd.UpdateServiceOutput{}, nil
}
func (s *AWSSDClientStub) DeleteService(ctx context.Context, input *sd.DeleteServiceInput, optFns ...func(options *sd.Options)) (*sd.DeleteServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
service := out.Service
namespace := s.services[*service.NamespaceId]
delete(namespace, *input.Id)
return &sd.DeleteServiceOutput{}, nil
}
func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, namespaceTypeFilter, ownerID string) *AWSSDProvider {
return &AWSSDProvider{
client: api,
dryRun: false,
namespaceFilter: domainFilter,
namespaceTypeFilter: newSdNamespaceFilter(namespaceTypeFilter),
cleanEmptyService: true,
ownerID: ownerID,
}
}
func instanceToHTTPInstanceSummary(instance *sdtypes.Instance) *sdtypes.HttpInstanceSummary {
if instance == nil {
return nil
}
return &sdtypes.HttpInstanceSummary{
InstanceId: instance.Id,
Attributes: instance.Attributes,
}
}
func namespaceToNamespaceSummary(namespace *sdtypes.Namespace) *sdtypes.NamespaceSummary {
if namespace == nil {
return nil
}
return &sdtypes.NamespaceSummary{
Id: namespace.Id,
Type: namespace.Type,
Name: namespace.Name,
Arn: namespace.Arn,
}
}
func serviceToServiceSummary(service *sdtypes.Service) *sdtypes.ServiceSummary {
if service == nil {
return nil
}
return &sdtypes.ServiceSummary{
Arn: service.Arn,
CreateDate: service.CreateDate,
Description: service.Description,
DnsConfig: service.DnsConfig,
HealthCheckConfig: service.HealthCheckConfig,
HealthCheckCustomConfig: service.HealthCheckCustomConfig,
Id: service.Id,
InstanceCount: service.InstanceCount,
Name: service.Name,
Type: service.Type,
}
}
func TestAWSSDProvider_Records(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
@ -324,6 +95,19 @@ func TestAWSSDProvider_Records(t *testing.T) {
}},
},
},
"aaaa-srv-not-managed-without-owner-id": {
Id: aws.String("aaaa-srv"),
Name: aws.String("service5"),
Description: nil,
DnsConfig: &sdtypes.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeAaaa,
TTL: aws.Int64(100),
}},
},
},
},
}
@ -414,9 +198,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
ctx := context.Background()
// apply creates
provider.ApplyChanges(ctx, &plan.Changes{
err := provider.ApplyChanges(ctx, &plan.Changes{
Create: expectedEndpoints,
})
assert.NoError(t, err)
// make sure services were created
assert.Len(t, api.services["private"], 3)
@ -431,9 +216,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
ctx = context.Background()
// apply deletes
provider.ApplyChanges(ctx, &plan.Changes{
err = provider.ApplyChanges(ctx, &plan.Changes{
Delete: expectedEndpoints,
})
assert.NoError(t, err)
// make sure all instances are gone
endpoints, _ = provider.Records(ctx)
@ -616,7 +402,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
// A type
provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
_, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "A-srv",
},
@ -624,6 +410,8 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
RecordTTL: 60,
Targets: endpoint.Targets{"1.2.3.4"},
})
assert.NoError(t, err)
expectedServices["A-srv"] = &sdtypes.Service{
Name: aws.String("A-srv"),
Description: aws.String("A-srv"),
@ -638,7 +426,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
}
// AAAA type
provider.CreateService(context.Background(), aws.String("private"), aws.String("AAAA-srv"), &endpoint.Endpoint{
_, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("AAAA-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "AAAA-srv",
},
@ -646,6 +434,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
RecordTTL: 60,
Targets: endpoint.Targets{"::1234:5678:"},
})
assert.NoError(t, err)
expectedServices["AAAA-srv"] = &sdtypes.Service{
Name: aws.String("AAAA-srv"),
Description: aws.String("AAAA-srv"),
@ -660,7 +449,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
}
// CNAME type
provider.CreateService(context.Background(), aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{
_, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "CNAME-srv",
},
@ -668,6 +457,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
RecordTTL: 80,
Targets: endpoint.Targets{"cname.target.com"},
})
assert.NoError(t, err)
expectedServices["CNAME-srv"] = &sdtypes.Service{
Name: aws.String("CNAME-srv"),
Description: aws.String("CNAME-srv"),
@ -682,7 +472,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
}
// ALIAS type
provider.CreateService(context.Background(), aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{
_, err = provider.CreateService(context.Background(), aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "ALIAS-srv",
},
@ -690,6 +480,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
RecordTTL: 100,
Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com"},
})
assert.NoError(t, err)
expectedServices["ALIAS-srv"] = &sdtypes.Service{
Name: aws.String("ALIAS-srv"),
Description: aws.String("ALIAS-srv"),
@ -703,21 +494,68 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
NamespaceId: aws.String("private"),
}
validateAWSSDServicesMapsEqual(t, expectedServices, api.services["private"])
testHelperAWSSDServicesMapsEqual(t, expectedServices, api.services["private"])
}
func validateAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sdtypes.Service, services map[string]*sdtypes.Service) {
require.Len(t, services, len(expected))
for _, srv := range services {
validateAWSSDServicesEqual(t, expected[*srv.Name], srv)
func TestAWSSDProvider_CreateServiceDryRun(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: make(map[string]map[string]*sdtypes.Service),
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
provider.dryRun = true
service, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "A-srv",
},
RecordType: endpoint.RecordTypeA,
RecordTTL: 60,
Targets: endpoint.Targets{"1.2.3.4"},
})
assert.NoError(t, err)
assert.NotNil(t, service)
assert.Equal(t, "dry-run-service", *service.Name)
}
func validateAWSSDServicesEqual(t *testing.T, expected *sdtypes.Service, srv *sdtypes.Service) {
assert.Equal(t, *expected.Description, *srv.Description)
assert.Equal(t, *expected.Name, *srv.Name)
assert.True(t, reflect.DeepEqual(*expected.DnsConfig, *srv.DnsConfig))
func TestAWSSDProvider_CreateService_LabelNotSet(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: make(map[string]map[string]*sdtypes.Service),
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-123")
service, err := provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
Labels: map[string]string{
"wrong-unsupported-label": "A-srv",
},
RecordType: endpoint.RecordTypeA,
RecordTTL: 60,
Targets: endpoint.Targets{"1.2.3.4"},
})
assert.NoError(t, err)
assert.NotNil(t, service)
assert.Empty(t, *service.Description)
}
func TestAWSSDProvider_UpdateService(t *testing.T) {
@ -754,14 +592,63 @@ func TestAWSSDProvider_UpdateService(t *testing.T) {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
// update service with different TTL
provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{
err := provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeA,
RecordTTL: 100,
})
assert.NoError(t, err)
assert.Len(t, api.services["private"], 1)
assert.Equal(t, int64(100), *api.services["private"]["srv1"].DnsConfig.DnsRecords[0].TTL)
}
func TestAWSSDProvider_UpdateService_DryRun(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyMultivalue,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(60),
}},
},
},
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: services,
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
provider.dryRun = true
// update service with different TTL
err := provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeAAAA,
RecordTTL: 100,
})
assert.NoError(t, err)
assert.Len(t, api.services["private"], 1)
// records should not be updated
assert.NotEqual(t, 100, api.services["private"]["srv1"].DnsConfig.DnsRecords[0].TTL)
assert.NotEqual(t, endpoint.RecordTypeAAAA, api.services["private"]["srv1"].DnsConfig.DnsRecords[0].Type)
}
func TestAWSSDProvider_DeleteService(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
@ -791,6 +678,12 @@ func TestAWSSDProvider_DeleteService(t *testing.T) {
Name: aws.String("service3"),
NamespaceId: aws.String("private"),
},
"srv4": {
Id: aws.String("srv4"),
Description: nil,
Name: aws.String("service4"),
NamespaceId: aws.String("private"),
},
},
}
@ -803,24 +696,105 @@ func TestAWSSDProvider_DeleteService(t *testing.T) {
// delete first service
err := provider.DeleteService(context.Background(), services["private"]["srv1"])
require.NoError(t, err)
assert.Len(t, api.services["private"], 2)
assert.NoError(t, err)
assert.Len(t, api.services["private"], 3)
// delete third service
err1 := provider.DeleteService(context.Background(), services["private"]["srv3"])
require.NoError(t, err1)
assert.Len(t, api.services["private"], 1)
err = provider.DeleteService(context.Background(), services["private"]["srv3"])
assert.NoError(t, err)
assert.Len(t, api.services["private"], 2)
expectedServices := map[string]*sdtypes.Service{
// delete service with no description
err = provider.DeleteService(context.Background(), services["private"]["srv4"])
assert.NoError(t, err)
expected := map[string]*sdtypes.Service{
"srv2": {
Id: aws.String("srv2"),
Description: aws.String("heritage=external-dns,external-dns/owner=owner-id"),
Name: aws.String("service2"),
NamespaceId: aws.String("private"),
},
"srv4": {
Id: aws.String("srv4"),
Description: nil,
Name: aws.String("service4"),
NamespaceId: aws.String("private"),
},
}
assert.Equal(t, expectedServices, api.services["private"])
assert.Equal(t, expected, api.services["private"])
}
func TestAWSSDProvider_DeleteServiceEmptyDescription_Logging(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
Description: nil,
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
},
},
}
logs := testutils.LogsUnderTestWithLogLevel(log.DebugLevel, t)
api := &AWSSDClientStub{
namespaces: namespaces,
services: services,
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-id")
// delete service
err := provider.DeleteService(context.Background(), services["private"]["srv1"])
assert.NoError(t, err)
assert.Len(t, api.services["private"], 1)
testutils.TestHelperLogContainsWithLogLevel("Skipping service removal \"service1\" because owner id (service.Description) not set, when should be", log.DebugLevel, logs, t)
}
func TestAWSSDProvider_DeleteServiceDryRun(t *testing.T) {
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
Description: aws.String("heritage=external-dns,external-dns/owner=owner-id"),
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
},
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: services,
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-id")
provider.dryRun = true
// delete first service
err := provider.DeleteService(context.Background(), services["private"]["srv1"])
assert.NoError(t, err)
assert.Len(t, api.services["private"], 1)
}
func TestAWSSDProvider_RegisterInstance(t *testing.T) {
@ -897,12 +871,13 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) {
expectedInstances := make(map[string]*sdtypes.Instance)
// IPv4-based instance
provider.RegisterInstance(context.Background(), services["private"]["a-srv"], &endpoint.Endpoint{
err := provider.RegisterInstance(context.Background(), services["private"]["a-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeA,
DNSName: "service1.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"1.2.3.4", "1.2.3.5"},
})
assert.NoError(t, err)
expectedInstances["1.2.3.4"] = &sdtypes.Instance{
Id: aws.String("1.2.3.4"),
Attributes: map[string]string{
@ -917,12 +892,13 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) {
}
// AWS ELB instance (ALIAS)
provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{
err = provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeCNAME,
DNSName: "service1.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com", "load-balancer.us-west-2.elb.amazonaws.com"},
})
assert.NoError(t, err)
expectedInstances["load-balancer.us-east-1.elb.amazonaws.com"] = &sdtypes.Instance{
Id: aws.String("load-balancer.us-east-1.elb.amazonaws.com"),
Attributes: map[string]string{

View File

@ -0,0 +1,275 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package awssd
import (
"context"
"errors"
"math/rand"
"reflect"
"strconv"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/servicediscovery"
"github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sigs.k8s.io/external-dns/endpoint"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
)
var (
// Compile time checks for interface conformance
_ AWSSDClient = &AWSSDClientStub{}
ErrNamespaceNotFound = errors.New("namespace not found")
)
type AWSSDClientStub struct {
// map[namespace_id]namespace
namespaces map[string]*types.Namespace
// map[namespace_id] => map[service_id]instance
services map[string]map[string]*types.Service
// map[service_id] => map[inst_id]instance
instances map[string]map[string]*types.Instance
// []inst_id
deregistered []string
}
func (s *AWSSDClientStub) CreateService(_ context.Context, input *servicediscovery.CreateServiceInput, _ ...func(*servicediscovery.Options)) (*servicediscovery.CreateServiceOutput, error) {
srv := &types.Service{
Id: aws.String(strconv.Itoa(rand.Intn(10000))),
DnsConfig: input.DnsConfig,
Name: input.Name,
Description: input.Description,
CreateDate: aws.Time(time.Now()),
CreatorRequestId: input.CreatorRequestId,
}
nsServices, ok := s.services[*input.NamespaceId]
if !ok {
nsServices = make(map[string]*types.Service)
s.services[*input.NamespaceId] = nsServices
}
nsServices[*srv.Id] = srv
return &servicediscovery.CreateServiceOutput{
Service: srv,
}, nil
}
func (s *AWSSDClientStub) DeregisterInstance(_ context.Context, input *servicediscovery.DeregisterInstanceInput, _ ...func(options *servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) {
serviceInstances := s.instances[*input.ServiceId]
delete(serviceInstances, *input.InstanceId)
s.deregistered = append(s.deregistered, *input.InstanceId)
return &servicediscovery.DeregisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) GetService(_ context.Context, input *servicediscovery.GetServiceInput, _ ...func(options *servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) {
for _, entry := range s.services {
srv, ok := entry[*input.Id]
if ok {
return &servicediscovery.GetServiceOutput{
Service: srv,
}, nil
}
}
return nil, errors.New("service not found")
}
func (s *AWSSDClientStub) DiscoverInstances(_ context.Context, input *sd.DiscoverInstancesInput, _ ...func(options *sd.Options)) (*sd.DiscoverInstancesOutput, error) {
instances := make([]sdtypes.HttpInstanceSummary, 0)
var foundNs bool
for _, ns := range s.namespaces {
if *ns.Name == *input.NamespaceName {
foundNs = true
for _, srv := range s.services[*ns.Id] {
if *srv.Name == *input.ServiceName {
for _, inst := range s.instances[*srv.Id] {
instances = append(instances, *instanceToHTTPInstanceSummary(inst))
}
}
}
}
}
if !foundNs {
return nil, ErrNamespaceNotFound
}
return &sd.DiscoverInstancesOutput{
Instances: instances,
}, nil
}
func (s *AWSSDClientStub) ListNamespaces(_ context.Context, input *sd.ListNamespacesInput, _ ...func(options *sd.Options)) (*sd.ListNamespacesOutput, error) {
namespaces := make([]sdtypes.NamespaceSummary, 0)
for _, ns := range s.namespaces {
if len(input.Filters) > 0 && input.Filters[0].Name == sdtypes.NamespaceFilterNameType {
if ns.Type != sdtypes.NamespaceType(input.Filters[0].Values[0]) {
// skip namespaces not matching filter
continue
}
}
namespaces = append(namespaces, *namespaceToNamespaceSummary(ns))
}
return &sd.ListNamespacesOutput{
Namespaces: namespaces,
}, nil
}
func (s *AWSSDClientStub) ListServices(_ context.Context, input *sd.ListServicesInput, _ ...func(options *sd.Options)) (*sd.ListServicesOutput, error) {
services := make([]sdtypes.ServiceSummary, 0)
// get namespace filter
if len(input.Filters) == 0 || input.Filters[0].Name != sdtypes.ServiceFilterNameNamespaceId {
return nil, errors.New("missing namespace filter")
}
nsID := input.Filters[0].Values[0]
for _, srv := range s.services[nsID] {
services = append(services, *serviceToServiceSummary(srv))
}
return &sd.ListServicesOutput{
Services: services,
}, nil
}
func (s *AWSSDClientStub) RegisterInstance(ctx context.Context, input *sd.RegisterInstanceInput, _ ...func(options *sd.Options)) (*sd.RegisterInstanceOutput, error) {
srvInstances, ok := s.instances[*input.ServiceId]
if !ok {
srvInstances = make(map[string]*sdtypes.Instance)
s.instances[*input.ServiceId] = srvInstances
}
srvInstances[*input.InstanceId] = &sdtypes.Instance{
Id: input.InstanceId,
Attributes: input.Attributes,
CreatorRequestId: input.CreatorRequestId,
}
return &sd.RegisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) UpdateService(ctx context.Context, input *sd.UpdateServiceInput, _ ...func(options *sd.Options)) (*sd.UpdateServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
origSrv := out.Service
updateSrv := input.Service
origSrv.Description = updateSrv.Description
origSrv.DnsConfig.DnsRecords = updateSrv.DnsConfig.DnsRecords
return &sd.UpdateServiceOutput{}, nil
}
func (s *AWSSDClientStub) DeleteService(ctx context.Context, input *sd.DeleteServiceInput, _ ...func(options *sd.Options)) (*sd.DeleteServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
service := out.Service
namespace := s.services[*service.NamespaceId]
delete(namespace, *input.Id)
return &sd.DeleteServiceOutput{}, nil
}
func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, namespaceTypeFilter, ownerID string) *AWSSDProvider {
return &AWSSDProvider{
client: api,
dryRun: false,
namespaceFilter: domainFilter,
namespaceTypeFilter: newSdNamespaceFilter(namespaceTypeFilter),
cleanEmptyService: true,
ownerID: ownerID,
}
}
func instanceToHTTPInstanceSummary(instance *sdtypes.Instance) *sdtypes.HttpInstanceSummary {
if instance == nil {
return nil
}
return &sdtypes.HttpInstanceSummary{
InstanceId: instance.Id,
Attributes: instance.Attributes,
}
}
func namespaceToNamespaceSummary(namespace *sdtypes.Namespace) *sdtypes.NamespaceSummary {
if namespace == nil {
return nil
}
return &sdtypes.NamespaceSummary{
Id: namespace.Id,
Type: namespace.Type,
Name: namespace.Name,
Arn: namespace.Arn,
}
}
func serviceToServiceSummary(service *sdtypes.Service) *sdtypes.ServiceSummary {
if service == nil {
return nil
}
return &sdtypes.ServiceSummary{
Arn: service.Arn,
CreateDate: service.CreateDate,
Description: service.Description,
DnsConfig: service.DnsConfig,
HealthCheckConfig: service.HealthCheckConfig,
HealthCheckCustomConfig: service.HealthCheckCustomConfig,
Id: service.Id,
InstanceCount: service.InstanceCount,
Name: service.Name,
Type: service.Type,
}
}
func testHelperAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sdtypes.Service, services map[string]*sdtypes.Service) {
require.Len(t, services, len(expected))
for _, srv := range services {
testHelperAWSSDServicesEqual(t, expected[*srv.Name], srv)
}
}
func testHelperAWSSDServicesEqual(t *testing.T, expected *sdtypes.Service, srv *sdtypes.Service) {
assert.Equal(t, *expected.Description, *srv.Description)
assert.Equal(t, *expected.Name, *srv.Name)
assert.True(t, reflect.DeepEqual(*expected.DnsConfig, *srv.DnsConfig))
}

View File

@ -92,17 +92,17 @@ func TestWaitForDynamicCacheSync(t *testing.T) {
}{
{
name: "all caches synced",
syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: true},
syncResults: map[schema.GroupVersionResource]bool{{}: true},
},
{
name: "some caches not synced",
syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: false},
syncResults: map[schema.GroupVersionResource]bool{{}: false},
expectError: true,
errorMsg: "failed to sync string with timeout 1m0s",
},
{
name: "context timeout",
syncResults: map[schema.GroupVersionResource]bool{schema.GroupVersionResource{}: false},
syncResults: map[schema.GroupVersionResource]bool{{}: false},
expectError: true,
errorMsg: "failed to sync string with timeout 1m0s",
},