diff --git a/provider/cloudflare/cloudflare.go b/provider/cloudflare/cloudflare.go index 2edbaf624..d910226fc 100644 --- a/provider/cloudflare/cloudflare.go +++ b/provider/cloudflare/cloudflare.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "os" - "sort" "strconv" "strings" @@ -112,8 +111,8 @@ type CloudFlareProvider struct { // cloudFlareChange differentiates between ChangActions type cloudFlareChange struct { - Action string - ResourceRecordSet []cloudflare.DNSRecord + Action string + ResourceRecord cloudflare.DNSRecord } // NewCloudFlareProvider initializes a new CloudFlare DNS based Provider. @@ -200,15 +199,39 @@ func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, // ApplyChanges applies a given set of changes in a given zone. func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { - proxiedByDefault := p.proxiedByDefault + cloudflareChanges := []*cloudFlareChange{} - combinedChanges := make([]*cloudFlareChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete)) + for _, endpoint := range changes.Create { + for _, target := range endpoint.Targets { + cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareCreate, endpoint, target)) + } + } - combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareCreate, changes.Create, proxiedByDefault)...) - combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...) - combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...) + for i, desired := range changes.UpdateNew { + current := changes.UpdateOld[i] - return p.submitChanges(ctx, combinedChanges) + add, remove, leave := provider.Difference(current.Targets, desired.Targets) + + for _, a := range add { + cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareCreate, desired, a)) + } + + for _, a := range leave { + cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareUpdate, desired, a)) + } + + for _, a := range remove { + cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareDelete, current, a)) + } + } + + for _, endpoint := range changes.Delete { + for _, target := range endpoint.Targets { + cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareDelete, endpoint, target)) + } + } + + return p.submitChanges(ctx, cloudflareChanges) } // submitChanges takes a zone and a collection of Changes and sends them as a single transaction. @@ -232,12 +255,11 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud } for _, change := range changes { logFields := log.Fields{ - "record": change.ResourceRecordSet[0].Name, - "type": change.ResourceRecordSet[0].Type, - "ttl": change.ResourceRecordSet[0].TTL, - "targets": len(change.ResourceRecordSet), - "action": change.Action, - "zone": zoneID, + "record": change.ResourceRecord.Name, + "type": change.ResourceRecord.Type, + "ttl": change.ResourceRecord.TTL, + "action": change.Action, + "zone": zoneID, } log.WithFields(logFields).Info("Changing record.") @@ -246,24 +268,30 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud continue } - recordIDs := p.getRecordIDs(records, change.ResourceRecordSet[0]) - - // to simplify bookkeeping for multiple records, an update is executed as delete+create - if change.Action == cloudFlareDelete || change.Action == cloudFlareUpdate { - for _, recordID := range recordIDs { - err := p.Client.DeleteDNSRecord(zoneID, recordID) - if err != nil { - log.WithFields(logFields).Errorf("failed to delete record: %v", err) - } + if change.Action == cloudFlareUpdate { + recordID := p.getRecordID(records, change.ResourceRecord) + if recordID == "" { + log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord) + continue } - } - - if change.Action == cloudFlareCreate || change.Action == cloudFlareUpdate { - for _, record := range change.ResourceRecordSet { - _, err := p.Client.CreateDNSRecord(zoneID, record) - if err != nil { - log.WithFields(logFields).Errorf("failed to create record: %v", err) - } + err := p.Client.UpdateDNSRecord(zoneID, recordID, change.ResourceRecord) + if err != nil { + log.WithFields(logFields).Errorf("failed to delete record: %v", err) + } + } else if change.Action == cloudFlareDelete { + recordID := p.getRecordID(records, change.ResourceRecord) + if recordID == "" { + log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord) + continue + } + err := p.Client.DeleteDNSRecord(zoneID, recordID) + if err != nil { + log.WithFields(logFields).Errorf("failed to delete record: %v", err) + } + } else if change.Action == cloudFlareCreate { + _, err := p.Client.CreateDNSRecord(zoneID, change.ResourceRecord) + if err != nil { + log.WithFields(logFields).Errorf("failed to create record: %v", err) } } } @@ -282,9 +310,9 @@ func (p *CloudFlareProvider) changesByZone(zones []cloudflare.Zone, changeSet [] } for _, c := range changeSet { - zoneID, _ := zoneNameIDMapper.FindZone(c.ResourceRecordSet[0].Name) + zoneID, _ := zoneNameIDMapper.FindZone(c.ResourceRecord.Name) if zoneID == "" { - log.Debugf("Skipping record %s because no hosted zone matching record DNS Name was detected", c.ResourceRecordSet[0].Name) + log.Debugf("Skipping record %s because no hosted zone matching record DNS Name was detected", c.ResourceRecord.Name) continue } changes[zoneID] = append(changes[zoneID], c) @@ -293,51 +321,36 @@ func (p *CloudFlareProvider) changesByZone(zones []cloudflare.Zone, changeSet [] return changes } -func (p *CloudFlareProvider) getRecordIDs(records []cloudflare.DNSRecord, record cloudflare.DNSRecord) []string { - recordIDs := make([]string, 0) +func (p *CloudFlareProvider) getRecordID(records []cloudflare.DNSRecord, record cloudflare.DNSRecord) string { for _, zoneRecord := range records { - if zoneRecord.Name == record.Name && zoneRecord.Type == record.Type { - recordIDs = append(recordIDs, zoneRecord.ID) + if zoneRecord.Name == record.Name && zoneRecord.Type == record.Type && zoneRecord.Content == record.Content { + return zoneRecord.ID } } - sort.Strings(recordIDs) - return recordIDs + return "" } -// newCloudFlareChanges returns a collection of Changes based on the given records and action. -func newCloudFlareChanges(action string, endpoints []*endpoint.Endpoint, proxiedByDefault bool) []*cloudFlareChange { - changes := make([]*cloudFlareChange, 0, len(endpoints)) - - for _, endpoint := range endpoints { - changes = append(changes, newCloudFlareChange(action, endpoint, proxiedByDefault)) - } - - return changes -} - -func newCloudFlareChange(action string, endpoint *endpoint.Endpoint, proxiedByDefault bool) *cloudFlareChange { +func (p *CloudFlareProvider) newCloudFlareChange(action string, endpoint *endpoint.Endpoint, target string) *cloudFlareChange { ttl := defaultCloudFlareRecordTTL - proxied := shouldBeProxied(endpoint, proxiedByDefault) + proxied := shouldBeProxied(endpoint, p.proxiedByDefault) if endpoint.RecordTTL.IsConfigured() { ttl = int(endpoint.RecordTTL) } - resourceRecordSet := make([]cloudflare.DNSRecord, len(endpoint.Targets)) + if len(endpoint.Targets) > 1 { + log.Errorf("Updates should have just one target") + } - for i := range endpoint.Targets { - resourceRecordSet[i] = cloudflare.DNSRecord{ + return &cloudFlareChange{ + Action: action, + ResourceRecord: cloudflare.DNSRecord{ Name: endpoint.DNSName, TTL: ttl, Proxied: proxied, Type: endpoint.RecordType, - Content: endpoint.Targets[i], - } - } - - return &cloudFlareChange{ - Action: action, - ResourceRecordSet: resourceRecordSet, + Content: target, + }, } } diff --git a/provider/cloudflare/cloudflare_test.go b/provider/cloudflare/cloudflare_test.go index 521cb3ba1..51e6f7ea6 100644 --- a/provider/cloudflare/cloudflare_test.go +++ b/provider/cloudflare/cloudflare_test.go @@ -57,6 +57,15 @@ var ExampleDomain = []cloudflare.DNSRecord{ Content: "1.2.3.4", Proxied: false, }, + { + ID: "2345678901", + ZoneID: "001", + Name: "foobar.bar.com", + Type: endpoint.RecordTypeA, + TTL: 120, + Content: "3.4.5.6", + Proxied: false, + }, { ID: "1231231233", ZoneID: "002", @@ -656,29 +665,51 @@ func TestCloudflareGetRecordID(t *testing.T) { p := &CloudFlareProvider{} records := []cloudflare.DNSRecord{ { - Name: "foo.com", - Type: endpoint.RecordTypeCNAME, - ID: "1", + Name: "foo.com", + Type: endpoint.RecordTypeCNAME, + Content: "foobar", + ID: "1", }, { Name: "bar.de", Type: endpoint.RecordTypeA, ID: "2", }, + { + Name: "bar.de", + Type: endpoint.RecordTypeA, + Content: "1.2.3.4", + ID: "2", + }, } - assert.Len(t, p.getRecordIDs(records, cloudflare.DNSRecord{ - Name: "foo.com", - Type: endpoint.RecordTypeA, - }), 0) - assert.Len(t, p.getRecordIDs(records, cloudflare.DNSRecord{ - Name: "bar.de", - Type: endpoint.RecordTypeA, - }), 1) - assert.Equal(t, "2", p.getRecordIDs(records, cloudflare.DNSRecord{ - Name: "bar.de", - Type: endpoint.RecordTypeA, - })[0]) + assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{ + Name: "foo.com", + Type: endpoint.RecordTypeA, + Content: "foobar", + })) + + assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{ + Name: "foo.com", + Type: endpoint.RecordTypeCNAME, + Content: "fizfuz", + })) + + assert.Equal(t, "1", p.getRecordID(records, cloudflare.DNSRecord{ + Name: "foo.com", + Type: endpoint.RecordTypeCNAME, + Content: "foobar", + })) + assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{ + Name: "bar.de", + Type: endpoint.RecordTypeA, + Content: "2.3.4.5", + })) + assert.Equal(t, "2", p.getRecordID(records, cloudflare.DNSRecord{ + Name: "bar.de", + Type: endpoint.RecordTypeA, + Content: "1.2.3.4", + })) } func TestCloudflareGroupByNameAndType(t *testing.T) { @@ -948,23 +979,7 @@ func TestCloudflareComplexUpdate(t *testing.T) { } td.CmpDeeply(t, client.Actions, []MockAction{ - { - Name: "Delete", - ZoneId: "001", - RecordId: "1234567890", - }, - { - Name: "Create", - ZoneId: "001", - RecordData: cloudflare.DNSRecord{ - Name: "foobar.bar.com", - Type: "A", - Content: "1.2.3.4", - TTL: 1, - Proxied: true, - }, - }, - { + MockAction{ Name: "Create", ZoneId: "001", RecordData: cloudflare.DNSRecord{ @@ -975,5 +990,22 @@ func TestCloudflareComplexUpdate(t *testing.T) { Proxied: true, }, }, + MockAction{ + Name: "Update", + ZoneId: "001", + RecordId: "1234567890", + RecordData: cloudflare.DNSRecord{ + Name: "foobar.bar.com", + Type: "A", + Content: "1.2.3.4", + TTL: 1, + Proxied: true, + }, + }, + MockAction{ + Name: "Delete", + ZoneId: "001", + RecordId: "2345678901", + }, }) } diff --git a/provider/provider.go b/provider/provider.go index c16bab748..fbe8945f0 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -50,3 +50,26 @@ func EnsureTrailingDot(hostname string) string { return strings.TrimSuffix(hostname, ".") + "." } + +// Difference tells which entries need to be respectively +// added, removed, or left untouched for "current" to be transformed to "desired" +func Difference(current, desired []string) ([]string, []string, []string) { + add, remove, leave := []string{}, []string{}, []string{} + index := make(map[string]struct{}, len(current)) + for _, x := range current { + index[x] = struct{}{} + } + for _, x := range desired { + if _, found := index[x]; found { + leave = append(leave, x) + delete(index, x) + } else { + add = append(add, x) + delete(index, x) + } + } + for x := range index { + remove = append(remove, x) + } + return add, remove, leave +} diff --git a/provider/provider_test.go b/provider/provider_test.go index 2b10bc79c..57d48e70e 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -22,6 +22,8 @@ import ( "testing" log "github.com/sirupsen/logrus" + + "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { @@ -44,3 +46,12 @@ func TestEnsureTrailingDot(t *testing.T) { } } } + +func TestDifference(t *testing.T) { + current := []string{"foo", "bar"} + desired := []string{"bar", "baz"} + add, remove, leave := Difference(current, desired) + assert.Equal(t, add, []string{"baz"}) + assert.Equal(t, remove, []string{"foo"}) + assert.Equal(t, leave, []string{"bar"}) +}