diff --git a/msg.go b/msg.go index 73092d0f..cc1ab7c0 100644 --- a/msg.go +++ b/msg.go @@ -16,7 +16,7 @@ package dns import ( "os" -// "fmt" + // "fmt" "reflect" "net" "rand" @@ -67,7 +67,7 @@ var Rr_str = map[uint16]string{ TypeTXT: "TXT", TypeSRV: "SRV", TypeNAPTR: "NAPTR", - TypeKX: "KX", + TypeKX: "KX", TypeCERT: "CERT", TypeDNAME: "DNAME", TypeA: "A", @@ -75,7 +75,7 @@ var Rr_str = map[uint16]string{ TypeLOC: "LOC", TypeOPT: "OPT", TypeDS: "DS", - TypeDHCID: "DHCID", + TypeDHCID: "DHCID", TypeIPSECKEY: "IPSECKEY", TypeSSHFP: "SSHFP", TypeRRSIG: "RRSIG", @@ -332,7 +332,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o var _ = byte(fv.Elem(j).(*reflect.UintValue).Get()) } // handle type bit maps - // TODO(mg) + // TODO(mg) } case *reflect.StructValue: off, ok = packStructValue(fv, msg, off) @@ -426,10 +426,10 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o // length of string. String is RAW (not encoded in hex, nor base64) copy(msg[off:off+len(s)], s) off += len(s) - case "txt": - // Counted string: 1 byte length, but the string may be longer - // than 255, in that case it should be multiple strings, for now: - fallthrough + case "txt": + // Counted string: 1 byte length, but the string may be longer + // than 255, in that case it should be multiple strings, for now: + fallthrough case "": // Counted string: 1 byte length. if len(s) > 255 || off+1+len(s) > len(msg) { @@ -520,13 +520,13 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, //fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC") return len(msg), false } - if blocks == 0 { - // Nothing encoded in this window - // Kinda lame to alloc above and to clear it here - nsec = nsec[:ni] - fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue)) - break - } + if blocks == 0 { + // Nothing encoded in this window + // Kinda lame to alloc above and to clear it here + nsec = nsec[:ni] + fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue)) + break + } off += 2 for j := 0; j < blocks; j++ { @@ -689,15 +689,15 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, name := val.FieldByName("HashLength") size = int(name.(*reflect.UintValue).Get()) } - case "RR_TSIG": - switch f.Name { - case "MAC": - name := val.FieldByName("MACSize") - size = int(name.(*reflect.UintValue).Get()) - case "OtherData": - name := val.FieldByName("OtherLen") - size = int(name.(*reflect.UintValue).Get()) - } + case "RR_TSIG": + switch f.Name { + case "MAC": + name := val.FieldByName("MACSize") + size = int(name.(*reflect.UintValue).Get()) + case "OtherData": + name := val.FieldByName("OtherLen") + size = int(name.(*reflect.UintValue).Get()) + } } if off+size > len(msg) { //fmt.Fprintf(os.Stderr, "dns: failure unpacking size-hex string") @@ -705,10 +705,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, } s = hex.EncodeToString(msg[off : off+size]) off += size - case "txt": - // 1 or multiple txt pieces + case "txt": + // 1 or multiple txt pieces rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get()) - Txt: + Txt: if off >= len(msg) || off+1+int(msg[off]) > len(msg) { //fmt.Fprintf(os.Stderr, "dns: failure unpacking txt string") return len(msg), false @@ -721,10 +721,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, } off += n s += string(b) - if off < rdlength { - // More to come - goto Txt - } + if off < rdlength { + // More to come + goto Txt + } case "": if off >= len(msg) || off+1+int(msg[off]) > len(msg) { //fmt.Fprintf(os.Stderr, "dns: failure unpacking string") @@ -770,6 +770,10 @@ func unpackBase64(b []byte) string { } // Helper function for packing, mostly used in dnssec.go +func packUint16(i uint16) (byte, byte) { + return byte(i >> 8), byte(i) +} + func packBase64(s []byte) ([]byte, os.Error) { b64len := base64.StdEncoding.DecodedLen(len(s)) buf := make([]byte, b64len) @@ -1022,25 +1026,32 @@ func (dns *Msg) String() string { if len(dns.Question) > 0 { s += "\n;; QUESTION SECTION:\n" for i := 0; i < len(dns.Question); i++ { - s += dns.Question[i].String() + "\n" + // Need check if it exists? TODO(mg) + s += dns.Question[i].String() + "\n" } } if len(dns.Answer) > 0 { s += "\n;; ANSWER SECTION:\n" for i := 0; i < len(dns.Answer); i++ { - s += dns.Answer[i].String() + "\n" + if dns.Answer[i] != nil { + s += dns.Answer[i].String() + "\n" + } } } if len(dns.Ns) > 0 { s += "\n;; AUTHORITY SECTION:\n" for i := 0; i < len(dns.Ns); i++ { - s += dns.Ns[i].String() + "\n" + if dns.Ns[i] != nil { + s += dns.Ns[i].String() + "\n" + } } } if len(dns.Extra) > 0 { s += "\n;; ADDITIONAL SECTION:\n" for i := 0; i < len(dns.Extra); i++ { - s += dns.Extra[i].String() + "\n" + if dns.Extra[i] != nil { + s += dns.Extra[i].String() + "\n" + } } } return s diff --git a/resolver.go b/resolver.go index 487cb7d0..022f9788 100644 --- a/resolver.go +++ b/resolver.go @@ -45,9 +45,10 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) { // Check if there is a TSIG appended, if so, check it var ( c net.Conn - in *Msg port string + inb []byte ) + in := new(Msg) if len(res.Servers) == 0 { return nil, &Error{Error: "No servers defined"} } @@ -81,9 +82,12 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) { continue } if res.Tcp { - in, err = exchangeTCP(c, sending, res, true) + inb, err = exchangeTCP(c, sending, res, true) + in.Unpack(inb) + } else { - in, err = exchangeUDP(c, sending, res, true) + inb, err = exchangeUDP(c, sending, res, true) + in.Unpack(inb) } res.Rtt[server] = time.Nanoseconds() - t @@ -114,9 +118,12 @@ type Xfr struct { // Channel m is closed when the IXFR ends. func (res *Resolver) Ixfr(q *Msg, m chan Xfr) { // TSIG - var port string - var in *Msg - var x Xfr + var ( + port string + x Xfr + inb []byte + ) + in := new(Msg) if res.Port == "" { port = "53" } else { @@ -149,9 +156,11 @@ Server: defer c.Close() for { if first { - in, err = exchangeTCP(c, sending, res, true) + inb, err = exchangeTCP(c, sending, res, true) + in.Unpack(inb) } else { - in, err = exchangeTCP(c, sending, res, false) + inb, err = exchangeTCP(c, sending, res, false) + in.Unpack(inb) } if err != nil { @@ -220,8 +229,11 @@ Server: // the zone as-is. Xfr.Add is always true. // The channel is closed to signal the end of the AXFR. func (res *Resolver) AxfrTSIG(q *Msg, m chan Xfr, secret string) { - var port string - var in *Msg + var ( + port string + inb []byte + ) + in := new(Msg) if res.Port == "" { port = "53" } else { @@ -263,9 +275,17 @@ Server: defer c.Close() // TODO(mg): if not open? for { if first { - in, err = exchangeTCP(c, sending, res, true) + inb, err = exchangeTCP(c, sending, res, true) + stripTSIG(inb) + /* + pt2 := new(Msg) + pt2.Unpack(t2) + //println("P", pt2.String()) + */ + in.Unpack(inb) } else { - in, err = exchangeTCP(c, sending, res, false) + inb, err = exchangeTCP(c, sending, res, false) + in.Unpack(inb) } if err != nil { @@ -282,7 +302,7 @@ Server: t := in.Extra[len(in.Extra)-1] switch t.(type) { case *RR_TSIG: - if t.(*RR_TSIG).Verify(in, secret, reqmac) { + if t.(*RR_TSIG).Verify(inb, secret, reqmac) { println("Validates") } else { println("DOES NOT validate") @@ -322,8 +342,11 @@ Server: // the zone as-is. Xfr.Add is always true. // The channel is closed to signal the end of the AXFR. func (res *Resolver) Axfr(q *Msg, m chan Xfr) { - var port string - var in *Msg + var ( + port string + inb []byte + ) + in := new(Msg) if res.Port == "" { port = "53" } else { @@ -343,17 +366,6 @@ func (res *Resolver) Axfr(q *Msg, m chan Xfr) { return } -/* - // Need the secret! - var tsig *RR_TSIG - // Check if there is a TSIG added - if len(q.Extra) > 0 { - lastrr := q.Extra[len(q.Extra)-1] - if lastrr.Header().Rrtype == TypeTSIG { - tsig = lastrr.(*RR_TSIG) - } - } - */ Server: for i := 0; i < len(res.Servers); i++ { server := res.Servers[i] + ":" + port @@ -365,9 +377,11 @@ Server: defer c.Close() // TODO(mg): if not open? for { if first { - in, err = exchangeTCP(c, sending, res, true) + inb, err = exchangeTCP(c, sending, res, true) + in.Unpack(inb) } else { - in, err = exchangeTCP(c, sending, res, false) + inb, err = exchangeTCP(c, sending, res, false) + in.Unpack(inb) } if err != nil { @@ -408,7 +422,7 @@ Server: // Send a request on the connection and hope for a reply. // Up to res.Attempts attempts. If send is false, nothing // is send. -func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) { +func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) { var timeout int64 var attempts int if r.Mangle != nil { @@ -443,18 +457,13 @@ func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) } return nil, err } - - in := new(Msg) - if !in.Unpack(buf) { - continue - } - return in, nil + return buf, nil } return nil, &Error{Error: servErr} } // Up to res.Attempts attempts. -func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) { +func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) { var timeout int64 var attempts int if r.Mangle != nil { @@ -484,7 +493,7 @@ func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) } c.SetReadTimeout(timeout * 1e9) // nanoseconds - // The server replies with two bytes length + // The server replies with two bytes length. buf, err := recvTCP(c) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { @@ -492,11 +501,7 @@ func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) } return nil, err } - in := new(Msg) - if !in.Unpack(buf) { - continue - } - return in, nil + return buf, nil } return nil, &Error{Error: servErr} } @@ -510,7 +515,7 @@ func sendUDP(m []byte, c net.Conn) os.Error { } func recvUDP(c net.Conn) ([]byte, os.Error) { - m := make([]byte, DefaultMsgSize) // More than enough??? + m := make([]byte, DefaultMsgSize) n, err := c.Read(m) if err != nil { return nil, err @@ -537,8 +542,7 @@ func sendTCP(m []byte, c net.Conn) os.Error { } func recvTCP(c net.Conn) ([]byte, os.Error) { - l := make([]byte, 2) // receiver length - // The server replies with two bytes length + l := make([]byte, 2) // The server replies with two bytes length. _, err := c.Read(l) if err != nil { return nil, err diff --git a/tsig.go b/tsig.go index 6b6e65e7..c4669a13 100644 --- a/tsig.go +++ b/tsig.go @@ -107,7 +107,7 @@ func (t *RR_TSIG) Generate(m *Msg, secret string) bool { // the TSIG record still attached (as the last rr in the Additional // section). Return true on success. // The secret is a base64 encoded string with the secret. -func (t *RR_TSIG) Verify(m *Msg, secret, reqmac string) bool { +func (t *RR_TSIG) Verify(m []byte, secret, reqmac string) bool { rawsecret, err := packBase64([]byte(secret)) if err != nil { return false @@ -121,9 +121,8 @@ func (t *RR_TSIG) Verify(m *Msg, secret, reqmac string) bool { if t.Header().Rrtype != TypeTSIG { return false } - println(msg2.String()) + msg2.MsgHdr.Id = t.OrigId - println(msg2.String()) msg2.Extra = msg2.Extra[:len(msg2.Extra)-1] // Strip off the TSIG buf, ok := tsigToBuf(t, msg2, reqmac) if !ok { @@ -182,3 +181,53 @@ func tsigToBuf(rr *RR_TSIG, msg *Msg, reqmac string) ([]byte, bool) { } return buf, true } + +// Strip the TSIG from the pkt. +func stripTSIG(orig []byte) ([]byte, bool) { + // Copied from msg.go's Unpack() + // Header. + var dh Header + dns := new(Msg) + msg := make([]byte, len(orig)) + copy(msg, orig) // fhhh.. another copy + off := 0 + tsigoff := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return nil, false + } + if dh.Arcount == 0 { + // No records at all in the additional. + return nil, false + } + + // Arrays. + dns.Question = make([]Question, dh.Qdcount) + dns.Answer = make([]RR, dh.Ancount) + dns.Ns = make([]RR, dh.Nscount) + dns.Extra = make([]RR, dh.Arcount) + + for i := 0; i < len(dns.Question); i++ { + off, ok = unpackStruct(&dns.Question[i], msg, off) + } + for i := 0; i < len(dns.Answer); i++ { + dns.Answer[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.Ns); i++ { + dns.Ns[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.Extra); i++ { + tsigoff = off + dns.Extra[i], off, ok = unpackRR(msg, off) + if dns.Extra[i].Header().Rrtype == TypeTSIG { + // Adjust Arcount. + arcount, _ := unpackUint16(msg, 10) + msg[10], msg[11] = packUint16(arcount-1) + break + } + } + if !ok { + return nil, false + } + return msg[:tsigoff], true +}