Cache the endpoints on the controller loop

The controller will retrieve all the endpoints at the beginning of its
loop. When changes need to be applied, the provider may need to query
the endpoints again. Allow the provider to skip the queries if its data was
cached.
This commit is contained in:
Michael Fraenkel 2019-04-24 21:18:26 -04:00 committed by Michael Fraenkel
parent ad68fb8daf
commit fab942f486
49 changed files with 252 additions and 156 deletions

View File

@ -17,12 +17,14 @@ limitations under the License.
package controller package controller
import ( import (
"context"
"time" "time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/kubernetes-incubator/external-dns/plan" "github.com/kubernetes-incubator/external-dns/plan"
"github.com/kubernetes-incubator/external-dns/provider"
"github.com/kubernetes-incubator/external-dns/registry" "github.com/kubernetes-incubator/external-dns/registry"
"github.com/kubernetes-incubator/external-dns/source" "github.com/kubernetes-incubator/external-dns/source"
) )
@ -89,6 +91,8 @@ func (c *Controller) RunOnce() error {
} }
registryEndpointsTotal.Set(float64(len(records))) registryEndpointsTotal.Set(float64(len(records)))
ctx := context.WithValue(context.Background(), provider.RecordsContextKey, records)
endpoints, err := c.Source.Endpoints() endpoints, err := c.Source.Endpoints()
if err != nil { if err != nil {
sourceErrors.Inc() sourceErrors.Inc()
@ -104,7 +108,7 @@ func (c *Controller) RunOnce() error {
plan = plan.Calculate() plan = plan.Calculate()
err = c.Registry.ApplyChanges(plan.Changes) err = c.Registry.ApplyChanges(ctx, plan.Changes)
if err != nil { if err != nil {
registryErrors.Inc() registryErrors.Inc()
return err return err

View File

@ -17,7 +17,9 @@ limitations under the License.
package controller package controller
import ( import (
"context"
"errors" "errors"
"reflect"
"testing" "testing"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
@ -42,7 +44,7 @@ func (p *mockProvider) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges validates that the passed in changes satisfy the assumtions. // ApplyChanges validates that the passed in changes satisfy the assumtions.
func (p *mockProvider) ApplyChanges(changes *plan.Changes) error { func (p *mockProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
if len(changes.Create) != len(p.ExpectChanges.Create) { if len(changes.Create) != len(p.ExpectChanges.Create) {
return errors.New("number of created records is wrong") return errors.New("number of created records is wrong")
} }
@ -71,6 +73,9 @@ func (p *mockProvider) ApplyChanges(changes *plan.Changes) error {
} }
} }
if !reflect.DeepEqual(ctx.Value(provider.RecordsContextKey), p.RecordsStore) {
return errors.New("context is wrong")
}
return nil return nil
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"strings" "strings"
@ -291,7 +292,7 @@ func (p *AlibabaCloudProvider) Records() (endpoints []*endpoint.Endpoint, err er
// ApplyChanges applies the given changes. // ApplyChanges applies the given changes.
// //
// Returns nil if the operation was successful or an error if the operation failed. // Returns nil if the operation was successful or an error if the operation failed.
func (p *AlibabaCloudProvider) ApplyChanges(changes *plan.Changes) error { func (p *AlibabaCloudProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
if changes == nil || len(changes.Create)+len(changes.Delete)+len(changes.UpdateNew) == 0 { if changes == nil || len(changes.Create)+len(changes.Delete)+len(changes.UpdateNew) == 0 {
// No op // No op
return nil return nil

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"testing" "testing"
"github.com/aliyun/alibaba-cloud-sdk-go/services/alidns" "github.com/aliyun/alibaba-cloud-sdk-go/services/alidns"
@ -301,7 +302,7 @@ func TestAlibabaCloudProvider_ApplyChanges(t *testing.T) {
}, },
}, },
} }
p.ApplyChanges(&changes) p.ApplyChanges(context.Background(), &changes)
endpoints, err := p.Records() endpoints, err := p.Records()
if err != nil { if err != nil {
t.Errorf("Failed to get records: %v", err) t.Errorf("Failed to get records: %v", err)
@ -358,7 +359,7 @@ func TestAlibabaCloudProvider_ApplyChanges_PrivateZone(t *testing.T) {
}, },
}, },
} }
p.ApplyChanges(&changes) p.ApplyChanges(context.Background(), &changes)
endpoints, err := p.Records() endpoints, err := p.Records()
if err != nil { if err != nil {
t.Errorf("Failed to get records: %v", err) t.Errorf("Failed to get records: %v", err)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@ -319,15 +320,19 @@ func (p *AWSProvider) doRecords(action string, endpoints []*endpoint.Endpoint) e
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *AWSProvider) ApplyChanges(changes *plan.Changes) error { func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
zones, err := p.Zones() zones, err := p.Zones()
if err != nil { if err != nil {
return err return err
} }
records, err := p.records(zones) records, ok := ctx.Value(RecordsContextKey).([]*endpoint.Endpoint)
if err != nil { if !ok {
log.Errorf("getting records failed: %v", err) var err error
records, err = p.records(zones)
if err != nil {
log.Errorf("getting records failed: %v", err)
}
} }
combinedChanges := make([]*route53.Change, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*route53.Change, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"strings" "strings"
"crypto/sha256" "crypto/sha256"
@ -193,7 +194,7 @@ func (p *AWSSDProvider) instancesToEndpoint(ns *sd.NamespaceSummary, srv *sd.Ser
} }
// ApplyChanges applies Kubernetes changes in endpoints to AWS API // ApplyChanges applies Kubernetes changes in endpoints to AWS API
func (p *AWSSDProvider) ApplyChanges(changes *plan.Changes) error { func (p *AWSSDProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// return early if there is nothing to change // return early if there is nothing to change
if len(changes.Create) == 0 && len(changes.Delete) == 0 && len(changes.UpdateNew) == 0 { if len(changes.Create) == 0 && len(changes.Delete) == 0 && len(changes.UpdateNew) == 0 {
log.Info("All records are already up to date") log.Info("All records are already up to date")

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"errors" "errors"
"math/rand" "math/rand"
"reflect" "reflect"
@ -316,7 +317,7 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "") provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "")
// apply creates // apply creates
provider.ApplyChanges(&plan.Changes{ provider.ApplyChanges(context.Background(), &plan.Changes{
Create: expectedEndpoints, Create: expectedEndpoints,
}) })
@ -332,7 +333,7 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints) assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints)
// apply deletes // apply deletes
provider.ApplyChanges(&plan.Changes{ provider.ApplyChanges(context.Background(), &plan.Changes{
Delete: expectedEndpoints, Delete: expectedEndpoints,
}) })

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"sort" "sort"
@ -412,79 +413,96 @@ func TestAWSDeleteRecords(t *testing.T) {
} }
func TestAWSApplyChanges(t *testing.T) { func TestAWSApplyChanges(t *testing.T) {
provider, _ := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{ tests := []struct {
endpoint.NewEndpointWithTTL("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"), name string
endpoint.NewEndpointWithTTL("delete-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"), setup func(p *AWSProvider) context.Context
endpoint.NewEndpointWithTTL("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"), listRRSets int
endpoint.NewEndpointWithTTL("delete-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"), }{
endpoint.NewEndpointWithTTL("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "bar.elb.amazonaws.com"), {"no cache", func(p *AWSProvider) context.Context { return context.Background() }, 3},
endpoint.NewEndpointWithTTL("delete-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "qux.elb.amazonaws.com"), {"cached", func(p *AWSProvider) context.Context {
endpoint.NewEndpointWithTTL("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "bar.elb.amazonaws.com"), records, err := p.Records()
endpoint.NewEndpointWithTTL("delete-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "qux.elb.amazonaws.com"), require.NoError(t, err)
endpoint.NewEndpointWithTTL("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8", "8.8.4.4"), return context.WithValue(context.Background(), RecordsContextKey, records)
endpoint.NewEndpointWithTTL("delete-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "4.3.2.1"), }, 0},
})
createRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("create-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
endpoint.NewEndpoint("create-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"),
endpoint.NewEndpoint("create-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.elb.amazonaws.com"),
endpoint.NewEndpoint("create-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.elb.amazonaws.com"),
endpoint.NewEndpoint("create-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"),
} }
currentRecords := []*endpoint.Endpoint{ for _, tt := range tests {
endpoint.NewEndpoint("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"), provider, _ := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{
endpoint.NewEndpoint("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"), endpoint.NewEndpointWithTTL("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"),
endpoint.NewEndpoint("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "bar.elb.amazonaws.com"), endpoint.NewEndpointWithTTL("delete-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"),
endpoint.NewEndpoint("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "bar.elb.amazonaws.com"), endpoint.NewEndpointWithTTL("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"),
endpoint.NewEndpoint("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"), endpoint.NewEndpointWithTTL("delete-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"),
endpoint.NewEndpointWithTTL("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "bar.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("delete-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "qux.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "bar.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("delete-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "qux.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8", "8.8.4.4"),
endpoint.NewEndpointWithTTL("delete-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "4.3.2.1"),
})
createRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("create-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
endpoint.NewEndpoint("create-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"),
endpoint.NewEndpoint("create-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.elb.amazonaws.com"),
endpoint.NewEndpoint("create-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "foo.elb.amazonaws.com"),
endpoint.NewEndpoint("create-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"),
}
currentRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
endpoint.NewEndpoint("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"),
endpoint.NewEndpoint("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "bar.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "bar.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"),
}
updatedRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "4.3.2.1"),
endpoint.NewEndpoint("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "baz.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "baz.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"),
}
deleteRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("delete-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
endpoint.NewEndpoint("delete-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"),
endpoint.NewEndpoint("delete-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "qux.elb.amazonaws.com"),
endpoint.NewEndpoint("delete-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "qux.elb.amazonaws.com"),
endpoint.NewEndpoint("delete-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"),
}
changes := &plan.Changes{
Create: createRecords,
UpdateNew: updatedRecords,
UpdateOld: currentRecords,
Delete: deleteRecords,
}
ctx := tt.setup(provider)
counter := NewRoute53APICounter(provider.client)
provider.client = counter
require.NoError(t, provider.ApplyChanges(ctx, changes))
assert.Equal(t, 1, counter.calls["ListHostedZonesPages"], tt.name)
assert.Equal(t, tt.listRRSets, counter.calls["ListResourceRecordSetsPages"], tt.name)
records, err := provider.Records()
require.NoError(t, err, tt.name)
validateEndpoints(t, records, []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("create-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"),
endpoint.NewEndpointWithTTL("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4"),
endpoint.NewEndpointWithTTL("create-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"),
endpoint.NewEndpointWithTTL("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "4.3.2.1"),
endpoint.NewEndpointWithTTL("create-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "foo.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "baz.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("create-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "foo.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "baz.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("create-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8", "8.8.4.4"),
endpoint.NewEndpointWithTTL("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "4.3.2.1"),
})
} }
updatedRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "4.3.2.1"),
endpoint.NewEndpoint("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "baz.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "baz.elb.amazonaws.com"),
endpoint.NewEndpoint("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"),
}
deleteRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("delete-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8"),
endpoint.NewEndpoint("delete-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.4.4"),
endpoint.NewEndpoint("delete-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "qux.elb.amazonaws.com"),
endpoint.NewEndpoint("delete-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, "qux.elb.amazonaws.com"),
endpoint.NewEndpoint("delete-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"),
}
changes := &plan.Changes{
Create: createRecords,
UpdateNew: updatedRecords,
UpdateOld: currentRecords,
Delete: deleteRecords,
}
counter := NewRoute53APICounter(provider.client)
provider.client = counter
require.NoError(t, provider.ApplyChanges(changes))
assert.Equal(t, 1, counter.calls["ListHostedZonesPages"])
assert.Equal(t, 3, counter.calls["ListResourceRecordSetsPages"])
records, err := provider.Records()
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("create-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8"),
endpoint.NewEndpointWithTTL("update-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4"),
endpoint.NewEndpointWithTTL("create-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.4.4"),
endpoint.NewEndpointWithTTL("update-test.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "4.3.2.1"),
endpoint.NewEndpointWithTTL("create-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "foo.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-cname.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "baz.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("create-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "foo.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("update-test-cname-alias.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "baz.elb.amazonaws.com"),
endpoint.NewEndpointWithTTL("create-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "8.8.8.8", "8.8.4.4"),
endpoint.NewEndpointWithTTL("update-test-multiple.zone-2.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "4.3.2.1"),
})
} }
func TestAWSApplyChangesDryRun(t *testing.T) { func TestAWSApplyChangesDryRun(t *testing.T) {
@ -541,7 +559,7 @@ func TestAWSApplyChangesDryRun(t *testing.T) {
Delete: deleteRecords, Delete: deleteRecords,
} }
require.NoError(t, provider.ApplyChanges(changes)) require.NoError(t, provider.ApplyChanges(context.Background(), changes))
records, err := provider.Records() records, err := provider.Records()
require.NoError(t, err) require.NoError(t, err)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"strings" "strings"
@ -209,7 +210,7 @@ func (p *AzureProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
// ApplyChanges applies the given changes. // ApplyChanges applies the given changes.
// //
// Returns nil if the operation was successful or an error if the operation failed. // Returns nil if the operation was successful or an error if the operation failed.
func (p *AzureProvider) ApplyChanges(changes *plan.Changes) error { func (p *AzureProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
zones, err := p.zones() zones, err := p.zones()
if err != nil { if err != nil {
return err return err

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"testing" "testing"
"github.com/Azure/azure-sdk-for-go/arm/dns" "github.com/Azure/azure-sdk-for-go/arm/dns"
@ -344,7 +345,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordsClie
Delete: deleteRecords, Delete: deleteRecords,
} }
if err := provider.ApplyChanges(changes); err != nil { if err := provider.ApplyChanges(context.Background(), changes); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -192,7 +192,7 @@ func (p *CloudFlareProvider) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *CloudFlareProvider) ApplyChanges(changes *plan.Changes) error { func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
proxiedByDefault := p.proxiedByDefault proxiedByDefault := p.proxiedByDefault
combinedChanges := make([]*cloudFlareChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*cloudFlareChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))

View File

@ -542,7 +542,7 @@ func TestApplyChanges(t *testing.T) {
changes.Delete = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target"}}} changes.Delete = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target"}}}
changes.UpdateOld = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target-old"}}} changes.UpdateOld = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target-old"}}}
changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target-new"}}} changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.zalando.to.", Targets: endpoint.Targets{"target-new"}}}
err := provider.ApplyChanges(changes) err := provider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)
} }
@ -553,7 +553,7 @@ func TestApplyChanges(t *testing.T) {
changes.UpdateOld = []*endpoint.Endpoint{} changes.UpdateOld = []*endpoint.Endpoint{}
changes.UpdateNew = []*endpoint.Endpoint{} changes.UpdateNew = []*endpoint.Endpoint{}
err = provider.ApplyChanges(changes) err = provider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)
} }

View File

@ -298,7 +298,7 @@ func (p coreDNSProvider) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges stores changes back to etcd converting them to CoreDNS format and aggregating A/CNAME and TXT records // ApplyChanges stores changes back to etcd converting them to CoreDNS format and aggregating A/CNAME and TXT records
func (p coreDNSProvider) ApplyChanges(changes *plan.Changes) error { func (p coreDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
grouped := map[string][]*endpoint.Endpoint{} grouped := map[string][]*endpoint.Endpoint{}
for _, ep := range changes.Create { for _, ep := range changes.Create {
grouped[ep.DNSName] = append(grouped[ep.DNSName], ep) grouped[ep.DNSName] = append(grouped[ep.DNSName], ep)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@ -227,7 +228,7 @@ func TestCoreDNSApplyChanges(t *testing.T) {
endpoint.NewEndpoint("domain2.local", endpoint.RecordTypeCNAME, "site.local"), endpoint.NewEndpoint("domain2.local", endpoint.RecordTypeCNAME, "site.local"),
}, },
} }
coredns.ApplyChanges(changes1) coredns.ApplyChanges(context.Background(), changes1)
expectedServices1 := map[string]*Service{ expectedServices1 := map[string]*Service{
"/skydns/local/domain1": {Host: "5.5.5.5", Text: "string1"}, "/skydns/local/domain1": {Host: "5.5.5.5", Text: "string1"},
@ -285,7 +286,7 @@ func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) {
} }
} }
} }
provider.ApplyChanges(changes) provider.ApplyChanges(context.Background(), changes)
} }
func validateServices(services, expectedServices map[string]*Service, t *testing.T, step int) { func validateServices(services, expectedServices map[string]*Service, t *testing.T, step int) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -379,7 +380,7 @@ func addEndpoint(ep *endpoint.Endpoint, recordSets map[string]*recordSet, delete
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p designateProvider) ApplyChanges(changes *plan.Changes) error { func (p designateProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
managedZones, err := p.getZones() managedZones, err := p.getZones()
if err != nil { if err != nil {
return err return err

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -407,7 +408,7 @@ func testDesignateCreateRecords(t *testing.T, client *fakeDesignateClient) []*re
expectedCopy := make([]*recordsets.RecordSet, len(expected)) expectedCopy := make([]*recordsets.RecordSet, len(expected))
copy(expectedCopy, expected) copy(expectedCopy, expected)
err := client.ToProvider().ApplyChanges(&plan.Changes{Create: endpoints}) err := client.ToProvider().ApplyChanges(context.Background(), &plan.Changes{Create: endpoints})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -495,7 +496,7 @@ func testDesignateUpdateRecords(t *testing.T, client *fakeDesignateClient) []*re
expected[2].Records = []string{"10.3.3.1"} expected[2].Records = []string{"10.3.3.1"}
expected[3].Records = []string{"10.2.1.1", "10.3.3.2"} expected[3].Records = []string{"10.2.1.1", "10.3.3.2"}
err := client.ToProvider().ApplyChanges(&plan.Changes{UpdateOld: updatesOld, UpdateNew: updatesNew}) err := client.ToProvider().ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updatesOld, UpdateNew: updatesNew})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -553,7 +554,7 @@ func testDesignateDeleteRecords(t *testing.T, client *fakeDesignateClient) {
expected[3].Records = []string{"10.3.3.2"} expected[3].Records = []string{"10.3.3.2"}
expected = expected[1:] expected = expected[1:]
err := client.ToProvider().ApplyChanges(&plan.Changes{Delete: deletes}) err := client.ToProvider().ApplyChanges(context.Background(), &plan.Changes{Delete: deletes})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
goctx "context"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@ -261,7 +262,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *DigitalOceanProvider) ApplyChanges(changes *plan.Changes) error { func (p *DigitalOceanProvider) ApplyChanges(ctx goctx.Context, changes *plan.Changes) error {
combinedChanges := make([]*DigitalOceanChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*DigitalOceanChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))
combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanCreate, changes.Create)...) combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanCreate, changes.Create)...)

