Reformat code to meet lint standards

This commit is contained in:
Rick Henry 2021-12-07 14:38:34 +00:00
parent 023eb23ef9
commit 6fc68a82db
No known key found for this signature in database
GPG Key ID: 07243CA36106218D
2 changed files with 456 additions and 457 deletions

View File

@ -17,225 +17,225 @@ limitations under the License.
package safedns package safedns
import ( import (
"context" "context"
"fmt" "fmt"
"os" "os"
"github.com/ukfast/sdk-go/pkg/service/safedns" log "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" ukfClient "github.com/ukfast/sdk-go/pkg/client"
ukf_client "github.com/ukfast/sdk-go/pkg/client" ukfConnection "github.com/ukfast/sdk-go/pkg/connection"
ukf_connection "github.com/ukfast/sdk-go/pkg/connection" "github.com/ukfast/sdk-go/pkg/service/safedns"
"sigs.k8s.io/external-dns/provider" "sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan"
"sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/provider"
) )
// SafeDNS is an interface that is a subset of the SafeDNS service API that are actually used. // SafeDNS is an interface that is a subset of the SafeDNS service API that are actually used.
// Signatures must match exactly. // Signatures must match exactly.
type SafeDNS interface { type SafeDNS interface {
CreateZoneRecord(zoneName string, req safedns.CreateRecordRequest) (int, error) CreateZoneRecord(zoneName string, req safedns.CreateRecordRequest) (int, error)
DeleteZoneRecord(zoneName string, recordID int) error DeleteZoneRecord(zoneName string, recordID int) error
GetZone(zoneName string) (safedns.Zone, error) GetZone(zoneName string) (safedns.Zone, error)
GetZoneRecord(zoneName string, recordID int) (safedns.Record, error) GetZoneRecord(zoneName string, recordID int) (safedns.Record, error)
GetZoneRecords(zoneName string, parameters ukf_connection.APIRequestParameters) ([]safedns.Record, error) GetZoneRecords(zoneName string, parameters ukfConnection.APIRequestParameters) ([]safedns.Record, error)
GetZones(parameters ukf_connection.APIRequestParameters) ([]safedns.Zone, error) GetZones(parameters ukfConnection.APIRequestParameters) ([]safedns.Zone, error)
PatchZoneRecord(zoneName string, recordID int, patch safedns.PatchRecordRequest) (int, error) PatchZoneRecord(zoneName string, recordID int, patch safedns.PatchRecordRequest) (int, error)
UpdateZoneRecord(zoneName string, record safedns.Record) (int, error) UpdateZoneRecord(zoneName string, record safedns.Record) (int, error)
} }
// SafeDNSProvider implements the DNS provider spec for UKFast SafeDNS. // SafeDNSProvider implements the DNS provider spec for UKFast SafeDNS.
type SafeDNSProvider struct { type SafeDNSProvider struct {
provider.BaseProvider provider.BaseProvider
Client SafeDNS Client SafeDNS
// Only consider hosted zones managing domains ending in this suffix // Only consider hosted zones managing domains ending in this suffix
domainFilter endpoint.DomainFilter domainFilter endpoint.DomainFilter
DryRun bool DryRun bool
APIRequestParams ukf_connection.APIRequestParameters APIRequestParams ukfConnection.APIRequestParameters
} }
// ZoneRecord is a datatype to simplify management of a record in a zone. // ZoneRecord is a datatype to simplify management of a record in a zone.
type ZoneRecord struct { type ZoneRecord struct {
ID int ID int
Name string Name string
Type safedns.RecordType Type safedns.RecordType
TTL safedns.RecordTTL TTL safedns.RecordTTL
Zone string Zone string
Content string Content string
} }
func NewSafeDNSProvider(domainFilter endpoint.DomainFilter, dryRun bool) (*SafeDNSProvider, error) { func NewSafeDNSProvider(domainFilter endpoint.DomainFilter, dryRun bool) (*SafeDNSProvider, error) {
token, ok := os.LookupEnv("SAFEDNS_TOKEN") token, ok := os.LookupEnv("SAFEDNS_TOKEN")
if !ok { if !ok {
return nil, fmt.Errorf("No SAFEDNS_TOKEN found in environment") return nil, fmt.Errorf("no SAFEDNS_TOKEN found in environment")
} }
ukfAPIConnection := ukf_connection.NewAPIKeyCredentialsAPIConnection(token) ukfAPIConnection := ukfConnection.NewAPIKeyCredentialsAPIConnection(token)
ukfClient := ukf_client.NewClient(ukfAPIConnection) ukfClient := ukfClient.NewClient(ukfAPIConnection)
safeDNS := ukfClient.SafeDNSService() safeDNS := ukfClient.SafeDNSService()
provider := &SafeDNSProvider{ provider := &SafeDNSProvider{
Client: safeDNS, Client: safeDNS,
domainFilter: domainFilter, domainFilter: domainFilter,
DryRun: dryRun, DryRun: dryRun,
APIRequestParams: *ukf_connection.NewAPIRequestParameters(), APIRequestParams: *ukfConnection.NewAPIRequestParameters(),
} }
return provider, nil return provider, nil
} }
// Zones returns the list of hosted zones in the SafeDNS account // Zones returns the list of hosted zones in the SafeDNS account
func (p *SafeDNSProvider) Zones(ctx context.Context) ([]safedns.Zone, error) { func (p *SafeDNSProvider) Zones(ctx context.Context) ([]safedns.Zone, error) {
var zones []safedns.Zone var zones []safedns.Zone
allZones, err := p.Client.GetZones(p.APIRequestParams) allZones, err := p.Client.GetZones(p.APIRequestParams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Check each found zone to see whether they match the domain filter provided. If they do, append it to the array of // Check each found zone to see whether they match the domain filter provided. If they do, append it to the array of
// zones defined above. If not, continue to the next item in the loop. // zones defined above. If not, continue to the next item in the loop.
for _, zone := range allZones { for _, zone := range allZones {
if p.domainFilter.Match(zone.Name) { if p.domainFilter.Match(zone.Name) {
zones = append(zones, zone) zones = append(zones, zone)
} else { } else {
continue continue
} }
} }
return zones, nil return zones, nil
} }
func (p *SafeDNSProvider) ZoneRecords(ctx context.Context) ([]ZoneRecord, error){ func (p *SafeDNSProvider) ZoneRecords(ctx context.Context) ([]ZoneRecord, error) {
zones, err := p.Zones(ctx) zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var zoneRecords []ZoneRecord var zoneRecords []ZoneRecord
for _, zone := range zones { for _, zone := range zones {
// For each zone in the zonelist, get all records of an ExternalDNS supported type. // For each zone in the zonelist, get all records of an ExternalDNS supported type.
records, err := p.Client.GetZoneRecords(zone.Name, p.APIRequestParams) records, err := p.Client.GetZoneRecords(zone.Name, p.APIRequestParams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, r := range records { for _, r := range records {
zoneRecord := ZoneRecord{ zoneRecord := ZoneRecord{
ID: r.ID, ID: r.ID,
Name: r.Name, Name: r.Name,
Type: r.Type, Type: r.Type,
TTL: r.TTL, TTL: r.TTL,
Zone: zone.Name, Zone: zone.Name,
Content: r.Content, Content: r.Content,
} }
zoneRecords = append(zoneRecords, zoneRecord) zoneRecords = append(zoneRecords, zoneRecord)
} }
} }
return zoneRecords, nil return zoneRecords, nil
} }
// Records returns a list of Endpoint resources created from all records in supported zones. // Records returns a list of Endpoint resources created from all records in supported zones.
func (p *SafeDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (p *SafeDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
var endpoints []*endpoint.Endpoint var endpoints []*endpoint.Endpoint
zoneRecords, err := p.ZoneRecords(ctx) zoneRecords, err := p.ZoneRecords(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, r := range zoneRecords { for _, r := range zoneRecords {
if provider.SupportedRecordType(string(r.Type)) { if provider.SupportedRecordType(string(r.Type)) {
endpoints = append(endpoints, endpoint.NewEndpointWithTTL(r.Name, string(r.Type), endpoint.TTL(r.TTL), r.Content)) endpoints = append(endpoints, endpoint.NewEndpointWithTTL(r.Name, string(r.Type), endpoint.TTL(r.TTL), r.Content))
} }
} }
return endpoints, nil return endpoints, nil
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *SafeDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { func (p *SafeDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// Identify the zone name for each record // Identify the zone name for each record
zoneNameIDMapper := provider.ZoneIDName{} zoneNameIDMapper := provider.ZoneIDName{}
zones, err := p.Zones(ctx) zones, err := p.Zones(ctx)
if err != nil { if err != nil {
return err return err
} }
for _, zone := range zones { for _, zone := range zones {
zoneNameIDMapper.Add(zone.Name, zone.Name) zoneNameIDMapper.Add(zone.Name, zone.Name)
} }
zoneRecords, err := p.ZoneRecords(ctx) zoneRecords, err := p.ZoneRecords(ctx)
if err != nil { if err != nil {
return err return err
} }
for _, endpoint := range changes.Create { for _, endpoint := range changes.Create {
_, ZoneName := zoneNameIDMapper.FindZone(endpoint.DNSName) _, ZoneName := zoneNameIDMapper.FindZone(endpoint.DNSName)
for _, target := range endpoint.Targets { for _, target := range endpoint.Targets {
request := safedns.CreateRecordRequest { request := safedns.CreateRecordRequest{
Name: endpoint.DNSName, Name: endpoint.DNSName,
Type: endpoint.RecordType, Type: endpoint.RecordType,
Content: target, Content: target,
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"zoneID": ZoneName, "zoneID": ZoneName,
"dnsName": endpoint.DNSName, "dnsName": endpoint.DNSName,
"recordType": endpoint.RecordType, "recordType": endpoint.RecordType,
"Value": target, "Value": target,
}).Info("Creating record") }).Info("Creating record")
_, err := p.Client.CreateZoneRecord(ZoneName, request) _, err := p.Client.CreateZoneRecord(ZoneName, request)
if err != nil { if err != nil {
return err return err
} }
} }
} }
for _, endpoint := range changes.UpdateNew { for _, endpoint := range changes.UpdateNew {
// TODO: Find a more effient way of doing this. // TODO: Find a more effient way of doing this.
// Currently iterates over each zoneRecord in ZoneRecords for each Endpoint in UpdateNew; the same will go for // Currently iterates over each zoneRecord in ZoneRecords for each Endpoint in UpdateNew; the same will go for
// Delete. As it's double-iteration, that's O(n^2), which isn't great. // Delete. As it's double-iteration, that's O(n^2), which isn't great.
var zoneRecord ZoneRecord var zoneRecord ZoneRecord
for _, target := range endpoint.Targets { for _, target := range endpoint.Targets {
for _, zr := range zoneRecords { for _, zr := range zoneRecords {
if zr.Name == endpoint.DNSName && zr.Content == target { if zr.Name == endpoint.DNSName && zr.Content == target {
zoneRecord = zr zoneRecord = zr
break break
} }
} }
newTTL := safedns.RecordTTL(int(endpoint.RecordTTL)) newTTL := safedns.RecordTTL(int(endpoint.RecordTTL))
newRecord := safedns.PatchRecordRequest{ newRecord := safedns.PatchRecordRequest{
Name: endpoint.DNSName, Name: endpoint.DNSName,
Content: target, Content: target,
TTL: &newTTL, TTL: &newTTL,
Type: endpoint.RecordType, Type: endpoint.RecordType,
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"zoneID": zoneRecord.Zone, "zoneID": zoneRecord.Zone,
"dnsName": newRecord.Name, "dnsName": newRecord.Name,
"recordType": newRecord.Type, "recordType": newRecord.Type,
"Value": newRecord.Content, "Value": newRecord.Content,
"Priority": newRecord.Priority, "Priority": newRecord.Priority,
}).Info("Patching record") }).Info("Patching record")
_, err = p.Client.PatchZoneRecord(zoneRecord.Zone, zoneRecord.ID, newRecord) _, err = p.Client.PatchZoneRecord(zoneRecord.Zone, zoneRecord.ID, newRecord)
if err != nil { if err != nil {
return err return err
} }
} }
} }
for _, endpoint := range changes.Delete { for _, endpoint := range changes.Delete {
// TODO: Find a more effient way of doing this. // TODO: Find a more effient way of doing this.
var zoneRecord ZoneRecord var zoneRecord ZoneRecord
for _, zr := range zoneRecords { for _, zr := range zoneRecords {
if zr.Name == endpoint.DNSName && string(zr.Type) == endpoint.RecordType { if zr.Name == endpoint.DNSName && string(zr.Type) == endpoint.RecordType {
zoneRecord = zr zoneRecord = zr
break break
} }
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"zoneID": zoneRecord.Zone, "zoneID": zoneRecord.Zone,
"dnsName": zoneRecord.Name, "dnsName": zoneRecord.Name,
"recordType": zoneRecord.Type, "recordType": zoneRecord.Type,
}).Info("Deleting record") }).Info("Deleting record")
err := p.Client.DeleteZoneRecord(zoneRecord.Zone, zoneRecord.ID) err := p.Client.DeleteZoneRecord(zoneRecord.Zone, zoneRecord.ID)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }

View File

@ -17,351 +17,350 @@ limitations under the License.
package safedns package safedns
import ( import (
"context" "context"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ukfast/sdk-go/pkg/service/safedns" ukfConnection "github.com/ukfast/sdk-go/pkg/connection"
ukf_connection "github.com/ukfast/sdk-go/pkg/connection" "github.com/ukfast/sdk-go/pkg/service/safedns"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/plan"
) )
// Create an implementation of the SafeDNS interface for Mocking // Create an implementation of the SafeDNS interface for Mocking
type MockSafeDNSService struct { type MockSafeDNSService struct {
mock.Mock mock.Mock
} }
func (m *MockSafeDNSService) CreateZoneRecord(zoneName string, req safedns.CreateRecordRequest) (int, error) { func (m *MockSafeDNSService) CreateZoneRecord(zoneName string, req safedns.CreateRecordRequest) (int, error) {
args := m.Called(zoneName, req) args := m.Called(zoneName, req)
return args.Int(0), args.Error(1) return args.Int(0), args.Error(1)
} }
func (m *MockSafeDNSService) DeleteZoneRecord(zoneName string, recordID int) error { func (m *MockSafeDNSService) DeleteZoneRecord(zoneName string, recordID int) error {
args := m.Called(zoneName, recordID) args := m.Called(zoneName, recordID)
return args.Error(0) return args.Error(0)
} }
func (m *MockSafeDNSService) GetZone(zoneName string) (safedns.Zone, error) { func (m *MockSafeDNSService) GetZone(zoneName string) (safedns.Zone, error) {
args := m.Called(zoneName) args := m.Called(zoneName)
return args.Get(0).(safedns.Zone), args.Error(1) return args.Get(0).(safedns.Zone), args.Error(1)
} }
func (m *MockSafeDNSService) GetZoneRecord(zoneName string, recordID int) (safedns.Record, error) { func (m *MockSafeDNSService) GetZoneRecord(zoneName string, recordID int) (safedns.Record, error) {
args := m.Called(zoneName, recordID) args := m.Called(zoneName, recordID)
return args.Get(0).(safedns.Record), args.Error(1) return args.Get(0).(safedns.Record), args.Error(1)
} }
func (m *MockSafeDNSService) GetZoneRecords(zoneName string, parameters ukf_connection.APIRequestParameters) ([]safedns.Record, error) { func (m *MockSafeDNSService) GetZoneRecords(zoneName string, parameters ukfConnection.APIRequestParameters) ([]safedns.Record, error) {
args := m.Called(zoneName, parameters) args := m.Called(zoneName, parameters)
return args.Get(0).([]safedns.Record), args.Error(1) return args.Get(0).([]safedns.Record), args.Error(1)
} }
func (m *MockSafeDNSService) GetZones(parameters ukf_connection.APIRequestParameters) ([]safedns.Zone, error) { func (m *MockSafeDNSService) GetZones(parameters ukfConnection.APIRequestParameters) ([]safedns.Zone, error) {
args := m.Called(parameters) args := m.Called(parameters)
return args.Get(0).([]safedns.Zone), args.Error(1) return args.Get(0).([]safedns.Zone), args.Error(1)
} }
func (m *MockSafeDNSService) PatchZoneRecord(zoneName string, recordID int, patch safedns.PatchRecordRequest) (int, error) { func (m *MockSafeDNSService) PatchZoneRecord(zoneName string, recordID int, patch safedns.PatchRecordRequest) (int, error) {
args := m.Called(zoneName, recordID, patch) args := m.Called(zoneName, recordID, patch)
return args.Int(0), args.Error(1) return args.Int(0), args.Error(1)
} }
func (m *MockSafeDNSService) UpdateZoneRecord(zoneName string, record safedns.Record) (int, error) { func (m *MockSafeDNSService) UpdateZoneRecord(zoneName string, record safedns.Record) (int, error) {
args := m.Called(zoneName, record) args := m.Called(zoneName, record)
return args.Int(0), args.Error(1) return args.Int(0), args.Error(1)
} }
// Utility functions // Utility functions
func createZones() []safedns.Zone { func createZones() []safedns.Zone {
return []safedns.Zone{ return []safedns.Zone{
{Name: "foo.com", Description: "Foo dot com"}, {Name: "foo.com", Description: "Foo dot com"},
{Name: "bar.io", Description: ""}, {Name: "bar.io", Description: ""},
{Name: "baz.org", Description: "Org"}, {Name: "baz.org", Description: "Org"},
} }
} }
func createFooRecords() []safedns.Record { func createFooRecords() []safedns.Record {
return []safedns.Record{ return []safedns.Record{
{ {
ID: 11, ID: 11,
Type: safedns.RecordTypeA, Type: safedns.RecordTypeA,
Name: "foo.com", Name: "foo.com",
Content: "targetFoo", Content: "targetFoo",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
{ {
ID: 12, ID: 12,
Type: safedns.RecordTypeTXT, Type: safedns.RecordTypeTXT,
Name: "foo.com", Name: "foo.com",
Content: "text", Content: "text",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
{ {
ID: 13, ID: 13,
Type: safedns.RecordTypeCAA, Type: safedns.RecordTypeCAA,
Name: "foo.com", Name: "foo.com",
Content: "", Content: "",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
} }
} }
func createBarRecords() []safedns.Record { func createBarRecords() []safedns.Record {
return []safedns.Record{} return []safedns.Record{}
} }
func createBazRecords() []safedns.Record { func createBazRecords() []safedns.Record {
return []safedns.Record{ return []safedns.Record{
{ {
ID: 31, ID: 31,
Type: safedns.RecordTypeA, Type: safedns.RecordTypeA,
Name: "baz.org", Name: "baz.org",
Content: "targetBaz", Content: "targetBaz",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
{ {
ID: 32, ID: 32,
Type: safedns.RecordTypeTXT, Type: safedns.RecordTypeTXT,
Name: "baz.org", Name: "baz.org",
Content: "text", Content: "text",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
{ {
ID: 33, ID: 33,
Type: safedns.RecordTypeA, Type: safedns.RecordTypeA,
Name: "api.baz.org", Name: "api.baz.org",
Content: "targetBazAPI", Content: "targetBazAPI",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
{ {
ID: 34, ID: 34,
Type: safedns.RecordTypeTXT, Type: safedns.RecordTypeTXT,
Name: "api.baz.org", Name: "api.baz.org",
Content: "text", Content: "text",
TTL: safedns.RecordTTL(3600), TTL: safedns.RecordTTL(3600),
}, },
} }
} }
// Actual tests // Actual tests
func TestNewSafeDNSProvider(t *testing.T) { func TestNewSafeDNSProvider(t *testing.T) {
_ = os.Setenv("SAFEDNS_TOKEN", "DUMMYVALUE") _ = os.Setenv("SAFEDNS_TOKEN", "DUMMYVALUE")
_, err := NewSafeDNSProvider(endpoint.NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) _, err := NewSafeDNSProvider(endpoint.NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
require.NoError(t, err) require.NoError(t, err)
_ = os.Unsetenv("SAFEDNS_TOKEN") _ = os.Unsetenv("SAFEDNS_TOKEN")
_, err = NewSafeDNSProvider(endpoint.NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true) _, err = NewSafeDNSProvider(endpoint.NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
require.Error(t, err) require.Error(t, err)
} }
func TestRecords(t *testing.T) { func TestRecords(t *testing.T) {
mockSafeDNSService := MockSafeDNSService{} mockSafeDNSService := MockSafeDNSService{}
provider := &SafeDNSProvider{ provider := &SafeDNSProvider{
Client: &mockSafeDNSService, Client: &mockSafeDNSService,
domainFilter: endpoint.NewDomainFilter([]string{}), domainFilter: endpoint.NewDomainFilter([]string{}),
DryRun: false, DryRun: false,
} }
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZones", "GetZones",
mock.Anything, mock.Anything,
).Return(createZones(), nil).Once() ).Return(createZones(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"foo.com", "foo.com",
mock.Anything, mock.Anything,
).Return(createFooRecords(), nil).Once() ).Return(createFooRecords(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"bar.io", "bar.io",
mock.Anything, mock.Anything,
).Return(createBarRecords(), nil).Once() ).Return(createBarRecords(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"baz.org", "baz.org",
mock.Anything, mock.Anything,
).Return(createBazRecords(), nil).Once() ).Return(createBazRecords(), nil).Once()
actual, err := provider.Records(context.Background()) actual, err := provider.Records(context.Background())
require.NoError(t, err) require.NoError(t, err)
expected := []*endpoint.Endpoint{ expected := []*endpoint.Endpoint{
{ {
DNSName: "foo.com", DNSName: "foo.com",
Targets: []string{"targetFoo"}, Targets: []string{"targetFoo"},
RecordType: "A", RecordType: "A",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
{ {
DNSName: "foo.com", DNSName: "foo.com",
Targets: []string{"text"}, Targets: []string{"text"},
RecordType: "TXT", RecordType: "TXT",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
{ {
DNSName: "baz.org", DNSName: "baz.org",
Targets: []string{"targetBaz"}, Targets: []string{"targetBaz"},
RecordType: "A", RecordType: "A",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
{ {
DNSName: "baz.org", DNSName: "baz.org",
Targets: []string{"text"}, Targets: []string{"text"},
RecordType: "TXT", RecordType: "TXT",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
{ {
DNSName: "api.baz.org", DNSName: "api.baz.org",
Targets: []string{"targetBazAPI"}, Targets: []string{"targetBazAPI"},
RecordType: "A", RecordType: "A",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
{ {
DNSName: "api.baz.org", DNSName: "api.baz.org",
Targets: []string{"text"}, Targets: []string{"text"},
RecordType: "TXT", RecordType: "TXT",
RecordTTL: 3600, RecordTTL: 3600,
Labels: endpoint.NewLabels(), Labels: endpoint.NewLabels(),
}, },
} }
mockSafeDNSService.AssertExpectations(t) mockSafeDNSService.AssertExpectations(t)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
} }
func TestSafeDNSApplyChanges( t *testing.T) { func TestSafeDNSApplyChanges(t *testing.T) {
mockSafeDNSService := MockSafeDNSService{} mockSafeDNSService := MockSafeDNSService{}
provider := &SafeDNSProvider{ provider := &SafeDNSProvider{
Client: &mockSafeDNSService, Client: &mockSafeDNSService,
domainFilter: endpoint.NewDomainFilter([]string{}), domainFilter: endpoint.NewDomainFilter([]string{}),
DryRun: false, DryRun: false,
} }
// Dummy data // Dummy data
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZones", "GetZones",
mock.Anything, mock.Anything,
).Return(createZones(), nil).Once() ).Return(createZones(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZones", "GetZones",
mock.Anything, mock.Anything,
).Return(createZones(), nil).Once() ).Return(createZones(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"foo.com", "foo.com",
mock.Anything, mock.Anything,
).Return(createFooRecords(), nil).Once() ).Return(createFooRecords(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"bar.io", "bar.io",
mock.Anything, mock.Anything,
).Return(createBarRecords(), nil).Once() ).Return(createBarRecords(), nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"GetZoneRecords", "GetZoneRecords",
"baz.org", "baz.org",
mock.Anything, mock.Anything,
).Return(createBazRecords(), nil).Once() ).Return(createBazRecords(), nil).Once()
// Apply actions // Apply actions
mockSafeDNSService.On( mockSafeDNSService.On(
"DeleteZoneRecord", "DeleteZoneRecord",
"baz.org", "baz.org",
33, 33,
).Return(nil).Once() ).Return(nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"DeleteZoneRecord", "DeleteZoneRecord",
"baz.org", "baz.org",
34, 34,
).Return(nil).Once() ).Return(nil).Once()
TTL300 := safedns.RecordTTL(300) TTL300 := safedns.RecordTTL(300)
mockSafeDNSService.On( mockSafeDNSService.On(
"PatchZoneRecord", "PatchZoneRecord",
"foo.com", "foo.com",
11, 11,
safedns.PatchRecordRequest{ safedns.PatchRecordRequest{
Type: "A", Type: "A",
Name: "foo.com", Name: "foo.com",
Content: "targetFoo", Content: "targetFoo",
TTL: &TTL300, TTL: &TTL300,
}, },
).Return(123, nil).Once() ).Return(123, nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"CreateZoneRecord", "CreateZoneRecord",
"bar.io", "bar.io",
safedns.CreateRecordRequest{ safedns.CreateRecordRequest{
Type: "A", Type: "A",
Name: "create.bar.io", Name: "create.bar.io",
Content: "targetBar", Content: "targetBar",
}, },
).Return(246, nil).Once() ).Return(246, nil).Once()
mockSafeDNSService.On( mockSafeDNSService.On(
"CreateZoneRecord", "CreateZoneRecord",
"bar.io", "bar.io",
safedns.CreateRecordRequest{ safedns.CreateRecordRequest{
Type: "A", Type: "A",
Name: "bar.io", Name: "bar.io",
Content: "targetBar", Content: "targetBar",
}, },
).Return(369, nil).Once() ).Return(369, nil).Once()
err := provider.ApplyChanges(context.Background(), &plan.Changes{ err := provider.ApplyChanges(context.Background(), &plan.Changes{
Create: []*endpoint.Endpoint{ Create: []*endpoint.Endpoint{
{ {
DNSName: "create.bar.io", DNSName: "create.bar.io",
RecordType: "A", RecordType: "A",
Targets: []string{"targetBar"}, Targets: []string{"targetBar"},
RecordTTL: 3600, RecordTTL: 3600,
}, },
{ {
DNSName: "bar.io", DNSName: "bar.io",
RecordType: "A", RecordType: "A",
Targets: []string{"targetBar"}, Targets: []string{"targetBar"},
RecordTTL: 3600, RecordTTL: 3600,
}, },
}, },
Delete: []*endpoint.Endpoint{ Delete: []*endpoint.Endpoint{
{ {
DNSName: "api.baz.org", DNSName: "api.baz.org",
RecordType: "A", RecordType: "A",
}, },
{ {
DNSName: "api.baz.org", DNSName: "api.baz.org",
RecordType: "TXT", RecordType: "TXT",
}, },
}, },
UpdateNew: []*endpoint.Endpoint{ UpdateNew: []*endpoint.Endpoint{
{ {
DNSName: "foo.com", DNSName: "foo.com",
RecordType: "A", RecordType: "A",
RecordTTL: 300, RecordTTL: 300,
Targets: []string{"targetFoo"}, Targets: []string{"targetFoo"},
}, },
}, },
UpdateOld: []*endpoint.Endpoint{}, UpdateOld: []*endpoint.Endpoint{},
}) })
require.NoError(t, err) require.NoError(t, err)
mockSafeDNSService.AssertExpectations(t) mockSafeDNSService.AssertExpectations(t)
} }