From 1ad76fe65b25cf93cc98c860ccd4dc112a76e06c Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Wed, 26 Jun 2013 22:18:09 +0100 Subject: [PATCH] Added packLen packLen() returns the length of an uncompressed packet buffer, this is used when packing a packet. This is needed for compression. When compression is used, we first create the full packet and *then* compress it. If we use Len() which accounts for compression, we can get buffer overruns, when packing the (then still uncompressed) packet. --- dns_test.go | 57 ++++++++++++++++++++--------------------------------- msg.go | 44 ++++++++++++++++++++++++++++------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/dns_test.go b/dns_test.go index 292b9a2a..c64ed3a5 100644 --- a/dns_test.go +++ b/dns_test.go @@ -159,49 +159,34 @@ func TestCompressLength(t *testing.T) { // Does the predicted length match final packed length func TestMsgLenTest(t *testing.T) { - var ( - // util function to build messages - makeMsg = func(question string, ans, ns, e []RR) *Msg { - var msg Msg - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return &msg - } + // util function to build messages + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } - name = "12345678901234567890123456789012345.12345678.123." - rrA, _ = NewRR(name + " 3600 IN A 192.0.2.1") - rrMx, _ = NewRR(name + " 3600 IN MX 10 " + name) - rrTxt, _ = NewRR(name + ` 3600 IN TXT "I am a TXT"`) - tests = []*Msg{ - makeMsg(name, nil, nil, nil), - makeMsg(name, []RR{rrA}, nil, nil), - makeMsg(name, []RR{rrMx}, nil, nil), - makeMsg(name, []RR{rrTxt}, nil, nil), - makeMsg(name, []RR{rrA, rrA}, nil, nil), - makeMsg(name, []RR{rrMx, rrMx}, nil, nil), - makeMsg(name, []RR{rrTxt, rrTxt}, nil, nil), - makeMsg(name, []RR{rrA}, []RR{rrA}, nil), - makeMsg(name, []RR{rrMx}, []RR{rrMx}, nil), - makeMsg(name, []RR{rrTxt}, []RR{rrTxt}, nil), - makeMsg(name, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt}, nil), - makeMsg(name, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt})} - ) + name1 := "12345678901234567890123456789012345.12345678.123." + rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1") + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + tests := []*Msg{ + makeMsg(name1, []RR{rrA}, nil, nil), + makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)} for _, msg := range tests { - var ( - predicted = msg.Len() - buf, err = msg.Pack() - actual = len(buf) - ) + predicted := msg.Len() + buf, err := msg.Pack() if err != nil { t.Error(err) t.Fail() } - if predicted != actual { - t.Errorf("Predicted length is wrong: predicted %d, actual %d\n%s", predicted, actual, msg) + if predicted != len(buf) { + t.Errorf("Predicted length is wrong: predicted %s (len=%d) %d, actual %d\n", + msg.Question[0].Name, len(msg.Answer), predicted, len(buf)) t.Fail() } } diff --git a/msg.go b/msg.go index b3a68f1b..edbea2c8 100644 --- a/msg.go +++ b/msg.go @@ -283,7 +283,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c // the offset of the current name, because that's // where we need to insert the pointer later - // If compress is true, we're allowed to compress this dname + // If compress is true, we're allowed to compress this dname if pointer == -1 && compress { pointer = p // Where to point to nameoffset = offset // Where to point from @@ -298,7 +298,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c if len(bs) == 1 && bs[0] == '.' { return off, nil } - // If we did compression and we find something at the pointer here + // If we did compression and we find something add the pointer here if pointer != -1 { // We have two bytes (14 bits) to put the pointer in msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000)) @@ -1255,9 +1255,7 @@ func (dns *Msg) Pack() (msg []byte, err error) { dh.Nscount = uint16(len(ns)) dh.Arcount = uint16(len(extra)) - // TODO(mg): still a little too much, but better than 64K... - msg = make([]byte, dns.Len()+10) - + msg = make([]byte, dns.packLen()+10) // TODO(miekg): +10 should go sometimses // Pack it in: header and then the pieces. off := 0 off, err = packStructCompress(&dh, msg, off, compression, dns.Compress) @@ -1393,10 +1391,28 @@ func (dns *Msg) String() string { return s } -// Len return the message length when in (un)compressed wire format. +// packLen returns the message length when in UNcompressed wire format. +func (dns *Msg) packLen() int { + // Message header is always 12 bytes + l := 12 + for i := 0; i < len(dns.Question); i++ { + l += dns.Question[i].len() + } + for i := 0; i < len(dns.Answer); i++ { + l += dns.Answer[i].len() + } + for i := 0; i < len(dns.Ns); i++ { + l += dns.Ns[i].len() + } + for i := 0; i < len(dns.Extra); i++ { + l += dns.Extra[i].len() + } + return l +} + +// Len returns the message length when in (un)compressed wire format. // If dns.Compress is true compression it is taken into account, currently -// this only counts owner name compression. There is no check for -// nil valued sections (allocated, but contain no RRs). +// this only counts owner name compression. func (dns *Msg) Len() int { // Message header is always 12 bytes l := 12 @@ -1413,8 +1429,8 @@ func (dns *Msg) Len() int { } for i := 0; i < len(dns.Answer); i++ { if dns.Compress { - if v, ok := compression[dns.Answer[i].Header().Name]; ok { - l += dns.Answer[i].len() - v + if _, ok := compression[dns.Answer[i].Header().Name]; ok { + l += dns.Answer[i].len() - len(dns.Answer[i].Header().Name) + 2 continue } compressionHelper(compression, dns.Answer[i].Header().Name) @@ -1423,8 +1439,8 @@ func (dns *Msg) Len() int { } for i := 0; i < len(dns.Ns); i++ { if dns.Compress { - if v, ok := compression[dns.Ns[i].Header().Name]; ok { - l += dns.Ns[i].len() - v + if _, ok := compression[dns.Ns[i].Header().Name]; ok { + l += dns.Ns[i].len() - len(dns.Ns[i].Header().Name) + 2 continue } compressionHelper(compression, dns.Ns[i].Header().Name) @@ -1433,8 +1449,8 @@ func (dns *Msg) Len() int { } for i := 0; i < len(dns.Extra); i++ { if dns.Compress { - if v, ok := compression[dns.Extra[i].Header().Name]; ok { - l += dns.Extra[i].len() - v + if _, ok := compression[dns.Extra[i].Header().Name]; ok { + l += dns.Extra[i].len() - len(dns.Extra[i].Header().Name) + 2 continue } compressionHelper(compression, dns.Extra[i].Header().Name)