Refactor AWS Cloud Map provider to aws-sdk-go-v2

Signed-off-by: Michael Shen <mishen@umich.edu>
This commit is contained in:
Michael Shen 2024-07-29 22:19:02 -04:00
parent ce1ab808f2
commit c4a18a9cb6
No known key found for this signature in database
GPG Key ID: 12CC712F576BDFEE
5 changed files with 358 additions and 358 deletions

1
go.mod
View File

@ -22,6 +22,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.27
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.10
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.34.4
github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.3
github.com/aws/aws-sdk-go-v2/service/sts v1.30.3
github.com/bodgit/tsig v1.2.2
github.com/cenkalti/backoff/v4 v4.3.0

2
go.sum
View File

@ -145,6 +145,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.16 h1:lhAX
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.16/go.mod h1:AblAlCwvi7Q/SFowvckgN+8M3uFPlopSYeLlbNDArhA=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII=
github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.3 h1:EthA93BNgTnk36FoI9DCKtv4S0m63WzdGDYlBp/CvHQ=
github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.3/go.mod h1:4xh/h0pevPhBkA4b2iYosZaqrThccxFREQxiGuZpJlc=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE=

View File

@ -26,8 +26,8 @@ import (
"time"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
"github.com/aws/aws-sdk-go/service/route53"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/go-logr/logr"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
@ -235,7 +235,7 @@ func main() {
log.Infof("Registry \"%s\" cannot be used with AWS Cloud Map. Switching to \"aws-sd\".", cfg.Registry)
cfg.Registry = "aws-sd"
}
p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, sd.New(aws.CreateDefaultSession(cfg)))
p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, sd.NewFromConfig(aws.CreateDefaultV2Config(cfg)))
case "azure-dns", "azure":
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.DryRun)
case "azure-private-dns":

View File

