diff --git a/provider/rfc2136/rfc2136.go b/provider/rfc2136/rfc2136.go index 8353bc2d9..c0671c994 100644 --- a/provider/rfc2136/rfc2136.go +++ b/provider/rfc2136/rfc2136.go @@ -256,7 +256,12 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes for c, chunk := range chunkBy(changes.Create, r.batchChangeSize) { log.Debugf("Processing batch %d of create changes", c) - m := new(dns.Msg) + m := make(map[string]*dns.Msg) + m["."] = new(dns.Msg) // Add the root zone + for _, z := range r.zoneNames { + z = dns.Fqdn(z) + m[z] = new(dns.Msg) + } for _, ep := range chunk { if !r.domainFilter.Match(ep.DNSName) { log.Debugf("Skipping record %s because it was filtered out by the specified --domain-filter", ep.DNSName) @@ -265,18 +270,19 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes zone := findMsgZone(ep, r.zoneNames) r.krb5Realm = strings.ToUpper(zone) - m.SetUpdate(zone) + m[zone].SetUpdate(zone) - r.AddRecord(m, ep) + r.AddRecord(m[zone], ep) } // only send if there are records available - if len(m.Ns) > 0 { - err := r.actions.SendMessage(m) - if err != nil { - log.Errorf("RFC2136 update failed: %v", err) - errors = append(errors, err) - continue + for _, z := range m { + if len(z.Ns) > 0 { + if err := r.actions.SendMessage(z); err != nil { + log.Errorf("RFC2136 create record failed: %v", err) + errors = append(errors, err) + continue + } } } } @@ -284,7 +290,12 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes for c, chunk := range chunkBy(changes.UpdateNew, r.batchChangeSize) { log.Debugf("Processing batch %d of update changes", c) - m := new(dns.Msg) + m := make(map[string]*dns.Msg) + m["."] = new(dns.Msg) // Add the root zone + for _, z := range r.zoneNames { + z = dns.Fqdn(z) + m[z] = new(dns.Msg) + } for i, ep := range chunk { if !r.domainFilter.Match(ep.DNSName) { @@ -294,18 +305,19 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes zone := findMsgZone(ep, r.zoneNames) r.krb5Realm = strings.ToUpper(zone) - m.SetUpdate(zone) + m[zone].SetUpdate(zone) - r.UpdateRecord(m, changes.UpdateOld[i], ep) + r.UpdateRecord(m[zone], changes.UpdateOld[i], ep) } // only send if there are records available - if len(m.Ns) > 0 { - err := r.actions.SendMessage(m) - if err != nil { - log.Errorf("RFC2136 update failed: %v", err) - errors = append(errors, err) - continue + for _, z := range m { + if len(z.Ns) > 0 { + if err := r.actions.SendMessage(z); err != nil { + log.Errorf("RFC2136 update record failed: %v", err) + errors = append(errors, err) + continue + } } } } @@ -313,8 +325,12 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes for c, chunk := range chunkBy(changes.Delete, r.batchChangeSize) { log.Debugf("Processing batch %d of delete changes", c) - m := new(dns.Msg) - + m := make(map[string]*dns.Msg) + m["."] = new(dns.Msg) // Add the root zone + for _, z := range r.zoneNames { + z = dns.Fqdn(z) + m[z] = new(dns.Msg) + } for _, ep := range chunk { if !r.domainFilter.Match(ep.DNSName) { log.Debugf("Skipping record %s because it was filtered out by the specified --domain-filter", ep.DNSName) @@ -323,18 +339,19 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes zone := findMsgZone(ep, r.zoneNames) r.krb5Realm = strings.ToUpper(zone) - m.SetUpdate(zone) + m[zone].SetUpdate(zone) - r.RemoveRecord(m, ep) + r.RemoveRecord(m[zone], ep) } // only send if there are records available - if len(m.Ns) > 0 { - err := r.actions.SendMessage(m) - if err != nil { - log.Errorf("RFC2136 update failed: %v", err) - errors = append(errors, err) - continue + for _, z := range m { + if len(z.Ns) > 0 { + if err := r.actions.SendMessage(z); err != nil { + log.Errorf("RFC2136 delete record failed: %v", err) + errors = append(errors, err) + continue + } } } } diff --git a/provider/rfc2136/rfc2136_test.go b/provider/rfc2136/rfc2136_test.go index 26f3d7186..a143ef7b0 100644 --- a/provider/rfc2136/rfc2136_test.go +++ b/provider/rfc2136/rfc2136_test.go @@ -19,6 +19,8 @@ package rfc2136 import ( "context" "fmt" + "regexp" + "sort" "strings" "testing" "time" @@ -46,8 +48,23 @@ func newStub() *rfc2136Stub { } } +func getSortedChanges(msgs []*dns.Msg) []string { + r := []string{} + for _, d := range msgs { + // only care about section after the ZONE SECTION: as the id: needs stripped out in order to sort and grantee the order when sorting + r = append(r, strings.Split(d.String(), "ZONE SECTION:")[1]) + } + sort.Strings(r) + return r +} + func (r *rfc2136Stub) SendMessage(msg *dns.Msg) error { - log.Info(msg.String()) + zone := extractZoneFromMessage(msg.String()) + // Make sure the zone starts with . to make sure HasSuffix does not match forbar.com for zone bar.com + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + log.Infof("zone=%s", zone) lines := extractUpdateSectionFromMessage(msg) for _, line := range lines { // break at first empty line @@ -57,6 +74,12 @@ func (r *rfc2136Stub) SendMessage(msg *dns.Msg) error { line = strings.Replace(line, "\t", " ", -1) log.Info(line) + record := strings.Split(line, " ")[0] + if !strings.HasSuffix(record, zone) { + err := fmt.Errorf("Message contains updates outside of it's zone. zone=%v record=%v", zone, record) + log.Error(err) + return err + } if strings.Contains(line, " NONE ") { r.updateMsgs = append(r.updateMsgs, msg) @@ -98,12 +121,28 @@ func createRfc2136StubProvider(stub *rfc2136Stub) (provider.Provider, error) { return NewRfc2136Provider("", 0, nil, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, "", "", "", 50, stub) } +func createRfc2136StubProviderWithZones(stub *rfc2136Stub) (provider.Provider, error) { + zones := []string{"foo.com", "foobar.com"} + return NewRfc2136Provider("", 0, zones, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, "", "", "", 50, stub) +} + +func createRfc2136StubProviderWithZonesFilters(stub *rfc2136Stub) (provider.Provider, error) { + zones := []string{"foo.com", "foobar.com"} + return NewRfc2136Provider("", 0, zones, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{Filters: zones}, false, 300*time.Second, false, "", "", "", 50, stub) +} + func extractUpdateSectionFromMessage(msg fmt.Stringer) []string { const searchPattern = "UPDATE SECTION:" updateSectionOffset := strings.Index(msg.String(), searchPattern) return strings.Split(strings.TrimSpace(msg.String()[updateSectionOffset+len(searchPattern):]), "\n") } +func extractZoneFromMessage(msg string) string { + re := regexp.MustCompile(`ZONE SECTION:\n;(?P[\.,\-,\w,\d]+)\t`) + matches := re.FindStringSubmatch(msg) + return matches[re.SubexpIndex("ZONE")] +} + // TestRfc2136GetRecordsMultipleTargets simulates a single record with multiple targets. func TestRfc2136GetRecordsMultipleTargets(t *testing.T) { stub := newStub() @@ -154,6 +193,32 @@ func TestRfc2136GetRecords(t *testing.T) { assert.True(t, contains(recs, "v2.foo.com")) } +// Make sure the test version of SendMessage raises an error +// if a zone update ever contains records outside of it's zone +// as the TestRfc2136ApplyChanges tests all assume this +func TestRfc2136SendMessage(t *testing.T) { + stub := newStub() + + m := new(dns.Msg) + m.SetUpdate("foo.com.") + rr, err := dns.NewRR(fmt.Sprintf("%s %d %s %s", "v1.foo.com.", 0, "A", "1.2.3.4")) + m.Insert([]dns.RR{rr}) + + err = stub.SendMessage(m) + assert.NoError(t, err) + + rr, err = dns.NewRR(fmt.Sprintf("%s %d %s %s", "v1.bar.com.", 0, "A", "1.2.3.4")) + m.Insert([]dns.RR{rr}) + + err = stub.SendMessage(m) + assert.Error(t, err) + + m.SetUpdate(".") + err = stub.SendMessage(m) + assert.NoError(t, err) +} + +// These tests are use the . root zone with no filters func TestRfc2136ApplyChanges(t *testing.T) { stub := newStub() provider, err := createRfc2136StubProvider(stub) @@ -210,6 +275,145 @@ func TestRfc2136ApplyChanges(t *testing.T) { assert.True(t, strings.Contains(stub.updateMsgs[1].String(), "v2.foobar.com")) } +// These tests all use the foo.com and foobar.com zones with no filters +// createMsgs and updateMsgs need sorted when are are used +func TestRfc2136ApplyChangesWithZones(t *testing.T) { + stub := newStub() + provider, err := createRfc2136StubProviderWithZones(stub) + assert.NoError(t, err) + + p := &plan.Changes{ + Create: []*endpoint.Endpoint{ + { + DNSName: "v1.foo.com", + RecordType: "A", + Targets: []string{"1.2.3.4"}, + RecordTTL: endpoint.TTL(400), + }, + { + DNSName: "v1.foobar.com", + RecordType: "TXT", + Targets: []string{"boom"}, + }, + { + DNSName: "ns.foobar.com", + RecordType: "NS", + Targets: []string{"boom"}, + }, + }, + Delete: []*endpoint.Endpoint{ + { + DNSName: "v2.foo.com", + RecordType: "A", + Targets: []string{"1.2.3.4"}, + }, + { + DNSName: "v2.foobar.com", + RecordType: "TXT", + Targets: []string{"boom2"}, + }, + }, + } + + err = provider.ApplyChanges(context.Background(), p) + assert.NoError(t, err) + + assert.Equal(t, 3, len(stub.createMsgs)) + createMsgs := getSortedChanges(stub.createMsgs) + assert.Equal(t, 3, len(createMsgs)) + + assert.True(t, strings.Contains(createMsgs[0], "v1.foo.com")) + assert.True(t, strings.Contains(createMsgs[0], "1.2.3.4")) + + assert.True(t, strings.Contains(createMsgs[1], "v1.foobar.com")) + assert.True(t, strings.Contains(createMsgs[1], "boom")) + + assert.True(t, strings.Contains(createMsgs[2], "ns.foobar.com")) + assert.True(t, strings.Contains(createMsgs[2], "boom")) + + assert.Equal(t, 2, len(stub.updateMsgs)) + updateMsgs := getSortedChanges(stub.updateMsgs) + assert.Equal(t, 2, len(updateMsgs)) + + assert.True(t, strings.Contains(updateMsgs[0], "v2.foo.com")) + assert.True(t, strings.Contains(updateMsgs[1], "v2.foobar.com")) +} + +// These tests use the foo.com and foobar.com zones and with filters set to both zones +// createMsgs and updateMsgs need sorted when are are used +func TestRfc2136ApplyChangesWithZonesFilters(t *testing.T) { + stub := newStub() + provider, err := createRfc2136StubProviderWithZonesFilters(stub) + assert.NoError(t, err) + + p := &plan.Changes{ + Create: []*endpoint.Endpoint{ + { + DNSName: "v1.foo.com", + RecordType: "A", + Targets: []string{"1.2.3.4"}, + RecordTTL: endpoint.TTL(400), + }, + { + DNSName: "v1.foobar.com", + RecordType: "TXT", + Targets: []string{"boom"}, + }, + { + DNSName: "ns.foobar.com", + RecordType: "NS", + Targets: []string{"boom"}, + }, + { + DNSName: "filtered-out.foo.bar", + RecordType: "A", + Targets: []string{"1.2.3.4"}, + RecordTTL: endpoint.TTL(400), + }, + }, + Delete: []*endpoint.Endpoint{ + { + DNSName: "v2.foo.com", + RecordType: "A", + Targets: []string{"1.2.3.4"}, + }, + { + DNSName: "v2.foobar.com", + RecordType: "TXT", + Targets: []string{"boom2"}, + }, + }, + } + + err = provider.ApplyChanges(context.Background(), p) + assert.NoError(t, err) + + assert.Equal(t, 3, len(stub.createMsgs)) + createMsgs := getSortedChanges(stub.createMsgs) + assert.Equal(t, 3, len(createMsgs)) + + assert.True(t, strings.Contains(createMsgs[0], "v1.foo.com")) + assert.True(t, strings.Contains(createMsgs[0], "1.2.3.4")) + + assert.True(t, strings.Contains(createMsgs[1], "v1.foobar.com")) + assert.True(t, strings.Contains(createMsgs[1], "boom")) + + assert.True(t, strings.Contains(createMsgs[2], "ns.foobar.com")) + assert.True(t, strings.Contains(createMsgs[2], "boom")) + + for _, s := range createMsgs { + assert.False(t, strings.Contains(s, "filtered-out.foo.bar")) + } + + assert.Equal(t, 2, len(stub.updateMsgs)) + updateMsgs := getSortedChanges(stub.updateMsgs) + assert.Equal(t, 2, len(updateMsgs)) + + assert.True(t, strings.Contains(updateMsgs[0], "v2.foo.com")) + assert.True(t, strings.Contains(updateMsgs[1], "v2.foobar.com")) + +} + func TestRfc2136ApplyChangesWithDifferentTTLs(t *testing.T) { stub := newStub()