From e264fdecff7cefa1c58eb9bad5e21937f773d3e2 Mon Sep 17 00:00:00 2001 From: Scott Fleener Date: Sun, 8 Jan 2023 20:48:15 -0500 Subject: [PATCH] Dedupe on DNS name and record type --- provider/pihole/pihole.go | 22 +++++++---- provider/pihole/pihole_test.go | 70 ++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/provider/pihole/pihole.go b/provider/pihole/pihole.go index 0aabb7e5b..6b1f352d9 100644 --- a/provider/pihole/pihole.go +++ b/provider/pihole/pihole.go @@ -49,6 +49,12 @@ type PiholeConfig struct { DryRun bool } +// Helper struct for de-duping DNS entry updates. +type piholeEntryKey struct { + Target string + RecordType string +} + // NewPiholeProvider initializes a new Pi-hole Local DNS based Provider. func NewPiholeProvider(cfg PiholeConfig) (*PiholeProvider, error) { api, err := newPiholeClient(cfg) @@ -82,17 +88,19 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes } // Handle updated state - there are no endpoints for updating in place. - updateNew := make(map[string]*endpoint.Endpoint) + updateNew := make(map[piholeEntryKey]*endpoint.Endpoint) for _, ep := range changes.UpdateNew { - updateNew[ep.DNSName] = ep + key := piholeEntryKey{ep.DNSName, ep.RecordType} + updateNew[key] = ep } for _, ep := range changes.UpdateOld { - // Check if this existing entry has an exact match for an updated entry, and skip it if so. - if newRecord := updateNew[ep.DNSName]; newRecord != nil { - // PiHole only has a single target and a record type, no need to compare other fields. - if newRecord.Targets.String() == ep.Targets.String() && newRecord.RecordType == ep.RecordType { - delete(updateNew, ep.DNSName) + // Check if this existing entry has an exact match for an updated entry and skip it if so. + key := piholeEntryKey{ep.DNSName, ep.RecordType} + if newRecord := updateNew[key]; newRecord != nil { + // PiHole only has a single target; no need to compare other fields. + if newRecord.Targets[0] == ep.Targets[0] { + delete(updateNew, key) continue } } diff --git a/provider/pihole/pihole_test.go b/provider/pihole/pihole_test.go index 99e1fb92e..bd19196af 100644 --- a/provider/pihole/pihole_test.go +++ b/provider/pihole/pihole_test.go @@ -18,6 +18,7 @@ package pihole import ( "context" + "reflect" "testing" "sigs.k8s.io/external-dns/endpoint" @@ -41,7 +42,7 @@ func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*en func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error { t.endpoints = append(t.endpoints, ep) - t.requests.createRequests += 1 + t.requests.createRequests = append(t.requests.createRequests, ep) return nil } @@ -53,18 +54,18 @@ func (t *testPiholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoi } } t.endpoints = newEPs - t.requests.deleteRequests += 1 + t.requests.deleteRequests = append(t.requests.deleteRequests, ep) return nil } type requestTracker struct { - createRequests int - deleteRequests int + createRequests []*endpoint.Endpoint + deleteRequests []*endpoint.Endpoint } func (r *requestTracker) clear() { - r.createRequests = 0 - r.deleteRequests = 0 + r.createRequests = nil + r.deleteRequests = nil } func TestNewPiholeProvider(t *testing.T) { @@ -127,10 +128,10 @@ func TestProvider(t *testing.T) { if len(newRecords) != 3 { t.Fatal("Expected list of 3 records, got:", records) } - if requests.createRequests != 3 { + if len(requests.createRequests) != 3 { t.Fatal("Expected 3 create requests, got:", requests.createRequests) } - if requests.deleteRequests != 0 { + if len(requests.deleteRequests) != 0 { t.Fatal("Expected no delete requests, got:", requests.deleteRequests) } @@ -141,6 +142,10 @@ func TestProvider(t *testing.T) { if newRecords[idx].Targets[0] != record.Targets[0] { t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) } + + if !reflect.DeepEqual(requests.createRequests[idx], record) { + t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName) + } } requests.clear() @@ -159,23 +164,18 @@ func TestProvider(t *testing.T) { RecordType: endpoint.RecordTypeA, }, } + recordToDelete := endpoint.Endpoint{ + DNSName: "test3.example.com", + Targets: []string{"192.168.1.3"}, + RecordType: endpoint.RecordTypeA, + } if err := p.ApplyChanges(context.Background(), &plan.Changes{ Delete: []*endpoint.Endpoint{ - { - DNSName: "test3.example.com", - Targets: []string{"192.168.1.3"}, - RecordType: endpoint.RecordTypeA, - }, + &recordToDelete, }, }); err != nil { t.Fatal(err) } - if requests.createRequests != 0 { - t.Fatal("Expected no create requests, got:", requests.createRequests) - } - if requests.deleteRequests != 1 { - t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) - } // Test records are updated newRecords, err = p.Records(context.Background()) @@ -185,6 +185,12 @@ func TestProvider(t *testing.T) { if len(newRecords) != 2 { t.Fatal("Expected list of 2 records, got:", records) } + if len(requests.createRequests) != 0 { + t.Fatal("Expected no create requests, got:", requests.createRequests) + } + if len(requests.deleteRequests) != 1 { + t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) + } for idx, record := range records { if newRecords[idx].DNSName != record.DNSName { @@ -195,6 +201,10 @@ func TestProvider(t *testing.T) { } } + if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDelete) { + t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDelete) + } + requests.clear() // Test update a record @@ -248,10 +258,10 @@ func TestProvider(t *testing.T) { if len(newRecords) != 2 { t.Fatal("Expected list of 2 records, got:", records) } - if requests.createRequests != 1 { + if len(requests.createRequests) != 1 { t.Fatal("Expected 1 create request, got:", requests.createRequests) } - if requests.deleteRequests != 1 { + if len(requests.deleteRequests) != 1 { t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) } @@ -264,5 +274,23 @@ func TestProvider(t *testing.T) { } } + expectedCreate := endpoint.Endpoint{ + DNSName: "test2.example.com", + Targets: []string{"10.0.0.1"}, + RecordType: endpoint.RecordTypeA, + } + expectedDelete := endpoint.Endpoint{ + DNSName: "test2.example.com", + Targets: []string{"192.168.1.2"}, + RecordType: endpoint.RecordTypeA, + } + + if !reflect.DeepEqual(requests.createRequests[0], &expectedCreate) { + t.Error("Unexpected create request, got:", requests.createRequests[0], "expected:", &expectedCreate) + } + if !reflect.DeepEqual(requests.deleteRequests[0], &expectedDelete) { + t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", &expectedDelete) + } + requests.clear() }