feat(pihole): add support for IPv6 Dual format (#5253)

* Add support of ipv6 dual on pihole provider

* PiHoleV6 : Switch from instrumented_http to httpClient

* Add support of ipv6 dual on pihole provider - extends tests cases

* Switch to net/netip to check ipV6

* Fix linter

* ListRecords should not log filtered records

Should not log records reject by filter on listRecords because PiHole return A and AAAA records. It is normal to filter some records

* Update provider/pihole/clientV6.go

Co-authored-by: Michel Loiseleur <97035654+mloiseleur@users.noreply.github.com>

---------

Co-authored-by: Michel Loiseleur <97035654+mloiseleur@users.noreply.github.com>
This commit is contained in:
tJouve 2025-04-23 14:43:42 +02:00 committed by GitHub
parent bc96176ddc
commit c49322f7ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 136 additions and 39 deletions

View File

@ -25,6 +25,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@ -63,6 +64,7 @@ func newPiholeClientV6(cfg PiholeConfig) (piholeAPI, error) {
}, },
}, },
} }
cl := instrumented_http.NewClient(httpClient, &instrumented_http.Callbacks{}) cl := instrumented_http.NewClient(httpClient, &instrumented_http.Callbacks{})
p := &piholeClientV6{ p := &piholeClientV6{
@ -114,6 +116,32 @@ func (p *piholeClientV6) getConfigValue(ctx context.Context, rtype string) ([]st
return results, nil return results, nil
} }
/**
* isValidIPv4 checks if the given IP address is a valid IPv4 address.
* It returns true if the IP address is valid, false otherwise.
* If the IP address is in IPv6 format, it will return false.
*/
func isValidIPv4(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
return addr.Is4()
}
/**
* isValidIPv6 checks if the given IP address is a valid IPv6 address.
* It returns true if the IP address is valid, false otherwise.
* If the IP address is in IPv6 with dual format y:y:y:y:y:y:x.x.x.x. , it will return true.
*/
func isValidIPv6(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
return addr.Is6()
}
func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) { func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) {
out := make([]*endpoint.Endpoint, 0) out := make([]*endpoint.Endpoint, 0)
results, err := p.getConfigValue(ctx, rtype) results, err := p.getConfigValue(ctx, rtype)
@ -126,42 +154,39 @@ func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endp
return r == ' ' || r == ',' return r == ' ' || r == ','
}) })
if len(recs) < 2 { if len(recs) < 2 {
log.Warnf("skipping record %s: invalid format", rec) log.Warnf("skipping record %s: invalid format received from PiHole", rec)
continue continue
} }
var DNSName, Target string var DNSName, Target string
var Ttl endpoint.TTL = 0 var Ttl = endpoint.TTL(0)
// A/AAAA record format is target(IP) DNSName // A/AAAA record format is target(IP) DNSName
DNSName, Target = recs[1], recs[0] DNSName, Target = recs[1], recs[0]
switch rtype { switch rtype {
case endpoint.RecordTypeA: case endpoint.RecordTypeA:
if strings.Contains(Target, ":") { //PiHole return A and AAAA records. Filter to only keep the A records
if !isValidIPv4(Target) {
continue continue
} }
case endpoint.RecordTypeAAAA: case endpoint.RecordTypeAAAA:
if strings.Contains(Target, ".") { //PiHole return A and AAAA records. Filter to only keep the AAAA records
if !isValidIPv6(Target) {
continue continue
} }
case endpoint.RecordTypeCNAME: case endpoint.RecordTypeCNAME:
// CNAME format is DNSName,target //PiHole return only CNAME records.
// CNAME format is DNSName,target, ttl?
DNSName, Target = recs[0], recs[1] DNSName, Target = recs[0], recs[1]
if len(recs) == 3 { // TTL is present if len(recs) == 3 { // TTL is present
// Parse string to int64 first // Parse string to int64 first
if ttlInt, err := strconv.ParseInt(recs[2], 10, 64); err == nil { if ttlInt, err := strconv.ParseInt(recs[2], 10, 64); err == nil {
Ttl = endpoint.TTL(ttlInt) Ttl = endpoint.TTL(ttlInt)
} else { } else {
log.Warnf("failed to parse TTL value '%s': %v; using a TTL of %d", recs[2], err, Ttl) log.Warnf("failed to parse TTL value received from PiHole '%s': %v; using a TTL of %d", recs[2], err, Ttl)
} }
} }
} }
out = append(out, &endpoint.Endpoint{ out = append(out, endpoint.NewEndpointWithTTL(DNSName, rtype, Ttl, Target))
DNSName: DNSName,
Targets: []string{Target},
RecordTTL: Ttl,
RecordType: rtype,
})
} }
return out, nil return out, nil
} }
@ -375,7 +400,13 @@ func (p *piholeClientV6) do(req *http.Request) ([]byte, error) {
if err := json.Unmarshal(jRes, &apiError); err != nil { if err := json.Unmarshal(jRes, &apiError); err != nil {
return nil, fmt.Errorf("failed to unmarshal error response: %w", err) return nil, fmt.Errorf("failed to unmarshal error response: %w", err)
} }
log.Debugf("Error on request %s", req.Body) if log.IsLevelEnabled(log.DebugLevel) {
log.Debugf("Error on request %s", req.URL)
if req.Body != nil {
log.Debugf("Body of the request %s", req.Body)
}
}
if res.StatusCode == http.StatusUnauthorized && p.token != "" { if res.StatusCode == http.StatusUnauthorized && p.token != "" {
tryCount := 1 tryCount := 1
maxRetries := 3 maxRetries := 3

View File

@ -29,6 +29,62 @@ import (
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
) )
func TestIsValidIPv4(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"192.168.1.1", true},
{"255.255.255.255", true},
{"0.0.0.0", true},
{"", false},
{"256.256.256.256", false},
{"192.168.0.1/22", false},
{"192.168.1", false},
{"abc.def.ghi.jkl", false},
{"::ffff:192.168.20.3", false},
}
for _, test := range tests {
t.Run(test.ip, func(t *testing.T) {
if got := isValidIPv4(test.ip); got != test.expected {
t.Errorf("isValidIPv4(%s) = %v; want %v", test.ip, got, test.expected)
}
})
}
}
func TestIsValidIPv6(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"2001:0db8:85a3:0000:0000:8a2e:0370:7334", true},
{"2001:db8:85a3::8a2e:370:7334", true},
//IPV6 dual, the format is y:y:y:y:y:y:x.x.x.x.
{"::ffff:192.168.20.3", true},
{"::1", true},
{"::", true},
{"2001:db8::", true},
{"", false},
{":", false},
{"::ffff:", false},
{"192.168.20.3", false},
{"2001:db8:85a3:0:0:8a2e:370:7334:1234", false},
{"2001:db8:85a3::8a2e:370g:7334", false},
{"2001:db8:85a3::8a2e:370:7334::", false},
{"2001:db8:85a3::8a2e:370:7334::1", false},
}
for _, test := range tests {
t.Run(test.ip, func(t *testing.T) {
if got := isValidIPv6(test.ip); got != test.expected {
t.Errorf("isValidIPv6(%s) = %v; want %v", test.ip, got, test.expected)
}
})
}
}
func newTestServerV6(t *testing.T, hdlr http.HandlerFunc) *httptest.Server { func newTestServerV6(t *testing.T, hdlr http.HandlerFunc) *httptest.Server {
t.Helper() t.Helper()
svr := httptest.NewServer(hdlr) svr := httptest.NewServer(hdlr)
@ -137,7 +193,9 @@ func TestListRecordsV6(t *testing.T) {
"192.168.178.34 service3.example.com", "192.168.178.34 service3.example.com",
"fc00::1:192:168:1:1 service4.example.com", "fc00::1:192:168:1:1 service4.example.com",
"fc00::1:192:168:1:2 service5.example.com", "fc00::1:192:168:1:2 service5.example.com",
"fc00::1:192:168:1:3 service6.example.com" "fc00::1:192:168:1:3 service6.example.com",
"::ffff:192.168.20.3 service7.example.com",
"192.168.20.3 service7.example.com"
] ]
} }
}, },
@ -177,20 +235,22 @@ func TestListRecordsV6(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Ensure A records were parsed correctly
expected := [][]string{
{"service1.example.com", "192.168.178.33"},
{"service2.example.com", "192.168.178.34"},
{"service3.example.com", "192.168.178.34"},
{"service7.example.com", "192.168.20.3"},
}
// Test retrieve A records unfiltered // Test retrieve A records unfiltered
arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA) arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(arecs) != 3 { if len(arecs) != len(expected) {
t.Fatal("Expected 3 A records returned, got:", len(arecs)) t.Fatalf("Expected %d A records returned, got: %d", len(expected), len(arecs))
}
// Ensure records were parsed correctly
expected := [][]string{
{"service1.example.com", "192.168.178.33"},
{"service2.example.com", "192.168.178.34"},
{"service3.example.com", "192.168.178.34"},
} }
for idx, rec := range arecs { for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] { if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
@ -200,20 +260,23 @@ func TestListRecordsV6(t *testing.T) {
} }
} }
// Ensure AAAA records were parsed correctly
expected = [][]string{
{"service4.example.com", "fc00::1:192:168:1:1"},
{"service5.example.com", "fc00::1:192:168:1:2"},
{"service6.example.com", "fc00::1:192:168:1:3"},
{"service7.example.com", "::ffff:192.168.20.3"},
}
// Test retrieve AAAA records unfiltered // Test retrieve AAAA records unfiltered
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA) arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(arecs) != 3 {
t.Fatal("Expected 3 AAAA records returned, got:", len(arecs)) if len(arecs) != len(expected) {
} t.Fatalf("Expected %d AAAA records returned, got: %d", len(expected), len(arecs))
// Ensure records were parsed correctly
expected = [][]string{
{"service4.example.com", "fc00::1:192:168:1:1"},
{"service5.example.com", "fc00::1:192:168:1:2"},
{"service6.example.com", "fc00::1:192:168:1:3"},
} }
for idx, rec := range arecs { for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] { if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
@ -223,20 +286,22 @@ func TestListRecordsV6(t *testing.T) {
} }
} }
// Ensure CNAME records were parsed correctly
expected = [][]string{
{"source1.example.com", "target1.domain.com", "1000"},
{"source2.example.com", "target2.domain.com", "50"},
{"source3.example.com", "target3.domain.com"},
}
// Test retrieve CNAME records unfiltered // Test retrieve CNAME records unfiltered
cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME) cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(cnamerecs) != 3 { if len(cnamerecs) != len(expected) {
t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs)) t.Fatalf("Expected %d CAME records returned, got: %d", len(expected), len(cnamerecs))
}
// Ensure records were parsed correctly
expected = [][]string{
{"source1.example.com", "target1.domain.com", "1000"},
{"source2.example.com", "target2.domain.com", "50"},
{"source3.example.com", "target3.domain.com"},
} }
for idx, rec := range cnamerecs { for idx, rec := range cnamerecs {
if rec.DNSName != expected[idx][0] { if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
@ -261,6 +326,7 @@ func TestListRecordsV6(t *testing.T) {
t.Fatal("Expected error for using unsupported record type") t.Fatal("Expected error for using unsupported record type")
} }
} }
func TestErrorsV6(t *testing.T) { func TestErrorsV6(t *testing.T) {
//Error test cases //Error test cases