Dedupe on DNS name and record type

This commit is contained in:
Scott Fleener 2023-01-08 20:48:15 -05:00
parent 27eb7c9ea0
commit e264fdecff
2 changed files with 64 additions and 28 deletions

View File

@ -49,6 +49,12 @@ type PiholeConfig struct {
DryRun bool DryRun bool
} }
// Helper struct for de-duping DNS entry updates.
type piholeEntryKey struct {
Target string
RecordType string
}
// NewPiholeProvider initializes a new Pi-hole Local DNS based Provider. // NewPiholeProvider initializes a new Pi-hole Local DNS based Provider.
func NewPiholeProvider(cfg PiholeConfig) (*PiholeProvider, error) { func NewPiholeProvider(cfg PiholeConfig) (*PiholeProvider, error) {
api, err := newPiholeClient(cfg) api, err := newPiholeClient(cfg)
@ -82,17 +88,19 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
} }
// Handle updated state - there are no endpoints for updating in place. // Handle updated state - there are no endpoints for updating in place.
updateNew := make(map[string]*endpoint.Endpoint) updateNew := make(map[piholeEntryKey]*endpoint.Endpoint)
for _, ep := range changes.UpdateNew { for _, ep := range changes.UpdateNew {
updateNew[ep.DNSName] = ep key := piholeEntryKey{ep.DNSName, ep.RecordType}
updateNew[key] = ep
} }
for _, ep := range changes.UpdateOld { for _, ep := range changes.UpdateOld {
// Check if this existing entry has an exact match for an updated entry, and skip it if so. // Check if this existing entry has an exact match for an updated entry and skip it if so.
if newRecord := updateNew[ep.DNSName]; newRecord != nil { key := piholeEntryKey{ep.DNSName, ep.RecordType}
// PiHole only has a single target and a record type, no need to compare other fields. if newRecord := updateNew[key]; newRecord != nil {
if newRecord.Targets.String() == ep.Targets.String() && newRecord.RecordType == ep.RecordType { // PiHole only has a single target; no need to compare other fields.
delete(updateNew, ep.DNSName) if newRecord.Targets[0] == ep.Targets[0] {
delete(updateNew, key)
continue continue
} }
} }

View File

@ -18,6 +18,7 @@ package pihole
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
@ -41,7 +42,7 @@ func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*en
func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error { func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error {
t.endpoints = append(t.endpoints, ep) t.endpoints = append(t.endpoints, ep)
t.requests.createRequests += 1 t.requests.createRequests = append(t.requests.createRequests, ep)
return nil return nil
} }
@ -53,18 +54,18 @@ func (t *testPiholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoi
} }
} }
t.endpoints = newEPs t.endpoints = newEPs
t.requests.deleteRequests += 1 t.requests.deleteRequests = append(t.requests.deleteRequests, ep)
return nil return nil
} }
type requestTracker struct { type requestTracker struct {
createRequests int createRequests []*endpoint.Endpoint
deleteRequests int deleteRequests []*endpoint.Endpoint
} }
func (r *requestTracker) clear() { func (r *requestTracker) clear() {
r.createRequests = 0 r.createRequests = nil
r.deleteRequests = 0 r.deleteRequests = nil
} }
func TestNewPiholeProvider(t *testing.T) { func TestNewPiholeProvider(t *testing.T) {
@ -127,10 +128,10 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 3 { if len(newRecords) != 3 {
t.Fatal("Expected list of 3 records, got:", records) t.Fatal("Expected list of 3 records, got:", records)
} }
if requests.createRequests != 3 { if len(requests.createRequests) != 3 {
t.Fatal("Expected 3 create requests, got:", requests.createRequests) t.Fatal("Expected 3 create requests, got:", requests.createRequests)
} }
if requests.deleteRequests != 0 { if len(requests.deleteRequests) != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests) t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
} }
@ -141,6 +142,10 @@ func TestProvider(t *testing.T) {
if newRecords[idx].Targets[0] != record.Targets[0] { if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
} }
if !reflect.DeepEqual(requests.createRequests[idx], record) {
t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
}
} }
requests.clear() requests.clear()
@ -159,23 +164,18 @@ func TestProvider(t *testing.T) {
RecordType: endpoint.RecordTypeA, RecordType: endpoint.RecordTypeA,
}, },
} }
if err := p.ApplyChanges(context.Background(), &plan.Changes{ recordToDelete := endpoint.Endpoint{
Delete: []*endpoint.Endpoint{
{
DNSName: "test3.example.com", DNSName: "test3.example.com",
Targets: []string{"192.168.1.3"}, Targets: []string{"192.168.1.3"},
RecordType: endpoint.RecordTypeA, RecordType: endpoint.RecordTypeA,
}, }
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Delete: []*endpoint.Endpoint{
&recordToDelete,
}, },
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if requests.createRequests != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if requests.deleteRequests != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}
// Test records are updated // Test records are updated
newRecords, err = p.Records(context.Background()) newRecords, err = p.Records(context.Background())
@ -185,6 +185,12 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 2 { if len(newRecords) != 2 {
t.Fatal("Expected list of 2 records, got:", records) t.Fatal("Expected list of 2 records, got:", records)
} }
if len(requests.createRequests) != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}
for idx, record := range records { for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName { if newRecords[idx].DNSName != record.DNSName {
@ -195,6 +201,10 @@ func TestProvider(t *testing.T) {
} }
} }
if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDelete) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDelete)
}
requests.clear() requests.clear()
// Test update a record // Test update a record
@ -248,10 +258,10 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 2 { if len(newRecords) != 2 {
t.Fatal("Expected list of 2 records, got:", records) t.Fatal("Expected list of 2 records, got:", records)
} }
if requests.createRequests != 1 { if len(requests.createRequests) != 1 {
t.Fatal("Expected 1 create request, got:", requests.createRequests) t.Fatal("Expected 1 create request, got:", requests.createRequests)
} }
if requests.deleteRequests != 1 { if len(requests.deleteRequests) != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
} }
@ -264,5 +274,23 @@ func TestProvider(t *testing.T) {
} }
} }
expectedCreate := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
}
expectedDelete := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
}
if !reflect.DeepEqual(requests.createRequests[0], &expectedCreate) {
t.Error("Unexpected create request, got:", requests.createRequests[0], "expected:", &expectedCreate)
}
if !reflect.DeepEqual(requests.deleteRequests[0], &expectedDelete) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", &expectedDelete)
}
requests.clear() requests.clear()
} }