fix(azure): enhance retry logic using azure SDK (#5361)

* fix(azure): Enhance retry logic using azure SDK

* Added the changes for flag based maxretries configuration

* Fixed types.go, flags.md and delected unneccesary comments

* Added the correct image for the Azure Private DNS tutorial

* Following the go naming convention for maxRetriesCount

* Added the correct flag information to the --azure-maxretries-count

* Made the required changes to accept the --azure-maxretries-count flag value from cli/env
This commit is contained in:
Shruti Panapana 2025-05-17 14:39:14 +05:30 committed by GitHub
parent 09d49a106a
commit 7b9d8d9355
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 366 additions and 19 deletions

View File

@ -186,9 +186,9 @@ func Execute() {
}
p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, cfg.AWSSDCreateTag, sd.NewFromConfig(aws.CreateDefaultV2Config(cfg)))
case "azure-dns", "azure":
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.DryRun)
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.AzureMaxRetriesCount, cfg.DryRun)
case "azure-private-dns":
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.DryRun)
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.AzureActiveDirectoryAuthorityHost, cfg.AzureZonesCacheDuration, cfg.AzureMaxRetriesCount, cfg.DryRun)
case "ultradns":
p, err = ultradns.NewUltraDNSProvider(domainFilter, cfg.DryRun)
case "civo":

View File

@ -86,6 +86,7 @@
| `--azure-subscription-id=""` | When using the Azure provider, override the Azure subscription to use (optional) |
| `--azure-user-assigned-identity-client-id=""` | When using the Azure provider, override the client id of user assigned identity in config file (optional) |
| `--azure-zones-cache-duration=0s` | When using the Azure provider, set the zones list cache TTL (0s to disable). |
| `--azure-maxretries-count=3` | When using the Azure provider, set the number of retries for API calls (When less than 0, it disables retries). (optional) |
| `--tencent-cloud-config-file="/etc/kubernetes/tencent-cloud.json"` | When using the Tencent Cloud provider, specify the Tencent Cloud configuration file (required when --provider=tencentcloud) |
| `--tencent-cloud-zone-type=` | When using the Tencent Cloud provider, filter for zones with visibility (optional, options: public, private) |
| `--[no-]cloudflare-proxied` | When using the Cloudflare provider, specify if the proxy mode must be enabled (default: disabled) |

View File

@ -108,6 +108,7 @@ $ az role assignment create --role "Private DNS Zone Contributor" --assignee <ap
## Throttling
When the ExternalDNS managed zones list doesn't change frequently, one can set `--azure-zones-cache-duration` (zones list cache time-to-live). The zones list cache is disabled by default, with a value of 0s.
Also, one can leverage the built-in retry policies of the Azure SDK. The flag --azure-maxretries-count can be specified in the manifest yaml to configure behavior. The default value of Azure SDK retry is 3.
## Deploy ExternalDNS
@ -151,6 +152,7 @@ spec:
- --provider=azure-private-dns
- --azure-resource-group=externaldns
- --azure-subscription-id=<use the id of your subscription>
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes
@ -223,6 +225,7 @@ spec:
- --provider=azure-private-dns
- --azure-resource-group=externaldns
- --azure-subscription-id=<use the id of your subscription>
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes
@ -295,6 +298,7 @@ spec:
- --provider=azure-private-dns
- --azure-resource-group=externaldns
- --azure-subscription-id=<use the id of your subscription>
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes

View File

