external-dns/provider/pihole/client_test.go
Ivan Ka e21f1389fb
linter(usetesting): enable usetesting (#6266)
* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

* linter(usetesting): enable usetesting

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>

---------

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>
2026-03-14 22:07:35 +05:30

496 lines
13 KiB
Go

/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package pihole
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"sigs.k8s.io/external-dns/endpoint"
)
func newTestServer(t *testing.T, hdlr http.HandlerFunc) *httptest.Server {
t.Helper()
svr := httptest.NewServer(hdlr)
return svr
}
func TestNewPiholeClient(t *testing.T) {
// Test correct error on no server provided
_, err := newPiholeClient(PiholeConfig{})
if err == nil {
t.Error("Expected error from config with no server")
} else if !errors.Is(err, ErrNoPiholeServer) {
t.Error("Expected ErrNoPiholeServer, got", err)
}
// Test new client with no password. Should create the
// client cleanly.
cl, err := newPiholeClient(PiholeConfig{
Server: "test",
})
if err != nil {
t.Fatal(err)
}
if _, ok := cl.(*piholeClient); !ok {
t.Error("Did not create a new pihole client")
}
// Create a test server for auth tests
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
pw := r.Form.Get("pw")
if pw != "correct" {
// Pihole actually server side renders the fact that you failed, normal 200
_, err = w.Write([]byte("Invalid"))
if err != nil {
t.Fatal(err)
}
return
}
// This is a subset of what happens on successful login
_, err = w.Write([]byte(`
<!doctype html>
<html lang="en">
<body>
<div id="token" hidden>supersecret</div>
</body>
</html>
`))
if err != nil {
t.Fatal(err)
}
})
defer srvr.Close()
// Test invalid password
_, err = newPiholeClient(
PiholeConfig{Server: srvr.URL, Password: "wrong"},
)
if err == nil {
t.Error("Expected error for creating client with invalid password")
}
// Test correct password
cl, err = newPiholeClient(
PiholeConfig{Server: srvr.URL, Password: "correct"},
)
if err != nil {
t.Fatal(err)
}
if cl.(*piholeClient).token != "supersecret" {
t.Error("Parsed invalid token from login response:", cl.(*piholeClient).token)
}
}
// Helper function to validate records against expected values
func ValidateRecords(t *testing.T, records []*endpoint.Endpoint, expected [][]string, expectedCount int, recordType string) {
t.Helper()
if len(records) != expectedCount {
t.Fatalf("Expected %d %s records returned, got: %d", expectedCount, recordType, len(records))
}
for idx, rec := range records {
if rec.DNSName != expected[idx][0] {
t.Errorf("Got invalid DNS Name: %s, expected: %s", rec.DNSName, expected[idx][0])
}
if rec.Targets[0] != expected[idx][1] {
t.Errorf("Got invalid target: %s, expected: %s", rec.Targets[0], expected[idx][1])
}
}
}
// Helper function to test record retrieval for a specific type
func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, expected [][]string, expectedCount int) {
t.Helper()
records, err := cl.listRecords(t.Context(), recordType)
if err != nil {
t.Fatal(err)
}
ValidateRecords(t, records, expected, expectedCount, recordType)
}
func TestListRecords(t *testing.T) {
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
if r.Form.Get("action") != "get" {
t.Error("Expected 'get' action in form from client")
}
if strings.Contains(r.URL.Path, "cname") {
_, err = w.Write([]byte(`
{
"data": [
["test4.example.com", "cname.example.com"],
["test5.example.com", "cname.example.com"],
["test6.match.com", "cname.example.com"]
]
}
`))
if err != nil {
t.Fatal(err)
}
return
}
// Pihole makes no distinction between A and AAAA records
_, err = w.Write([]byte(`
{
"data": [
["test1.example.com", "192.168.1.1"],
["test2.example.com", "192.168.1.2"],
["test3.match.com", "192.168.1.3"],
["test1.example.com", "fc00::1:192:168:1:1"],
["test2.example.com", "fc00::1:192:168:1:2"],
["test3.match.com", "fc00::1:192:168:1:3"]
]
}
`))
if err != nil {
t.Fatal(err)
}
})
defer srvr.Close()
// Create a client
cfg := PiholeConfig{
Server: srvr.URL,
}
cl, err := newPiholeClient(cfg)
if err != nil {
t.Fatal(err)
}
// Test retrieve A records unfiltered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
{"test1.example.com", "192.168.1.1"},
{"test2.example.com", "192.168.1.2"},
{"test3.match.com", "192.168.1.3"},
}, 3)
// Test retrieve AAAA records unfiltered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
{"test1.example.com", "fc00::1:192:168:1:1"},
{"test2.example.com", "fc00::1:192:168:1:2"},
{"test3.match.com", "fc00::1:192:168:1:3"},
}, 3)
// Test retrieve CNAME records unfiltered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
{"test4.example.com", "cname.example.com"},
{"test5.example.com", "cname.example.com"},
{"test6.match.com", "cname.example.com"},
}, 3)
// Same tests but with a domain filter
cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"})
cl, err = newPiholeClient(cfg)
if err != nil {
t.Fatal(err)
}
// Test retrieve A records filtered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
{"test3.match.com", "192.168.1.3"},
}, 1)
// Test retrieve AAAA records filtered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
{"test3.match.com", "fc00::1:192:168:1:3"},
}, 1)
// Test retrieve CNAME records filtered
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
{"test6.match.com", "cname.example.com"},
}, 1)
}
func TestErrorScenarios(t *testing.T) {
// Test errors token
srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}
pw := r.Form.Get("pw")
if pw != "" {
if pw != "correct" {
_, err = w.Write([]byte("Invalid"))
if err != nil {
t.Fatal(err)
}
return
}
}
if strings.Contains(r.URL.Path, "admin/scripts/pi-hole/php/customcname.php") && r.Form.Get("token") == "correct" {
_, err = w.Write([]byte(`
{
"nodata": [
["nodata", "no"]
]
}
`))
if err != nil {
t.Fatal(err)
}
}
})
defer srvrErr.Close()
cfgExpired := PiholeConfig{
Server: srvrErr.URL,
}
clExpired, err := newPiholeClient(cfgExpired)
if err != nil {
t.Fatal(err)
}
// set clExpired.token to a valid token
clExpired.(*piholeClient).token = "expired"
clExpired.(*piholeClient).cfg.Password = "notcorrect"
_, err = clExpired.listRecords(t.Context(), "notarealrecordtype")
if err == nil {
t.Fatal("Should return error, type is unknown ! ")
}
_, err = clExpired.listRecords(t.Context(), endpoint.RecordTypeCNAME)
if err == nil {
t.Fatal("Should return error on failed auth ! ")
}
clExpired.(*piholeClient).token = "correct"
clExpired.(*piholeClient).cfg.Password = "correct"
cnamerecs, err := clExpired.listRecords(t.Context(), endpoint.RecordTypeCNAME)
if err != nil {
t.Fatal(err)
}
if len(cnamerecs) != 0 {
t.Fatal("Should return empty on missing data in response ! ")
}
}
func TestCreateRecord(t *testing.T) {
var ep *endpoint.Endpoint
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
if r.Form.Get("action") != "add" {
t.Error("Expected 'add' action in form from client")
}
if r.Form.Get("domain") != ep.DNSName {
t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName)
}
switch ep.RecordType {
case endpoint.RecordTypeA:
if r.Form.Get("ip") != ep.Targets[0] {
t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
}
// Pihole makes no distinction between A and AAAA records
case endpoint.RecordTypeAAAA:
if r.Form.Get("ip") != ep.Targets[0] {
t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
}
case endpoint.RecordTypeCNAME:
if r.Form.Get("target") != ep.Targets[0] {
t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0])
}
}
out, err := json.Marshal(actionResponse{
Success: true,
Message: "",
})
if err != nil {
t.Fatal(err)
}
w.Write(out)
})
defer srvr.Close()
// Create a client
cfg := PiholeConfig{
Server: srvr.URL,
DomainFilter: endpoint.NewDomainFilter([]string{"example.com"}),
}
cl, err := newPiholeClient(cfg)
if err != nil {
t.Fatal(err)
}
// Test create A record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
}
if err := cl.createRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
// Test create AAAA record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
}
if err := cl.createRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
// Test create CNAME record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"test.cname.com"},
RecordType: endpoint.RecordTypeCNAME,
}
if err := cl.createRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
// Test create a wildcard record and ensure it fails
ep = &endpoint.Endpoint{
DNSName: "*.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
}
cl.(*piholeClient).token = "correct"
if err := cl.createRecord(t.Context(), ep); err == nil {
t.Fatal(err)
}
// Skip not matching domain
ep = &endpoint.Endpoint{
DNSName: "foo.bar.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
}
if err := cl.createRecord(t.Context(), ep); err != nil {
t.Fatal("Should not return error on non filtered domain")
}
// Not supported type
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"192.168.1.1"},
RecordType: "not a type",
}
if err := cl.createRecord(t.Context(), ep); err != nil {
t.Fatal("Should not return error on unsupported type")
}
// Create a client
cfgDr := PiholeConfig{
Server: srvr.URL,
DomainFilter: endpoint.NewDomainFilter([]string{"example.com"}),
DryRun: true,
}
clDr, err := newPiholeClient(cfgDr)
if err != nil {
t.Fatal(err)
}
// Skip Dry Run
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
}
if err := clDr.createRecord(t.Context(), ep); err != nil {
t.Fatal("Should not return error on dry run")
}
}
func TestDeleteRecord(t *testing.T) {
var ep *endpoint.Endpoint
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
if r.Form.Get("action") != "delete" {
t.Error("Expected 'delete' action in form from client")
}
if r.Form.Get("domain") != ep.DNSName {
t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName)
}
switch ep.RecordType {
case endpoint.RecordTypeA:
if r.Form.Get("ip") != ep.Targets[0] {
t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
}
// Pihole makes no distinction between A and AAAA records
case endpoint.RecordTypeAAAA:
if r.Form.Get("ip") != ep.Targets[0] {
t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
}
case endpoint.RecordTypeCNAME:
if r.Form.Get("target") != ep.Targets[0] {
t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0])
}
}
out, err := json.Marshal(actionResponse{
Success: true,
Message: "",
})
if err != nil {
t.Fatal(err)
}
w.Write(out)
})
defer srvr.Close()
// Create a client
cfg := PiholeConfig{
Server: srvr.URL,
}
cl, err := newPiholeClient(cfg)
if err != nil {
t.Fatal(err)
}
// Test delete A record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
}
if err := cl.deleteRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
// Test delete AAAA record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
}
if err := cl.deleteRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
// Test delete CNAME record
ep = &endpoint.Endpoint{
DNSName: "test.example.com",
Targets: []string{"test.cname.com"},
RecordType: endpoint.RecordTypeCNAME,
}
if err := cl.deleteRecord(t.Context(), ep); err != nil {
t.Fatal(err)
}
}