diff --git a/controller/controller.go b/controller/controller.go index fa5dba6ec..1be6c0fb9 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -103,8 +103,7 @@ type Controller struct { } // RunOnce runs a single iteration of a reconciliation loop. -func (c *Controller) RunOnce() error { - ctx := context.Background() +func (c *Controller) RunOnce(ctx context.Context) error { records, err := c.Registry.Records(ctx) if err != nil { registryErrorsTotal.Inc() @@ -141,11 +140,11 @@ func (c *Controller) RunOnce() error { } // Run runs RunOnce in a loop with a delay until stopChan receives a value. -func (c *Controller) Run(stopChan <-chan struct{}) { +func (c *Controller) Run(ctx context.Context, stopChan <-chan struct{}) { ticker := time.NewTicker(c.Interval) defer ticker.Stop() for { - err := c.RunOnce() + err := c.RunOnce(ctx) if err != nil { log.Error(err) } diff --git a/controller/controller_test.go b/controller/controller_test.go index c59795ef7..8fb644fdc 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -146,7 +146,7 @@ func TestRunOnce(t *testing.T) { Policy: &plan.SyncPolicy{}, } - assert.NoError(t, ctrl.RunOnce()) + assert.NoError(t, ctrl.RunOnce(context.Background())) // Validate that the mock source was called. source.AssertExpectations(t) diff --git a/main.go b/main.go index b3ed6a891..5a4a6ae15 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "net/http" "os" "os/signal" @@ -60,6 +61,8 @@ func main() { } log.SetLevel(ll) + ctx := context.Background() + stopChan := make(chan struct{}, 1) go serveMetrics(cfg.MetricsAddress) @@ -144,9 +147,9 @@ func main() { case "rcodezero": p, err = provider.NewRcodeZeroProvider(domainFilter, cfg.DryRun, cfg.RcodezeroTXTEncrypt) case "google": - p, err = provider.NewGoogleProvider(cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun) + p, err = provider.NewGoogleProvider(ctx, cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun) case "digitalocean": - p, err = provider.NewDigitalOceanProvider(domainFilter, cfg.DryRun) + p, err = provider.NewDigitalOceanProvider(ctx, domainFilter, cfg.DryRun) case "linode": p, err = provider.NewLinodeProvider(domainFilter, cfg.DryRun, externaldns.Version) case "dnsimple": @@ -197,6 +200,7 @@ func main() { p, err = provider.NewDesignateProvider(domainFilter, cfg.DryRun) case "pdns": p, err = provider.NewPDNSProvider( + ctx, provider.PDNSConfig{ DomainFilter: domainFilter, DryRun: cfg.DryRun, @@ -266,14 +270,14 @@ func main() { } if cfg.Once { - err := ctrl.RunOnce() + err := ctrl.RunOnce(ctx) if err != nil { log.Fatal(err) } os.Exit(0) } - ctrl.Run(stopChan) + ctrl.Run(ctx, stopChan) } func handleSigterm(stopChan chan struct{}) { diff --git a/provider/cloudflare.go b/provider/cloudflare.go index b1ff72f5d..06796e04a 100644 --- a/provider/cloudflare.go +++ b/provider/cloudflare.go @@ -146,9 +146,8 @@ func NewCloudFlareProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, } // Zones returns the list of hosted zones. -func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) { +func (p *CloudFlareProvider) Zones(ctx context.Context) ([]cloudflare.Zone, error) { result := []cloudflare.Zone{} - ctx := context.TODO() p.PaginationOptions.Page = 1 for { @@ -177,7 +176,7 @@ func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) { // Records returns the list of records. func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return nil, err } @@ -208,17 +207,17 @@ func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Cha combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...) combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...) - return p.submitChanges(combinedChanges) + return p.submitChanges(ctx, combinedChanges) } // submitChanges takes a zone and a collection of Changes and sends them as a single transaction. -func (p *CloudFlareProvider) submitChanges(changes []*cloudFlareChange) error { +func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloudFlareChange) error { // return early if there is nothing to change if len(changes) == 0 { return nil } - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return err } diff --git a/provider/cloudflare_test.go b/provider/cloudflare_test.go index 6bbbb6f3d..01a9877a2 100644 --- a/provider/cloudflare_test.go +++ b/provider/cloudflare_test.go @@ -477,7 +477,7 @@ func TestCloudFlareZones(t *testing.T) { zoneIDFilter: NewZoneIDFilter([]string{""}), } - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) if err != nil { t.Fatal(err) } diff --git a/provider/digital_ocean.go b/provider/digital_ocean.go index a97bfbb4e..668c1ee53 100644 --- a/provider/digital_ocean.go +++ b/provider/digital_ocean.go @@ -57,12 +57,12 @@ type DigitalOceanChange struct { } // NewDigitalOceanProvider initializes a new DigitalOcean DNS based Provider. -func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) { +func NewDigitalOceanProvider(ctx context.Context, domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) { token, ok := os.LookupEnv("DO_TOKEN") if !ok { return nil, fmt.Errorf("No token found") } - oauthClient := oauth2.NewClient(context.TODO(), oauth2.StaticTokenSource(&oauth2.Token{ + oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: token, })) client := godo.NewClient(oauthClient) @@ -76,10 +76,10 @@ func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOc } // Zones returns the list of hosted zones. -func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) { +func (p *DigitalOceanProvider) Zones(ctx context.Context) ([]godo.Domain, error) { result := []godo.Domain{} - zones, err := p.fetchZones() + zones, err := p.fetchZones(ctx) if err != nil { return nil, err } @@ -95,13 +95,13 @@ func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) { // Records returns the list of records in a given zone. func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return nil, err } endpoints := []*endpoint.Endpoint{} for _, zone := range zones { - records, err := p.fetchRecords(zone.Name) + records, err := p.fetchRecords(ctx, zone.Name) if err != nil { return nil, err } @@ -124,11 +124,11 @@ func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoin return endpoints, nil } -func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecord, error) { +func (p *DigitalOceanProvider) fetchRecords(ctx context.Context, zoneName string) ([]godo.DomainRecord, error) { allRecords := []godo.DomainRecord{} listOptions := &godo.ListOptions{} for { - records, resp, err := p.Client.Records(context.TODO(), zoneName, listOptions) + records, resp, err := p.Client.Records(ctx, zoneName, listOptions) if err != nil { return nil, err } @@ -149,11 +149,11 @@ func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecor return allRecords, nil } -func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) { +func (p *DigitalOceanProvider) fetchZones(ctx context.Context) ([]godo.Domain, error) { allZones := []godo.Domain{} listOptions := &godo.ListOptions{} for { - zones, resp, err := p.Client.List(context.TODO(), listOptions) + zones, resp, err := p.Client.List(ctx, listOptions) if err != nil { return nil, err } @@ -175,13 +175,13 @@ func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) { } // submitChanges takes a zone and a collection of Changes and sends them as a single transaction. -func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) error { +func (p *DigitalOceanProvider) submitChanges(ctx context.Context, changes []*DigitalOceanChange) error { // return early if there is nothing to change if len(changes) == 0 { return nil } - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return err } @@ -189,7 +189,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro // separate into per-zone change sets to be passed to the API. changesByZone := digitalOceanChangesByZone(zones, changes) for zoneName, changes := range changesByZone { - records, err := p.fetchRecords(zoneName) + records, err := p.fetchRecords(ctx, zoneName) if err != nil { log.Errorf("Failed to list records in the zone: %s", zoneName) continue @@ -225,7 +225,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro switch change.Action { case DigitalOceanCreate: - _, _, err = p.Client.CreateRecord(context.TODO(), zoneName, + _, _, err = p.Client.CreateRecord(ctx, zoneName, &godo.DomainRecordEditRequest{ Data: change.ResourceRecordSet.Data, Name: change.ResourceRecordSet.Name, @@ -237,13 +237,13 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro } case DigitalOceanDelete: recordID := p.getRecordID(records, change.ResourceRecordSet) - _, err = p.Client.DeleteRecord(context.TODO(), zoneName, recordID) + _, err = p.Client.DeleteRecord(ctx, zoneName, recordID) if err != nil { return err } case DigitalOceanUpdate: recordID := p.getRecordID(records, change.ResourceRecordSet) - _, _, err = p.Client.EditRecord(context.TODO(), zoneName, recordID, + _, _, err = p.Client.EditRecord(ctx, zoneName, recordID, &godo.DomainRecordEditRequest{ Data: change.ResourceRecordSet.Data, Name: change.ResourceRecordSet.Name, @@ -267,7 +267,7 @@ func (p *DigitalOceanProvider) ApplyChanges(ctx context.Context, changes *plan.C combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanUpdate, changes.UpdateNew)...) combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanDelete, changes.Delete)...) - return p.submitChanges(combinedChanges) + return p.submitChanges(ctx, combinedChanges) } // newDigitalOceanChanges returns a collection of Changes based on the given records and action. diff --git a/provider/digital_ocean_test.go b/provider/digital_ocean_test.go index 8ca91ca15..9372a5f2f 100644 --- a/provider/digital_ocean_test.go +++ b/provider/digital_ocean_test.go @@ -413,7 +413,7 @@ func TestDigitalOceanZones(t *testing.T) { domainFilter: NewDomainFilter([]string{"com"}), } - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) if err != nil { t.Fatal(err) } @@ -445,12 +445,12 @@ func TestDigitalOceanApplyChanges(t *testing.T) { func TestNewDigitalOceanProvider(t *testing.T) { _ = os.Setenv("DO_TOKEN", "xxxxxxxxxxxxxxxxx") - _, err := NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) + _, err := NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) if err != nil { t.Errorf("should not fail, %s", err) } _ = os.Unsetenv("DO_TOKEN") - _, err = NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) + _, err = NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) if err == nil { t.Errorf("expected to fail") } @@ -494,7 +494,7 @@ func TestDigitalOceanRecord(t *testing.T) { Client: &mockDigitalOceanClient{}, } - records, err := provider.fetchRecords("example.com") + records, err := provider.fetchRecords(context.Background(), "example.com") if err != nil { t.Fatal(err) } diff --git a/provider/exoscale.go b/provider/exoscale.go index 72eeb61a5..dcd66f181 100644 --- a/provider/exoscale.go +++ b/provider/exoscale.go @@ -175,13 +175,13 @@ func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Chan func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { endpoints := make([]*endpoint.Endpoint, 0) - domains, err := ep.client.GetDomains(context.TODO()) + domains, err := ep.client.GetDomains(ctx) if err != nil { return nil, err } for _, d := range domains { - record, err := ep.client.GetRecords(context.TODO(), d.Name) + record, err := ep.client.GetRecords(ctx, d.Name) if err != nil { return nil, err } diff --git a/provider/google.go b/provider/google.go index f1ce72dcb..2a8475b1f 100644 --- a/provider/google.go +++ b/provider/google.go @@ -116,11 +116,13 @@ type GoogleProvider struct { managedZonesClient managedZonesServiceInterface // A client for managing change sets changesClient changesServiceInterface + // The context parameter to be passed for gcloud API calls. + ctx context.Context } // NewGoogleProvider initializes a new Google CloudDNS based Provider. -func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) { - gcloud, err := google.DefaultClient(context.TODO(), dns.NdevClouddnsReadwriteScope) +func NewGoogleProvider(ctx context.Context, project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) { + gcloud, err := google.DefaultClient(ctx, dns.NdevClouddnsReadwriteScope) if err != nil { return nil, err } @@ -132,7 +134,7 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z }, }) - dnsClient, err := dns.NewService(context.TODO(), option.WithHTTPClient(gcloud)) + dnsClient, err := dns.NewService(ctx, option.WithHTTPClient(gcloud)) if err != nil { return nil, err } @@ -155,13 +157,14 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z resourceRecordSetsClient: resourceRecordSetsService{dnsClient.ResourceRecordSets}, managedZonesClient: managedZonesService{dnsClient.ManagedZones}, changesClient: changesService{dnsClient.Changes}, + ctx: ctx, } return provider, nil } // Zones returns the list of hosted zones. -func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) { +func (p *GoogleProvider) Zones(ctx context.Context) (map[string]*dns.ManagedZone, error) { zones := make(map[string]*dns.ManagedZone) f := func(resp *dns.ManagedZonesListResponse) error { @@ -178,7 +181,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) { } log.Debugf("Matching zones against domain filters: %v", p.domainFilter.filters) - if err := p.managedZonesClient.List(p.project).Pages(context.TODO(), f); err != nil { + if err := p.managedZonesClient.List(p.project).Pages(ctx, f); err != nil { return nil, err } @@ -199,7 +202,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) { // Records returns the list of records in all relevant zones. func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return nil, err } @@ -230,7 +233,7 @@ func (p *GoogleProvider) CreateRecords(endpoints []*endpoint.Endpoint) error { change.Additions = append(change.Additions, p.newFilteredRecords(endpoints)...) - return p.submitChange(change) + return p.submitChange(p.ctx, change) } // UpdateRecords updates a given set of old records to a new set of records in a given hosted zone. @@ -240,7 +243,7 @@ func (p *GoogleProvider) UpdateRecords(records, oldRecords []*endpoint.Endpoint) change.Additions = append(change.Additions, p.newFilteredRecords(records)...) change.Deletions = append(change.Deletions, p.newFilteredRecords(oldRecords)...) - return p.submitChange(change) + return p.submitChange(p.ctx, change) } // DeleteRecords deletes a given set of DNS records in a given zone. @@ -249,7 +252,7 @@ func (p *GoogleProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error { change.Deletions = append(change.Deletions, p.newFilteredRecords(endpoints)...) - return p.submitChange(change) + return p.submitChange(p.ctx, change) } // ApplyChanges applies a given set of changes in a given zone. @@ -263,7 +266,7 @@ func (p *GoogleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.Delete)...) - return p.submitChange(change) + return p.submitChange(ctx, change) } // newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter. @@ -280,13 +283,13 @@ func (p *GoogleProvider) newFilteredRecords(endpoints []*endpoint.Endpoint) []*d } // submitChange takes a zone and a Change and sends it to Google. -func (p *GoogleProvider) submitChange(change *dns.Change) error { +func (p *GoogleProvider) submitChange(ctx context.Context, change *dns.Change) error { if len(change.Additions) == 0 && len(change.Deletions) == 0 { log.Info("All records are already up to date") return nil } - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return err } diff --git a/provider/google_test.go b/provider/google_test.go index 618412a5b..7caedfe42 100644 --- a/provider/google_test.go +++ b/provider/google_test.go @@ -194,7 +194,7 @@ func hasTrailingDot(target string) bool { func TestGoogleZonesIDFilter(t *testing.T) { provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"10002"}), false, []*endpoint.Endpoint{}) - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) require.NoError(t, err) validateZones(t, zones, map[string]*dns.ManagedZone{ @@ -205,7 +205,7 @@ func TestGoogleZonesIDFilter(t *testing.T) { func TestGoogleZonesNameFilter(t *testing.T) { provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"internal-2"}), false, []*endpoint.Endpoint{}) - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) require.NoError(t, err) validateZones(t, zones, map[string]*dns.ManagedZone{ @@ -216,7 +216,7 @@ func TestGoogleZonesNameFilter(t *testing.T) { func TestGoogleZones(t *testing.T) { provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{}) - zones, err := provider.Zones() + zones, err := provider.Zones(context.Background()) require.NoError(t, err) validateZones(t, zones, map[string]*dns.ManagedZone{ @@ -777,7 +777,7 @@ func setupGoogleRecords(t *testing.T, provider *GoogleProvider, endpoints []*end func clearGoogleRecords(t *testing.T, provider *GoogleProvider, zone string) { recordSets := []*dns.ResourceRecordSet{} - require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.TODO(), func(resp *dns.ResourceRecordSetsListResponse) error { + require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.Background(), func(resp *dns.ResourceRecordSetsListResponse) error { for _, r := range resp.Rrsets { switch r.Type { case endpoint.RecordTypeA, endpoint.RecordTypeCNAME: diff --git a/provider/linode.go b/provider/linode.go index d89b841a9..e32181c53 100644 --- a/provider/linode.go +++ b/provider/linode.go @@ -101,8 +101,8 @@ func NewLinodeProvider(domainFilter DomainFilter, dryRun bool, appVersion string } // Zones returns the list of hosted zones. -func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) { - zones, err := p.fetchZones() +func (p *LinodeProvider) Zones(ctx context.Context) ([]*linodego.Domain, error) { + zones, err := p.fetchZones(ctx) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) { // Records returns the list of records in a given zone. func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { - zones, err := p.Zones() + zones, err := p.Zones(ctx) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err var endpoints []*endpoint.Endpoint for _, zone := range zones { - records, err := p.fetchRecords(zone.ID) + records, err := p.fetchRecords(ctx, zone.ID) if err != nil { return nil, err } @@ -143,8 +143,8 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err return endpoints, nil } -func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, error) { - records, err := p.Client.ListDomainRecords(context.TODO(), domainID, nil) +func (p *LinodeProvider) fetchRecords(ctx context.Context, domainID int) ([]*linodego.DomainRecord, error) { + records, err := p.Client.ListDomainRecords(ctx, domainID, nil) if err != nil { return nil, err } @@ -152,10 +152,10 @@ func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, e return records, nil } -func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) { +func (p *LinodeProvider) fetchZones(ctx context.Context) ([]*linodego.Domain, error) { var zones []*linodego.Domain - allZones, err := p.Client.ListDomains(context.TODO(), linodego.NewListOptions(0, "")) + allZones, err := p.Client.ListDomains(ctx, linodego.NewListOptions(0, "")) if err != nil { return nil, err @@ -173,7 +173,7 @@ func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) { } // submitChanges takes a zone and a collection of Changes and sends them as a single transaction. -func (p *LinodeProvider) submitChanges(changes LinodeChanges) error { +func (p *LinodeProvider) submitChanges(ctx context.Context, changes LinodeChanges) error { for _, change := range changes.Creates { logFields := log.Fields{ "record": change.Options.Name, @@ -187,7 +187,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error { if p.DryRun { log.WithFields(logFields).Info("Would create record.") - } else if _, err := p.Client.CreateDomainRecord(context.TODO(), change.Domain.ID, change.Options); err != nil { + } else if _, err := p.Client.CreateDomainRecord(ctx, change.Domain.ID, change.Options); err != nil { log.WithFields(logFields).Errorf( "Failed to Create record: %v", err, @@ -208,7 +208,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error { if p.DryRun { log.WithFields(logFields).Info("Would delete record.") - } else if err := p.Client.DeleteDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID); err != nil { + } else if err := p.Client.DeleteDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID); err != nil { log.WithFields(logFields).Errorf( "Failed to Delete record: %v", err, @@ -229,7 +229,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error { if p.DryRun { log.WithFields(logFields).Info("Would update record.") - } else if _, err := p.Client.UpdateDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil { + } else if _, err := p.Client.UpdateDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil { log.WithFields(logFields).Errorf( "Failed to Update record: %v", err, @@ -259,7 +259,7 @@ func getPriority() *int { func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { recordsByZoneID := make(map[string][]*linodego.DomainRecord) - zones, err := p.fetchZones() + zones, err := p.fetchZones(ctx) if err != nil { return err @@ -276,7 +276,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes // Fetch records for each zone for _, zone := range zones { - records, err := p.fetchRecords(zone.ID) + records, err := p.fetchRecords(ctx, zone.ID) if err != nil { return err @@ -484,7 +484,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes } } - return p.submitChanges(LinodeChanges{ + return p.submitChanges(ctx, LinodeChanges{ Creates: linodeCreates, Deletes: linodeDeletes, Updates: linodeUpdates, diff --git a/provider/linode_test.go b/provider/linode_test.go index 5419cb147..a41d1a99d 100644 --- a/provider/linode_test.go +++ b/provider/linode_test.go @@ -160,7 +160,7 @@ func TestLinodeStripRecordName(t *testing.T) { })) } -func TestLinodeFetchZonesNoFiilters(t *testing.T) { +func TestLinodeFetchZonesNoFilters(t *testing.T) { mockDomainClient := MockDomainClient{} provider := &LinodeProvider{ @@ -176,7 +176,7 @@ func TestLinodeFetchZonesNoFiilters(t *testing.T) { ).Return(createZones(), nil).Once() expected := createZones() - actual, err := provider.fetchZones() + actual, err := provider.fetchZones(context.Background()) require.NoError(t, err) mockDomainClient.AssertExpectations(t) @@ -202,7 +202,7 @@ func TestLinodeFetchZonesWithFilter(t *testing.T) { {ID: 1, Domain: "foo.com"}, {ID: 3, Domain: "baz.com"}, } - actual, err := provider.fetchZones() + actual, err := provider.fetchZones(context.Background()) require.NoError(t, err) mockDomainClient.AssertExpectations(t) diff --git a/provider/pdns.go b/provider/pdns.go index afacdf33c..15591b66e 100644 --- a/provider/pdns.go +++ b/provider/pdns.go @@ -225,7 +225,7 @@ type PDNSProvider struct { } // NewPDNSProvider initializes a new PowerDNS based Provider. -func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) { +func NewPDNSProvider(ctx context.Context, config PDNSConfig) (*PDNSProvider, error) { // Do some input validation @@ -252,7 +252,7 @@ func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) { provider := &PDNSProvider{ client: &PDNSAPIClient{ dryRun: config.DryRun, - authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}), + authCtx: context.WithValue(ctx, pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}), client: pgo.NewAPIClient(pdnsClientConfig), domainFilter: config.DomainFilter, }, diff --git a/provider/pdns_test.go b/provider/pdns_test.go index de254b16f..6de83ecfe 100644 --- a/provider/pdns_test.go +++ b/provider/pdns_test.go @@ -495,21 +495,21 @@ var ( DomainFilterEmptyClient = &PDNSAPIClient{ dryRun: false, - authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), + authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), client: pgo.NewAPIClient(pgo.NewConfiguration()), domainFilter: DomainFilterListEmpty, } DomainFilterSingleClient = &PDNSAPIClient{ dryRun: false, - authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), + authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), client: pgo.NewAPIClient(pgo.NewConfiguration()), domainFilter: DomainFilterListSingle, } DomainFilterMultipleClient = &PDNSAPIClient{ dryRun: false, - authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), + authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), client: pgo.NewAPIClient(pgo.NewConfiguration()), domainFilter: DomainFilterListMultiple, } @@ -639,124 +639,148 @@ type NewPDNSProviderTestSuite struct { func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreate() { - _, err := NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - DomainFilter: NewDomainFilter([]string{""}), - }) + _, err := NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + DomainFilter: NewDomainFilter([]string{""}), + }) assert.Error(suite.T(), err, "--pdns-api-key should be specified") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}), - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}), + }) assert.Nil(suite.T(), err, "--domain-filter should raise no error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - DryRun: true, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + DryRun: true, + }) assert.Error(suite.T(), err, "--dry-run should raise an error") // This is our "regular" code path, no error should be thrown - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + }) assert.Nil(suite.T(), err, "Regular case should raise no error") } func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreateTLS() { - _, err := NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - }) + _, err := NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + }) assert.Nil(suite.T(), err, "Omitted TLS Config case should raise no error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: false, - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: false, + }, + }) assert.Nil(suite.T(), err, "Disabled TLS Config should raise no error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: false, - CAFilePath: "/path/to/ca.crt", - ClientCertFilePath: "/path/to/cert.pem", - ClientCertKeyFilePath: "/path/to/cert-key.pem", - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: false, + CAFilePath: "/path/to/ca.crt", + ClientCertFilePath: "/path/to/cert.pem", + ClientCertKeyFilePath: "/path/to/cert-key.pem", + }, + }) assert.Nil(suite.T(), err, "Disabled TLS Config with additional flags should raise no error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: true, - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: true, + }, + }) assert.Error(suite.T(), err, "Enabled TLS Config without --tls-ca should raise an error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: true, - CAFilePath: "../internal/testresources/ca.pem", - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: true, + CAFilePath: "../internal/testresources/ca.pem", + }, + }) assert.Nil(suite.T(), err, "Enabled TLS Config with --tls-ca should raise no error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: true, - CAFilePath: "../internal/testresources/ca.pem", - ClientCertFilePath: "../internal/testresources/client-cert.pem", - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: true, + CAFilePath: "../internal/testresources/ca.pem", + ClientCertFilePath: "../internal/testresources/client-cert.pem", + }, + }) assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert only should raise an error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: true, - CAFilePath: "../internal/testresources/ca.pem", - ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: true, + CAFilePath: "../internal/testresources/ca.pem", + ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", + }, + }) assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert-key only should raise an error") - _, err = NewPDNSProvider(PDNSConfig{ - Server: "http://localhost:8081", - APIKey: "foo", - DomainFilter: NewDomainFilter([]string{""}), - TLSConfig: TLSConfig{ - TLSEnabled: true, - CAFilePath: "../internal/testresources/ca.pem", - ClientCertFilePath: "../internal/testresources/client-cert.pem", - ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", - }, - }) + _, err = NewPDNSProvider( + context.Background(), + PDNSConfig{ + Server: "http://localhost:8081", + APIKey: "foo", + DomainFilter: NewDomainFilter([]string{""}), + TLSConfig: TLSConfig{ + TLSEnabled: true, + CAFilePath: "../internal/testresources/ca.pem", + ClientCertFilePath: "../internal/testresources/client-cert.pem", + ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", + }, + }) assert.Nil(suite.T(), err, "Enabled TLS Config with all flags should raise no error") }