Allow AdjustEndpoints to return error

This commit is contained in:
John Gardiner Myers 2023-09-03 11:23:03 -07:00
parent b2de466aa5
commit c596611f9e
17 changed files with 95 additions and 77 deletions

View File

@ -18,6 +18,7 @@ package controller
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@ -214,7 +215,10 @@ func (c *Controller) RunOnce(ctx context.Context) error {
vARecords, vAAAARecords := countMatchingAddressRecords(endpoints, records) vARecords, vAAAARecords := countMatchingAddressRecords(endpoints, records)
verifiedARecords.Set(float64(vARecords)) verifiedARecords.Set(float64(vARecords))
verifiedAAAARecords.Set(float64(vAAAARecords)) verifiedAAAARecords.Set(float64(vAAAARecords))
endpoints = c.Registry.AdjustEndpoints(endpoints) endpoints, err = c.Registry.AdjustEndpoints(endpoints)
if err != nil {
return fmt.Errorf("adjusting endpoints: %w", err)
}
registryFilter := c.Registry.GetDomainFilter() registryFilter := c.Registry.GetDomainFilter()
plan := &plan.Plan{ plan := &plan.Plan{

View File

@ -660,7 +660,7 @@ func (p *AWSProvider) newChanges(action string, endpoints []*endpoint.Endpoint)
// unneeded (potentially failing) changes. // unneeded (potentially failing) changes.
// Example: CNAME endpoints pointing to ELBs will have a `alias` provider-specific property // Example: CNAME endpoints pointing to ELBs will have a `alias` provider-specific property
// added to match the endpoints generated from existing alias records in Route53. // added to match the endpoints generated from existing alias records in Route53.
func (p *AWSProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *AWSProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
for _, ep := range endpoints { for _, ep := range endpoints {
alias := false alias := false
if ep.RecordType != endpoint.RecordTypeCNAME { if ep.RecordType != endpoint.RecordTypeCNAME {
@ -692,7 +692,7 @@ func (p *AWSProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoin
ep.DeleteProviderSpecificProperty(providerSpecificEvaluateTargetHealth) ep.DeleteProviderSpecificProperty(providerSpecificEvaluateTargetHealth)
} }
} }
return endpoints return endpoints, nil
} }
// newChange returns a route53 Change and a boolean indicating if there should also be a change to a AAAA record // newChange returns a route53 Change and a boolean indicating if there should also be a change to a AAAA record

View File

@ -529,7 +529,8 @@ func TestAWSAdjustEndpoints(t *testing.T) {
endpoint.NewEndpoint("cname-test-elb-no-eth.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.eu-central-1.elb.amazonaws.com").WithProviderSpecific(providerSpecificEvaluateTargetHealth, "false"), // eth = evaluate target health endpoint.NewEndpoint("cname-test-elb-no-eth.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.eu-central-1.elb.amazonaws.com").WithProviderSpecific(providerSpecificEvaluateTargetHealth, "false"), // eth = evaluate target health
} }
records = provider.AdjustEndpoints(records) records, err := provider.AdjustEndpoints(records)
assert.NoError(t, err)
validateEndpoints(t, provider, records, []*endpoint.Endpoint{ validateEndpoints(t, provider, records, []*endpoint.Endpoint{
endpoint.NewEndpoint("a-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"), endpoint.NewEndpoint("a-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
@ -1610,7 +1611,8 @@ func TestAWSBatchChangeSetExceedingNameChange(t *testing.T) {
func validateEndpoints(t *testing.T, provider *AWSProvider, endpoints []*endpoint.Endpoint, expected []*endpoint.Endpoint) { func validateEndpoints(t *testing.T, provider *AWSProvider, endpoints []*endpoint.Endpoint, expected []*endpoint.Endpoint) {
assert.True(t, testutils.SameEndpoints(endpoints, expected), "actual and expected endpoints don't match. %+v:%+v", endpoints, expected) assert.True(t, testutils.SameEndpoints(endpoints, expected), "actual and expected endpoints don't match. %+v:%+v", endpoints, expected)
normalized := provider.AdjustEndpoints(endpoints) normalized, err := provider.AdjustEndpoints(endpoints)
assert.NoError(t, err)
assert.True(t, testutils.SameEndpoints(normalized, expected), "actual and normalized endpoints don't match. %+v:%+v", endpoints, normalized) assert.True(t, testutils.SameEndpoints(normalized, expected), "actual and normalized endpoints don't match. %+v:%+v", endpoints, normalized)
} }

View File

@ -368,7 +368,7 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider // AdjustEndpoints modifies the endpoints as needed by the specific provider
func (p *CloudFlareProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *CloudFlareProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
adjustedEndpoints := []*endpoint.Endpoint{} adjustedEndpoints := []*endpoint.Endpoint{}
for _, e := range endpoints { for _, e := range endpoints {
proxied := shouldBeProxied(e, p.proxiedByDefault) proxied := shouldBeProxied(e, p.proxiedByDefault)
@ -379,7 +379,7 @@ func (p *CloudFlareProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*
adjustedEndpoints = append(adjustedEndpoints, e) adjustedEndpoints = append(adjustedEndpoints, e)
} }
return adjustedEndpoints return adjustedEndpoints, nil
} }
// changesByZone separates a multi-zone change into a single change per zone. // changesByZone separates a multi-zone change into a single change per zone.

View File

@ -302,7 +302,8 @@ func AssertActions(t *testing.T, provider *CloudFlareProvider, endpoints []*endp
t.Fatalf("cannot fetch records, %s", err) t.Fatalf("cannot fetch records, %s", err)
} }
endpoints = provider.AdjustEndpoints(endpoints) endpoints, err = provider.AdjustEndpoints(endpoints)
assert.NoError(t, err)
domainFilter := endpoint.NewDomainFilter([]string{"bar.com"}) domainFilter := endpoint.NewDomainFilter([]string{"bar.com"})
plan := &plan.Plan{ plan := &plan.Plan{
Current: records, Current: records,
@ -1147,7 +1148,8 @@ func TestProviderPropertiesIdempotency(t *testing.T) {
}) })
} }
desired = provider.AdjustEndpoints(desired) desired, err = provider.AdjustEndpoints(desired)
assert.NoError(t, err)
plan := plan.Plan{ plan := plan.Plan{
Current: current, Current: current,
@ -1190,9 +1192,7 @@ func TestCloudflareComplexUpdate(t *testing.T) {
} }
domainFilter := endpoint.NewDomainFilter([]string{"bar.com"}) domainFilter := endpoint.NewDomainFilter([]string{"bar.com"})
plan := &plan.Plan{ endpoints, err := provider.AdjustEndpoints([]*endpoint.Endpoint{
Current: records,
Desired: provider.AdjustEndpoints([]*endpoint.Endpoint{
{ {
DNSName: "foobar.bar.com", DNSName: "foobar.bar.com",
Targets: endpoint.Targets{"1.2.3.4", "2.3.4.5"}, Targets: endpoint.Targets{"1.2.3.4", "2.3.4.5"},
@ -1206,7 +1206,11 @@ func TestCloudflareComplexUpdate(t *testing.T) {
}, },
}, },
}, },
}), })
assert.NoError(t, err)
plan := &plan.Plan{
Current: records,
Desired: endpoints,
DomainFilter: endpoint.MatchAllDomainFilters{&domainFilter}, DomainFilter: endpoint.MatchAllDomainFilters{&domainFilter},
ManagedRecords: []string{endpoint.RecordTypeA, endpoint.RecordTypeCNAME}, ManagedRecords: []string{endpoint.RecordTypeA, endpoint.RecordTypeCNAME},
} }

View File

@ -386,7 +386,7 @@ func (p *IBMCloudProvider) ApplyChanges(ctx context.Context, changes *plan.Chang
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider // AdjustEndpoints modifies the endpoints as needed by the specific provider
func (p *IBMCloudProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *IBMCloudProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
adjustedEndpoints := []*endpoint.Endpoint{} adjustedEndpoints := []*endpoint.Endpoint{}
for _, e := range endpoints { for _, e := range endpoints {
log.Debugf("adjusting endpont: %v", *e) log.Debugf("adjusting endpont: %v", *e)
@ -398,7 +398,7 @@ func (p *IBMCloudProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*en
adjustedEndpoints = append(adjustedEndpoints, e) adjustedEndpoints = append(adjustedEndpoints, e)
} }
return adjustedEndpoints return adjustedEndpoints, nil
} }
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction. // submitChanges takes a zone and a collection of Changes and sends them as a single transaction.

View File

@ -276,8 +276,7 @@ func TestPublic_ApplyChanges(t *testing.T) {
func TestPrivate_ApplyChanges(t *testing.T) { func TestPrivate_ApplyChanges(t *testing.T) {
p := newTestIBMCloudProvider(true) p := newTestIBMCloudProvider(true)
changes := plan.Changes{ endpointsCreate, err := p.AdjustEndpoints([]*endpoint.Endpoint{
Create: p.AdjustEndpoints([]*endpoint.Endpoint{
{ {
DNSName: "newA.example.com", DNSName: "newA.example.com",
RecordType: "A", RecordType: "A",
@ -302,7 +301,21 @@ func TestPrivate_ApplyChanges(t *testing.T) {
RecordTTL: 240, RecordTTL: 240,
Targets: endpoint.NewTargets("\"heritage=external-dns,external-dns/owner=tower-pdns\""), Targets: endpoint.NewTargets("\"heritage=external-dns,external-dns/owner=tower-pdns\""),
}, },
}), })
assert.NoError(t, err)
endpointsUpdate, err := p.AdjustEndpoints([]*endpoint.Endpoint{
{
DNSName: "test.example.com",
RecordType: "A",
RecordTTL: 180,
Targets: endpoint.NewTargets("1.2.3.4", "5.6.7.8"),
},
})
assert.NoError(t, err)
changes := plan.Changes{
Create: endpointsCreate,
UpdateOld: []*endpoint.Endpoint{ UpdateOld: []*endpoint.Endpoint{
{ {
DNSName: "test.example.com", DNSName: "test.example.com",
@ -311,14 +324,7 @@ func TestPrivate_ApplyChanges(t *testing.T) {
Targets: endpoint.NewTargets("1.2.3.4"), Targets: endpoint.NewTargets("1.2.3.4"),
}, },
}, },
UpdateNew: p.AdjustEndpoints([]*endpoint.Endpoint{ UpdateNew: endpointsUpdate,
{
DNSName: "test.example.com",
RecordType: "A",
RecordTTL: 180,
Targets: endpoint.NewTargets("1.2.3.4", "5.6.7.8"),
},
}),
Delete: []*endpoint.Endpoint{ Delete: []*endpoint.Endpoint{
{ {
DNSName: "test.example.com", DNSName: "test.example.com",
@ -329,7 +335,7 @@ func TestPrivate_ApplyChanges(t *testing.T) {
}, },
} }
ctx := context.Background() ctx := context.Background()
err := p.ApplyChanges(ctx, &changes) err = p.ApplyChanges(ctx, &changes)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)
} }
@ -353,7 +359,8 @@ func TestAdjustEndpoints(t *testing.T) {
}, },
} }
ep := p.AdjustEndpoints(endpoints) ep, err := p.AdjustEndpoints(endpoints)
assert.NoError(t, err)
assert.Equal(t, endpoint.TTL(0), ep[0].RecordTTL) assert.Equal(t, endpoint.TTL(0), ep[0].RecordTTL)
assert.Equal(t, "test.example.com", ep[0].DNSName) assert.Equal(t, "test.example.com", ep[0].DNSName)

View File

@ -376,14 +376,14 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End
return endpoints, nil return endpoints, nil
} }
func (p *ProviderConfig) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *ProviderConfig) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
// Update user specified TTL (0 == disabled) // Update user specified TTL (0 == disabled)
for i := range endpoints { for i := range endpoints {
endpoints[i].RecordTTL = endpoint.TTL(p.cacheDuration) endpoints[i].RecordTTL = endpoint.TTL(p.cacheDuration)
} }
if !p.createPTR { if !p.createPTR {
return endpoints return endpoints, nil
} }
// for all A records, we want to create PTR records // for all A records, we want to create PTR records
@ -403,7 +403,7 @@ func (p *ProviderConfig) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endp
} }
} }
return endpoints return endpoints, nil
} }
// ApplyChanges applies the given changes. // ApplyChanges applies the given changes.