View File

@ -438,7 +438,7 @@ func TestDigitalOceanApplyChanges(t *testing.T) {
changes.Delete = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.bar.com", Targets: endpoint.Targets{"target"}}} changes.Delete = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.bar.com", Targets: endpoint.Targets{"target"}}}
changes.UpdateOld = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.bar.de", Targets: endpoint.Targets{"target-old"}}} changes.UpdateOld = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.bar.de", Targets: endpoint.Targets{"target-old"}}}
changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.foo.com", Targets: endpoint.Targets{"target-new"}, RecordType: "CNAME", RecordTTL: 100}} changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "foobar.ext-dns-test.foo.com", Targets: endpoint.Targets{"target-new"}, RecordType: "CNAME", RecordTTL: 100}}
err := provider.ApplyChanges(changes) err := provider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -332,7 +333,7 @@ func (p *dnsimpleProvider) UpdateRecords(endpoints []*endpoint.Endpoint) error {
} }
// ApplyChanges applies a given set of changes // ApplyChanges applies a given set of changes
func (p *dnsimpleProvider) ApplyChanges(changes *plan.Changes) error { func (p *dnsimpleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
combinedChanges := make([]*dnsimpleChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*dnsimpleChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))
combinedChanges = append(combinedChanges, newDnsimpleChanges(dnsimpleCreate, changes.Create)...) combinedChanges = append(combinedChanges, newDnsimpleChanges(dnsimpleCreate, changes.Create)...)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -172,7 +173,7 @@ func testDnsimpleProviderApplyChanges(t *testing.T) {
} }
mockProvider.accountID = "1" mockProvider.accountID = "1"
err := mockProvider.ApplyChanges(changes) err := mockProvider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("Failed to apply changes: %v", err) t.Errorf("Failed to apply changes: %v", err)
} }
@ -185,7 +186,7 @@ func testDnsimpleProviderApplyChangesSkipsUnknown(t *testing.T) {
} }
mockProvider.accountID = "1" mockProvider.accountID = "1"
err := mockProvider.ApplyChanges(changes) err := mockProvider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("Failed to ignore unknown zones: %v", err) t.Errorf("Failed to ignore unknown zones: %v", err)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -637,7 +638,7 @@ func (d *dynProviderState) Records() ([]*endpoint.Endpoint, error) {
// this method does C + 2*Z requests: C=total number of changes, Z = number of // this method does C + 2*Z requests: C=total number of changes, Z = number of
// affected zones (1 login + 1 commit) // affected zones (1 login + 1 commit)
func (d *dynProviderState) ApplyChanges(changes *plan.Changes) error { func (d *dynProviderState) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
log.Debugf("Processing chages: %+v", changes) log.Debugf("Processing chages: %+v", changes)
if d.DryRun { if d.DryRun {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"strings" "strings"
"github.com/exoscale/egoscale" "github.com/exoscale/egoscale"
@ -81,7 +82,7 @@ func (ep *ExoscaleProvider) getZones() (map[int64]string, error) {
} }
// ApplyChanges simply modifies DNS via exoscale API // ApplyChanges simply modifies DNS via exoscale API
func (ep *ExoscaleProvider) ApplyChanges(changes *plan.Changes) error { func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
ep.OnApplyChanges(changes) ep.OnApplyChanges(changes)
if ep.dryRun { if ep.dryRun {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@ -173,7 +174,7 @@ func TestExoscaleApplyChanges(t *testing.T) {
createExoscale = make([]createRecordExoscale, 0) createExoscale = make([]createRecordExoscale, 0)
deleteExoscale = make([]deleteRecordExoscale, 0) deleteExoscale = make([]deleteRecordExoscale, 0)
provider.ApplyChanges(plan) provider.ApplyChanges(context.Background(), plan)
assert.Equal(t, 1, len(createExoscale)) assert.Equal(t, 1, len(createExoscale))
assert.Equal(t, "foo.com", createExoscale[0].name) assert.Equal(t, "foo.com", createExoscale[0].name)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
goctx "context"
"fmt" "fmt"
"strings" "strings"
@ -247,7 +248,7 @@ func (p *GoogleProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *GoogleProvider) ApplyChanges(changes *plan.Changes) error { func (p *GoogleProvider) ApplyChanges(ctx goctx.Context, changes *plan.Changes) error {
change := &dns.Change{} change := &dns.Change{}
change.Additions = append(change.Additions, p.newFilteredRecords(changes.Create)...) change.Additions = append(change.Additions, p.newFilteredRecords(changes.Create)...)

View File

@ -387,7 +387,7 @@ func TestGoogleApplyChanges(t *testing.T) {
Delete: deleteRecords, Delete: deleteRecords,
} }
require.NoError(t, provider.ApplyChanges(changes)) require.NoError(t, provider.ApplyChanges(context.Background(), changes))
records, err := provider.Records() records, err := provider.Records()
require.NoError(t, err) require.NoError(t, err)
@ -444,7 +444,7 @@ func TestGoogleApplyChangesDryRun(t *testing.T) {
Delete: deleteRecords, Delete: deleteRecords,
} }
require.NoError(t, provider.ApplyChanges(changes)) require.NoError(t, provider.ApplyChanges(context.Background(), changes))
records, err := provider.Records() records, err := provider.Records()
require.NoError(t, err) require.NoError(t, err)
@ -454,7 +454,7 @@ func TestGoogleApplyChangesDryRun(t *testing.T) {
func TestGoogleApplyChangesEmpty(t *testing.T) { func TestGoogleApplyChangesEmpty(t *testing.T) {
provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{}) provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{})
assert.NoError(t, provider.ApplyChanges(&plan.Changes{})) assert.NoError(t, provider.ApplyChanges(context.Background(), &plan.Changes{}))
} }
func TestNewFilteredRecords(t *testing.T) { func TestNewFilteredRecords(t *testing.T) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -177,7 +178,7 @@ func (p *InfobloxProvider) Records() (endpoints []*endpoint.Endpoint, err error)
} }
// ApplyChanges applies the given changes. // ApplyChanges applies the given changes.
func (p *InfobloxProvider) ApplyChanges(changes *plan.Changes) error { func (p *InfobloxProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
zones, err := p.zones() zones, err := p.zones()
if err != nil { if err != nil {
return err return err

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"regexp" "regexp"
@ -469,7 +470,7 @@ func testInfobloxApplyChangesInternal(t *testing.T, dryRun bool, client ibclient
Delete: deleteRecords, Delete: deleteRecords,
} }
if err := provider.ApplyChanges(changes); err != nil { if err := provider.ApplyChanges(context.Background(), changes); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"errors" "errors"
"strings" "strings"
@ -45,7 +46,7 @@ type InMemoryProvider struct {
domain DomainFilter domain DomainFilter
client *inMemoryClient client *inMemoryClient
filter *filter filter *filter
OnApplyChanges func(changes *plan.Changes) OnApplyChanges func(ctx context.Context, changes *plan.Changes)
OnRecords func() OnRecords func()
} }
@ -55,7 +56,7 @@ type InMemoryOption func(*InMemoryProvider)
// InMemoryWithLogging injects logging when ApplyChanges is called // InMemoryWithLogging injects logging when ApplyChanges is called
func InMemoryWithLogging() InMemoryOption { func InMemoryWithLogging() InMemoryOption {
return func(p *InMemoryProvider) { return func(p *InMemoryProvider) {
p.OnApplyChanges = func(changes *plan.Changes) { p.OnApplyChanges = func(ctx context.Context, changes *plan.Changes) {
for _, v := range changes.Create { for _, v := range changes.Create {
log.Infof("CREATE: %v", v) log.Infof("CREATE: %v", v)
} }
@ -94,7 +95,7 @@ func InMemoryInitZones(zones []string) InMemoryOption {
func NewInMemoryProvider(opts ...InMemoryOption) *InMemoryProvider { func NewInMemoryProvider(opts ...InMemoryOption) *InMemoryProvider {
im := &InMemoryProvider{ im := &InMemoryProvider{
filter: &filter{}, filter: &filter{},
OnApplyChanges: func(changes *plan.Changes) {}, OnApplyChanges: func(ctx context.Context, changes *plan.Changes) {},
OnRecords: func() {}, OnRecords: func() {},
domain: NewDomainFilter([]string{""}), domain: NewDomainFilter([]string{""}),
client: newInMemoryClient(), client: newInMemoryClient(),
@ -142,8 +143,8 @@ func (im *InMemoryProvider) Records() ([]*endpoint.Endpoint, error) {
// create record - record should not exist // create record - record should not exist
// update/delete record - record should exist // update/delete record - record should exist
// create/update/delete lists should not have overlapping records // create/update/delete lists should not have overlapping records
func (im *InMemoryProvider) ApplyChanges(changes *plan.Changes) error { func (im *InMemoryProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
defer im.OnApplyChanges(changes) defer im.OnApplyChanges(ctx, changes)
perZoneChanges := map[string]*plan.Changes{} perZoneChanges := map[string]*plan.Changes{}
@ -188,7 +189,7 @@ func (im *InMemoryProvider) ApplyChanges(changes *plan.Changes) error {
UpdateOld: convertToInMemoryRecord(perZoneChanges[zoneID].UpdateOld), UpdateOld: convertToInMemoryRecord(perZoneChanges[zoneID].UpdateOld),
Delete: convertToInMemoryRecord(perZoneChanges[zoneID].Delete), Delete: convertToInMemoryRecord(perZoneChanges[zoneID].Delete),
} }
err := im.client.ApplyChanges(zoneID, change) err := im.client.ApplyChanges(ctx, zoneID, change)
if err != nil { if err != nil {
return err return err
} }
@ -293,7 +294,7 @@ func (c *inMemoryClient) CreateZone(zone string) error {
return nil return nil
} }
func (c *inMemoryClient) ApplyChanges(zoneID string, changes *inMemoryChange) error { func (c *inMemoryClient) ApplyChanges(ctx context.Context, zoneID string, changes *inMemoryChange) error {
if err := c.validateChangeBatch(zoneID, changes); err != nil { if err := c.validateChangeBatch(zoneID, changes); err != nil {
return err return err
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"testing" "testing"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
@ -773,7 +774,7 @@ func testInMemoryApplyChanges(t *testing.T) {
c.zones = getInitData() c.zones = getInitData()
im.client = c im.client = c
err := im.ApplyChanges(ti.changes) err := im.ApplyChanges(context.Background(), ti.changes)
if ti.expectError { if ti.expectError {
assert.Error(t, err) assert.Error(t, err)
} else { } else {

View File

@ -263,7 +263,7 @@ func getPriority() *int {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *LinodeProvider) ApplyChanges(changes *plan.Changes) error { func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
recordsByZoneID := make(map[string][]*linodego.DomainRecord) recordsByZoneID := make(map[string][]*linodego.DomainRecord)
zones, err := p.fetchZones() zones, err := p.fetchZones()

View File

@ -353,7 +353,7 @@ func TestLinodeApplyChanges(t *testing.T) {
}, },
).Return(&linodego.DomainRecord{}, nil).Once() ).Return(&linodego.DomainRecord{}, nil).Once()
err := provider.ApplyChanges(&plan.Changes{ err := provider.ApplyChanges(context.Background(), &plan.Changes{
Create: []*endpoint.Endpoint{{ Create: []*endpoint.Endpoint{{
DNSName: "create.bar.io", DNSName: "create.bar.io",
RecordType: "A", RecordType: "A",
@ -428,7 +428,7 @@ func TestLinodeApplyChangesTargetAdded(t *testing.T) {
}, },
).Return(&linodego.DomainRecord{}, nil).Once() ).Return(&linodego.DomainRecord{}, nil).Once()
err := provider.ApplyChanges(&plan.Changes{ err := provider.ApplyChanges(context.Background(), &plan.Changes{
// From 1 target to 2 // From 1 target to 2
UpdateNew: []*endpoint.Endpoint{{ UpdateNew: []*endpoint.Endpoint{{
DNSName: "example.com", DNSName: "example.com",
@ -484,7 +484,7 @@ func TestLinodeApplyChangesTargetRemoved(t *testing.T) {
11, 11,
).Return(nil).Once() ).Return(nil).Once()
err := provider.ApplyChanges(&plan.Changes{ err := provider.ApplyChanges(context.Background(), &plan.Changes{
// From 2 targets to 1 // From 2 targets to 1
UpdateNew: []*endpoint.Endpoint{{ UpdateNew: []*endpoint.Endpoint{{
DNSName: "example.com", DNSName: "example.com",
@ -521,7 +521,7 @@ func TestLinodeApplyChangesNoChanges(t *testing.T) {
mock.Anything, mock.Anything,
).Return([]*linodego.DomainRecord{{ID: 11, Name: "", Type: "A", Target: "targetA"}}, nil).Once() ).Return([]*linodego.DomainRecord{{ID: 11, Name: "", Type: "A", Target: "targetA"}}, nil).Once()
err := provider.ApplyChanges(&plan.Changes{}) err := provider.ApplyChanges(context.Background(), &plan.Changes{})
require.NoError(t, err) require.NoError(t, err)
mockDomainClient.AssertExpectations(t) mockDomainClient.AssertExpectations(t)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
@ -271,7 +272,7 @@ type ns1Change struct {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *NS1Provider) ApplyChanges(changes *plan.Changes) error { func (p *NS1Provider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
combinedChanges := make([]*ns1Change, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*ns1Change, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))
combinedChanges = append(combinedChanges, newNS1Changes(ns1Create, changes.Create)...) combinedChanges = append(combinedChanges, newNS1Changes(ns1Create, changes.Create)...)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@ -221,14 +222,14 @@ func TestNS1ApplyChanges(t *testing.T) {
} }
changes.Delete = []*endpoint.Endpoint{{DNSName: "test.foo.com", Targets: endpoint.Targets{"target"}}} changes.Delete = []*endpoint.Endpoint{{DNSName: "test.foo.com", Targets: endpoint.Targets{"target"}}}
changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "test.foo.com", Targets: endpoint.Targets{"target-new"}}} changes.UpdateNew = []*endpoint.Endpoint{{DNSName: "test.foo.com", Targets: endpoint.Targets{"target-new"}}}
err := provider.ApplyChanges(changes) err := provider.ApplyChanges(context.Background(), changes)
require.NoError(t, err) require.NoError(t, err)
// empty changes // empty changes
changes.Create = []*endpoint.Endpoint{} changes.Create = []*endpoint.Endpoint{}
changes.Delete = []*endpoint.Endpoint{} changes.Delete = []*endpoint.Endpoint{}
changes.UpdateNew = []*endpoint.Endpoint{} changes.UpdateNew = []*endpoint.Endpoint{}
err = provider.ApplyChanges(changes) err = provider.ApplyChanges(context.Background(), changes)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -201,7 +201,7 @@ func (p *OCIProvider) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges applies a given set of changes to a given zone. // ApplyChanges applies a given set of changes to a given zone.
func (p *OCIProvider) ApplyChanges(changes *plan.Changes) error { func (p *OCIProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
log.Debugf("Processing chages: %+v", changes) log.Debugf("Processing chages: %+v", changes)
ops := []dns.RecordOperation{} ops := []dns.RecordOperation{}
@ -217,7 +217,6 @@ func (p *OCIProvider) ApplyChanges(changes *plan.Changes) error {
return nil return nil
} }
ctx := context.Background()
zones, err := p.zones(ctx) zones, err := p.zones(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching zones") return errors.Wrap(err, "fetching zones")

View File

@ -829,7 +829,7 @@ func TestOCIApplyChanges(t *testing.T) {
NewZoneIDFilter([]string{""}), NewZoneIDFilter([]string{""}),
tc.dryRun, tc.dryRun,
) )
err := provider.ApplyChanges(tc.changes) err := provider.ApplyChanges(context.Background(), tc.changes)
require.Equal(t, tc.err, err) require.Equal(t, tc.err, err)
endpoints, err := provider.Records() endpoints, err := provider.Records()
require.NoError(t, err) require.NoError(t, err)

View File

@ -443,7 +443,7 @@ func (p *PDNSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
// ApplyChanges takes a list of changes (endpoints) and updates the PDNS server // ApplyChanges takes a list of changes (endpoints) and updates the PDNS server
// by sending the correct HTTP PATCH requests to a matching zone // by sending the correct HTTP PATCH requests to a matching zone
func (p *PDNSProvider) ApplyChanges(changes *plan.Changes) error { func (p *PDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
startTime := time.Now() startTime := time.Now()

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"net" "net"
"strings" "strings"
@ -27,9 +28,20 @@ import (
// Provider defines the interface DNS providers should implement. // Provider defines the interface DNS providers should implement.
type Provider interface { type Provider interface {
Records() ([]*endpoint.Endpoint, error) Records() ([]*endpoint.Endpoint, error)
ApplyChanges(changes *plan.Changes) error ApplyChanges(ctx context.Context, changes *plan.Changes) error
} }
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "provider context value " + k.name }
// RecordsContextKey is a context key. It can be used during ApplyChanges
// to access previously cached records. The associated value will be of
// type []*endpoint.Endpoint.
var RecordsContextKey = &contextKey{"records"}
// ensureTrailingDot ensures that the hostname receives a trailing dot if it hasn't already. // ensureTrailingDot ensures that the hostname receives a trailing dot if it hasn't already.
func ensureTrailingDot(hostname string) string { func ensureTrailingDot(hostname string) string {
if net.ParseIP(hostname) != nil { if net.ParseIP(hostname) != nil {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -141,7 +142,7 @@ func (p *RcodeZeroProvider) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *RcodeZeroProvider) ApplyChanges(changes *plan.Changes) error { func (p *RcodeZeroProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
combinedChanges := make([]*rc0.RRSetChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) combinedChanges := make([]*rc0.RRSetChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -102,7 +103,7 @@ func TestRcodeZeroProvider_ApplyChanges(t *testing.T) {
changes := mockChanges() changes := mockChanges()
err := provider.ApplyChanges(changes) err := provider.ApplyChanges(context.Background(), changes)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -195,7 +196,7 @@ func (r rfc2136Provider) List() ([]dns.RR, error) {
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (r rfc2136Provider) ApplyChanges(changes *plan.Changes) error { func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
log.Debugf("ApplyChanges") log.Debugf("ApplyChanges")
for _, ep := range changes.Create { for _, ep := range changes.Create {

View File

@ -17,6 +17,7 @@ limitations under the License.
package provider package provider
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@ -149,7 +150,7 @@ func TestRfc2136ApplyChanges(t *testing.T) {
}, },
} }
err = provider.ApplyChanges(p) err = provider.ApplyChanges(context.Background(), p)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(stub.createMsgs)) assert.Equal(t, 2, len(stub.createMsgs))

View File

@ -1,6 +1,7 @@
package provider package provider
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -62,7 +63,7 @@ func NewTransIPProvider(accountName, privateKeyFile string, domainFilter DomainF
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *TransIPProvider) ApplyChanges(changes *plan.Changes) error { func (p *TransIPProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// build zonefinder with all our zones so we can use FindZone // build zonefinder with all our zones so we can use FindZone
// and a mapping of zones and their domain name // and a mapping of zones and their domain name
zones, err := p.fetchZones() zones, err := p.fetchZones()

View File

@ -17,6 +17,7 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"errors" "errors"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
@ -64,7 +65,7 @@ func (sdr *AWSSDRegistry) Records() ([]*endpoint.Endpoint, error) {
// ApplyChanges filters out records not owned the External-DNS, additionally it adds the required label // ApplyChanges filters out records not owned the External-DNS, additionally it adds the required label
// inserted in the AWS SD instance as a CreateID field // inserted in the AWS SD instance as a CreateID field
func (sdr *AWSSDRegistry) ApplyChanges(changes *plan.Changes) error { func (sdr *AWSSDRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
filteredChanges := &plan.Changes{ filteredChanges := &plan.Changes{
Create: changes.Create, Create: changes.Create,
UpdateNew: filterOwnedRecords(sdr.ownerID, changes.UpdateNew), UpdateNew: filterOwnedRecords(sdr.ownerID, changes.UpdateNew),
@ -77,7 +78,7 @@ func (sdr *AWSSDRegistry) ApplyChanges(changes *plan.Changes) error {
sdr.updateLabels(filteredChanges.UpdateOld) sdr.updateLabels(filteredChanges.UpdateOld)
sdr.updateLabels(filteredChanges.Delete) sdr.updateLabels(filteredChanges.Delete)
return sdr.provider.ApplyChanges(filteredChanges) return sdr.provider.ApplyChanges(ctx, filteredChanges)
} }
func (sdr *AWSSDRegistry) updateLabels(endpoints []*endpoint.Endpoint) { func (sdr *AWSSDRegistry) updateLabels(endpoints []*endpoint.Endpoint) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"testing" "testing"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
@ -35,7 +36,7 @@ func (p *inMemoryProvider) Records() ([]*endpoint.Endpoint, error) {
return p.endpoints, nil return p.endpoints, nil
} }
func (p *inMemoryProvider) ApplyChanges(changes *plan.Changes) error { func (p *inMemoryProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
p.onApplyChanges(changes) p.onApplyChanges(changes)
return nil return nil
} }
@ -151,7 +152,7 @@ func TestAWSSDRegistry_Records_ApplyChanges(t *testing.T) {
r, err := NewAWSSDRegistry(p, "owner") r, err := NewAWSSDRegistry(p, "owner")
require.NoError(t, err) require.NoError(t, err)
err = r.ApplyChanges(changes) err = r.ApplyChanges(context.Background(), changes)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -17,6 +17,8 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
"github.com/kubernetes-incubator/external-dns/plan" "github.com/kubernetes-incubator/external-dns/plan"
"github.com/kubernetes-incubator/external-dns/provider" "github.com/kubernetes-incubator/external-dns/provider"
@ -40,6 +42,6 @@ func (im *NoopRegistry) Records() ([]*endpoint.Endpoint, error) {
} }
// ApplyChanges propagates changes to the dns provider // ApplyChanges propagates changes to the dns provider
func (im *NoopRegistry) ApplyChanges(changes *plan.Changes) error { func (im *NoopRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
return im.provider.ApplyChanges(changes) return im.provider.ApplyChanges(ctx, changes)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"testing" "testing"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
@ -53,7 +54,7 @@ func testNoopRecords(t *testing.T) {
RecordType: endpoint.RecordTypeCNAME, RecordType: endpoint.RecordTypeCNAME,
}, },
} }
p.ApplyChanges(&plan.Changes{ p.ApplyChanges(context.Background(), &plan.Changes{
Create: providerRecords, Create: providerRecords,
}) })
@ -88,13 +89,14 @@ func testNoopApplyChanges(t *testing.T) {
}, },
} }
p.ApplyChanges(&plan.Changes{ ctx := context.Background()
p.ApplyChanges(ctx, &plan.Changes{
Create: providerRecords, Create: providerRecords,
}) })
// wrong changes // wrong changes
r, _ := NewNoopRegistry(p) r, _ := NewNoopRegistry(p)
err := r.ApplyChanges(&plan.Changes{ err := r.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
{ {
DNSName: "example.org", DNSName: "example.org",
@ -106,7 +108,7 @@ func testNoopApplyChanges(t *testing.T) {
assert.EqualError(t, err, provider.ErrRecordAlreadyExists.Error()) assert.EqualError(t, err, provider.ErrRecordAlreadyExists.Error())
//correct changes //correct changes
require.NoError(t, r.ApplyChanges(&plan.Changes{ require.NoError(t, r.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
{ {
DNSName: "new-record.org", DNSName: "new-record.org",

View File

@ -17,6 +17,8 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"github.com/kubernetes-incubator/external-dns/endpoint" "github.com/kubernetes-incubator/external-dns/endpoint"
"github.com/kubernetes-incubator/external-dns/plan" "github.com/kubernetes-incubator/external-dns/plan"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -28,7 +30,7 @@ import (
// ApplyChanges(changes *plan.Changes) propagates the changes to the DNS Provider API and correspondingly updates ownership depending on type of registry being used // ApplyChanges(changes *plan.Changes) propagates the changes to the DNS Provider API and correspondingly updates ownership depending on type of registry being used
type Registry interface { type Registry interface {
Records() ([]*endpoint.Endpoint, error) Records() ([]*endpoint.Endpoint, error)
ApplyChanges(changes *plan.Changes) error ApplyChanges(ctx context.Context, changes *plan.Changes) error
} }
//TODO(ideahitme): consider moving this to Plan //TODO(ideahitme): consider moving this to Plan

View File

@ -17,6 +17,7 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"errors" "errors"
"time" "time"
@ -117,7 +118,7 @@ func (im *TXTRegistry) Records() ([]*endpoint.Endpoint, error) {
// ApplyChanges updates dns provider with the changes // ApplyChanges updates dns provider with the changes
// for each created/deleted record it will also take into account TXT records for creation/deletion // for each created/deleted record it will also take into account TXT records for creation/deletion
func (im *TXTRegistry) ApplyChanges(changes *plan.Changes) error { func (im *TXTRegistry) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
filteredChanges := &plan.Changes{ filteredChanges := &plan.Changes{
Create: changes.Create, Create: changes.Create,
UpdateNew: filterOwnedRecords(im.ownerID, changes.UpdateNew), UpdateNew: filterOwnedRecords(im.ownerID, changes.UpdateNew),
@ -171,7 +172,11 @@ func (im *TXTRegistry) ApplyChanges(changes *plan.Changes) error {
} }
} }
return im.provider.ApplyChanges(filteredChanges) // when caching is enabled, disable the provider from using the cache
if im.cacheInterval > 0 {
ctx = context.WithValue(ctx, provider.RecordsContextKey, nil)
}
return im.provider.ApplyChanges(ctx, filteredChanges)
} }
/** /**

View File

@ -17,6 +17,7 @@ limitations under the License.
package registry package registry
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -68,7 +69,7 @@ func testTXTRegistryRecords(t *testing.T) {
func testTXTRegistryRecordsPrefixed(t *testing.T) { func testTXTRegistryRecordsPrefixed(t *testing.T) {
p := provider.NewInMemoryProvider() p := provider.NewInMemoryProvider()
p.CreateZone(testZone) p.CreateZone(testZone)
p.ApplyChanges(&plan.Changes{ p.ApplyChanges(context.Background(), &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""),
newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""),
@ -141,7 +142,7 @@ func testTXTRegistryRecordsPrefixed(t *testing.T) {
func testTXTRegistryRecordsNoPrefix(t *testing.T) { func testTXTRegistryRecordsNoPrefix(t *testing.T) {
p := provider.NewInMemoryProvider() p := provider.NewInMemoryProvider()
p.CreateZone(testZone) p.CreateZone(testZone)
p.ApplyChanges(&plan.Changes{ p.ApplyChanges(context.Background(), &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""),
newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""),
@ -220,7 +221,12 @@ func testTXTRegistryApplyChanges(t *testing.T) {
func testTXTRegistryApplyChangesWithPrefix(t *testing.T) { func testTXTRegistryApplyChangesWithPrefix(t *testing.T) {
p := provider.NewInMemoryProvider() p := provider.NewInMemoryProvider()
p.CreateZone(testZone) p.CreateZone(testZone)
p.ApplyChanges(&plan.Changes{ ctxEndpoints := []*endpoint.Endpoint{}
ctx := context.WithValue(context.Background(), provider.RecordsContextKey, ctxEndpoints)
p.OnApplyChanges = func(ctx context.Context, got *plan.Changes) {
assert.Equal(t, ctxEndpoints, ctx.Value(provider.RecordsContextKey))
}
p.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""),
newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""),
@ -267,7 +273,7 @@ func testTXTRegistryApplyChangesWithPrefix(t *testing.T) {
newEndpointWithOwner("txt.tar.test-zone.example.org", "\"heritage=external-dns,external-dns/owner=owner\"", endpoint.RecordTypeTXT, ""), newEndpointWithOwner("txt.tar.test-zone.example.org", "\"heritage=external-dns,external-dns/owner=owner\"", endpoint.RecordTypeTXT, ""),
}, },
} }
p.OnApplyChanges = func(got *plan.Changes) { p.OnApplyChanges = func(ctx context.Context, got *plan.Changes) {
mExpected := map[string][]*endpoint.Endpoint{ mExpected := map[string][]*endpoint.Endpoint{
"Create": expected.Create, "Create": expected.Create,
"UpdateNew": expected.UpdateNew, "UpdateNew": expected.UpdateNew,
@ -281,15 +287,21 @@ func testTXTRegistryApplyChangesWithPrefix(t *testing.T) {
"Delete": got.Delete, "Delete": got.Delete,
} }
assert.True(t, testutils.SamePlanChanges(mGot, mExpected)) assert.True(t, testutils.SamePlanChanges(mGot, mExpected))
assert.Equal(t, nil, ctx.Value(provider.RecordsContextKey))
} }
err := r.ApplyChanges(changes) err := r.ApplyChanges(ctx, changes)
require.NoError(t, err) require.NoError(t, err)
} }
func testTXTRegistryApplyChangesNoPrefix(t *testing.T) { func testTXTRegistryApplyChangesNoPrefix(t *testing.T) {
p := provider.NewInMemoryProvider() p := provider.NewInMemoryProvider()
p.CreateZone(testZone) p.CreateZone(testZone)
p.ApplyChanges(&plan.Changes{ ctxEndpoints := []*endpoint.Endpoint{}
ctx := context.WithValue(context.Background(), provider.RecordsContextKey, ctxEndpoints)
p.OnApplyChanges = func(ctx context.Context, got *plan.Changes) {
assert.Equal(t, ctxEndpoints, ctx.Value(provider.RecordsContextKey))
}
p.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""),
newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""),
@ -330,7 +342,7 @@ func testTXTRegistryApplyChangesNoPrefix(t *testing.T) {
UpdateNew: []*endpoint.Endpoint{}, UpdateNew: []*endpoint.Endpoint{},
UpdateOld: []*endpoint.Endpoint{}, UpdateOld: []*endpoint.Endpoint{},
} }
p.OnApplyChanges = func(got *plan.Changes) { p.OnApplyChanges = func(ctx context.Context, got *plan.Changes) {
mExpected := map[string][]*endpoint.Endpoint{ mExpected := map[string][]*endpoint.Endpoint{
"Create": expected.Create, "Create": expected.Create,
"UpdateNew": expected.UpdateNew, "UpdateNew": expected.UpdateNew,
@ -344,8 +356,9 @@ func testTXTRegistryApplyChangesNoPrefix(t *testing.T) {
"Delete": got.Delete, "Delete": got.Delete,
} }
assert.True(t, testutils.SamePlanChanges(mGot, mExpected)) assert.True(t, testutils.SamePlanChanges(mGot, mExpected))
assert.Equal(t, nil, ctx.Value(provider.RecordsContextKey))
} }
err := r.ApplyChanges(changes) err := r.ApplyChanges(ctx, changes)
require.NoError(t, err) require.NoError(t, err)
} }