Dedupe on DNS name and record type

This commit is contained in:
Scott Fleener 2023-01-08 20:48:15 -05:00
parent 27eb7c9ea0
commit e264fdecff
2 changed files with 64 additions and 28 deletions

View File

@ -49,6 +49,12 @@ type PiholeConfig struct {
DryRun bool
}
// Helper struct for de-duping DNS entry updates.
type piholeEntryKey struct {
Target string
RecordType string
}
// NewPiholeProvider initializes a new Pi-hole Local DNS based Provider.
func NewPiholeProvider(cfg PiholeConfig) (*PiholeProvider, error) {
api, err := newPiholeClient(cfg)
@ -82,17 +88,19 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
}
// Handle updated state - there are no endpoints for updating in place.
updateNew := make(map[string]*endpoint.Endpoint)
updateNew := make(map[piholeEntryKey]*endpoint.Endpoint)
for _, ep := range changes.UpdateNew {
updateNew[ep.DNSName] = ep
key := piholeEntryKey{ep.DNSName, ep.RecordType}
updateNew[key] = ep
}
for _, ep := range changes.UpdateOld {
// Check if this existing entry has an exact match for an updated entry, and skip it if so.
if newRecord := updateNew[ep.DNSName]; newRecord != nil {
// PiHole only has a single target and a record type, no need to compare other fields.
if newRecord.Targets.String() == ep.Targets.String() && newRecord.RecordType == ep.RecordType {
delete(updateNew, ep.DNSName)
// Check if this existing entry has an exact match for an updated entry and skip it if so.
key := piholeEntryKey{ep.DNSName, ep.RecordType}
if newRecord := updateNew[key]; newRecord != nil {
// PiHole only has a single target; no need to compare other fields.
if newRecord.Targets[0] == ep.Targets[0] {
delete(updateNew, key)
continue
}
}

View File

@ -18,6 +18,7 @@ package pihole
import (
"context"
"reflect"
"testing"
"sigs.k8s.io/external-dns/endpoint"
@ -41,7 +42,7 @@ func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*en
func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error {
t.endpoints = append(t.endpoints, ep)
t.requests.createRequests += 1
t.requests.createRequests = append(t.requests.createRequests, ep)
return nil
}
@ -53,18 +54,18 @@ func (t *testPiholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoi
}
}
t.endpoints = newEPs
t.requests.deleteRequests += 1
t.requests.deleteRequests = append(t.requests.deleteRequests, ep)
return nil
}
type requestTracker struct {
createRequests int
deleteRequests int
createRequests []*endpoint.Endpoint
deleteRequests []*endpoint.Endpoint
}
func (r *requestTracker) clear() {
r.createRequests = 0
r.deleteRequests = 0
r.createRequests = nil
r.deleteRequests = nil
}
func TestNewPiholeProvider(t *testing.T) {
@ -127,10 +128,10 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 3 {
t.Fatal("Expected list of 3 records, got:", records)
}
if requests.createRequests != 3 {
if len(requests.createRequests) != 3 {
t.Fatal("Expected 3 create requests, got:", requests.createRequests)
}
if requests.deleteRequests != 0 {
if len(requests.deleteRequests) != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
}
@ -141,6 +142,10 @@ func TestProvider(t *testing.T) {
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
}
if !reflect.DeepEqual(requests.createRequests[idx], record) {
t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
}
}
requests.clear()
@ -159,23 +164,18 @@ func TestProvider(t *testing.T) {
RecordType: endpoint.RecordTypeA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Delete: []*endpoint.Endpoint{
{
recordToDelete := endpoint.Endpoint{
DNSName: "test3.example.com",
Targets: []string{"192.168.1.3"},
RecordType: endpoint.RecordTypeA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Delete: []*endpoint.Endpoint{
&recordToDelete,
},
}); err != nil {
t.Fatal(err)
}
if requests.createRequests != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if requests.deleteRequests != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}
// Test records are updated
newRecords, err = p.Records(context.Background())
@ -185,6 +185,12 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 2 {
t.Fatal("Expected list of 2 records, got:", records)
}
if len(requests.createRequests) != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}
for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
@ -195,6 +201,10 @@ func TestProvider(t *testing.T) {
}
}
if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDelete) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDelete)
}
requests.clear()
// Test update a record
@ -248,10 +258,10 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 2 {
t.Fatal("Expected list of 2 records, got:", records)
}
if requests.createRequests != 1 {
if len(requests.createRequests) != 1 {
t.Fatal("Expected 1 create request, got:", requests.createRequests)
}
if requests.deleteRequests != 1 {
if len(requests.deleteRequests) != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}
@ -264,5 +274,23 @@ func TestProvider(t *testing.T) {
}
}
expectedCreate := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
}
expectedDelete := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
}
if !reflect.DeepEqual(requests.createRequests[0], &expectedCreate) {
t.Error("Unexpected create request, got:", requests.createRequests[0], "expected:", &expectedCreate)
}
if !reflect.DeepEqual(requests.deleteRequests[0], &expectedDelete) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", &expectedDelete)
}
requests.clear()
}