Forget to the add the i := fv.Uint()

This commit is contained in:
Miek Gieben 2011-08-08 13:10:35 +02:00
parent 0d917b3c92
commit 861a2adb6c
9 changed files with 73 additions and 50 deletions

View File

@ -27,12 +27,12 @@ func main() {
} }
for _, a := range addr { for _, a := range addr {
m.Question[0] = dns.Question{"version.bind.", dns.TypeTXT, dns.ClassCHAOS} m.Question[0] = dns.Question{"version.bind.", dns.TypeTXT, dns.ClassCHAOS}
in := c.Exchange(m, a) in, _ := c.Exchange(m, a)
if in != nil && in.Answer != nil { if in != nil && in.Answer != nil {
fmt.Printf("%v\n", in.Answer[0]) fmt.Printf("%v\n", in.Answer[0])
} }
m.Question[0] = dns.Question{"hostname.bind.", dns.TypeTXT, dns.ClassCHAOS} m.Question[0] = dns.Question{"hostname.bind.", dns.TypeTXT, dns.ClassCHAOS}
in = c.Exchange(m, a) in, _ = c.Exchange(m, a)
if in != nil && in.Answer != nil { if in != nil && in.Answer != nil {
fmt.Printf("%v\n", in.Answer[0]) fmt.Printf("%v\n", in.Answer[0])
} }
@ -40,6 +40,8 @@ func main() {
} }
func qhandler(w dns.RequestWriter, m *dns.Msg) { func qhandler(w dns.RequestWriter, m *dns.Msg) {
w.Dial()
defer w.Close()
w.Send(m) w.Send(m)
r, _ := w.Receive() r, _ := w.Receive()
w.Write(r) w.Write(r)

View File

@ -40,7 +40,7 @@ func delay(m *dns.Msg) (buf []byte) {
} }
println("Ok: let it through") println("Ok: let it through")
for _, c := range qr { for _, c := range qr {
o = c.Client.Exchange(m, c.Addr) o, _ = c.Client.Exchange(m, c.Addr)
} }
buf, _ = o.Pack() buf, _ = o.Pack()
return return

View File

@ -99,7 +99,7 @@ func checkcache(m *dns.Msg) (o []byte) {
println("Cache miss") println("Cache miss")
var p *dns.Msg var p *dns.Msg
for _, c := range qr { for _, c := range qr {
p = c.Client.Exchange(m, c.Addr) p, _ = c.Client.Exchange(m, c.Addr)
} }
cache.add(p) cache.add(p)
o, _ = p.Pack() o, _ = p.Pack()

View File

@ -38,7 +38,7 @@ func sign(m *dns.Msg) *dns.Msg {
func sendsign(m *dns.Msg) (o []byte) { func sendsign(m *dns.Msg) (o []byte) {
var p *dns.Msg var p *dns.Msg
for _, c := range qr { for _, c := range qr {
p = c.Client.Exchange(m, c.Addr) p, _ = c.Client.Exchange(m, c.Addr)
} }
o, _ = sign(p).Pack() o, _ = sign(p).Pack()
println("signing") println("signing")
@ -48,7 +48,7 @@ func sendsign(m *dns.Msg) (o []byte) {
func send(m *dns.Msg) (o []byte) { func send(m *dns.Msg) (o []byte) {
var p *dns.Msg var p *dns.Msg
for _, c := range qr { for _, c := range qr {
p = c.Client.Exchange(m, c.Addr) p, _ = c.Client.Exchange(m, c.Addr)
} }
o, _ = p.Pack() o, _ = p.Pack()
return return

View File

@ -28,7 +28,7 @@ func main() {
m.Extra = append(m.Extra, e) m.Extra = append(m.Extra, e)
c := dns.NewClient() c := dns.NewClient()
r := c.Exchange(m, conf.Servers[0]) r, _ := c.Exchange(m, conf.Servers[0])
if r == nil { if r == nil {
fmt.Printf("*** no answer received for %s\n", os.Args[1]) fmt.Printf("*** no answer received for %s\n", os.Args[1])
os.Exit(1) os.Exit(1)

View File

@ -22,9 +22,9 @@ func main() {
m.MsgHdr.RecursionDesired = true m.MsgHdr.RecursionDesired = true
// Simple sync query, nothing fancy // Simple sync query, nothing fancy
r := c.Exchange(m, config.Servers[0]) r, err := c.Exchange(m, config.Servers[0])
if err != nil {
if r == nil { fmt.Printf("%s\n", err.String())
os.Exit(1) os.Exit(1)
} }

View File

@ -11,7 +11,10 @@ import (
func q(w dns.RequestWriter, m *dns.Msg) { func q(w dns.RequestWriter, m *dns.Msg) {
w.Send(m) w.Send(m)
r, _ := w.Receive() r, err := w.Receive()
if err != nil {
fmt.Printf("%s\n", err.String())
}
w.Write(r) w.Write(r)
} }
@ -130,9 +133,11 @@ forever:
select { select {
case r := <-dns.DefaultReplyChan: case r := <-dns.DefaultReplyChan:
if r[1] != nil { if r[1] != nil {
if r[1].Rcode == dns.RcodeSuccess {
if r[0].Id != r[1].Id { if r[0].Id != r[1].Id {
fmt.Printf("Id mismatch\n") fmt.Printf("Id mismatch\n")
} }
}
if *short { if *short {
r[1] = shortMsg(r[1]) r[1] = shortMsg(r[1])
} }

View File

@ -24,6 +24,8 @@ type RequestWriter interface {
Write(*Msg) Write(*Msg)
Send(*Msg) os.Error Send(*Msg) os.Error
Receive() (*Msg, os.Error) Receive() (*Msg, os.Error)
Close() os.Error
Dial() os.Error
} }
// hijacked connections...? // hijacked connections...?
@ -123,6 +125,7 @@ type Client struct {
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret> TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
//Conn net.Conn // if set, use this connection, otherwise Dial again TODO
// LocalAddr string // Local address to use // LocalAddr string // Local address to use
} }
@ -191,44 +194,59 @@ func (c *Client) Do(m *Msg, a string) {
// ExchangeBuf performs a synchronous query. It sends the buffer m to the // ExchangeBuf performs a synchronous query. It sends the buffer m to the
// address (net.Addr?) contained in a // address (net.Addr?) contained in a
func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) bool { func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, err os.Error) {
w := new(reply) w := new(reply)
w.client = c w.client = c
w.addr = a w.addr = a
_, err := w.writeClient(inbuf) if err = w.Dial(); err != nil {
defer w.closeClient() // XXX here?? what about TCP which should remain open return 0, err
if err != nil { }
println(err.String()) defer w.Close() // XXX here?? what about TCP which should remain open
return false if n, err = w.writeClient(inbuf); err != nil {
return 0, err
} }
// udp / tcp TODO // udp / tcp TODO
n, err := w.readClient(outbuf) if n, err = w.readClient(outbuf); err != nil {
if err != nil { return n, err
return false
} }
outbuf = outbuf[:n] return n, nil
return true
} }
// Exchange performs an synchronous query. It sends the message m to the address // Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. // contained in a and waits for an reply.
func (c *Client) Exchange(m *Msg, a string) *Msg { func (c *Client) Exchange(m *Msg, a string) (r *Msg, err os.Error) {
var n int
out, ok := m.Pack() out, ok := m.Pack()
if !ok { if !ok {
panic("failed to pack message") panic("failed to pack message")
} }
in := make([]byte, DefaultMsgSize) in := make([]byte, DefaultMsgSize)
if ok := c.ExchangeBuffer(out, a, in); !ok { if n, err = c.ExchangeBuffer(out, a, in); err != nil {
return nil return nil, err
} }
r := new(Msg) r = new(Msg)
if ok := r.Unpack(in); !ok { if ok := r.Unpack(in[:n]); !ok {
return nil return nil, ErrUnpack
} }
return r return r, nil
} }
// Dial connects to the address addr for the networks c.Net
func (w *reply) Dial() os.Error {
conn, err := net.Dial(w.Client().Net, w.addr)
if err != nil {
return err
}
w.conn = conn
return nil
}
// UDP/TCP stuff big TODO
func (w *reply) Close() (err os.Error) {
return w.conn.Close()
}
func (w *reply) WriteMessages(m []*Msg) { func (w *reply) WriteMessages(m []*Msg) {
m1 := append([]*Msg{w.req}, m...) m1 := append([]*Msg{w.req}, m...)
w.Client().ChannelReply <- m1 w.Client().ChannelReply <- m1
@ -347,12 +365,12 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
if w.Client().Net == "" { if w.Client().Net == "" {
panic("c.Net empty") panic("c.Net empty")
} }
if w.conn == nil {
conn, err := net.Dial(w.Client().Net, w.addr) // No connection yet, dial it. impl. at this place? TODO
if err != nil { if err = w.Dial(); err != nil {
return 0, err return 0, err
} }
w.conn = conn }
switch w.Client().Net { switch w.Client().Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if len(p) < 2 { if len(p) < 2 {
@ -361,7 +379,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
for a := 0; a < w.Client().Attempts; a++ { for a := 0; a < w.Client().Attempts; a++ {
l := make([]byte, 2) l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p))) l[0], l[1] = packUint16(uint16(len(p)))
n, err = conn.Write(l) n, err = w.conn.Write(l)
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
continue continue
@ -371,7 +389,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
if n != 2 { if n != 2 {
return n, io.ErrShortWrite return n, io.ErrShortWrite
} }
n, err = conn.Write(p) n, err = w.conn.Write(p)
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
continue continue
@ -380,7 +398,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
} }
i := n i := n
if i < len(p) { if i < len(p) {
j, err := conn.Write(p[i:len(p)]) j, err := w.conn.Write(p[i:len(p)])
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
// We are half way in our write... // We are half way in our write...
@ -394,7 +412,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
} }
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
for a := 0; a < w.Client().Attempts; a++ { for a := 0; a < w.Client().Attempts; a++ {
n, err = conn.(*net.UDPConn).WriteTo(p, conn.RemoteAddr()) n, err = w.conn.(*net.UDPConn).WriteTo(p, w.conn.RemoteAddr())
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
continue continue
@ -405,8 +423,3 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
} }
return 0, nil return 0, nil
} }
// UDP/TCP stuff
func (w *reply) closeClient() (err os.Error) {
return w.conn.Close()
}

5
msg.go
View File

@ -365,13 +365,14 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool)
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint8")
return len(msg), false return len(msg), false
} }
msg[off] = byte(i) msg[off] = byte(fv.Uint())
off++ off++
case reflect.Uint16: case reflect.Uint16:
if off+2 > len(msg) { if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint16")
return len(msg), false return len(msg), false
} }
i := fv.Uint()
msg[off] = byte(i >> 8) msg[off] = byte(i >> 8)
msg[off+1] = byte(i) msg[off+1] = byte(i)
off += 2 off += 2
@ -380,6 +381,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool)
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint32")
return len(msg), false return len(msg), false
} }
i := fv.Uint()
msg[off] = byte(i >> 24) msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16) msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8) msg[off+2] = byte(i >> 8)
@ -391,6 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool)
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint64")
return len(msg), false return len(msg), false
} }
i := fv.Uint()
msg[off] = byte(i >> 40) msg[off] = byte(i >> 40)
msg[off+1] = byte(i >> 32) msg[off+1] = byte(i >> 32)
msg[off+2] = byte(i >> 24) msg[off+2] = byte(i >> 24)