diff --git a/controller/controller.go b/controller/controller.go index aa12b6edd..fa5dba6ec 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -104,7 +104,8 @@ type Controller struct { // RunOnce runs a single iteration of a reconciliation loop. func (c *Controller) RunOnce() error { - records, err := c.Registry.Records() + ctx := context.Background() + records, err := c.Registry.Records(ctx) if err != nil { registryErrorsTotal.Inc() deprecatedRegistryErrors.Inc() @@ -112,7 +113,7 @@ func (c *Controller) RunOnce() error { } registryEndpointsTotal.Set(float64(len(records))) - ctx := context.WithValue(context.Background(), provider.RecordsContextKey, records) + ctx = context.WithValue(ctx, provider.RecordsContextKey, records) endpoints, err := c.Source.Endpoints() if err != nil { diff --git a/controller/controller_test.go b/controller/controller_test.go index 7bebe395e..c59795ef7 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -39,7 +39,7 @@ type mockProvider struct { } // Records returns the desired mock endpoints. -func (p *mockProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *mockProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { return p.RecordsStore, nil } diff --git a/provider/alibaba_cloud.go b/provider/alibaba_cloud.go index 181a4431b..cd07813f4 100644 --- a/provider/alibaba_cloud.go +++ b/provider/alibaba_cloud.go @@ -281,7 +281,7 @@ func (p *AlibabaCloudProvider) refreshStsToken(sleepTime time.Duration) { // Records gets the current records. // // Returns the current records or an error if the operation failed. -func (p *AlibabaCloudProvider) Records() (endpoints []*endpoint.Endpoint, err error) { +func (p *AlibabaCloudProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) { if p.privateZone { endpoints, err = p.privateZoneRecords() } else { diff --git a/provider/alibaba_cloud_test.go b/provider/alibaba_cloud_test.go index 87b629bef..dfe9a6a60 100644 --- a/provider/alibaba_cloud_test.go +++ b/provider/alibaba_cloud_test.go @@ -246,7 +246,7 @@ func newTestAlibabaCloudProvider(private bool) *AlibabaCloudProvider { func TestAlibabaCloudPrivateProvider_Records(t *testing.T) { p := newTestAlibabaCloudProvider(true) - endpoints, err := p.Records() + endpoints, err := p.Records(context.Background()) if err != nil { t.Errorf("Failed to get records: %v", err) } else { @@ -261,7 +261,7 @@ func TestAlibabaCloudPrivateProvider_Records(t *testing.T) { func TestAlibabaCloudProvider_Records(t *testing.T) { p := newTestAlibabaCloudProvider(false) - endpoints, err := p.Records() + endpoints, err := p.Records(context.Background()) if err != nil { t.Errorf("Failed to get records: %v", err) } else { @@ -302,8 +302,9 @@ func TestAlibabaCloudProvider_ApplyChanges(t *testing.T) { }, }, } - p.ApplyChanges(context.Background(), &changes) - endpoints, err := p.Records() + ctx := context.Background() + p.ApplyChanges(ctx, &changes) + endpoints, err := p.Records(ctx) if err != nil { t.Errorf("Failed to get records: %v", err) } else { @@ -318,7 +319,7 @@ func TestAlibabaCloudProvider_ApplyChanges(t *testing.T) { func TestAlibabaCloudProvider_Records_PrivateZone(t *testing.T) { p := newTestAlibabaCloudProvider(true) - endpoints, err := p.Records() + endpoints, err := p.Records(context.Background()) if err != nil { t.Errorf("Failed to get records: %v", err) } else { @@ -359,8 +360,9 @@ func TestAlibabaCloudProvider_ApplyChanges_PrivateZone(t *testing.T) { }, }, } - p.ApplyChanges(context.Background(), &changes) - endpoints, err := p.Records() + ctx := context.Background() + p.ApplyChanges(ctx, &changes) + endpoints, err := p.Records(ctx) if err != nil { t.Errorf("Failed to get records: %v", err) } else { diff --git a/provider/aws.go b/provider/aws.go index 2e23885cb..2cbfee31e 100644 --- a/provider/aws.go +++ b/provider/aws.go @@ -26,6 +26,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/route53" "github.com/linki/instrumented_http" @@ -101,11 +102,11 @@ var ( // Route53API is the subset of the AWS Route53 API that we actually use. Add methods as required. Signatures must match exactly. // mostly taken from: https://github.com/kubernetes/kubernetes/blob/853167624edb6bc0cfdcdfb88e746e178f5db36c/federation/pkg/dnsprovider/providers/aws/route53/stubs/route53api.go type Route53API interface { - ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error - ChangeResourceRecordSets(*route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error) - CreateHostedZone(*route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error) - ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error - ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error) + ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error + ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error) + CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error) + ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error + ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error) } // AWSProvider is an implementation of Provider for AWS Route53. @@ -184,7 +185,7 @@ func NewAWSProvider(awsConfig AWSConfig) (*AWSProvider, error) { } // Zones returns the list of hosted zones. -func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) { +func (p *AWSProvider) Zones(ctx context.Context) (map[string]*route53.HostedZone, error) { zones := make(map[string]*route53.HostedZone) var tagErr error @@ -204,7 +205,7 @@ func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) { // Only fetch tags if a tag filter was specified if !p.zoneTagFilter.IsEmpty() { - tags, err := p.tagsForZone(*zone.Id) + tags, err := p.tagsForZone(ctx, *zone.Id) if err != nil { tagErr = err return false @@ -220,7 +221,7 @@ func (p *AWSProvider) Zones() (map[string]*route53.HostedZone, error) { return true } - err := p.client.ListHostedZonesPages(&route53.ListHostedZonesInput{}, f) + err := p.client.ListHostedZonesPagesWithContext(ctx, &route53.ListHostedZonesInput{}, f) if err != nil { return nil, err } @@ -245,16 +246,16 @@ func wildcardUnescape(s string) string { } // Records returns the list of records in a given hosted zone. -func (p *AWSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { - zones, err := p.Zones() +func (p *AWSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { + zones, err := p.Zones(ctx) if err != nil { return nil, err } - return p.records(zones) + return p.records(ctx, zones) } -func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint.Endpoint, error) { +func (p *AWSProvider) records(ctx context.Context, zones map[string]*route53.HostedZone) ([]*endpoint.Endpoint, error) { endpoints := make([]*endpoint.Endpoint, 0) f := func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool) { for _, r := range resp.ResourceRecordSets { @@ -331,7 +332,7 @@ func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint HostedZoneId: z.Id, } - if err := p.client.ListResourceRecordSetsPages(params, f); err != nil { + if err := p.client.ListResourceRecordSetsPagesWithContext(ctx, params, f); err != nil { return nil, err } } @@ -340,36 +341,36 @@ func (p *AWSProvider) records(zones map[string]*route53.HostedZone) ([]*endpoint } // CreateRecords creates a given set of DNS records in the given hosted zone. -func (p *AWSProvider) CreateRecords(endpoints []*endpoint.Endpoint) error { - return p.doRecords(route53.ChangeActionCreate, endpoints) +func (p *AWSProvider) CreateRecords(ctx context.Context, endpoints []*endpoint.Endpoint) error { + return p.doRecords(ctx, route53.ChangeActionCreate, endpoints) } // UpdateRecords updates a given set of old records to a new set of records in a given hosted zone. -func (p *AWSProvider) UpdateRecords(endpoints, _ []*endpoint.Endpoint) error { - return p.doRecords(route53.ChangeActionUpsert, endpoints) +func (p *AWSProvider) UpdateRecords(ctx context.Context, endpoints, _ []*endpoint.Endpoint) error { + return p.doRecords(ctx, route53.ChangeActionUpsert, endpoints) } // DeleteRecords deletes a given set of DNS records in a given zone. -func (p *AWSProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error { - return p.doRecords(route53.ChangeActionDelete, endpoints) +func (p *AWSProvider) DeleteRecords(ctx context.Context, endpoints []*endpoint.Endpoint) error { + return p.doRecords(ctx, route53.ChangeActionDelete, endpoints) } -func (p *AWSProvider) doRecords(action string, endpoints []*endpoint.Endpoint) error { - zones, err := p.Zones() +func (p *AWSProvider) doRecords(ctx context.Context, action string, endpoints []*endpoint.Endpoint) error { + zones, err := p.Zones(ctx) if err != nil { return err } - records, err := p.records(zones) + records, err := p.records(ctx, zones) if err != nil { log.Errorf("getting records failed: %v", err) } - return p.submitChanges(p.newChanges(action, endpoints, records, zones), zones) + return p.submitChanges(ctx, p.newChanges(action, endpoints, records, zones), zones) } // ApplyChanges applies a given set of changes in a given zone. func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return err } @@ -377,7 +378,7 @@ func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e records, ok := ctx.Value(RecordsContextKey).([]*endpoint.Endpoint) if !ok { var err error - records, err = p.records(zones) + records, err = p.records(ctx, zones) if err != nil { log.Errorf("getting records failed: %v", err) } @@ -389,11 +390,11 @@ func (p *AWSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e combinedChanges = append(combinedChanges, p.newChanges(route53.ChangeActionUpsert, changes.UpdateNew, records, zones)...) combinedChanges = append(combinedChanges, p.newChanges(route53.ChangeActionDelete, changes.Delete, records, zones)...) - return p.submitChanges(combinedChanges, zones) + return p.submitChanges(ctx, combinedChanges, zones) } // submitChanges takes a zone and a collection of Changes and sends them as a single transaction. -func (p *AWSProvider) submitChanges(changes []*route53.Change, zones map[string]*route53.HostedZone) error { +func (p *AWSProvider) submitChanges(ctx context.Context, changes []*route53.Change, zones map[string]*route53.HostedZone) error { // return early if there is nothing to change if len(changes) == 0 { log.Info("All records are already up to date") @@ -425,7 +426,7 @@ func (p *AWSProvider) submitChanges(changes []*route53.Change, zones map[string] }, } - if _, err := p.client.ChangeResourceRecordSets(params); err != nil { + if _, err := p.client.ChangeResourceRecordSetsWithContext(ctx, params); err != nil { log.Errorf("Failure in zone %s [Id: %s]", aws.StringValue(zones[z].Name), z) log.Error(err) //TODO(ideahitme): consider changing the interface in cases when this error might be a concern for other components failedUpdate = true @@ -568,8 +569,8 @@ func (p *AWSProvider) newChange(action string, ep *endpoint.Endpoint, recordsCac return change, dualstack } -func (p *AWSProvider) tagsForZone(zoneID string) (map[string]string, error) { - response, err := p.client.ListTagsForResource(&route53.ListTagsForResourceInput{ +func (p *AWSProvider) tagsForZone(ctx context.Context, zoneID string) (map[string]string, error) { + response, err := p.client.ListTagsForResourceWithContext(ctx, &route53.ListTagsForResourceInput{ ResourceType: aws.String("hostedzone"), ResourceId: aws.String(zoneID), }) diff --git a/provider/aws_sd.go b/provider/aws_sd.go index ff2075b71..55ac9fc3e 100644 --- a/provider/aws_sd.go +++ b/provider/aws_sd.go @@ -138,7 +138,7 @@ func newSdNamespaceFilter(namespaceTypeConfig string) *sd.NamespaceFilter { } // Records returns list of all endpoints. -func (p *AWSSDProvider) Records() (endpoints []*endpoint.Endpoint, err error) { +func (p *AWSSDProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) { namespaces, err := p.ListNamespaces() if err != nil { return nil, err diff --git a/provider/aws_sd_test.go b/provider/aws_sd_test.go index d66963fce..131cb2e51 100644 --- a/provider/aws_sd_test.go +++ b/provider/aws_sd_test.go @@ -289,7 +289,7 @@ func TestAWSSDProvider_Records(t *testing.T) { provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "") - endpoints, _ := provider.Records() + endpoints, _ := provider.Records(context.Background()) assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints) } @@ -317,8 +317,10 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) { provider := newTestAWSSDProvider(api, NewDomainFilter([]string{}), "") + ctx := context.Background() + // apply creates - provider.ApplyChanges(context.Background(), &plan.Changes{ + provider.ApplyChanges(ctx, &plan.Changes{ Create: expectedEndpoints, }) @@ -330,16 +332,17 @@ func TestAWSSDProvider_ApplyChanges(t *testing.T) { assert.NotNil(t, existingServices["service3"]) // make sure instances were registered - endpoints, _ := provider.Records() + endpoints, _ := provider.Records(ctx) assert.True(t, testutils.SameEndpoints(expectedEndpoints, endpoints), "expected and actual endpoints don't match, expected=%v, actual=%v", expectedEndpoints, endpoints) + ctx = context.Background() // apply deletes - provider.ApplyChanges(context.Background(), &plan.Changes{ + provider.ApplyChanges(ctx, &plan.Changes{ Delete: expectedEndpoints, }) // make sure all instances are gone - endpoints, _ = provider.Records() + endpoints, _ = provider.Records(ctx) assert.Empty(t, endpoints) } diff --git a/provider/aws_test.go b/provider/aws_test.go index c109a4c2d..1dc112010 100644 --- a/provider/aws_test.go +++ b/provider/aws_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/route53" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -73,7 +74,7 @@ func NewRoute53APIStub() *Route53APIStub { } } -func (r *Route53APIStub) ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(p *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error { +func (r *Route53APIStub) ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(p *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error { output := route53.ListResourceRecordSetsOutput{} // TODO: Support optional input args. if len(r.recordSets) == 0 { output.ResourceRecordSets = []*route53.ResourceRecordSet{} @@ -103,29 +104,29 @@ func NewRoute53APICounter(w Route53API) *Route53APICounter { } } -func (c *Route53APICounter) ListResourceRecordSetsPages(input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool)) error { +func (c *Route53APICounter) ListResourceRecordSetsPagesWithContext(ctx context.Context, input *route53.ListResourceRecordSetsInput, fn func(resp *route53.ListResourceRecordSetsOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error { c.calls["ListResourceRecordSetsPages"]++ - return c.wrapped.ListResourceRecordSetsPages(input, fn) + return c.wrapped.ListResourceRecordSetsPagesWithContext(ctx, input, fn) } -func (c *Route53APICounter) ChangeResourceRecordSets(input *route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error) { +func (c *Route53APICounter) ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error) { c.calls["ChangeResourceRecordSets"]++ - return c.wrapped.ChangeResourceRecordSets(input) + return c.wrapped.ChangeResourceRecordSetsWithContext(ctx, input) } -func (c *Route53APICounter) CreateHostedZone(input *route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error) { +func (c *Route53APICounter) CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error) { c.calls["CreateHostedZone"]++ - return c.wrapped.CreateHostedZone(input) + return c.wrapped.CreateHostedZoneWithContext(ctx, input) } -func (c *Route53APICounter) ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error { +func (c *Route53APICounter) ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(resp *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error { c.calls["ListHostedZonesPages"]++ - return c.wrapped.ListHostedZonesPages(input, fn) + return c.wrapped.ListHostedZonesPagesWithContext(ctx, input, fn) } -func (c *Route53APICounter) ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error) { +func (c *Route53APICounter) ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error) { c.calls["ListTagsForResource"]++ - return c.wrapped.ListTagsForResource(input) + return c.wrapped.ListTagsForResourceWithContext(ctx, input) } // Route53 stores wildcards escaped: http://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DomainNameFormat.html?shortFooter=true#domain-name-format-asterisk @@ -136,7 +137,7 @@ func wildcardEscape(s string) string { return s } -func (r *Route53APIStub) ListTagsForResource(input *route53.ListTagsForResourceInput) (*route53.ListTagsForResourceOutput, error) { +func (r *Route53APIStub) ListTagsForResourceWithContext(ctx context.Context, input *route53.ListTagsForResourceInput, opts ...request.Option) (*route53.ListTagsForResourceOutput, error) { if aws.StringValue(input.ResourceType) == "hostedzone" { tags := r.zoneTags[aws.StringValue(input.ResourceId)] return &route53.ListTagsForResourceOutput{ @@ -150,7 +151,7 @@ func (r *Route53APIStub) ListTagsForResource(input *route53.ListTagsForResourceI return &route53.ListTagsForResourceOutput{}, nil } -func (r *Route53APIStub) ChangeResourceRecordSets(input *route53.ChangeResourceRecordSetsInput) (*route53.ChangeResourceRecordSetsOutput, error) { +func (r *Route53APIStub) ChangeResourceRecordSetsWithContext(ctx context.Context, input *route53.ChangeResourceRecordSetsInput, opts ...request.Option) (*route53.ChangeResourceRecordSetsOutput, error) { if r.m.isMocked("ChangeResourceRecordSets", input) { return r.m.ChangeResourceRecordSets(input) } @@ -209,7 +210,7 @@ func (r *Route53APIStub) ChangeResourceRecordSets(input *route53.ChangeResourceR return output, nil // TODO: We should ideally return status etc, but we don't' use that yet. } -func (r *Route53APIStub) ListHostedZonesPages(input *route53.ListHostedZonesInput, fn func(p *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool)) error { +func (r *Route53APIStub) ListHostedZonesPagesWithContext(ctx context.Context, input *route53.ListHostedZonesInput, fn func(p *route53.ListHostedZonesOutput, lastPage bool) (shouldContinue bool), opts ...request.Option) error { output := &route53.ListHostedZonesOutput{} for _, zone := range r.zones { output.HostedZones = append(output.HostedZones, zone) @@ -219,7 +220,7 @@ func (r *Route53APIStub) ListHostedZonesPages(input *route53.ListHostedZonesInpu return nil } -func (r *Route53APIStub) CreateHostedZone(input *route53.CreateHostedZoneInput) (*route53.CreateHostedZoneOutput, error) { +func (r *Route53APIStub) CreateHostedZoneWithContext(ctx context.Context, input *route53.CreateHostedZoneInput, opts ...request.Option) (*route53.CreateHostedZoneOutput, error) { name := aws.StringValue(input.Name) id := "/hostedzone/" + name if _, ok := r.zones[id]; ok { @@ -302,7 +303,7 @@ func TestAWSZones(t *testing.T) { } { provider, _ := newAWSProviderWithTagFilter(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), ti.zoneIDFilter, ti.zoneTypeFilter, ti.zoneTagFilter, defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{}) - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) require.NoError(t, err) validateAWSZones(t, zones, ti.expectedZones) @@ -328,7 +329,7 @@ func TestAWSRecords(t *testing.T) { endpoint.NewEndpointWithTTL("geolocation-test.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "4.3.2.1").WithSetIdentifier("test-set-2").WithProviderSpecific(providerSpecificGeolocationCountryCode, "DE"), }) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -362,9 +363,9 @@ func TestAWSCreateRecords(t *testing.T) { endpoint.NewEndpoint("create-test-multiple.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "8.8.8.8", "8.8.4.4"), } - require.NoError(t, provider.CreateRecords(records)) + require.NoError(t, provider.CreateRecords(context.Background(), records)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -397,9 +398,9 @@ func TestAWSUpdateRecords(t *testing.T) { endpoint.NewEndpoint("create-test-multiple.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, "1.2.3.4", "4.3.2.1"), } - require.NoError(t, provider.UpdateRecords(updatedRecords, currentRecords)) + require.NoError(t, provider.UpdateRecords(context.Background(), updatedRecords, currentRecords)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -422,9 +423,9 @@ func TestAWSDeleteRecords(t *testing.T) { provider, _ := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), false, false, originalEndpoints) - require.NoError(t, provider.DeleteRecords(originalEndpoints)) + require.NoError(t, provider.DeleteRecords(context.Background(), originalEndpoints)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) @@ -439,9 +440,10 @@ func TestAWSApplyChanges(t *testing.T) { }{ {"no cache", func(p *AWSProvider) context.Context { return context.Background() }, 3}, {"cached", func(p *AWSProvider) context.Context { - records, err := p.Records() + ctx := context.Background() + records, err := p.Records(ctx) require.NoError(t, err) - return context.WithValue(context.Background(), RecordsContextKey, records) + return context.WithValue(ctx, RecordsContextKey, records) }, 0}, } @@ -506,7 +508,7 @@ func TestAWSApplyChanges(t *testing.T) { assert.Equal(t, 1, counter.calls["ListHostedZonesPages"], tt.name) assert.Equal(t, tt.listRRSets, counter.calls["ListResourceRecordSetsPages"], tt.name) - records, err := provider.Records() + records, err := provider.Records(ctx) require.NoError(t, err, tt.name) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -578,9 +580,11 @@ func TestAWSApplyChangesDryRun(t *testing.T) { Delete: deleteRecords, } - require.NoError(t, provider.ApplyChanges(context.Background(), changes)) + ctx := context.Background() - records, err := provider.Records() + require.NoError(t, provider.ApplyChanges(ctx, changes)) + + records, err := provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, originalEndpoints) @@ -698,14 +702,15 @@ func TestAWSsubmitChanges(t *testing.T) { } } - zones, _ := provider.Zones() - records, _ := provider.Records() + ctx := context.Background() + zones, _ := provider.Zones(ctx) + records, _ := provider.Records(ctx) cs := make([]*route53.Change, 0, len(endpoints)) cs = append(cs, provider.newChanges(route53.ChangeActionCreate, endpoints, records, zones)...) - require.NoError(t, provider.submitChanges(cs, zones)) + require.NoError(t, provider.submitChanges(ctx, cs, zones)) - records, err := provider.Records() + records, err := provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, endpoints) @@ -715,15 +720,16 @@ func TestAWSsubmitChangesError(t *testing.T) { provider, clientStub := newAWSProvider(t, NewDomainFilter([]string{"ext-dns-test-2.teapot.zalan.do."}), NewZoneIDFilter([]string{}), NewZoneTypeFilter(""), defaultEvaluateTargetHealth, false, []*endpoint.Endpoint{}) clientStub.MockMethod("ChangeResourceRecordSets", mock.Anything).Return(nil, fmt.Errorf("Mock route53 failure")) - zones, err := provider.Zones() + ctx := context.Background() + zones, err := provider.Zones(ctx) require.NoError(t, err) - records, err := provider.Records() + records, err := provider.Records(ctx) require.NoError(t, err) ep := endpoint.NewEndpointWithTTL("fail.zone-1.ext-dns-test-2.teapot.zalan.do", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.0.0.1") cs := provider.newChanges(route53.ChangeActionCreate, []*endpoint.Endpoint{ep}, records, zones) - require.Error(t, provider.submitChanges(cs, zones)) + require.Error(t, provider.submitChanges(ctx, cs, zones)) } func TestAWSBatchChangeSet(t *testing.T) { @@ -853,7 +859,7 @@ func TestAWSCreateRecordsWithCNAME(t *testing.T) { {DNSName: "create-test.zone-1.ext-dns-test-2.teapot.zalan.do", Targets: endpoint.Targets{"foo.example.org"}, RecordType: endpoint.RecordTypeCNAME}, } - require.NoError(t, provider.CreateRecords(records)) + require.NoError(t, provider.CreateRecords(context.Background(), records)) recordSets := listAWSRecords(t, provider.client, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") @@ -906,7 +912,7 @@ func TestAWSCreateRecordsWithALIAS(t *testing.T) { }, } - require.NoError(t, provider.CreateRecords(records)) + require.NoError(t, provider.CreateRecords(context.Background(), records)) recordSets := listAWSRecords(t, provider.client, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") @@ -1093,7 +1099,7 @@ func createAWSZone(t *testing.T, provider *AWSProvider, zone *route53.HostedZone HostedZoneConfig: zone.Config, } - if _, err := provider.client.CreateHostedZone(params); err != nil { + if _, err := provider.client.CreateHostedZoneWithContext(context.Background(), params); err != nil { require.EqualError(t, err, route53.ErrCodeHostedZoneAlreadyExists) } } @@ -1103,25 +1109,26 @@ func setupAWSRecords(t *testing.T, provider *AWSProvider, endpoints []*endpoint. clearAWSRecords(t, provider, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.") clearAWSRecords(t, provider, "/hostedzone/zone-3.ext-dns-test-2.teapot.zalan.do.") - records, err := provider.Records() + ctx := context.Background() + records, err := provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{}) - require.NoError(t, provider.CreateRecords(endpoints)) + require.NoError(t, provider.CreateRecords(context.Background(), endpoints)) escapeAWSRecords(t, provider, "/hostedzone/zone-1.ext-dns-test-2.teapot.zalan.do.") escapeAWSRecords(t, provider, "/hostedzone/zone-2.ext-dns-test-2.teapot.zalan.do.") escapeAWSRecords(t, provider, "/hostedzone/zone-3.ext-dns-test-2.teapot.zalan.do.") - _, err = provider.Records() + _, err = provider.Records(ctx) require.NoError(t, err) } func listAWSRecords(t *testing.T, client Route53API, zone string) []*route53.ResourceRecordSet { recordSets := []*route53.ResourceRecordSet{} - require.NoError(t, client.ListResourceRecordSetsPages(&route53.ListResourceRecordSetsInput{ + require.NoError(t, client.ListResourceRecordSetsPagesWithContext(context.Background(), &route53.ListResourceRecordSetsInput{ HostedZoneId: aws.String(zone), }, func(resp *route53.ListResourceRecordSetsOutput, _ bool) bool { for _, recordSet := range resp.ResourceRecordSets { @@ -1145,7 +1152,7 @@ func clearAWSRecords(t *testing.T, provider *AWSProvider, zone string) { } if len(changes) != 0 { - _, err := provider.client.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err := provider.client.ChangeResourceRecordSetsWithContext(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: aws.String(zone), ChangeBatch: &route53.ChangeBatch{ Changes: changes, @@ -1168,7 +1175,7 @@ func escapeAWSRecords(t *testing.T, provider *AWSProvider, zone string) { } if len(changes) != 0 { - _, err := provider.client.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err := provider.client.ChangeResourceRecordSetsWithContext(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: aws.String(zone), ChangeBatch: &route53.ChangeBatch{ Changes: changes, diff --git a/provider/azure.go b/provider/azure.go index efc949357..af7c037ce 100644 --- a/provider/azure.go +++ b/provider/azure.go @@ -180,8 +180,7 @@ func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePri // Records gets the current records. // // Returns the current records or an error if the operation failed. -func (p *AzureProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { - ctx := context.Background() +func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, err := p.zones(ctx) if err != nil { return nil, err diff --git a/provider/azure_private_dns.go b/provider/azure_private_dns.go index 693156feb..74e1697ab 100644 --- a/provider/azure_private_dns.go +++ b/provider/azure_private_dns.go @@ -90,8 +90,7 @@ func NewAzurePrivateDNSProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFi // Records gets the current records. // // Returns the current records or an error if the operation failed. -func (p *AzurePrivateDNSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { - ctx := context.Background() +func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, err := p.zones(ctx) if err != nil { return nil, err diff --git a/provider/azure_privatedns_test.go b/provider/azure_privatedns_test.go index 03c468615..b4d311e90 100644 --- a/provider/azure_privatedns_test.go +++ b/provider/azure_privatedns_test.go @@ -266,7 +266,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) { t.Fatal(err) } - actual, err := provider.Records() + actual, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) @@ -302,7 +302,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) { t.Fatal(err) } - actual, err := provider.Records() + actual, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) diff --git a/provider/azure_test.go b/provider/azure_test.go index 5f48eb8f7..6fbb88fcb 100644 --- a/provider/azure_test.go +++ b/provider/azure_test.go @@ -273,7 +273,8 @@ func TestAzureRecord(t *testing.T) { t.Fatal(err) } - actual, err := provider.Records() + ctx := context.Background() + actual, err := provider.Records(ctx) if err != nil { t.Fatal(err) @@ -309,7 +310,8 @@ func TestAzureMultiRecord(t *testing.T) { t.Fatal(err) } - actual, err := provider.Records() + ctx := context.Background() + actual, err := provider.Records(ctx) if err != nil { t.Fatal(err) diff --git a/provider/cloudflare.go b/provider/cloudflare.go index 3fd131871..b1ff72f5d 100644 --- a/provider/cloudflare.go +++ b/provider/cloudflare.go @@ -176,7 +176,7 @@ func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) { } // Records returns the list of records. -func (p *CloudFlareProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/cloudflare_test.go b/provider/cloudflare_test.go index 4f29f6615..6bbbb6f3d 100644 --- a/provider/cloudflare_test.go +++ b/provider/cloudflare_test.go @@ -491,19 +491,21 @@ func TestRecords(t *testing.T) { provider := &CloudFlareProvider{ Client: &mockCloudFlareClient{}, } - records, err := provider.Records() + ctx := context.Background() + + records, err := provider.Records(ctx) if err != nil { t.Errorf("should not fail, %s", err) } assert.Equal(t, 1, len(records)) provider.Client = &mockCloudFlareDNSRecordsFail{} - _, err = provider.Records() + _, err = provider.Records(ctx) if err == nil { t.Errorf("expected to fail") } provider.Client = &mockCloudFlareListZonesFail{} - _, err = provider.Records() + _, err = provider.Records(ctx) if err == nil { t.Errorf("expected to fail") } diff --git a/provider/coredns.go b/provider/coredns.go index 3c6418943..d630c9234 100644 --- a/provider/coredns.go +++ b/provider/coredns.go @@ -260,7 +260,7 @@ func NewCoreDNSProvider(domainFilter DomainFilter, prefix string, dryRun bool) ( // Records returns all DNS records found in CoreDNS etcd backend. Depending on the record fields // it may be mapped to one or two records of type A, CNAME, TXT, A+TXT, CNAME+TXT -func (p coreDNSProvider) Records() ([]*endpoint.Endpoint, error) { +func (p coreDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { var result []*endpoint.Endpoint services, err := p.client.GetServices(p.coreDNSPrefix) if err != nil { diff --git a/provider/coredns_test.go b/provider/coredns_test.go index 3fd64a47c..588acc549 100644 --- a/provider/coredns_test.go +++ b/provider/coredns_test.go @@ -66,7 +66,7 @@ func TestAServiceTranslation(t *testing.T) { client: client, coreDNSPrefix: defaultCoreDNSPrefix, } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -98,7 +98,7 @@ func TestCNAMEServiceTranslation(t *testing.T) { client: client, coreDNSPrefix: defaultCoreDNSPrefix, } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -130,7 +130,7 @@ func TestTXTServiceTranslation(t *testing.T) { client: client, coreDNSPrefix: defaultCoreDNSPrefix, } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -164,7 +164,7 @@ func TestAWithTXTServiceTranslation(t *testing.T) { client: client, coreDNSPrefix: defaultCoreDNSPrefix, } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -206,7 +206,7 @@ func TestCNAMEWithTXTServiceTranslation(t *testing.T) { client: client, coreDNSPrefix: defaultCoreDNSPrefix, } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -264,7 +264,7 @@ func TestCoreDNSApplyChanges(t *testing.T) { endpoint.NewEndpoint("domain1.local", "A", "6.6.6.6"), }, } - records, _ := coredns.Records() + records, _ := coredns.Records(context.Background()) for _, ep := range records { if ep.DNSName == "domain1.local" { changes2.UpdateOld = append(changes2.UpdateOld, ep) @@ -296,7 +296,8 @@ func TestCoreDNSApplyChanges(t *testing.T) { } func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) { - records, _ := provider.Records() + ctx := context.Background() + records, _ := provider.Records(ctx) for _, col := range [][]*endpoint.Endpoint{changes.Create, changes.UpdateNew, changes.Delete} { for _, record := range col { for _, existingRecord := range records { @@ -306,7 +307,7 @@ func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) { } } } - provider.ApplyChanges(context.Background(), changes) + provider.ApplyChanges(ctx, changes) } func validateServices(services, expectedServices map[string]*Service, t *testing.T, step int) { diff --git a/provider/designate.go b/provider/designate.go index 1144beb5f..58a428e6d 100644 --- a/provider/designate.go +++ b/provider/designate.go @@ -308,7 +308,7 @@ func (p designateProvider) getHostZoneID(hostname string, managedZones map[strin } // Records returns the list of records. -func (p designateProvider) Records() ([]*endpoint.Endpoint, error) { +func (p designateProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { var result []*endpoint.Endpoint managedZones, err := p.getZones() if err != nil { diff --git a/provider/designate_test.go b/provider/designate_test.go index 9d1675d84..357dd51e4 100644 --- a/provider/designate_test.go +++ b/provider/designate_test.go @@ -303,7 +303,7 @@ func TestDesignateRecords(t *testing.T) { }, } - endpoints, err := client.ToProvider().Records() + endpoints, err := client.ToProvider().Records(context.Background()) if err != nil { t.Fatal(err) } diff --git a/provider/digital_ocean.go b/provider/digital_ocean.go index b4233aaa7..a97bfbb4e 100644 --- a/provider/digital_ocean.go +++ b/provider/digital_ocean.go @@ -94,7 +94,7 @@ func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) { } // Records returns the list of records in a given zone. -func (p *DigitalOceanProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/digital_ocean_test.go b/provider/digital_ocean_test.go index 34e77ca69..8ca91ca15 100644 --- a/provider/digital_ocean_test.go +++ b/provider/digital_ocean_test.go @@ -509,15 +509,16 @@ func TestDigitalOceanAllRecords(t *testing.T) { provider := &DigitalOceanProvider{ Client: &mockDigitalOceanClient{}, } + ctx := context.Background() - records, err := provider.Records() + records, err := provider.Records(ctx) if err != nil { t.Errorf("should not fail, %s", err) } require.Equal(t, 5, len(records)) provider.Client = &mockDigitalOceanRecordsFail{} - _, err = provider.Records() + _, err = provider.Records(ctx) if err == nil { t.Errorf("expected to fail, %s", err) } diff --git a/provider/dnsimple.go b/provider/dnsimple.go index 0b01c6f20..54f0da4b1 100644 --- a/provider/dnsimple.go +++ b/provider/dnsimple.go @@ -157,7 +157,7 @@ func (p *dnsimpleProvider) Zones() (map[string]dnsimple.Zone, error) { } // Records returns a list of endpoints in a given zone -func (p *dnsimpleProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { +func (p *dnsimpleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/dnsimple_test.go b/provider/dnsimple_test.go index 42149609f..d391e27b6 100644 --- a/provider/dnsimple_test.go +++ b/provider/dnsimple_test.go @@ -152,13 +152,14 @@ func testDnsimpleProviderZones(t *testing.T) { } func testDnsimpleProviderRecords(t *testing.T) { + ctx := context.Background() mockProvider.accountID = "1" - result, err := mockProvider.Records() + result, err := mockProvider.Records(ctx) assert.Nil(t, err) assert.Equal(t, len(dnsimpleListRecordsResponse.Data), len(result)) mockProvider.accountID = "2" - _, err = mockProvider.Records() + _, err = mockProvider.Records(ctx) assert.NotNil(t, err) } func testDnsimpleProviderApplyChanges(t *testing.T) { diff --git a/provider/dyn.go b/provider/dyn.go index 0a825fb0a..b08a84529 100644 --- a/provider/dyn.go +++ b/provider/dyn.go @@ -588,7 +588,7 @@ func (d *dynProviderState) commit(client *dynect.Client) error { // Records makes on average C + 2*Z requests (Z = number of zones): 1 login + 1 fetchAllRecords // A cache is used to avoid querying for every single record found. C is proportional to the number // of expired/changed records -func (d *dynProviderState) Records() ([]*endpoint.Endpoint, error) { +func (d *dynProviderState) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { client, err := d.login() if err != nil { return nil, err diff --git a/provider/exoscale.go b/provider/exoscale.go index d709137cb..72eeb61a5 100644 --- a/provider/exoscale.go +++ b/provider/exoscale.go @@ -172,7 +172,7 @@ func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Chan } // Records returns the list of endpoints -func (ep *ExoscaleProvider) Records() ([]*endpoint.Endpoint, error) { +func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { endpoints := make([]*endpoint.Endpoint, 0) domains, err := ep.client.GetDomains(context.TODO()) diff --git a/provider/exoscale_test.go b/provider/exoscale_test.go index fd6e313c3..7a33bd9a1 100644 --- a/provider/exoscale_test.go +++ b/provider/exoscale_test.go @@ -107,7 +107,7 @@ func contains(arr []*endpoint.Endpoint, name string) bool { func TestExoscaleGetRecords(t *testing.T) { provider := NewExoscaleProviderWithClient("", "", "", NewExoscaleClientStub(), false) - if recs, err := provider.Records(); err == nil { + if recs, err := provider.Records(context.Background()); err == nil { assert.Equal(t, 3, len(recs)) assert.True(t, contains(recs, "v1.foo.com")) assert.True(t, contains(recs, "v2.bar.com")) diff --git a/provider/google.go b/provider/google.go index aee2152bc..97d46bd13 100644 --- a/provider/google.go +++ b/provider/google.go @@ -199,7 +199,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) { } // Records returns the list of records in all relevant zones. -func (p *GoogleProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { +func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/google_test.go b/provider/google_test.go index f45acd5b4..618412a5b 100644 --- a/provider/google_test.go +++ b/provider/google_test.go @@ -235,7 +235,7 @@ func TestGoogleRecords(t *testing.T) { provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, originalEndpoints) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, originalEndpoints) @@ -278,7 +278,7 @@ func TestGoogleRecordsFilter(t *testing.T) { require.NoError(t, provider.CreateRecords(ignoredEndpoints)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) // assert that due to filtering no changes were made. @@ -296,7 +296,7 @@ func TestGoogleCreateRecords(t *testing.T) { require.NoError(t, provider.CreateRecords(records)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -321,7 +321,7 @@ func TestGoogleUpdateRecords(t *testing.T) { require.NoError(t, provider.UpdateRecords(updatedRecords, currentRecords)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -342,7 +342,7 @@ func TestGoogleDeleteRecords(t *testing.T) { require.NoError(t, provider.DeleteRecords(originalEndpoints)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{}) @@ -410,7 +410,7 @@ func TestGoogleApplyChanges(t *testing.T) { require.NoError(t, provider.ApplyChanges(context.Background(), changes)) - records, err := provider.Records() + records, err := provider.Records(context.Background()) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{ @@ -465,9 +465,10 @@ func TestGoogleApplyChangesDryRun(t *testing.T) { Delete: deleteRecords, } - require.NoError(t, provider.ApplyChanges(context.Background(), changes)) + ctx := context.Background() + require.NoError(t, provider.ApplyChanges(ctx, changes)) - records, err := provider.Records() + records, err := provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, originalEndpoints) @@ -760,14 +761,15 @@ func setupGoogleRecords(t *testing.T, provider *GoogleProvider, endpoints []*end clearGoogleRecords(t, provider, "zone-2-ext-dns-test-2-gcp-zalan-do") clearGoogleRecords(t, provider, "zone-3-ext-dns-test-2-gcp-zalan-do") - records, err := provider.Records() + ctx := context.Background() + records, err := provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, []*endpoint.Endpoint{}) require.NoError(t, provider.CreateRecords(endpoints)) - records, err = provider.Records() + records, err = provider.Records(ctx) require.NoError(t, err) validateEndpoints(t, records, endpoints) diff --git a/provider/infoblox.go b/provider/infoblox.go index 5cae02bbb..208e93a83 100644 --- a/provider/infoblox.go +++ b/provider/infoblox.go @@ -135,7 +135,7 @@ func NewInfobloxProvider(infobloxConfig InfobloxConfig) (*InfobloxProvider, erro } // Records gets the current records. -func (p *InfobloxProvider) Records() (endpoints []*endpoint.Endpoint, err error) { +func (p *InfobloxProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) { zones, err := p.zones() if err != nil { return nil, fmt.Errorf("could not fetch zones: %s", err) diff --git a/provider/infoblox_test.go b/provider/infoblox_test.go index 2d7beb46d..b7cc8903b 100644 --- a/provider/infoblox_test.go +++ b/provider/infoblox_test.go @@ -355,7 +355,7 @@ func TestInfobloxRecords(t *testing.T) { } provider := newInfobloxProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, &client) - actual, err := provider.Records() + actual, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) diff --git a/provider/inmemory.go b/provider/inmemory.go index e8a42d6e4..470393391 100644 --- a/provider/inmemory.go +++ b/provider/inmemory.go @@ -119,7 +119,7 @@ func (im *InMemoryProvider) Zones() map[string]string { } // Records returns the list of endpoints -func (im *InMemoryProvider) Records() ([]*endpoint.Endpoint, error) { +func (im *InMemoryProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { defer im.OnRecords() endpoints := make([]*endpoint.Endpoint, 0) diff --git a/provider/inmemory_test.go b/provider/inmemory_test.go index b31fd2c08..bb13cb354 100644 --- a/provider/inmemory_test.go +++ b/provider/inmemory_test.go @@ -233,7 +233,7 @@ func testInMemoryRecords(t *testing.T) { im.client = c f := filter{domain: ti.zone} im.filter = &f - records, err := im.Records() + records, err := im.Records(context.Background()) if ti.expectError { assert.Nil(t, records) assert.EqualError(t, err, ErrZoneNotFound.Error()) diff --git a/provider/linode.go b/provider/linode.go index e80ead9f8..d89b841a9 100644 --- a/provider/linode.go +++ b/provider/linode.go @@ -111,7 +111,7 @@ func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) { } // Records returns the list of records in a given zone. -func (p *LinodeProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/linode_test.go b/provider/linode_test.go index c033f70c6..5419cb147 100644 --- a/provider/linode_test.go +++ b/provider/linode_test.go @@ -257,7 +257,7 @@ func TestLinodeRecords(t *testing.T) { mock.Anything, ).Return(createBazRecords(), nil).Once() - actual, err := provider.Records() + actual, err := provider.Records(context.Background()) require.NoError(t, err) expected := []*endpoint.Endpoint{ diff --git a/provider/ns1.go b/provider/ns1.go index b203be797..2bda2e907 100644 --- a/provider/ns1.go +++ b/provider/ns1.go @@ -141,7 +141,7 @@ func newNS1ProviderWithHTTPClient(config NS1Config, client *http.Client) (*NS1Pr } // Records returns the endpoints this provider knows about -func (p *NS1Provider) Records() ([]*endpoint.Endpoint, error) { +func (p *NS1Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.zonesFiltered() if err != nil { return nil, err diff --git a/provider/ns1_test.go b/provider/ns1_test.go index f5bc2b944..2bb5fb239 100644 --- a/provider/ns1_test.go +++ b/provider/ns1_test.go @@ -132,16 +132,18 @@ func TestNS1Records(t *testing.T) { domainFilter: NewDomainFilter([]string{"foo.com."}), zoneIDFilter: NewZoneIDFilter([]string{""}), } - records, err := provider.Records() + ctx := context.Background() + + records, err := provider.Records(ctx) require.NoError(t, err) assert.Equal(t, 1, len(records)) provider.client = &MockNS1GetZoneFail{} - _, err = provider.Records() + _, err = provider.Records(ctx) require.Error(t, err) provider.client = &MockNS1ListZonesFail{} - _, err = provider.Records() + _, err = provider.Records(ctx) require.Error(t, err) } diff --git a/provider/oci.go b/provider/oci.go index 933bf67de..29bce2cd2 100644 --- a/provider/oci.go +++ b/provider/oci.go @@ -157,8 +157,7 @@ func (p *OCIProvider) newFilteredRecordOperations(endpoints []*endpoint.Endpoint } // Records returns the list of records in a given hosted zone. -func (p *OCIProvider) Records() ([]*endpoint.Endpoint, error) { - ctx := context.Background() +func (p *OCIProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.zones(ctx) if err != nil { return nil, errors.Wrap(err, "getting zones") diff --git a/provider/oci_test.go b/provider/oci_test.go index b073eb79b..9840025b3 100644 --- a/provider/oci_test.go +++ b/provider/oci_test.go @@ -287,7 +287,7 @@ func TestOCIRecords(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { provider := newOCIProvider(&mockOCIDNSClient{}, tc.domainFilter, tc.zoneIDFilter, false) - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) require.NoError(t, err) require.ElementsMatch(t, tc.expected, endpoints) }) @@ -829,9 +829,11 @@ func TestOCIApplyChanges(t *testing.T) { NewZoneIDFilter([]string{""}), tc.dryRun, ) - err := provider.ApplyChanges(context.Background(), tc.changes) + + ctx := context.Background() + err := provider.ApplyChanges(ctx, tc.changes) require.Equal(t, tc.err, err) - endpoints, err := provider.Records() + endpoints, err := provider.Records(ctx) require.NoError(t, err) require.ElementsMatch(t, tc.expectedEndpoints, endpoints) }) diff --git a/provider/pdns.go b/provider/pdns.go index 52d3bdee7..afacdf33c 100644 --- a/provider/pdns.go +++ b/provider/pdns.go @@ -412,7 +412,7 @@ func (p *PDNSProvider) mutateRecords(endpoints []*endpoint.Endpoint, changetype } // Records returns all DNS records controlled by the configured PDNS server (for all zones) -func (p *PDNSProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { +func (p *PDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, _, err := p.client.ListZones() if err != nil { diff --git a/provider/pdns_test.go b/provider/pdns_test.go index 6dfda9f3e..de254b16f 100644 --- a/provider/pdns_test.go +++ b/provider/pdns_test.go @@ -793,9 +793,11 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSRecords() { client: &PDNSAPIClientStub{}, } + ctx := context.Background() + /* We test that endpoints are returned correctly for a Zone when Records() is called */ - eps, err := p.Records() + eps, err := p.Records(ctx) assert.Nil(suite.T(), err) assert.Equal(suite.T(), endpointsMixedRecords, eps) @@ -804,13 +806,13 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSRecords() { p = &PDNSProvider{ client: &PDNSAPIClientStubListZoneFailure{}, } - _, err = p.Records() + _, err = p.Records(ctx) assert.NotNil(suite.T(), err) p = &PDNSProvider{ client: &PDNSAPIClientStubListZonesFailure{}, } - _, err = p.Records() + _, err = p.Records(ctx) assert.NotNil(suite.T(), err) } diff --git a/provider/provider.go b/provider/provider.go index 96876d781..fcd12018c 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -27,7 +27,7 @@ import ( // Provider defines the interface DNS providers should implement. type Provider interface { - Records() ([]*endpoint.Endpoint, error) + Records(ctx context.Context) ([]*endpoint.Endpoint, error) ApplyChanges(ctx context.Context, changes *plan.Changes) error } diff --git a/provider/rcode0.go b/provider/rcode0.go index 707278b41..3311437aa 100644 --- a/provider/rcode0.go +++ b/provider/rcode0.go @@ -96,7 +96,7 @@ func (p *RcodeZeroProvider) Zones() ([]*rc0.Zone, error) { // Records returns resource records // // Decrypts TXT records if TXT-Encrypt flag is set and key is provided -func (p *RcodeZeroProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *RcodeZeroProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.Zones() if err != nil { diff --git a/provider/rcode0_test.go b/provider/rcode0_test.go index 9c10b594f..a977b8fe8 100644 --- a/provider/rcode0_test.go +++ b/provider/rcode0_test.go @@ -73,7 +73,9 @@ func TestRcodeZeroProvider_Records(t *testing.T) { }), } - endpoints, err := provider.Records() // should return 6 rrs + ctx := context.Background() + + endpoints, err := provider.Records(ctx) // should return 6 rrs if err != nil { t.Errorf("should not fail, %s", err) @@ -82,7 +84,7 @@ func TestRcodeZeroProvider_Records(t *testing.T) { mockRRSetService.TestErrorReturned = true - _, err = provider.Records() + _, err = provider.Records(ctx) if err == nil { t.Errorf("expected to fail, %s", err) } diff --git a/provider/rdns.go b/provider/rdns.go index 5826da8d6..414bd9c0c 100644 --- a/provider/rdns.go +++ b/provider/rdns.go @@ -113,7 +113,7 @@ func NewRDNSProvider(config RDNSConfig) (*RDNSProvider, error) { // Records returns all DNS records found in Rancher DNS(RDNS) etcdv3 backend. Depending on the record fields // it may be mapped to one or two records of type A, TXT, A+TXT. -func (p RDNSProvider) Records() ([]*endpoint.Endpoint, error) { +func (p RDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { var result []*endpoint.Endpoint rs, err := p.client.List(p.rootDomain) diff --git a/provider/rdns_test.go b/provider/rdns_test.go index 6abcd5add..b600d7665 100644 --- a/provider/rdns_test.go +++ b/provider/rdns_test.go @@ -113,7 +113,7 @@ func TestARecordTranslation(t *testing.T) { rootDomain: "lb.rancher.cloud", } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -148,7 +148,7 @@ func TestTXTRecordTranslation(t *testing.T) { rootDomain: "lb.rancher.cloud", } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -184,7 +184,7 @@ func TestAWithTXTRecordTranslation(t *testing.T) { rootDomain: "lb.rancher.cloud", } - endpoints, err := provider.Records() + endpoints, err := provider.Records(context.Background()) if err != nil { t.Fatal(err) } @@ -248,7 +248,7 @@ func TestRDNSApplyChanges(t *testing.T) { }, } - records, _ := provider.Records() + records, _ := provider.Records(context.Background()) for _, ep := range records { if ep.DNSName == "p1xaf1.lb.rancher.cloud" { changes2.UpdateOld = append(changes2.UpdateOld, ep) diff --git a/provider/rfc2136.go b/provider/rfc2136.go index ca17ba62d..b07ff970c 100644 --- a/provider/rfc2136.go +++ b/provider/rfc2136.go @@ -95,7 +95,7 @@ func NewRfc2136Provider(host string, port int, zoneName string, insecure bool, k } // Records returns the list of records. -func (r rfc2136Provider) Records() ([]*endpoint.Endpoint, error) { +func (r rfc2136Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { rrs, err := r.List() if err != nil { return nil, err diff --git a/provider/rfc2136_test.go b/provider/rfc2136_test.go index dc4abe767..164f3587a 100644 --- a/provider/rfc2136_test.go +++ b/provider/rfc2136_test.go @@ -113,7 +113,7 @@ func TestRfc2136GetRecordsMultipleTargets(t *testing.T) { provider, err := createRfc2136StubProvider(stub) assert.NoError(t, err) - recs, err := provider.Records() + recs, err := provider.Records(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(recs), "expected single record") @@ -142,7 +142,7 @@ func TestRfc2136GetRecords(t *testing.T) { provider, err := createRfc2136StubProvider(stub) assert.NoError(t, err) - recs, err := provider.Records() + recs, err := provider.Records(context.Background()) assert.NoError(t, err) assert.Equal(t, 6, len(recs)) diff --git a/provider/transip.go b/provider/transip.go index c3f8f7b16..68cdc099f 100644 --- a/provider/transip.go +++ b/provider/transip.go @@ -221,7 +221,7 @@ func (p *TransIPProvider) Zones() ([]transip.Domain, error) { } // Records returns the list of records in a given zone. -func (p *TransIPProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := p.Zones() if err != nil { return nil, err diff --git a/provider/vinyldns.go b/provider/vinyldns.go index 6ffd84a48..812c1bdb7 100644 --- a/provider/vinyldns.go +++ b/provider/vinyldns.go @@ -75,7 +75,7 @@ func NewVinylDNSProvider(domainFilter DomainFilter, zoneFilter ZoneIDFilter, dry }, nil } -func (p *vinyldnsProvider) Records() (endpoints []*endpoint.Endpoint, _ error) { +func (p *vinyldnsProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { zones, err := p.client.Zones() if err != nil { return nil, err diff --git a/provider/vinyldns_test.go b/provider/vinyldns_test.go index ada49071d..9dfb43f61 100644 --- a/provider/vinyldns_test.go +++ b/provider/vinyldns_test.go @@ -90,18 +90,20 @@ func TestVinylDNSServices(t *testing.T) { } func testVinylDNSProviderRecords(t *testing.T) { + ctx := context.Background() + mockVinylDNSProvider.domainFilter = NewDomainFilter([]string{"example.com"}) - result, err := mockVinylDNSProvider.Records() + result, err := mockVinylDNSProvider.Records(ctx) assert.Nil(t, err) assert.Equal(t, len(vinylDNSRecords), len(result)) mockVinylDNSProvider.zoneFilter = NewZoneIDFilter([]string{"0"}) - result, err = mockVinylDNSProvider.Records() + result, err = mockVinylDNSProvider.Records(ctx) assert.Nil(t, err) assert.Equal(t, len(vinylDNSRecords), len(result)) mockVinylDNSProvider.zoneFilter = NewZoneIDFilter([]string{"1"}) - result, err = mockVinylDNSProvider.Records() + result, err = mockVinylDNSProvider.Records(ctx) assert.Nil(t, err) assert.Equal(t, 0, len(result)) } diff --git a/registry/aws_sd_registry.go b/registry/aws_sd_registry.go index 1422db0b9..f9a0f0d65 100644 --- a/registry/aws_sd_registry.go +++ b/registry/aws_sd_registry.go @@ -44,8 +44,8 @@ func NewAWSSDRegistry(provider provider.Provider, ownerID string) (*AWSSDRegistr // Records calls AWS SD API and expects AWS SD provider to provider Owner/Resource information as a serialized // value in the AWSSDDescriptionLabel value in the Labels map -func (sdr *AWSSDRegistry) Records() ([]*endpoint.Endpoint, error) { - records, err := sdr.provider.Records() +func (sdr *AWSSDRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { + records, err := sdr.provider.Records(ctx) if err != nil { return nil, err } diff --git a/registry/aws_sd_registry_test.go b/registry/aws_sd_registry_test.go index 0658877eb..7aca319d7 100644 --- a/registry/aws_sd_registry_test.go +++ b/registry/aws_sd_registry_test.go @@ -33,7 +33,7 @@ type inMemoryProvider struct { onApplyChanges func(changes *plan.Changes) } -func (p *inMemoryProvider) Records() ([]*endpoint.Endpoint, error) { +func (p *inMemoryProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { return p.endpoints, nil } @@ -101,7 +101,7 @@ func TestAWSSDRegistryTest_Records(t *testing.T) { } r, _ := NewAWSSDRegistry(p, "owner") - records, _ := r.Records() + records, _ := r.Records(context.Background()) assert.True(t, testutils.SameEndpoints(records, expectedRecords)) } diff --git a/registry/noop.go b/registry/noop.go index 1a49b8ed0..4b91fbaf5 100644 --- a/registry/noop.go +++ b/registry/noop.go @@ -37,8 +37,8 @@ func NewNoopRegistry(provider provider.Provider) (*NoopRegistry, error) { } // Records returns the current records from the dns provider -func (im *NoopRegistry) Records() ([]*endpoint.Endpoint, error) { - return im.provider.Records() +func (im *NoopRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { + return im.provider.Records(ctx) } // ApplyChanges propagates changes to the dns provider diff --git a/registry/noop_test.go b/registry/noop_test.go index b56b8eeca..7e7598807 100644 --- a/registry/noop_test.go +++ b/registry/noop_test.go @@ -45,6 +45,7 @@ func testNoopInit(t *testing.T) { } func testNoopRecords(t *testing.T) { + ctx := context.Background() p := provider.NewInMemoryProvider() p.CreateZone("org") providerRecords := []*endpoint.Endpoint{ @@ -54,13 +55,13 @@ func testNoopRecords(t *testing.T) { RecordType: endpoint.RecordTypeCNAME, }, } - p.ApplyChanges(context.Background(), &plan.Changes{ + p.ApplyChanges(ctx, &plan.Changes{ Create: providerRecords, }) r, _ := NewNoopRegistry(p) - eps, err := r.Records() + eps, err := r.Records(ctx) require.NoError(t, err) assert.True(t, testutils.SameEndpoints(eps, providerRecords)) } @@ -131,6 +132,6 @@ func testNoopApplyChanges(t *testing.T) { }, }, })) - res, _ := p.Records() + res, _ := p.Records(ctx) assert.True(t, testutils.SameEndpoints(res, expectedUpdate)) } diff --git a/registry/registry.go b/registry/registry.go index be87d2046..746e7fdd3 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -30,7 +30,7 @@ import ( // each entry includes owner information // ApplyChanges(changes *plan.Changes) propagates the changes to the DNS Provider API and correspondingly updates ownership depending on type of registry being used type Registry interface { - Records() ([]*endpoint.Endpoint, error) + Records(ctx context.Context) ([]*endpoint.Endpoint, error) ApplyChanges(ctx context.Context, changes *plan.Changes) error } diff --git a/registry/txt.go b/registry/txt.go index 544355b4d..2d99b0660 100644 --- a/registry/txt.go +++ b/registry/txt.go @@ -61,7 +61,7 @@ func NewTXTRegistry(provider provider.Provider, txtPrefix, ownerID string, cache // Records returns the current records from the registry excluding TXT Records // If TXT records was created previously to indicate ownership its corresponding value // will be added to the endpoints Labels map -func (im *TXTRegistry) Records() ([]*endpoint.Endpoint, error) { +func (im *TXTRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { // If we have the zones cached AND we have refreshed the cache since the // last given interval, then just use the cached results. if im.recordsCache != nil && time.Since(im.recordsCacheRefreshTime) < im.cacheInterval { @@ -69,7 +69,7 @@ func (im *TXTRegistry) Records() ([]*endpoint.Endpoint, error) { return im.recordsCache, nil } - records, err := im.provider.Records() + records, err := im.provider.Records(ctx) if err != nil { return nil, err } diff --git a/registry/txt_test.go b/registry/txt_test.go index c5920ad7e..64b31f0f0 100644 --- a/registry/txt_test.go +++ b/registry/txt_test.go @@ -67,9 +67,10 @@ func testTXTRegistryRecords(t *testing.T) { } func testTXTRegistryRecordsPrefixed(t *testing.T) { + ctx := context.Background() p := provider.NewInMemoryProvider() p.CreateZone(testZone) - p.ApplyChanges(context.Background(), &plan.Changes{ + p.ApplyChanges(ctx, &plan.Changes{ Create: []*endpoint.Endpoint{ newEndpointWithOwnerAndLabels("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, "", endpoint.Labels{"foo": "somefoo"}), newEndpointWithOwnerAndLabels("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, "", endpoint.Labels{"bar": "somebar"}), @@ -159,21 +160,22 @@ func testTXTRegistryRecordsPrefixed(t *testing.T) { } r, _ := NewTXTRegistry(p, "txt.", "owner", time.Hour) - records, _ := r.Records() + records, _ := r.Records(ctx) assert.True(t, testutils.SameEndpoints(records, expectedRecords)) // Ensure prefix is case-insensitive r, _ = NewTXTRegistry(p, "TxT.", "owner", time.Hour) - records, _ = r.Records() + records, _ = r.Records(ctx) assert.True(t, testutils.SameEndpointLabels(records, expectedRecords)) } func testTXTRegistryRecordsNoPrefix(t *testing.T) { p := provider.NewInMemoryProvider() + ctx := context.Background() p.CreateZone(testZone) - p.ApplyChanges(context.Background(), &plan.Changes{ + p.ApplyChanges(ctx, &plan.Changes{ Create: []*endpoint.Endpoint{ newEndpointWithOwner("foo.test-zone.example.org", "foo.loadbalancer.com", endpoint.RecordTypeCNAME, ""), newEndpointWithOwner("bar.test-zone.example.org", "my-domain.com", endpoint.RecordTypeCNAME, ""), @@ -239,7 +241,7 @@ func testTXTRegistryRecordsNoPrefix(t *testing.T) { } r, _ := NewTXTRegistry(p, "", "owner", time.Hour) - records, _ := r.Records() + records, _ := r.Records(ctx) assert.True(t, testutils.SameEndpoints(records, expectedRecords)) }