diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 428a7961b..dba536dcd 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -40,6 +40,8 @@ const ( RecordTypeNS = "NS" // RecordTypePTR is a RecordType enum value RecordTypePTR = "PTR" + // RecordTypeMX is a RecordType enum value + RecordTypeMX = "MX" ) // TTL is a structure defining the TTL of a DNS record diff --git a/provider/aws/aws.go b/provider/aws/aws.go index 5f7457420..4a21e39dd 100644 --- a/provider/aws/aws.go +++ b/provider/aws/aws.go @@ -373,7 +373,7 @@ func (p *AWSProvider) records(ctx context.Context, zones map[string]*route53.Hos for _, r := range resp.ResourceRecordSets { newEndpoints := make([]*endpoint.Endpoint, 0) - if !provider.SupportedRecordType(aws.StringValue(r.Type)) { + if !p.SupportedRecordType(aws.StringValue(r.Type)) { continue } @@ -1059,3 +1059,12 @@ func canonicalHostedZone(hostname string) string { func cleanZoneID(id string) string { return strings.TrimPrefix(id, "/hostedzone/") } + +func (p *AWSProvider) SupportedRecordType(recordType string) bool { + switch recordType { + case "MX": + return true + default: + return provider.SupportedRecordType(recordType) + } +} diff --git a/provider/azure/azure.go b/provider/azure/azure.go index dfab96eff..24a0e0ecb 100644 --- a/provider/azure/azure.go +++ b/provider/azure/azure.go @@ -109,7 +109,7 @@ func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp return true } recordType := strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/dnszones/") - if !provider.SupportedRecordType(recordType) { + if !p.SupportedRecordType(recordType) { return true } name := formatAzureDNSName(*recordSet.Name, *zone.Name) @@ -190,6 +190,15 @@ func (p *AzureProvider) zones(ctx context.Context) ([]dns.Zone, error) { return zones, nil } +func (p *AzureProvider) SupportedRecordType(recordType string) bool { + switch recordType { + case "MX": + return true + default: + return provider.SupportedRecordType(recordType) + } +} + func (p *AzureProvider) iterateRecords(ctx context.Context, zoneName string, callback func(dns.RecordSet) bool) error { log.Debugf("Retrieving Azure DNS records for zone '%s'.", zoneName) @@ -377,6 +386,21 @@ func (p *AzureProvider) newRecordSet(endpoint *endpoint.Endpoint) (dns.RecordSet }, }, }, nil + case dns.MX: + mxRecords := make([]dns.MxRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + mxRecord, err := parseMxTarget[dns.MxRecord](target) + if err != nil { + return dns.RecordSet{}, err + } + mxRecords[i] = mxRecord + } + return dns.RecordSet{ + RecordSetProperties: &dns.RecordSetProperties{ + TTL: to.Int64Ptr(ttl), + MxRecords: &mxRecords, + }, + }, nil case dns.TXT: return dns.RecordSet{ RecordSetProperties: &dns.RecordSetProperties{ @@ -425,6 +449,16 @@ func extractAzureTargets(recordSet *dns.RecordSet) []string { return []string{*cnameRecord.Cname} } + // Check for MX records + mxRecords := properties.MxRecords + if mxRecords != nil && len(*mxRecords) > 0 && (*mxRecords)[0].Exchange != nil { + targets := make([]string, len(*mxRecords)) + for i, mxRecord := range *mxRecords { + targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange) + } + return targets + } + // Check for TXT records txtRecords := properties.TxtRecords if txtRecords != nil && len(*txtRecords) > 0 && (*txtRecords)[0].Value != nil { diff --git a/provider/azure/azure_private_dns.go b/provider/azure/azure_private_dns.go index 374b5e5d6..fd5733bfe 100644 --- a/provider/azure/azure_private_dns.go +++ b/provider/azure/azure_private_dns.go @@ -367,6 +367,21 @@ func (p *AzurePrivateDNSProvider) newRecordSet(endpoint *endpoint.Endpoint) (pri }, }, }, nil + case privatedns.MX: + mxRecords := make([]privatedns.MxRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + mxRecord, err := parseMxTarget[privatedns.MxRecord](target) + if err != nil { + return privatedns.RecordSet{}, err + } + mxRecords[i] = mxRecord + } + return privatedns.RecordSet{ + RecordSetProperties: &privatedns.RecordSetProperties{ + TTL: to.Int64Ptr(ttl), + MxRecords: &mxRecords, + }, + }, nil case privatedns.TXT: return privatedns.RecordSet{ RecordSetProperties: &privatedns.RecordSetProperties{ @@ -407,6 +422,16 @@ func extractAzurePrivateDNSTargets(recordSet *privatedns.RecordSet) []string { return []string{*cnameRecord.Cname} } + // Check for MX records + mxRecords := properties.MxRecords + if mxRecords != nil && len(*mxRecords) > 0 && (*mxRecords)[0].Exchange != nil { + targets := make([]string, len(*mxRecords)) + for i, mxRecord := range *mxRecords { + targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange) + } + return targets + } + // Check for TXT records txtRecords := properties.TxtRecords if txtRecords != nil && len(*txtRecords) > 0 && (*txtRecords)[0].Value != nil { diff --git a/provider/azure/azure_privatedns_test.go b/provider/azure/azure_privatedns_test.go index f35720154..e728659c4 100644 --- a/provider/azure/azure_privatedns_test.go +++ b/provider/azure/azure_privatedns_test.go @@ -123,6 +123,17 @@ func privateCNameRecordSetPropertiesGetter(values []string, ttl int64) *privated } } +func privateMXRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + mxRecords := make([]privatedns.MxRecord, len(values)) + for i, target := range values { + mxRecords[i], _ = parseMxTarget[privatedns.MxRecord](target) + } + return &privatedns.RecordSetProperties{ + TTL: to.Int64Ptr(ttl), + MxRecords: &mxRecords, + } +} + func privateTxtRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { return &privatedns.RecordSetProperties{ TTL: to.Int64Ptr(ttl), @@ -156,6 +167,8 @@ func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, getterFunc = privateARecordSetPropertiesGetter case endpoint.RecordTypeCNAME: getterFunc = privateCNameRecordSetPropertiesGetter + case endpoint.RecordTypeMX: + getterFunc = privateMXRecordSetPropertiesGetter case endpoint.RecordTypeTXT: getterFunc = privateTxtRecordSetPropertiesGetter default: @@ -266,6 +279,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) { createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createPrivateMockRecordSetWithTTL("mail", endpoint.RecordTypeMX, "10 example.com", 4000), }) if err != nil { t.Fatal(err) @@ -281,6 +295,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) { endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com"), } validateAzureEndpoints(t, actual, expected) @@ -299,6 +314,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) { createPrivateMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createPrivateMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), }) if err != nil { t.Fatal(err) @@ -314,6 +330,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) { endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), } validateAzureEndpoints(t, actual, expected) @@ -329,6 +346,7 @@ func TestAzurePrivateDNSApplyChanges(t *testing.T) { endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, ""), endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, ""), endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, ""), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, ""), }) validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ @@ -342,6 +360,9 @@ func TestAzurePrivateDNSApplyChanges(t *testing.T) { endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(recordTTL), "10 other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), }) } @@ -401,17 +422,21 @@ func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client P endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeMX, "10 other.com"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeTXT, "tag"), } currentRecords := []*endpoint.Endpoint{ endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, "121.212.121.212"), endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, "other.com"), endpoint.NewEndpoint("old.nope.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, "20 foo.other.com"), } updatedRecords := []*endpoint.Endpoint{ endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), } deleteRecords := []*endpoint.Endpoint{ diff --git a/provider/azure/azure_test.go b/provider/azure/azure_test.go index 0598dd4f3..050ab3d92 100644 --- a/provider/azure/azure_test.go +++ b/provider/azure/azure_test.go @@ -122,6 +122,17 @@ func cNameRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetPr } } +func mxRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + mxRecords := make([]dns.MxRecord, len(values)) + for i, target := range values { + mxRecords[i], _ = parseMxTarget[dns.MxRecord](target) + } + return &dns.RecordSetProperties{ + TTL: to.Int64Ptr(ttl), + MxRecords: &mxRecords, + } +} + func txtRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { return &dns.RecordSetProperties{ TTL: to.Int64Ptr(ttl), @@ -155,6 +166,8 @@ func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values getterFunc = aRecordSetPropertiesGetter case endpoint.RecordTypeCNAME: getterFunc = cNameRecordSetPropertiesGetter + case endpoint.RecordTypeMX: + getterFunc = mxRecordSetPropertiesGetter case endpoint.RecordTypeTXT: getterFunc = txtRecordSetPropertiesGetter default: @@ -271,6 +284,7 @@ func TestAzureRecord(t *testing.T) { createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com"), }) if err != nil { t.Fatal(err) @@ -287,6 +301,7 @@ func TestAzureRecord(t *testing.T) { endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com"), } validateAzureEndpoints(t, actual, expected) @@ -305,6 +320,7 @@ func TestAzureMultiRecord(t *testing.T) { createMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), }) if err != nil { t.Fatal(err) @@ -321,6 +337,7 @@ func TestAzureMultiRecord(t *testing.T) { endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), } validateAzureEndpoints(t, actual, expected) @@ -336,6 +353,7 @@ func TestAzureApplyChanges(t *testing.T) { endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, ""), endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, ""), endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, ""), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, ""), }) validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ @@ -349,6 +367,9 @@ func TestAzureApplyChanges(t *testing.T) { endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(recordTTL), "10 other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), }) } @@ -410,17 +431,21 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeMX, "10 other.com"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeTXT, "tag"), } currentRecords := []*endpoint.Endpoint{ endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, "121.212.121.212"), endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, "other.com"), endpoint.NewEndpoint("old.nope.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, "20 foo.other.com"), } updatedRecords := []*endpoint.Endpoint{ endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), } deleteRecords := []*endpoint.Endpoint{ @@ -455,6 +480,7 @@ func TestAzureNameFilter(t *testing.T) { createMockRecordSetWithTTL("test.nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL), createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), }) if err != nil { @@ -470,6 +496,7 @@ func TestAzureNameFilter(t *testing.T) { endpoint.NewEndpointWithTTL("test.nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("mail.nginx.example.com", endpoint.RecordTypeMX, recordTTL, "20 example.com"), } validateAzureEndpoints(t, actual, expected) diff --git a/provider/azure/common.go b/provider/azure/common.go new file mode 100644 index 000000000..15ecf64cf --- /dev/null +++ b/provider/azure/common.go @@ -0,0 +1,33 @@ +package azure + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns" + "github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns" + "github.com/Azure/go-autorest/autorest/to" +) + +// Helper function (shared with test code) +func parseMxTarget[T dns.MxRecord | privatedns.MxRecord](mxTarget string) (T, error) { + targetParts := strings.SplitN(mxTarget, " ", 2) + + if len(targetParts) != 2 { + return T{}, fmt.Errorf("mx target needs to be of form '10 example.com'") + } + + preferenceRaw, exchange := targetParts[0], targetParts[1] + preference, err := strconv.ParseInt(preferenceRaw, 10, 32) + + if err != nil { + return T{}, fmt.Errorf("invalid preference specified") + } + res := T{ + Preference: to.Int32Ptr(int32(preference)), + Exchange: to.StringPtr(exchange), + } + + return res, nil +} diff --git a/provider/azure/common_test.go b/provider/azure/common_test.go new file mode 100644 index 000000000..b7f4d8dfd --- /dev/null +++ b/provider/azure/common_test.go @@ -0,0 +1,71 @@ +package azure + +import ( + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns" + "github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns" + "github.com/Azure/go-autorest/autorest/to" + + "github.com/stretchr/testify/assert" +) + +func Test_parseMxTarget(t *testing.T) { + type testCase[T interface { + dns.MxRecord | privatedns.MxRecord + }] struct { + name string + args string + want T + wantErr assert.ErrorAssertionFunc + } + + tests := []testCase[dns.MxRecord]{ + { + name: "valid mx target", + args: "10 example.com", + want: dns.MxRecord{ + Preference: to.Int32Ptr(int32(10)), + Exchange: to.StringPtr("example.com"), + }, + wantErr: assert.NoError, + }, + { + name: "valid mx target with a subdomain", + args: "99 foo-bar.example.com", + want: dns.MxRecord{ + Preference: to.Int32Ptr(int32(99)), + Exchange: to.StringPtr("foo-bar.example.com"), + }, + wantErr: assert.NoError, + }, + { + name: "invalid mx target with misplaced preference and exchange", + args: "example.com 10", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + { + name: "invalid mx target without preference", + args: "example.com", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + { + name: "invalid mx target with non numeric preference", + args: "aa example.com", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMxTarget[dns.MxRecord](tt.args) + if !tt.wantErr(t, err, fmt.Sprintf("parseMxTarget(%v)", tt.args)) { + return + } + assert.Equalf(t, tt.want, got, "parseMxTarget(%v)", tt.args) + }) + } +} diff --git a/provider/google/google.go b/provider/google/google.go index d6c53ef6a..5fce653f5 100644 --- a/provider/google/google.go +++ b/provider/google/google.go @@ -213,7 +213,7 @@ func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.End f := func(resp *dns.ResourceRecordSetsListResponse) error { for _, r := range resp.Rrsets { - if !provider.SupportedRecordType(r.Type) { + if !p.SupportedRecordType(r.Type) { continue } endpoints = append(endpoints, endpoint.NewEndpointWithTTL(r.Name, r.Type, endpoint.TTL(r.Ttl), r.Rrdatas...)) @@ -273,6 +273,16 @@ func (p *GoogleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes return p.submitChange(ctx, change) } +// SupportedRecordType returns true if the record type is supported by the provider +func (p *GoogleProvider) SupportedRecordType(recordType string) bool { + switch recordType { + case "MX": + return true + default: + return provider.SupportedRecordType(recordType) + } +} + // newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter. func (p *GoogleProvider) newFilteredRecords(endpoints []*endpoint.Endpoint) []*dns.ResourceRecordSet { records := []*dns.ResourceRecordSet{} @@ -447,6 +457,12 @@ func newRecord(ep *endpoint.Endpoint) *dns.ResourceRecordSet { targets[0] = provider.EnsureTrailingDot(targets[0]) } + if ep.RecordType == endpoint.RecordTypeMX { + for i, mxRecord := range ep.Targets { + targets[i] = provider.EnsureTrailingDot(mxRecord) + } + } + // no annotation results in a Ttl of 0, default to 300 for backwards-compatibility var ttl int64 = googleRecordTTL if ep.RecordTTL.IsConfigured() { diff --git a/provider/google/google_test.go b/provider/google/google_test.go index 3fe610091..bc321200b 100644 --- a/provider/google/google_test.go +++ b/provider/google/google_test.go @@ -512,6 +512,7 @@ func TestNewFilteredRecords(t *testing.T) { endpoint.NewEndpointWithTTL("update-test-cname.zone-1.ext-dns-test-2.gcp.zalan.do", endpoint.RecordTypeCNAME, 4000, "bar.elb.amazonaws.com"), // test fallback to Ttl:300 when Ttl==0 : endpoint.NewEndpointWithTTL("update-test.zone-1.ext-dns-test-2.gcp.zalan.do", endpoint.RecordTypeA, 0, "8.8.8.8"), + endpoint.NewEndpointWithTTL("update-test-mx.zone-1.ext-dns-test-2.gcp.zalan.do", endpoint.RecordTypeMX, 6000, "10 mail.elb.amazonaws.com"), endpoint.NewEndpoint("delete-test.zone-1.ext-dns-test-2.gcp.zalan.do", endpoint.RecordTypeA, "8.8.8.8"), endpoint.NewEndpoint("delete-test-cname.zone-1.ext-dns-test-2.gcp.zalan.do", endpoint.RecordTypeCNAME, "qux.elb.amazonaws.com"), }) @@ -521,6 +522,7 @@ func TestNewFilteredRecords(t *testing.T) { {Name: "delete-test.zone-2.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"8.8.4.4"}, Type: "A", Ttl: 120}, {Name: "update-test-cname.zone-1.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"bar.elb.amazonaws.com."}, Type: "CNAME", Ttl: 4000}, {Name: "update-test.zone-1.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"8.8.8.8"}, Type: "A", Ttl: 300}, + {Name: "update-test-mx.zone-1.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"10 mail.elb.amazonaws.com."}, Type: "MX", Ttl: 6000}, {Name: "delete-test.zone-1.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"8.8.8.8"}, Type: "A", Ttl: 300}, {Name: "delete-test-cname.zone-1.ext-dns-test-2.gcp.zalan.do.", Rrdatas: []string{"qux.elb.amazonaws.com."}, Type: "CNAME", Ttl: 300}, })