refactor(pihole): reduce cyclomatic complexity of TestProviderV6 (#5876)

* refactor(pihole): reduce cyclomatic complexity of TestProviderV6

* chore(pihole): increase coverage

* style: linting

* style: linting

* fix: remove coverage html
This commit is contained in:
Andrew Hay 2025-10-01 03:50:23 -04:00 committed by GitHub
parent 1f9edcb7fc
commit e22ceab66f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 318 additions and 336 deletions

1
.gitignore vendored
View File

@ -42,6 +42,7 @@ cscope.*
# coverage output
cover.out
coverage.html
*.coverprofile
external-dns

View File

@ -41,7 +41,7 @@ linters:
- name: confusing-naming
disabled: true
cyclop: # Lower cyclomatic complexity threshold after the max complexity is lowered
max-complexity: 43
max-complexity: 37 # Controller/execute.go:147:1: calculated cyclomatic complexity for function buildProvider is 37
testifylint:
# Enable all checkers (https://github.com/Antonboom/testifylint#checkers).
# Default: false

View File

@ -92,6 +92,12 @@ func newTestServerV6(t *testing.T, hdlr http.HandlerFunc) *httptest.Server {
return svr
}
type errorTransportV6 struct{}
func (t *errorTransportV6) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, errors.New("network error")
}
func TestNewPiholeClientV6(t *testing.T) {
// Test correct error on no server provided
_, err := newPiholeClientV6(PiholeConfig{APIVersion: "6"})
@ -117,7 +123,10 @@ func TestNewPiholeClientV6(t *testing.T) {
srvr := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/auth" && r.Method == http.MethodPost {
var requestData map[string]string
json.NewDecoder(r.Body).Decode(&requestData)
err := json.NewDecoder(r.Body).Decode(&requestData)
if err != nil {
t.Fatal(err)
}
defer r.Body.Close()
w.Header().Set("Content-Type", "application/json")
@ -125,7 +134,7 @@ func TestNewPiholeClientV6(t *testing.T) {
if requestData["password"] != "correct" {
// Return unsuccessful authentication response
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{
_, err = w.Write([]byte(`{
"session": {
"valid": false,
"totp": false,
@ -135,11 +144,14 @@ func TestNewPiholeClientV6(t *testing.T) {
},
"took": 0.2
}`))
if err != nil {
t.Fatal(err)
}
return
}
// Return successful authentication response
w.Write([]byte(`{
_, err = w.Write([]byte(`{
"session": {
"valid": true,
"totp": false,
@ -185,7 +197,7 @@ func TestListRecordsV6(t *testing.T) {
w.Header().Set("Content-Type", "application/json")
// Return A records
w.Write([]byte(`{
if _, err := w.Write([]byte(`{
"config": {
"dns": {
"hosts": [
@ -205,7 +217,9 @@ func TestListRecordsV6(t *testing.T) {
}
},
"took": 5
}`))
}`)); err != nil {
t.Fatal(err)
}
} else if r.URL.Path == "/api/config/dns/cnameRecords" && r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK)
@ -384,9 +398,12 @@ func TestErrorsV6(t *testing.T) {
Server: "not an url",
APIVersion: "6",
}
clErrURL, _ := newPiholeClientV6(cfgErrURL)
clErrURL, err := newPiholeClientV6(cfgErrURL)
if err != nil {
t.Fatal(err)
}
_, err := clErrURL.listRecords(context.Background(), endpoint.RecordTypeCNAME)
_, err = clErrURL.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err == nil {
t.Fatal("Expected error for using invalid URL")
}
@ -785,6 +802,80 @@ func TestDoRetryOne(t *testing.T) {
}
func TestDoV6AdditionalCases(t *testing.T) {
t.Run("http client error", func(t *testing.T) {
client := &piholeClientV6{
httpClient: &http.Client{
Transport: &errorTransportV6{},
},
}
req, _ := http.NewRequest(http.MethodGet, "http://localhost", nil)
_, err := client.do(req)
if err == nil {
t.Fatal("expected an error, but got none")
}
if !strings.Contains(err.Error(), "network error") {
t.Fatalf("expected error to contain 'network error', but got '%v'", err)
}
})
t.Run("item already present", func(t *testing.T) {
server := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{
"error": {
"key": "bad_request",
"message": "Item already present",
"hint": "The item you're trying to add already exists"
},
"took": 0.1
}`))
})
defer server.Close()
client := &piholeClientV6{
httpClient: server.Client(),
token: "test-token",
}
req, _ := http.NewRequest(http.MethodPut, server.URL+"/api/test", nil)
resp, err := client.do(req)
if err != nil {
t.Fatalf("expected no error for 'Item already present', but got '%v'", err)
}
if resp == nil {
t.Fatal("expected response, but got nil")
}
})
t.Run("404 on DELETE", func(t *testing.T) {
server := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{
"error": {
"key": "not_found",
"message": "Item not found",
"hint": "The item you're trying to delete does not exist"
},
"took": 0.1
}`))
})
defer server.Close()
client := &piholeClientV6{
httpClient: server.Client(),
token: "test-token",
}
req, _ := http.NewRequest(http.MethodDelete, server.URL+"/api/test", nil)
resp, err := client.do(req)
if err != nil {
t.Fatalf("expected no error for 404 on DELETE, but got '%v'", err)
}
if resp == nil {
t.Fatal("expected response, but got nil")
}
})
}
func TestCreateRecordV6(t *testing.T) {
var ep *endpoint.Endpoint
srvr := newTestServerV6(t, func(w http.ResponseWriter, r *http.Request) {

View File

@ -57,15 +57,21 @@ func TestNewPiholeClient(t *testing.T) {
// Create a test server for auth tests
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
pw := r.Form.Get("pw")
if pw != "correct" {
// Pihole actually server side renders the fact that you failed, normal 200
w.Write([]byte("Invalid"))
_, err = w.Write([]byte("Invalid"))
if err != nil {
t.Fatal(err)
}
return
}
// This is a subset of what happens on successful login
w.Write([]byte(`
_, err = w.Write([]byte(`
<!doctype html>
<html lang="en">
<body>
@ -73,6 +79,9 @@ func TestNewPiholeClient(t *testing.T) {
</body>
</html>
`))
if err != nil {
t.Fatal(err)
}
})
defer srvr.Close()
@ -124,12 +133,15 @@ func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, exp
func TestListRecords(t *testing.T) {
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
if r.Form.Get("action") != "get" {
t.Error("Expected 'get' action in form from client")
}
if strings.Contains(r.URL.Path, "cname") {
w.Write([]byte(`
_, err = w.Write([]byte(`
{
"data": [
["test4.example.com", "cname.example.com"],
@ -138,10 +150,13 @@ func TestListRecords(t *testing.T) {
]
}
`))
if err != nil {
t.Fatal(err)
}
return
}
// Pihole makes no distinction between A and AAAA records
w.Write([]byte(`
_, err = w.Write([]byte(`
{
"data": [
["test1.example.com", "192.168.1.1"],
@ -153,6 +168,9 @@ func TestListRecords(t *testing.T) {
]
}
`))
if err != nil {
t.Fatal(err)
}
})
defer srvr.Close()
@ -243,41 +261,32 @@ func testErrorScenarios(t *testing.T, srvrErr *httptest.Server) {
func TestErrorScenarios(t *testing.T) {
// Test errors token
srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
pw := r.Form.Get("pw")
if pw != "" {
if pw != "correct" {
// Pihole actually server side renders the fact that you failed, normal 200
w.Write([]byte("Invalid"))
_, err = w.Write([]byte("Invalid"))
if err != nil {
t.Fatal(err)
}
return
}
}
if strings.Contains(r.URL.Path, "admin/scripts/pi-hole/php/customcname.php") && r.Form.Get("token") == "correct" {
w.Write([]byte(`
_, err = w.Write([]byte(`
{
"nodata": [
["nodata", "no"]
]
}
`))
return
}
if strings.Contains(r.URL.Path, "admin/index.php?login") {
w.Write([]byte(`
<!doctype html>
<html lang="en">
<body>
<div id="token" hidden>supersecret</div>
</body>
</html>
`))
if err != nil {
t.Fatal(err)
}
}
// Token Expired
w.Write([]byte(`
{
"auth": "expired"
}
`))
})
defer srvrErr.Close()

View File

@ -19,14 +19,32 @@ package pihole
import (
"context"
"errors"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/plan"
)
var (
endpointSort = cmpopts.SortSlices(func(x, y *endpoint.Endpoint) bool {
if x.DNSName < y.DNSName {
return true
}
if x.DNSName > y.DNSName {
return false
}
if x.RecordType < y.RecordType {
return true
}
if x.RecordType > y.RecordType {
return false
}
return x.Targets.String() < y.Targets.String()
})
)
type testPiholeClientV6 struct {
endpoints []*endpoint.Endpoint
requests *requestTrackerV6
@ -127,316 +145,179 @@ func TestProviderV6(t *testing.T) {
apiVersion: "6",
}
records, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatal("Expected empty list of records, got:", records)
}
// Populate the provider with records
records = []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test3.example.com",
Targets: []string{"192.168.1.3"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test3.example.com",
Targets: []string{"fc00::1:192:168:1:3"},
RecordType: endpoint.RecordTypeAAAA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Create: records,
}); err != nil {
t.Fatal(err)
}
// Test records are correct on retrieval
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(newRecords) != 6 {
t.Fatal("Expected list of 6 records, got:", records)
}
if len(requests.createRequests) != 6 {
t.Fatal("Expected 6 create requests, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
}
for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
t.Run("Initial Records", func(t *testing.T) {
records, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
if len(records) != 0 {
t.Fatal("Expected empty list of records, got:", records)
}
})
t.Run("Create Records", func(t *testing.T) {
records := []*endpoint.Endpoint{
{DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test3.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA},
{DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA},
{DNSName: "test3.example.com", Targets: []string{"fc00::1:192:168:1:3"}, RecordType: endpoint.RecordTypeAAAA},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{Create: records}); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(requests.createRequests[idx], record) {
t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
}
requests.clear()
// Test delete a record
records = []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
}
recordToDeleteA := endpoint.Endpoint{
DNSName: "test3.example.com",
Targets: []string{"192.168.1.3"},
RecordType: endpoint.RecordTypeA,
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Delete: []*endpoint.Endpoint{
&recordToDeleteA,
},
}); err != nil {
t.Fatal(err)
}
recordToDeleteAAAA := endpoint.Endpoint{
DNSName: "test3.example.com",
Targets: []string{"fc00::1:192:168:1:3"},
RecordType: endpoint.RecordTypeAAAA,
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Delete: []*endpoint.Endpoint{
&recordToDeleteAAAA,
},
}); err != nil {
t.Fatal(err)
}
// Test records are updated
newRecords, err = p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(newRecords) != 4 {
t.Fatal("Expected list of 4 records, got:", records)
}
if len(requests.createRequests) != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 2 {
t.Fatal("Expected 2 delete request, got:", requests.deleteRequests)
}
for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
if !cmp.Equal(newRecords, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Records are not equal:", cmp.Diff(newRecords, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
if !cmp.Equal(requests.createRequests, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Create requests are not equal:", cmp.Diff(requests.createRequests, records, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
}
if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDeleteA) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDeleteA)
}
if !reflect.DeepEqual(requests.deleteRequests[1], &recordToDeleteAAAA) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[1], "expected:", recordToDeleteAAAA)
}
requests.clear()
// Test update a record
records = []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
RecordType: endpoint.RecordTypeAAAA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
UpdateOld: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
},
UpdateNew: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
RecordType: endpoint.RecordTypeAAAA,
},
},
}); err != nil {
t.Fatal(err)
}
// Test records are updated
newRecords, err = p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(newRecords) != 4 {
t.Fatal("Expected list of 4 records, got:", newRecords)
}
if len(requests.createRequests) != 2 {
t.Fatal("Expected 2 create request, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 2 {
t.Fatal("Expected 2 delete request, got:", requests.deleteRequests)
}
for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
if len(requests.deleteRequests) != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
}
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
requests.clear()
})
t.Run("Delete Records", func(t *testing.T) {
recordToDeleteA := &endpoint.Endpoint{DNSName: "test3.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA}
if err := p.ApplyChanges(context.Background(), &plan.Changes{Delete: []*endpoint.Endpoint{recordToDeleteA}}); err != nil {
t.Fatal(err)
}
}
expectedCreateA := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1", "10.0.0.2"},
RecordType: endpoint.RecordTypeA,
}
expectedDeleteA := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
}
expectedCreateAAAA := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
RecordType: endpoint.RecordTypeAAAA,
}
expectedDeleteAAAA := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
}
for _, request := range requests.createRequests {
switch request.RecordType {
case endpoint.RecordTypeA:
if !reflect.DeepEqual(request, &expectedCreateA) {
t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateA)
}
case endpoint.RecordTypeAAAA:
if !reflect.DeepEqual(request, &expectedCreateAAAA) {
t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateAAAA)
}
default:
recordToDeleteAAAA := &endpoint.Endpoint{DNSName: "test3.example.com", Targets: []string{"fc00::1:192:168:1:3"}, RecordType: endpoint.RecordTypeAAAA}
if err := p.ApplyChanges(context.Background(), &plan.Changes{Delete: []*endpoint.Endpoint{recordToDeleteAAAA}}); err != nil {
t.Fatal(err)
}
}
for _, request := range requests.deleteRequests {
switch request.RecordType {
case endpoint.RecordTypeA:
if !reflect.DeepEqual(request, &expectedDeleteA) {
t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteA)
}
case endpoint.RecordTypeAAAA:
if !reflect.DeepEqual(request, &expectedDeleteAAAA) {
t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteAAAA)
}
default:
expectedRecords := []*endpoint.Endpoint{
{DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA},
{DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA},
}
}
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Records are not equal:", cmp.Diff(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
if len(requests.createRequests) != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
expectedDeletes := []*endpoint.Endpoint{recordToDeleteA, recordToDeleteAAAA}
if !cmp.Equal(requests.deleteRequests, expectedDeletes, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Delete requests are not equal:", cmp.Diff(requests.deleteRequests, expectedDeletes, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
requests.clear()
})
requests.clear()
t.Run("Update Records", func(t *testing.T) {
updateOld := []*endpoint.Endpoint{
{DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test2.example.com", Targets: []string{"fc00::1:192:168:1:2"}, RecordType: endpoint.RecordTypeAAAA},
}
updateNew := []*endpoint.Endpoint{
{DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test2.example.com", Targets: []string{"fc00::1:10:0:0:1"}, RecordType: endpoint.RecordTypeAAAA},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil {
t.Fatal(err)
}
expectedRecords := []*endpoint.Endpoint{
{DNSName: "test1.example.com", Targets: []string{"192.168.1.1"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, RecordType: endpoint.RecordTypeA},
{DNSName: "test1.example.com", Targets: []string{"fc00::1:192:168:1:1"}, RecordType: endpoint.RecordTypeAAAA},
{DNSName: "test2.example.com", Targets: []string{"fc00::1:10:0:0:1"}, RecordType: endpoint.RecordTypeAAAA},
}
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Records are not equal:", cmp.Diff(newRecords, expectedRecords, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
if !cmp.Equal(requests.createRequests, updateNew, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort) {
t.Error("Create requests are not equal:", cmp.Diff(requests.createRequests, updateNew, cmpopts.IgnoreUnexported(endpoint.Endpoint{}), endpointSort))
}
if !cmp.Equal(requests.deleteRequests, updateOld, cmpopts.IgnoreUnexported(endpoint.Endpoint{})) {
t.Error("Delete requests are not equal:", cmp.Diff(requests.deleteRequests, updateOld, cmpopts.IgnoreUnexported(endpoint.Endpoint{})))
}
requests.clear()
})
}
func TestProviderV6MultipleTargets(t *testing.T) {
requests := requestTrackerV6{}
p := &PiholeProvider{
api: &testPiholeClientV6{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
apiVersion: "6",
}
t.Run("Update with multiple targets - merge and deduplicate", func(t *testing.T) {
// Create initial records with multiple targets
initialRecords := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.1", "192.168.1.2"}, RecordType: endpoint.RecordTypeA},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{Create: initialRecords}); err != nil {
t.Fatal(err)
}
requests.clear()
// Update with new targets that should be merged
updateOld := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.1", "192.168.1.2"}, RecordType: endpoint.RecordTypeA},
}
updateNew := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA},
{DNSName: "multi.example.com", Targets: []string{"192.168.1.4"}, RecordType: endpoint.RecordTypeA},
{DNSName: "multi.example.com", Targets: []string{"192.168.1.3"}, RecordType: endpoint.RecordTypeA}, // Duplicate to test deduplication
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil {
t.Fatal(err)
}
// Verify that targets were merged and deduplicated
expectedCreate := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA},
}
if len(requests.createRequests) != 1 {
t.Fatalf("Expected 1 create request, got %d", len(requests.createRequests))
}
if !cmp.Equal(requests.createRequests[0].Targets, expectedCreate[0].Targets) {
t.Error("Targets not merged correctly:", cmp.Diff(requests.createRequests[0].Targets, expectedCreate[0].Targets))
}
if len(requests.deleteRequests) != 1 {
t.Fatalf("Expected 1 delete request, got %d", len(requests.deleteRequests))
}
requests.clear()
})
t.Run("Update with exact match - should skip delete", func(t *testing.T) {
// Update where old and new have the same targets (exact match)
updateOld := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA},
}
updateNew := []*endpoint.Endpoint{
{DNSName: "multi.example.com", Targets: []string{"192.168.1.3", "192.168.1.4"}, RecordType: endpoint.RecordTypeA},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updateOld, UpdateNew: updateNew}); err != nil {
t.Fatal(err)
}
// Should not create or delete anything since targets are the same
if len(requests.createRequests) != 0 {
t.Fatalf("Expected no create requests for exact match, got %d", len(requests.createRequests))
}
if len(requests.deleteRequests) != 0 {
t.Fatalf("Expected no delete requests for exact match, got %d", len(requests.deleteRequests))
}
requests.clear()
})
}