View File

@ -83,8 +83,8 @@ func (p *PluralProvider) Records(_ context.Context) (endpoints []*endpoint.Endpo
return return
} }
func (p *PluralProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *PluralProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return endpoints return endpoints, nil
} }
func (p *PluralProvider) ApplyChanges(_ context.Context, diffs *plan.Changes) error { func (p *PluralProvider) ApplyChanges(_ context.Context, diffs *plan.Changes) error {

View File

@ -36,14 +36,14 @@ type Provider interface {
// the endpoints that the provider returns in `Records` so that the change plan will not have // the endpoints that the provider returns in `Records` so that the change plan will not have
// unnecessary (potentially failing) changes. It may also modify other fields, add, or remove // unnecessary (potentially failing) changes. It may also modify other fields, add, or remove
// Endpoints. It is permitted to modify the supplied endpoints. // Endpoints. It is permitted to modify the supplied endpoints.
AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error)
GetDomainFilter() endpoint.DomainFilter GetDomainFilter() endpoint.DomainFilter
} }
type BaseProvider struct{} type BaseProvider struct{}
func (b BaseProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (b BaseProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return endpoints return endpoints, nil
} }
func (b BaseProvider) GetDomainFilter() endpoint.DomainFilter { func (b BaseProvider) GetDomainFilter() endpoint.DomainFilter {

View File

@ -92,7 +92,7 @@ func NewScalewayProvider(ctx context.Context, domainFilter endpoint.DomainFilter
} }
// AdjustEndpoints is used to normalize the endoints // AdjustEndpoints is used to normalize the endoints
func (p *ScalewayProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (p *ScalewayProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
eps := make([]*endpoint.Endpoint, len(endpoints)) eps := make([]*endpoint.Endpoint, len(endpoints))
for i := range endpoints { for i := range endpoints {
eps[i] = endpoints[i] eps[i] = endpoints[i]
@ -103,7 +103,7 @@ func (p *ScalewayProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*en
eps[i] = eps[i].WithProviderSpecific(scalewayPriorityKey, fmt.Sprintf("%d", scalewayDefaultPriority)) eps[i] = eps[i].WithProviderSpecific(scalewayPriorityKey, fmt.Sprintf("%d", scalewayDefaultPriority))
} }
} }
return eps return eps, nil
} }
// Zones returns the list of hosted zones. // Zones returns the list of hosted zones.

