external-dns/provider/azure/azure_privatedns_test.go
Andy Bursavich f6392be41e
gofumpt
2022-09-22 10:44:50 +01:00

434 lines
18 KiB
Go

/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"testing"
"github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/to"
"sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/plan"
"sigs.k8s.io/external-dns/provider"
)
const (
recordTTL = 300
)
// mockPrivateZonesClient implements the methods of the Azure Private DNS Zones Client which are used in the Azure Private DNS Provider
// and returns static results which are defined per test
type mockPrivateZonesClient struct {
mockZonesClientIterator *privatedns.PrivateZoneListResultIterator
}
// mockPrivateRecordSetsClient implements the methods of the Azure Private DNS RecordSet Client which are used in the Azure Private DNS Provider
// and returns static results which are defined per test
type mockPrivateRecordSetsClient struct {
mockRecordSetListIterator *privatedns.RecordSetListResultIterator
deletedEndpoints []*endpoint.Endpoint
updatedEndpoints []*endpoint.Endpoint
}
// mockPrivateZoneListResultPageIterator is used to paginate forward through a list of zones
type mockPrivateZoneListResultPageIterator struct {
offset int
results []privatedns.PrivateZoneListResult
}
// getNextPage provides the next page based on the offset of the mockZoneListResultPageIterator
func (m *mockPrivateZoneListResultPageIterator) getNextPage(context.Context, privatedns.PrivateZoneListResult) (privatedns.PrivateZoneListResult, error) {
// it assumed that instances of this kind of iterator are only skimmed through once per test
// otherwise a real implementation is required, e.g. based on a linked list
if m.offset < len(m.results) {
m.offset++
return m.results[m.offset-1], nil
}
// paged to last page or empty
return privatedns.PrivateZoneListResult{}, nil
}
// mockPrivateRecordSetListResultPageIterator is used to paginate forward through a list of recordsets
type mockPrivateRecordSetListResultPageIterator struct {
offset int
results []privatedns.RecordSetListResult
}
// getNextPage provides the next page based on the offset of the mockRecordSetListResultPageIterator
func (m *mockPrivateRecordSetListResultPageIterator) getNextPage(context.Context, privatedns.RecordSetListResult) (privatedns.RecordSetListResult, error) {
// it assumed that instances of this kind of iterator are only skimmed through once per test
// otherwise a real implementation is required, e.g. based on a linked list
if m.offset < len(m.results) {
m.offset++
return m.results[m.offset-1], nil
}
// paged to last page or empty
return privatedns.RecordSetListResult{}, nil
}
func createMockPrivateZone(zone string, id string) privatedns.PrivateZone {
return privatedns.PrivateZone{
ID: to.StringPtr(id),
Name: to.StringPtr(zone),
}
}
func (client *mockPrivateZonesClient) ListByResourceGroupComplete(ctx context.Context, resourceGroupName string, top *int32) (result privatedns.PrivateZoneListResultIterator, err error) {
// pre-iterate to first item to emulate behaviour of Azure SDK
err = client.mockZonesClientIterator.NextWithContext(ctx)
if err != nil {
return *client.mockZonesClientIterator, err
}
return *client.mockZonesClientIterator, nil
}
func privateARecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties {
aRecords := make([]privatedns.ARecord, len(values))
for i, value := range values {
aRecords[i] = privatedns.ARecord{
Ipv4Address: to.StringPtr(value),
}
}
return &privatedns.RecordSetProperties{
TTL: to.Int64Ptr(ttl),
ARecords: &aRecords,
}
}
func privateCNameRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties {
return &privatedns.RecordSetProperties{
TTL: to.Int64Ptr(ttl),
CnameRecord: &privatedns.CnameRecord{
Cname: to.StringPtr(values[0]),
},
}
}
func privateTxtRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties {
return &privatedns.RecordSetProperties{
TTL: to.Int64Ptr(ttl),
TxtRecords: &[]privatedns.TxtRecord{
{
Value: &[]string{values[0]},
},
},
}
}
func privateOthersRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties {
return &privatedns.RecordSetProperties{
TTL: to.Int64Ptr(ttl),
}
}
func createPrivateMockRecordSet(name, recordType string, values ...string) privatedns.RecordSet {
return createPrivateMockRecordSetMultiWithTTL(name, recordType, 0, values...)
}
func createPrivateMockRecordSetWithTTL(name, recordType, value string, ttl int64) privatedns.RecordSet {
return createPrivateMockRecordSetMultiWithTTL(name, recordType, ttl, value)
}
func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) privatedns.RecordSet {
var getterFunc func(values []string, ttl int64) *privatedns.RecordSetProperties
switch recordType {
case endpoint.RecordTypeA:
getterFunc = privateARecordSetPropertiesGetter
case endpoint.RecordTypeCNAME:
getterFunc = privateCNameRecordSetPropertiesGetter
case endpoint.RecordTypeTXT:
getterFunc = privateTxtRecordSetPropertiesGetter
default:
getterFunc = privateOthersRecordSetPropertiesGetter
}
return privatedns.RecordSet{
Name: to.StringPtr(name),
Type: to.StringPtr("Microsoft.Network/privateDnsZones/" + recordType),
RecordSetProperties: getterFunc(values, ttl),
}
}
func (client *mockPrivateRecordSetsClient) ListComplete(ctx context.Context, resourceGroupName string, zoneName string, top *int32, recordSetNameSuffix string) (result privatedns.RecordSetListResultIterator, err error) {
// pre-iterate to first item to emulate behaviour of Azure SDK
err = client.mockRecordSetListIterator.NextWithContext(ctx)
if err != nil {
return *client.mockRecordSetListIterator, err
}
return *client.mockRecordSetListIterator, nil
}
func (client *mockPrivateRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, ifMatch string) (result autorest.Response, err error) {
client.deletedEndpoints = append(
client.deletedEndpoints,
endpoint.NewEndpoint(
formatAzureDNSName(relativeRecordSetName, privateZoneName),
string(recordType),
"",
),
)
return autorest.Response{}, nil
}
func (client *mockPrivateRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, ifMatch string, ifNoneMatch string) (result privatedns.RecordSet, err error) {
var ttl endpoint.TTL
if parameters.TTL != nil {
ttl = endpoint.TTL(*parameters.TTL)
}
client.updatedEndpoints = append(
client.updatedEndpoints,
endpoint.NewEndpointWithTTL(
formatAzureDNSName(relativeRecordSetName, privateZoneName),
string(recordType),
ttl,
extractAzurePrivateDNSTargets(&parameters)...,
),
)
return parameters, nil
}
// newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones *[]privatedns.PrivateZone, recordSets *[]privatedns.RecordSet) (*AzurePrivateDNSProvider, error) {
// init zone-related parts of the mock-client
pageIterator := mockPrivateZoneListResultPageIterator{
results: []privatedns.PrivateZoneListResult{
{
Value: zones,
},
},
}
mockZoneListResultPage := privatedns.NewPrivateZoneListResultPage(privatedns.PrivateZoneListResult{}, pageIterator.getNextPage)
mockZoneClientIterator := privatedns.NewPrivateZoneListResultIterator(mockZoneListResultPage)
zonesClient := mockPrivateZonesClient{
mockZonesClientIterator: &mockZoneClientIterator,
}
// init record-related parts of the mock-client
resultPageIterator := mockPrivateRecordSetListResultPageIterator{
results: []privatedns.RecordSetListResult{
{
Value: recordSets,
},
},
}
mockRecordSetListResultPage := privatedns.NewRecordSetListResultPage(privatedns.RecordSetListResult{}, resultPageIterator.getNextPage)
mockRecordSetListIterator := privatedns.NewRecordSetListResultIterator(mockRecordSetListResultPage)
recordSetsClient := mockPrivateRecordSetsClient{
mockRecordSetListIterator: &mockRecordSetListIterator,
}
return newAzurePrivateDNSProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
}
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider {
return &AzurePrivateDNSProvider{
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: resourceGroup,
zonesClient: privateZonesClient,
recordSetsClient: privateRecordsClient,
}
}
func TestAzurePrivateDNSRecord(t *testing.T) {
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
&[]privatedns.PrivateZone{
createMockPrivateZone("example.com", "/privateDnsZones/example.com"),
},
&[]privatedns.RecordSet{
createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."),
createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"),
createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"),
createPrivateMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"),
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600),
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
})
if err != nil {
t.Fatal(err)
}
actual, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
expected := []*endpoint.Endpoint{
endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122"),
endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"),
endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"),
}
validateAzureEndpoints(t, actual, expected)
}
func TestAzurePrivateDNSMultiRecord(t *testing.T) {
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
&[]privatedns.PrivateZone{
createMockPrivateZone("example.com", "/privateDnsZones/example.com"),
},
&[]privatedns.RecordSet{
createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."),
createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"),
createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"),
createPrivateMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"),
createPrivateMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"),
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
})
if err != nil {
t.Fatal(err)
}
actual, err := provider.Records(context.Background())
if err != nil {
t.Fatal(err)
}
expected := []*endpoint.Endpoint{
endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"),
endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"),
endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"),
endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"),
}
validateAzureEndpoints(t, actual, expected)
}
func TestAzurePrivateDNSApplyChanges(t *testing.T) {
recordsClient := mockPrivateRecordSetsClient{}
testAzurePrivateDNSApplyChangesInternal(t, false, &recordsClient)
validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{
endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, ""),
endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, ""),
endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, ""),
endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, ""),
})
validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4"),
endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"),
endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"),
endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"),
endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "other.com"),
endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"),
endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "5.6.7.8"),
endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"),
endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"),
endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"),
})
}
func TestAzurePrivateDNSApplyChangesDryRun(t *testing.T) {
recordsClient := mockRecordSetsClient{}
testAzureApplyChangesInternal(t, true, &recordsClient)
validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{})
validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{})
}
func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client PrivateRecordSetsClient) {
zlr := privatedns.PrivateZoneListResult{
Value: &[]privatedns.PrivateZone{
createMockPrivateZone("example.com", "/privateDnsZones/example.com"),
createMockPrivateZone("other.com", "/privateDnsZones/other.com"),
},
}
results := []privatedns.PrivateZoneListResult{
zlr,
}
mockZoneListResultPage := privatedns.NewPrivateZoneListResultPage(privatedns.PrivateZoneListResult{}, func(ctxParam context.Context, zlrParam privatedns.PrivateZoneListResult) (privatedns.PrivateZoneListResult, error) {
if len(results) > 0 {
result := results[0]
results = nil
return result, nil
}
return privatedns.PrivateZoneListResult{}, nil
})
mockZoneClientIterator := privatedns.NewPrivateZoneListResultIterator(mockZoneListResultPage)
zonesClient := mockPrivateZonesClient{
mockZonesClientIterator: &mockZoneClientIterator,
}
provider := newAzurePrivateDNSProvider(
endpoint.NewDomainFilter([]string{""}),
provider.NewZoneIDFilter([]string{""}),
dryRun,
"group",
&zonesClient,
client,
)
createRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"),
endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"),
endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"),
endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"),
endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"),
}
currentRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, "121.212.121.212"),
endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("old.nope.com", endpoint.RecordTypeA, "121.212.121.212"),
}
updatedRecords := []*endpoint.Endpoint{
endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"),
endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"),
endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeA, "222.111.222.111"),
}
deleteRecords := []*endpoint.Endpoint{
endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, "111.222.111.222"),
endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, "other.com"),
endpoint.NewEndpoint("deleted.nope.com", endpoint.RecordTypeA, "222.111.222.111"),
}
changes := &plan.Changes{
Create: createRecords,
UpdateNew: updatedRecords,
UpdateOld: currentRecords,
Delete: deleteRecords,
}
if err := provider.ApplyChanges(context.Background(), changes); err != nil {
t.Fatal(err)
}
}