mirror of
https://github.com/kubernetes-sigs/external-dns.git
synced 2025-08-06 01:26:59 +02:00
fix(azure): enhance retry logic using azure SDK (#5361)
* fix(azure): Enhance retry logic using azure SDK * Added the changes for flag based maxretries configuration * Fixed types.go, flags.md and delected unneccesary comments * Added the correct image for the Azure Private DNS tutorial * Following the go naming convention for maxRetriesCount * Added the correct flag information to the --azure-maxretries-count * Made the required changes to accept the --azure-maxretries-count flag value from cli/env
This commit is contained in:
parent
09d49a106a
commit
7b9d8d9355
@ -186,9 +186,9 @@ func Execute() {
|
||||
}
|
||||
p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, cfg.AWSSDCreateTag, sd.NewFromConfig(aws.CreateDefaultV2Config(cfg)))
|
||||
case "azure-dns", "azure":
|
||||
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.DryRun)
|
||||
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.AzureMaxRetriesCount, cfg.DryRun)
|
||||
case "azure-private-dns":
|
||||
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.DryRun)
|
||||
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.AzureMaxRetriesCount, cfg.DryRun)
|
||||
case "ultradns":
|
||||
p, err = ultradns.NewUltraDNSProvider(domainFilter, cfg.DryRun)
|
||||
case "civo":
|
||||
|
@ -86,6 +86,7 @@
|
||||
| `--azure-subscription-id=""` | When using the Azure provider, override the Azure subscription to use (optional) |
|
||||
| `--azure-user-assigned-identity-client-id=""` | When using the Azure provider, override the client id of user assigned identity in config file (optional) |
|
||||
| `--azure-zones-cache-duration=0s` | When using the Azure provider, set the zones list cache TTL (0s to disable). |
|
||||
| `--azure-maxretries-count=3` | When using the Azure provider, set the number of retries for API calls (When less than 0, it disables retries). (optional) |
|
||||
| `--tencent-cloud-config-file="/etc/kubernetes/tencent-cloud.json"` | When using the Tencent Cloud provider, specify the Tencent Cloud configuration file (required when --provider=tencentcloud) |
|
||||
| `--tencent-cloud-zone-type=` | When using the Tencent Cloud provider, filter for zones with visibility (optional, options: public, private) |
|
||||
| `--[no-]cloudflare-proxied` | When using the Cloudflare provider, specify if the proxy mode must be enabled (default: disabled) |
|
||||
|
@ -108,6 +108,7 @@ $ az role assignment create --role "Private DNS Zone Contributor" --assignee <ap
|
||||
## Throttling
|
||||
|
||||
When the ExternalDNS managed zones list doesn't change frequently, one can set `--azure-zones-cache-duration` (zones list cache time-to-live). The zones list cache is disabled by default, with a value of 0s.
|
||||
Also, one can leverage the built-in retry policies of the Azure SDK. The flag --azure-maxretries-count can be specified in the manifest yaml to configure behavior. The default value of Azure SDK retry is 3.
|
||||
|
||||
## Deploy ExternalDNS
|
||||
|
||||
@ -151,6 +152,7 @@ spec:
|
||||
- --provider=azure-private-dns
|
||||
- --azure-resource-group=externaldns
|
||||
- --azure-subscription-id=<use the id of your subscription>
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
@ -223,6 +225,7 @@ spec:
|
||||
- --provider=azure-private-dns
|
||||
- --azure-resource-group=externaldns
|
||||
- --azure-subscription-id=<use the id of your subscription>
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
@ -295,6 +298,7 @@ spec:
|
||||
- --provider=azure-private-dns
|
||||
- --azure-resource-group=externaldns
|
||||
- --azure-subscription-id=<use the id of your subscription>
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
|
@ -493,6 +493,7 @@ NOTE: make sure the pod is restarted whenever you make a configuration change.
|
||||
## Throttling
|
||||
|
||||
When the ExternalDNS managed zones list doesn't change frequently, one can set `--azure-zones-cache-duration` (zones list cache time-to-live). The zones list cache is disabled by default, with a value of 0s.
|
||||
Also, one can leverage the built-in retry policies of the Azure SDK with a tunable maxRetries value. Environment variable AZURE_SDK_MAX_RETRIES can be specified in the manifest yaml to configure behavior. The defualt value of Azure SDK retry is 3.
|
||||
|
||||
## Ingress used with ExternalDNS
|
||||
|
||||
@ -540,6 +541,7 @@ spec:
|
||||
- --domain-filter=example.com # (optional) limit to only example.com domains; change to match the zone created above.
|
||||
- --provider=azure
|
||||
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
@ -609,6 +611,7 @@ spec:
|
||||
- --provider=azure
|
||||
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
|
||||
- --txt-prefix=externaldns-
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
@ -680,6 +683,7 @@ spec:
|
||||
- --domain-filter=example.com # (optional) limit to only example.com domains; change to match the zone created above.
|
||||
- --provider=azure
|
||||
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
|
||||
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
|
||||
volumeMounts:
|
||||
- name: azure-config-file
|
||||
mountPath: /etc/kubernetes
|
||||
|
@ -106,6 +106,7 @@ type Config struct {
|
||||
AzureUserAssignedIdentityClientID string
|
||||
AzureActiveDirectoryAuthorityHost string
|
||||
AzureZonesCacheDuration time.Duration
|
||||
AzureMaxRetriesCount int
|
||||
CloudflareProxied bool
|
||||
CloudflareCustomHostnames bool
|
||||
CloudflareCustomHostnamesMinTLSVersion string
|
||||
@ -247,6 +248,7 @@ var defaultConfig = &Config{
|
||||
AzureResourceGroup: "",
|
||||
AzureSubscriptionID: "",
|
||||
AzureZonesCacheDuration: 0 * time.Second,
|
||||
AzureMaxRetriesCount: 3,
|
||||
CFAPIEndpoint: "",
|
||||
CFPassword: "",
|
||||
CFUsername: "",
|
||||
@ -527,6 +529,7 @@ func App(cfg *Config) *kingpin.Application {
|
||||
app.Flag("azure-subscription-id", "When using the Azure provider, override the Azure subscription to use (optional)").Default(defaultConfig.AzureSubscriptionID).StringVar(&cfg.AzureSubscriptionID)
|
||||
app.Flag("azure-user-assigned-identity-client-id", "When using the Azure provider, override the client id of user assigned identity in config file (optional)").Default("").StringVar(&cfg.AzureUserAssignedIdentityClientID)
|
||||
app.Flag("azure-zones-cache-duration", "When using the Azure provider, set the zones list cache TTL (0s to disable).").Default(defaultConfig.AzureZonesCacheDuration.String()).DurationVar(&cfg.AzureZonesCacheDuration)
|
||||
app.Flag("azure-maxretries-count", "When using the Azure provider, set the number of retries for API calls (When less than 0, it disables retries). (optional)").Default(strconv.Itoa(defaultConfig.AzureMaxRetriesCount)).IntVar(&cfg.AzureMaxRetriesCount)
|
||||
app.Flag("tencent-cloud-config-file", "When using the Tencent Cloud provider, specify the Tencent Cloud configuration file (required when --provider=tencentcloud)").Default(defaultConfig.TencentCloudConfigFile).StringVar(&cfg.TencentCloudConfigFile)
|
||||
app.Flag("tencent-cloud-zone-type", "When using the Tencent Cloud provider, filter for zones with visibility (optional, options: public, private)").Default(defaultConfig.TencentCloudZoneType).EnumVar(&cfg.TencentCloudZoneType, "", "public", "private")
|
||||
|
||||
|
@ -73,6 +73,7 @@ var (
|
||||
AzureConfigFile: "/etc/kubernetes/azure.json",
|
||||
AzureResourceGroup: "",
|
||||
AzureSubscriptionID: "",
|
||||
AzureMaxRetriesCount: 3,
|
||||
CloudflareProxied: false,
|
||||
CloudflareCustomHostnames: false,
|
||||
CloudflareCustomHostnamesMinTLSVersion: "1.0",
|
||||
@ -183,6 +184,7 @@ var (
|
||||
AzureConfigFile: "azure.json",
|
||||
AzureResourceGroup: "arg",
|
||||
AzureSubscriptionID: "arg",
|
||||
AzureMaxRetriesCount: 4,
|
||||
CloudflareProxied: true,
|
||||
CloudflareCustomHostnames: true,
|
||||
CloudflareCustomHostnamesMinTLSVersion: "1.3",
|
||||
@ -296,6 +298,7 @@ func TestParseFlags(t *testing.T) {
|
||||
"--azure-config-file=azure.json",
|
||||
"--azure-resource-group=arg",
|
||||
"--azure-subscription-id=arg",
|
||||
"--azure-maxretries-count=4",
|
||||
"--cloudflare-proxied",
|
||||
"--cloudflare-custom-hostnames",
|
||||
"--cloudflare-custom-hostnames-min-tls-version=1.3",
|
||||
@ -427,6 +430,7 @@ func TestParseFlags(t *testing.T) {
|
||||
"EXTERNAL_DNS_AZURE_CONFIG_FILE": "azure.json",
|
||||
"EXTERNAL_DNS_AZURE_RESOURCE_GROUP": "arg",
|
||||
"EXTERNAL_DNS_AZURE_SUBSCRIPTION_ID": "arg",
|
||||
"EXTERNAL_DNS_AZURE_MAXRETRIES_COUNT": "4",
|
||||
"EXTERNAL_DNS_CLOUDFLARE_PROXIED": "1",
|
||||
"EXTERNAL_DNS_CLOUDFLARE_CUSTOM_HOSTNAMES": "1",
|
||||
"EXTERNAL_DNS_CLOUDFLARE_CUSTOM_HOSTNAMES_MIN_TLS_VERSION": "1.3",
|
||||
|
@ -63,17 +63,19 @@ type AzureProvider struct {
|
||||
zonesClient ZonesClient
|
||||
zonesCache *zonesCache[dns.Zone]
|
||||
recordSetsClient RecordSetsClient
|
||||
maxRetriesCount int
|
||||
}
|
||||
|
||||
// NewAzureProvider creates a new Azure provider.
|
||||
//
|
||||
// Returns the provider or an error if a provider could not be created.
|
||||
func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, dryRun bool) (*AzureProvider, error) {
|
||||
func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, maxRetriesCount int, dryRun bool) (*AzureProvider, error) {
|
||||
cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
|
||||
}
|
||||
cred, clientOpts, err := getCredentials(*cfg)
|
||||
|
||||
cred, clientOpts, err := getCredentials(*cfg, maxRetriesCount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@ -97,6 +99,7 @@ func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zon
|
||||
zonesClient: zonesClient,
|
||||
zonesCache: &zonesCache[dns.Zone]{duration: zonesCacheDuration},
|
||||
recordSetsClient: recordSetsClient,
|
||||
maxRetriesCount: maxRetriesCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -58,17 +58,19 @@ type AzurePrivateDNSProvider struct {
|
||||
zonesClient PrivateZonesClient
|
||||
zonesCache *zonesCache[privatedns.PrivateZone]
|
||||
recordSetsClient PrivateRecordSetsClient
|
||||
maxRetriesCount int
|
||||
}
|
||||
|
||||
// NewAzurePrivateDNSProvider creates a new Azure Private DNS provider.
|
||||
//
|
||||
// Returns the provider or an error if a provider could not be created.
|
||||
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, dryRun bool) (*AzurePrivateDNSProvider, error) {
|
||||
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, maxRetriesCount int, dryRun bool) (*AzurePrivateDNSProvider, error) {
|
||||
cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
|
||||
}
|
||||
cred, clientOpts, err := getCredentials(*cfg)
|
||||
|
||||
cred, clientOpts, err := getCredentials(*cfg, maxRetriesCount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@ -92,6 +94,7 @@ func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainF
|
||||
zonesClient: zonesClient,
|
||||
zonesCache: &zonesCache[privatedns.PrivateZone]{duration: zonesCacheDuration},
|
||||
recordSetsClient: recordSetsClient,
|
||||
maxRetriesCount: maxRetriesCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -224,13 +224,13 @@ func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64,
|
||||
}
|
||||
|
||||
// newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
|
||||
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) {
|
||||
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet, maxRetriesCount int) (*AzurePrivateDNSProvider, error) {
|
||||
zonesClient := newMockPrivateZonesClient(zones)
|
||||
recordSetsClient := newMockPrivateRecordSectsClient(recordSets)
|
||||
return newAzurePrivateDNSProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
|
||||
return newAzurePrivateDNSProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient, maxRetriesCount), nil
|
||||
}
|
||||
|
||||
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider {
|
||||
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient, maxRetriesCount int) *AzurePrivateDNSProvider {
|
||||
return &AzurePrivateDNSProvider{
|
||||
domainFilter: domainFilter,
|
||||
zoneNameFilter: zoneNameFilter,
|
||||
@ -240,6 +240,7 @@ func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilt
|
||||
zonesClient: privateZonesClient,
|
||||
zonesCache: &zonesCache[privatedns.PrivateZone]{duration: 0},
|
||||
recordSetsClient: privateRecordsClient,
|
||||
maxRetriesCount: maxRetriesCount,
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,7 +260,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) {
|
||||
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
|
||||
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
createPrivateMockRecordSetWithTTL("mail", endpoint.RecordTypeMX, "10 example.com", 4000),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -298,7 +299,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) {
|
||||
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
|
||||
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
createPrivateMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -378,6 +379,7 @@ func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client P
|
||||
"group",
|
||||
&zonesClient,
|
||||
client,
|
||||
3,
|
||||
)
|
||||
|
||||
createRecords := []*endpoint.Endpoint{
|
||||
@ -450,7 +452,7 @@ func TestAzurePrivateDNSNameFilter(t *testing.T) {
|
||||
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
|
||||
createPrivateMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL),
|
||||
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -505,6 +507,7 @@ func testAzurePrivateDNSApplyChangesInternalZoneName(t *testing.T, dryRun bool,
|
||||
"group",
|
||||
&zonesClient,
|
||||
client,
|
||||
3,
|
||||
)
|
||||
|
||||
createRecords := []*endpoint.Endpoint{
|
||||
|
@ -237,13 +237,13 @@ func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values
|
||||
}
|
||||
|
||||
// newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
|
||||
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) {
|
||||
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zones []*dns.Zone, recordSets []*dns.RecordSet, maxRetriesCount int) (*AzureProvider, error) {
|
||||
zonesClient := newMockZonesClient(zones)
|
||||
recordSetsClient := newMockRecordSetsClient(recordSets)
|
||||
return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost, &zonesClient, &recordSetsClient), nil
|
||||
return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost, &zonesClient, &recordSetsClient, maxRetriesCount), nil
|
||||
}
|
||||
|
||||
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider {
|
||||
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesClient ZonesClient, recordsClient RecordSetsClient, maxRetriesCount int) *AzureProvider {
|
||||
return &AzureProvider{
|
||||
domainFilter: domainFilter,
|
||||
zoneNameFilter: zoneNameFilter,
|
||||
@ -255,6 +255,7 @@ func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoin
|
||||
zonesClient: zonesClient,
|
||||
zonesCache: &zonesCache[dns.Zone]{duration: 0},
|
||||
recordSetsClient: recordsClient,
|
||||
maxRetriesCount: maxRetriesCount,
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,7 +281,7 @@ func TestAzureRecord(t *testing.T) {
|
||||
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
|
||||
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com"),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -325,7 +326,7 @@ func TestAzureMultiRecord(t *testing.T) {
|
||||
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
|
||||
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -415,6 +416,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC
|
||||
"",
|
||||
&zonesClient,
|
||||
client,
|
||||
3,
|
||||
)
|
||||
|
||||
createRecords := []*endpoint.Endpoint{
|
||||
@ -497,7 +499,7 @@ func TestAzureNameFilter(t *testing.T) {
|
||||
createMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL),
|
||||
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
|
||||
createMockRecordSetWithTTL("hack", endpoint.RecordTypeNS, "ns1.example.com.", 3600),
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -555,6 +557,7 @@ func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client Rec
|
||||
"",
|
||||
&zonesClient,
|
||||
client,
|
||||
3,
|
||||
)
|
||||
|
||||
createRecords := []*endpoint.Endpoint{
|
||||
|
@ -17,15 +17,19 @@ limitations under the License.
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -72,15 +76,57 @@ func getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityCl
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ctxKey is a type for context keys
|
||||
// This is used to avoid collisions with other packages that may use the same key in the context.
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
// Context key for request ID
|
||||
clientRequestIDKey ctxKey = "client-request-id"
|
||||
// Azure API Headers
|
||||
msRequestIDHeader = "x-ms-request-id"
|
||||
msCorrelationRequestHeader = "x-ms-correlation-request-id"
|
||||
msClientRequestIDHeader = "x-ms-client-request-id"
|
||||
)
|
||||
|
||||
// customHeaderPolicy adds UUID to request headers
|
||||
type customHeaderPolicy struct{}
|
||||
|
||||
func (p *customHeaderPolicy) Do(req *policy.Request) (*http.Response, error) {
|
||||
id := req.Raw().Header.Get(msClientRequestIDHeader)
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
req.Raw().Header.Set(msClientRequestIDHeader, id)
|
||||
newCtx := context.WithValue(req.Raw().Context(), clientRequestIDKey, id)
|
||||
*req.Raw() = *req.Raw().WithContext(newCtx)
|
||||
}
|
||||
return req.Next()
|
||||
}
|
||||
func CustomHeaderPolicynew() policy.Policy { return &customHeaderPolicy{} }
|
||||
|
||||
// getCredentials retrieves Azure API credentials.
|
||||
func getCredentials(cfg config) (azcore.TokenCredential, *arm.ClientOptions, error) {
|
||||
func getCredentials(cfg config, maxRetries int) (azcore.TokenCredential, *arm.ClientOptions, error) {
|
||||
cloudCfg, err := getCloudConfiguration(cfg.Cloud)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get cloud configuration: %w", err)
|
||||
}
|
||||
clientOpts := azcore.ClientOptions{
|
||||
Cloud: cloudCfg,
|
||||
Retry: policy.RetryOptions{
|
||||
MaxRetries: int32(maxRetries),
|
||||
},
|
||||
Logging: policy.LogOptions{
|
||||
AllowedHeaders: []string{
|
||||
msRequestIDHeader,
|
||||
msCorrelationRequestHeader,
|
||||
msClientRequestIDHeader,
|
||||
},
|
||||
},
|
||||
PerCallPolicies: []policy.Policy{
|
||||
CustomHeaderPolicynew(),
|
||||
},
|
||||
}
|
||||
log.Debugf("Configured Azure client with maxRetries: %d", clientOpts.Retry.MaxRetries)
|
||||
armClientOpts := &arm.ClientOptions{
|
||||
ClientOptions: clientOpts,
|
||||
}
|
||||
|
@ -17,11 +17,19 @@ limitations under the License.
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -59,3 +67,268 @@ func TestOverrideConfiguration(t *testing.T) {
|
||||
assert.Equal(t, cfg.ResourceGroup, "rg-override")
|
||||
assert.Equal(t, cfg.ActiveDirectoryAuthorityHost, "aad-endpoint-override")
|
||||
}
|
||||
|
||||
// Test for custom header policy
|
||||
type transportFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f transportFunc) Do(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestCustomHeaderPolicyWithRetries(t *testing.T) {
|
||||
// Set up test environment
|
||||
defaultRetries := 3
|
||||
flagValue := "-6"
|
||||
isSet := true
|
||||
|
||||
retries, err := parseMaxRetries(flagValue, defaultRetries)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse retries: %v", err)
|
||||
}
|
||||
maxRetries := int32(retries)
|
||||
if !isSet || (isSet && flagValue == "0") {
|
||||
// Use default if flag not provided OR if flag is "0"
|
||||
maxRetries = int32(defaultRetries)
|
||||
t.Logf("Using default value: %d (flag provided: %v, value: %q)",
|
||||
defaultRetries, isSet, flagValue)
|
||||
} else {
|
||||
// Flag was provided with non-zero value
|
||||
retries, err := parseMaxRetries(flagValue, defaultRetries)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse retries: %v", err)
|
||||
}
|
||||
maxRetries = int32(retries)
|
||||
t.Logf("Using provided flag value: %d", retries)
|
||||
}
|
||||
|
||||
var attempt int32
|
||||
var firstRequestID string
|
||||
|
||||
// Create mock transport that simulates 429 responses
|
||||
mockTransport := transportFunc(func(req *http.Request) (*http.Response, error) {
|
||||
attempt++
|
||||
|
||||
// Get the request ID from header
|
||||
requestID := req.Header.Get("x-ms-client-request-id")
|
||||
if requestID == "" {
|
||||
t.Fatalf("Request ID missing on attempt %d", attempt)
|
||||
}
|
||||
|
||||
// On first attempt, store the request ID
|
||||
if attempt == 1 {
|
||||
firstRequestID = requestID
|
||||
t.Logf("Initial request ID: %s", firstRequestID)
|
||||
} else {
|
||||
// On subsequent attempts, verify it matches the first request ID
|
||||
if requestID != firstRequestID {
|
||||
t.Fatalf("Request ID changed on retry %d: got %s, want %s",
|
||||
attempt, requestID, firstRequestID)
|
||||
} else {
|
||||
t.Logf("Request ID preserved on attempt %d: %s", attempt, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the ID is also in the context
|
||||
if ctxID, ok := req.Context().Value(clientRequestIDKey).(string); !ok || ctxID != requestID {
|
||||
t.Errorf("Context ID mismatch on attempt %d: got %v, want %s",
|
||||
attempt, ctxID, requestID)
|
||||
}
|
||||
|
||||
// Return 429 for all but the last attempt
|
||||
if maxRetries < 0 || attempt <= maxRetries {
|
||||
t.Logf("Attempt %d: THROTTLED (429) - Request ID: %s", attempt, requestID)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: io.NopCloser(strings.NewReader("Too many requests")),
|
||||
Request: req,
|
||||
Header: http.Header{
|
||||
"x-ms-client-request-id": []string{requestID},
|
||||
"Retry-After": []string{"1"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Return 200 on final attempt
|
||||
t.Logf("Attempt %d: SUCCESS (200) - Request ID: %s", attempt, requestID)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("Success")),
|
||||
Request: req,
|
||||
Header: http.Header{
|
||||
"x-ms-client-request-id": []string{requestID},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
|
||||
// Create pipeline with retry policy and custom header policy
|
||||
mockPipeline := azruntime.NewPipeline(
|
||||
"testmodule",
|
||||
"1.0",
|
||||
azruntime.PipelineOptions{
|
||||
PerCall: []policy.Policy{
|
||||
CustomHeaderPolicynew(),
|
||||
},
|
||||
},
|
||||
&policy.ClientOptions{
|
||||
Retry: policy.RetryOptions{
|
||||
MaxRetries: maxRetries,
|
||||
},
|
||||
Transport: mockTransport,
|
||||
},
|
||||
)
|
||||
// Create request and execute
|
||||
req, err := azruntime.NewRequest(context.Background(), http.MethodGet, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := mockPipeline.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Verify we got the expected number of attempts
|
||||
var expectedAttempts int32
|
||||
if maxRetries < 0 {
|
||||
expectedAttempts = 1 // For negative retries, only one attempt should be made
|
||||
} else {
|
||||
expectedAttempts = maxRetries + 1 // For zero or positive retries, attempts = retries + 1
|
||||
}
|
||||
|
||||
if attempt != expectedAttempts {
|
||||
t.Errorf("Wrong number of attempts: got %d, want %d", attempt, expectedAttempts)
|
||||
}
|
||||
|
||||
t.Logf("Test completed with %d attempts, all with request ID: %s", attempt, firstRequestID)
|
||||
}
|
||||
|
||||
func TestMaxRetriesCount(t *testing.T) {
|
||||
defaultRetries := 3
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
isSet bool // indicates if flag was provided
|
||||
expected int
|
||||
shouldError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "FlagNotProvided",
|
||||
input: "",
|
||||
isSet: false,
|
||||
expected: defaultRetries,
|
||||
shouldError: false,
|
||||
description: "When flag is not provided, should use default value",
|
||||
},
|
||||
{
|
||||
name: "FlagProvidedEmpty",
|
||||
input: "",
|
||||
isSet: true,
|
||||
expected: 0,
|
||||
shouldError: true,
|
||||
description: "When flag is provided but empty, should error",
|
||||
},
|
||||
{
|
||||
name: "ValidPositive",
|
||||
input: "5",
|
||||
isSet: true,
|
||||
expected: 5,
|
||||
shouldError: false,
|
||||
description: "Valid positive number should be accepted",
|
||||
},
|
||||
{
|
||||
name: "ZeroRetries",
|
||||
input: "0",
|
||||
isSet: true,
|
||||
expected: 0,
|
||||
shouldError: false,
|
||||
description: "Zero should be accepted and handled by SDK",
|
||||
},
|
||||
{
|
||||
name: "NegativeRetries",
|
||||
input: "-2",
|
||||
isSet: true,
|
||||
expected: -2,
|
||||
shouldError: false,
|
||||
description: "Negative values should be accepted and handled by SDK",
|
||||
},
|
||||
{
|
||||
name: "InvalidString",
|
||||
input: "abc",
|
||||
isSet: true,
|
||||
expected: 0,
|
||||
shouldError: true,
|
||||
description: "Non-numeric string should error",
|
||||
},
|
||||
{
|
||||
name: "Whitespace",
|
||||
input: " ",
|
||||
isSet: true,
|
||||
expected: 0,
|
||||
shouldError: true,
|
||||
description: "Whitespace should error",
|
||||
},
|
||||
{
|
||||
name: "SpecialChars",
|
||||
input: "@#$%",
|
||||
isSet: true,
|
||||
expected: 0,
|
||||
shouldError: true,
|
||||
description: "Special characters should error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Logf("=== Test Case: %s ===", tt.name)
|
||||
t.Logf("Description: %s", tt.description)
|
||||
t.Logf("Input: %q (flag provided: %v)", tt.input, tt.isSet)
|
||||
|
||||
// Handle flag not provided case
|
||||
if !tt.isSet {
|
||||
t.Logf("Using default value: %d", defaultRetries)
|
||||
return
|
||||
}
|
||||
|
||||
retries, err := parseMaxRetries(tt.input, defaultRetries)
|
||||
|
||||
// Check error condition
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for input %q but got none", tt.input)
|
||||
} else {
|
||||
t.Logf("Got expected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if retries != tt.expected {
|
||||
t.Errorf("Got %d retries, want %d", retries, tt.expected)
|
||||
} else {
|
||||
t.Logf("Got expected value: %d", retries)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse max retries value
|
||||
func parseMaxRetries(value string, defaultValue int) (int, error) {
|
||||
// Trim whitespace
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Empty string or whitespace should error
|
||||
if value == "" {
|
||||
return 0, fmt.Errorf("retry count must be provided when flag is set")
|
||||
}
|
||||
|
||||
retries, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid retry count %q: %v", value, err)
|
||||
}
|
||||
|
||||
return retries, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user