chore(cloudflare): migrate ListRecords() to new lib (#5778)

* chore(cloudflare): migrate ListRecords() to new lib

* chore(cloudflare): test zoneService.ListDNSRecord()
This commit is contained in:
vflaux 2025-09-04 11:53:15 +02:00 committed by GitHub
parent e88b94bdee
commit 61ca17c0fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 67 deletions

View File

@ -110,7 +110,7 @@ type cloudFlareDNS interface {
ZoneIDByName(zoneName string) (string, error)
ListZones(ctx context.Context, params zones.ZoneListParams) autoPager[zones.Zone]
GetZone(ctx context.Context, zoneID string) (*zones.Zone, error)
ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error)
ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse]
CreateDNSRecord(ctx context.Context, params dns.RecordNewParams) (*dns.RecordResponse, error)
DeleteDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, recordID string) error
UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error
@ -152,13 +152,8 @@ func (z zoneService) CreateDNSRecord(ctx context.Context, params dns.RecordNewPa
return z.service.DNS.Records.New(ctx, params)
}
func (z zoneService) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) {
records, info, err := z.serviceV0.ListDNSRecords(ctx, rc, rp)
convertedRecords := make([]dns.RecordResponse, 0, len(records))
for _, record := range records {
convertedRecords = append(convertedRecords, dnsRecordResponseFromLegacyDNSRecord(record))
}
return convertedRecords, info, err
func (z zoneService) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] {
return z.service.DNS.Records.ListAutoPaging(ctx, params)
}
func (z zoneService) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error {
@ -428,7 +423,7 @@ func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint,
var endpoints []*endpoint.Endpoint
for _, zone := range zones {
records, err := p.listDNSRecordsWithAutoPagination(ctx, zone.ID)
records, err := p.getDNSRecordsMap(ctx, zone.ID)
if err != nil {
return nil, err
}
@ -643,7 +638,7 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud
continue
}
records, err := p.listDNSRecordsWithAutoPagination(ctx, zoneID)
records, err := p.getDNSRecordsMap(ctx, zoneID)
if err != nil {
return fmt.Errorf("could not fetch records from zone, %w", err)
}
@ -860,27 +855,19 @@ func newDNSRecordIndex(r dns.RecordResponse) DNSRecordIndex {
return DNSRecordIndex{Name: r.Name, Type: string(r.Type), Content: r.Content}
}
// listDNSRecordsWithAutoPagination performs automatic pagination of results on requests to cloudflare.ListDNSRecords with custom per_page values
func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Context, zoneID string) (DNSRecordsMap, error) {
// getDNSRecordsMap retrieves all DNS records for a given zone and returns them as a DNSRecordsMap.
func (p *CloudFlareProvider) getDNSRecordsMap(ctx context.Context, zoneID string) (DNSRecordsMap, error) {
// for faster getRecordID lookup
records := make(DNSRecordsMap)
resultInfo := cloudflarev0.ResultInfo{PerPage: p.DNSRecordsConfig.PerPage, Page: 1}
params := cloudflarev0.ListDNSRecordsParams{ResultInfo: resultInfo}
for {
pageRecords, resultInfo, err := p.Client.ListDNSRecords(ctx, cloudflarev0.ZoneIdentifier(zoneID), params)
if err != nil {
return nil, convertCloudflareError(err)
}
for _, r := range pageRecords {
records[newDNSRecordIndex(r)] = r
}
params.ResultInfo = resultInfo.Next()
if params.Done() {
break
}
recordsMap := make(DNSRecordsMap)
params := dns.RecordListParams{ZoneID: cloudflare.F(zoneID)}
iter := p.Client.ListDNSRecords(ctx, params)
for record := range autoPagerIterator(iter) {
recordsMap[newDNSRecordIndex(record)] = record
}
return records, nil
if iter.Err() != nil {
return nil, convertCloudflareError(iter.Err())
}
return recordsMap, nil
}
func newCustomHostnameIndex(ch cloudflarev0.CustomHostname) CustomHostnameIndex {

View File

@ -22,12 +22,12 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"testing"
"time"
cloudflarev0 "github.com/cloudflare/cloudflare-go"
"github.com/cloudflare/cloudflare-go/v5"
"github.com/cloudflare/cloudflare-go/v5/dns"
"github.com/cloudflare/cloudflare-go/v5/zones"
"github.com/maxatome/go-testdeep/td"
@ -171,49 +171,22 @@ func (m *mockCloudFlareClient) CreateDNSRecord(ctx context.Context, params dns.R
return &record, nil
}
func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.ListDNSRecordsParams) ([]dns.RecordResponse, *cloudflarev0.ResultInfo, error) {
func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, params dns.RecordListParams) autoPager[dns.RecordResponse] {
if m.dnsRecordsError != nil {
return nil, &cloudflarev0.ResultInfo{}, m.dnsRecordsError
return &mockAutoPager[dns.RecordResponse]{err: m.dnsRecordsError}
}
result := []dns.RecordResponse{}
if zone, ok := m.Records[rc.Identifier]; ok {
iter := &mockAutoPager[dns.RecordResponse]{}
if zone, ok := m.Records[params.ZoneID.Value]; ok {
for _, record := range zone {
if strings.HasPrefix(record.Name, "newerror-list-") {
m.DeleteDNSRecord(ctx, rc, record.ID)
return nil, &cloudflarev0.ResultInfo{}, errors.New("failed to list erroring DNS record")
m.DeleteDNSRecord(ctx, cloudflarev0.ResourceIdentifier(params.ZoneID.Value), record.ID)
iter.err = errors.New("failed to list erroring DNS record")
return iter
}
result = append(result, record)
iter.items = append(iter.items, record)
}
}
if len(result) == 0 || rp.PerPage == 0 {
return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: 0, Total: 0}, nil
}
// if not pagination options were passed in, return the result as is
if rp.Page == 0 {
return result, &cloudflarev0.ResultInfo{Page: 1, TotalPages: 1, Count: len(result), Total: len(result)}, nil
}
// otherwise, split the result into chunks of size rp.PerPage to simulate the pagination from the API
chunks := [][]dns.RecordResponse{}
// to ensure consistency in the multiple calls to this function, sort the result slice
sort.Slice(result, func(i, j int) bool { return strings.Compare(result[i].ID, result[j].ID) > 0 })
for rp.PerPage < len(result) {
result, chunks = result[rp.PerPage:], append(chunks, result[0:rp.PerPage])
}
chunks = append(chunks, result)
// return the requested page
partialResult := chunks[rp.Page-1]
return partialResult, &cloudflarev0.ResultInfo{
PerPage: rp.PerPage,
Page: rp.Page,
TotalPages: len(chunks),
Count: len(partialResult),
Total: len(result),
}, nil
return iter
}
func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudflarev0.ResourceContainer, rp cloudflarev0.UpdateDNSRecordParams) error {
@ -1500,7 +1473,7 @@ func TestGroupByNameAndTypeWithCustomHostnames_MX(t *testing.T) {
}
ctx := context.Background()
chs := CustomHostnamesMap{}
records, err := provider.listDNSRecordsWithAutoPagination(ctx, "001")
records, err := provider.getDNSRecordsMap(ctx, "001")
assert.NoError(t, err)
endpoints := provider.groupByNameAndTypeWithCustomHostnames(records, chs)
@ -3346,3 +3319,22 @@ func TestDnsRecordFromLegacyAPI(t *testing.T) {
})
}
}
func TestZoneService(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(t.Context())
cancel()
client := &zoneService{
service: cloudflare.NewClient(),
}
t.Run("UpdateDNSRecord", func(t *testing.T) {
t.Parallel()
iter := client.ListDNSRecords(ctx, dns.RecordListParams{ZoneID: cloudflare.F("foo")})
require.False(t, iter.Next())
require.Empty(t, iter.Current())
require.ErrorIs(t, iter.Err(), context.Canceled)
})
}