diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 1bd2d54fb..a8a385c57 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -20,6 +20,7 @@ import ( "fmt" "net/netip" "sort" + "strconv" "strings" log "github.com/sirupsen/logrus" @@ -396,3 +397,54 @@ func RemoveDuplicates(endpoints []*Endpoint) []*Endpoint { return result } + +// Check endpoint if is it properly formatted according to RFC standards +func (e *Endpoint) CheckEndpoint() bool { + switch recordType := e.RecordType; recordType { + case RecordTypeMX: + return e.Targets.ValidateMXRecord() + case RecordTypeSRV: + return e.Targets.ValidateSRVRecord() + } + return true +} + +func (t Targets) ValidateMXRecord() bool { + for _, target := range t { + // MX records must have a preference value to indicate priority, e.g. "10 example.com" + // as per https://www.rfc-editor.org/rfc/rfc974.txt + targetParts := strings.Fields(strings.TrimSpace(target)) + if len(targetParts) != 2 { + log.Debugf("Invalid MX record target: %s. MX records must have a preference value to indicate priority, e.g. '10 example.com'", target) + return false + } + preferenceRaw := targetParts[0] + _, err := strconv.ParseUint(preferenceRaw, 10, 16) + if err != nil { + log.Debugf("Invalid SRV record target: %s. Invalid integer value in target.", target) + return false + } + } + return true +} + +func (t Targets) ValidateSRVRecord() bool { + for _, target := range t { + // SRV records must have a priority, weight, and port value, e.g. "10 5 5060 example.com" + // as per https://www.rfc-editor.org/rfc/rfc2782.txt + targetParts := strings.Fields(strings.TrimSpace(target)) + if len(targetParts) != 4 { + log.Debugf("Invalid SRV record target: %s. SRV records must have a priority, weight, and port value, e.g. '10 5 5060 example.com'", target) + return false + } + + for _, part := range targetParts[:3] { + _, err := strconv.ParseUint(part, 10, 16) + if err != nil { + log.Debugf("Invalid SRV record target: %s. Invalid integer value in target.", target) + return false + } + } + } + return true +} diff --git a/endpoint/endpoint_test.go b/endpoint/endpoint_test.go index cd38208e4..6c805ecf0 100644 --- a/endpoint/endpoint_test.go +++ b/endpoint/endpoint_test.go @@ -19,6 +19,8 @@ package endpoint import ( "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestNewEndpoint(t *testing.T) { @@ -441,3 +443,98 @@ func TestDuplicatedEndpointsWithOverlappingZones(t *testing.T) { }) } } + +func TestPDNScheckEndpoint(t *testing.T) { + tests := []struct { + description string + endpoint Endpoint + expected bool + }{ + { + description: "Valid MX record target", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"10 example.com"}, + }, + expected: true, + }, + { + description: "Valid MX record with multiple targets", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"10 example.com", "20 backup.example.com"}, + }, + expected: true, + }, + { + description: "MX record with valid and invalid targets", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"example.com", "backup.example.com"}, + }, + expected: false, + }, + { + description: "Invalid MX record with missing priority value", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"example.com"}, + }, + expected: false, + }, + { + description: "Invalid MX record with too many arguments", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"10 example.com abc"}, + }, + expected: false, + }, + { + description: "Invalid MX record with non-integer priority", + endpoint: Endpoint{ + DNSName: "example.com", + RecordType: RecordTypeMX, + Targets: Targets{"abc example.com"}, + }, + expected: false, + }, + { + description: "Valid SRV record target", + endpoint: Endpoint{ + DNSName: "_service._tls.example.com", + RecordType: RecordTypeSRV, + Targets: Targets{"10 20 5060 service.example.com"}, + }, + expected: true, + }, + { + description: "Invalid SRV record with missing part", + endpoint: Endpoint{ + DNSName: "_service._tls.example.com", + RecordType: RecordTypeSRV, + Targets: Targets{"10 20 5060"}, + }, + expected: false, + }, + { + description: "Invalid SRV record with non-integer part", + endpoint: Endpoint{ + DNSName: "_service._tls.example.com", + RecordType: RecordTypeSRV, + Targets: Targets{"10 20 abc service.example.com"}, + }, + expected: false, + }, + } + + for _, tt := range tests { + actual := tt.endpoint.CheckEndpoint() + assert.Equal(t, tt.expected, actual) + } +} diff --git a/provider/pdns/pdns.go b/provider/pdns/pdns.go index c0180725a..a37787f8a 100644 --- a/provider/pdns/pdns.go +++ b/provider/pdns/pdns.go @@ -433,6 +433,19 @@ func (p *PDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpo return endpoints, nil } +// AdjustEndpoints performs checks on the provided endpoints and will skip any potentially failing changes. +func (p *PDNSProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) { + var validEndpoints []*endpoint.Endpoint + for i := 0; i < len(endpoints); i++ { + if !endpoints[i].CheckEndpoint() { + log.Warnf("Ignoring Endpoint because of invalid %v record formatting: {Target: '%v'}", endpoints[i].RecordType, endpoints[i].Targets) + continue + } + validEndpoints = append(validEndpoints, endpoints[i]) + } + return validEndpoints, nil +} + // ApplyChanges takes a list of changes (endpoints) and updates the PDNS server // by sending the correct HTTP PATCH requests to a matching zone func (p *PDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { diff --git a/provider/pdns/pdns_test.go b/provider/pdns/pdns_test.go index 47df67cb1..773d4897c 100644 --- a/provider/pdns/pdns_test.go +++ b/provider/pdns/pdns_test.go @@ -161,6 +161,24 @@ var ( endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeA, endpoint.TTL(300), "8.8.8.8", "8.8.4.4", "4.4.4.4"), } + endpointsMXRecord = []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com"), + } + + endpointsMXRecordInvalidFormatTooManyArgs = []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com abc"), + } + + endpointsMultipleMXRecordsWithSingleInvalid = []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "abc example.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "20 backup.example.com"), + } + + endpointsMultipleInvalidMXRecords = []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "example.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "backup.example.com"), + } + endpointsMixedRecords = []*endpoint.Endpoint{ endpoint.NewEndpointWithTTL("cname.example.com", endpoint.RecordTypeCNAME, endpoint.TTL(300), "example.com"), endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeTXT, endpoint.TTL(300), "'would smell as sweet'"), @@ -1116,6 +1134,51 @@ func (suite *NewPDNSProviderTestSuite) TestPDNSClientPartitionZones() { assert.Equal(suite.T(), partitionResultResidualSingleFilter, residualZones) } +// Validate whether invalid endpoints are removed by AdjustEndpoints +func (suite *NewPDNSProviderTestSuite) TestPDNSAdjustEndpoints() { + // Function definition: AdjustEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint + + // Create a new provider to run tests against + p := &PDNSProvider{} + + tests := []struct { + description string + endpoints []*endpoint.Endpoint + expected []*endpoint.Endpoint + }{ + { + description: "Valid MX endpoint is not removed", + endpoints: endpointsMXRecord, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "10 example.com"), + }, + }, + { + description: "Invalid MX endpoint with too many arguments is removed", + endpoints: endpointsMXRecordInvalidFormatTooManyArgs, + expected: []*endpoint.Endpoint([]*endpoint.Endpoint(nil)), + }, + { + description: "Invalid MX endpoint is removed among valid endpoints", + endpoints: endpointsMultipleMXRecordsWithSingleInvalid, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(300), "20 backup.example.com"), + }, + }, + { + description: "Multiple invalid MX endpoints are removed", + endpoints: endpointsMultipleInvalidMXRecords, + expected: []*endpoint.Endpoint([]*endpoint.Endpoint(nil)), + }, + } + + for _, tt := range tests { + actual, err := p.AdjustEndpoints(tt.endpoints) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), tt.expected, actual) + } +} + func TestNewPDNSProviderTestSuite(t *testing.T) { suite.Run(t, new(NewPDNSProviderTestSuite)) }