add ctx parameter to provider interface and aws API

This commit is contained in:
tariqibrahim 2020-01-12 13:16:55 -08:00
parent a1ff7d75ea
commit 672b00d821
56 changed files with 220 additions and 187 deletions

View File

@ -104,7 +104,8 @@ type Controller struct {
// RunOnce runs a single iteration of a reconciliation loop.
func (c *Controller) RunOnce() error {
records, err := c.Registry.Records()
ctx := context.Background()
records, err := c.Registry.Records(ctx)
if err != nil {
registryErrorsTotal.Inc()
deprecatedRegistryErrors.Inc()
@ -112,7 +113,7 @@ func (c *Controller) RunOnce() error {
}
registryEndpointsTotal.Set(float64(len(records)))
ctx := context.WithValue(context.Background(), provider.RecordsContextKey, records)
ctx = context.WithValue(ctx, provider.RecordsContextKey, records)
endpoints, err := c.Source.Endpoints()
if err != nil {

View File

@ -39,7 +39,7 @@ type mockProvider struct {
}
// Records returns the desired mock endpoints.
func (p *mockProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *mockProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
return p.RecordsStore, nil
}

View File

@ -281,7 +281,7 @@ func (p *AlibabaCloudProvider) refreshStsToken(sleepTime time.Duration) {
// Records gets the current records.
//
// Returns the current records or an error if the operation failed.
func (p *AlibabaCloudProvider) Records() (endpoints []*endpoint.Endpoint, err error) {
func (p *AlibabaCloudProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) {
if p.privateZone {
endpoints, err = p.privateZoneRecords()
} else {

View File

@ -246,7 +246,7 @@ func newTestAlibabaCloudProvider(private bool) *AlibabaCloudProvider {
func TestAlibabaCloudPrivateProvider_Records(t *testing.T) {
p := newTestAlibabaCloudProvider(true)
endpoints, err := p.Records()
endpoints, err := p.Records(context.Background())
if err != nil {
t.Errorf("Failed to get records: %v", err)
} else {
@ -261,7 +261,7 @@ func TestAlibabaCloudPrivateProvider_Records(t *testing.T) {
func TestAlibabaCloudProvider_Records(t *testing.T) {
p := newTestAlibabaCloudProvider(false)
endpoints, err := p.Records()
endpoints, err := p.Records(context.Background())
if err != nil {
t.Errorf("Failed to get records: %v", err)
} else {
@ -302,8 +302,9 @@ func TestAlibabaCloudProvider_ApplyChanges(t *testing.T) {
},
},
}
p.ApplyChanges(context.Background(), &changes)
endpoints, err := p.Records()
ctx := context.Background()
p.ApplyChanges(ctx, &changes)
endpoints, err := p.Records(ctx)
if err != nil {
t.Errorf("Failed to get records: %v", err)
} else {
@ -318,7 +319,7 @@ func TestAlibabaCloudProvider_ApplyChanges(t *testing.T) {
func TestAlibabaCloudProvider_Records_PrivateZone(t *testing.T) {
p := newTestAlibabaCloudProvider(true)
endpoints, err := p.Records()
endpoints, err := p.Records(context.Background())
if err != nil {
t.Errorf("Failed to get records: %v", err)
} else {
@ -359,8 +360,9 @@ func TestAlibabaCloudProvider_ApplyChanges_PrivateZone(t *testing.T) {
},
},
}
p.ApplyChanges(context.Background(), &changes)
endpoints, err := p.Records()
ctx := context.Background()
p.ApplyChanges(ctx, &changes)
endpoints, err := p.Records(ctx)
if err != nil {
t.Errorf("Failed to get records: %v", err)
} else {

View File

@ -26,6 +26,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/linki/instrumented_http"
@ -101,11 +102,11 @@ var (
// Route53API is the subset of the AWS Route53 API that we actually use. Add methods as required. Signatures must match exactly.
// mostly taken from: https://github.com/kubernetes/kubernetes/blob/853167624edb6bc0cfdcdfb88e746e178f5db36c/federation/pkg/dnsprovider/providers/aws/route53/stubs/route53api.go
type Route53API interface {
ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error
ChangeResourceRecordSets(*route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error)
CreateHostedZone(*route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error)
ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error
ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error)
ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error
ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error)
CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error)
ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error
ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error)
}
// AWSProvider is an implementation of Provider for AWS Route53.
@ -184,7 +185,7 @@ func NewAWSProvider(awsConfig AWSConfig) (*AWSProvider, error) {
}
// Zones returns the list of hosted zones.
func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) {
func (p *AWSProvider) Zones(ctx context.Context) (map[string]*route53.HostedZone, error) {
zones := make(map[string]*route53.HostedZone)
var tagErr error
@ -204,7 +205,7 @@ func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) {
// Only fetch tags if a tag filter was specified
if !p.zoneTagFilter.IsEmpty() {
tags, err := p.tagsForZone(*zone.Id)
tags, err := p.tagsForZone(ctx, *zone.Id)
if err != nil {
tagErr = err
return false
@ -220,7 +221,7 @@ func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) {
return true
}
err := p.client.ListHostedZonesPages(&route53.ListHostedZonesInput{}, f)
err := p.client.ListHostedZonesPagesWithContext(ctx, &route53.ListHostedZonesInput{}, f)
if err != nil {
return nil, err
}
@ -245,16 +246,16 @@ func wildcardUnescape(s string) string {
}
// Records returns the list of records in a given hosted zone.
func (p *AWSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones()
func (p *AWSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones(ctx)
if err != nil {
return nil, err
}
return p.records(zones)
return p.records(ctx, zones)
}
func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint.Endpoint, error) {
func (p *AWSProvider) records(ctx context.Context, zones map[string]*route53.HostedZone) ([]*endpoint.Endpoint, error) {
endpoints := make([]*endpoint.Endpoint, 0)
f := func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool) {
for _, r := range resp.ResourceRecordSets {
@ -331,7 +332,7 @@ func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint
HostedZoneId: z.Id,
}
if err := p.client.ListResourceRecordSetsPages(params, f); err != nil {
if err := p.client.ListResourceRecordSetsPagesWithContext(ctx, params, f); err != nil {
return nil, err
}
}
@ -340,36 +341,36 @@ func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint
}
// CreateRecords creates a given set of DNS records in the given hosted zone.
func (p *AWSProvider) CreateRecords(endpoints []*endpoint.Endpoint) error {
return p.doRecords(route53.ChangeActionCreate, endpoints)
func (p *AWSProvider) CreateRecords(ctx context.Context, endpoints []*endpoint.Endpoint) error {
return p.doRecords(ctx, route53.ChangeActionCreate, endpoints)
}
// UpdateRecords updates a given set of old records to a new set of records in a given hosted zone.
func (p *AWSProvider) UpdateRecords(endpoints, _ []*endpoint.Endpoint) error {
return p.doRecords(route53.ChangeActionUpsert, endpoints)
func (p *AWSProvider) UpdateRecords(ctx context.Context, endpoints, _ []*endpoint.Endpoint) error {
return p.doRecords(ctx, route53.ChangeActionUpsert, endpoints)
}
// DeleteRecords deletes a given set of DNS records in a given zone.
func (p *AWSProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error {
return p.doRecords(route53.ChangeActionDelete, endpoints)
func (p *AWSProvider) DeleteRecords(ctx context.Context, endpoints []*endpoint.Endpoint) error {
return p.doRecords(ctx, route53.ChangeActionDelete, endpoints)
}
func (p *AWSProvider) doRecords(action string, endpoints []*endpoint.Endpoint) error {
zones, err := p.Zones()
func (p *AWSProvider) doRecords(ctx context.Context, action string, endpoints []*endpoint.Endpoint) error {
zones, err := p.Zones(ctx)
if err != nil {
return err
}
records, err := p.records(zones)
records, err := p.records(ctx, zones)
if err != nil {
log.Errorf("getting records failed: %v", err)
}
return p.submitChanges(p.newChanges(action, endpoints, records, zones), zones)
return p.submitChanges(ctx, p.newChanges(action, endpoints, records, zones), zones)
}
// ApplyChanges applies a given set of changes in a given zone.
func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return err
}
@ -377,7 +378,7 @@ func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e
records, ok := ctx.Value(RecordsContextKey).([]*endpoint.Endpoint)
if !ok {
var err error
records, err = p.records(zones)
records, err = p.records(ctx, zones)
if err != nil {
log.Errorf("getting records failed: %v", err)
}
@ -389,11 +390,11 @@ func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e
combinedChanges = append(combinedChanges, p.newChanges(route53.ChangeActionUpsert, changes.UpdateNew, records, zones)...)
combinedChanges = append(combinedChanges, p.newChanges(route53.ChangeActionDelete, changes.Delete, records, zones)...)
return p.submitChanges(combinedChanges, zones)
return p.submitChanges(ctx, combinedChanges, zones)
}
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *AWSProvider) submitChanges(changes []*route53.Change, zones map[string]*route53.HostedZone) error {
func (p *AWSProvider) submitChanges(ctx context.Context, changes []*route53.Change, zones map[string]*route53.HostedZone) error {
// return early if there is nothing to change
if len(changes) == 0 {
log.Info("All records are already up to date")
@ -425,7 +426,7 @@ func (p *AWSProvider) submitChanges(changes []*route53.Change, zones map[string]
},
}
if _, err := p.client.ChangeResourceRecordSets(params); err != nil {
if _, err := p.client.ChangeResourceRecordSetsWithContext(ctx, params); err != nil {
log.Errorf("Failure in zone %s [Id: %s]", aws.StringValue(zones[z].Name), z)
log.Error(err) //TODO(ideahitme): consider changing the interface in cases when this error might be a concern for other components
failedUpdate = true
@ -568,8 +569,8 @@ func (p *AWSProvider) newChange(action string, ep *endpoint.Endpoint, recordsCac
return change, dualstack
}
func (p *AWSProvider) tagsForZone(zoneID string) (map[string]string, error) {
response, err := p.client.ListTagsForResource(&route53.ListTagsForResourceInput{
func (p *AWSProvider) tagsForZone(ctx context.Context, zoneID string) (map[string]string, error) {
response, err := p.client.ListTagsForResourceWithContext(ctx, &route53.ListTagsForResourceInput{
ResourceType: aws.String("hostedzone"),
ResourceId: aws.String(zoneID),
})

View File

@ -138,7 +138,7 @@ func newSdNamespaceFilter(namespaceTypeConfig string) *sd.NamespaceFilter {
}
// Records returns list of all endpoints.
func (p *AWSSDProvider) Records() (endpoints []*endpoint.Endpoint, err error) {
func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) {
namespaces, err := p.ListNamespaces()
if err != nil {
return nil, err

View File

@ -289,7 +289,7 @@ func TestAWSSDProvider_Records(t *testing.T) {
provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "")
endpoints, _ := provider.Records()
endpoints, _ := provider.Records(context.Background())
assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints)
}
@ -317,8 +317,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "")
ctx := context.Background()
// apply creates
provider.ApplyChanges(context.Background(), &plan.Changes{
provider.ApplyChanges(ctx, &plan.Changes{
Create: expectedEndpoints,
})
@ -330,16 +332,17 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) {
assert.NotNil(t, existingServices["service3"])
// make sure instances were registered
endpoints, _ := provider.Records()
endpoints, _ := provider.Records(ctx)
assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints)
ctx = context.Background()
// apply deletes
provider.ApplyChanges(context.Background(), &plan.Changes{
provider.ApplyChanges(ctx, &plan.Changes{
Delete: expectedEndpoints,
})
// make sure all instances are gone
endpoints, _ = provider.Records()
endpoints, _ = provider.Records(ctx)
assert.Empty(t, endpoints)
}

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -73,7 +74,7 @@ func NewRoute53APIStub() *Route53APIStub {
}
}
func (r *Route53APIStub) ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(p *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error {
func (r *Route53APIStub) ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(p *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error {
output := route53.ListResourceRecordSetsOutput{} // TODO: Support optional input args.
if len(r.recordSets) == 0 {
output.ResourceRecordSets = []*route53.ResourceRecordSet{}
@ -103,29 +104,29 @@ func NewRoute53APICounter(w Route53API) *Route53APICounter {
}
}
func (c *Route53APICounter) ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error {
func (c *Route53APICounter) ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error {
c.calls["ListResourceRecordSetsPages"]++
return c.wrapped.ListResourceRecordSetsPages(input, fn)
return c.wrapped.ListResourceRecordSetsPagesWithContext(ctx, input, fn)
}
func (c *Route53APICounter) ChangeResourceRecordSets(input *route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error) {
func (c *Route53APICounter) ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error) {
c.calls["ChangeResourceRecordSets"]++
return c.wrapped.ChangeResourceRecordSets(input)
return c.wrapped.ChangeResourceRecordSetsWithContext(ctx, input)
}
func (c *Route53APICounter) CreateHostedZone(input *route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error) {
func (c *Route53APICounter) CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error) {
c.calls["CreateHostedZone"]++
return c.wrapped.CreateHostedZone(input)
return c.wrapped.CreateHostedZoneWithContext(ctx, input)
}
func (c *Route53APICounter) ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error {
func (c *Route53APICounter) ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error {
c.calls["ListHostedZonesPages"]++
return c.wrapped.ListHostedZonesPages(input, fn)
return c.wrapped.ListHostedZonesPagesWithContext(ctx, input, fn)
}
func (c *Route53APICounter) ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error) {
func (c *Route53APICounter) ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error) {
c.calls["ListTagsForResource"]++
return c.wrapped.ListTagsForResource(input)
return c.wrapped.ListTagsForResourceWithContext(ctx, input)
}
// Route53 stores wildcards escaped: http://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DomainNameFormat.html?shortFooter=true#domain-name-format-asterisk
@ -136,7 +137,7 @@ func wildcardEscape(s string) string {
return s
}
func (r *Route53APIStub) ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error) {
func (r *Route53APIStub) ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error) {
if aws.StringValue(input.ResourceType) == "hostedzone" {
tags := r.zoneTags[aws.StringValue(input.ResourceId)]
return &route53.ListTagsForResourceOutput{
@ -150,7 +151,7 @@ func (r *Route53APIStub) ListTagsForResource(input *route53.ListTagsForResourceI
return &route53.ListTagsForResourceOutput{}, nil
}
func (r *Route53APIStub) ChangeResourceRecordSets(input *route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error) {
func (r *Route53APIStub) ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error) {
if r.m.isMocked("ChangeResourceRecordSets", input) {
return r.m.ChangeResourceRecordSets(input)
}
@ -209,7 +210,7 @@ func (r *Route53APIStub) ChangeResourceRecordSets(input *route53.ChangeResourceR
return output, nil // TODO: We should ideally return status etc, but we don't' use that yet.
}
func (r *Route53APIStub) ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(p *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error {
func (r *Route53APIStub) ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(p *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error {
output := &route53.ListHostedZonesOutput{}
for _, zone := range r.zones {
output.HostedZones = append(output.HostedZones, zone)
@ -219,7 +220,7 @@ func (r *Route53APIStub) ListHostedZonesPages(input *route53.ListHostedZonesInpu
return nil
}
func (r *Route53APIStub) CreateHostedZone(input *route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error) {
func (r *Route53APIStub) CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error) {
name := aws.StringValue(input.Name)
id := "/hostedzone/" + name
if _, ok := r.zones[id]; ok {
@ -302,7 +303,7 @@ func TestAWSZones(t *testing.T) {
} {
provider, _ := newAWSProviderWithTagFilter(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), ti.zoneIDFilter, ti.zoneTypeFilter, ti.zoneTagFilter, defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{})
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
require.NoError(t, err)
validateAWSZones(t, zones, ti.expectedZones)
@ -328,7 +329,7 @@ func TestAWSRecords(t *testing.T) {
endpoint.NewEndpointWithTTL("geolocation-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "4.3.2.1").WithSetIdentifier("test-set-2").WithProviderSpecific(providerSpecificGeolocationCountryCode, "DE"),
})
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -362,9 +363,9 @@ func TestAWSCreateRecords(t *testing.T) {
endpoint.NewEndpoint("create-test-multiple.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"),
}
require.NoError(t, provider.CreateRecords(records))
require.NoError(t, provider.CreateRecords(context.Background(), records))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -397,9 +398,9 @@ func TestAWSUpdateRecords(t *testing.T) {
endpoint.NewEndpoint("create-test-multiple.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"),
}
require.NoError(t, provider.UpdateRecords(updatedRecords, currentRecords))
require.NoError(t, provider.UpdateRecords(context.Background(), updatedRecords, currentRecords))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -422,9 +423,9 @@ func TestAWSDeleteRecords(t *testing.T) {
provider, _ := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), false, false, originalEndpoints)
require.NoError(t, provider.DeleteRecords(originalEndpoints))
require.NoError(t, provider.DeleteRecords(context.Background(), originalEndpoints))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
@ -439,9 +440,10 @@ func TestAWSApplyChanges(t *testing.T) {
}{
{"no cache", func(p *AWSProvider) context.Context { return context.Background() }, 3},
{"cached", func(p *AWSProvider) context.Context {
records, err := p.Records()
ctx := context.Background()
records, err := p.Records(ctx)
require.NoError(t, err)
return context.WithValue(context.Background(), RecordsContextKey, records)
return context.WithValue(ctx, RecordsContextKey, records)
}, 0},
}
@ -506,7 +508,7 @@ func TestAWSApplyChanges(t *testing.T) {
assert.Equal(t, 1, counter.calls["ListHostedZonesPages"], tt.name)
assert.Equal(t, tt.listRRSets, counter.calls["ListResourceRecordSetsPages"], tt.name)
records, err := provider.Records()
records, err := provider.Records(ctx)
require.NoError(t, err, tt.name)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -578,9 +580,11 @@ func TestAWSApplyChangesDryRun(t *testing.T) {
Delete: deleteRecords,
}
require.NoError(t, provider.ApplyChanges(context.Background(), changes))
ctx := context.Background()
records, err := provider.Records()
require.NoError(t, provider.ApplyChanges(ctx, changes))
records, err := provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, originalEndpoints)
@ -698,14 +702,15 @@ func TestAWSsubmitChanges(t *testing.T) {
}
}
zones, _ := provider.Zones()
records, _ := provider.Records()
ctx := context.Background()
zones, _ := provider.Zones(ctx)
records, _ := provider.Records(ctx)
cs := make([]*route53.Change, 0, len(endpoints))
cs = append(cs, provider.newChanges(route53.ChangeActionCreate, endpoints, records, zones)...)
require.NoError(t, provider.submitChanges(cs, zones))
require.NoError(t, provider.submitChanges(ctx, cs, zones))
records, err := provider.Records()
records, err := provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, endpoints)
@ -715,15 +720,16 @@ func TestAWSsubmitChangesError(t *testing.T) {
provider, clientStub := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{})
clientStub.MockMethod("ChangeResourceRecordSets", mock.Anything).Return(nil, fmt.Errorf("Mock route53 failure"))
zones, err := provider.Zones()
ctx := context.Background()
zones, err := provider.Zones(ctx)
require.NoError(t, err)
records, err := provider.Records()
records, err := provider.Records(ctx)
require.NoError(t, err)
ep := endpoint.NewEndpointWithTTL("fail.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.0.0.1")
cs := provider.newChanges(route53.ChangeActionCreate, []*endpoint.Endpoint{ep}, records, zones)
require.Error(t, provider.submitChanges(cs, zones))
require.Error(t, provider.submitChanges(ctx, cs, zones))
}
func TestAWSBatchChangeSet(t *testing.T) {
@ -853,7 +859,7 @@ func TestAWSCreateRecordsWithCNAME(t *testing.T) {
{DNSName: "create-test.zone-1.ext-dns-test-2.teapot.zalan.do", Targets: endpoint.Targets{"foo.example.org"}, RecordType: endpoint.RecordTypeCNAME},
}
require.NoError(t, provider.CreateRecords(records))
require.NoError(t, provider.CreateRecords(context.Background(), records))
recordSets := listAWSRecords(t, provider.client, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.")
@ -906,7 +912,7 @@ func TestAWSCreateRecordsWithALIAS(t *testing.T) {
},
}
require.NoError(t, provider.CreateRecords(records))
require.NoError(t, provider.CreateRecords(context.Background(), records))
recordSets := listAWSRecords(t, provider.client, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.")
@ -1093,7 +1099,7 @@ func createAWSZone(t *testing.T, provider *AWSProvider, zone *route53.HostedZone
HostedZoneConfig: zone.Config,
}
if _, err := provider.client.CreateHostedZone(params); err != nil {
if _, err := provider.client.CreateHostedZoneWithContext(context.Background(), params); err != nil {
require.EqualError(t, err, route53.ErrCodeHostedZoneAlreadyExists)
}
}
@ -1103,25 +1109,26 @@ func setupAWSRecords(t *testing.T, provider *AWSProvider, endpoints []*endpoint.
clearAWSRecords(t, provider, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.")
clearAWSRecords(t, provider, "/hostedzone/zone-3.ext-dns-test-2.teapot.zalan.do.")
records, err := provider.Records()
ctx := context.Background()
records, err := provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{})
require.NoError(t, provider.CreateRecords(endpoints))
require.NoError(t, provider.CreateRecords(context.Background(), endpoints))
escapeAWSRecords(t, provider, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.")
escapeAWSRecords(t, provider, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.")
escapeAWSRecords(t, provider, "/hostedzone/zone-3.ext-dns-test-2.teapot.zalan.do.")
_, err = provider.Records()
_, err = provider.Records(ctx)
require.NoError(t, err)
}
func listAWSRecords(t *testing.T, client Route53API, zone string) []*route53.ResourceRecordSet {
recordSets := []*route53.ResourceRecordSet{}
require.NoError(t, client.ListResourceRecordSetsPages(&route53.ListResourceRecordSetsInput{
require.NoError(t, client.ListResourceRecordSetsPagesWithContext(context.Background(), &route53.ListResourceRecordSetsInput{
HostedZoneId: aws.String(zone),
}, func(resp *route53.ListResourceRecordSetsOutput, _ bool) bool {
for _, recordSet := range resp.ResourceRecordSets {
@ -1145,7 +1152,7 @@ func clearAWSRecords(t *testing.T, provider *AWSProvider, zone string) {
}
if len(changes) != 0 {
_, err := provider.client.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{
_, err := provider.client.ChangeResourceRecordSetsWithContext(context.Background(), &route53.ChangeResourceRecordSetsInput{
HostedZoneId: aws.String(zone),
ChangeBatch: &route53.ChangeBatch{
Changes: changes,
@ -1168,7 +1175,7 @@ func escapeAWSRecords(t *testing.T, provider *AWSProvider, zone string) {
}
if len(changes) != 0 {
_, err := provider.client.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{
_, err := provider.client.ChangeResourceRecordSetsWithContext(context.Background(), &route53.ChangeResourceRecordSetsInput{
HostedZoneId: aws.String(zone),
ChangeBatch: &route53.ChangeBatch{
Changes: changes,

View File

@ -180,8 +180,7 @@ func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePri
// Records gets the current records.
//
// Returns the current records or an error if the operation failed.
func (p *AzureProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
ctx := context.Background()
func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.zones(ctx)
if err != nil {
return nil, err

View File

@ -90,8 +90,7 @@ func NewAzurePrivateDNSProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFi
// Records gets the current records.
//
// Returns the current records or an error if the operation failed.
func (p *AzurePrivateDNSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
ctx := context.Background()
func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.zones(ctx)
if err != nil {
return nil, err

View File

@ -266,7 +266,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) {
t.Fatal(err)
}
actual, err := provider.Records()
actual, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
@ -302,7 +302,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) {
t.Fatal(err)
}
actual, err := provider.Records()
actual, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)

View File

@ -273,7 +273,8 @@ func TestAzureRecord(t *testing.T) {
t.Fatal(err)
}
actual, err := provider.Records()
ctx := context.Background()
actual, err := provider.Records(ctx)
if err != nil {
t.Fatal(err)
@ -309,7 +310,8 @@ func TestAzureMultiRecord(t *testing.T) {
t.Fatal(err)
}
actual, err := provider.Records()
ctx := context.Background()
actual, err := provider.Records(ctx)
if err != nil {
t.Fatal(err)

View File

@ -176,7 +176,7 @@ func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) {
}
// Records returns the list of records.
func (p *CloudFlareProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -491,19 +491,21 @@ func TestRecords(t *testing.T) {
provider := &CloudFlareProvider{
Client: &mockCloudFlareClient{},
}
records, err := provider.Records()
ctx := context.Background()
records, err := provider.Records(ctx)
if err != nil {
t.Errorf("should not fail, %s", err)
}
assert.Equal(t, 1, len(records))
provider.Client = &mockCloudFlareDNSRecordsFail{}
_, err = provider.Records()
_, err = provider.Records(ctx)
if err == nil {
t.Errorf("expected to fail")
}
provider.Client = &mockCloudFlareListZonesFail{}
_, err = provider.Records()
_, err = provider.Records(ctx)
if err == nil {
t.Errorf("expected to fail")
}

View File

@ -260,7 +260,7 @@ func NewCoreDNSProvider(domainFilter DomainFilter, prefix string, dryRun bool) (
// Records returns all DNS records found in CoreDNS etcd backend. Depending on the record fields
// it may be mapped to one or two records of type A, CNAME, TXT, A+TXT, CNAME+TXT
func (p coreDNSProvider) Records() ([]*endpoint.Endpoint, error) {
func (p coreDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
var result []*endpoint.Endpoint
services, err := p.client.GetServices(p.coreDNSPrefix)
if err != nil {

View File

@ -66,7 +66,7 @@ func TestAServiceTranslation(t *testing.T) {
client: client,
coreDNSPrefix: defaultCoreDNSPrefix,
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -98,7 +98,7 @@ func TestCNAMEServiceTranslation(t *testing.T) {
client: client,
coreDNSPrefix: defaultCoreDNSPrefix,
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -130,7 +130,7 @@ func TestTXTServiceTranslation(t *testing.T) {
client: client,
coreDNSPrefix: defaultCoreDNSPrefix,
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -164,7 +164,7 @@ func TestAWithTXTServiceTranslation(t *testing.T) {
client: client,
coreDNSPrefix: defaultCoreDNSPrefix,
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -206,7 +206,7 @@ func TestCNAMEWithTXTServiceTranslation(t *testing.T) {
client: client,
coreDNSPrefix: defaultCoreDNSPrefix,
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -264,7 +264,7 @@ func TestCoreDNSApplyChanges(t *testing.T) {
endpoint.NewEndpoint("domain1.local", "A", "6.6.6.6"),
},
}
records, _ := coredns.Records()
records, _ := coredns.Records(context.Background())
for _, ep := range records {
if ep.DNSName == "domain1.local" {
changes2.UpdateOld = append(changes2.UpdateOld, ep)
@ -296,7 +296,8 @@ func TestCoreDNSApplyChanges(t *testing.T) {
}
func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) {
records, _ := provider.Records()
ctx := context.Background()
records, _ := provider.Records(ctx)
for _, col := range [][]*endpoint.Endpoint{changes.Create, changes.UpdateNew, changes.Delete} {
for _, record := range col {
for _, existingRecord := range records {
@ -306,7 +307,7 @@ func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) {
}
}
}
provider.ApplyChanges(context.Background(), changes)
provider.ApplyChanges(ctx, changes)
}
func validateServices(services, expectedServices map[string]*Service, t *testing.T, step int) {

View File

@ -308,7 +308,7 @@ func (p designateProvider) getHostZoneID(hostname string, managedZones map[strin
}
// Records returns the list of records.
func (p designateProvider) Records() ([]*endpoint.Endpoint, error) {
func (p designateProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
var result []*endpoint.Endpoint
managedZones, err := p.getZones()
if err != nil {

View File

@ -303,7 +303,7 @@ func TestDesignateRecords(t *testing.T) {
},
}
endpoints, err := client.ToProvider().Records()
endpoints, err := client.ToProvider().Records(context.Background())
if err != nil {
t.Fatal(err)
}

View File

@ -94,7 +94,7 @@ func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) {
}
// Records returns the list of records in a given zone.
func (p *DigitalOceanProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -509,15 +509,16 @@ func TestDigitalOceanAllRecords(t *testing.T) {
provider := &DigitalOceanProvider{
Client: &mockDigitalOceanClient{},
}
ctx := context.Background()
records, err := provider.Records()
records, err := provider.Records(ctx)
if err != nil {
t.Errorf("should not fail, %s", err)
}
require.Equal(t, 5, len(records))
provider.Client = &mockDigitalOceanRecordsFail{}
_, err = provider.Records()
_, err = provider.Records(ctx)
if err == nil {
t.Errorf("expected to fail, %s", err)
}

View File

@ -157,7 +157,7 @@ func (p *dnsimpleProvider) Zones() (map[string]dnsimple.Zone, error) {
}
// Records returns a list of endpoints in a given zone
func (p *dnsimpleProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
func (p *dnsimpleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -152,13 +152,14 @@ func testDnsimpleProviderZones(t *testing.T) {
}
func testDnsimpleProviderRecords(t *testing.T) {
ctx := context.Background()
mockProvider.accountID = "1"
result, err := mockProvider.Records()
result, err := mockProvider.Records(ctx)
assert.Nil(t, err)
assert.Equal(t, len(dnsimpleListRecordsResponse.Data), len(result))
mockProvider.accountID = "2"
_, err = mockProvider.Records()
_, err = mockProvider.Records(ctx)
assert.NotNil(t, err)
}
func testDnsimpleProviderApplyChanges(t *testing.T) {

View File

@ -588,7 +588,7 @@ func (d *dynProviderState) commit(client *dynect.Client) error {
// Records makes on average C + 2*Z requests (Z = number of zones): 1 login + 1 fetchAllRecords
// A cache is used to avoid querying for every single record found. C is proportional to the number
// of expired/changed records
func (d *dynProviderState) Records() ([]*endpoint.Endpoint, error) {
func (d *dynProviderState) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
client, err := d.login()
if err != nil {
return nil, err

View File

@ -172,7 +172,7 @@ func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Chan
}
// Records returns the list of endpoints
func (ep *ExoscaleProvider) Records() ([]*endpoint.Endpoint, error) {
func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
endpoints := make([]*endpoint.Endpoint, 0)
domains, err := ep.client.GetDomains(context.TODO())

View File

@ -107,7 +107,7 @@ func contains(arr []*endpoint.Endpoint, name string) bool {
func TestExoscaleGetRecords(t *testing.T) {
provider := NewExoscaleProviderWithClient("", "", "", NewExoscaleClientStub(), false)
if recs, err := provider.Records(); err == nil {
if recs, err := provider.Records(context.Background()); err == nil {
assert.Equal(t, 3, len(recs))
assert.True(t, contains(recs, "v1.foo.com"))
assert.True(t, contains(recs, "v2.bar.com"))

View File

@ -199,7 +199,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
}
// Records returns the list of records in all relevant zones.
func (p *GoogleProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -235,7 +235,7 @@ func TestGoogleRecords(t *testing.T) {
provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, originalEndpoints)
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, originalEndpoints)
@ -278,7 +278,7 @@ func TestGoogleRecordsFilter(t *testing.T) {
require.NoError(t, provider.CreateRecords(ignoredEndpoints))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
// assert that due to filtering no changes were made.
@ -296,7 +296,7 @@ func TestGoogleCreateRecords(t *testing.T) {
require.NoError(t, provider.CreateRecords(records))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -321,7 +321,7 @@ func TestGoogleUpdateRecords(t *testing.T) {
require.NoError(t, provider.UpdateRecords(updatedRecords, currentRecords))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -342,7 +342,7 @@ func TestGoogleDeleteRecords(t *testing.T) {
require.NoError(t, provider.DeleteRecords(originalEndpoints))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{})
@ -410,7 +410,7 @@ func TestGoogleApplyChanges(t *testing.T) {
require.NoError(t, provider.ApplyChanges(context.Background(), changes))
records, err := provider.Records()
records, err := provider.Records(context.Background())
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{
@ -465,9 +465,10 @@ func TestGoogleApplyChangesDryRun(t *testing.T) {
Delete: deleteRecords,
}
require.NoError(t, provider.ApplyChanges(context.Background(), changes))
ctx := context.Background()
require.NoError(t, provider.ApplyChanges(ctx, changes))
records, err := provider.Records()
records, err := provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, originalEndpoints)
@ -760,14 +761,15 @@ func setupGoogleRecords(t *testing.T, provider *GoogleProvider, endpoints []*end
clearGoogleRecords(t, provider, "zone-2-ext-dns-test-2-gcp-zalan-do")
clearGoogleRecords(t, provider, "zone-3-ext-dns-test-2-gcp-zalan-do")
records, err := provider.Records()
ctx := context.Background()
records, err := provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, []*endpoint.Endpoint{})
require.NoError(t, provider.CreateRecords(endpoints))
records, err = provider.Records()
records, err = provider.Records(ctx)
require.NoError(t, err)
validateEndpoints(t, records, endpoints)

View File

@ -135,7 +135,7 @@ func NewInfobloxProvider(infobloxConfig InfobloxConfig) (*InfobloxProvider, erro
}
// Records gets the current records.
func (p *InfobloxProvider) Records() (endpoints []*endpoint.Endpoint, err error) {
func (p *InfobloxProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) {
zones, err := p.zones()
if err != nil {
return nil, fmt.Errorf("could not fetch zones: %s", err)

View File

@ -355,7 +355,7 @@ func TestInfobloxRecords(t *testing.T) {
}
provider := newInfobloxProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, &client)
actual, err := provider.Records()
actual, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)

View File

@ -119,7 +119,7 @@ func (im *InMemoryProvider) Zones() map[string]string {
}
// Records returns the list of endpoints
func (im *InMemoryProvider) Records() ([]*endpoint.Endpoint, error) {
func (im *InMemoryProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
defer im.OnRecords()
endpoints := make([]*endpoint.Endpoint, 0)

View File

@ -233,7 +233,7 @@ func testInMemoryRecords(t *testing.T) {
im.client = c
f := filter{domain: ti.zone}
im.filter = &f
records, err := im.Records()
records, err := im.Records(context.Background())
if ti.expectError {
assert.Nil(t, records)
assert.EqualError(t, err, ErrZoneNotFound.Error())

View File

@ -111,7 +111,7 @@ func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) {
}
// Records returns the list of records in a given zone.
func (p *LinodeProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -257,7 +257,7 @@ func TestLinodeRecords(t *testing.T) {
mock.Anything,
).Return(createBazRecords(), nil).Once()
actual, err := provider.Records()
actual, err := provider.Records(context.Background())
require.NoError(t, err)
expected := []*endpoint.Endpoint{

View File

@ -141,7 +141,7 @@ func newNS1ProviderWithHTTPClient(config NS1Config, client *http.Client) (*NS1Pr
}
// Records returns the endpoints this provider knows about
func (p *NS1Provider) Records() ([]*endpoint.Endpoint, error) {
func (p *NS1Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.zonesFiltered()
if err != nil {
return nil, err

View File

@ -132,16 +132,18 @@ func TestNS1Records(t *testing.T) {
domainFilter: NewDomainFilter([]string{"foo.com."}),
zoneIDFilter: NewZoneIDFilter([]string{""}),
}
records, err := provider.Records()
ctx := context.Background()
records, err := provider.Records(ctx)
require.NoError(t, err)
assert.Equal(t, 1, len(records))
provider.client = &MockNS1GetZoneFail{}
_, err = provider.Records()
_, err = provider.Records(ctx)
require.Error(t, err)
provider.client = &MockNS1ListZonesFail{}
_, err = provider.Records()
_, err = provider.Records(ctx)
require.Error(t, err)
}

View File

@ -157,8 +157,7 @@ func (p *OCIProvider) newFilteredRecordOperations(endpoints []*endpoint.Endpoint
}
// Records returns the list of records in a given hosted zone.
func (p *OCIProvider) Records() ([]*endpoint.Endpoint, error) {
ctx := context.Background()
func (p *OCIProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.zones(ctx)
if err != nil {
return nil, errors.Wrap(err, "getting zones")

View File

@ -287,7 +287,7 @@ func TestOCIRecords(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provider := newOCIProvider(&mockOCIDNSClient{}, tc.domainFilter, tc.zoneIDFilter, false)
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
require.NoError(t, err)
require.ElementsMatch(t, tc.expected, endpoints)
})
@ -829,9 +829,11 @@ func TestOCIApplyChanges(t *testing.T) {
NewZoneIDFilter([]string{""}),
tc.dryRun,
)
err := provider.ApplyChanges(context.Background(), tc.changes)
ctx := context.Background()
err := provider.ApplyChanges(ctx, tc.changes)
require.Equal(t, tc.err, err)
endpoints, err := provider.Records()
endpoints, err := provider.Records(ctx)
require.NoError(t, err)
require.ElementsMatch(t, tc.expectedEndpoints, endpoints)
})

View File

@ -412,7 +412,7 @@ func (p *PDNSProvider) mutateRecords(endpoints []*endpoint.Endpoint, changetype
}
// Records returns all DNS records controlled by the configured PDNS server (for all zones)
func (p *PDNSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
func (p *PDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, _, err := p.client.ListZones()
if err != nil {

View File

@ -793,9 +793,11 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSRecords() {
client: &PDNSAPIClientStub{},
}
ctx := context.Background()
/* We test that endpoints are returned correctly for a Zone when Records() is called
*/
eps, err := p.Records()
eps, err := p.Records(ctx)
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), endpointsMixedRecords, eps)
@ -804,13 +806,13 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSRecords() {
p = &PDNSProvider{
client: &PDNSAPIClientStubListZoneFailure{},
}
_, err = p.Records()
_, err = p.Records(ctx)
assert.NotNil(suite.T(), err)
p = &PDNSProvider{
client: &PDNSAPIClientStubListZonesFailure{},
}
_, err = p.Records()
_, err = p.Records(ctx)
assert.NotNil(suite.T(), err)
}

View File

@ -27,7 +27,7 @@ import (
// Provider defines the interface DNS providers should implement.
type Provider interface {
Records() ([]*endpoint.Endpoint, error)
Records(ctx context.Context) ([]*endpoint.Endpoint, error)
ApplyChanges(ctx context.Context, changes *plan.Changes) error
}

View File

@ -96,7 +96,7 @@ func (p *RcodeZeroProvider) Zones() ([]*rc0.Zone, error) {
// Records returns resource records
//
// Decrypts TXT records if TXT-Encrypt flag is set and key is provided
func (p *RcodeZeroProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *RcodeZeroProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
if err != nil {

View File

@ -73,7 +73,9 @@ func TestRcodeZeroProvider_Records(t *testing.T) {
}),
}
endpoints, err := provider.Records() // should return 6 rrs
ctx := context.Background()
endpoints, err := provider.Records(ctx) // should return 6 rrs
if err != nil {
t.Errorf("should not fail, %s", err)
@ -82,7 +84,7 @@ func TestRcodeZeroProvider_Records(t *testing.T) {
mockRRSetService.TestErrorReturned = true
_, err = provider.Records()
_, err = provider.Records(ctx)
if err == nil {
t.Errorf("expected to fail, %s", err)
}

View File

@ -113,7 +113,7 @@ func NewRDNSProvider(config RDNSConfig) (*RDNSProvider, error) {
// Records returns all DNS records found in Rancher DNS(RDNS) etcdv3 backend. Depending on the record fields
// it may be mapped to one or two records of type A, TXT, A+TXT.
func (p RDNSProvider) Records() ([]*endpoint.Endpoint, error) {
func (p RDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
var result []*endpoint.Endpoint
rs, err := p.client.List(p.rootDomain)

View File

@ -113,7 +113,7 @@ func TestARecordTranslation(t *testing.T) {
rootDomain: "lb.rancher.cloud",
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -148,7 +148,7 @@ func TestTXTRecordTranslation(t *testing.T) {
rootDomain: "lb.rancher.cloud",
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -184,7 +184,7 @@ func TestAWithTXTRecordTranslation(t *testing.T) {
rootDomain: "lb.rancher.cloud",
}
endpoints, err := provider.Records()
endpoints, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
@ -248,7 +248,7 @@ func TestRDNSApplyChanges(t *testing.T) {
},
}
records, _ := provider.Records()
records, _ := provider.Records(context.Background())
for _, ep := range records {
if ep.DNSName == "p1xaf1.lb.rancher.cloud" {
changes2.UpdateOld = append(changes2.UpdateOld, ep)

View File

@ -95,7 +95,7 @@ func NewRfc2136Provider(host string, port int, zoneName string, insecure bool, k
}
// Records returns the list of records.
func (r rfc2136Provider) Records() ([]*endpoint.Endpoint, error) {
func (r rfc2136Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
rrs, err := r.List()
if err != nil {
return nil, err

View File

@ -113,7 +113,7 @@ func TestRfc2136GetRecordsMultipleTargets(t *testing.T) {
provider, err := createRfc2136StubProvider(stub)
assert.NoError(t, err)
recs, err := provider.Records()
recs, err := provider.Records(context.Background())
assert.NoError(t, err)
assert.Equal(t, 1, len(recs), "expected single record")
@ -142,7 +142,7 @@ func TestRfc2136GetRecords(t *testing.T) {
provider, err := createRfc2136StubProvider(stub)
assert.NoError(t, err)
recs, err := provider.Records()
recs, err := provider.Records(context.Background())
assert.NoError(t, err)
assert.Equal(t, 6, len(recs))

View File

@ -221,7 +221,7 @@ func (p *TransIPProvider) Zones() ([]transip.Domain, error) {
}
// Records returns the list of records in a given zone.
func (p *TransIPProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
if err != nil {
return nil, err

View File

@ -75,7 +75,7 @@ func NewVinylDNSProvider(domainFilter DomainFilter, zoneFilter ZoneIDFilter, dry
}, nil
}
func (p *vinyldnsProvider) Records() (endpoints []*endpoint.Endpoint, _ error) {
func (p *vinyldnsProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.client.Zones()
if err != nil {
return nil, err

View File

@ -90,18 +90,20 @@ func TestVinylDNSServices(t *testing.T) {
}
func testVinylDNSProviderRecords(t *testing.T) {
ctx := context.Background()
mockVinylDNSProvider.domainFilter = NewDomainFilter([]string{"example.com"})
result, err := mockVinylDNSProvider.Records()
result, err := mockVinylDNSProvider.Records(ctx)
assert.Nil(t, err)
assert.Equal(t, len(vinylDNSRecords), len(result))
mockVinylDNSProvider.zoneFilter = NewZoneIDFilter([]string{"0"})
result, err = mockVinylDNSProvider.Records()
result, err = mockVinylDNSProvider.Records(ctx)
assert.Nil(t, err)
assert.Equal(t, len(vinylDNSRecords), len(result))
mockVinylDNSProvider.zoneFilter = NewZoneIDFilter([]string{"1"})
result, err = mockVinylDNSProvider.Records()
result, err = mockVinylDNSProvider.Records(ctx)
assert.Nil(t, err)
assert.Equal(t, 0, len(result))
}

View File

@ -44,8 +44,8 @@ func NewAWSSDRegistry(provider provider.Provider, ownerID string) (*AWSSDRegistr
// Records calls AWS SD API and expects AWS SD provider to provider Owner/Resource information as a serialized
// value in the AWSSDDescriptionLabel value in the Labels map
func (sdr *AWSSDRegistry) Records() ([]*endpoint.Endpoint, error) {
records, err := sdr.provider.Records()
func (sdr *AWSSDRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
records, err := sdr.provider.Records(ctx)
if err != nil {
return nil, err
}

View File

@ -33,7 +33,7 @@ type inMemoryProvider struct {
onApplyChanges func(changes *plan.Changes)
}
func (p *inMemoryProvider) Records() ([]*endpoint.Endpoint, error) {
func (p *inMemoryProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
return p.endpoints, nil
}
@ -101,7 +101,7 @@ func TestAWSSDRegistryTest_Records(t *testing.T) {
}
r, _ := NewAWSSDRegistry(p, "owner")
records, _ := r.Records()
records, _ := r.Records(context.Background())
assert.True(t, testutils.SameEndpoints(records, expectedRecords))
}

View File

@ -37,8 +37,8 @@ func NewNoopRegistry(provider provider.Provider) (*NoopRegistry, error) {
}
// Records returns the current records from the dns provider
func (im *NoopRegistry) Records() ([]*endpoint.Endpoint, error) {
return im.provider.Records()
func (im *NoopRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
return im.provider.Records(ctx)
}
// ApplyChanges propagates changes to the dns provider

View File

@ -45,6 +45,7 @@ func testNoopInit(t *testing.T) {
}
func testNoopRecords(t *testing.T) {
ctx := context.Background()
p := provider.NewInMemoryProvider()
p.CreateZone("org")
providerRecords := []*endpoint.Endpoint{
@ -54,13 +55,13 @@ func testNoopRecords(t *testing.T) {
RecordType: endpoint.RecordTypeCNAME,
},
}
p.ApplyChanges(context.Background(), &plan.Changes{
p.ApplyChanges(ctx, &plan.Changes{
Create: providerRecords,
})
r, _ := NewNoopRegistry(p)
eps, err := r.Records()
eps, err := r.Records(ctx)
require.NoError(t, err)
assert.True(t, testutils.SameEndpoints(eps, providerRecords))
}
@ -131,6 +132,6 @@ func testNoopApplyChanges(t *testing.T) {
},
},
}))
res, _ := p.Records()
res, _ := p.Records(ctx)
assert.True(t, testutils.SameEndpoints(res, expectedUpdate))
}

View File

@ -30,7 +30,7 @@ import (
// each entry includes owner information
// ApplyChanges(changes *plan.Changes) propagates the changes to the DNS Provider API and correspondingly updates ownership depending on type of registry being used
type Registry interface {
Records() ([]*endpoint.Endpoint, error)
Records(ctx context.Context) ([]*endpoint.Endpoint, error)
ApplyChanges(ctx context.Context, changes *plan.Changes) error
}

View File

@ -61,7 +61,7 @@ func NewTXTRegistry(provider provider.Provider, txtPrefix, ownerID string, cache
// Records returns the current records from the registry excluding TXT Records
// If TXT records was created previously to indicate ownership its corresponding value
// will be added to the endpoints Labels map
func (im *TXTRegistry) Records() ([]*endpoint.Endpoint, error) {
func (im *TXTRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
// If we have the zones cached AND we have refreshed the cache since the
// last given interval, then just use the cached results.
if im.recordsCache != nil && time.Since(im.recordsCacheRefreshTime) < im.cacheInterval {
@ -69,7 +69,7 @@ func (im *TXTRegistry) Records() ([]*endpoint.Endpoint, error) {
return im.recordsCache, nil
}
records, err := im.provider.Records()
records, err := im.provider.Records(ctx)
if err != nil {
return nil, err
}

View File

@ -67,9 +67,10 @@ func testTXTRegistryRecords(t *testing.T) {
}
func testTXTRegistryRecordsPrefixed(t *testing.T) {
ctx := context.Background()
p := provider.NewInMemoryProvider()
p.CreateZone(testZone)
p.ApplyChanges(context.Background(), &plan.Changes{
p.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{
newEndpointWithOwnerAndLabels("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, "", endpoint.Labels{"foo": "somefoo"}),
newEndpointWithOwnerAndLabels("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, "", endpoint.Labels{"bar": "somebar"}),
@ -159,21 +160,22 @@ func testTXTRegistryRecordsPrefixed(t *testing.T) {
}
r, _ := NewTXTRegistry(p, "txt.", "owner", time.Hour)
records, _ := r.Records()
records, _ := r.Records(ctx)
assert.True(t, testutils.SameEndpoints(records, expectedRecords))
// Ensure prefix is case-insensitive
r, _ = NewTXTRegistry(p, "TxT.", "owner", time.Hour)
records, _ = r.Records()
records, _ = r.Records(ctx)
assert.True(t, testutils.SameEndpointLabels(records, expectedRecords))
}
func testTXTRegistryRecordsNoPrefix(t *testing.T) {
p := provider.NewInMemoryProvider()
ctx := context.Background()
p.CreateZone(testZone)
p.ApplyChanges(context.Background(), &plan.Changes{
p.ApplyChanges(ctx, &plan.Changes{
Create: []*endpoint.Endpoint{
newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""),
newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""),
@ -239,7 +241,7 @@ func testTXTRegistryRecordsNoPrefix(t *testing.T) {
}
r, _ := NewTXTRegistry(p, "", "owner", time.Hour)
records, _ := r.Records()
records, _ := r.Records(ctx)
assert.True(t, testutils.SameEndpoints(records, expectedRecords))
}