From a706ba32ab173bf9839d6040d5a6973a57384a0d Mon Sep 17 00:00:00 2001 From: Thibault Cohen Date: Tue, 8 Sep 2020 11:18:04 -0400 Subject: [PATCH] Add test for ZoneNameFilter --- provider/azure/azure.go | 2 +- provider/azure/azure_test.go | 147 +++++++++++++++++++++++++++++++++-- 2 files changed, 143 insertions(+), 6 deletions(-) diff --git a/provider/azure/azure.go b/provider/azure/azure.go index 0584bd0c2..1c82b0136 100644 --- a/provider/azure/azure.go +++ b/provider/azure/azure.go @@ -210,7 +210,7 @@ func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) { log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name) - return false + return true } targets := extractAzureTargets(&recordSet) if len(targets) == 0 { diff --git a/provider/azure/azure_test.go b/provider/azure/azure_test.go index 30a21f2d6..fa5e2341d 100644 --- a/provider/azure/azure_test.go +++ b/provider/azure/azure_test.go @@ -207,7 +207,7 @@ func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resource } // newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets -func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) { +func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) { // init zone-related parts of the mock-client pageIterator := mockZoneListResultPageIterator{ results: []dns.ZoneListResult{ @@ -237,12 +237,13 @@ func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter pro mockRecordSetListIterator: &mockRecordSetListIterator, } - return newAzureProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil + return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil } -func newAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider { +func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider { return &AzureProvider{ domainFilter: domainFilter, + zoneNameFilter: zoneNameFilter, zoneIDFilter: zoneIDFilter, dryRun: dryRun, resourceGroup: resourceGroup, @@ -257,7 +258,7 @@ func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expect } func TestAzureRecord(t *testing.T) { - provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", &[]dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, @@ -294,7 +295,7 @@ func TestAzureRecord(t *testing.T) { } func TestAzureMultiRecord(t *testing.T) { - provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", &[]dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, @@ -393,6 +394,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC } provider := newAzureProvider( + endpoint.NewDomainFilter([]string{""}), endpoint.NewDomainFilter([]string{""}), provider.NewZoneIDFilter([]string{""}), dryRun, @@ -496,3 +498,138 @@ func TestAzureGetAccessToken(t *testing.T) { t.Fatalf("expect the clientID of the token is SPNClientID, but got token %s", string(innerToken)) } } + +func TestAzureNameFilter(t *testing.T) { + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + &[]dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + }, + + &[]dns.RecordSet{ + createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), + createMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createMockRecordSetWithTTL("test.nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + }) + + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + actual, err := provider.Records(ctx) + + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("test.nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + } + + validateAzureEndpoints(t, actual, expected) + +} + +func TestAzureApplyChangesZoneName(t *testing.T) { + recordsClient := mockRecordSetsClient{} + + testAzureApplyChangesInternalZoneName(t, false, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpoint("old.foo.example.com", endpoint.RecordTypeA, ""), + endpoint.NewEndpoint("oldcname.foo.example.com", endpoint.RecordTypeCNAME, ""), + endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, ""), + endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, ""), + }) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + }) +} + +func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client RecordSetsClient) { + zlr := dns.ZoneListResult{ + Value: &[]dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + }, + } + + results := []dns.ZoneListResult{ + zlr, + } + + mockZoneListResultPage := dns.NewZoneListResultPage(func(ctxParam context.Context, zlrParam dns.ZoneListResult) (dns.ZoneListResult, error) { + if len(results) > 0 { + result := results[0] + results = nil + return result, nil + } + return dns.ZoneListResult{}, nil + }) + mockZoneClientIterator := dns.NewZoneListResultIterator(mockZoneListResultPage) + + zonesClient := mockZonesClient{ + mockZonesClientIterator: &mockZoneClientIterator, + } + + provider := newAzureProvider( + endpoint.NewDomainFilter([]string{"foo.example.com"}), + endpoint.NewDomainFilter([]string{"example.com"}), + provider.NewZoneIDFilter([]string{""}), + dryRun, + "group", + "", + &zonesClient, + client, + ) + + createRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + } + + currentRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("old.foo.example.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("old.nope.example.com", endpoint.RecordTypeA, "121.212.121.212"), + } + updatedRecords := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpoint("new.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"), + } + + deleteRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, "111.222.111.222"), + endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("deleted.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"), + } + + changes := &plan.Changes{ + Create: createRecords, + UpdateNew: updatedRecords, + UpdateOld: currentRecords, + Delete: deleteRecords, + } + + if err := provider.ApplyChanges(context.Background(), changes); err != nil { + t.Fatal(err) + } +}