Add test for ZoneNameFilter

This commit is contained in:
Thibault Cohen 2020-09-08 11:18:04 -04:00
parent dac21e3aff
commit a706ba32ab
2 changed files with 143 additions and 6 deletions

View File

@ -210,7 +210,7 @@ func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp
if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) {
log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name)
return false
return true
}
targets := extractAzureTargets(&recordSet)
if len(targets) == 0 {

View File

@ -207,7 +207,7 @@ func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resource
}
// newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) {
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) {
// init zone-related parts of the mock-client
pageIterator := mockZoneListResultPageIterator{
results: []dns.ZoneListResult{
@ -237,12 +237,13 @@ func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter pro
mockRecordSetListIterator: &mockRecordSetListIterator,
}
return newAzureProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil
return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil
}
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider {
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider {
return &AzureProvider{
domainFilter: domainFilter,
zoneNameFilter: zoneNameFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: resourceGroup,
@ -257,7 +258,7 @@ func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expect
}
func TestAzureRecord(t *testing.T) {
provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "",
provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "",
&[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
@ -294,7 +295,7 @@ func TestAzureRecord(t *testing.T) {
}
func TestAzureMultiRecord(t *testing.T) {
provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "",
provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "",
&[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
@ -393,6 +394,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC
}
provider := newAzureProvider(
endpoint.NewDomainFilter([]string{""}),
endpoint.NewDomainFilter([]string{""}),
provider.NewZoneIDFilter([]string{""}),
dryRun,
@ -496,3 +498,138 @@ func TestAzureGetAccessToken(t *testing.T) {
t.Fatalf("expect the clientID of the token is SPNClientID, but got token %s", string(innerToken))
}
}
func TestAzureNameFilter(t *testing.T) {
provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "",
&[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
&[]dns.RecordSet{
createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."),
createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"),
createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"),
createMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"),
createMockRecordSetWithTTL("test.nginx", endpoint.RecordTypeA, "123.123.123.123", 3600),
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600),
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
})
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
actual, err := provider.Records(ctx)
if err != nil {
t.Fatal(err)
}
expected := []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("test.nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"),
}
validateAzureEndpoints(t, actual, expected)
}
func TestAzureApplyChangesZoneName(t *testing.T) {
recordsClient := mockRecordSetsClient{}
testAzureApplyChangesInternalZoneName(t, false, &recordsClient)
validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{
endpoint.NewEndpoint("old.foo.example.com", endpoint.RecordTypeA, ""),
endpoint.NewEndpoint("oldcname.foo.example.com", endpoint.RecordTypeCNAME, ""),
endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, ""),
endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, ""),
})
validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"),
endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"),
endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"),
endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"),
})
}
func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client RecordSetsClient) {
zlr := dns.ZoneListResult{
Value: &[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
}
results := []dns.ZoneListResult{
zlr,
}
mockZoneListResultPage := dns.NewZoneListResultPage(func(ctxParam context.Context, zlrParam dns.ZoneListResult) (dns.ZoneListResult, error) {
if len(results) > 0 {
result := results[0]
results = nil
return result, nil
}
return dns.ZoneListResult{}, nil
})
mockZoneClientIterator := dns.NewZoneListResultIterator(mockZoneListResultPage)
zonesClient := mockZonesClient{
mockZonesClientIterator: &mockZoneClientIterator,
}
provider := newAzureProvider(
endpoint.NewDomainFilter([]string{"foo.example.com"}),
endpoint.NewDomainFilter([]string{"example.com"}),
provider.NewZoneIDFilter([]string{""}),
dryRun,
"group",
"",
&zonesClient,
client,
)
createRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"),
endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"),
endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"),
endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"),
}
currentRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("old.foo.example.com", endpoint.RecordTypeA, "121.212.121.212"),
endpoint.NewEndpoint("oldcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("old.nope.example.com", endpoint.RecordTypeA, "121.212.121.212"),
}
updatedRecords := []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"),
endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"),
endpoint.NewEndpoint("new.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"),
}
deleteRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, "111.222.111.222"),
endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("deleted.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"),
}
changes := &plan.Changes{
Create: createRecords,
UpdateNew: updatedRecords,
UpdateOld: currentRecords,
Delete: deleteRecords,
}
if err := provider.ApplyChanges(context.Background(), changes); err != nil {
t.Fatal(err)
}
}