remove context.TODO()s in external-dns

This commit is contained in:
Tariq Ibrahim 2020-01-15 13:59:20 -08:00
parent f400ded46c
commit a5896c2326
No known key found for this signature in database
GPG Key ID: DFC94E4A008B908A
14 changed files with 199 additions and 170 deletions

View File

@ -103,8 +103,7 @@ type Controller struct {
} }
// RunOnce runs a single iteration of a reconciliation loop. // RunOnce runs a single iteration of a reconciliation loop.
func (c *Controller) RunOnce() error { func (c *Controller) RunOnce(ctx context.Context) error {
ctx := context.Background()
records, err := c.Registry.Records(ctx) records, err := c.Registry.Records(ctx)
if err != nil { if err != nil {
registryErrorsTotal.Inc() registryErrorsTotal.Inc()
@ -141,11 +140,11 @@ func (c *Controller) RunOnce() error {
} }
// Run runs RunOnce in a loop with a delay until stopChan receives a value. // Run runs RunOnce in a loop with a delay until stopChan receives a value.
func (c *Controller) Run(stopChan <-chan struct{}) { func (c *Controller) Run(ctx context.Context, stopChan <-chan struct{}) {
ticker := time.NewTicker(c.Interval) ticker := time.NewTicker(c.Interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
err := c.RunOnce() err := c.RunOnce(ctx)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }

View File

@ -146,7 +146,7 @@ func TestRunOnce(t *testing.T) {
Policy: &plan.SyncPolicy{}, Policy: &plan.SyncPolicy{},
} }
assert.NoError(t, ctrl.RunOnce()) assert.NoError(t, ctrl.RunOnce(context.Background()))
// Validate that the mock source was called. // Validate that the mock source was called.
source.AssertExpectations(t) source.AssertExpectations(t)

12
main.go
View File

@ -17,6 +17,7 @@ limitations under the License.
package main package main
import ( import (
"context"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -60,6 +61,8 @@ func main() {
} }
log.SetLevel(ll) log.SetLevel(ll)
ctx := context.Background()
stopChan := make(chan struct{}, 1) stopChan := make(chan struct{}, 1)
go serveMetrics(cfg.MetricsAddress) go serveMetrics(cfg.MetricsAddress)
@ -144,9 +147,9 @@ func main() {
case "rcodezero": case "rcodezero":
p, err = provider.NewRcodeZeroProvider(domainFilter, cfg.DryRun, cfg.RcodezeroTXTEncrypt) p, err = provider.NewRcodeZeroProvider(domainFilter, cfg.DryRun, cfg.RcodezeroTXTEncrypt)
case "google": case "google":
p, err = provider.NewGoogleProvider(cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun) p, err = provider.NewGoogleProvider(ctx, cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun)
case "digitalocean": case "digitalocean":
p, err = provider.NewDigitalOceanProvider(domainFilter, cfg.DryRun) p, err = provider.NewDigitalOceanProvider(ctx, domainFilter, cfg.DryRun)
case "linode": case "linode":
p, err = provider.NewLinodeProvider(domainFilter, cfg.DryRun, externaldns.Version) p, err = provider.NewLinodeProvider(domainFilter, cfg.DryRun, externaldns.Version)
case "dnsimple": case "dnsimple":
@ -197,6 +200,7 @@ func main() {
p, err = provider.NewDesignateProvider(domainFilter, cfg.DryRun) p, err = provider.NewDesignateProvider(domainFilter, cfg.DryRun)
case "pdns": case "pdns":
p, err = provider.NewPDNSProvider( p, err = provider.NewPDNSProvider(
ctx,
provider.PDNSConfig{ provider.PDNSConfig{
DomainFilter: domainFilter, DomainFilter: domainFilter,
DryRun: cfg.DryRun, DryRun: cfg.DryRun,
@ -266,14 +270,14 @@ func main() {
} }
if cfg.Once { if cfg.Once {
err := ctrl.RunOnce() err := ctrl.RunOnce(ctx)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
os.Exit(0) os.Exit(0)
} }
ctrl.Run(stopChan) ctrl.Run(ctx, stopChan)
} }
func handleSigterm(stopChan chan struct{}) { func handleSigterm(stopChan chan struct{}) {

View File

@ -146,9 +146,8 @@ func NewCloudFlareProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter,
} }
// Zones returns the list of hosted zones. // Zones returns the list of hosted zones.
func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) { func (p *CloudFlareProvider) Zones(ctx context.Context) ([]cloudflare.Zone, error) {
result := []cloudflare.Zone{} result := []cloudflare.Zone{}
ctx := context.TODO()
p.PaginationOptions.Page = 1 p.PaginationOptions.Page = 1
for { for {
@ -177,7 +176,7 @@ func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) {
// Records returns the list of records. // Records returns the list of records.
func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -208,17 +207,17 @@ func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Cha
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...) combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...)
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...) combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...)
return p.submitChanges(combinedChanges) return p.submitChanges(ctx, combinedChanges)
} }
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction. // submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *CloudFlareProvider) submitChanges(changes []*cloudFlareChange) error { func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloudFlareChange) error {
// return early if there is nothing to change // return early if there is nothing to change
if len(changes) == 0 { if len(changes) == 0 {
return nil return nil
} }
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -477,7 +477,7 @@ func TestCloudFlareZones(t *testing.T) {
zoneIDFilter: NewZoneIDFilter([]string{""}), zoneIDFilter: NewZoneIDFilter([]string{""}),
} }
zones, err := provider.Zones() zones, err := provider.Zones(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -57,12 +57,12 @@ type DigitalOceanChange struct {
} }
// NewDigitalOceanProvider initializes a new DigitalOcean DNS based Provider. // NewDigitalOceanProvider initializes a new DigitalOcean DNS based Provider.
func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) { func NewDigitalOceanProvider(ctx context.Context, domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) {
token, ok := os.LookupEnv("DO_TOKEN") token, ok := os.LookupEnv("DO_TOKEN")
if !ok { if !ok {
return nil, fmt.Errorf("No token found") return nil, fmt.Errorf("No token found")
} }
oauthClient := oauth2.NewClient(context.TODO(), oauth2.StaticTokenSource(&oauth2.Token{ oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: token, AccessToken: token,
})) }))
client := godo.NewClient(oauthClient) client := godo.NewClient(oauthClient)
@ -76,10 +76,10 @@ func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOc
} }
// Zones returns the list of hosted zones. // Zones returns the list of hosted zones.
func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) { func (p *DigitalOceanProvider) Zones(ctx context.Context) ([]godo.Domain, error) {
result := []godo.Domain{} result := []godo.Domain{}
zones, err := p.fetchZones() zones, err := p.fetchZones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,13 +95,13 @@ func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) {
// Records returns the list of records in a given zone. // Records returns the list of records in a given zone.
func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoints := []*endpoint.Endpoint{} endpoints := []*endpoint.Endpoint{}
for _, zone := range zones { for _, zone := range zones {
records, err := p.fetchRecords(zone.Name) records, err := p.fetchRecords(ctx, zone.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,11 +124,11 @@ func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoin
return endpoints, nil return endpoints, nil
} }
func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecord, error) { func (p *DigitalOceanProvider) fetchRecords(ctx context.Context, zoneName string) ([]godo.DomainRecord, error) {
allRecords := []godo.DomainRecord{} allRecords := []godo.DomainRecord{}
listOptions := &godo.ListOptions{} listOptions := &godo.ListOptions{}
for { for {
records, resp, err := p.Client.Records(context.TODO(), zoneName, listOptions) records, resp, err := p.Client.Records(ctx, zoneName, listOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,11 +149,11 @@ func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecor
return allRecords, nil return allRecords, nil
} }
func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) { func (p *DigitalOceanProvider) fetchZones(ctx context.Context) ([]godo.Domain, error) {
allZones := []godo.Domain{} allZones := []godo.Domain{}
listOptions := &godo.ListOptions{} listOptions := &godo.ListOptions{}
for { for {
zones, resp, err := p.Client.List(context.TODO(), listOptions) zones, resp, err := p.Client.List(ctx, listOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -175,13 +175,13 @@ func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) {
} }
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction. // submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) error { func (p *DigitalOceanProvider) submitChanges(ctx context.Context, changes []*DigitalOceanChange) error {
// return early if there is nothing to change // return early if there is nothing to change
if len(changes) == 0 { if len(changes) == 0 {
return nil return nil
} }
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -189,7 +189,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
// separate into per-zone change sets to be passed to the API. // separate into per-zone change sets to be passed to the API.
changesByZone := digitalOceanChangesByZone(zones, changes) changesByZone := digitalOceanChangesByZone(zones, changes)
for zoneName, changes := range changesByZone { for zoneName, changes := range changesByZone {
records, err := p.fetchRecords(zoneName) records, err := p.fetchRecords(ctx, zoneName)
if err != nil { if err != nil {
log.Errorf("Failed to list records in the zone: %s", zoneName) log.Errorf("Failed to list records in the zone: %s", zoneName)
continue continue
@ -225,7 +225,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
switch change.Action { switch change.Action {
case DigitalOceanCreate: case DigitalOceanCreate:
_, _, err = p.Client.CreateRecord(context.TODO(), zoneName, _, _, err = p.Client.CreateRecord(ctx, zoneName,
&godo.DomainRecordEditRequest{ &godo.DomainRecordEditRequest{
Data: change.ResourceRecordSet.Data, Data: change.ResourceRecordSet.Data,
Name: change.ResourceRecordSet.Name, Name: change.ResourceRecordSet.Name,
@ -237,13 +237,13 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
} }
case DigitalOceanDelete: case DigitalOceanDelete:
recordID := p.getRecordID(records, change.ResourceRecordSet) recordID := p.getRecordID(records, change.ResourceRecordSet)
_, err = p.Client.DeleteRecord(context.TODO(), zoneName, recordID) _, err = p.Client.DeleteRecord(ctx, zoneName, recordID)
if err != nil { if err != nil {
return err return err
} }
case DigitalOceanUpdate: case DigitalOceanUpdate:
recordID := p.getRecordID(records, change.ResourceRecordSet) recordID := p.getRecordID(records, change.ResourceRecordSet)
_, _, err = p.Client.EditRecord(context.TODO(), zoneName, recordID, _, _, err = p.Client.EditRecord(ctx, zoneName, recordID,
&godo.DomainRecordEditRequest{ &godo.DomainRecordEditRequest{
Data: change.ResourceRecordSet.Data, Data: change.ResourceRecordSet.Data,
Name: change.ResourceRecordSet.Name, Name: change.ResourceRecordSet.Name,
@ -267,7 +267,7 @@ func (p *DigitalOceanProvider) ApplyChanges(ctx context.Context, changes *plan.C
combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanUpdate, changes.UpdateNew)...) combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanUpdate, changes.UpdateNew)...)
combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanDelete, changes.Delete)...) combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanDelete, changes.Delete)...)
return p.submitChanges(combinedChanges) return p.submitChanges(ctx, combinedChanges)
} }
// newDigitalOceanChanges returns a collection of Changes based on the given records and action. // newDigitalOceanChanges returns a collection of Changes based on the given records and action.

View File

@ -413,7 +413,7 @@ func TestDigitalOceanZones(t *testing.T) {
domainFilter: NewDomainFilter([]string{"com"}), domainFilter: NewDomainFilter([]string{"com"}),
} }
zones, err := provider.Zones() zones, err := provider.Zones(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -445,12 +445,12 @@ func TestDigitalOceanApplyChanges(t *testing.T) {
func TestNewDigitalOceanProvider(t *testing.T) { func TestNewDigitalOceanProvider(t *testing.T) {
_ = os.Setenv("DO_TOKEN", "xxxxxxxxxxxxxxxxx") _ = os.Setenv("DO_TOKEN", "xxxxxxxxxxxxxxxxx")
_, err := NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) _, err := NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
if err != nil { if err != nil {
t.Errorf("should not fail, %s", err) t.Errorf("should not fail, %s", err)
} }
_ = os.Unsetenv("DO_TOKEN") _ = os.Unsetenv("DO_TOKEN")
_, err = NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) _, err = NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
if err == nil { if err == nil {
t.Errorf("expected to fail") t.Errorf("expected to fail")
} }
@ -494,7 +494,7 @@ func TestDigitalOceanRecord(t *testing.T) {
Client: &mockDigitalOceanClient{}, Client: &mockDigitalOceanClient{},
} }
records, err := provider.fetchRecords("example.com") records, err := provider.fetchRecords(context.Background(), "example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -175,13 +175,13 @@ func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Chan
func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
endpoints := make([]*endpoint.Endpoint, 0) endpoints := make([]*endpoint.Endpoint, 0)
domains, err := ep.client.GetDomains(context.TODO()) domains, err := ep.client.GetDomains(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, d := range domains { for _, d := range domains {
record, err := ep.client.GetRecords(context.TODO(), d.Name) record, err := ep.client.GetRecords(ctx, d.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -116,11 +116,13 @@ type GoogleProvider struct {
managedZonesClient managedZonesServiceInterface managedZonesClient managedZonesServiceInterface
// A client for managing change sets // A client for managing change sets
changesClient changesServiceInterface changesClient changesServiceInterface
// The context parameter to be passed for gcloud API calls.
ctx context.Context
} }
// NewGoogleProvider initializes a new Google CloudDNS based Provider. // NewGoogleProvider initializes a new Google CloudDNS based Provider.
func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) { func NewGoogleProvider(ctx context.Context, project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) {
gcloud, err := google.DefaultClient(context.TODO(), dns.NdevClouddnsReadwriteScope) gcloud, err := google.DefaultClient(ctx, dns.NdevClouddnsReadwriteScope)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,7 +134,7 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z
}, },
}) })
dnsClient, err := dns.NewService(context.TODO(), option.WithHTTPClient(gcloud)) dnsClient, err := dns.NewService(ctx, option.WithHTTPClient(gcloud))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,13 +157,14 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z
resourceRecordSetsClient: resourceRecordSetsService{dnsClient.ResourceRecordSets}, resourceRecordSetsClient: resourceRecordSetsService{dnsClient.ResourceRecordSets},
managedZonesClient: managedZonesService{dnsClient.ManagedZones}, managedZonesClient: managedZonesService{dnsClient.ManagedZones},
changesClient: changesService{dnsClient.Changes}, changesClient: changesService{dnsClient.Changes},
ctx: ctx,
} }
return provider, nil return provider, nil
} }
// Zones returns the list of hosted zones. // Zones returns the list of hosted zones.
func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) { func (p *GoogleProvider) Zones(ctx context.Context) (map[string]*dns.ManagedZone, error) {
zones := make(map[string]*dns.ManagedZone) zones := make(map[string]*dns.ManagedZone)
f := func(resp *dns.ManagedZonesListResponse) error { f := func(resp *dns.ManagedZonesListResponse) error {
@ -178,7 +181,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
} }
log.Debugf("Matching zones against domain filters: %v", p.domainFilter.filters) log.Debugf("Matching zones against domain filters: %v", p.domainFilter.filters)
if err := p.managedZonesClient.List(p.project).Pages(context.TODO(), f); err != nil { if err := p.managedZonesClient.List(p.project).Pages(ctx, f); err != nil {
return nil, err return nil, err
} }
@ -199,7 +202,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
// Records returns the list of records in all relevant zones. // Records returns the list of records in all relevant zones.
func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -230,7 +233,7 @@ func (p *GoogleProvider) CreateRecords(endpoints []*endpoint.Endpoint) error {
change.Additions = append(change.Additions, p.newFilteredRecords(endpoints)...) change.Additions = append(change.Additions, p.newFilteredRecords(endpoints)...)
return p.submitChange(change) return p.submitChange(p.ctx, change)
} }
// UpdateRecords updates a given set of old records to a new set of records in a given hosted zone. // UpdateRecords updates a given set of old records to a new set of records in a given hosted zone.
@ -240,7 +243,7 @@ func (p *GoogleProvider) UpdateRecords(records, oldRecords []*endpoint.Endpoint)
change.Additions = append(change.Additions, p.newFilteredRecords(records)...) change.Additions = append(change.Additions, p.newFilteredRecords(records)...)
change.Deletions = append(change.Deletions, p.newFilteredRecords(oldRecords)...) change.Deletions = append(change.Deletions, p.newFilteredRecords(oldRecords)...)
return p.submitChange(change) return p.submitChange(p.ctx, change)
} }
// DeleteRecords deletes a given set of DNS records in a given zone. // DeleteRecords deletes a given set of DNS records in a given zone.
@ -249,7 +252,7 @@ func (p *GoogleProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error {
change.Deletions = append(change.Deletions, p.newFilteredRecords(endpoints)...) change.Deletions = append(change.Deletions, p.newFilteredRecords(endpoints)...)
return p.submitChange(change) return p.submitChange(p.ctx, change)
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
@ -263,7 +266,7 @@ func (p *GoogleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.Delete)...) change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.Delete)...)
return p.submitChange(change) return p.submitChange(ctx, change)
} }
// newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter. // newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter.
@ -280,13 +283,13 @@ func (p *GoogleProvider) newFilteredRecords(endpoints []*endpoint.Endpoint) []*d
} }
// submitChange takes a zone and a Change and sends it to Google. // submitChange takes a zone and a Change and sends it to Google.
func (p *GoogleProvider) submitChange(change *dns.Change) error { func (p *GoogleProvider) submitChange(ctx context.Context, change *dns.Change) error {
if len(change.Additions) == 0 && len(change.Deletions) == 0 { if len(change.Additions) == 0 && len(change.Deletions) == 0 {
log.Info("All records are already up to date") log.Info("All records are already up to date")
return nil return nil
} }
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -194,7 +194,7 @@ func hasTrailingDot(target string) bool {
func TestGoogleZonesIDFilter(t *testing.T) { func TestGoogleZonesIDFilter(t *testing.T) {
provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"10002"}), false, []*endpoint.Endpoint{}) provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"10002"}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones() zones, err := provider.Zones(context.Background())
require.NoError(t, err) require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{ validateZones(t, zones, map[string]*dns.ManagedZone{
@ -205,7 +205,7 @@ func TestGoogleZonesIDFilter(t *testing.T) {
func TestGoogleZonesNameFilter(t *testing.T) { func TestGoogleZonesNameFilter(t *testing.T) {
provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"internal-2"}), false, []*endpoint.Endpoint{}) provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"internal-2"}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones() zones, err := provider.Zones(context.Background())
require.NoError(t, err) require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{ validateZones(t, zones, map[string]*dns.ManagedZone{
@ -216,7 +216,7 @@ func TestGoogleZonesNameFilter(t *testing.T) {
func TestGoogleZones(t *testing.T) { func TestGoogleZones(t *testing.T) {
provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{}) provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones() zones, err := provider.Zones(context.Background())
require.NoError(t, err) require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{ validateZones(t, zones, map[string]*dns.ManagedZone{
@ -777,7 +777,7 @@ func setupGoogleRecords(t *testing.T, provider *GoogleProvider, endpoints []*end
func clearGoogleRecords(t *testing.T, provider *GoogleProvider, zone string) { func clearGoogleRecords(t *testing.T, provider *GoogleProvider, zone string) {
recordSets := []*dns.ResourceRecordSet{} recordSets := []*dns.ResourceRecordSet{}
require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.TODO(), func(resp *dns.ResourceRecordSetsListResponse) error { require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.Background(), func(resp *dns.ResourceRecordSetsListResponse) error {
for _, r := range resp.Rrsets { for _, r := range resp.Rrsets {
switch r.Type { switch r.Type {
case endpoint.RecordTypeA, endpoint.RecordTypeCNAME: case endpoint.RecordTypeA, endpoint.RecordTypeCNAME:

View File

@ -101,8 +101,8 @@ func NewLinodeProvider(domainFilter DomainFilter, dryRun bool, appVersion string
} }
// Zones returns the list of hosted zones. // Zones returns the list of hosted zones.
func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) { func (p *LinodeProvider) Zones(ctx context.Context) ([]*linodego.Domain, error) {
zones, err := p.fetchZones() zones, err := p.fetchZones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +112,7 @@ func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) {
// Records returns the list of records in a given zone. // Records returns the list of records in a given zone.
func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones() zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,7 +120,7 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err
var endpoints []*endpoint.Endpoint var endpoints []*endpoint.Endpoint
for _, zone := range zones { for _, zone := range zones {
records, err := p.fetchRecords(zone.ID) records, err := p.fetchRecords(ctx, zone.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,8 +143,8 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err
return endpoints, nil return endpoints, nil
} }
func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, error) { func (p *LinodeProvider) fetchRecords(ctx context.Context, domainID int) ([]*linodego.DomainRecord, error) {
records, err := p.Client.ListDomainRecords(context.TODO(), domainID, nil) records, err := p.Client.ListDomainRecords(ctx, domainID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,10 +152,10 @@ func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, e
return records, nil return records, nil
} }
func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) { func (p *LinodeProvider) fetchZones(ctx context.Context) ([]*linodego.Domain, error) {
var zones []*linodego.Domain var zones []*linodego.Domain
allZones, err := p.Client.ListDomains(context.TODO(), linodego.NewListOptions(0, "")) allZones, err := p.Client.ListDomains(ctx, linodego.NewListOptions(0, ""))
if err != nil { if err != nil {
return nil, err return nil, err
@ -173,7 +173,7 @@ func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) {
} }
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction. // submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *LinodeProvider) submitChanges(changes LinodeChanges) error { func (p *LinodeProvider) submitChanges(ctx context.Context, changes LinodeChanges) error {
for _, change := range changes.Creates { for _, change := range changes.Creates {
logFields := log.Fields{ logFields := log.Fields{
"record": change.Options.Name, "record": change.Options.Name,
@ -187,7 +187,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun { if p.DryRun {
log.WithFields(logFields).Info("Would create record.") log.WithFields(logFields).Info("Would create record.")
} else if _, err := p.Client.CreateDomainRecord(context.TODO(), change.Domain.ID, change.Options); err != nil { } else if _, err := p.Client.CreateDomainRecord(ctx, change.Domain.ID, change.Options); err != nil {
log.WithFields(logFields).Errorf( log.WithFields(logFields).Errorf(
"Failed to Create record: %v", "Failed to Create record: %v",
err, err,
@ -208,7 +208,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun { if p.DryRun {
log.WithFields(logFields).Info("Would delete record.") log.WithFields(logFields).Info("Would delete record.")
} else if err := p.Client.DeleteDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID); err != nil { } else if err := p.Client.DeleteDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID); err != nil {
log.WithFields(logFields).Errorf( log.WithFields(logFields).Errorf(
"Failed to Delete record: %v", "Failed to Delete record: %v",
err, err,
@ -229,7 +229,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun { if p.DryRun {
log.WithFields(logFields).Info("Would update record.") log.WithFields(logFields).Info("Would update record.")
} else if _, err := p.Client.UpdateDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil { } else if _, err := p.Client.UpdateDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil {
log.WithFields(logFields).Errorf( log.WithFields(logFields).Errorf(
"Failed to Update record: %v", "Failed to Update record: %v",
err, err,
@ -259,7 +259,7 @@ func getPriority() *int {
func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
recordsByZoneID := make(map[string][]*linodego.DomainRecord) recordsByZoneID := make(map[string][]*linodego.DomainRecord)
zones, err := p.fetchZones() zones, err := p.fetchZones(ctx)
if err != nil { if err != nil {
return err return err
@ -276,7 +276,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
// Fetch records for each zone // Fetch records for each zone
for _, zone := range zones { for _, zone := range zones {
records, err := p.fetchRecords(zone.ID) records, err := p.fetchRecords(ctx, zone.ID)
if err != nil { if err != nil {
return err return err
@ -484,7 +484,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
} }
} }
return p.submitChanges(LinodeChanges{ return p.submitChanges(ctx, LinodeChanges{
Creates: linodeCreates, Creates: linodeCreates,
Deletes: linodeDeletes, Deletes: linodeDeletes,
Updates: linodeUpdates, Updates: linodeUpdates,

View File

@ -160,7 +160,7 @@ func TestLinodeStripRecordName(t *testing.T) {
})) }))
} }
func TestLinodeFetchZonesNoFiilters(t *testing.T) { func TestLinodeFetchZonesNoFilters(t *testing.T) {
mockDomainClient := MockDomainClient{} mockDomainClient := MockDomainClient{}
provider := &LinodeProvider{ provider := &LinodeProvider{
@ -176,7 +176,7 @@ func TestLinodeFetchZonesNoFiilters(t *testing.T) {
).Return(createZones(), nil).Once() ).Return(createZones(), nil).Once()
expected := createZones() expected := createZones()
actual, err := provider.fetchZones() actual, err := provider.fetchZones(context.Background())
require.NoError(t, err) require.NoError(t, err)
mockDomainClient.AssertExpectations(t) mockDomainClient.AssertExpectations(t)
@ -202,7 +202,7 @@ func TestLinodeFetchZonesWithFilter(t *testing.T) {
{ID: 1, Domain: "foo.com"}, {ID: 1, Domain: "foo.com"},
{ID: 3, Domain: "baz.com"}, {ID: 3, Domain: "baz.com"},
} }
actual, err := provider.fetchZones() actual, err := provider.fetchZones(context.Background())
require.NoError(t, err) require.NoError(t, err)
mockDomainClient.AssertExpectations(t) mockDomainClient.AssertExpectations(t)

View File

@ -225,7 +225,7 @@ type PDNSProvider struct {
} }
// NewPDNSProvider initializes a new PowerDNS based Provider. // NewPDNSProvider initializes a new PowerDNS based Provider.
func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) { func NewPDNSProvider(ctx context.Context, config PDNSConfig) (*PDNSProvider, error) {
// Do some input validation // Do some input validation
@ -252,7 +252,7 @@ func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) {
provider := &PDNSProvider{ provider := &PDNSProvider{
client: &PDNSAPIClient{ client: &PDNSAPIClient{
dryRun: config.DryRun, dryRun: config.DryRun,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}), authCtx: context.WithValue(ctx, pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}),
client: pgo.NewAPIClient(pdnsClientConfig), client: pgo.NewAPIClient(pdnsClientConfig),
domainFilter: config.DomainFilter, domainFilter: config.DomainFilter,
}, },

View File

@ -495,21 +495,21 @@ var (
DomainFilterEmptyClient = &PDNSAPIClient{ DomainFilterEmptyClient = &PDNSAPIClient{
dryRun: false, dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()), client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListEmpty, domainFilter: DomainFilterListEmpty,
} }
DomainFilterSingleClient = &PDNSAPIClient{ DomainFilterSingleClient = &PDNSAPIClient{
dryRun: false, dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()), client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListSingle, domainFilter: DomainFilterListSingle,
} }
DomainFilterMultipleClient = &PDNSAPIClient{ DomainFilterMultipleClient = &PDNSAPIClient{
dryRun: false, dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}), authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()), client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListMultiple, domainFilter: DomainFilterListMultiple,
} }
@ -639,124 +639,148 @@ type NewPDNSProviderTestSuite struct {
func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreate() { func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreate() {
_, err := NewPDNSProvider(PDNSConfig{ _, err := NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
DomainFilter: NewDomainFilter([]string{""}), PDNSConfig{
}) Server: "http://localhost:8081",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Error(suite.T(), err, "--pdns-api-key should be specified") assert.Error(suite.T(), err, "--pdns-api-key should be specified")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}), Server: "http://localhost:8081",
}) APIKey: "foo",
DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}),
})
assert.Nil(suite.T(), err, "--domain-filter should raise no error") assert.Nil(suite.T(), err, "--domain-filter should raise no error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
DryRun: true, APIKey: "foo",
}) DomainFilter: NewDomainFilter([]string{""}),
DryRun: true,
})
assert.Error(suite.T(), err, "--dry-run should raise an error") assert.Error(suite.T(), err, "--dry-run should raise an error")
// This is our "regular" code path, no error should be thrown // This is our "regular" code path, no error should be thrown
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
}) APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Nil(suite.T(), err, "Regular case should raise no error") assert.Nil(suite.T(), err, "Regular case should raise no error")
} }
func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreateTLS() { func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreateTLS() {
_, err := NewPDNSProvider(PDNSConfig{ _, err := NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
}) APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Nil(suite.T(), err, "Omitted TLS Config case should raise no error") assert.Nil(suite.T(), err, "Omitted TLS Config case should raise no error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: false, DomainFilter: NewDomainFilter([]string{""}),
}, TLSConfig: TLSConfig{
}) TLSEnabled: false,
},
})
assert.Nil(suite.T(), err, "Disabled TLS Config should raise no error") assert.Nil(suite.T(), err, "Disabled TLS Config should raise no error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: false, DomainFilter: NewDomainFilter([]string{""}),
CAFilePath: "/path/to/ca.crt", TLSConfig: TLSConfig{
ClientCertFilePath: "/path/to/cert.pem", TLSEnabled: false,
ClientCertKeyFilePath: "/path/to/cert-key.pem", CAFilePath: "/path/to/ca.crt",
}, ClientCertFilePath: "/path/to/cert.pem",
}) ClientCertKeyFilePath: "/path/to/cert-key.pem",
},
})
assert.Nil(suite.T(), err, "Disabled TLS Config with additional flags should raise no error") assert.Nil(suite.T(), err, "Disabled TLS Config with additional flags should raise no error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: true, DomainFilter: NewDomainFilter([]string{""}),
}, TLSConfig: TLSConfig{
}) TLSEnabled: true,
},
})
assert.Error(suite.T(), err, "Enabled TLS Config without --tls-ca should raise an error") assert.Error(suite.T(), err, "Enabled TLS Config without --tls-ca should raise an error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: true, DomainFilter: NewDomainFilter([]string{""}),
CAFilePath: "../internal/testresources/ca.pem", TLSConfig: TLSConfig{
}, TLSEnabled: true,
}) CAFilePath: "../internal/testresources/ca.pem",
},
})
assert.Nil(suite.T(), err, "Enabled TLS Config with --tls-ca should raise no error") assert.Nil(suite.T(), err, "Enabled TLS Config with --tls-ca should raise no error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: true, DomainFilter: NewDomainFilter([]string{""}),
CAFilePath: "../internal/testresources/ca.pem", TLSConfig: TLSConfig{
ClientCertFilePath: "../internal/testresources/client-cert.pem", TLSEnabled: true,
}, CAFilePath: "../internal/testresources/ca.pem",
}) ClientCertFilePath: "../internal/testresources/client-cert.pem",
},
})
assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert only should raise an error") assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert only should raise an error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: true, DomainFilter: NewDomainFilter([]string{""}),
CAFilePath: "../internal/testresources/ca.pem", TLSConfig: TLSConfig{
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", TLSEnabled: true,
}, CAFilePath: "../internal/testresources/ca.pem",
}) ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert-key only should raise an error") assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert-key only should raise an error")
_, err = NewPDNSProvider(PDNSConfig{ _, err = NewPDNSProvider(
Server: "http://localhost:8081", context.Background(),
APIKey: "foo", PDNSConfig{
DomainFilter: NewDomainFilter([]string{""}), Server: "http://localhost:8081",
TLSConfig: TLSConfig{ APIKey: "foo",
TLSEnabled: true, DomainFilter: NewDomainFilter([]string{""}),
CAFilePath: "../internal/testresources/ca.pem", TLSConfig: TLSConfig{
ClientCertFilePath: "../internal/testresources/client-cert.pem", TLSEnabled: true,
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem", CAFilePath: "../internal/testresources/ca.pem",
}, ClientCertFilePath: "../internal/testresources/client-cert.pem",
}) ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
assert.Nil(suite.T(), err, "Enabled TLS Config with all flags should raise no error") assert.Nil(suite.T(), err, "Enabled TLS Config with all flags should raise no error")
} }