update TransIP's Go client to v6

This commit is contained in:
Reinier Schoof 2021-03-23 12:17:32 +01:00
parent 6ca57a58d3
commit b09d7abc55
4 changed files with 446 additions and 319 deletions

2
go.mod
View File

@ -50,7 +50,7 @@ require (
github.com/smartystreets/gunit v1.3.4 // indirect github.com/smartystreets/gunit v1.3.4 // indirect
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
github.com/terra-farm/udnssdk v1.3.5 // indirect github.com/terra-farm/udnssdk v1.3.5 // indirect
github.com/transip/gotransip v5.8.2+incompatible github.com/transip/gotransip/v6 v6.6.0
github.com/ultradns/ultradns-sdk-go v0.0.0-20200616202852-e62052662f60 github.com/ultradns/ultradns-sdk-go v0.0.0-20200616202852-e62052662f60
github.com/vinyldns/go-vinyldns v0.0.0-20200211145900-fe8a3d82e556 github.com/vinyldns/go-vinyldns v0.0.0-20200211145900-fe8a3d82e556
github.com/vultr/govultr v0.4.2 github.com/vultr/govultr v0.4.2

4
go.sum
View File

@ -856,8 +856,8 @@ github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhV
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/transip/gotransip v5.8.2+incompatible h1:aNJhw/w/3QBqFcHAIPz1ytoK5FexeMzbUCGrrhWr3H0= github.com/transip/gotransip/v6 v6.6.0 h1:dAHCTZzX98H6QE2kA4R9acAXu5RPPTwMSUFtpKZF3Nk=
github.com/transip/gotransip v5.8.2+incompatible/go.mod h1:uacMoJVmrfOcscM4Bi5NVg708b7c6rz2oDTWqa7i2Ic= github.com/transip/gotransip/v6 v6.6.0/go.mod h1:pQZ36hWWRahCUXkFWlx9Hs711gLd8J4qdgLdRzmtY+g=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8=

View File

@ -23,8 +23,8 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/transip/gotransip" "github.com/transip/gotransip/v6"
transip "github.com/transip/gotransip/domain" "github.com/transip/gotransip/v6/domain"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/plan"
@ -40,9 +40,11 @@ const (
// TransIPProvider is an implementation of Provider for TransIP. // TransIPProvider is an implementation of Provider for TransIP.
type TransIPProvider struct { type TransIPProvider struct {
provider.BaseProvider provider.BaseProvider
client gotransip.SOAPClient domainRepo domain.Repository
domainFilter endpoint.DomainFilter domainFilter endpoint.DomainFilter
dryRun bool dryRun bool
zoneMap provider.ZoneIDName
} }
// NewTransIPProvider initializes a new TransIP Provider. // NewTransIPProvider initializes a new TransIP Provider.
@ -64,7 +66,7 @@ func NewTransIPProvider(accountName, privateKeyFile string, domainFilter endpoin
} }
// create new TransIP API client // create new TransIP API client
c, err := gotransip.NewSOAPClient(gotransip.ClientConfig{ client, err := gotransip.NewClient(gotransip.ClientConfiguration{
AccountName: accountName, AccountName: accountName,
PrivateKeyPath: privateKeyFile, PrivateKeyPath: privateKeyFile,
Mode: apiMode, Mode: apiMode,
@ -73,233 +75,280 @@ func NewTransIPProvider(accountName, privateKeyFile string, domainFilter endpoin
return nil, fmt.Errorf("could not setup TransIP API client: %s", err.Error()) return nil, fmt.Errorf("could not setup TransIP API client: %s", err.Error())
} }
// return tipCloud struct // return TransIPProvider struct
return &TransIPProvider{ return &TransIPProvider{
client: c, domainRepo: domain.Repository{Client: client},
domainFilter: domainFilter, domainFilter: domainFilter,
dryRun: dryRun, dryRun: dryRun,
zoneMap: provider.ZoneIDName{},
}, nil }, nil
} }
// ApplyChanges applies a given set of changes in a given zone. // ApplyChanges applies a given set of changes in a given zone.
func (p *TransIPProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { func (p *TransIPProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// build zonefinder with all our zones so we can use FindZone // fetch all zones we currently have
// and a mapping of zones and their domain name // this does NOT include any DNS entries, so we'll have to fetch these for
zones, err := p.fetchZones() // each zone that gets updated
zones, err := p.domainRepo.GetAll()
if err != nil { if err != nil {
return err return err
} }
zoneNameMapper := provider.ZoneIDName{} // refresh zone mapping
zonesByName := make(map[string]transip.Domain) zoneMap := provider.ZoneIDName{}
updatedZones := make(map[string]bool)
for _, zone := range zones { for _, zone := range zones {
// TransIP API doesn't expose a unique identifier for zones, other than than // TransIP API doesn't expose a unique identifier for zones, other than than
// the domain name itself // the domain name itself
zoneNameMapper.Add(zone.Name, zone.Name) zoneMap.Add(zone.Name, zone.Name)
zonesByName[zone.Name] = zone
} }
p.zoneMap = zoneMap
// first see if we need to delete anything // first remove obsolete DNS records
for _, ep := range changes.Delete { for _, ep := range changes.Delete {
log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Info("endpoint has to go") epLog := log.WithFields(log.Fields{
"record": ep.DNSName,
"type": ep.RecordType,
})
epLog.Info("endpoint has to go")
zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName) zoneName, entries, err := p.entriesForEndpoint(ep)
if err != nil { if err != nil {
log.Errorf("could not find zone for %s: %s", ep.DNSName, err.Error()) epLog.WithError(err).Error("could not get DNS entries")
continue return err
} }
log.Debugf("removing records for %s", zone.Name) epLog = epLog.WithField("zone", zoneName)
// remove current records from DNS entry set if len(entries) == 0 {
entries := p.removeEndpointFromEntries(ep, zone) epLog.Info("no matching entries found")
// update zone in zone map
zone.DNSEntries = entries
zonesByName[zone.Name] = zone
// flag zone for updating
updatedZones[zone.Name] = true
}
for _, ep := range changes.Create {
log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Info("endpoint is missing")
zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName)
if err != nil {
log.Errorf("could not find zone for %s: %s", ep.DNSName, err.Error())
continue
}
log.Debugf("creating records for %s", zone.Name)
// add new entries to set
zone.DNSEntries = p.addEndpointToEntries(ep, zone, zone.DNSEntries)
// update zone in zone map
zonesByName[zone.Name] = zone
// flag zone for updating
updatedZones[zone.Name] = true
log.WithFields(log.Fields{"zone": zone.Name}).Debug("flagging for update")
}
for _, ep := range changes.UpdateNew {
log.WithFields(log.Fields{"record": ep.DNSName, "type": ep.RecordType}).Debug("needs updating")
zone, err := p.zoneForZoneName(ep.DNSName, zoneNameMapper, zonesByName)
if err != nil {
log.WithFields(log.Fields{"record": ep.DNSName}).Warn(err.Error())
continue
}
// updating the records is basically finding all matching records according
// to the name and the type, removing them from the set and add the new
// records
log.WithFields(log.Fields{
"zone": zone.Name,
"dnsname": ep.DNSName,
"recordtype": ep.RecordType,
}).Debug("removing matching entries")
// remove current records from DNS entry set
entries := p.removeEndpointFromEntries(ep, zone)
// add new entries to set
entries = p.addEndpointToEntries(ep, zone, entries)
// check to see if actually anything changed in the DNSEntry set
if p.dnsEntriesAreEqual(entries, zone.DNSEntries) {
log.WithFields(log.Fields{"zone": zone.Name}).Debug("not updating identical entries")
continue
}
// update zone in zone map
zone.DNSEntries = entries
zonesByName[zone.Name] = zone
// flag zone for updating
updatedZones[zone.Name] = true
log.WithFields(log.Fields{"zone": zone.Name}).Debug("flagging for update")
}
// go over all updated zones and set new DNSEntry set
for uz := range updatedZones {
zone, ok := zonesByName[uz]
if !ok {
log.WithFields(log.Fields{"zone": uz}).Debug("updated zone no longer found")
continue continue
} }
if p.dryRun { if p.dryRun {
log.WithFields(log.Fields{"zone": zone.Name}).Info("not updating in dry-run mode") epLog.Info("not removing DNS entries in dry-run mode")
continue continue
} }
log.WithFields(log.Fields{"zone": zone.Name}).Info("updating DNS entries") for _, entry := range entries {
if err := transip.SetDNSEntries(p.client, zone.Name, zone.DNSEntries); err != nil { log.WithFields(log.Fields{
log.WithFields(log.Fields{"zone": zone.Name, "error": err.Error()}).Warn("failed to update") "domain": zoneName,
"name": entry.Name,
"type": entry.Type,
"content": entry.Content,
"ttl": entry.Expire,
}).Info("removing DNS entry")
err = p.domainRepo.RemoveDNSEntry(zoneName, entry)
if err != nil {
epLog.WithError(err).Error("could not remove DNS entry")
return err
}
}
}
// then create new DNS records
for _, ep := range changes.Create {
epLog := log.WithFields(log.Fields{
"record": ep.DNSName,
"type": ep.RecordType,
})
epLog.Info("endpoint should be created")
zoneName, err := p.zoneNameForDNSName(ep.DNSName)
if err != nil {
epLog.WithError(err).Warn("could not find zone for endpoint")
continue
}
epLog = epLog.WithField("zone", zoneName)
if p.dryRun {
epLog.Info("not adding DNS entries in dry-run mode")
continue
}
for _, entry := range dnsEntriesForEndpoint(ep, zoneName) {
log.WithFields(log.Fields{
"domain": zoneName,
"name": entry.Name,
"type": entry.Type,
"content": entry.Content,
"ttl": entry.Expire,
}).Info("creating DNS entry")
err = p.domainRepo.AddDNSEntry(zoneName, entry)
if err != nil {
epLog.WithError(err).Error("could not add DNS entry")
return err
}
}
}
// then update existing DNS records
for _, ep := range changes.UpdateNew {
epLog := log.WithFields(log.Fields{
"record": ep.DNSName,
"type": ep.RecordType,
})
epLog.Debug("endpoint needs updating")
zoneName, entries, err := p.entriesForEndpoint(ep)
if err != nil {
epLog.WithError(err).Error("could not get DNS entries")
return err
}
epLog = epLog.WithField("zone", zoneName)
if len(entries) == 0 {
epLog.Info("no matching entries found")
continue
}
newEntries := dnsEntriesForEndpoint(ep, zoneName)
// check to see if actually anything changed in the DNSEntry set
if dnsEntriesAreEqual(newEntries, entries) {
epLog.Debug("not updating identical DNS entries")
continue
}
if p.dryRun {
epLog.Info("not updating DNS entries in dry-run mode")
continue
}
// TransIP API client does have an UpdateDNSEntry call but that does only
// allow you to update the content of a DNSEntry, not the TTL
// to work around this, remove the old entry first and add the new entry
for _, entry := range entries {
log.WithFields(log.Fields{
"domain": zoneName,
"name": entry.Name,
"type": entry.Type,
"content": entry.Content,
"ttl": entry.Expire,
}).Info("removing DNS entry")
err = p.domainRepo.RemoveDNSEntry(zoneName, entry)
if err != nil {
epLog.WithError(err).Error("could not remove DNS entry")
return err
}
}
for _, entry := range newEntries {
log.WithFields(log.Fields{
"domain": zoneName,
"name": entry.Name,
"type": entry.Type,
"content": entry.Content,
"ttl": entry.Expire,
}).Info("adding DNS entry")
err = p.domainRepo.AddDNSEntry(zoneName, entry)
if err != nil {
epLog.WithError(err).Error("could not add DNS entry")
return err
}
} }
} }
return nil return nil
} }
// fetchZones returns a list of all domains within the account // Records returns the list of records in all zones
func (p *TransIPProvider) fetchZones() ([]transip.Domain, error) {
domainNames, err := transip.GetDomainNames(p.client)
if err != nil {
return nil, err
}
domains, err := transip.BatchGetInfo(p.client, domainNames)
if err != nil {
return nil, err
}
var zones []transip.Domain
for _, d := range domains {
if !p.domainFilter.Match(d.Name) {
continue
}
zones = append(zones, d)
}
return zones, nil
}
// Zones returns the list of hosted zones.
func (p *TransIPProvider) Zones() ([]transip.Domain, error) {
zones, err := p.fetchZones()
if err != nil {
return nil, err
}
return zones, nil
}
// Records returns the list of records in a given zone.
func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones() zones, err := p.domainRepo.GetAll()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var endpoints []*endpoint.Endpoint var endpoints []*endpoint.Endpoint
var name string
// go over all zones and their DNS entries and create endpoints for them // go over all zones and their DNS entries and create endpoints for them
for _, zone := range zones { for _, zone := range zones {
for _, r := range zone.DNSEntries { entries, err := p.domainRepo.GetDNSEntries(zone.Name)
if !provider.SupportedRecordType(string(r.Type)) { if err != nil {
return nil, err
}
for _, r := range entries {
if !provider.SupportedRecordType(r.Type) {
continue continue
} }
name = p.endpointNameForRecord(r, zone) name := endpointNameForRecord(r, zone.Name)
endpoints = append(endpoints, endpoint.NewEndpointWithTTL(name, string(r.Type), endpoint.TTL(r.TTL), r.Content)) endpoints = append(endpoints, endpoint.NewEndpointWithTTL(name, r.Type, endpoint.TTL(r.Expire), r.Content))
} }
} }
return endpoints, nil return endpoints, nil
} }
func (p *TransIPProvider) entriesForEndpoint(ep *endpoint.Endpoint) (string, []domain.DNSEntry, error) {
zoneName, err := p.zoneNameForDNSName(ep.DNSName)
if err != nil {
return "", nil, err
}
epName := recordNameForEndpoint(ep, zoneName)
dnsEntries, err := p.domainRepo.GetDNSEntries(zoneName)
if err != nil {
return zoneName, nil, err
}
matches := []domain.DNSEntry{}
for _, entry := range dnsEntries {
if ep.RecordType != entry.Type {
continue
}
if entry.Name == epName {
matches = append(matches, entry)
}
}
return zoneName, matches, nil
}
// endpointNameForRecord returns "www.example.org" for DNSEntry with Name "www" and // endpointNameForRecord returns "www.example.org" for DNSEntry with Name "www" and
// Domain with Name "example.org" // Domain with Name "example.org"
func (p *TransIPProvider) endpointNameForRecord(r transip.DNSEntry, d transip.Domain) string { func endpointNameForRecord(r domain.DNSEntry, zoneName string) string {
// root name is identified by "@" and should be translated to domain name for // root name is identified by "@" and should be translated to domain name for
// the endpoint entry. // the endpoint entry.
if r.Name == "@" { if r.Name == "@" {
return d.Name return zoneName
} }
return fmt.Sprintf("%s.%s", r.Name, d.Name) return fmt.Sprintf("%s.%s", r.Name, zoneName)
} }
// recordNameForEndpoint returns "www" for Endpoint with DNSName "www.example.org" // recordNameForEndpoint returns "www" for Endpoint with DNSName "www.example.org"
// and Domain with Name "example.org" // and Domain with Name "example.org"
func (p *TransIPProvider) recordNameForEndpoint(ep *endpoint.Endpoint, d transip.Domain) string { func recordNameForEndpoint(ep *endpoint.Endpoint, zoneName string) string {
// root name is identified by "@" and should be translated to domain name for // root name is identified by "@" and should be translated to domain name for
// the endpoint entry. // the endpoint entry.
if ep.DNSName == d.Name { if ep.DNSName == zoneName {
return "@" return "@"
} }
return strings.TrimSuffix(ep.DNSName, "."+d.Name) return strings.TrimSuffix(ep.DNSName, "."+zoneName)
} }
// getMinimalValidTTL returns max between given Endpoint's RecordTTL and // getMinimalValidTTL returns max between given Endpoint's RecordTTL and
// transipMinimalValidTTL // transipMinimalValidTTL
func (p *TransIPProvider) getMinimalValidTTL(ep *endpoint.Endpoint) int64 { func getMinimalValidTTL(ep *endpoint.Endpoint) int {
// TTL cannot be lower than transipMinimalValidTTL // TTL cannot be lower than transipMinimalValidTTL
if ep.RecordTTL < transipMinimalValidTTL { if ep.RecordTTL < transipMinimalValidTTL {
return transipMinimalValidTTL return transipMinimalValidTTL
} }
return int64(ep.RecordTTL) return int(ep.RecordTTL)
} }
// dnsEntriesAreEqual compares the entries in 2 sets and returns true if the // dnsEntriesAreEqual compares the entries in 2 sets and returns true if the
// content of the entries is equal // content of the entries is equal
func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool { func dnsEntriesAreEqual(a, b []domain.DNSEntry) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false
} }
@ -315,7 +364,7 @@ func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool {
continue continue
} }
if aa.TTL != bb.TTL { if aa.Expire != bb.Expire {
continue continue
} }
@ -330,45 +379,22 @@ func (p *TransIPProvider) dnsEntriesAreEqual(a, b transip.DNSEntries) bool {
return (len(a) == match) return (len(a) == match)
} }
// removeEndpointFromEntries removes DNS entries from zone's set that match the // dnsEntriesForEndpoint creates DNS entries for given endpoint and returns
// type and name from given endpoint and returns the resulting DNS entry set // resulting DNS entry set
func (p *TransIPProvider) removeEndpointFromEntries(ep *endpoint.Endpoint, zone transip.Domain) transip.DNSEntries { func dnsEntriesForEndpoint(ep *endpoint.Endpoint, zoneName string) []domain.DNSEntry {
// create new entry set ttl := getMinimalValidTTL(ep)
entries := transip.DNSEntries{}
// go over each DNS entry to see if it is a match entries := []domain.DNSEntry{}
for _, e := range zone.DNSEntries { for _, target := range ep.Targets {
// if we have match, don't copy it to the new entry set // external hostnames require a trailing dot in TransIP API
if p.endpointNameForRecord(e, zone) == ep.DNSName && string(e.Type) == ep.RecordType { if ep.RecordType == "CNAME" {
log.WithFields(log.Fields{ target = provider.EnsureTrailingDot(target)
"name": e.Name,
"content": e.Content,
"type": e.Type,
}).Debug("found match")
continue
} }
entries = append(entries, e) entries = append(entries, domain.DNSEntry{
} Name: recordNameForEndpoint(ep, zoneName),
Expire: ttl,
return entries Type: ep.RecordType,
}
// addEndpointToEntries creates DNS entries for given endpoint and returns
// resulting DNS entry set
func (p *TransIPProvider) addEndpointToEntries(ep *endpoint.Endpoint, zone transip.Domain, entries transip.DNSEntries) transip.DNSEntries {
ttl := p.getMinimalValidTTL(ep)
for _, target := range ep.Targets {
log.WithFields(log.Fields{
"zone": zone.Name,
"dnsname": ep.DNSName,
"recordtype": ep.RecordType,
"ttl": ttl,
"target": target,
}).Debugf("adding new record")
entries = append(entries, transip.DNSEntry{
Name: p.recordNameForEndpoint(ep, zone),
TTL: ttl,
Type: transip.DNSEntryType(ep.RecordType),
Content: target, Content: target,
}) })
} }
@ -378,16 +404,11 @@ func (p *TransIPProvider) addEndpointToEntries(ep *endpoint.Endpoint, zone trans
// zoneForZoneName returns the zone mapped to given name or error if zone could // zoneForZoneName returns the zone mapped to given name or error if zone could
// not be found // not be found
func (p *TransIPProvider) zoneForZoneName(name string, m provider.ZoneIDName, z map[string]transip.Domain) (transip.Domain, error) { func (p *TransIPProvider) zoneNameForDNSName(name string) (string, error) {
_, zoneName := m.FindZone(name) _, zoneName := p.zoneMap.FindZone(name)
if zoneName == "" { if zoneName == "" {
return transip.Domain{}, fmt.Errorf("could not find zoneName for %s", name) return "", fmt.Errorf("could not find zoneName for %s", name)
} }
zone, ok := z[zoneName] return zoneName, nil
if !ok {
return zone, fmt.Errorf("could not find zone for %s", zoneName)
}
return zone, nil
} }

View File

@ -17,116 +17,123 @@ limitations under the License.
package transip package transip
import ( import (
"context"
"encoding/json"
"errors"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
transip "github.com/transip/gotransip/domain" "github.com/stretchr/testify/require"
"github.com/transip/gotransip/v6/domain"
"github.com/transip/gotransip/v6/rest"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/provider"
) )
func newProvider() *TransIPProvider {
return &TransIPProvider{
zoneMap: provider.ZoneIDName{},
}
}
func TestTransIPDnsEntriesAreEqual(t *testing.T) { func TestTransIPDnsEntriesAreEqual(t *testing.T) {
p := TransIPProvider{}
// test with equal set // test with equal set
a := transip.DNSEntries{ a := []domain.DNSEntry{
transip.DNSEntry{ {
Name: "www.example.org", Name: "www.example.org",
Type: transip.DNSEntryTypeCNAME, Type: "CNAME",
TTL: 3600, Expire: 3600,
Content: "www.example.com", Content: "www.example.com",
}, },
transip.DNSEntry{ {
Name: "www.example.com", Name: "www.example.com",
Type: transip.DNSEntryTypeA, Type: "A",
TTL: 3600, Expire: 3600,
Content: "192.168.0.1", Content: "192.168.0.1",
}, },
} }
b := transip.DNSEntries{ b := []domain.DNSEntry{
transip.DNSEntry{ {
Name: "www.example.com", Name: "www.example.com",
Type: transip.DNSEntryTypeA, Type: "A",
TTL: 3600, Expire: 3600,
Content: "192.168.0.1", Content: "192.168.0.1",
}, },
transip.DNSEntry{ {
Name: "www.example.org", Name: "www.example.org",
Type: transip.DNSEntryTypeCNAME, Type: "CNAME",
TTL: 3600, Expire: 3600,
Content: "www.example.com", Content: "www.example.com",
}, },
} }
assert.Equal(t, true, p.dnsEntriesAreEqual(a, b)) assert.Equal(t, true, dnsEntriesAreEqual(a, b))
// change type on one of b's records // change type on one of b's records
b[1].Type = transip.DNSEntryTypeNS b[1].Type = "NS"
assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) assert.Equal(t, false, dnsEntriesAreEqual(a, b))
b[1].Type = transip.DNSEntryTypeCNAME b[1].Type = "CNAME"
// change ttl on one of b's records // change ttl on one of b's records
b[1].TTL = 1800 b[1].Expire = 1800
assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) assert.Equal(t, false, dnsEntriesAreEqual(a, b))
b[1].TTL = 3600 b[1].Expire = 3600
// change name on one of b's records // change name on one of b's records
b[1].Name = "example.org" b[1].Name = "example.org"
assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) assert.Equal(t, false, dnsEntriesAreEqual(a, b))
// remove last entry of b // remove last entry of b
b = b[:1] b = b[:1]
assert.Equal(t, false, p.dnsEntriesAreEqual(a, b)) assert.Equal(t, false, dnsEntriesAreEqual(a, b))
} }
func TestTransIPGetMinimalValidTTL(t *testing.T) { func TestTransIPGetMinimalValidTTL(t *testing.T) {
p := TransIPProvider{}
// test with 'unconfigured' TTL // test with 'unconfigured' TTL
ep := &endpoint.Endpoint{} ep := &endpoint.Endpoint{}
assert.Equal(t, int64(transipMinimalValidTTL), p.getMinimalValidTTL(ep)) assert.EqualValues(t, transipMinimalValidTTL, getMinimalValidTTL(ep))
// test with lower than minimal ttl // test with lower than minimal ttl
ep.RecordTTL = (transipMinimalValidTTL - 1) ep.RecordTTL = (transipMinimalValidTTL - 1)
assert.Equal(t, int64(transipMinimalValidTTL), p.getMinimalValidTTL(ep)) assert.EqualValues(t, transipMinimalValidTTL, getMinimalValidTTL(ep))
// test with higher than minimal ttl // test with higher than minimal ttl
ep.RecordTTL = (transipMinimalValidTTL + 1) ep.RecordTTL = (transipMinimalValidTTL + 1)
assert.Equal(t, int64(transipMinimalValidTTL+1), p.getMinimalValidTTL(ep)) assert.EqualValues(t, transipMinimalValidTTL+1, getMinimalValidTTL(ep))
} }
func TestTransIPRecordNameForEndpoint(t *testing.T) { func TestTransIPRecordNameForEndpoint(t *testing.T) {
p := TransIPProvider{}
ep := &endpoint.Endpoint{ ep := &endpoint.Endpoint{
DNSName: "example.org", DNSName: "example.org",
} }
d := transip.Domain{ d := domain.Domain{
Name: "example.org", Name: "example.org",
} }
assert.Equal(t, "@", p.recordNameForEndpoint(ep, d)) assert.Equal(t, "@", recordNameForEndpoint(ep, d.Name))
ep.DNSName = "www.example.org" ep.DNSName = "www.example.org"
assert.Equal(t, "www", p.recordNameForEndpoint(ep, d)) assert.Equal(t, "www", recordNameForEndpoint(ep, d.Name))
} }
func TestTransIPEndpointNameForRecord(t *testing.T) { func TestTransIPEndpointNameForRecord(t *testing.T) {
p := TransIPProvider{} r := domain.DNSEntry{
r := transip.DNSEntry{
Name: "@", Name: "@",
} }
d := transip.Domain{ d := domain.Domain{
Name: "example.org", Name: "example.org",
} }
assert.Equal(t, d.Name, p.endpointNameForRecord(r, d)) assert.Equal(t, d.Name, endpointNameForRecord(r, d.Name))
r.Name = "www" r.Name = "www"
assert.Equal(t, "www.example.org", p.endpointNameForRecord(r, d)) assert.Equal(t, "www.example.org", endpointNameForRecord(r, d.Name))
} }
func TestTransIPAddEndpointToEntries(t *testing.T) { func TestTransIPAddEndpointToEntries(t *testing.T) {
p := TransIPProvider{}
// prepare endpoint // prepare endpoint
ep := &endpoint.Endpoint{ ep := &endpoint.Endpoint{
DNSName: "www.example.org", DNSName: "www.example.org",
@ -139,94 +146,193 @@ func TestTransIPAddEndpointToEntries(t *testing.T) {
} }
// prepare zone with DNS entry set // prepare zone with DNS entry set
zone := transip.Domain{ zone := domain.Domain{
Name: "example.org", Name: "example.org",
// 2 matching A records
DNSEntries: transip.DNSEntries{
// 1 non-matching A record
transip.DNSEntry{
Name: "mail",
Type: transip.DNSEntryTypeA,
Content: "192.168.0.1",
TTL: 3600,
},
// 1 non-matching MX record
transip.DNSEntry{
Name: "@",
Type: transip.DNSEntryTypeMX,
Content: "mail.example.org",
TTL: 3600,
},
},
} }
// add endpoint to zone's entries // add endpoint to zone's entries
result := p.addEndpointToEntries(ep, zone, zone.DNSEntries) result := dnsEntriesForEndpoint(ep, zone.Name)
assert.Equal(t, 4, len(result)) if assert.Equal(t, 2, len(result)) {
assert.Equal(t, "mail", result[0].Name) assert.Equal(t, "www", result[0].Name)
assert.Equal(t, transip.DNSEntryTypeA, result[0].Type) assert.Equal(t, "A", result[0].Type)
assert.Equal(t, "@", result[1].Name) assert.Equal(t, "192.168.0.1", result[0].Content)
assert.Equal(t, transip.DNSEntryTypeMX, result[1].Type) assert.EqualValues(t, 1800, result[0].Expire)
assert.Equal(t, "www", result[2].Name) assert.Equal(t, "www", result[1].Name)
assert.Equal(t, transip.DNSEntryTypeA, result[2].Type) assert.Equal(t, "A", result[1].Type)
assert.Equal(t, "192.168.0.1", result[2].Content) assert.Equal(t, "192.168.0.2", result[1].Content)
assert.Equal(t, int64(1800), result[2].TTL) assert.EqualValues(t, 1800, result[1].Expire)
assert.Equal(t, "www", result[3].Name)
assert.Equal(t, transip.DNSEntryTypeA, result[3].Type)
assert.Equal(t, "192.168.0.2", result[3].Content)
assert.Equal(t, int64(1800), result[3].TTL)
}
func TestTransIPRemoveEndpointFromEntries(t *testing.T) {
p := TransIPProvider{}
// prepare endpoint
ep := &endpoint.Endpoint{
DNSName: "www.example.org",
RecordType: "A",
} }
// prepare zone with DNS entry set // try again with CNAME
zone := transip.Domain{ ep.RecordType = "CNAME"
Name: "example.org", ep.Targets = []string{"foo.bar"}
// 2 matching A records result = dnsEntriesForEndpoint(ep, zone.Name)
DNSEntries: transip.DNSEntries{ if assert.Equal(t, 1, len(result)) {
transip.DNSEntry{ assert.Equal(t, "CNAME", result[0].Type)
Name: "www", assert.Equal(t, "foo.bar.", result[0].Content)
Type: transip.DNSEntryTypeA, }
Content: "192.168.0.1", }
TTL: 3600,
}, func TestZoneNameForDNSName(t *testing.T) {
transip.DNSEntry{ p := newProvider()
Name: "www", p.zoneMap.Add("example.com", "example.com")
Type: transip.DNSEntryTypeA,
Content: "192.168.0.2", zoneName, err := p.zoneNameForDNSName("www.example.com")
TTL: 3600, if assert.NoError(t, err) {
}, assert.Equal(t, "example.com", zoneName)
// 1 non-matching A record }
transip.DNSEntry{
Name: "mail", _, err = p.zoneNameForDNSName("www.example.org")
Type: transip.DNSEntryTypeA, if assert.Error(t, err) {
Content: "192.168.0.1", assert.Equal(t, "could not find zoneName for www.example.org", err.Error())
TTL: 3600, }
}, }
// 1 non-matching MX record
transip.DNSEntry{ // fakeClient mocks the REST API client
Name: "@", type fakeClient struct {
Type: transip.DNSEntryTypeMX, getFunc func(rest.Request, interface{}) error
Content: "mail.example.org", }
TTL: 3600,
}, func (f *fakeClient) Get(request rest.Request, dest interface{}) error {
if f.getFunc == nil {
return errors.New("GET not defined")
}
return f.getFunc(request, dest)
}
func (f fakeClient) Put(request rest.Request) error {
return errors.New("PUT not implemented")
}
func (f fakeClient) Post(request rest.Request) error {
return errors.New("POST not implemented")
}
func (f fakeClient) Delete(request rest.Request) error {
return errors.New("DELETE not implemented")
}
func (f fakeClient) Patch(request rest.Request) error {
return errors.New("PATCH not implemented")
}
func TestProviderRecords(t *testing.T) {
// set up the fake REST client
client := &fakeClient{}
client.getFunc = func(req rest.Request, dest interface{}) error {
var data []byte
switch {
case req.Endpoint == "/domains":
// return list of some domain names
// names only, other fields are not used
data = []byte(`{"domains":[{"name":"example.org"}, {"name":"example.com"}]}`)
case strings.HasSuffix(req.Endpoint, "/dns"):
// return list of DNS entries
// also some unsupported types
data = []byte(`{"dnsEntries":[{"name":"www", "expire":1234, "type":"CNAME", "content":"@"},{"type":"MX"},{"type":"AAAA"}]}`)
}
// unmarshal the prepared return data into the given destination type
return json.Unmarshal(data, &dest)
}
// set up provider
p := newProvider()
p.domainRepo = domain.Repository{Client: client}
endpoints, err := p.Records(context.TODO())
if assert.NoError(t, err) {
if assert.Equal(t, 2, len(endpoints)) {
assert.Equal(t, "www.example.org", endpoints[0].DNSName)
assert.EqualValues(t, "@", endpoints[0].Targets[0])
assert.Equal(t, "CNAME", endpoints[0].RecordType)
assert.Equal(t, 0, len(endpoints[0].Labels))
assert.EqualValues(t, 1234, endpoints[0].RecordTTL)
}
}
}
func TestProviderEntriesForEndpoint(t *testing.T) {
// set up fake REST client
client := &fakeClient{}
// set up provider
p := newProvider()
p.domainRepo = domain.Repository{Client: client}
p.zoneMap.Add("example.com", "example.com")
// get entries for endpoint with unknown zone
_, _, err := p.entriesForEndpoint(&endpoint.Endpoint{
DNSName: "www.example.org",
})
if assert.Error(t, err) {
assert.Equal(t, "could not find zoneName for www.example.org", err.Error())
}
// get entries for endpoint with known zone but client returns error
// we leave GET functions undefined so we know which error to expect
zoneName, _, err := p.entriesForEndpoint(&endpoint.Endpoint{
DNSName: "www.example.com",
})
if assert.Error(t, err) {
assert.Equal(t, "GET not defined", err.Error())
}
assert.Equal(t, "example.com", zoneName)
// to be able to return a valid set of DNS entries through the API, we define
// some first, then JSON encode them and have the fake API client's Get function
// return that
// in this set are some entries that do and others that don't match the given
// endpoint
dnsEntries := []domain.DNSEntry{
{
Name: "www",
Type: "A",
Expire: 3600,
Content: "1.2.3.4",
},
{
Name: "ftp",
Type: "A",
Expire: 86400,
Content: "3.4.5.6",
},
{
Name: "www",
Type: "A",
Expire: 3600,
Content: "2.3.4.5",
},
{
Name: "www",
Type: "CNAME",
Expire: 3600,
Content: "@",
}, },
} }
var v struct {
DNSEntries []domain.DNSEntry `json:"dnsEntries"`
}
v.DNSEntries = dnsEntries
returnData, err := json.Marshal(&v)
require.NoError(t, err)
// remove endpoint from zone's entries // define GET function
result := p.removeEndpointFromEntries(ep, zone) client.getFunc = func(unused rest.Request, dest interface{}) error {
// unmarshal the prepared return data into the given dnsEntriesWrapper
assert.Equal(t, 2, len(result)) return json.Unmarshal(returnData, &dest)
assert.Equal(t, "mail", result[0].Name) }
assert.Equal(t, transip.DNSEntryTypeA, result[0].Type) _, entries, err := p.entriesForEndpoint(&endpoint.Endpoint{
assert.Equal(t, "@", result[1].Name) DNSName: "www.example.com",
assert.Equal(t, transip.DNSEntryTypeMX, result[1].Type) RecordType: "A",
})
if assert.NoError(t, err) {
if assert.Equal(t, 2, len(entries)) {
// only first and third entry should be returned
assert.Equal(t, dnsEntries[0], entries[0])
assert.Equal(t, dnsEntries[2], entries[1])
}
}
} }