@ -24,9 +24,9 @@ import (
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/aws/aws-sdk-go-v2/aws"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
log "github.com/sirupsen/logrus"
"sigs.k8s.io/external-dns/endpoint"
@ -54,16 +54,16 @@ var (
)
// AWSSDClient is the subset of the AWS Cloud Map API that we actually use. Add methods as required.
// Signatures must match exactly. Taken from https://github.com/aws/aws-sdk-go/blob/HEAD/service/servicediscovery/api.go
// Signatures must match exactly. Taken from https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/servicediscovery
type AWSSDClient interface {
CreateService(input *sd.CreateServiceInput) (*sd.CreateServiceOutput, error)
DeregisterInstance(input *sd.DeregisterInstanceInput) (*sd.DeregisterInstanceOutput, error)
DiscoverInstancesWithContext(ctx aws.Context, input *sd.DiscoverInstancesInput, opts ...request.Option) (*sd.DiscoverInstancesOutput, error)
ListNamespacesPages(input *sd.ListNamespacesInput, fn func(*sd.ListNamespacesOutput, bool) bool) error
ListServicesPages(input *sd.ListServicesInput, fn func(*sd.ListServicesOutput, bool) bool) error
RegisterInstance(input *sd.RegisterInstanceInput) (*sd.RegisterInstanceOutput, error)
UpdateService(input *sd.UpdateServiceInput) (*sd.UpdateServiceOutput, error)
DeleteService(input *sd.DeleteServiceInput) (*sd.DeleteServiceOutput, error)
CreateService(ctx context.Context, params *sd.CreateServiceInput, optFns ...func(*sd.Options)) (*sd.CreateServiceOutput, error)
DeregisterInstance(ctx context.Context, params *sd.DeregisterInstanceInput, optFns ...func(*sd.Options)) (*sd.DeregisterInstanceOutput, error)
DiscoverInstances(ctx context.Context, params *sd.DiscoverInstancesInput, optFns ...func(*sd.Options)) (*sd.DiscoverInstancesOutput, error)
ListNamespaces(ctx context.Context, params *sd.ListNamespacesInput, optFns ...func(*sd.Options)) (*sd.ListNamespacesOutput, error)
ListServices(ctx context.Context, params *sd.ListServicesInput, optFns ...func(*sd.Options)) (*sd.ListServicesOutput, error)
RegisterInstance(ctx context.Context, params *sd.RegisterInstanceInput, optFns ...func(*sd.Options)) (*sd.RegisterInstanceOutput, error)
UpdateService(ctx context.Context, params *sd.UpdateServiceInput, optFns ...func(*sd.Options)) (*sd.UpdateServiceOutput, error)
DeleteService(ctx context.Context, params *sd.DeleteServiceInput, optFns ...func(*sd.Options)) (*sd.DeleteServiceOutput, error)
}
// AWSSDProvider is an implementation of Provider for AWS Cloud Map.
@ -74,7 +74,7 @@ type AWSSDProvider struct {
// only consider namespaces ending in this suffix
namespaceFilter endpoint.DomainFilter
// filter namespace by type (private or public)
namespaceTypeFilter *sd.NamespaceFilter
namespaceTypeFilter sdtypes.NamespaceFilter
// enables service without instances cleanup
cleanEmptyService bool
// filter services for removal
@ -83,7 +83,7 @@ type AWSSDProvider struct {
// NewAWSSDProvider initializes a new AWS Cloud Map based Provider.
func NewAWSSDProvider(domainFilter endpoint.DomainFilter, namespaceType string, dryRun, cleanEmptyService bool, ownerID string, client AWSSDClient) (*AWSSDProvider, error) {
provider := &AWSSDProvider{
p := &AWSSDProvider{
client: client,
dryRun: dryRun,
namespaceFilter: domainFilter,
@ -92,42 +92,42 @@ func NewAWSSDProvider(domainFilter endpoint.DomainFilter, namespaceType string,
ownerID: ownerID,
}
return provider, nil
return p, nil
}
// newSdNamespaceFilter initialized AWS SD Namespace Filter based on given string config
func newSdNamespaceFilter(namespaceTypeConfig string) *sd.NamespaceFilter {
func newSdNamespaceFilter(namespaceTypeConfig string) sdtypes.NamespaceFilter {
switch namespaceTypeConfig {
case sdNamespaceTypePublic:
return &sd.NamespaceFilter{
Name: aws.String(sd.NamespaceFilterNameType),
Values: []*string{aws.String(sd.NamespaceTypeDnsPublic)},
return sdtypes.NamespaceFilter{
Name: sdtypes.NamespaceFilterNameType,
Values: []string{string(sdtypes.NamespaceTypeDnsPublic)},
}
case sdNamespaceTypePrivate:
return &sd.NamespaceFilter{
Name: aws.String(sd.NamespaceFilterNameType),
Values: []*string{aws.String(sd.NamespaceTypeDnsPrivate)},
return sdtypes.NamespaceFilter{
Name: sdtypes.NamespaceFilterNameType,
Values: []string{string(sdtypes.NamespaceTypeDnsPrivate)},
}
default:
return nil
return sdtypes.NamespaceFilter{}
}
}
// Records returns list of all endpoints.
func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) {
namespaces, err := p.ListNamespaces()
namespaces, err := p.ListNamespaces(ctx)
if err != nil {
return nil, err
}
for _, ns := range namespaces {
services, err := p.ListServicesByNamespaceID(ns.Id)
services, err := p.ListServicesByNamespaceID(ctx, ns.Id)
if err != nil {
return nil, err
}
for _, srv := range services {
resp, err := p.client.DiscoverInstancesWithContext(ctx, &sd.DiscoverInstancesInput{
resp, err := p.client.DiscoverInstances(ctx, &sd.DiscoverInstancesInput{
NamespaceName: ns.Name,
ServiceName: srv.Name,
})
@ -136,8 +136,8 @@ func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp
}
if len(resp.Instances) == 0 {
if err := p.DeleteService(srv); err != nil {
log.Errorf("Failed to delete service %q, error: %s", aws.StringValue(srv.Name), err)
if err := p.DeleteService(ctx, srv); err != nil {
log.Errorf("Failed to delete service %q, error: %s", *srv.Name, err)
}
continue
}
@ -149,35 +149,35 @@ func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp
return endpoints, nil
}
func (p *AWSSDProvider) instancesToEndpoint(ns *sd.NamespaceSummary, srv *sd.Service, instances []*sd.HttpInstanceSummary) *endpoint.Endpoint {
func (p *AWSSDProvider) instancesToEndpoint(ns *sdtypes.NamespaceSummary, srv *sdtypes.Service, instances []sdtypes.HttpInstanceSummary) *endpoint.Endpoint {
// DNS name of the record is a concatenation of service and namespace
recordName := *srv.Name + "." + *ns.Name
labels := endpoint.NewLabels()
labels[endpoint.AWSSDDescriptionLabel] = aws.StringValue(srv.Description)
labels[endpoint.AWSSDDescriptionLabel] = *srv.Description
newEndpoint := &endpoint.Endpoint{
DNSName: recordName,
RecordTTL: endpoint.TTL(aws.Int64Value(srv.DnsConfig.DnsRecords[0].TTL)),
RecordTTL: endpoint.TTL(*srv.DnsConfig.DnsRecords[0].TTL),
Targets: make(endpoint.Targets, 0, len(instances)),
Labels: labels,
}
for _, inst := range instances {
// CNAME
if inst.Attributes[sdInstanceAttrCname] != nil && aws.StringValue(srv.DnsConfig.DnsRecords[0].Type) == sd.RecordTypeCname {
if inst.Attributes[sdInstanceAttrCname] != "" && srv.DnsConfig.DnsRecords[0].Type == sdtypes.RecordTypeCname {
newEndpoint.RecordType = endpoint.RecordTypeCNAME
newEndpoint.Targets = append(newEndpoint.Targets, aws.StringValue(inst.Attributes[sdInstanceAttrCname]))
newEndpoint.Targets = append(newEndpoint.Targets, inst.Attributes[sdInstanceAttrCname])
// ALIAS
} else if inst.Attributes[sdInstanceAttrAlias] != nil {
} else if inst.Attributes[sdInstanceAttrAlias] != "" {
newEndpoint.RecordType = endpoint.RecordTypeCNAME
newEndpoint.Targets = append(newEndpoint.Targets, aws.StringValue(inst.Attributes[sdInstanceAttrAlias]))
newEndpoint.Targets = append(newEndpoint.Targets, inst.Attributes[sdInstanceAttrAlias])
// IP-based target
} else if inst.Attributes[sdInstanceAttrIPV4] != nil {
} else if inst.Attributes[sdInstanceAttrIPV4] != "" {
newEndpoint.RecordType = endpoint.RecordTypeA
newEndpoint.Targets = append(newEndpoint.Targets, aws.StringValue(inst.Attributes[sdInstanceAttrIPV4]))
newEndpoint.Targets = append(newEndpoint.Targets, inst.Attributes[sdInstanceAttrIPV4])
} else {
log.Warnf("Invalid instance \"%v\" found in service \"%v\"", inst, srv.Name)
}
@ -199,7 +199,7 @@ func (p *AWSSDProvider) ApplyChanges(ctx context.Context, changes *plan.Changes)
changes.Delete = append(changes.Delete, deletes...)
changes.Create = append(changes.Create, creates...)
namespaces, err := p.ListNamespaces()
namespaces, err := p.ListNamespaces(ctx)
if err != nil {
return err
}
@ -211,12 +211,12 @@ func (p *AWSSDProvider) ApplyChanges(ctx context.Context, changes *plan.Changes)
// creates = [1.2.3.4, 1.2.3.5]
// ```
// then when deletes are executed after creates it will miss the `1.2.3.4` instance.
err = p.submitDeletes(namespaces, changes.Delete)
err = p.submitDeletes(ctx, namespaces, changes.Delete)
if err != nil {
return err
}
err = p.submitCreates(namespaces, changes.Create)
err = p.submitCreates(ctx, namespaces, changes.Create)
if err != nil {
return err
}
@ -245,11 +245,11 @@ func (p *AWSSDProvider) updatesToCreates(changes *plan.Changes) (creates []*endp
return creates, deletes
}
func (p *AWSSDProvider) submitCreates(namespaces []*sd.NamespaceSummary, changes []*endpoint.Endpoint) error {
func (p *AWSSDProvider) submitCreates(ctx context.Context, namespaces []*sdtypes.NamespaceSummary, changes []*endpoint.Endpoint) error {
changesByNamespaceID := p.changesByNamespaceID(namespaces, changes)
for nsID, changeList := range changesByNamespaceID {
services, err := p.ListServicesByNamespaceID(aws.String(nsID))
services, err := p.ListServicesByNamespaceID(ctx, aws.String(nsID))
if err != nil {
return err
}
@ -260,7 +260,7 @@ func (p *AWSSDProvider) submitCreates(namespaces []*sd.NamespaceSummary, changes
srv := services[srvName]
if srv == nil {
// when service is missing create a new one
srv, err = p.CreateService(&nsID, &srvName, ch)
srv, err = p.CreateService(ctx, &nsID, &srvName, ch)
if err != nil {
return err
}
@ -268,13 +268,13 @@ func (p *AWSSDProvider) submitCreates(namespaces []*sd.NamespaceSummary, changes
services[*srv.Name] = srv
} else if ch.RecordTTL.IsConfigured() && *srv.DnsConfig.DnsRecords[0].TTL != int64(ch.RecordTTL) {
// update service when TTL differ
err = p.UpdateService(srv, ch)
err = p.UpdateService(ctx, srv, ch)
if err != nil {
return err
}
}
err = p.RegisterInstance(srv, ch)
err = p.RegisterInstance(ctx, srv, ch)
if err != nil {
return err
}
@ -284,11 +284,11 @@ func (p *AWSSDProvider) submitCreates(namespaces []*sd.NamespaceSummary, changes
return nil
}
func (p *AWSSDProvider) submitDeletes(namespaces []*sd.NamespaceSummary, changes []*endpoint.Endpoint) error {
func (p *AWSSDProvider) submitDeletes(ctx context.Context, namespaces []*sdtypes.NamespaceSummary, changes []*endpoint.Endpoint) error {
changesByNamespaceID := p.changesByNamespaceID(namespaces, changes)
for nsID, changeList := range changesByNamespaceID {
services, err := p.ListServicesByNamespaceID(aws.String(nsID))
services, err := p.ListServicesByNamespaceID(ctx, aws.String(nsID))
if err != nil {
return err
}
@ -302,7 +302,7 @@ func (p *AWSSDProvider) submitDeletes(namespaces []*sd.NamespaceSummary, changes
return fmt.Errorf("service \"%s\" is missing when trying to delete \"%v\"", srvName, hostname)
}
err := p.DeregisterInstance(srv, ch)
err := p.DeregisterInstance(ctx, srv, ch)
if err != nil {
return err
}
@ -313,53 +313,51 @@ func (p *AWSSDProvider) submitDeletes(namespaces []*sd.NamespaceSummary, changes
}
// ListNamespaces returns all namespaces matching defined namespace filter
func (p *AWSSDProvider) ListNamespaces() ([]*sd.NamespaceSummary, error) {
namespaces := make([]*sd.NamespaceSummary, 0)
func (p *AWSSDProvider) ListNamespaces(ctx context.Context) ([]*sdtypes.NamespaceSummary, error) {
namespaces := make([]*sdtypes.NamespaceSummary, 0)
f := func(resp *sd.ListNamespacesOutput, lastPage bool) bool {
for _, ns := range resp.Namespaces {
if !p.namespaceFilter.Match(aws.StringValue(ns.Name)) {
continue
}
namespaces = append(namespaces, ns)
paginator := sd.NewListNamespacesPaginator(p.client, &sd.ListNamespacesInput{
Filters: []sdtypes.NamespaceFilter{p.namespaceTypeFilter},
})
for paginator.HasMorePages() {
resp, err := paginator.NextPage(ctx)
if err != nil {
return nil, err
}
return true
}
err := p.client.ListNamespacesPages(&sd.ListNamespacesInput{
Filters: []*sd.NamespaceFilter{p.namespaceTypeFilter},
}, f)
if err != nil {
return nil, err
for _, ns := range resp.Namespaces {
if !p.namespaceFilter.Match(*ns.Name) {
continue
}
namespaces = append(namespaces, &ns)
}
}
return namespaces, nil
}
// ListServicesByNamespaceID returns list of services in given namespace. Returns map[srv_name]*sd.Service
func (p *AWSSDProvider) ListServicesByNamespaceID(namespaceID *string) (map[string]*sd.Service, error) {
services := make([]*sd.ServiceSummary, 0)
// ListServicesByNamespaceID returns list of services in given namespace.
func (p *AWSSDProvider) ListServicesByNamespaceID(ctx context.Context, namespaceID *string) (map[string]*sdtypes.Service, error) {
services := make([]sdtypes.ServiceSummary, 0)
f := func(resp *sd.ListServicesOutput, lastPage bool) bool {
services = append(services, resp.Services...)
return true
}
err := p.client.ListServicesPages(&sd.ListServicesInput{
Filters: []*sd.ServiceFilter{{
Name: aws.String(sd.ServiceFilterNameNamespaceId),
Values: []*string{namespaceID},
paginator := sd.NewListServicesPaginator(p.client, &sd.ListServicesInput{
Filters: []sdtypes.ServiceFilter{{
Name: sdtypes.ServiceFilterNameNamespaceId,
Values: []string{*namespaceID},
}},
MaxResults: aws.Int64(100),
}, f)
if err != nil {
return nil, err
MaxResults: aws.Int32(100),
})
for paginator.HasMorePages() {
resp, err := paginator.NextPage(ctx)
if err != nil {
return nil, err
}
services = append(services, resp.Services...)
}
servicesMap := make(map[string]*sd.Service)
servicesMap := make(map[string]*sdtypes.Service)
for _, serviceSummary := range services {
service := &sd.Service{
service := &sdtypes.Service{
Arn: serviceSummary.Arn,
CreateDate: serviceSummary.CreateDate,
Description: serviceSummary.Description,
@ -373,13 +371,13 @@ func (p *AWSSDProvider) ListServicesByNamespaceID(namespaceID *string) (map[stri
Type: serviceSummary.Type,
}
servicesMap[aws.StringValue(service.Name)] = service
servicesMap[*service.Name] = service
}
return servicesMap, nil
}
// CreateService creates a new service in AWS API. Returns the created service.
func (p *AWSSDProvider) CreateService(namespaceID *string, srvName *string, ep *endpoint.Endpoint) (*sd.Service, error) {
func (p *AWSSDProvider) CreateService(ctx context.Context, namespaceID *string, srvName *string, ep *endpoint.Endpoint) (*sdtypes.Service, error) {
log.Infof("Creating a new service \"%s\" in \"%s\" namespace", *srvName, *namespaceID)
srvType := p.serviceTypeFromEndpoint(ep)
@ -391,13 +389,13 @@ func (p *AWSSDProvider) CreateService(namespaceID *string, srvName *string, ep *
}
if !p.dryRun {
out, err := p.client.CreateService(&sd.CreateServiceInput{
out, err := p.client.CreateService(ctx, &sd.CreateServiceInput{
Name: srvName,
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sd.DnsConfig{
RoutingPolicy: aws.String(routingPolicy),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(srvType),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: routingPolicy,
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
@ -411,11 +409,11 @@ func (p *AWSSDProvider) CreateService(namespaceID *string, srvName *string, ep *
}
// return mock service summary in case of dry run
return &sd.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil
return &sdtypes.Service{Id: aws.String("dry-run-service"), Name: aws.String("dry-run-service")}, nil
}
// UpdateService updates the specified service with information from provided endpoint.
func (p *AWSSDProvider) UpdateService(service *sd.Service, ep *endpoint.Endpoint) error {
func (p *AWSSDProvider) UpdateService(ctx context.Context, service *sdtypes.Service, ep *endpoint.Endpoint) error {
log.Infof("Updating service \"%s\"", *service.Name)
srvType := p.serviceTypeFromEndpoint(ep)
@ -426,13 +424,13 @@ func (p *AWSSDProvider) UpdateService(service *sd.Service, ep *endpoint.Endpoint
}
if !p.dryRun {
_, err := p.client.UpdateService(&sd.UpdateServiceInput{
_, err := p.client.UpdateService(ctx, &sd.UpdateServiceInput{
Id: service.Id,
Service: &sd.ServiceChange{
Service: &sdtypes.ServiceChange{
Description: aws.String(ep.Labels[endpoint.AWSSDDescriptionLabel]),
DnsConfig: &sd.DnsConfigChange{
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(srvType),
DnsConfig: &sdtypes.DnsConfigChange{
DnsRecords: []sdtypes.DnsRecord{{
Type: srvType,
TTL: aws.Int64(ttl),
}},
},
@ -447,7 +445,7 @@ func (p *AWSSDProvider) UpdateService(service *sd.Service, ep *endpoint.Endpoint
}
// DeleteService deletes empty Service from AWS API if its owner id match
func (p *AWSSDProvider) DeleteService(service *sd.Service) error {
func (p *AWSSDProvider) DeleteService(ctx context.Context, service *sdtypes.Service) error {
log.Debugf("Check if service \"%s\" owner id match and it can be deleted", *service.Name)
if !p.dryRun && p.cleanEmptyService {
// convert ownerID string to service description format
@ -455,39 +453,39 @@ func (p *AWSSDProvider) DeleteService(service *sd.Service) error {
label[endpoint.OwnerLabelKey] = p.ownerID
label[endpoint.AWSSDDescriptionLabel] = label.SerializePlain(false)
if strings.HasPrefix(aws.StringValue(service.Description), label[endpoint.AWSSDDescriptionLabel]) {
if strings.HasPrefix(*service.Description, label[endpoint.AWSSDDescriptionLabel]) {
log.Infof("Deleting service \"%s\"", *service.Name)
_, err := p.client.DeleteService(&sd.DeleteServiceInput{
_, err := p.client.DeleteService(ctx, &sd.DeleteServiceInput{
Id: aws.String(*service.Id),
})
return err
}
log.Debugf("Skipping service removal %s because owner id does not match, found: \"%s\", required: \"%s\"", aws.StringValue(service.Name), aws.StringValue(service.Description), label[endpoint.AWSSDDescriptionLabel])
log.Debugf("Skipping service removal %s because owner id does not match, found: \"%s\", required: \"%s\"", *service.Name, *service.Description, label[endpoint.AWSSDDescriptionLabel])
}
return nil
}
// RegisterInstance creates a new instance in given service.
func (p *AWSSDProvider) RegisterInstance(service *sd.Service, ep *endpoint.Endpoint) error {
func (p *AWSSDProvider) RegisterInstance(ctx context.Context, service *sdtypes.Service, ep *endpoint.Endpoint) error {
for _, target := range ep.Targets {
log.Infof("Registering a new instance \"%s\" for service \"%s\" (%s)", target, *service.Name, *service.Id)
attr := make(map[string]*string)
attr := make(map[string]string)
if ep.RecordType == endpoint.RecordTypeCNAME {
if p.isAWSLoadBalancer(target) {
attr[sdInstanceAttrAlias] = aws.String(target)
attr[sdInstanceAttrAlias] = target
} else {
attr[sdInstanceAttrCname] = aws.String(target)
attr[sdInstanceAttrCname] = target
}
} else if ep.RecordType == endpoint.RecordTypeA {
attr[sdInstanceAttrIPV4] = aws.String(target)
attr[sdInstanceAttrIPV4] = target
} else {
return fmt.Errorf("invalid endpoint type (%v)", ep)
}
if !p.dryRun {
_, err := p.client.RegisterInstance(&sd.RegisterInstanceInput{
_, err := p.client.RegisterInstance(ctx, &sd.RegisterInstanceInput{
ServiceId: service.Id,
Attributes: attr,
InstanceId: aws.String(p.targetToInstanceID(target)),
@ -502,12 +500,12 @@ func (p *AWSSDProvider) RegisterInstance(service *sd.Service, ep *endpoint.Endpo
}
// DeregisterInstance removes an instance from given service.
func (p *AWSSDProvider) DeregisterInstance(service *sd.Service, ep *endpoint.Endpoint) error {
func (p *AWSSDProvider) DeregisterInstance(ctx context.Context, service *sdtypes.Service, ep *endpoint.Endpoint) error {
for _, target := range ep.Targets {
log.Infof("De-registering an instance \"%s\" for service \"%s\" (%s)", target, *service.Name, *service.Id)
if !p.dryRun {
_, err := p.client.DeregisterInstance(&sd.DeregisterInstanceInput{
_, err := p.client.DeregisterInstance(ctx, &sd.DeregisterInstanceInput{
InstanceId: aws.String(p.targetToInstanceID(target)),
ServiceId: service.Id,
})
@ -531,43 +529,7 @@ func (p *AWSSDProvider) targetToInstanceID(target string) string {
return strings.ToLower(target)
}
// nolint: deadcode
// used from unit test
func namespaceToNamespaceSummary(namespace *sd.Namespace) *sd.NamespaceSummary {
if namespace == nil {
return nil
}
return &sd.NamespaceSummary{
Id: namespace.Id,
Type: namespace.Type,
Name: namespace.Name,
Arn: namespace.Arn,
}
}
// nolint: deadcode
// used from unit test
func serviceToServiceSummary(service *sd.Service) *sd.ServiceSummary {
if service == nil {
return nil
}
return &sd.ServiceSummary{
Arn: service.Arn,
CreateDate: service.CreateDate,
Description: service.Description,
DnsConfig: service.DnsConfig,
HealthCheckConfig: service.HealthCheckConfig,
HealthCheckCustomConfig: service.HealthCheckCustomConfig,
Id: service.Id,
InstanceCount: service.InstanceCount,
Name: service.Name,
Type: service.Type,
}
}
func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sd.NamespaceSummary, changes []*endpoint.Endpoint) map[string][]*endpoint.Endpoint {
func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sdtypes.NamespaceSummary, changes []*endpoint.Endpoint) map[string][]*endpoint.Endpoint {
changesByNsID := make(map[string][]*endpoint.Endpoint)
for _, ns := range namespaces {
@ -600,8 +562,8 @@ func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sd.NamespaceSummary,
}
// returns list of all namespaces matching given hostname
func matchingNamespaces(hostname string, namespaces []*sd.NamespaceSummary) []*sd.NamespaceSummary {
matchingNamespaces := make([]*sd.NamespaceSummary, 0)
func matchingNamespaces(hostname string, namespaces []*sdtypes.NamespaceSummary) []*sdtypes.NamespaceSummary {
matchingNamespaces := make([]*sdtypes.NamespaceSummary, 0)
for _, ns := range namespaces {
if *ns.Name == hostname {
@ -621,26 +583,26 @@ func (p *AWSSDProvider) parseHostname(hostname string) (namespace string, servic
}
// determine service routing policy based on endpoint type
func (p *AWSSDProvider) routingPolicyFromEndpoint(ep *endpoint.Endpoint) string {
func (p *AWSSDProvider) routingPolicyFromEndpoint(ep *endpoint.Endpoint) sdtypes.RoutingPolicy {
if ep.RecordType == endpoint.RecordTypeA {
return sd.RoutingPolicyMultivalue
return sdtypes.RoutingPolicyMultivalue
}
return sd.RoutingPolicyWeighted
return sdtypes.RoutingPolicyWeighted
}
// determine service type (A, CNAME) from given endpoint
func (p *AWSSDProvider) serviceTypeFromEndpoint(ep *endpoint.Endpoint) string {
func (p *AWSSDProvider) serviceTypeFromEndpoint(ep *endpoint.Endpoint) sdtypes.RecordType {
if ep.RecordType == endpoint.RecordTypeCNAME {
// FIXME service type is derived from the first target only. Theoretically this may be problem.
// But I don't see a scenario where one endpoint contains targets of different types.
if p.isAWSLoadBalancer(ep.Targets[0]) {
// ALIAS target uses DNS record type of A
return sd.RecordTypeA
// ALIAS target uses DNS record of type A
return sdtypes.RecordTypeA
}
return sd.RecordTypeCname
return sdtypes.RecordTypeCname
}
return sd.RecordTypeA
return sdtypes.RecordTypeA
}
// determine if a given hostname belongs to an AWS load balancer

View File

@ -25,9 +25,9 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/aws/aws-sdk-go-v2/aws"
sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
sdtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -45,17 +45,17 @@ var (
type AWSSDClientStub struct {
// map[namespace_id]namespace
namespaces map[string]*sd.Namespace
namespaces map[string]*sdtypes.Namespace
// map[namespace_id] => map[service_id]instance
services map[string]map[string]*sd.Service
services map[string]map[string]*sdtypes.Service
// map[service_id] => map[inst_id]instance
instances map[string]map[string]*sd.Instance
instances map[string]map[string]*sdtypes.Instance
}
func (s *AWSSDClientStub) CreateService(input *sd.CreateServiceInput) (*sd.CreateServiceOutput, error) {
srv := &sd.Service{
func (s *AWSSDClientStub) CreateService(ctx context.Context, input *sd.CreateServiceInput, optFns ...func(*sd.Options)) (*sd.CreateServiceOutput, error) {
srv := &sdtypes.Service{
Id: aws.String(strconv.Itoa(rand.Intn(10000))),
DnsConfig: input.DnsConfig,
Name: input.Name,
@ -66,7 +66,7 @@ func (s *AWSSDClientStub) CreateService(input *sd.CreateServiceInput) (*sd.Creat
nsServices, ok := s.services[*input.NamespaceId]
if !ok {
nsServices = make(map[string]*sd.Service)
nsServices = make(map[string]*sdtypes.Service)
s.services[*input.NamespaceId] = nsServices
}
nsServices[*srv.Id] = srv
@ -76,14 +76,14 @@ func (s *AWSSDClientStub) CreateService(input *sd.CreateServiceInput) (*sd.Creat
}, nil
}
func (s *AWSSDClientStub) DeregisterInstance(input *sd.DeregisterInstanceInput) (*sd.DeregisterInstanceOutput, error) {
func (s *AWSSDClientStub) DeregisterInstance(ctx context.Context, input *sd.DeregisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.DeregisterInstanceOutput, error) {
serviceInstances := s.instances[*input.ServiceId]
delete(serviceInstances, *input.InstanceId)
return &sd.DeregisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) GetService(input *sd.GetServiceInput) (*sd.GetServiceOutput, error) {
func (s *AWSSDClientStub) GetService(ctx context.Context, input *sd.GetServiceInput, optFns ...func(options *sd.Options)) (*sd.GetServiceOutput, error) {
for _, entry := range s.services {
srv, ok := entry[*input.Id]
if ok {
@ -96,18 +96,18 @@ func (s *AWSSDClientStub) GetService(input *sd.GetServiceInput) (*sd.GetServiceO
return nil, errors.New("service not found")
}
func (s *AWSSDClientStub) DiscoverInstancesWithContext(ctx context.Context, input *sd.DiscoverInstancesInput, opts ...request.Option) (*sd.DiscoverInstancesOutput, error) {
instances := make([]*sd.HttpInstanceSummary, 0)
func (s *AWSSDClientStub) DiscoverInstances(ctx context.Context, input *sd.DiscoverInstancesInput, opts ...func(options *sd.Options)) (*sd.DiscoverInstancesOutput, error) {
instances := make([]sdtypes.HttpInstanceSummary, 0)
var foundNs bool
for _, ns := range s.namespaces {
if aws.StringValue(ns.Name) == aws.StringValue(input.NamespaceName) {
if *ns.Name == *input.NamespaceName {
foundNs = true
for _, srv := range s.services[*ns.Id] {
if aws.StringValue(srv.Name) == aws.StringValue(input.ServiceName) {
if *srv.Name == *input.ServiceName {
for _, inst := range s.instances[*srv.Id] {
instances = append(instances, instanceToHTTPInstanceSummary(inst))
instances = append(instances, *instanceToHTTPInstanceSummary(inst))
}
}
}
@ -123,57 +123,50 @@ func (s *AWSSDClientStub) DiscoverInstancesWithContext(ctx context.Context, inpu
}, nil
}
func (s *AWSSDClientStub) ListNamespacesPages(input *sd.ListNamespacesInput, fn func(*sd.ListNamespacesOutput, bool) bool) error {
namespaces := make([]*sd.NamespaceSummary, 0)
filter := input.Filters[0]
func (s *AWSSDClientStub) ListNamespaces(ctx context.Context, input *sd.ListNamespacesInput, optFns ...func(options *sd.Options)) (*sd.ListNamespacesOutput, error) {
namespaces := make([]sdtypes.NamespaceSummary, 0)
for _, ns := range s.namespaces {
if filter != nil && *filter.Name == sd.NamespaceFilterNameType {
if *ns.Type != *filter.Values[0] {
if len(input.Filters) > 0 && input.Filters[0].Name == sdtypes.NamespaceFilterNameType {
if ns.Type != sdtypes.NamespaceType(input.Filters[0].Values[0]) {
// skip namespaces not matching filter
continue
}
}
namespaces = append(namespaces, namespaceToNamespaceSummary(ns))
namespaces = append(namespaces, *namespaceToNamespaceSummary(ns))
}
fn(&sd.ListNamespacesOutput{
return &sd.ListNamespacesOutput{
Namespaces: namespaces,
}, true)
return nil
}, nil
}
func (s *AWSSDClientStub) ListServicesPages(input *sd.ListServicesInput, fn func(*sd.ListServicesOutput, bool) bool) error {
services := make([]*sd.ServiceSummary, 0)
func (s *AWSSDClientStub) ListServices(ctx context.Context, input *sd.ListServicesInput, optFns ...func(options *sd.Options)) (*sd.ListServicesOutput, error) {
services := make([]sdtypes.ServiceSummary, 0)
// get namespace filter
filter := input.Filters[0]
if filter == nil || *filter.Name != sd.ServiceFilterNameNamespaceId {
return errors.New("missing namespace filter")
if len(input.Filters) == 0 || input.Filters[0].Name != sdtypes.ServiceFilterNameNamespaceId {
return nil, errors.New("missing namespace filter")
}
nsID := filter.Values[0]
nsID := input.Filters[0].Values[0]
for _, srv := range s.services[*nsID] {
services = append(services, serviceToServiceSummary(srv))
for _, srv := range s.services[nsID] {
services = append(services, *serviceToServiceSummary(srv))
}
fn(&sd.ListServicesOutput{
return &sd.ListServicesOutput{
Services: services,
}, true)
return nil
}, nil
}
func (s *AWSSDClientStub) RegisterInstance(input *sd.RegisterInstanceInput) (*sd.RegisterInstanceOutput, error) {
func (s *AWSSDClientStub) RegisterInstance(ctx context.Context, input *sd.RegisterInstanceInput, optFns ...func(options *sd.Options)) (*sd.RegisterInstanceOutput, error) {
srvInstances, ok := s.instances[*input.ServiceId]
if !ok {
srvInstances = make(map[string]*sd.Instance)
srvInstances = make(map[string]*sdtypes.Instance)
s.instances[*input.ServiceId] = srvInstances
}
srvInstances[*input.InstanceId] = &sd.Instance{
srvInstances[*input.InstanceId] = &sdtypes.Instance{
Id: input.InstanceId,
Attributes: input.Attributes,
CreatorRequestId: input.CreatorRequestId,
@ -182,8 +175,8 @@ func (s *AWSSDClientStub) RegisterInstance(input *sd.RegisterInstanceInput) (*sd
return &sd.RegisterInstanceOutput{}, nil
}
func (s *AWSSDClientStub) UpdateService(input *sd.UpdateServiceInput) (*sd.UpdateServiceOutput, error) {
out, err := s.GetService(&sd.GetServiceInput{Id: input.Id})
func (s *AWSSDClientStub) UpdateService(ctx context.Context, input *sd.UpdateServiceInput, optFns ...func(options *sd.Options)) (*sd.UpdateServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
@ -197,8 +190,8 @@ func (s *AWSSDClientStub) UpdateService(input *sd.UpdateServiceInput) (*sd.Updat
return &sd.UpdateServiceOutput{}, nil
}
func (s *AWSSDClientStub) DeleteService(input *sd.DeleteServiceInput) (*sd.DeleteServiceOutput, error) {
out, err := s.GetService(&sd.GetServiceInput{Id: input.Id})
func (s *AWSSDClientStub) DeleteService(ctx context.Context, input *sd.DeleteServiceInput, optFns ...func(options *sd.Options)) (*sd.DeleteServiceOutput, error) {
out, err := s.GetService(ctx, &sd.GetServiceInput{Id: input.Id})
if err != nil {
return nil, err
}
@ -221,39 +214,69 @@ func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, n
}
}
// nolint: deadcode
// used for unit test
func instanceToHTTPInstanceSummary(instance *sd.Instance) *sd.HttpInstanceSummary {
func instanceToHTTPInstanceSummary(instance *sdtypes.Instance) *sdtypes.HttpInstanceSummary {
if instance == nil {
return nil
}
return &sd.HttpInstanceSummary{
return &sdtypes.HttpInstanceSummary{
InstanceId: instance.Id,
Attributes: instance.Attributes,
}
}
func namespaceToNamespaceSummary(namespace *sdtypes.Namespace) *sdtypes.NamespaceSummary {
if namespace == nil {
return nil
}
return &sdtypes.NamespaceSummary{
Id: namespace.Id,
Type: namespace.Type,
Name: namespace.Name,
Arn: namespace.Arn,
}
}
func serviceToServiceSummary(service *sdtypes.Service) *sdtypes.ServiceSummary {
if service == nil {
return nil
}
return &sdtypes.ServiceSummary{
Arn: service.Arn,
CreateDate: service.CreateDate,
Description: service.Description,
DnsConfig: service.DnsConfig,
HealthCheckConfig: service.HealthCheckConfig,
HealthCheckCustomConfig: service.HealthCheckCustomConfig,
Id: service.Id,
InstanceCount: service.InstanceCount,
Name: service.Name,
Type: service.Type,
}
}
func TestAWSSDProvider_Records(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"a-srv": {
Id: aws.String("a-srv"),
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
Description: aws.String("owner-id"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(100),
}},
},
@ -261,12 +284,12 @@ func TestAWSSDProvider_Records(t *testing.T) {
"alias-srv": {
Id: aws.String("alias-srv"),
Name: aws.String("service2"),
NamespaceId: aws.String("private"),
Description: aws.String("owner-id"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(100),
}},
},
@ -274,12 +297,12 @@ func TestAWSSDProvider_Records(t *testing.T) {
"cname-srv": {
Id: aws.String("cname-srv"),
Name: aws.String("service3"),
NamespaceId: aws.String("private"),
Description: aws.String("owner-id"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeCname),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeCname,
TTL: aws.Int64(80),
}},
},
@ -287,34 +310,34 @@ func TestAWSSDProvider_Records(t *testing.T) {
},
}
instances := map[string]map[string]*sd.Instance{
instances := map[string]map[string]*sdtypes.Instance{
"a-srv": {
"1.2.3.4": {
Id: aws.String("1.2.3.4"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.4"),
Attributes: map[string]string{
sdInstanceAttrIPV4: "1.2.3.4",
},
},
"1.2.3.5": {
Id: aws.String("1.2.3.5"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.5"),
Attributes: map[string]string{
sdInstanceAttrIPV4: "1.2.3.5",
},
},
},
"alias-srv": {
"load-balancer.us-east-1.elb.amazonaws.com": {
Id: aws.String("load-balancer.us-east-1.elb.amazonaws.com"),
Attributes: map[string]*string{
sdInstanceAttrAlias: aws.String("load-balancer.us-east-1.elb.amazonaws.com"),
Attributes: map[string]string{
sdInstanceAttrAlias: "load-balancer.us-east-1.elb.amazonaws.com",
},
},
},
"cname-srv": {
"cname.target.com": {
Id: aws.String("cname.target.com"),
Attributes: map[string]*string{
sdInstanceAttrCname: aws.String("cname.target.com"),
Attributes: map[string]string{
sdInstanceAttrCname: "cname.target.com",
},
},
},
@ -340,18 +363,18 @@ func TestAWSSDProvider_Records(t *testing.T) {
}
func TestAWSSDProvider_ApplyChanges(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: make(map[string]map[string]*sd.Service),
instances: make(map[string]map[string]*sd.Instance),
services: make(map[string]map[string]*sdtypes.Service),
instances: make(map[string]map[string]*sdtypes.Instance),
}
expectedEndpoints := []*endpoint.Endpoint{
@ -371,7 +394,7 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
// make sure services were created
assert.Len(t, api.services["private"], 3)
existingServices, _ := provider.ListServicesByNamespaceID(namespaces["private"].Id)
existingServices, _ := provider.ListServicesByNamespaceID(context.Background(), namespaces["private"].Id)
assert.NotNil(t, existingServices["service1"])
assert.NotNil(t, existingServices["service2"])
assert.NotNil(t, existingServices["service3"])
@ -392,16 +415,16 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
}
func TestAWSSDProvider_ListNamespaces(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
"public": {
Id: aws.String("public"),
Name: aws.String("public.com"),
Type: aws.String(sd.NamespaceTypeDnsPublic),
Type: sdtypes.NamespaceTypeDnsPublic,
},
}
@ -413,20 +436,20 @@ func TestAWSSDProvider_ListNamespaces(t *testing.T) {
msg string
domainFilter endpoint.DomainFilter
namespaceTypeFilter string
expectedNamespaces []*sd.NamespaceSummary
expectedNamespaces []*sdtypes.NamespaceSummary
}{
{"public filter", endpoint.NewDomainFilter([]string{}), "public", []*sd.NamespaceSummary{namespaceToNamespaceSummary(namespaces["public"])}},
{"private filter", endpoint.NewDomainFilter([]string{}), "private", []*sd.NamespaceSummary{namespaceToNamespaceSummary(namespaces["private"])}},
{"domain filter", endpoint.NewDomainFilter([]string{"public.com"}), "", []*sd.NamespaceSummary{namespaceToNamespaceSummary(namespaces["public"])}},
{"non-existing domain", endpoint.NewDomainFilter([]string{"xxx.com"}), "", []*sd.NamespaceSummary{}},
{"public filter", endpoint.NewDomainFilter([]string{}), "public", []*sdtypes.NamespaceSummary{namespaceToNamespaceSummary(namespaces["public"])}},
{"private filter", endpoint.NewDomainFilter([]string{}), "private", []*sdtypes.NamespaceSummary{namespaceToNamespaceSummary(namespaces["private"])}},
{"domain filter", endpoint.NewDomainFilter([]string{"public.com"}), "", []*sdtypes.NamespaceSummary{namespaceToNamespaceSummary(namespaces["public"])}},
{"non-existing domain", endpoint.NewDomainFilter([]string{"xxx.com"}), "", []*sdtypes.NamespaceSummary{}},
} {
provider := newTestAWSSDProvider(api, tc.domainFilter, tc.namespaceTypeFilter, "")
result, err := provider.ListNamespaces()
result, err := provider.ListNamespaces(context.Background())
require.NoError(t, err)
expectedMap := make(map[string]*sd.NamespaceSummary)
resultMap := make(map[string]*sd.NamespaceSummary)
expectedMap := make(map[string]*sdtypes.NamespaceSummary)
resultMap := make(map[string]*sdtypes.NamespaceSummary)
for _, ns := range tc.expectedNamespaces {
expectedMap[*ns.Id] = ns
}
@ -441,20 +464,20 @@ func TestAWSSDProvider_ListNamespaces(t *testing.T) {
}
func TestAWSSDProvider_ListServicesByNamespace(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
"public": {
Id: aws.String("public"),
Name: aws.String("public.com"),
Type: aws.String(sd.NamespaceTypeDnsPublic),
Type: sdtypes.NamespaceTypeDnsPublic,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
@ -482,48 +505,52 @@ func TestAWSSDProvider_ListServicesByNamespace(t *testing.T) {
}
for _, tc := range []struct {
expectedServices map[string]*sd.Service
expectedServices map[string]*sdtypes.Service
}{
{map[string]*sd.Service{"service1": services["private"]["srv1"], "service2": services["private"]["srv2"]}},
{map[string]*sdtypes.Service{"service1": services["private"]["srv1"], "service2": services["private"]["srv2"]}},
} {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
result, err := provider.ListServicesByNamespaceID(namespaces["private"].Id)
result, err := provider.ListServicesByNamespaceID(context.Background(), namespaces["private"].Id)
require.NoError(t, err)
assert.Equal(t, tc.expectedServices, result)
}
}
func TestAWSSDProvider_CreateService(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: make(map[string]map[string]*sd.Service),
services: make(map[string]map[string]*sdtypes.Service),
}
expectedServices := make(map[string]*sd.Service)
expectedServices := make(map[string]*sdtypes.Service)
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
// A type
provider.CreateService(aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
provider.CreateService(context.Background(), aws.String("private"), aws.String("A-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "A-srv",
},
RecordType: endpoint.RecordTypeA,
RecordTTL: 60,
Targets: endpoint.Targets{"1.2.3.4"},
})
expectedServices["A-srv"] = &sd.Service{
Name: aws.String("A-srv"),
DnsConfig: &sd.DnsConfig{
RoutingPolicy: aws.String(sd.RoutingPolicyMultivalue),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
expectedServices["A-srv"] = &sdtypes.Service{
Name: aws.String("A-srv"),
Description: aws.String("A-srv"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyMultivalue,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(60),
}},
},
@ -531,17 +558,21 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
}
// CNAME type
provider.CreateService(aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{
provider.CreateService(context.Background(), aws.String("private"), aws.String("CNAME-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "CNAME-srv",
},
RecordType: endpoint.RecordTypeCNAME,
RecordTTL: 80,
Targets: endpoint.Targets{"cname.target.com"},
})
expectedServices["CNAME-srv"] = &sd.Service{
Name: aws.String("CNAME-srv"),
DnsConfig: &sd.DnsConfig{
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeCname),
expectedServices["CNAME-srv"] = &sdtypes.Service{
Name: aws.String("CNAME-srv"),
Description: aws.String("CNAME-srv"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeCname,
TTL: aws.Int64(80),
}},
},
@ -549,17 +580,21 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
}
// ALIAS type
provider.CreateService(aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{
provider.CreateService(context.Background(), aws.String("private"), aws.String("ALIAS-srv"), &endpoint.Endpoint{
Labels: map[string]string{
endpoint.AWSSDDescriptionLabel: "ALIAS-srv",
},
RecordType: endpoint.RecordTypeCNAME,
RecordTTL: 100,
Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com"},
})
expectedServices["ALIAS-srv"] = &sd.Service{
Name: aws.String("ALIAS-srv"),
DnsConfig: &sd.DnsConfig{
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
expectedServices["ALIAS-srv"] = &sdtypes.Service{
Name: aws.String("ALIAS-srv"),
Description: aws.String("ALIAS-srv"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(100),
}},
},
@ -569,7 +604,7 @@ func TestAWSSDProvider_CreateService(t *testing.T) {
validateAWSSDServicesMapsEqual(t, expectedServices, api.services["private"])
}
func validateAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sd.Service, services map[string]*sd.Service) {
func validateAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sdtypes.Service, services map[string]*sdtypes.Service) {
require.Len(t, services, len(expected))
for _, srv := range services {
@ -577,31 +612,31 @@ func validateAWSSDServicesMapsEqual(t *testing.T, expected map[string]*sd.Servic
}
}
func validateAWSSDServicesEqual(t *testing.T, expected *sd.Service, srv *sd.Service) {
assert.Equal(t, aws.StringValue(expected.Description), aws.StringValue(srv.Description))
assert.Equal(t, aws.StringValue(expected.Name), aws.StringValue(srv.Name))
func validateAWSSDServicesEqual(t *testing.T, expected *sdtypes.Service, srv *sdtypes.Service) {
assert.Equal(t, *expected.Description, *srv.Description)
assert.Equal(t, *expected.Name, *srv.Name)
assert.True(t, reflect.DeepEqual(*expected.DnsConfig, *srv.DnsConfig))
}
func TestAWSSDProvider_UpdateService(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
Name: aws.String("service1"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyMultivalue),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
Id: aws.String("srv1"),
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyMultivalue,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(60),
}},
},
@ -617,7 +652,7 @@ func TestAWSSDProvider_UpdateService(t *testing.T) {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
// update service with different TTL
provider.UpdateService(services["private"]["srv1"], &endpoint.Endpoint{
provider.UpdateService(context.Background(), services["private"]["srv1"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeA,
RecordTTL: 100,
})
@ -626,15 +661,15 @@ func TestAWSSDProvider_UpdateService(t *testing.T) {
}
func TestAWSSDProvider_DeleteService(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
@ -665,16 +700,16 @@ func TestAWSSDProvider_DeleteService(t *testing.T) {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "owner-id")
// delete first service
err := provider.DeleteService(services["private"]["srv1"])
err := provider.DeleteService(context.Background(), services["private"]["srv1"])
assert.NoError(t, err)
assert.Len(t, api.services["private"], 2)
// delete third service
err1 := provider.DeleteService(services["private"]["srv3"])
err1 := provider.DeleteService(context.Background(), services["private"]["srv3"])
assert.NoError(t, err1)
assert.Len(t, api.services["private"], 1)
expectedServices := map[string]*sd.Service{
expectedServices := map[string]*sdtypes.Service{
"srv2": {
Id: aws.String("srv2"),
Description: aws.String("heritage=external-dns,external-dns/owner=owner-id"),
@ -687,48 +722,48 @@ func TestAWSSDProvider_DeleteService(t *testing.T) {
}
func TestAWSSDProvider_RegisterInstance(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"a-srv": {
Id: aws.String("a-srv"),
Name: aws.String("service1"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
Id: aws.String("a-srv"),
Name: aws.String("service1"),
NamespaceId: aws.String("private"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(60),
}},
},
},
"cname-srv": {
Id: aws.String("cname-srv"),
Name: aws.String("service2"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeCname),
Id: aws.String("cname-srv"),
Name: aws.String("service2"),
NamespaceId: aws.String("private"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeCname,
TTL: aws.Int64(60),
}},
},
},
"alias-srv": {
Id: aws.String("alias-srv"),
Name: aws.String("service3"),
DnsConfig: &sd.DnsConfig{
NamespaceId: aws.String("private"),
RoutingPolicy: aws.String(sd.RoutingPolicyWeighted),
DnsRecords: []*sd.DnsRecord{{
Type: aws.String(sd.RecordTypeA),
Id: aws.String("alias-srv"),
Name: aws.String("service3"),
NamespaceId: aws.String("private"),
DnsConfig: &sdtypes.DnsConfig{
RoutingPolicy: sdtypes.RoutingPolicyWeighted,
DnsRecords: []sdtypes.DnsRecord{{
Type: sdtypes.RecordTypeA,
TTL: aws.Int64(60),
}},
},
@ -739,78 +774,78 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) {
api := &AWSSDClientStub{
namespaces: namespaces,
services: services,
instances: make(map[string]map[string]*sd.Instance),
instances: make(map[string]map[string]*sdtypes.Instance),
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
expectedInstances := make(map[string]*sd.Instance)
expectedInstances := make(map[string]*sdtypes.Instance)
// IP-based instance
provider.RegisterInstance(services["private"]["a-srv"], &endpoint.Endpoint{
provider.RegisterInstance(context.Background(), services["private"]["a-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeA,
DNSName: "service1.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"1.2.3.4", "1.2.3.5"},
})
expectedInstances["1.2.3.4"] = &sd.Instance{
expectedInstances["1.2.3.4"] = &sdtypes.Instance{
Id: aws.String("1.2.3.4"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.4"),
Attributes: map[string]string{
sdInstanceAttrIPV4: "1.2.3.4",
},
}
expectedInstances["1.2.3.5"] = &sd.Instance{
expectedInstances["1.2.3.5"] = &sdtypes.Instance{
Id: aws.String("1.2.3.5"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.5"),
Attributes: map[string]string{
sdInstanceAttrIPV4: "1.2.3.5",
},
}
// AWS ELB instance (ALIAS)
provider.RegisterInstance(services["private"]["alias-srv"], &endpoint.Endpoint{
provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeCNAME,
DNSName: "service1.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"load-balancer.us-east-1.elb.amazonaws.com", "load-balancer.us-west-2.elb.amazonaws.com"},
})
expectedInstances["load-balancer.us-east-1.elb.amazonaws.com"] = &sd.Instance{
expectedInstances["load-balancer.us-east-1.elb.amazonaws.com"] = &sdtypes.Instance{
Id: aws.String("load-balancer.us-east-1.elb.amazonaws.com"),
Attributes: map[string]*string{
sdInstanceAttrAlias: aws.String("load-balancer.us-east-1.elb.amazonaws.com"),
Attributes: map[string]string{
sdInstanceAttrAlias: "load-balancer.us-east-1.elb.amazonaws.com",
},
}
expectedInstances["load-balancer.us-west-2.elb.amazonaws.com"] = &sd.Instance{
expectedInstances["load-balancer.us-west-2.elb.amazonaws.com"] = &sdtypes.Instance{
Id: aws.String("load-balancer.us-west-2.elb.amazonaws.com"),
Attributes: map[string]*string{
sdInstanceAttrAlias: aws.String("load-balancer.us-west-2.elb.amazonaws.com"),
Attributes: map[string]string{
sdInstanceAttrAlias: "load-balancer.us-west-2.elb.amazonaws.com",
},
}
// AWS NLB instance (ALIAS)
provider.RegisterInstance(services["private"]["alias-srv"], &endpoint.Endpoint{
provider.RegisterInstance(context.Background(), services["private"]["alias-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeCNAME,
DNSName: "service1.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"load-balancer.elb.us-west-2.amazonaws.com"},
})
expectedInstances["load-balancer.elb.us-west-2.amazonaws.com"] = &sd.Instance{
expectedInstances["load-balancer.elb.us-west-2.amazonaws.com"] = &sdtypes.Instance{
Id: aws.String("load-balancer.elb.us-west-2.amazonaws.com"),
Attributes: map[string]*string{
sdInstanceAttrAlias: aws.String("load-balancer.elb.us-west-2.amazonaws.com"),
Attributes: map[string]string{
sdInstanceAttrAlias: "load-balancer.elb.us-west-2.amazonaws.com",
},
}
// CNAME instance
provider.RegisterInstance(services["private"]["cname-srv"], &endpoint.Endpoint{
provider.RegisterInstance(context.Background(), services["private"]["cname-srv"], &endpoint.Endpoint{
RecordType: endpoint.RecordTypeCNAME,
DNSName: "service2.private.com.",
RecordTTL: 300,
Targets: endpoint.Targets{"cname.target.com"},
})
expectedInstances["cname.target.com"] = &sd.Instance{
expectedInstances["cname.target.com"] = &sdtypes.Instance{
Id: aws.String("cname.target.com"),
Attributes: map[string]*string{
sdInstanceAttrCname: aws.String("cname.target.com"),
Attributes: map[string]string{
sdInstanceAttrCname: "cname.target.com",
},
}
@ -825,15 +860,15 @@ func TestAWSSDProvider_RegisterInstance(t *testing.T) {
}
func TestAWSSDProvider_DeregisterInstance(t *testing.T) {
namespaces := map[string]*sd.Namespace{
namespaces := map[string]*sdtypes.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
Type: sdtypes.NamespaceTypeDnsPrivate,
},
}
services := map[string]map[string]*sd.Service{
services := map[string]map[string]*sdtypes.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
@ -842,12 +877,12 @@ func TestAWSSDProvider_DeregisterInstance(t *testing.T) {
},
}
instances := map[string]map[string]*sd.Instance{
instances := map[string]map[string]*sdtypes.Instance{
"srv1": {
"1.2.3.4": {
Id: aws.String("1.2.3.4"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.4"),
Attributes: map[string]string{
sdInstanceAttrIPV4: "1.2.3.4",
},
},
},
@ -861,7 +896,7 @@ func TestAWSSDProvider_DeregisterInstance(t *testing.T) {
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
provider.DeregisterInstance(services["private"]["srv1"], endpoint.NewEndpoint("srv1.private.com.", endpoint.RecordTypeA, "1.2.3.4"))
provider.DeregisterInstance(context.Background(), services["private"]["srv1"], endpoint.NewEndpoint("srv1.private.com.", endpoint.RecordTypeA, "1.2.3.4"))
assert.Len(t, instances["srv1"], 0)
}