From b09d7abc559904d111e7f398dd79bd69a45177f7 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Tue, 23 Mar 2021 12:17:32 +0100 Subject: [PATCH] update TransIP's Go client to v6 --- go.mod | 2 +- go.sum | 4 +- provider/transip/transip.go | 413 ++++++++++++++++--------------- provider/transip/transip_test.go | 346 +++++++++++++++++--------- 4 files changed, 446 insertions(+), 319 deletions(-) diff --git a/go.mod b/go.mod index feff79fa2..8ed81f3cd 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,7 @@ require ( github.com/smartystreets/gunit v1.3.4 // indirect github.com/stretchr/testify v1.6.1 github.com/terra-farm/udnssdk v1.3.5 // indirect - github.com/transip/gotransip v5.8.2+incompatible + github.com/transip/gotransip/v6 v6.6.0 github.com/ultradns/ultradns-sdk-go v0.0.0-20200616202852-e62052662f60 github.com/vinyldns/go-vinyldns v0.0.0-20200211145900-fe8a3d82e556 github.com/vultr/govultr v0.4.2 diff --git a/go.sum b/go.sum index 8a5a6e67b..deb93320b 100644 --- a/go.sum +++ b/go.sum @@ -856,8 +856,8 @@ github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhV github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/transip/gotransip v5.8.2+incompatible h1:aNJhw/w/3QBqFcHAIPz1ytoK5FexeMzbUCGrrhWr3H0= -github.com/transip/gotransip v5.8.2+incompatible/go.mod h1:uacMoJVmrfOcscM4Bi5NVg708b7c6rz2oDTWqa7i2Ic= +github.com/transip/gotransip/v6 v6.6.0 h1:dAHCTZzX98H6QE2kA4R9acAXu5RPPTwMSUFtpKZF3Nk= +github.com/transip/gotransip/v6 v6.6.0/go.mod h1:pQZ36hWWRahCUXkFWlx9Hs711gLd8J4qdgLdRzmtY+g= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= diff --git a/provider/transip/transip.go b/provider/transip/transip.go index 297678ab0..6543da915 100644 --- a/provider/transip/transip.go +++ b/provider/transip/transip.go @@ -23,8 +23,8 @@ import ( "strings" log "github.com/sirupsen/logrus" - "github.com/transip/gotransip" - transip "github.com/transip/gotransip/domain" + "github.com/transip/gotransip/v6" + "github.com/transip/gotransip/v6/domain" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" @@ -40,9 +40,11 @@ const ( // TransIPProvider is an implementation of Provider for TransIP. type TransIPProvider struct { provider.BaseProvider - client gotransip.SOAPClient + domainRepo domain.Repository domainFilter endpoint.DomainFilter dryRun bool + + zoneMap provider.ZoneIDName } // NewTransIPProvider initializes a new TransIP Provider. @@ -64,7 +66,7 @@ func NewTransIPProvider(accountName, privateKeyFile string, domainFilter endpoin } // create new TransIP API client - c, err := gotransip.NewSOAPClient(gotransip.ClientConfig{ + client, err := gotransip.NewClient(gotransip.ClientConfiguration{ AccountName: accountName, PrivateKeyPath: privateKeyFile, Mode: apiMode, @@ -73,233 +75,280 @@ func NewTransIPProvider(accountName, privateKeyFile string, domainFilter endpoin return nil, fmt.Errorf("could not setup TransIP API client: %s", err.Error()) } - // return tipCloud struct + // return TransIPProvider struct return &TransIPProvider{ - client: c, + domainRepo: domain.Repository{Client: client}, domainFilter: domainFilter, dryRun: dryRun, + zoneMap: provider.ZoneIDName{}, }, nil } // ApplyChanges applies a given set of changes in a given zone. func (p *TransIPProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { - // build zonefinder with all our zones so we can use FindZone - // and a mapping of zones and their domain name - zones, err := p.fetchZones() + // fetch all zones we currently have + // this does NOT include any DNS entries, so we'll have to fetch these for + // each zone that gets updated + zones, err := p.domainRepo.GetAll() if err != nil { return err } - zoneNameMapper := provider.ZoneIDName{} - zonesByName := make(map[string]transip.Domain) - updatedZones := make(map[string]bool) + // refresh zone mapping + zoneMap := provider.ZoneIDName{} for _, zone := range zones { // TransIP API doesn't expose a unique identifier for zones, other than than // the domain name itself - zoneNameMapper.Add(zone.Name, zone.Name) - zonesByName[zone.Name] = zone + zoneMap.Add(zone.Name, zone.Name) } + p.zoneMap = zoneMap - // first see if we need to delete anything + // first remove obsolete DNS records for _, ep := range changes.Delete { - log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Info("endpoint has to go") + epLog := log.WithFields(log.Fields{ + "record": ep.DNSName, + "type": ep.RecordType, + }) + epLog.Info("endpoint has to go") - zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName) + zoneName, entries, err := p.entriesForEndpoint(ep) if err != nil { - log.Errorf("could not find zone for %s: %s", ep.DNSName, err.Error()) - continue + epLog.WithError(err).Error("could not get DNS entries") + return err } - log.Debugf("removing records for %s", zone.Name) + epLog = epLog.WithField("zone", zoneName) - // remove current records from DNS entry set - entries := p.removeEndpointFromEntries(ep, zone) - - // update zone in zone map - zone.DNSEntries = entries - zonesByName[zone.Name] = zone - // flag zone for updating - updatedZones[zone.Name] = true - } - - for _, ep := range changes.Create { - log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Info("endpoint is missing") - - zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName) - if err != nil { - log.Errorf("could not find zone for %s: %s", ep.DNSName, err.Error()) - continue - } - - log.Debugf("creating records for %s", zone.Name) - - // add new entries to set - zone.DNSEntries = p.addEndpointToEntries(ep, zone, zone.DNSEntries) - - // update zone in zone map - zonesByName[zone.Name] = zone - // flag zone for updating - updatedZones[zone.Name] = true - log.WithFields(log.Fields{"zone": zone.Name}).Debug("flagging for update") - } - - for _, ep := range changes.UpdateNew { - log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Debug("needs updating") - - zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName) - if err != nil { - log.WithFields(log.Fields{"record": ep.DNSName}).Warn(err.Error()) - continue - } - - // updating the records is basically finding all matching records according - // to the name and the type, removing them from the set and add the new - // records - log.WithFields(log.Fields{ - "zone": zone.Name, - "dnsname": ep.DNSName, - "recordtype": ep.RecordType, - }).Debug("removing matching entries") - - // remove current records from DNS entry set - entries := p.removeEndpointFromEntries(ep, zone) - - // add new entries to set - entries = p.addEndpointToEntries(ep, zone, entries) - - // check to see if actually anything changed in the DNSEntry set - if p.dnsEntriesAreEqual(entries, zone.DNSEntries) { - log.WithFields(log.Fields{"zone": zone.Name}).Debug("not updating identical entries") - continue - } - - // update zone in zone map - zone.DNSEntries = entries - zonesByName[zone.Name] = zone - // flag zone for updating - updatedZones[zone.Name] = true - - log.WithFields(log.Fields{"zone": zone.Name}).Debug("flagging for update") - } - - // go over all updated zones and set new DNSEntry set - for uz := range updatedZones { - zone, ok := zonesByName[uz] - if !ok { - log.WithFields(log.Fields{"zone": uz}).Debug("updated zone no longer found") + if len(entries) == 0 { + epLog.Info("no matching entries found") continue } if p.dryRun { - log.WithFields(log.Fields{"zone": zone.Name}).Info("not updating in dry-run mode") + epLog.Info("not removing DNS entries in dry-run mode") continue } - log.WithFields(log.Fields{"zone": zone.Name}).Info("updating DNS entries") - if err := transip.SetDNSEntries(p.client, zone.Name, zone.DNSEntries); err != nil { - log.WithFields(log.Fields{"zone": zone.Name, "error": err.Error()}).Warn("failed to update") + for _, entry := range entries { + log.WithFields(log.Fields{ + "domain": zoneName, + "name": entry.Name, + "type": entry.Type, + "content": entry.Content, + "ttl": entry.Expire, + }).Info("removing DNS entry") + + err = p.domainRepo.RemoveDNSEntry(zoneName, entry) + if err != nil { + epLog.WithError(err).Error("could not remove DNS entry") + return err + } + } + } + + // then create new DNS records + for _, ep := range changes.Create { + epLog := log.WithFields(log.Fields{ + "record": ep.DNSName, + "type": ep.RecordType, + }) + epLog.Info("endpoint should be created") + + zoneName, err := p.zoneNameForDNSName(ep.DNSName) + if err != nil { + epLog.WithError(err).Warn("could not find zone for endpoint") + continue + } + + epLog = epLog.WithField("zone", zoneName) + + if p.dryRun { + epLog.Info("not adding DNS entries in dry-run mode") + continue + } + + for _, entry := range dnsEntriesForEndpoint(ep, zoneName) { + log.WithFields(log.Fields{ + "domain": zoneName, + "name": entry.Name, + "type": entry.Type, + "content": entry.Content, + "ttl": entry.Expire, + }).Info("creating DNS entry") + + err = p.domainRepo.AddDNSEntry(zoneName, entry) + if err != nil { + epLog.WithError(err).Error("could not add DNS entry") + return err + } + } + } + + // then update existing DNS records + for _, ep := range changes.UpdateNew { + epLog := log.WithFields(log.Fields{ + "record": ep.DNSName, + "type": ep.RecordType, + }) + epLog.Debug("endpoint needs updating") + + zoneName, entries, err := p.entriesForEndpoint(ep) + if err != nil { + epLog.WithError(err).Error("could not get DNS entries") + return err + } + + epLog = epLog.WithField("zone", zoneName) + + if len(entries) == 0 { + epLog.Info("no matching entries found") + continue + } + + newEntries := dnsEntriesForEndpoint(ep, zoneName) + + // check to see if actually anything changed in the DNSEntry set + if dnsEntriesAreEqual(newEntries, entries) { + epLog.Debug("not updating identical DNS entries") + continue + } + + if p.dryRun { + epLog.Info("not updating DNS entries in dry-run mode") + continue + } + + // TransIP API client does have an UpdateDNSEntry call but that does only + // allow you to update the content of a DNSEntry, not the TTL + // to work around this, remove the old entry first and add the new entry + for _, entry := range entries { + log.WithFields(log.Fields{ + "domain": zoneName, + "name": entry.Name, + "type": entry.Type, + "content": entry.Content, + "ttl": entry.Expire, + }).Info("removing DNS entry") + + err = p.domainRepo.RemoveDNSEntry(zoneName, entry) + if err != nil { + epLog.WithError(err).Error("could not remove DNS entry") + return err + } + } + + for _, entry := range newEntries { + log.WithFields(log.Fields{ + "domain": zoneName, + "name": entry.Name, + "type": entry.Type, + "content": entry.Content, + "ttl": entry.Expire, + }).Info("adding DNS entry") + + err = p.domainRepo.AddDNSEntry(zoneName, entry) + if err != nil { + epLog.WithError(err).Error("could not add DNS entry") + return err + } } } return nil } -// fetchZones returns a list of all domains within the account -func (p *TransIPProvider) fetchZones() ([]transip.Domain, error) { - domainNames, err := transip.GetDomainNames(p.client) - if err != nil { - return nil, err - } - - domains, err := transip.BatchGetInfo(p.client, domainNames) - if err != nil { - return nil, err - } - - var zones []transip.Domain - for _, d := range domains { - if !p.domainFilter.Match(d.Name) { - continue - } - - zones = append(zones, d) - } - - return zones, nil -} - -// Zones returns the list of hosted zones. -func (p *TransIPProvider) Zones() ([]transip.Domain, error) { - zones, err := p.fetchZones() - if err != nil { - return nil, err - } - - return zones, nil -} - -// Records returns the list of records in a given zone. +// Records returns the list of records in all zones func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { - zones, err := p.Zones() + zones, err := p.domainRepo.GetAll() if err != nil { return nil, err } var endpoints []*endpoint.Endpoint - var name string // go over all zones and their DNS entries and create endpoints for them for _, zone := range zones { - for _, r := range zone.DNSEntries { - if !provider.SupportedRecordType(string(r.Type)) { + entries, err := p.domainRepo.GetDNSEntries(zone.Name) + if err != nil { + return nil, err + } + + for _, r := range entries { + if !provider.SupportedRecordType(r.Type) { continue } - name = p.endpointNameForRecord(r, zone) - endpoints = append(endpoints, endpoint.NewEndpointWithTTL(name, string(r.Type), endpoint.TTL(r.TTL), r.Content)) + name := endpointNameForRecord(r, zone.Name) + endpoints = append(endpoints, endpoint.NewEndpointWithTTL(name, r.Type, endpoint.TTL(r.Expire), r.Content)) } } return endpoints, nil } +func (p *TransIPProvider) entriesForEndpoint(ep *endpoint.Endpoint) (string, []domain.DNSEntry, error) { + zoneName, err := p.zoneNameForDNSName(ep.DNSName) + if err != nil { + return "", nil, err + } + + epName := recordNameForEndpoint(ep, zoneName) + dnsEntries, err := p.domainRepo.GetDNSEntries(zoneName) + if err != nil { + return zoneName, nil, err + } + + matches := []domain.DNSEntry{} + for _, entry := range dnsEntries { + if ep.RecordType != entry.Type { + continue + } + + if entry.Name == epName { + matches = append(matches, entry) + } + } + + return zoneName, matches, nil +} + // endpointNameForRecord returns "www.example.org" for DNSEntry with Name "www" and // Domain with Name "example.org" -func (p *TransIPProvider) endpointNameForRecord(r transip.DNSEntry, d transip.Domain) string { +func endpointNameForRecord(r domain.DNSEntry, zoneName string) string { // root name is identified by "@" and should be translated to domain name for // the endpoint entry. if r.Name == "@" { - return d.Name + return zoneName } - return fmt.Sprintf("%s.%s", r.Name, d.Name) + return fmt.Sprintf("%s.%s", r.Name, zoneName) } // recordNameForEndpoint returns "www" for Endpoint with DNSName "www.example.org" // and Domain with Name "example.org" -func (p *TransIPProvider) recordNameForEndpoint(ep *endpoint.Endpoint, d transip.Domain) string { +func recordNameForEndpoint(ep *endpoint.Endpoint, zoneName string) string { // root name is identified by "@" and should be translated to domain name for // the endpoint entry. - if ep.DNSName == d.Name { + if ep.DNSName == zoneName { return "@" } - return strings.TrimSuffix(ep.DNSName, "."+d.Name) + return strings.TrimSuffix(ep.DNSName, "."+zoneName) } // getMinimalValidTTL returns max between given Endpoint's RecordTTL and // transipMinimalValidTTL -func (p *TransIPProvider) getMinimalValidTTL(ep *endpoint.Endpoint) int64 { +func getMinimalValidTTL(ep *endpoint.Endpoint) int { // TTL cannot be lower than transipMinimalValidTTL if ep.RecordTTL < transipMinimalValidTTL { return transipMinimalValidTTL } - return int64(ep.RecordTTL) + return int(ep.RecordTTL) } // dnsEntriesAreEqual compares the entries in 2 sets and returns true if the // content of the entries is equal -func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool { +func dnsEntriesAreEqual(a, b []domain.DNSEntry) bool { if len(a) != len(b) { return false } @@ -315,7 +364,7 @@ func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool { continue } - if aa.TTL != bb.TTL { + if aa.Expire != bb.Expire { continue } @@ -330,45 +379,22 @@ func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool { return (len(a) == match) } -// removeEndpointFromEntries removes DNS entries from zone's set that match the -// type and name from given endpoint and returns the resulting DNS entry set -func (p *TransIPProvider) removeEndpointFromEntries(ep *endpoint.Endpoint, zone transip.Domain) transip.DNSEntries { - // create new entry set - entries := transip.DNSEntries{} - // go over each DNS entry to see if it is a match - for _, e := range zone.DNSEntries { - // if we have match, don't copy it to the new entry set - if p.endpointNameForRecord(e, zone) == ep.DNSName && string(e.Type) == ep.RecordType { - log.WithFields(log.Fields{ - "name": e.Name, - "content": e.Content, - "type": e.Type, - }).Debug("found match") - continue +// dnsEntriesForEndpoint creates DNS entries for given endpoint and returns +// resulting DNS entry set +func dnsEntriesForEndpoint(ep *endpoint.Endpoint, zoneName string) []domain.DNSEntry { + ttl := getMinimalValidTTL(ep) + + entries := []domain.DNSEntry{} + for _, target := range ep.Targets { + // external hostnames require a trailing dot in TransIP API + if ep.RecordType == "CNAME" { + target = provider.EnsureTrailingDot(target) } - entries = append(entries, e) - } - - return entries -} - -// addEndpointToEntries creates DNS entries for given endpoint and returns -// resulting DNS entry set -func (p *TransIPProvider) addEndpointToEntries(ep *endpoint.Endpoint, zone transip.Domain, entries transip.DNSEntries) transip.DNSEntries { - ttl := p.getMinimalValidTTL(ep) - for _, target := range ep.Targets { - log.WithFields(log.Fields{ - "zone": zone.Name, - "dnsname": ep.DNSName, - "recordtype": ep.RecordType, - "ttl": ttl, - "target": target, - }).Debugf("adding new record") - entries = append(entries, transip.DNSEntry{ - Name: p.recordNameForEndpoint(ep, zone), - TTL: ttl, - Type: transip.DNSEntryType(ep.RecordType), + entries = append(entries, domain.DNSEntry{ + Name: recordNameForEndpoint(ep, zoneName), + Expire: ttl, + Type: ep.RecordType, Content: target, }) } @@ -378,16 +404,11 @@ func (p *TransIPProvider) addEndpointToEntries(ep *endpoint.Endpoint, zone trans // zoneForZoneName returns the zone mapped to given name or error if zone could // not be found -func (p *TransIPProvider) zoneForZoneName(name string, m provider.ZoneIDName, z map[string]transip.Domain) (transip.Domain, error) { - _, zoneName := m.FindZone(name) +func (p *TransIPProvider) zoneNameForDNSName(name string) (string, error) { + _, zoneName := p.zoneMap.FindZone(name) if zoneName == "" { - return transip.Domain{}, fmt.Errorf("could not find zoneName for %s", name) + return "", fmt.Errorf("could not find zoneName for %s", name) } - zone, ok := z[zoneName] - if !ok { - return zone, fmt.Errorf("could not find zone for %s", zoneName) - } - - return zone, nil + return zoneName, nil } diff --git a/provider/transip/transip_test.go b/provider/transip/transip_test.go index 76e090b78..91c622afb 100644 --- a/provider/transip/transip_test.go +++ b/provider/transip/transip_test.go @@ -17,116 +17,123 @@ limitations under the License. package transip import ( + "context" + "encoding/json" + "errors" + "strings" "testing" "github.com/stretchr/testify/assert" - transip "github.com/transip/gotransip/domain" + "github.com/stretchr/testify/require" + "github.com/transip/gotransip/v6/domain" + "github.com/transip/gotransip/v6/rest" "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/provider" ) +func newProvider() *TransIPProvider { + return &TransIPProvider{ + zoneMap: provider.ZoneIDName{}, + } +} + func TestTransIPDnsEntriesAreEqual(t *testing.T) { - p := TransIPProvider{} // test with equal set - a := transip.DNSEntries{ - transip.DNSEntry{ + a := []domain.DNSEntry{ + { Name: "www.example.org", - Type: transip.DNSEntryTypeCNAME, - TTL: 3600, + Type: "CNAME", + Expire: 3600, Content: "www.example.com", }, - transip.DNSEntry{ + { Name: "www.example.com", - Type: transip.DNSEntryTypeA, - TTL: 3600, + Type: "A", + Expire: 3600, Content: "192.168.0.1", }, } - b := transip.DNSEntries{ - transip.DNSEntry{ + b := []domain.DNSEntry{ + { Name: "www.example.com", - Type: transip.DNSEntryTypeA, - TTL: 3600, + Type: "A", + Expire: 3600, Content: "192.168.0.1", }, - transip.DNSEntry{ + { Name: "www.example.org", - Type: transip.DNSEntryTypeCNAME, - TTL: 3600, + Type: "CNAME", + Expire: 3600, Content: "www.example.com", }, } - assert.Equal(t, true, p.dnsEntriesAreEqual(a, b)) + assert.Equal(t, true, dnsEntriesAreEqual(a, b)) // change type on one of b's records - b[1].Type = transip.DNSEntryTypeNS - assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) - b[1].Type = transip.DNSEntryTypeCNAME + b[1].Type = "NS" + assert.Equal(t, false, dnsEntriesAreEqual(a, b)) + b[1].Type = "CNAME" // change ttl on one of b's records - b[1].TTL = 1800 - assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) - b[1].TTL = 3600 + b[1].Expire = 1800 + assert.Equal(t, false, dnsEntriesAreEqual(a, b)) + b[1].Expire = 3600 // change name on one of b's records b[1].Name = "example.org" - assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) + assert.Equal(t, false, dnsEntriesAreEqual(a, b)) // remove last entry of b b = b[:1] - assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) + assert.Equal(t, false, dnsEntriesAreEqual(a, b)) } func TestTransIPGetMinimalValidTTL(t *testing.T) { - p := TransIPProvider{} // test with 'unconfigured' TTL ep := &endpoint.Endpoint{} - assert.Equal(t, int64(transipMinimalValidTTL), p.getMinimalValidTTL(ep)) + assert.EqualValues(t, transipMinimalValidTTL, getMinimalValidTTL(ep)) // test with lower than minimal ttl ep.RecordTTL = (transipMinimalValidTTL - 1) - assert.Equal(t, int64(transipMinimalValidTTL), p.getMinimalValidTTL(ep)) + assert.EqualValues(t, transipMinimalValidTTL, getMinimalValidTTL(ep)) // test with higher than minimal ttl ep.RecordTTL = (transipMinimalValidTTL + 1) - assert.Equal(t, int64(transipMinimalValidTTL+1), p.getMinimalValidTTL(ep)) + assert.EqualValues(t, transipMinimalValidTTL+1, getMinimalValidTTL(ep)) } func TestTransIPRecordNameForEndpoint(t *testing.T) { - p := TransIPProvider{} ep := &endpoint.Endpoint{ DNSName: "example.org", } - d := transip.Domain{ + d := domain.Domain{ Name: "example.org", } - assert.Equal(t, "@", p.recordNameForEndpoint(ep, d)) + assert.Equal(t, "@", recordNameForEndpoint(ep, d.Name)) ep.DNSName = "www.example.org" - assert.Equal(t, "www", p.recordNameForEndpoint(ep, d)) + assert.Equal(t, "www", recordNameForEndpoint(ep, d.Name)) } func TestTransIPEndpointNameForRecord(t *testing.T) { - p := TransIPProvider{} - r := transip.DNSEntry{ + r := domain.DNSEntry{ Name: "@", } - d := transip.Domain{ + d := domain.Domain{ Name: "example.org", } - assert.Equal(t, d.Name, p.endpointNameForRecord(r, d)) + assert.Equal(t, d.Name, endpointNameForRecord(r, d.Name)) r.Name = "www" - assert.Equal(t, "www.example.org", p.endpointNameForRecord(r, d)) + assert.Equal(t, "www.example.org", endpointNameForRecord(r, d.Name)) } func TestTransIPAddEndpointToEntries(t *testing.T) { - p := TransIPProvider{} - // prepare endpoint ep := &endpoint.Endpoint{ DNSName: "www.example.org", @@ -139,94 +146,193 @@ func TestTransIPAddEndpointToEntries(t *testing.T) { } // prepare zone with DNS entry set - zone := transip.Domain{ + zone := domain.Domain{ Name: "example.org", - // 2 matching A records - DNSEntries: transip.DNSEntries{ - // 1 non-matching A record - transip.DNSEntry{ - Name: "mail", - Type: transip.DNSEntryTypeA, - Content: "192.168.0.1", - TTL: 3600, - }, - // 1 non-matching MX record - transip.DNSEntry{ - Name: "@", - Type: transip.DNSEntryTypeMX, - Content: "mail.example.org", - TTL: 3600, - }, - }, } // add endpoint to zone's entries - result := p.addEndpointToEntries(ep, zone, zone.DNSEntries) + result := dnsEntriesForEndpoint(ep, zone.Name) - assert.Equal(t, 4, len(result)) - assert.Equal(t, "mail", result[0].Name) - assert.Equal(t, transip.DNSEntryTypeA, result[0].Type) - assert.Equal(t, "@", result[1].Name) - assert.Equal(t, transip.DNSEntryTypeMX, result[1].Type) - assert.Equal(t, "www", result[2].Name) - assert.Equal(t, transip.DNSEntryTypeA, result[2].Type) - assert.Equal(t, "192.168.0.1", result[2].Content) - assert.Equal(t, int64(1800), result[2].TTL) - assert.Equal(t, "www", result[3].Name) - assert.Equal(t, transip.DNSEntryTypeA, result[3].Type) - assert.Equal(t, "192.168.0.2", result[3].Content) - assert.Equal(t, int64(1800), result[3].TTL) -} - -func TestTransIPRemoveEndpointFromEntries(t *testing.T) { - p := TransIPProvider{} - - // prepare endpoint - ep := &endpoint.Endpoint{ - DNSName: "www.example.org", - RecordType: "A", + if assert.Equal(t, 2, len(result)) { + assert.Equal(t, "www", result[0].Name) + assert.Equal(t, "A", result[0].Type) + assert.Equal(t, "192.168.0.1", result[0].Content) + assert.EqualValues(t, 1800, result[0].Expire) + assert.Equal(t, "www", result[1].Name) + assert.Equal(t, "A", result[1].Type) + assert.Equal(t, "192.168.0.2", result[1].Content) + assert.EqualValues(t, 1800, result[1].Expire) } - // prepare zone with DNS entry set - zone := transip.Domain{ - Name: "example.org", - // 2 matching A records - DNSEntries: transip.DNSEntries{ - transip.DNSEntry{ - Name: "www", - Type: transip.DNSEntryTypeA, - Content: "192.168.0.1", - TTL: 3600, - }, - transip.DNSEntry{ - Name: "www", - Type: transip.DNSEntryTypeA, - Content: "192.168.0.2", - TTL: 3600, - }, - // 1 non-matching A record - transip.DNSEntry{ - Name: "mail", - Type: transip.DNSEntryTypeA, - Content: "192.168.0.1", - TTL: 3600, - }, - // 1 non-matching MX record - transip.DNSEntry{ - Name: "@", - Type: transip.DNSEntryTypeMX, - Content: "mail.example.org", - TTL: 3600, - }, + // try again with CNAME + ep.RecordType = "CNAME" + ep.Targets = []string{"foo.bar"} + result = dnsEntriesForEndpoint(ep, zone.Name) + if assert.Equal(t, 1, len(result)) { + assert.Equal(t, "CNAME", result[0].Type) + assert.Equal(t, "foo.bar.", result[0].Content) + } +} + +func TestZoneNameForDNSName(t *testing.T) { + p := newProvider() + p.zoneMap.Add("example.com", "example.com") + + zoneName, err := p.zoneNameForDNSName("www.example.com") + if assert.NoError(t, err) { + assert.Equal(t, "example.com", zoneName) + } + + _, err = p.zoneNameForDNSName("www.example.org") + if assert.Error(t, err) { + assert.Equal(t, "could not find zoneName for www.example.org", err.Error()) + } +} + +// fakeClient mocks the REST API client +type fakeClient struct { + getFunc func(rest.Request, interface{}) error +} + +func (f *fakeClient) Get(request rest.Request, dest interface{}) error { + if f.getFunc == nil { + return errors.New("GET not defined") + } + + return f.getFunc(request, dest) +} + +func (f fakeClient) Put(request rest.Request) error { + return errors.New("PUT not implemented") +} + +func (f fakeClient) Post(request rest.Request) error { + return errors.New("POST not implemented") +} + +func (f fakeClient) Delete(request rest.Request) error { + return errors.New("DELETE not implemented") +} + +func (f fakeClient) Patch(request rest.Request) error { + return errors.New("PATCH not implemented") +} + +func TestProviderRecords(t *testing.T) { + // set up the fake REST client + client := &fakeClient{} + client.getFunc = func(req rest.Request, dest interface{}) error { + var data []byte + switch { + case req.Endpoint == "/domains": + // return list of some domain names + // names only, other fields are not used + data = []byte(`{"domains":[{"name":"example.org"}, {"name":"example.com"}]}`) + case strings.HasSuffix(req.Endpoint, "/dns"): + // return list of DNS entries + // also some unsupported types + data = []byte(`{"dnsEntries":[{"name":"www", "expire":1234, "type":"CNAME", "content":"@"},{"type":"MX"},{"type":"AAAA"}]}`) + } + + // unmarshal the prepared return data into the given destination type + return json.Unmarshal(data, &dest) + } + + // set up provider + p := newProvider() + p.domainRepo = domain.Repository{Client: client} + + endpoints, err := p.Records(context.TODO()) + if assert.NoError(t, err) { + if assert.Equal(t, 2, len(endpoints)) { + assert.Equal(t, "www.example.org", endpoints[0].DNSName) + assert.EqualValues(t, "@", endpoints[0].Targets[0]) + assert.Equal(t, "CNAME", endpoints[0].RecordType) + assert.Equal(t, 0, len(endpoints[0].Labels)) + assert.EqualValues(t, 1234, endpoints[0].RecordTTL) + } + } +} + +func TestProviderEntriesForEndpoint(t *testing.T) { + // set up fake REST client + client := &fakeClient{} + + // set up provider + p := newProvider() + p.domainRepo = domain.Repository{Client: client} + p.zoneMap.Add("example.com", "example.com") + + // get entries for endpoint with unknown zone + _, _, err := p.entriesForEndpoint(&endpoint.Endpoint{ + DNSName: "www.example.org", + }) + if assert.Error(t, err) { + assert.Equal(t, "could not find zoneName for www.example.org", err.Error()) + } + + // get entries for endpoint with known zone but client returns error + // we leave GET functions undefined so we know which error to expect + zoneName, _, err := p.entriesForEndpoint(&endpoint.Endpoint{ + DNSName: "www.example.com", + }) + if assert.Error(t, err) { + assert.Equal(t, "GET not defined", err.Error()) + } + assert.Equal(t, "example.com", zoneName) + + // to be able to return a valid set of DNS entries through the API, we define + // some first, then JSON encode them and have the fake API client's Get function + // return that + // in this set are some entries that do and others that don't match the given + // endpoint + dnsEntries := []domain.DNSEntry{ + { + Name: "www", + Type: "A", + Expire: 3600, + Content: "1.2.3.4", + }, + { + Name: "ftp", + Type: "A", + Expire: 86400, + Content: "3.4.5.6", + }, + { + Name: "www", + Type: "A", + Expire: 3600, + Content: "2.3.4.5", + }, + { + Name: "www", + Type: "CNAME", + Expire: 3600, + Content: "@", }, } + var v struct { + DNSEntries []domain.DNSEntry `json:"dnsEntries"` + } + v.DNSEntries = dnsEntries + returnData, err := json.Marshal(&v) + require.NoError(t, err) - // remove endpoint from zone's entries - result := p.removeEndpointFromEntries(ep, zone) - - assert.Equal(t, 2, len(result)) - assert.Equal(t, "mail", result[0].Name) - assert.Equal(t, transip.DNSEntryTypeA, result[0].Type) - assert.Equal(t, "@", result[1].Name) - assert.Equal(t, transip.DNSEntryTypeMX, result[1].Type) + // define GET function + client.getFunc = func(unused rest.Request, dest interface{}) error { + // unmarshal the prepared return data into the given dnsEntriesWrapper + return json.Unmarshal(returnData, &dest) + } + _, entries, err := p.entriesForEndpoint(&endpoint.Endpoint{ + DNSName: "www.example.com", + RecordType: "A", + }) + if assert.NoError(t, err) { + if assert.Equal(t, 2, len(entries)) { + // only first and third entry should be returned + assert.Equal(t, dnsEntries[0], entries[0]) + assert.Equal(t, dnsEntries[2], entries[1]) + } + } }