diff --git a/_examples/axfr/axfr.go b/_examples/axfr/axfr.go index 0c52ddc2..98a81348 100644 --- a/_examples/axfr/axfr.go +++ b/_examples/axfr/axfr.go @@ -18,9 +18,9 @@ func main() { res.Servers[0] = *nameserver c := make(chan dns.Xfr) - m := new(dns.Msg) m.Question = make([]dns.Question, 1) + if *serial > 0 { m.Question[0] = dns.Question{zone, dns.TypeIXFR, dns.ClassINET} soa := new(dns.RR_SOA) @@ -28,11 +28,10 @@ func main() { soa.Serial = uint32(*serial) m.Ns = make([]dns.RR, 1) m.Ns[0] = soa - go res.Ixfr(m, c) } else { m.Question[0] = dns.Question{zone, dns.TypeAXFR, dns.ClassINET} - go res.Axfr(m, c) } + go res.Xfr(m, nil, c) for x := range c { fmt.Printf("%v %v\n",x.Add, x.RR) } diff --git a/dns.go b/dns.go index 4d45bc7f..df835fa6 100644 --- a/dns.go +++ b/dns.go @@ -223,9 +223,7 @@ func (d *Conn) Exchange(request []byte, nosend bool) (reply []byte, err os.Error reply = make([]byte, DefaultMsgSize) } n, err = d.Read(reply) - println("READ ", n) if err != nil { - println(err.String()) return nil, err } reply = reply[:n] diff --git a/resolver.go b/resolver.go index 4b44bda2..be8539b7 100644 --- a/resolver.go +++ b/resolver.go @@ -97,11 +97,23 @@ type Xfr struct { Err os.Error } +func (res *Resolver) Xfr(q *Msg, t *Tsig, m chan Xfr) { + switch q.Question[0].Qtype { + case TypeAXFR: + res.axfr(q, t, m) + case TypeIXFR: + res.ixfr(q, t, m) + default: + // wrong request + return + } +} + // Start an IXFR, q should contain a *Msg with the question // for an IXFR: "miek.nl" ANY IXFR. RRs that should be added // have Xfr.Add set to true otherwise it is false. // Channel m is closed when the IXFR ends. -func (res *Resolver) Ixfr(q *Msg, m chan Xfr) { +func (res *Resolver) ixfr(q *Msg, t *Tsig, m chan Xfr) { var ( x Xfr inb []byte @@ -204,7 +216,7 @@ Server: // returned over the channel, so the caller will receive // the zone as-is. Xfr.Add is always true. // The channel is closed to signal the end of the AXFR. -func (res *Resolver) AxfrTSIG(q *Msg, m chan Xfr, t *Tsig) { +func (res *Resolver) axfr(q *Msg, t *Tsig, m chan Xfr) { var inb []byte in := new(Msg) port, err := check(res, q) @@ -243,7 +255,7 @@ Server: continue Server } - in.Unpack(inb) + in.Unpack(inb) // TODO(mg): error handling if in.Id != q.Id { c.Close() return @@ -266,81 +278,7 @@ Server: sendMsg(in, m, true) return } - } - } - panic("not reached") - return - } - return -} - - -// Start an AXFR, q should contain a message with the question -// for an AXFR: "miek.nl" ANY AXFR. The closing SOA isn't -// returned over the channel, so the caller will receive -// the zone as-is. Xfr.Add is always true. -// The channel is closed to signal the end of the AXFR. -func (res *Resolver) Axfr(q *Msg, m chan Xfr) { - var inb []byte - port, err := check(res, q) - if err != nil { - return - } - in := new(Msg) - - defer close(m) - sending, ok := q.Pack() - if !ok { - return - } - -Server: - for i := 0; i < len(res.Servers); i++ { - server := res.Servers[i] + ":" + port - c, err := net.Dial("tcp", "", server) - if err != nil { - continue Server - } - d := new(Conn) - d.TCP = c.(*net.TCPConn) - d.Addr = d.TCP.RemoteAddr() - - first := true - defer c.Close() // TODO(mg): if not open? - for { - if first { - inb, err = d.Exchange(sending, false) - } else { - inb, err = d.Exchange(sending, true) - } - if err != nil { - c.Close() - continue Server - } - if !in.Unpack(inb) { - println("Failed to unpack") - } - if in.Id != q.Id { - c.Close() - return - } - if first { - if !checkXfrSOA(in, true) { - c.Close() - continue Server - } - first = !first - } - - if !first { - if !checkXfrSOA(in, false) { - // Soa record not the last one - sendMsg(in, m, false) - continue - } else { - sendMsg(in, m, true) - return - } + d.Tsig.TimersOnly = true // } } panic("not reached") diff --git a/tsig.go b/tsig.go index 403c0c13..febc2fd3 100644 --- a/tsig.go +++ b/tsig.go @@ -28,7 +28,7 @@ type Tsig struct { // Request MAC RequestMAC string // Only include the timers if true. - Timers bool + TimersOnly bool } // HMAC hashing codes. These are transmitted as domain names. @@ -93,18 +93,19 @@ func (t *Tsig) Generate(msg []byte) ([]byte, bool) { return nil, false } - // okay, create TSIG, add to message + // Create TSIG and add it to the message. + q := new(Msg) + q.Unpack(msg) // TODO(mg): error handling + rr := new(RR_TSIG) rr.Hdr = RR_Header{Name: t.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0} rr.Fudge = t.Fudge rr.TimeSigned = t.TimeSigned rr.Algorithm = t.Algorithm + rr.OrigId = q.Id rr.MAC = t.MAC rr.MACSize = uint16(len(t.MAC) / 2) - q := new(Msg) - q.Unpack(msg) - q.Extra = append(q.Extra, rr) send, ok := q.Pack() return send, ok @@ -153,7 +154,16 @@ func (t *Tsig) Buffer(msg []byte) ([]byte, bool) { } tsigvar := make([]byte, DefaultMsgSize) - if t.Timers { + if t.TimersOnly { + tsig := new(timerWireFmt) + tsig.TimeSigned = t.TimeSigned + tsig.Fudge = t.Fudge + n, ok1 := packStruct(tsig, tsigvar, 0) + if !ok1 { + return nil, false + } + tsigvar = tsigvar[:n] + } else { tsig := new(tsigWireFmt) tsig.Name = strings.ToLower(t.Name) tsig.Class = ClassANY @@ -169,15 +179,6 @@ func (t *Tsig) Buffer(msg []byte) ([]byte, bool) { return nil, false } tsigvar = tsigvar[:n] - } else { - tsig := new(timerWireFmt) - tsig.TimeSigned = t.TimeSigned - tsig.Fudge = t.Fudge - n, ok1 := packStruct(tsig, tsigvar, 0) - if !ok1 { - return nil, false - } - tsigvar = tsigvar[:n] } if t.RequestMAC != "" { x := append(macbuf, msg...)