AWSSD: Utilize DiscoverInstances instead of ListInstances (#2506)

* AWSSD: Utilize DiscoverInstances instead of ListInstances

* Fixed stylecheck

Renamed instanceToHttpInstanceSummary to instanceToHTTPInstanceSummary

* awssd use DiscoverInstancesWithContext from client directly

* updated awssd tests

fix DiscoverInstancesWithContext to implement AWSSDClient interface
drop old test, no need to cover direct calls to aws clent methods
moved instanceToHTTPInstanceSummary to _test file

* awssd log error on failed DeleteService

* updated awssd tests

* fix missing import

* awssd tests handle not found NS with DiscoverInstancesWithContext

* Update provider/awssd/aws_sd_test.go

Co-authored-by: John Gardiner Myers <jgmyers@proofpoint.com>

* Update provider/awssd/aws_sd_test.go

Co-authored-by: John Gardiner Myers <jgmyers@proofpoint.com>

---------

Co-authored-by: John Gardiner Myers <jgmyers@proofpoint.com>
This commit is contained in:
Artem Voronin 2023-10-03 17:33:43 -07:00 committed by GitHub
parent 17e9637f11
commit 4eb7e7513a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 118 deletions

View File

@ -25,6 +25,7 @@ import (
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
log "github.com/sirupsen/logrus"
@ -57,7 +58,7 @@ var (
type AWSSDClient interface {
CreateService(input *sd.CreateServiceInput) (*sd.CreateServiceOutput, error)
DeregisterInstance(input *sd.DeregisterInstanceInput) (*sd.DeregisterInstanceOutput, error)
ListInstancesPages(input *sd.ListInstancesInput, fn func(*sd.ListInstancesOutput, bool) bool) error
DiscoverInstancesWithContext(ctx aws.Context, input *sd.DiscoverInstancesInput, opts ...request.Option) (*sd.DiscoverInstancesOutput, error)
ListNamespacesPages(input *sd.ListNamespacesInput, fn func(*sd.ListNamespacesOutput, bool) bool) error
ListServicesPages(input *sd.ListServicesInput, fn func(*sd.ListServicesOutput, bool) bool) error
RegisterInstance(input *sd.RegisterInstanceInput) (*sd.RegisterInstanceOutput, error)
@ -126,28 +127,29 @@ func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp
}
for _, srv := range services {
instances, err := p.ListInstancesByServiceID(srv.Id)
resp, err := p.client.DiscoverInstancesWithContext(ctx, &sd.DiscoverInstancesInput{
NamespaceName: ns.Name,
ServiceName: srv.Name,
})
if err != nil {
return nil, err
}
if len(instances) > 0 {
ep := p.instancesToEndpoint(ns, srv, instances)
endpoints = append(endpoints, ep)
}
if len(instances) == 0 {
err = p.DeleteService(srv)
if err != nil {
log.Warnf("Failed to delete service \"%s\", error: %s", aws.StringValue(srv.Name), err)
if len(resp.Instances) == 0 {
if err := p.DeleteService(srv); err != nil {
log.Errorf("Failed to delete service %q, error: %s", aws.StringValue(srv.Name), err)
}
continue
}
endpoints = append(endpoints, p.instancesToEndpoint(ns, srv, resp.Instances))
}
}
return endpoints, nil
}
func (p *AWSSDProvider) instancesToEndpoint(ns *sd.NamespaceSummary, srv *sd.Service, instances []*sd.InstanceSummary) *endpoint.Endpoint {
func (p *AWSSDProvider) instancesToEndpoint(ns *sd.NamespaceSummary, srv *sd.Service, instances []*sd.HttpInstanceSummary) *endpoint.Endpoint {
// DNS name of the record is a concatenation of service and namespace
recordName := *srv.Name + "." + *ns.Name
@ -376,26 +378,6 @@ func (p *AWSSDProvider) ListServicesByNamespaceID(namespaceID *string) (map[stri
return servicesMap, nil
}
// ListInstancesByServiceID returns list of instances registered in given service.
func (p *AWSSDProvider) ListInstancesByServiceID(serviceID *string) ([]*sd.InstanceSummary, error) {
instances := make([]*sd.InstanceSummary, 0)
f := func(resp *sd.ListInstancesOutput, lastPage bool) bool {
instances = append(instances, resp.Instances...)
return true
}
err := p.client.ListInstancesPages(&sd.ListInstancesInput{
ServiceId: serviceID,
}, f)
if err != nil {
return nil, err
}
return instances, nil
}
// CreateService creates a new service in AWS API. Returns the created service.
func (p *AWSSDProvider) CreateService(namespaceID *string, srvName *string, ep *endpoint.Endpoint) (*sd.Service, error) {
log.Infof("Creating a new service \"%s\" in \"%s\" namespace", *srvName, *namespaceID)
@ -585,19 +567,6 @@ func serviceToServiceSummary(service *sd.Service) *sd.ServiceSummary {
}
}
// nolint: deadcode
// used from unit test
func instanceToInstanceSummary(instance *sd.Instance) *sd.InstanceSummary {
if instance == nil {
return nil
}
return &sd.InstanceSummary{
Id: instance.Id,
Attributes: instance.Attributes,
}
}
func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sd.NamespaceSummary, changes []*endpoint.Endpoint) map[string][]*endpoint.Endpoint {
changesByNsID := make(map[string][]*endpoint.Endpoint)

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -38,6 +39,10 @@ import (
// Compile time check for interface conformance
var _ AWSSDClient = &AWSSDClientStub{}
var (
ErrNamespaceNotFound = errors.New("Namespace not found")
)
type AWSSDClientStub struct {
// map[namespace_id]namespace
namespaces map[string]*sd.Namespace
@ -91,18 +96,31 @@ func (s *AWSSDClientStub) GetService(input *sd.GetServiceInput) (*sd.GetServiceO
return nil, errors.New("service not found")
}
func (s *AWSSDClientStub) ListInstancesPages(input *sd.ListInstancesInput, fn func(*sd.ListInstancesOutput, bool) bool) error {
instances := make([]*sd.InstanceSummary, 0)
func (s *AWSSDClientStub) DiscoverInstancesWithContext(ctx context.Context, input *sd.DiscoverInstancesInput, opts ...request.Option) (*sd.DiscoverInstancesOutput, error) {
instances := make([]*sd.HttpInstanceSummary, 0)
for _, inst := range s.instances[*input.ServiceId] {
instances = append(instances, instanceToInstanceSummary(inst))
var foundNs bool
for _, ns := range s.namespaces {
if aws.StringValue(ns.Name) == aws.StringValue(input.NamespaceName) {
foundNs = true
for _, srv := range s.services[*ns.Id] {
if aws.StringValue(srv.Name) == aws.StringValue(input.ServiceName) {
for _, inst := range s.instances[*srv.Id] {
instances = append(instances, instanceToHTTPInstanceSummary(inst))
}
}
}
}
}
fn(&sd.ListInstancesOutput{
Instances: instances,
}, true)
if !foundNs {
return nil, ErrNamespaceNotFound
}
return nil
return &sd.DiscoverInstancesOutput{
Instances: instances,
}, nil
}
func (s *AWSSDClientStub) ListNamespacesPages(input *sd.ListNamespacesInput, fn func(*sd.ListNamespacesOutput, bool) bool) error {
@ -203,6 +221,19 @@ func newTestAWSSDProvider(api AWSSDClient, domainFilter endpoint.DomainFilter, n
}
}
// nolint: deadcode
// used for unit test
func instanceToHTTPInstanceSummary(instance *sd.Instance) *sd.HttpInstanceSummary {
if instance == nil {
return nil
}
return &sd.HttpInstanceSummary{
InstanceId: instance.Id,
Attributes: instance.Attributes,
}
}
func TestAWSSDProvider_Records(t *testing.T) {
namespaces := map[string]*sd.Namespace{
"private": {
@ -463,72 +494,6 @@ func TestAWSSDProvider_ListServicesByNamespace(t *testing.T) {
}
}
func TestAWSSDProvider_ListInstancesByService(t *testing.T) {
namespaces := map[string]*sd.Namespace{
"private": {
Id: aws.String("private"),
Name: aws.String("private.com"),
Type: aws.String(sd.NamespaceTypeDnsPrivate),
},
}
services := map[string]map[string]*sd.Service{
"private": {
"srv1": {
Id: aws.String("srv1"),
Name: aws.String("service1"),
},
"srv2": {
Id: aws.String("srv2"),
Name: aws.String("service2"),
},
},
}
instances := map[string]map[string]*sd.Instance{
"srv1": {
"inst1": {
Id: aws.String("inst1"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.4"),
},
},
"inst2": {
Id: aws.String("inst2"),
Attributes: map[string]*string{
sdInstanceAttrIPV4: aws.String("1.2.3.5"),
},
},
},
}
api := &AWSSDClientStub{
namespaces: namespaces,
services: services,
instances: instances,
}
provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{}), "", "")
result, err := provider.ListInstancesByServiceID(services["private"]["srv1"].Id)
require.NoError(t, err)
expectedInstances := []*sd.InstanceSummary{instanceToInstanceSummary(instances["srv1"]["inst1"]), instanceToInstanceSummary(instances["srv1"]["inst2"])}
expectedMap := make(map[string]*sd.InstanceSummary)
resultMap := make(map[string]*sd.InstanceSummary)
for _, inst := range expectedInstances {
expectedMap[*inst.Id] = inst
}
for _, inst := range result {
resultMap[*inst.Id] = inst
}
if !reflect.DeepEqual(resultMap, expectedMap) {
t.Errorf("AWSSDProvider.ListInstancesByServiceID() error = %v, wantErr %v", result, expectedInstances)
}
}
func TestAWSSDProvider_CreateService(t *testing.T) {
namespaces := map[string]*sd.Namespace{
"private": {