diff --git a/sanitize.go b/sanitize.go index 2bb3223e..6924db96 100644 --- a/sanitize.go +++ b/sanitize.go @@ -1,5 +1,7 @@ package dns +import "strings" + // Dedup removes identical RRs from rrs. It preserves the original ordering. // The lowest TTL of any duplicates is used in the remaining one. // @@ -8,13 +10,17 @@ package dns // if it finds a: a.miek.nl. CNAME foo, all other RRs with the ownername a.miek.nl. // will be removed. When a DNAME is found all RRs with an ownername below that of // the DNAME will be removed. -// Note that the class of the CNAME/DNAME is *not* taken into account. TODO(miek)? func Dedup(rrs []RR) []RR { m := make(map[string]RR) keys := make([]string, 0, len(rrs)) + // Save possible cname and dname domainnames. Currently a slice, don't + // expect millions here.. + cname := []string{} + dname := []string{} + for _, r := range rrs { - key := normalizedString(r) + key, end := normalizedString(r) keys = append(keys, key) if _, ok := m[key]; ok { // Shortest TTL wins. @@ -23,6 +29,17 @@ func Dedup(rrs []RR) []RR { } continue } + + if r.Header().Rrtype == TypeCNAME { + // we do end+3 here, so we capture the full domain name *and* + // the class field which mnemonic is always two chars. + cname = append(cname, key[:end+3]) + + } + if r.Header().Rrtype == TypeDNAME { + dname = append(dname, key[:end+3]) + } + m[key] = r } // If the length of the result map equals the amount of RRs we got, @@ -30,20 +47,33 @@ func Dedup(rrs []RR) []RR { if len(m) == len(rrs) { return rrs } - var i = 0 - for i, _ = range rrs { + + ret := make([]RR, 0, len(rrs)) + for i, r := range rrs { + // If keys[i] lives in the map, we should copy it and remove + // it from the map. + if _, ok := m[keys[i]]; ok { + if needsDeletion(r, keys[i], cname, dname) { + delete(m, keys[i]) + continue + } + + delete(m, keys[i]) + ret = append(ret, r) + } + if len(m) == 0 { break } - // We saved the key for each RR. - delete(m, keys[i]) } - return rrs[:i] + + return ret } // normalizedString returns a normalized string from r. The TTL -// is removed and the domain name is lowercased. -func normalizedString(r RR) string { +// is removed and the domain name is lowercased. The returned integer +// is the index where the domain name ends + 1. +func normalizedString(r RR) (string, int) { // A string Go DNS makes has: domainnameTTL... b := []byte(r.String()) @@ -78,30 +108,27 @@ func normalizedString(r RR) string { // remove TTL. copy(b[ttlStart:], b[ttlEnd:]) cut := ttlEnd - ttlStart - return string(b[:len(b)-cut]) + // ttlStart + 3 puts us on the start of the rdata + return string(b[:len(b)-cut]), ttlStart } -// dropCNAMEAndDNAME drops records from rrs that are not allowed, taking the rules -// for CNAME and DNAME into account. -func dropCNAMEAndDNAME(rrs []RR) []RR { - ret := make([]RR, 0, len(rrs)) +func needsDeletion(r RR, s string, cname, dname []string) bool { + if r.Header().Rrtype == TypeCNAME || r.Header().Rrtype == TypeDNAME { + return false + } - - return nil - /* - make separate step that remove cname, dname - switch r.Header().Rrtype { - case TypeCNAME: - cname = append(cname, strings.ToLower(r.Header().Name)) - case TypeDNAME: - dname = append(dname, strings.ToLower(r.Header().Name)) - default: - if len(cname) == 0 && len(dname) == 0 { - break - } - if strings.EqualFold + // For CNAME we can do strings.HasPrefix with s. + // For DNAME we can do strings.Contains with s. + // Either signals a removal of this RR. + for _, c := range cname { + if strings.HasPrefix(s, c) { + return true } - */ - - + } + for _, d := range dname { + if strings.Contains(s, d) { + return true + } + } + return false } diff --git a/sanitize_test.go b/sanitize_test.go index fc619dc0..3a1b4481 100644 --- a/sanitize_test.go +++ b/sanitize_test.go @@ -1,44 +1,65 @@ package dns -import "testing" +import ( + "reflect" + "testing" +) func TestDedup(t *testing.T) { - testcases := map[[3]RR]string{ + // make it []string + testcases := map[[3]RR][]string{ [...]RR{ newRR(t, "mIek.nl. IN A 127.0.0.1"), newRR(t, "mieK.nl. IN A 127.0.0.1"), newRR(t, "miek.Nl. IN A 127.0.0.1"), - }: "mIek.nl.\t3600\tIN\tA\t127.0.0.1", + }: []string{"mIek.nl.\t3600\tIN\tA\t127.0.0.1"}, [...]RR{ newRR(t, "miEk.nl. 2000 IN A 127.0.0.1"), newRR(t, "mieK.Nl. 1000 IN A 127.0.0.1"), newRR(t, "Miek.nL. 500 IN A 127.0.0.1"), - }: "miEk.nl.\t500\tIN\tA\t127.0.0.1", + }: []string{"miEk.nl.\t500\tIN\tA\t127.0.0.1"}, [...]RR{ newRR(t, "miek.nl. IN A 127.0.0.1"), newRR(t, "miek.nl. CH A 127.0.0.1"), newRR(t, "miek.nl. IN A 127.0.0.1"), - }: "miek.nl.\t3600\tIN\tA\t127.0.0.1", + }: []string{"miek.nl.\t3600\tIN\tA\t127.0.0.1", + "miek.nl.\t3600\tCH\tA\t127.0.0.1", + }, [...]RR{ newRR(t, "miek.nl. CH A 127.0.0.1"), newRR(t, "miek.nl. IN A 127.0.0.1"), + newRR(t, "miek.de. IN A 127.0.0.1"), + }: []string{"miek.nl.\t3600\tCH\tA\t127.0.0.1", + "miek.de.\t3600\tIN\tA\t127.0.0.1", + }, + [...]RR{ + newRR(t, "miek.de. IN A 127.0.0.1"), newRR(t, "miek.nl. IN A 127.0.0.1"), - }: "miek.nl.\t3600\tCH\tA\t127.0.0.1", + newRR(t, "miek.nl. IN A 127.0.0.1"), + }: []string{"miek.de.\t3600\tIN\tA\t127.0.0.1", + "miek.de.\t3600\tIN\tA\t127.0.0.1", + }, } for rr, expected := range testcases { out := Dedup([]RR{rr[0], rr[1], rr[2]}) - if len(out) == 0 || len(out) == 3 { - t.Logf("dedup failed, wrong number of RRs returned") - t.Fail() - } - if o := out[0].String(); o != expected { - t.Logf("dedup failed, expected %s, got %s", expected, o) - t.Fail() + if !reflect.DeepEqual(out, expected) { + t.Fatalf("expected %v, got %v", expected, out) } } } +func TestDedupWithCNAME(t *testing.T) { + in := []RR{ + newRR(t, "miek.Nl. CNAME a."), + newRR(t, "miEk.nl. IN A 127.0.0.1"), + newRR(t, "miek.Nl. IN A 127.0.0.1"), + newRR(t, "miek.de. IN A 127.0.0.1"), + } + out := Dedup(in) + t.Logf("%+v\n", out) +} + func TestNormalizedString(t *testing.T) { tests := map[RR]string{ newRR(t, "mIEk.Nl. 3600 IN A 127.0.0.1"): "miek.nl.\tIN\tA\t127.0.0.1", @@ -46,7 +67,7 @@ func TestNormalizedString(t *testing.T) { newRR(t, "m\\\tIeK.nl. 3600 in A 127.0.0.1"): "m\\tiek.nl.\tIN\tA\t127.0.0.1", } for tc, expected := range tests { - a1 := normalizedString(tc) + a1, _ := normalizedString(tc) if a1 != expected { t.Logf("expected %s, got %s", expected, a1) t.Fail()