@ -493,6 +493,7 @@ NOTE: make sure the pod is restarted whenever you make a configuration change.
## Throttling
When the ExternalDNS managed zones list doesn't change frequently, one can set `--azure-zones-cache-duration` (zones list cache time-to-live). The zones list cache is disabled by default, with a value of 0s.
Also, one can leverage the built-in retry policies of the Azure SDK with a tunable maxRetries value. Environment variable AZURE_SDK_MAX_RETRIES can be specified in the manifest yaml to configure behavior. The defualt value of Azure SDK retry is 3.
## Ingress used with ExternalDNS
@ -540,6 +541,7 @@ spec:
- --domain-filter=example.com # (optional) limit to only example.com domains; change to match the zone created above.
- --provider=azure
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes
@ -609,6 +611,7 @@ spec:
- --provider=azure
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
- --txt-prefix=externaldns-
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes
@ -680,6 +683,7 @@ spec:
- --domain-filter=example.com # (optional) limit to only example.com domains; change to match the zone created above.
- --provider=azure
- --azure-resource-group=MyDnsResourceGroup # (optional) use the DNS zones from the tutorial's resource group
- --azure-maxretries-count=1 # (optional) specifies the maxRetires value to be used by the Azure SDK. Default is 3.
volumeMounts:
- name: azure-config-file
mountPath: /etc/kubernetes

View File

