From e22ceab66f8339b4ace2cd4f9b63a53ea7a779ea Mon Sep 17 00:00:00 2001 From: Andrew Hay <39sumer3939@gmail.com> Date: Wed, 1 Oct 2025 03:50:23 -0400 Subject: [PATCH] refactor(pihole): reduce cyclomatic complexity of TestProviderV6 (#5876) * refactor(pihole): reduce cyclomatic complexity of TestProviderV6 * chore(pihole): increase coverage * style: linting * style: linting * fix: remove coverage html --- .gitignore | 1 + .golangci.yml | 2 +- provider/pihole/clientV6_test.go | 105 ++++++- provider/pihole/client_test.go | 63 ++-- provider/pihole/piholeV6_test.go | 483 ++++++++++++------------------- 5 files changed, 318 insertions(+), 336 deletions(-) diff --git a/.gitignore b/.gitignore index 29561b906..47232c849 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ cscope.* # coverage output cover.out +coverage.html *.coverprofile external-dns diff --git a/.golangci.yml b/.golangci.yml index 424196b17..508e4474d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -41,7 +41,7 @@ linters: - name: confusing-naming disabled: true cyclop: # Lower cyclomatic complexity threshold after the max complexity is lowered - max-complexity: 43 + max-complexity: 37 # Controller/execute.go:147:1: calculated cyclomatic complexity for function buildProvider is 37 testifylint: # Enable all checkers (https://github.com/Antonboom/testifylint#checkers). # Default: false diff --git a/provider/pihole/clientV6_test.go b/provider/pihole/clientV6_test.go index d474f8eba..5b5393774 100644 --- a/provider/pihole/clientV6_test.go +++ b/provider/pihole/clientV6_test.go @@ -92,6 +92,12 @@ func newTestServerV6(t *testing.T, hdlr http.HandlerFunc) *httptest.Server { return svr } +type errorTransportV6 struct{} + +func (t *errorTransportV6) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, errors.New("network error") +} + func TestNewPiholeClientV6(t *testing.T) { // Test correct error on no server provided _, err := newPiholeClientV6(PiholeConfig{APIVersion: "6"}) @@ -117,7 +123,10 @@ func TestNewPiholeClientV6(t *testing.T) { srvr := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/auth" && r.Method == http.MethodPost { var requestData map[string]string - json.NewDecoder(r.Body).Decode(&requestData) + err := json.NewDecoder(r.Body).Decode(&requestData) + if err != nil { + t.Fatal(err) + } defer r.Body.Close() w.Header().Set("Content-Type", "application/json") @@ -125,7 +134,7 @@ func TestNewPiholeClientV6(t *testing.T) { if requestData["password"] != "correct" { // Return unsuccessful authentication response w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{ + _, err = w.Write([]byte(`{ "session": { "valid": false, "totp": false, @@ -135,11 +144,14 @@ func TestNewPiholeClientV6(t *testing.T) { }, "took": 0.2 }`)) + if err != nil { + t.Fatal(err) + } return } // Return successful authentication response - w.Write([]byte(`{ + _, err = w.Write([]byte(`{ "session": { "valid": true, "totp": false, @@ -185,7 +197,7 @@ func TestListRecordsV6(t *testing.T) { w.Header().Set("Content-Type", "application/json") // Return A records - w.Write([]byte(`{ + if _, err := w.Write([]byte(`{ "config": { "dns": { "hosts": [ @@ -205,7 +217,9 @@ func TestListRecordsV6(t *testing.T) { } }, "took": 5 - }`)) + }`)); err != nil { + t.Fatal(err) + } } else if r.URL.Path == "/api/config/dns/cnameRecords" && r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) @@ -384,9 +398,12 @@ func TestErrorsV6(t *testing.T) { Server: "not an url", APIVersion: "6", } - clErrURL, _ := newPiholeClientV6(cfgErrURL) + clErrURL, err := newPiholeClientV6(cfgErrURL) + if err != nil { + t.Fatal(err) + } - _, err := clErrURL.listRecords(context.Background(), endpoint.RecordTypeCNAME) + _, err = clErrURL.listRecords(context.Background(), endpoint.RecordTypeCNAME) if err == nil { t.Fatal("Expected error for using invalid URL") } @@ -785,6 +802,80 @@ func TestDoRetryOne(t *testing.T) { } +func TestDoV6AdditionalCases(t *testing.T) { + t.Run("http client error", func(t *testing.T) { + client := &piholeClientV6{ + httpClient: &http.Client{ + Transport: &errorTransportV6{}, + }, + } + req, _ := http.NewRequest(http.MethodGet, "http://localhost", nil) + _, err := client.do(req) + if err == nil { + t.Fatal("expected an error, but got none") + } + if !strings.Contains(err.Error(), "network error") { + t.Fatalf("expected error to contain 'network error', but got '%v'", err) + } + }) + + t.Run("item already present", func(t *testing.T) { + server := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{ + "error": { + "key": "bad_request", + "message": "Item already present", + "hint": "The item you're trying to add already exists" + }, + "took": 0.1 + }`)) + }) + defer server.Close() + + client := &piholeClientV6{ + httpClient: server.Client(), + token: "test-token", + } + req, _ := http.NewRequest(http.MethodPut, server.URL+"/api/test", nil) + resp, err := client.do(req) + if err != nil { + t.Fatalf("expected no error for 'Item already present', but got '%v'", err) + } + if resp == nil { + t.Fatal("expected response, but got nil") + } + }) + + t.Run("404 on DELETE", func(t *testing.T) { + server := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{ + "error": { + "key": "not_found", + "message": "Item not found", + "hint": "The item you're trying to delete does not exist" + }, + "took": 0.1 + }`)) + }) + defer server.Close() + + client := &piholeClientV6{ + httpClient: server.Client(), + token: "test-token", + } + req, _ := http.NewRequest(http.MethodDelete, server.URL+"/api/test", nil) + resp, err := client.do(req) + if err != nil { + t.Fatalf("expected no error for 404 on DELETE, but got '%v'", err) + } + if resp == nil { + t.Fatal("expected response, but got nil") + } + }) +} + func TestCreateRecordV6(t *testing.T) { var ep *endpoint.Endpoint srvr := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/provider/pihole/client_test.go b/provider/pihole/client_test.go index b94545b40..912b631b5 100644 --- a/provider/pihole/client_test.go +++ b/provider/pihole/client_test.go @@ -57,15 +57,21 @@ func TestNewPiholeClient(t *testing.T) { // Create a test server for auth tests srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } pw := r.Form.Get("pw") if pw != "correct" { // Pihole actually server side renders the fact that you failed, normal 200 - w.Write([]byte("Invalid")) + _, err = w.Write([]byte("Invalid")) + if err != nil { + t.Fatal(err) + } return } // This is a subset of what happens on successful login - w.Write([]byte(` + _, err = w.Write([]byte(` @@ -73,6 +79,9 @@ func TestNewPiholeClient(t *testing.T) { `)) + if err != nil { + t.Fatal(err) + } }) defer srvr.Close() @@ -124,12 +133,15 @@ func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, exp func TestListRecords(t *testing.T) { srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } if r.Form.Get("action") != "get" { t.Error("Expected 'get' action in form from client") } if strings.Contains(r.URL.Path, "cname") { - w.Write([]byte(` + _, err = w.Write([]byte(` { "data": [ ["test4.example.com", "cname.example.com"], @@ -138,10 +150,13 @@ func TestListRecords(t *testing.T) { ] } `)) + if err != nil { + t.Fatal(err) + } return } // Pihole makes no distinction between A and AAAA records - w.Write([]byte(` + _, err = w.Write([]byte(` { "data": [ ["test1.example.com", "192.168.1.1"], @@ -153,6 +168,9 @@ func TestListRecords(t *testing.T) { ] } `)) + if err != nil { + t.Fatal(err) + } }) defer srvr.Close() @@ -243,41 +261,32 @@ func testErrorScenarios(t *testing.T, srvrErr *httptest.Server) { func TestErrorScenarios(t *testing.T) { // Test errors token srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } pw := r.Form.Get("pw") if pw != "" { if pw != "correct" { - // Pihole actually server side renders the fact that you failed, normal 200 - w.Write([]byte("Invalid")) + _, err = w.Write([]byte("Invalid")) + if err != nil { + t.Fatal(err) + } return } } if strings.Contains(r.URL.Path, "admin/scripts/pi-hole/php/customcname.php") && r.Form.Get("token") == "correct" { - w.Write([]byte(` + _, err = w.Write([]byte(` { "nodata": [ ["nodata", "no"] ] } - `)) - return - } - if strings.Contains(r.URL.Path, "admin/index.php?login") { - w.Write([]byte(` - - - - - - `)) + if err != nil { + t.Fatal(err) + } } - // Token Expired - w.Write([]byte(` - { - "auth": "expired" - } - `)) }) defer srvrErr.Close() diff --git a/provider/pihole/piholeV6_test.go b/provider/pihole/piholeV6_test.go index b14f77bc0..68c62453c 100644 --- a/provider/pihole/piholeV6_test.go +++ b/provider/pihole/piholeV6_test.go @@ -19,14 +19,32 @@ package pihole import ( "context" "errors" - "reflect" "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" ) +var ( + endpointSort = cmpopts.SortSlices(func(x, y *endpoint.Endpoint) bool { + if x.DNSName < y.DNSName { + return true + } + if x.DNSName > y.DNSName { + return false + } + if x.RecordType < y.RecordType { + return true + } + if x.RecordType > y.RecordType { + return false + } + return x.Targets.String() < y.Targets.String() + }) +) + type testPiholeClientV6 struct { endpoints []*endpoint.Endpoint requests *requestTrackerV6 @@ -127,316 +145,179 @@ func TestProviderV6(t *testing.T) { apiVersion: "6", } - records, err := p.Records(context.Background()) - if err != nil { - t.Fatal(err) - } - if len(records) != 0 { - t.Fatal("Expected empty list of records, got:", records) - } - - // Populate the provider with records - records = []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"192.168.1.2"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test3.example.com", - Targets: []string{"192.168.1.3"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:192:168:1:2"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test3.example.com", - Targets: []string{"fc00::1:192:168:1:3"}, - RecordType: endpoint.RecordTypeAAAA, - }, - } - if err := p.ApplyChanges(context.Background(), &plan.Changes{ - Create: records, - }); err != nil { - t.Fatal(err) - } - - // Test records are correct on retrieval - - newRecords, err := p.Records(context.Background()) - if err != nil { - t.Fatal(err) - } - if len(newRecords) != 6 { - t.Fatal("Expected list of 6 records, got:", records) - } - if len(requests.createRequests) != 6 { - t.Fatal("Expected 6 create requests, got:", requests.createRequests) - } - if len(requests.deleteRequests) != 0 { - t.Fatal("Expected no delete requests, got:", requests.deleteRequests) - } - - for idx, record := range records { - if newRecords[idx].DNSName != record.DNSName { - t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) + t.Run("Initial Records", func(t *testing.T) { + records, err := p.Records(context.Background()) + if err != nil { + t.Fatal(err) } - if newRecords[idx].Targets[0] != record.Targets[0] { - t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) + if len(records) != 0 { + t.Fatal("Expected empty list of records, got:", records) + } + }) + + t.Run("Create Records", func(t *testing.T) { + records := []*endpoint.Endpoint{ + {DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test3.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA}, + {DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA}, + {DNSName: "test3.example.com", Targets: []string{"fc00::1:192:168:1:3"}, RecordType: endpoint.RecordTypeAAAA}, + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{Create: records}); err != nil { + t.Fatal(err) } - if !reflect.DeepEqual(requests.createRequests[idx], record) { - t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName) + newRecords, err := p.Records(context.Background()) + if err != nil { + t.Fatal(err) } - } - - requests.clear() - - // Test delete a record - - records = []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"192.168.1.2"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:192:168:1:2"}, - RecordType: endpoint.RecordTypeAAAA, - }, - } - recordToDeleteA := endpoint.Endpoint{ - DNSName: "test3.example.com", - Targets: []string{"192.168.1.3"}, - RecordType: endpoint.RecordTypeA, - } - if err := p.ApplyChanges(context.Background(), &plan.Changes{ - Delete: []*endpoint.Endpoint{ - &recordToDeleteA, - }, - }); err != nil { - t.Fatal(err) - } - recordToDeleteAAAA := endpoint.Endpoint{ - DNSName: "test3.example.com", - Targets: []string{"fc00::1:192:168:1:3"}, - RecordType: endpoint.RecordTypeAAAA, - } - if err := p.ApplyChanges(context.Background(), &plan.Changes{ - Delete: []*endpoint.Endpoint{ - &recordToDeleteAAAA, - }, - }); err != nil { - t.Fatal(err) - } - - // Test records are updated - newRecords, err = p.Records(context.Background()) - if err != nil { - t.Fatal(err) - } - if len(newRecords) != 4 { - t.Fatal("Expected list of 4 records, got:", records) - } - if len(requests.createRequests) != 0 { - t.Fatal("Expected no create requests, got:", requests.createRequests) - } - if len(requests.deleteRequests) != 2 { - t.Fatal("Expected 2 delete request, got:", requests.deleteRequests) - } - - for idx, record := range records { - if newRecords[idx].DNSName != record.DNSName { - t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) + if !cmp.Equal(newRecords, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Records are not equal:", cmp.Diff(newRecords, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) } - if newRecords[idx].Targets[0] != record.Targets[0] { - t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) + if !cmp.Equal(requests.createRequests, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Create requests are not equal:", cmp.Diff(requests.createRequests, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) } - } - - if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDeleteA) { - t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDeleteA) - } - if !reflect.DeepEqual(requests.deleteRequests[1], &recordToDeleteAAAA) { - t.Error("Unexpected delete request, got:", requests.deleteRequests[1], "expected:", recordToDeleteAAAA) - } - - requests.clear() - - // Test update a record - - records = []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"10.0.0.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:10:0:0:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - } - if err := p.ApplyChanges(context.Background(), &plan.Changes{ - UpdateOld: []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"192.168.1.2"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:192:168:1:2"}, - RecordType: endpoint.RecordTypeAAAA, - }, - }, - UpdateNew: []*endpoint.Endpoint{ - { - DNSName: "test1.example.com", - Targets: []string{"192.168.1.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"10.0.0.1"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"10.0.0.2"}, - RecordType: endpoint.RecordTypeA, - }, - { - DNSName: "test1.example.com", - Targets: []string{"fc00::1:192:168:1:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - { - DNSName: "test2.example.com", - Targets: []string{"fc00::1:10:0:0:1"}, - RecordType: endpoint.RecordTypeAAAA, - }, - }, - }); err != nil { - t.Fatal(err) - } - - // Test records are updated - newRecords, err = p.Records(context.Background()) - if err != nil { - t.Fatal(err) - } - if len(newRecords) != 4 { - t.Fatal("Expected list of 4 records, got:", newRecords) - } - if len(requests.createRequests) != 2 { - t.Fatal("Expected 2 create request, got:", requests.createRequests) - } - if len(requests.deleteRequests) != 2 { - t.Fatal("Expected 2 delete request, got:", requests.deleteRequests) - } - - for idx, record := range records { - if newRecords[idx].DNSName != record.DNSName { - t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName) + if len(requests.deleteRequests) != 0 { + t.Fatal("Expected no delete requests, got:", requests.deleteRequests) } - if newRecords[idx].Targets[0] != record.Targets[0] { - t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) + requests.clear() + }) + + t.Run("Delete Records", func(t *testing.T) { + recordToDeleteA := &endpoint.Endpoint{DNSName: "test3.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA} + if err := p.ApplyChanges(context.Background(), &plan.Changes{Delete: []*endpoint.Endpoint{recordToDeleteA}}); err != nil { + t.Fatal(err) } - } - - expectedCreateA := endpoint.Endpoint{ - DNSName: "test2.example.com", - Targets: []string{"10.0.0.1", "10.0.0.2"}, - RecordType: endpoint.RecordTypeA, - } - expectedDeleteA := endpoint.Endpoint{ - DNSName: "test2.example.com", - Targets: []string{"192.168.1.2"}, - RecordType: endpoint.RecordTypeA, - } - expectedCreateAAAA := endpoint.Endpoint{ - DNSName: "test2.example.com", - Targets: []string{"fc00::1:10:0:0:1"}, - RecordType: endpoint.RecordTypeAAAA, - } - expectedDeleteAAAA := endpoint.Endpoint{ - DNSName: "test2.example.com", - Targets: []string{"fc00::1:192:168:1:2"}, - RecordType: endpoint.RecordTypeAAAA, - } - - for _, request := range requests.createRequests { - switch request.RecordType { - case endpoint.RecordTypeA: - if !reflect.DeepEqual(request, &expectedCreateA) { - t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateA) - } - case endpoint.RecordTypeAAAA: - if !reflect.DeepEqual(request, &expectedCreateAAAA) { - t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateAAAA) - } - default: + recordToDeleteAAAA := &endpoint.Endpoint{DNSName: "test3.example.com", Targets: []string{"fc00::1:192:168:1:3"}, RecordType: endpoint.RecordTypeAAAA} + if err := p.ApplyChanges(context.Background(), &plan.Changes{Delete: []*endpoint.Endpoint{recordToDeleteAAAA}}); err != nil { + t.Fatal(err) } - } - for _, request := range requests.deleteRequests { - switch request.RecordType { - case endpoint.RecordTypeA: - if !reflect.DeepEqual(request, &expectedDeleteA) { - t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteA) - } - case endpoint.RecordTypeAAAA: - if !reflect.DeepEqual(request, &expectedDeleteAAAA) { - t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteAAAA) - } - default: + expectedRecords := []*endpoint.Endpoint{ + {DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA}, + {DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA}, } - } + newRecords, err := p.Records(context.Background()) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Records are not equal:", cmp.Diff(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) + } + if len(requests.createRequests) != 0 { + t.Fatal("Expected no create requests, got:", requests.createRequests) + } + expectedDeletes := []*endpoint.Endpoint{recordToDeleteA, recordToDeleteAAAA} + if !cmp.Equal(requests.deleteRequests, expectedDeletes, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Delete requests are not equal:", cmp.Diff(requests.deleteRequests, expectedDeletes, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) + } + requests.clear() + }) - requests.clear() + t.Run("Update Records", func(t *testing.T) { + updateOld := []*endpoint.Endpoint{ + {DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA}, + } + updateNew := []*endpoint.Endpoint{ + {DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test2.example.com", Targets: []string{"fc00::1:10:0:0:1"}, RecordType: endpoint.RecordTypeAAAA}, + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil { + t.Fatal(err) + } + + expectedRecords := []*endpoint.Endpoint{ + {DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA}, + {DNSName: "test2.example.com", Targets: []string{"fc00::1:10:0:0:1"}, RecordType: endpoint.RecordTypeAAAA}, + } + newRecords, err := p.Records(context.Background()) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Records are not equal:", cmp.Diff(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) + } + if !cmp.Equal(requests.createRequests, updateNew, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) { + t.Error("Create requests are not equal:", cmp.Diff(requests.createRequests, updateNew, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort)) + } + if !cmp.Equal(requests.deleteRequests, updateOld, cmpopts.IgnoreUnexported(endpoint.Endpoint{})) { + t.Error("Delete requests are not equal:", cmp.Diff(requests.deleteRequests, updateOld, cmpopts.IgnoreUnexported(endpoint.Endpoint{}))) + } + requests.clear() + }) +} + +func TestProviderV6MultipleTargets(t *testing.T) { + requests := requestTrackerV6{} + p := &PiholeProvider{ + api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, + apiVersion: "6", + } + + t.Run("Update with multiple targets - merge and deduplicate", func(t *testing.T) { + // Create initial records with multiple targets + initialRecords := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.1", "192.168.1.2"}, RecordType: endpoint.RecordTypeA}, + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{Create: initialRecords}); err != nil { + t.Fatal(err) + } + requests.clear() + + // Update with new targets that should be merged + updateOld := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.1", "192.168.1.2"}, RecordType: endpoint.RecordTypeA}, + } + updateNew := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "multi.example.com", Targets: []string{"192.168.1.4"}, RecordType: endpoint.RecordTypeA}, + {DNSName: "multi.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA}, // Duplicate to test deduplication + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil { + t.Fatal(err) + } + + // Verify that targets were merged and deduplicated + expectedCreate := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA}, + } + if len(requests.createRequests) != 1 { + t.Fatalf("Expected 1 create request, got %d", len(requests.createRequests)) + } + if !cmp.Equal(requests.createRequests[0].Targets, expectedCreate[0].Targets) { + t.Error("Targets not merged correctly:", cmp.Diff(requests.createRequests[0].Targets, expectedCreate[0].Targets)) + } + if len(requests.deleteRequests) != 1 { + t.Fatalf("Expected 1 delete request, got %d", len(requests.deleteRequests)) + } + requests.clear() + }) + + t.Run("Update with exact match - should skip delete", func(t *testing.T) { + // Update where old and new have the same targets (exact match) + updateOld := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA}, + } + updateNew := []*endpoint.Endpoint{ + {DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA}, + } + if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil { + t.Fatal(err) + } + + // Should not create or delete anything since targets are the same + if len(requests.createRequests) != 0 { + t.Fatalf("Expected no create requests for exact match, got %d", len(requests.createRequests)) + } + if len(requests.deleteRequests) != 0 { + t.Fatalf("Expected no delete requests for exact match, got %d", len(requests.deleteRequests)) + } + requests.clear() + }) }