diff --git a/provider/pihole/clientV6.go b/provider/pihole/clientV6.go index ebaaabd6a..0d220b03c 100644 --- a/provider/pihole/clientV6.go +++ b/provider/pihole/clientV6.go @@ -143,12 +143,13 @@ func isValidIPv6(ip string) bool { } func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) { - out := make([]*endpoint.Endpoint, 0) results, err := p.getConfigValue(ctx, rtype) if err != nil { return nil, err } + endpoints := make(map[string]*endpoint.Endpoint) + for _, rec := range results { recs := strings.FieldsFunc(rec, func(r rune) bool { return r == ' ' || r == ',' @@ -186,7 +187,18 @@ func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endp } } - out = append(out, endpoint.NewEndpointWithTTL(DNSName, rtype, Ttl, Target)) + ep := endpoint.NewEndpointWithTTL(DNSName, rtype, Ttl, Target) + + if oldEp, ok := endpoints[DNSName]; ok { + ep.Targets = append(oldEp.Targets, Target) + } + + endpoints[DNSName] = ep + } + + out := make([]*endpoint.Endpoint, 0, len(endpoints)) + for _, ep := range endpoints { + out = append(out, ep) } return out, nil } @@ -272,37 +284,44 @@ func (p *piholeClientV6) apply(ctx context.Context, action string, ep *endpoint. return nil } - if p.cfg.DryRun { - log.Infof("DRY RUN: %s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0]) - return nil - } - - log.Infof("%s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0]) - // Get the current record if strings.Contains(ep.DNSName, "*") { return provider.NewSoftError(errors.New("UNSUPPORTED: Pihole DNS names cannot return wildcard")) } - switch ep.RecordType { - case endpoint.RecordTypeA, endpoint.RecordTypeAAAA: - apiUrl = p.generateApiUrl(apiUrl, fmt.Sprintf("%s %s", ep.Targets, ep.DNSName)) - case endpoint.RecordTypeCNAME: - if ep.RecordTTL.IsConfigured() { - apiUrl = p.generateApiUrl(apiUrl, fmt.Sprintf("%s,%s,%d", ep.DNSName, ep.Targets, ep.RecordTTL)) - } else { - apiUrl = p.generateApiUrl(apiUrl, fmt.Sprintf("%s,%s", ep.DNSName, ep.Targets)) + if ep.RecordType == endpoint.RecordTypeCNAME && len(ep.Targets) > 1 { + return provider.NewSoftError(errors.New("UNSUPPORTED: Pihole CNAME records cannot have multiple targets")) + } + + for _, target := range ep.Targets { + if p.cfg.DryRun { + log.Infof("DRY RUN: %s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, target) + continue } - } - req, err := http.NewRequestWithContext(ctx, action, apiUrl, nil) - if err != nil { - return err - } + log.Infof("%s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, target) - _, err = p.do(req) - if err != nil { - return err + targetApiUrl := apiUrl + + switch ep.RecordType { + case endpoint.RecordTypeA, endpoint.RecordTypeAAAA: + targetApiUrl = p.generateApiUrl(targetApiUrl, fmt.Sprintf("%s %s", target, ep.DNSName)) + case endpoint.RecordTypeCNAME: + if ep.RecordTTL.IsConfigured() { + targetApiUrl = p.generateApiUrl(targetApiUrl, fmt.Sprintf("%s,%s,%d", ep.DNSName, target, ep.RecordTTL)) + } else { + targetApiUrl = p.generateApiUrl(targetApiUrl, fmt.Sprintf("%s,%s", ep.DNSName, target)) + } + } + req, err := http.NewRequestWithContext(ctx, action, targetApiUrl, nil) + if err != nil { + return err + } + + _, err = p.do(req) + if err != nil { + return err + } } return nil @@ -400,6 +419,14 @@ func (p *piholeClientV6) do(req *http.Request) ([]byte, error) { if err := json.Unmarshal(jRes, &apiError); err != nil { return nil, fmt.Errorf("failed to unmarshal error response: %w", err) } + // Ignore if the entry already exists when adding a record + if strings.Contains(apiError.Error.Message, "Item already present") { + return jRes, nil + } + // Ignore if the entry does not exist when deleting a record + if res.StatusCode == http.StatusNotFound && req.Method == http.MethodDelete { + return jRes, nil + } if log.IsLevelEnabled(log.DebugLevel) { log.Debugf("Error on request %s", req.URL) if req.Body != nil { diff --git a/provider/pihole/clientV6_test.go b/provider/pihole/clientV6_test.go index 255032dd0..d474f8eba 100644 --- a/provider/pihole/clientV6_test.go +++ b/provider/pihole/clientV6_test.go @@ -23,10 +23,10 @@ import ( "fmt" "net/http" "net/http/httptest" - "strconv" "strings" "testing" + "github.com/google/go-cmp/cmp" "sigs.k8s.io/external-dns/endpoint" ) @@ -192,10 +192,14 @@ func TestListRecordsV6(t *testing.T) { "192.168.178.33 service1.example.com", "192.168.178.34 service2.example.com", "192.168.178.34 service3.example.com", + "192.168.178.35 service8.example.com", + "192.168.178.36 service8.example.com", "fc00::1:192:168:1:1 service4.example.com", "fc00::1:192:168:1:2 service5.example.com", "fc00::1:192:168:1:3 service6.example.com", "::ffff:192.168.20.3 service7.example.com", + "fc00::1:192:168:1:4 service9.example.com", + "fc00::1:192:168:1:5 service9.example.com", "192.168.20.3 service7.example.com" ] } @@ -237,37 +241,70 @@ func TestListRecordsV6(t *testing.T) { } // Ensure A records were parsed correctly - expected := [][]string{ - {"service1.example.com", "192.168.178.33"}, - {"service2.example.com", "192.168.178.34"}, - {"service3.example.com", "192.168.178.34"}, - {"service7.example.com", "192.168.20.3"}, + expected := []*endpoint.Endpoint{ + { + DNSName: "service1.example.com", + Targets: []string{"192.168.178.33"}, + }, + { + DNSName: "service2.example.com", + Targets: []string{"192.168.178.34"}, + }, + { + DNSName: "service3.example.com", + Targets: []string{"192.168.178.34"}, + }, + { + DNSName: "service7.example.com", + Targets: []string{"192.168.20.3"}, + }, + { + DNSName: "service8.example.com", + Targets: []string{"192.168.178.35", "192.168.178.36"}, + }, } // Test retrieve A records unfiltered arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA) if err != nil { t.Fatal(err) } - if len(arecs) != len(expected) { - t.Fatalf("Expected %d A records returned, got: %d", len(expected), len(arecs)) - } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) + expectedMap := make(map[string]*endpoint.Endpoint) + for _, ep := range expected { + expectedMap[ep.DNSName] = ep + } + for _, rec := range arecs { + if ep, ok := expectedMap[rec.DNSName]; ok { + if cmp.Diff(ep.Targets, rec.Targets) != "" { + t.Errorf("Got invalid targets for %s: %v, expected: %v", rec.DNSName, rec.Targets, ep.Targets) + } } } // Ensure AAAA records were parsed correctly - expected = [][]string{ - {"service4.example.com", "fc00::1:192:168:1:1"}, - {"service5.example.com", "fc00::1:192:168:1:2"}, - {"service6.example.com", "fc00::1:192:168:1:3"}, - {"service7.example.com", "::ffff:192.168.20.3"}, + expected = []*endpoint.Endpoint{ + { + DNSName: "service4.example.com", + Targets: []string{"fc00::1:192:168:1:1"}, + }, + { + DNSName: "service5.example.com", + Targets: []string{"fc00::1:192:168:1:2"}, + }, + { + DNSName: "service6.example.com", + Targets: []string{"fc00::1:192:168:1:3"}, + }, + { + DNSName: "service7.example.com", + Targets: []string{"::ffff:192.168.20.3"}, + }, + { + DNSName: "service9.example.com", + Targets: []string{"fc00::1:192:168:1:4", "fc00::1:192:168:1:5"}, + }, } + // Test retrieve AAAA records unfiltered arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA) if err != nil { @@ -278,20 +315,34 @@ func TestListRecordsV6(t *testing.T) { t.Fatalf("Expected %d AAAA records returned, got: %d", len(expected), len(arecs)) } - for idx, rec := range arecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) + expectedMap = make(map[string]*endpoint.Endpoint) + for _, ep := range expected { + expectedMap[ep.DNSName] = ep + } + for _, rec := range arecs { + if ep, ok := expectedMap[rec.DNSName]; ok { + if cmp.Diff(ep.Targets, rec.Targets) != "" { + t.Errorf("Got invalid targets for %s: %v, expected: %v", rec.DNSName, rec.Targets, ep.Targets) + } } } // Ensure CNAME records were parsed correctly - expected = [][]string{ - {"source1.example.com", "target1.domain.com", "1000"}, - {"source2.example.com", "target2.domain.com", "50"}, - {"source3.example.com", "target3.domain.com"}, + expected = []*endpoint.Endpoint{ + { + DNSName: "source1.example.com", + Targets: []string{"target1.domain.com"}, + RecordTTL: 1000, + }, + { + DNSName: "source2.example.com", + Targets: []string{"target2.domain.com"}, + RecordTTL: 50, + }, + { + DNSName: "source3.example.com", + Targets: []string{"target3.domain.com"}, + }, } // Test retrieve CNAME records unfiltered @@ -303,17 +354,14 @@ func TestListRecordsV6(t *testing.T) { t.Fatalf("Expected %d CAME records returned, got: %d", len(expected), len(cnamerecs)) } - for idx, rec := range cnamerecs { - if rec.DNSName != expected[idx][0] { - t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) - } - if rec.Targets[0] != expected[idx][1] { - t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) - } - if len(expected[idx]) == 3 { - expectedTTL, _ := strconv.ParseInt(expected[idx][2], 10, 64) - if int64(rec.RecordTTL) != expectedTTL { - t.Error("Got invalid TTL:", rec.RecordTTL, "expected:", expected[idx][2]) + expectedMap = make(map[string]*endpoint.Endpoint) + for _, ep := range expected { + expectedMap[ep.DNSName] = ep + } + for _, rec := range arecs { + if ep, ok := expectedMap[rec.DNSName]; ok { + if cmp.Diff(ep.Targets, rec.Targets) != "" { + t.Errorf("Got invalid targets for %s: %v, expected: %v", rec.DNSName, rec.Targets, ep.Targets) } } } @@ -432,8 +480,34 @@ func TestErrorsV6(t *testing.T) { if len(resp) != 2 { t.Fatal("Expected one records returned, got:", len(resp)) } - if resp[1].RecordTTL != 0 { - t.Fatal("Expected no TTL returned, got:", resp[0].RecordTTL) + + expected := []*endpoint.Endpoint{ + { + DNSName: "source1.example.com", + Targets: []string{"target1.domain.com"}, + RecordTTL: 100, + }, + { + DNSName: "source2.example.com", + Targets: []string{"target2.domain.com"}, + }, + } + + expectedMap := make(map[string]*endpoint.Endpoint) + for _, ep := range expected { + expectedMap[ep.DNSName] = ep + } + for _, rec := range resp { + if ep, ok := expectedMap[rec.DNSName]; ok { + if cmp.Diff(ep.Targets, rec.Targets) != "" { + t.Errorf("Got invalid targets for %s: %v, expected: %v", rec.DNSName, rec.Targets, ep.Targets) + } + if ep.RecordTTL != rec.RecordTTL { + t.Errorf("Got invalid TTL for %s: %d, expected: %d", rec.DNSName, rec.RecordTTL, ep.RecordTTL) + } + } else { + t.Errorf("Unexpected record found: %s", rec.DNSName) + } } } @@ -717,6 +791,10 @@ func TestCreateRecordV6(t *testing.T) { if r.Method == http.MethodPut && (r.URL.Path == "/api/config/dns/hosts/192.168.1.1 test.example.com" || r.URL.Path == "/api/config/dns/hosts/fc00::1:192:168:1:1 test.example.com" || r.URL.Path == "/api/config/dns/cnameRecords/source1.example.com,target1.domain.com" || + r.URL.Path == "/api/config/dns/hosts/192.168.1.2 test.example.com" || + r.URL.Path == "/api/config/dns/hosts/192.168.1.3 test.example.com" || + r.URL.Path == "/api/config/dns/hosts/fc00::1:192:168:1:2 test.example.com" || + r.URL.Path == "/api/config/dns/hosts/fc00::1:192:168:1:3 test.example.com" || r.URL.Path == "/api/config/dns/cnameRecords/source2.example.com,target2.domain.com,500") { // Return A records @@ -748,6 +826,16 @@ func TestCreateRecordV6(t *testing.T) { t.Fatal(err) } + // Test create multiple A records + ep = &endpoint.Endpoint{ + DNSName: "test.example.com", + Targets: []string{"192.168.1.2", "192.168.1.3"}, + RecordType: endpoint.RecordTypeA, + } + if err := cl.createRecord(context.Background(), ep); err != nil { + t.Fatal(err) + } + // Test create AAAA record ep = &endpoint.Endpoint{ DNSName: "test.example.com", @@ -758,6 +846,16 @@ func TestCreateRecordV6(t *testing.T) { t.Fatal(err) } + // Test create multiple AAAA records + ep = &endpoint.Endpoint{ + DNSName: "test.example.com", + Targets: []string{"fc00::1:192:168:1:2", "fc00::1:192:168:1:3"}, + RecordType: endpoint.RecordTypeAAAA, + } + if err := cl.createRecord(context.Background(), ep); err != nil { + t.Fatal(err) + } + // Test create CNAME record ep = &endpoint.Endpoint{ DNSName: "source1.example.com", @@ -779,6 +877,16 @@ func TestCreateRecordV6(t *testing.T) { t.Fatal(err) } + // Test create CNAME record with multiple targets and ensure it fails + ep = &endpoint.Endpoint{ + DNSName: "source3.example.com", + Targets: []string{"target3.domain.com", "target4.domain.com"}, + RecordType: endpoint.RecordTypeCNAME, + } + if err := cl.createRecord(context.Background(), ep); err == nil { + t.Fatal(err) + } + // Test create a wildcard record and ensure it fails ep = &endpoint.Endpoint{ DNSName: "*.example.com", diff --git a/provider/pihole/pihole.go b/provider/pihole/pihole.go index 53a7bf2af..cf401f374 100644 --- a/provider/pihole/pihole.go +++ b/provider/pihole/pihole.go @@ -19,6 +19,9 @@ package pihole import ( "context" "errors" + "slices" + + "github.com/google/go-cmp/cmp" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" @@ -32,7 +35,8 @@ var ErrNoPiholeServer = errors.New("no pihole server found in the environment or // PiholeProvider is an implementation of Provider for Pi-hole Local DNS. type PiholeProvider struct { provider.BaseProvider - api piholeAPI + api piholeAPI + apiVersion string } // PiholeConfig is used for configuring a PiholeProvider. @@ -70,7 +74,7 @@ func NewPiholeProvider(cfg PiholeConfig) (*PiholeProvider, error) { if err != nil { return nil, err } - return &PiholeProvider{api: api}, nil + return &PiholeProvider{api: api, apiVersion: cfg.APIVersion}, nil } // Records implements Provider, populating a slice of endpoints from @@ -105,6 +109,19 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes updateNew := make(map[piholeEntryKey]*endpoint.Endpoint) for _, ep := range changes.UpdateNew { key := piholeEntryKey{ep.DNSName, ep.RecordType} + + // If the API version is 6, we need to handle multiple targets for the same DNS name. + if p.apiVersion == "6" { + if existing, ok := updateNew[key]; ok { + existing.Targets = append(existing.Targets, ep.Targets...) + + // Deduplicate targets + slices.Sort(existing.Targets) + existing.Targets = slices.Compact(existing.Targets) + + ep = existing + } + } updateNew[key] = ep } @@ -112,14 +129,23 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes // Check if this existing entry has an exact match for an updated entry and skip it if so. key := piholeEntryKey{ep.DNSName, ep.RecordType} if newRecord := updateNew[key]; newRecord != nil { - // PiHole only has a single target; no need to compare other fields. - if newRecord.Targets[0] == ep.Targets[0] { - delete(updateNew, key) - continue + // If the API version is 6, we need to handle multiple targets for the same DNS name. + if p.apiVersion == "6" { + if cmp.Diff(ep.Targets, newRecord.Targets) == "" { + delete(updateNew, key) + continue + } + } else { + // For API version <= 5, we only check the first target. + if newRecord.Targets[0] == ep.Targets[0] { + delete(updateNew, key) + continue + } + } + + if err := p.api.deleteRecord(ctx, ep); err != nil { + return err } - } - if err := p.api.deleteRecord(ctx, ep); err != nil { - return err } } diff --git a/provider/pihole/piholeV6_test.go b/provider/pihole/piholeV6_test.go index 22e0a2005..b14f77bc0 100644 --- a/provider/pihole/piholeV6_test.go +++ b/provider/pihole/piholeV6_test.go @@ -22,6 +22,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" ) @@ -60,7 +61,7 @@ func (t *testPiholeClientV6) createRecord(_ context.Context, ep *endpoint.Endpoi func (t *testPiholeClientV6) deleteRecord(_ context.Context, ep *endpoint.Endpoint) error { newEPs := make([]*endpoint.Endpoint, 0) for _, existing := range t.endpoints { - if existing.DNSName != ep.DNSName && existing.Targets[0] != ep.Targets[0] { + if existing.DNSName != ep.DNSName || cmp.Diff(existing.Targets, ep.Targets) != "" || existing.RecordType != ep.RecordType { newEPs = append(newEPs, existing) } } @@ -82,7 +83,8 @@ func (r *requestTrackerV6) clear() { func TestErrorHandling(t *testing.T) { requests := requestTrackerV6{} p := &PiholeProvider{ - api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + apiVersion: "6", } p.api.(*testPiholeClientV6).trigger = "AERROR" @@ -121,7 +123,8 @@ func TestNewPiholeProviderV6(t *testing.T) { func TestProviderV6(t *testing.T) { requests := requestTrackerV6{} p := &PiholeProvider{ - api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + apiVersion: "6", } records, err := p.Records(context.Background()) @@ -342,6 +345,11 @@ func TestProviderV6(t *testing.T) { Targets: []string{"10.0.0.1"}, RecordType: endpoint.RecordTypeA, }, + { + DNSName: "test2.example.com", + Targets: []string{"10.0.0.2"}, + RecordType: endpoint.RecordTypeA, + }, { DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, @@ -383,7 +391,7 @@ func TestProviderV6(t *testing.T) { expectedCreateA := endpoint.Endpoint{ DNSName: "test2.example.com", - Targets: []string{"10.0.0.1"}, + Targets: []string{"10.0.0.1", "10.0.0.2"}, RecordType: endpoint.RecordTypeA, } expectedDeleteA := endpoint.Endpoint{