diff --git a/main.go b/main.go index 22fefec95..eb6cd24d8 100644 --- a/main.go +++ b/main.go @@ -253,9 +253,9 @@ func main() { } p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, sd.New(awsSession)) case "azure-dns", "azure": - p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun) + p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.DryRun) case "azure-private-dns": - p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun) + p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.DryRun) case "bluecat": p, err = bluecat.NewBluecatProvider(cfg.BluecatConfigFile, cfg.BluecatDNSConfiguration, cfg.BluecatDNSServerName, cfg.BluecatDNSDeployType, cfg.BluecatDNSView, cfg.BluecatGatewayHost, cfg.BluecatRootZone, cfg.TXTPrefix, cfg.TXTSuffix, domainFilter, zoneIDFilter, cfg.DryRun, cfg.BluecatSkipTLSVerify) case "vinyldns": diff --git a/pkg/apis/externaldns/types.go b/pkg/apis/externaldns/types.go index 0b25dcafe..26b579f82 100644 --- a/pkg/apis/externaldns/types.go +++ b/pkg/apis/externaldns/types.go @@ -101,6 +101,7 @@ type Config struct { AzureResourceGroup string AzureSubscriptionID string AzureUserAssignedIdentityClientID string + AzureActiveDirectoryAuthorityHost string BluecatDNSConfiguration string BluecatConfigFile string BluecatDNSView string diff --git a/provider/azure/azure.go b/provider/azure/azure.go index a7021d192..904fa6223 100644 --- a/provider/azure/azure.go +++ b/provider/azure/azure.go @@ -58,6 +58,7 @@ type AzureProvider struct { dryRun bool resourceGroup string userAssignedIdentityClientID string + activeDirectoryAuthorityHost string zonesClient ZonesClient recordSetsClient RecordSetsClient } @@ -65,8 +66,8 @@ type AzureProvider struct { // NewAzureProvider creates a new Azure provider. // // Returns the provider or an error if a provider could not be created. -func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzureProvider, error) { - cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID) +func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, dryRun bool) (*AzureProvider, error) { + cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost) if err != nil { return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) } @@ -90,6 +91,7 @@ func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zon dryRun: dryRun, resourceGroup: cfg.ResourceGroup, userAssignedIdentityClientID: cfg.UserAssignedIdentityID, + activeDirectoryAuthorityHost: cfg.ActiveDirectoryAuthorityHost, zonesClient: zonesClient, recordSetsClient: recordSetsClient, }, nil diff --git a/provider/azure/azure_private_dns.go b/provider/azure/azure_private_dns.go index 43e3bdc43..7b56cb6a0 100644 --- a/provider/azure/azure_private_dns.go +++ b/provider/azure/azure_private_dns.go @@ -52,6 +52,7 @@ type AzurePrivateDNSProvider struct { dryRun bool resourceGroup string userAssignedIdentityClientID string + activeDirectoryAuthorityHost string zonesClient PrivateZonesClient recordSetsClient PrivateRecordSetsClient } @@ -59,8 +60,8 @@ type AzurePrivateDNSProvider struct { // NewAzurePrivateDNSProvider creates a new Azure Private DNS provider. // // Returns the provider or an error if a provider could not be created. -func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzurePrivateDNSProvider, error) { - cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID) +func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, dryRun bool) (*AzurePrivateDNSProvider, error) { + cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost) if err != nil { return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) } @@ -83,6 +84,7 @@ func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainF dryRun: dryRun, resourceGroup: cfg.ResourceGroup, userAssignedIdentityClientID: cfg.UserAssignedIdentityID, + activeDirectoryAuthorityHost: cfg.ActiveDirectoryAuthorityHost, zonesClient: zonesClient, recordSetsClient: recordSetsClient, }, nil diff --git a/provider/azure/azure_test.go b/provider/azure/azure_test.go index 4fd9fa8fc..f2031d139 100644 --- a/provider/azure/azure_test.go +++ b/provider/azure/azure_test.go @@ -222,13 +222,13 @@ func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values } // newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets -func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) { +func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) { zonesClient := newMockZonesClient(zones) recordSetsClient := newMockRecordSetsClient(recordSets) - return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil + return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost, &zonesClient, &recordSetsClient), nil } -func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider { +func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider { return &AzureProvider{ domainFilter: domainFilter, zoneNameFilter: zoneNameFilter, @@ -236,6 +236,7 @@ func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoin dryRun: dryRun, resourceGroup: resourceGroup, userAssignedIdentityClientID: userAssignedIdentityClientID, + activeDirectoryAuthorityHost: activeDirectoryAuthorityHost, zonesClient: zonesClient, recordSetsClient: recordsClient, } @@ -246,7 +247,7 @@ func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expect } func TestAzureRecord(t *testing.T) { - provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", "", []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, @@ -286,7 +287,7 @@ func TestAzureRecord(t *testing.T) { } func TestAzureMultiRecord(t *testing.T) { - provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", "", []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, @@ -381,6 +382,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC dryRun, "group", "", + "", &zonesClient, client, ) @@ -440,7 +442,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC } func TestAzureNameFilter(t *testing.T) { - provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", "", []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, @@ -506,6 +508,7 @@ func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client Rec dryRun, "group", "", + "", &zonesClient, client, ) diff --git a/provider/azure/config.go b/provider/azure/config.go index 7df0e7667..eca4ae0d6 100644 --- a/provider/azure/config.go +++ b/provider/azure/config.go @@ -41,9 +41,10 @@ type config struct { UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"` UseWorkloadIdentityExtension bool `json:"useWorkloadIdentityExtension" yaml:"useWorkloadIdentityExtension"` UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"` + ActiveDirectoryAuthorityHost string `json:"activeDirectoryAuthorityHost" yaml:"activeDirectoryAuthorityHost"` } -func getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID string) (*config, error) { +func getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost string) (*config, error) { contents, err := os.ReadFile(configFile) if err != nil { return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) @@ -65,6 +66,10 @@ func getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityCl if userAssignedIdentityClientID != "" { cfg.UserAssignedIdentityID = userAssignedIdentityClientID } + // If activeDirectoryAuthorityHost is provided explicitly, override existing one in config file + if activeDirectoryAuthorityHost != "" { + cfg.ActiveDirectoryAuthorityHost = activeDirectoryAuthorityHost + } return cfg, nil } @@ -152,17 +157,6 @@ func getCloudConfiguration(name string) (cloud.Configuration, error) { return cloud.AzureGovernment, nil case "AZURECHINACLOUD": return cloud.AzureChina, nil - case "AZURECUSTOMCLOUD": - azureAdEndpoint := os.Getenv("AZURE_AD_ENDPOINT") - - if azureAdEndpoint == "" { - return cloud.Configuration{}, fmt.Errorf("AD Endpoint Not set: %s", name) - } else { - customCloud := cloud.Configuration{ - ActiveDirectoryAuthorityHost: os.Getenv("AZURE_AD_ENDPOINT"), - } - return customCloud, nil - } } return cloud.Configuration{}, fmt.Errorf("unknown cloud name: %s", name) } diff --git a/provider/azure/config_test.go b/provider/azure/config_test.go index ab52c72c2..1099515d3 100644 --- a/provider/azure/config_test.go +++ b/provider/azure/config_test.go @@ -17,7 +17,6 @@ limitations under the License. package azure import ( - "os" "path" "runtime" "testing" @@ -30,23 +29,14 @@ func TestGetCloudConfiguration(t *testing.T) { tests := map[string]struct { cloudName string expected cloud.Configuration - setEnv map[string]string }{ - "AzureChinaCloud": {"AzureChinaCloud", cloud.AzureChina, nil}, - "AzurePublicCloud": {"", cloud.AzurePublic, nil}, - "AzureUSGovernment": {"AzureUSGovernmentCloud", cloud.AzureGovernment, nil}, - "AzureCustomCloud": {"AzureCustomCloud", cloud.Configuration{ActiveDirectoryAuthorityHost: "https://custom.microsoftonline.com/"}, map[string]string{"AZURE_AD_ENDPOINT": "https://custom.microsoftonline.com/"}}, + "AzureChinaCloud": {"AzureChinaCloud", cloud.AzureChina}, + "AzurePublicCloud": {"", cloud.AzurePublic}, + "AzureUSGovernment": {"AzureUSGovernmentCloud", cloud.AzureGovernment}, } for name, test := range tests { t.Run(name, func(t *testing.T) { - if test.setEnv != nil { - for key, value := range test.setEnv { - os.Setenv(key, value) - defer os.Unsetenv(key) - } - } - cloudCfg, err := getCloudConfiguration(test.cloudName) if err != nil { t.Errorf("got unexpected err %v", err) @@ -61,10 +51,11 @@ func TestGetCloudConfiguration(t *testing.T) { func TestOverrideConfiguration(t *testing.T) { _, filename, _, _ := runtime.Caller(0) configFile := path.Join(path.Dir(filename), "config_test.json") - cfg, err := getConfig(configFile, "subscription-override", "rg-override", "") + cfg, err := getConfig(configFile, "subscription-override", "rg-override", "", "aad-endpoint-override") if err != nil { t.Errorf("got unexpected err %v", err) } assert.Equal(t, cfg.SubscriptionID, "subscription-override") assert.Equal(t, cfg.ResourceGroup, "rg-override") + assert.Equal(t, cfg.ActiveDirectoryAuthorityHost, "aad-endpoint-override") }