View File

@ -220,7 +220,8 @@ func TestScalewayProvider_AdjustEndpoints(t *testing.T) {
}, },
} }
after := provider.AdjustEndpoints(before) after, err := provider.AdjustEndpoints(before)
assert.NoError(t, err)
for i := range after { for i := range after {
if !checkRecordEquality(after[i], expected[i]) { if !checkRecordEquality(after[i], expected[i]) {
t.Errorf("got record %s instead of %s", after[i], expected[i]) t.Errorf("got record %s instead of %s", after[i], expected[i])

View File

@ -96,6 +96,6 @@ func (sdr *AWSSDRegistry) updateLabels(endpoints []*endpoint.Endpoint) {
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider // AdjustEndpoints modifies the endpoints as needed by the specific provider
func (sdr *AWSSDRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (sdr *AWSSDRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return sdr.provider.AdjustEndpoints(endpoints) return sdr.provider.AdjustEndpoints(endpoints)
} }

View File

@ -333,7 +333,7 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider. // AdjustEndpoints modifies the endpoints as needed by the specific provider.
func (im *DynamoDBRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (im *DynamoDBRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return im.provider.AdjustEndpoints(endpoints) return im.provider.AdjustEndpoints(endpoints)
} }

View File

@ -51,6 +51,6 @@ func (im *NoopRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider // AdjustEndpoints modifies the endpoints as needed by the specific provider
func (im *NoopRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (im *NoopRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return im.provider.AdjustEndpoints(endpoints) return im.provider.AdjustEndpoints(endpoints)
} }

View File

@ -32,7 +32,7 @@ import (
type Registry interface { type Registry interface {
Records(ctx context.Context) ([]*endpoint.Endpoint, error) Records(ctx context.Context) ([]*endpoint.Endpoint, error)
ApplyChanges(ctx context.Context, changes *plan.Changes) error ApplyChanges(ctx context.Context, changes *plan.Changes) error
AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error)
GetDomainFilter() endpoint.DomainFilter GetDomainFilter() endpoint.DomainFilter
} }

View File

@ -292,7 +292,7 @@ func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes)
} }
// AdjustEndpoints modifies the endpoints as needed by the specific provider // AdjustEndpoints modifies the endpoints as needed by the specific provider
func (im *TXTRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { func (im *TXTRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
return im.provider.AdjustEndpoints(endpoints) return im.provider.AdjustEndpoints(endpoints)
} }