Merge pull request #4871 from Demonware/pdns-validate-mx-srv

feat(pdns): add validation for MX and SRV records
This commit is contained in:
Kubernetes Prow Robot 2025-02-24 23:32:30 -08:00 committed by GitHub
commit 0e8f84662b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 225 additions and 0 deletions

View File

@ -20,6 +20,7 @@ import (
"fmt"
"net/netip"
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
@ -396,3 +397,54 @@ func RemoveDuplicates(endpoints []*Endpoint) []*Endpoint {
return result
}
// Check endpoint if is it properly formatted according to RFC standards
func (e *Endpoint) CheckEndpoint() bool {
switch recordType := e.RecordType; recordType {
case RecordTypeMX:
return e.Targets.ValidateMXRecord()
case RecordTypeSRV:
return e.Targets.ValidateSRVRecord()
}
return true
}
func (t Targets) ValidateMXRecord() bool {
for _, target := range t {
// MX records must have a preference value to indicate priority, e.g. "10 example.com"
// as per https://www.rfc-editor.org/rfc/rfc974.txt
targetParts := strings.Fields(strings.TrimSpace(target))
if len(targetParts) != 2 {
log.Debugf("Invalid MX record target: %s. MX records must have a preference value to indicate priority, e.g. '10 example.com'", target)
return false
}
preferenceRaw := targetParts[0]
_, err := strconv.ParseUint(preferenceRaw, 10, 16)
if err != nil {
log.Debugf("Invalid SRV record target: %s. Invalid integer value in target.", target)
return false
}
}
return true
}
func (t Targets) ValidateSRVRecord() bool {
for _, target := range t {
// SRV records must have a priority, weight, and port value, e.g. "10 5 5060 example.com"
// as per https://www.rfc-editor.org/rfc/rfc2782.txt
targetParts := strings.Fields(strings.TrimSpace(target))
if len(targetParts) != 4 {
log.Debugf("Invalid SRV record target: %s. SRV records must have a priority, weight, and port value, e.g. '10 5 5060 example.com'", target)
return false
}
for _, part := range targetParts[:3] {
_, err := strconv.ParseUint(part, 10, 16)
if err != nil {
log.Debugf("Invalid SRV record target: %s. Invalid integer value in target.", target)
return false
}
}
}
return true
}

View File

@ -19,6 +19,8 @@ package endpoint
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewEndpoint(t *testing.T) {
@ -441,3 +443,98 @@ func TestDuplicatedEndpointsWithOverlappingZones(t *testing.T) {
})
}
}
func TestPDNScheckEndpoint(t *testing.T) {
tests := []struct {
description string
endpoint Endpoint
expected bool
}{
{
description: "Valid MX record target",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"10 example.com"},
},
expected: true,
},
{
description: "Valid MX record with multiple targets",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"10 example.com", "20 backup.example.com"},
},
expected: true,
},
{
description: "MX record with valid and invalid targets",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"example.com", "backup.example.com"},
},
expected: false,
},
{
description: "Invalid MX record with missing priority value",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"example.com"},
},
expected: false,
},
{
description: "Invalid MX record with too many arguments",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"10 example.com abc"},
},
expected: false,
},
{
description: "Invalid MX record with non-integer priority",
endpoint: Endpoint{
DNSName: "example.com",
RecordType: RecordTypeMX,
Targets: Targets{"abc example.com"},
},
expected: false,
},
{
description: "Valid SRV record target",
endpoint: Endpoint{
DNSName: "_service._tls.example.com",
RecordType: RecordTypeSRV,
Targets: Targets{"10 20 5060 service.example.com"},
},
expected: true,
},
{
description: "Invalid SRV record with missing part",
endpoint: Endpoint{
DNSName: "_service._tls.example.com",
RecordType: RecordTypeSRV,
Targets: Targets{"10 20 5060"},
},
expected: false,
},
{
description: "Invalid SRV record with non-integer part",
endpoint: Endpoint{
DNSName: "_service._tls.example.com",
RecordType: RecordTypeSRV,
Targets: Targets{"10 20 abc service.example.com"},
},
expected: false,
},
}
for _, tt := range tests {
actual := tt.endpoint.CheckEndpoint()
assert.Equal(t, tt.expected, actual)
}
}

View File

@ -433,6 +433,19 @@ func (p *PDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpo
return endpoints, nil
}
// AdjustEndpoints performs checks on the provided endpoints and will skip any potentially failing changes.
func (p *PDNSProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
var validEndpoints []*endpoint.Endpoint
for i := 0; i < len(endpoints); i++ {
if !endpoints[i].CheckEndpoint() {
log.Warnf("Ignoring Endpoint because of invalid %v record formatting: {Target: '%v'}", endpoints[i].RecordType, endpoints[i].Targets)
continue
}
validEndpoints = append(validEndpoints, endpoints[i])
}
return validEndpoints, nil
}
// ApplyChanges takes a list of changes (endpoints) and updates the PDNS server
// by sending the correct HTTP PATCH requests to a matching zone
func (p *PDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {

View File

@ -161,6 +161,24 @@ var (
endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeA, endpoint.TTL(300), "8.8.8.8", "8.8.4.4", "4.4.4.4"),
}
endpointsMXRecord = []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com"),
}
endpointsMXRecordInvalidFormatTooManyArgs = []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com abc"),
}
endpointsMultipleMXRecordsWithSingleInvalid = []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "abc example.com"),
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "20 backup.example.com"),
}
endpointsMultipleInvalidMXRecords = []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "example.com"),
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "backup.example.com"),
}
endpointsMixedRecords = []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("cname.example.com", endpoint.RecordTypeCNAME, endpoint.TTL(300), "example.com"),
endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeTXT, endpoint.TTL(300), "'would smell as sweet'"),
@ -1116,6 +1134,51 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSClientPartitionZones() {
assert.Equal(suite.T(), partitionResultResidualSingleFilter, residualZones)
}
// Validate whether invalid endpoints are removed by AdjustEndpoints
func (suite *NewPDNSProviderTestSuite) TestPDNSAdjustEndpoints() {
// Function definition: AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint
// Create a new provider to run tests against
p := &PDNSProvider{}
tests := []struct {
description string
endpoints []*endpoint.Endpoint
expected []*endpoint.Endpoint
}{
{
description: "Valid MX endpoint is not removed",
endpoints: endpointsMXRecord,
expected: []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com"),
},
},
{
description: "Invalid MX endpoint with too many arguments is removed",
endpoints: endpointsMXRecordInvalidFormatTooManyArgs,
expected: []*endpoint.Endpoint([]*endpoint.Endpoint(nil)),
},
{
description: "Invalid MX endpoint is removed among valid endpoints",
endpoints: endpointsMultipleMXRecordsWithSingleInvalid,
expected: []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "20 backup.example.com"),
},
},
{
description: "Multiple invalid MX endpoints are removed",
endpoints: endpointsMultipleInvalidMXRecords,
expected: []*endpoint.Endpoint([]*endpoint.Endpoint(nil)),
},
}
for _, tt := range tests {
actual, err := p.AdjustEndpoints(tt.endpoints)
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), tt.expected, actual)
}
}
func TestNewPDNSProviderTestSuite(t *testing.T) {
suite.Run(t, new(NewPDNSProviderTestSuite))
}