@ -106,6 +106,7 @@ type Config struct {
AzureUserAssignedIdentityClientID string
AzureActiveDirectoryAuthorityHost string
AzureZonesCacheDuration time.Duration
AzureMaxRetriesCount int
CloudflareProxied bool
CloudflareCustomHostnames bool
CloudflareCustomHostnamesMinTLSVersion string
@ -247,6 +248,7 @@ var defaultConfig = &Config{
AzureResourceGroup: "",
AzureSubscriptionID: "",
AzureZonesCacheDuration: 0 * time.Second,
AzureMaxRetriesCount: 3,
CFAPIEndpoint: "",
CFPassword: "",
CFUsername: "",
@ -527,6 +529,7 @@ func App(cfg *Config) *kingpin.Application {
app.Flag("azure-subscription-id", "When using the Azure provider, override the Azure subscription to use (optional)").Default(defaultConfig.AzureSubscriptionID).StringVar(&cfg.AzureSubscriptionID)
app.Flag("azure-user-assigned-identity-client-id", "When using the Azure provider, override the client id of user assigned identity in config file (optional)").Default("").StringVar(&cfg.AzureUserAssignedIdentityClientID)
app.Flag("azure-zones-cache-duration", "When using the Azure provider, set the zones list cache TTL (0s to disable).").Default(defaultConfig.AzureZonesCacheDuration.String()).DurationVar(&cfg.AzureZonesCacheDuration)
app.Flag("azure-maxretries-count", "When using the Azure provider, set the number of retries for API calls (When less than 0, it disables retries). (optional)").Default(strconv.Itoa(defaultConfig.AzureMaxRetriesCount)).IntVar(&cfg.AzureMaxRetriesCount)
app.Flag("tencent-cloud-config-file", "When using the Tencent Cloud provider, specify the Tencent Cloud configuration file (required when --provider=tencentcloud)").Default(defaultConfig.TencentCloudConfigFile).StringVar(&cfg.TencentCloudConfigFile)
app.Flag("tencent-cloud-zone-type", "When using the Tencent Cloud provider, filter for zones with visibility (optional, options: public, private)").Default(defaultConfig.TencentCloudZoneType).EnumVar(&cfg.TencentCloudZoneType, "", "public", "private")

View File

@ -73,6 +73,7 @@ var (
AzureConfigFile: "/etc/kubernetes/azure.json",
AzureResourceGroup: "",
AzureSubscriptionID: "",
AzureMaxRetriesCount: 3,
CloudflareProxied: false,
CloudflareCustomHostnames: false,
CloudflareCustomHostnamesMinTLSVersion: "1.0",
@ -183,6 +184,7 @@ var (
AzureConfigFile: "azure.json",
AzureResourceGroup: "arg",
AzureSubscriptionID: "arg",
AzureMaxRetriesCount: 4,
CloudflareProxied: true,
CloudflareCustomHostnames: true,
CloudflareCustomHostnamesMinTLSVersion: "1.3",
@ -296,6 +298,7 @@ func TestParseFlags(t *testing.T) {
"--azure-config-file=azure.json",
"--azure-resource-group=arg",
"--azure-subscription-id=arg",
"--azure-maxretries-count=4",
"--cloudflare-proxied",
"--cloudflare-custom-hostnames",
"--cloudflare-custom-hostnames-min-tls-version=1.3",
@ -427,6 +430,7 @@ func TestParseFlags(t *testing.T) {
"EXTERNAL_DNS_AZURE_CONFIG_FILE": "azure.json",
"EXTERNAL_DNS_AZURE_RESOURCE_GROUP": "arg",
"EXTERNAL_DNS_AZURE_SUBSCRIPTION_ID": "arg",
"EXTERNAL_DNS_AZURE_MAXRETRIES_COUNT": "4",
"EXTERNAL_DNS_CLOUDFLARE_PROXIED": "1",
"EXTERNAL_DNS_CLOUDFLARE_CUSTOM_HOSTNAMES": "1",
"EXTERNAL_DNS_CLOUDFLARE_CUSTOM_HOSTNAMES_MIN_TLS_VERSION": "1.3",

View File

@ -63,17 +63,19 @@ type AzureProvider struct {
zonesClient ZonesClient
zonesCache *zonesCache[dns.Zone]
recordSetsClient RecordSetsClient
maxRetriesCount int
}
// NewAzureProvider creates a new Azure provider.
//
// Returns the provider or an error if a provider could not be created.
func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, dryRun bool) (*AzureProvider, error) {
func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, maxRetriesCount int, dryRun bool) (*AzureProvider, error) {
cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost)
if err != nil {
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
}
cred, clientOpts, err := getCredentials(*cfg)
cred, clientOpts, err := getCredentials(*cfg, maxRetriesCount)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@ -97,6 +99,7 @@ func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zon
zonesClient: zonesClient,
zonesCache: &zonesCache[dns.Zone]{duration: zonesCacheDuration},
recordSetsClient: recordSetsClient,
maxRetriesCount: maxRetriesCount,
}, nil
}

View File

@ -58,17 +58,19 @@ type AzurePrivateDNSProvider struct {
zonesClient PrivateZonesClient
zonesCache *zonesCache[privatedns.PrivateZone]
recordSetsClient PrivateRecordSetsClient
maxRetriesCount int
}
// NewAzurePrivateDNSProvider creates a new Azure Private DNS provider.
//
// Returns the provider or an error if a provider could not be created.
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, dryRun bool) (*AzurePrivateDNSProvider, error) {
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesCacheDuration time.Duration, maxRetriesCount int, dryRun bool) (*AzurePrivateDNSProvider, error) {
cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost)
if err != nil {
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
}
cred, clientOpts, err := getCredentials(*cfg)
cred, clientOpts, err := getCredentials(*cfg, maxRetriesCount)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@ -92,6 +94,7 @@ func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainF
zonesClient: zonesClient,
zonesCache: &zonesCache[privatedns.PrivateZone]{duration: zonesCacheDuration},
recordSetsClient: recordSetsClient,
maxRetriesCount: maxRetriesCount,
}, nil
}

View File

@ -224,13 +224,13 @@ func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64,
}
// newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) {
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet, maxRetriesCount int) (*AzurePrivateDNSProvider, error) {
zonesClient := newMockPrivateZonesClient(zones)
recordSetsClient := newMockPrivateRecordSectsClient(recordSets)
return newAzurePrivateDNSProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
return newAzurePrivateDNSProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient, maxRetriesCount), nil
}
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider {
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient, maxRetriesCount int) *AzurePrivateDNSProvider {
return &AzurePrivateDNSProvider{
domainFilter: domainFilter,
zoneNameFilter: zoneNameFilter,
@ -240,6 +240,7 @@ func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilt
zonesClient: privateZonesClient,
zonesCache: &zonesCache[privatedns.PrivateZone]{duration: 0},
recordSetsClient: privateRecordsClient,
maxRetriesCount: maxRetriesCount,
}
}
@ -259,7 +260,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) {
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
createPrivateMockRecordSetWithTTL("mail", endpoint.RecordTypeMX, "10 example.com", 4000),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -298,7 +299,7 @@ func TestAzurePrivateDNSMultiRecord(t *testing.T) {
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
createPrivateMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -378,6 +379,7 @@ func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client P
"group",
&zonesClient,
client,
3,
)
createRecords := []*endpoint.Endpoint{
@ -450,7 +452,7 @@ func TestAzurePrivateDNSNameFilter(t *testing.T) {
createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createPrivateMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL),
createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -505,6 +507,7 @@ func testAzurePrivateDNSApplyChangesInternalZoneName(t *testing.T, dryRun bool,
"group",
&zonesClient,
client,
3,
)
createRecords := []*endpoint.Endpoint{

View File

@ -237,13 +237,13 @@ func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values
}
// newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) {
func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zones []*dns.Zone, recordSets []*dns.RecordSet, maxRetriesCount int) (*AzureProvider, error) {
zonesClient := newMockZonesClient(zones)
recordSetsClient := newMockRecordSetsClient(recordSets)
return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost, &zonesClient, &recordSetsClient), nil
return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, activeDirectoryAuthorityHost, &zonesClient, &recordSetsClient, maxRetriesCount), nil
}
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider {
func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, activeDirectoryAuthorityHost string, zonesClient ZonesClient, recordsClient RecordSetsClient, maxRetriesCount int) *AzureProvider {
return &AzureProvider{
domainFilter: domainFilter,
zoneNameFilter: zoneNameFilter,
@ -255,6 +255,7 @@ func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoin
zonesClient: zonesClient,
zonesCache: &zonesCache[dns.Zone]{duration: 0},
recordSetsClient: recordsClient,
maxRetriesCount: maxRetriesCount,
}
}
@ -280,7 +281,7 @@ func TestAzureRecord(t *testing.T) {
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com"),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -325,7 +326,7 @@ func TestAzureMultiRecord(t *testing.T) {
createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL),
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -415,6 +416,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC
"",
&zonesClient,
client,
3,
)
createRecords := []*endpoint.Endpoint{
@ -497,7 +499,7 @@ func TestAzureNameFilter(t *testing.T) {
createMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL),
createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10),
createMockRecordSetWithTTL("hack", endpoint.RecordTypeNS, "ns1.example.com.", 3600),
})
}, 3)
if err != nil {
t.Fatal(err)
}
@ -555,6 +557,7 @@ func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client Rec
"",
&zonesClient,
client,
3,
)
createRecords := []*endpoint.Endpoint{

View File

@ -17,15 +17,19 @@ limitations under the License.
package azure
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
@ -72,15 +76,57 @@ func getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityCl
return cfg, nil
}
// ctxKey is a type for context keys
// This is used to avoid collisions with other packages that may use the same key in the context.
type ctxKey string
const (
// Context key for request ID
clientRequestIDKey ctxKey = "client-request-id"
// Azure API Headers
msRequestIDHeader = "x-ms-request-id"
msCorrelationRequestHeader = "x-ms-correlation-request-id"
msClientRequestIDHeader = "x-ms-client-request-id"
)
// customHeaderPolicy adds UUID to request headers
type customHeaderPolicy struct{}
func (p *customHeaderPolicy) Do(req *policy.Request) (*http.Response, error) {
id := req.Raw().Header.Get(msClientRequestIDHeader)
if id == "" {
id = uuid.New().String()
req.Raw().Header.Set(msClientRequestIDHeader, id)
newCtx := context.WithValue(req.Raw().Context(), clientRequestIDKey, id)
*req.Raw() = *req.Raw().WithContext(newCtx)
}
return req.Next()
}
func CustomHeaderPolicynew() policy.Policy { return &customHeaderPolicy{} }
// getCredentials retrieves Azure API credentials.
func getCredentials(cfg config) (azcore.TokenCredential, *arm.ClientOptions, error) {
func getCredentials(cfg config, maxRetries int) (azcore.TokenCredential, *arm.ClientOptions, error) {
cloudCfg, err := getCloudConfiguration(cfg.Cloud)
if err != nil {
return nil, nil, fmt.Errorf("failed to get cloud configuration: %w", err)
}
clientOpts := azcore.ClientOptions{
Cloud: cloudCfg,
Retry: policy.RetryOptions{
MaxRetries: int32(maxRetries),
},
Logging: policy.LogOptions{
AllowedHeaders: []string{
msRequestIDHeader,
msCorrelationRequestHeader,
msClientRequestIDHeader,
},
},
PerCallPolicies: []policy.Policy{
CustomHeaderPolicynew(),
},
}
log.Debugf("Configured Azure client with maxRetries: %d", clientOpts.Retry.MaxRetries)
armClientOpts := &arm.ClientOptions{
ClientOptions: clientOpts,
}

View File

@ -17,11 +17,19 @@ limitations under the License.
package azure
import (
"context"
"fmt"
"io"
"net/http"
"path"
"runtime"
"strconv"
"strings"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/stretchr/testify/assert"
)
@ -59,3 +67,268 @@ func TestOverrideConfiguration(t *testing.T) {
assert.Equal(t, cfg.ResourceGroup, "rg-override")
assert.Equal(t, cfg.ActiveDirectoryAuthorityHost, "aad-endpoint-override")
}
// Test for custom header policy
type transportFunc func(*http.Request) (*http.Response, error)
func (f transportFunc) Do(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestCustomHeaderPolicyWithRetries(t *testing.T) {
// Set up test environment
defaultRetries := 3
flagValue := "-6"
isSet := true
retries, err := parseMaxRetries(flagValue, defaultRetries)
if err != nil {
t.Fatalf("Failed to parse retries: %v", err)
}
maxRetries := int32(retries)
if !isSet || (isSet && flagValue == "0") {
// Use default if flag not provided OR if flag is "0"
maxRetries = int32(defaultRetries)
t.Logf("Using default value: %d (flag provided: %v, value: %q)",
defaultRetries, isSet, flagValue)
} else {
// Flag was provided with non-zero value
retries, err := parseMaxRetries(flagValue, defaultRetries)
if err != nil {
t.Fatalf("Failed to parse retries: %v", err)
}
maxRetries = int32(retries)
t.Logf("Using provided flag value: %d", retries)
}
var attempt int32
var firstRequestID string
// Create mock transport that simulates 429 responses
mockTransport := transportFunc(func(req *http.Request) (*http.Response, error) {
attempt++
// Get the request ID from header
requestID := req.Header.Get("x-ms-client-request-id")
if requestID == "" {
t.Fatalf("Request ID missing on attempt %d", attempt)
}
// On first attempt, store the request ID
if attempt == 1 {
firstRequestID = requestID
t.Logf("Initial request ID: %s", firstRequestID)
} else {
// On subsequent attempts, verify it matches the first request ID
if requestID != firstRequestID {
t.Fatalf("Request ID changed on retry %d: got %s, want %s",
attempt, requestID, firstRequestID)
} else {
t.Logf("Request ID preserved on attempt %d: %s", attempt, requestID)
}
}
// Verify the ID is also in the context
if ctxID, ok := req.Context().Value(clientRequestIDKey).(string); !ok || ctxID != requestID {
t.Errorf("Context ID mismatch on attempt %d: got %v, want %s",
attempt, ctxID, requestID)
}
// Return 429 for all but the last attempt
if maxRetries < 0 || attempt <= maxRetries {
t.Logf("Attempt %d: THROTTLED (429) - Request ID: %s", attempt, requestID)
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: io.NopCloser(strings.NewReader("Too many requests")),
Request: req,
Header: http.Header{
"x-ms-client-request-id": []string{requestID},
"Retry-After": []string{"1"},
},
}, nil
}
// Return 200 on final attempt
t.Logf("Attempt %d: SUCCESS (200) - Request ID: %s", attempt, requestID)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("Success")),
Request: req,
Header: http.Header{
"x-ms-client-request-id": []string{requestID},
},
}, nil
})
// Create pipeline with retry policy and custom header policy
mockPipeline := azruntime.NewPipeline(
"testmodule",
"1.0",
azruntime.PipelineOptions{
PerCall: []policy.Policy{
CustomHeaderPolicynew(),
},
},
&policy.ClientOptions{
Retry: policy.RetryOptions{
MaxRetries: maxRetries,
},
Transport: mockTransport,
},
)
// Create request and execute
req, err := azruntime.NewRequest(context.Background(), http.MethodGet, "https://example.com")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := mockPipeline.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Verify we got the expected number of attempts
var expectedAttempts int32
if maxRetries < 0 {
expectedAttempts = 1 // For negative retries, only one attempt should be made
} else {
expectedAttempts = maxRetries + 1 // For zero or positive retries, attempts = retries + 1
}
if attempt != expectedAttempts {
t.Errorf("Wrong number of attempts: got %d, want %d", attempt, expectedAttempts)
}
t.Logf("Test completed with %d attempts, all with request ID: %s", attempt, firstRequestID)
}
func TestMaxRetriesCount(t *testing.T) {
defaultRetries := 3
tests := []struct {
name string
input string
isSet bool // indicates if flag was provided
expected int
shouldError bool
description string
}{
{
name: "FlagNotProvided",
input: "",
isSet: false,
expected: defaultRetries,
shouldError: false,
description: "When flag is not provided, should use default value",
},
{
name: "FlagProvidedEmpty",
input: "",
isSet: true,
expected: 0,
shouldError: true,
description: "When flag is provided but empty, should error",
},
{
name: "ValidPositive",
input: "5",
isSet: true,
expected: 5,
shouldError: false,
description: "Valid positive number should be accepted",
},
{
name: "ZeroRetries",
input: "0",
isSet: true,
expected: 0,
shouldError: false,
description: "Zero should be accepted and handled by SDK",
},
{
name: "NegativeRetries",
input: "-2",
isSet: true,
expected: -2,
shouldError: false,
description: "Negative values should be accepted and handled by SDK",
},
{
name: "InvalidString",
input: "abc",
isSet: true,
expected: 0,
shouldError: true,
description: "Non-numeric string should error",
},
{
name: "Whitespace",
input: " ",
isSet: true,
expected: 0,
shouldError: true,
description: "Whitespace should error",
},
{
name: "SpecialChars",
input: "@#$%",
isSet: true,
expected: 0,
shouldError: true,
description: "Special characters should error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Logf("=== Test Case: %s ===", tt.name)
t.Logf("Description: %s", tt.description)
t.Logf("Input: %q (flag provided: %v)", tt.input, tt.isSet)
// Handle flag not provided case
if !tt.isSet {
t.Logf("Using default value: %d", defaultRetries)
return
}
retries, err := parseMaxRetries(tt.input, defaultRetries)
// Check error condition
if tt.shouldError {
if err == nil {
t.Errorf("Expected error for input %q but got none", tt.input)
} else {
t.Logf("Got expected error: %v", err)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if retries != tt.expected {
t.Errorf("Got %d retries, want %d", retries, tt.expected)
} else {
t.Logf("Got expected value: %d", retries)
}
}
})
}
}
// Helper function to parse max retries value
func parseMaxRetries(value string, defaultValue int) (int, error) {
// Trim whitespace
value = strings.TrimSpace(value)
// Empty string or whitespace should error
if value == "" {
return 0, fmt.Errorf("retry count must be provided when flag is set")
}
retries, err := strconv.Atoi(value)
if err != nil {
return 0, fmt.Errorf("invalid retry count %q: %v", value, err)
}
return retries, nil
}