diff --git a/provider/oci/oci.go b/provider/oci/oci.go index 77d70da4e..c526be630 100644 --- a/provider/oci/oci.go +++ b/provider/oci/oci.go @@ -170,6 +170,39 @@ func (p *OCIProvider) zones(ctx context.Context) (map[string]dns.ZoneSummary, er return zones, nil } +// Merge Endpoints with the same Name and Type into a single endpoint with multiple Targets. +func mergeEndpointsMultiTargets(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint { + endpointsByNameType := map[string][]*endpoint.Endpoint{} + + for _, ep := range endpoints { + key := fmt.Sprintf("%s-%s", ep.DNSName, ep.RecordType) + endpointsByNameType[key] = append(endpointsByNameType[key], ep) + } + + // If there were no merges, return endpoints. + if len(endpointsByNameType) == len(endpoints) { + return endpoints + } + + // Otherwise, create a new list of endpoints with the consolidated targets. + var mergedEndpoints []*endpoint.Endpoint + for _, endpoints := range endpointsByNameType { + dnsName := endpoints[0].DNSName + recordType := endpoints[0].RecordType + recordTTL := endpoints[0].RecordTTL + + targets := make([]string, len(endpoints)) + for i, e := range endpoints { + targets[i] = e.Targets[0] + } + + e := endpoint.NewEndpointWithTTL(dnsName, recordType, recordTTL, targets...) + mergedEndpoints = append(mergedEndpoints, e) + } + + return mergedEndpoints +} + func (p *OCIProvider) addPaginatedZones(ctx context.Context, zones map[string]dns.ZoneSummary, scope dns.GetZoneScopeEnum) error { var page *string // Loop until we have listed all zones. @@ -200,9 +233,22 @@ func (p *OCIProvider) addPaginatedZones(ctx context.Context, zones map[string]dn func (p *OCIProvider) newFilteredRecordOperations(endpoints []*endpoint.Endpoint, opType dns.RecordOperationOperationEnum) []dns.RecordOperation { ops := []dns.RecordOperation{} - for _, endpoint := range endpoints { - if p.domainFilter.Match(endpoint.DNSName) { - ops = append(ops, newRecordOperation(endpoint, opType)) + for _, ep := range endpoints { + if ep == nil { + continue + } + if p.domainFilter.Match(ep.DNSName) { + for _, t := range ep.Targets { + singleTargetEp := &endpoint.Endpoint{ + DNSName: ep.DNSName, + Targets: []string{t}, + RecordType: ep.RecordType, + RecordTTL: ep.RecordTTL, + Labels: ep.Labels, + ProviderSpecific: ep.ProviderSpecific, + } + ops = append(ops, newRecordOperation(singleTargetEp, opType)) + } } } return ops @@ -248,6 +294,8 @@ func (p *OCIProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) } } + endpoints = mergeEndpointsMultiTargets(endpoints) + return endpoints, nil } @@ -299,6 +347,20 @@ func (p *OCIProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e return nil } +// AdjustEndpoints modifies the endpoints as needed by the specific provider +func (p *OCIProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) { + adjustedEndpoints := []*endpoint.Endpoint{} + for _, e := range endpoints { + // OCI DNS does not support the set-identifier attribute, so we remove it to avoid plan failure + if e.SetIdentifier != "" { + log.Warnf("Adjusting endpont: %v. Ignoring unsupported annotation 'set-identifier': %s", *e, e.SetIdentifier) + e.SetIdentifier = "" + } + adjustedEndpoints = append(adjustedEndpoints, e) + } + return adjustedEndpoints, nil +} + // newRecordOperation returns a RecordOperation based on a given endpoint. func newRecordOperation(ep *endpoint.Endpoint, opType dns.RecordOperationOperationEnum) dns.RecordOperation { targets := make([]string, len(ep.Targets)) diff --git a/provider/oci/oci_test.go b/provider/oci/oci_test.go index 5a1d9bb32..7906951aa 100644 --- a/provider/oci/oci_test.go +++ b/provider/oci/oci_test.go @@ -541,7 +541,7 @@ func newMutableMockOCIDNSClient(zones []dns.ZoneSummary, recordsByZone map[strin for zoneID, records := range recordsByZone { for _, record := range records { - c.records[zoneID][ociRecordKey(*record.Rtype, *record.Domain)] = record + c.records[zoneID][ociRecordKey(*record.Rtype, *record.Domain, *record.Rdata)] = record } } @@ -577,8 +577,18 @@ func (c *mutableMockOCIDNSClient) GetZoneRecords(ctx context.Context, request dn return } -func ociRecordKey(rType, domain string) string { - return rType + "/" + domain +func ociRecordKey(rType, domain string, ip string) string { + rdata := "" + if rType == "A" { // adds support for multi-targets with same rtype and domain + rdata = "_" + ip + } + return rType + "_" + domain + rdata +} + +func sortEndpointTargets(endpoints []*endpoint.Endpoint) { + for _, ep := range endpoints { + sort.Strings([]string(ep.Targets)) + } } func (c *mutableMockOCIDNSClient) PatchZoneRecords(ctx context.Context, request dns.PatchZoneRecordsRequest) (response dns.PatchZoneRecordsResponse, err error) { @@ -599,7 +609,7 @@ func (c *mutableMockOCIDNSClient) PatchZoneRecords(ctx context.Context, request }) for _, op := range request.Items { - k := ociRecordKey(*op.Rtype, *op.Domain) + k := ociRecordKey(*op.Rtype, *op.Domain, *op.Rdata) switch op.Operation { case dns.RecordOperationOperationAdd: records[k] = dns.Record{ @@ -702,6 +712,7 @@ func TestMutableMockOCIDNSClient(t *testing.T) { } func TestOCIApplyChanges(t *testing.T) { + testCases := []struct { name string zones []dns.ZoneSummary @@ -840,10 +851,15 @@ func TestOCIApplyChanges(t *testing.T) { Rtype: common.String(endpoint.RecordTypeA), Ttl: common.Int(ociRecordTTL), }, { - Domain: common.String("bar.foo.com"), + Domain: common.String("car.foo.com"), Rdata: common.String("bar.com."), Rtype: common.String(endpoint.RecordTypeCNAME), Ttl: common.Int(ociRecordTTL), + }, { + Domain: common.String("bar.foo.com"), + Rdata: common.String("baz.com."), + Rtype: common.String(endpoint.RecordTypeCNAME), + Ttl: common.Int(ociRecordTTL), }}, }, changes: &plan.Changes{ @@ -851,10 +867,10 @@ func TestOCIApplyChanges(t *testing.T) { "foo.foo.com", endpoint.RecordTypeA, endpoint.TTL(ociRecordTTL), - "baz.com.", + "127.0.0.1", )}, UpdateOld: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( - "bar.foo.com", + "car.foo.com", endpoint.RecordTypeCNAME, endpoint.TTL(ociRecordTTL), "baz.com.", @@ -886,6 +902,129 @@ func TestOCIApplyChanges(t *testing.T) { "127.0.0.1"), }, }, + { + name: "combine_multi_target", + zones: []dns.ZoneSummary{{ + Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"), + Name: common.String("foo.com"), + }}, + + changes: &plan.Changes{ + Create: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "foo.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "192.168.1.2", + ), endpoint.NewEndpointWithTTL( + "foo.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "192.168.2.5", + )}, + }, + expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "foo.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), "192.168.1.2", "192.168.2.5", + )}, + }, + { + name: "remove_from_multi_target", + zones: []dns.ZoneSummary{{ + Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"), + Name: common.String("foo.com"), + }}, + records: map[string][]dns.Record{ + "ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959": {{ + Domain: common.String("foo.foo.com"), + Rdata: common.String("192.168.1.2"), + Rtype: common.String(endpoint.RecordTypeA), + Ttl: common.Int(ociRecordTTL), + }, { + Domain: common.String("foo.foo.com"), + Rdata: common.String("192.168.2.5"), + Rtype: common.String(endpoint.RecordTypeA), + Ttl: common.Int(ociRecordTTL), + }}, + }, + changes: &plan.Changes{ + Delete: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "foo.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "192.168.1.2", + )}, + }, + expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "foo.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), "192.168.2.5", + )}, + }, + { + name: "update_multi_target", + zones: []dns.ZoneSummary{{ + Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"), + Name: common.String("foo.com"), + }}, + records: map[string][]dns.Record{ + "ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959": {{ + Domain: common.String("first.foo.com"), + Rdata: common.String("10.77.4.5"), + Rtype: common.String(endpoint.RecordTypeA), + Ttl: common.Int(ociRecordTTL), + }}, + }, + changes: &plan.Changes{ + UpdateOld: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "first.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "10.77.4.5", + )}, + UpdateNew: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "first.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "10.77.6.10", + )}, + }, + expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "first.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "10.77.6.10", + )}, + }, + { + name: "increase_multi_target", + zones: []dns.ZoneSummary{{ + Id: common.String("ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959"), + Name: common.String("foo.com"), + }}, + records: map[string][]dns.Record{ + "ocid1.dns-zone.oc1..e1e042ef0bfbb5c251b9713fd7bf8959": {{ + Domain: common.String("first.foo.com"), + Rdata: common.String("10.77.4.5"), + Rtype: common.String(endpoint.RecordTypeA), + Ttl: common.Int(ociRecordTTL), + }}, + }, + changes: &plan.Changes{ + Create: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "first.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "10.77.6.10", + )}, + }, + expectedEndpoints: []*endpoint.Endpoint{endpoint.NewEndpointWithTTL( + "first.foo.com", + endpoint.RecordTypeA, + endpoint.TTL(ociRecordTTL), + "10.77.4.5", "10.77.6.10", + )}, + }, } for _, tc := range testCases { @@ -904,6 +1043,8 @@ func TestOCIApplyChanges(t *testing.T) { require.Equal(t, tc.err, err) endpoints, err := provider.Records(ctx) require.NoError(t, err) + sortEndpointTargets(endpoints) + sortEndpointTargets(tc.expectedEndpoints) require.ElementsMatch(t, tc.expectedEndpoints, endpoints) }) }