diff --git a/_examples/chaos/chaos.go b/_examples/chaos/chaos.go index 47f5f16d..d1dba1f5 100644 --- a/_examples/chaos/chaos.go +++ b/_examples/chaos/chaos.go @@ -27,12 +27,12 @@ func main() { } for _, a := range addr { m.Question[0] = dns.Question{"version.bind.", dns.TypeTXT, dns.ClassCHAOS} - in := c.Exchange(m, a) + in, _ := c.Exchange(m, a) if in != nil && in.Answer != nil { fmt.Printf("%v\n", in.Answer[0]) } m.Question[0] = dns.Question{"hostname.bind.", dns.TypeTXT, dns.ClassCHAOS} - in = c.Exchange(m, a) + in, _ = c.Exchange(m, a) if in != nil && in.Answer != nil { fmt.Printf("%v\n", in.Answer[0]) } @@ -40,6 +40,8 @@ func main() { } func qhandler(w dns.RequestWriter, m *dns.Msg) { + w.Dial() + defer w.Close() w.Send(m) r, _ := w.Receive() w.Write(r) diff --git a/_examples/funkensturm/config_delay.go b/_examples/funkensturm/config_delay.go index 79f7e032..f03e520a 100644 --- a/_examples/funkensturm/config_delay.go +++ b/_examples/funkensturm/config_delay.go @@ -40,7 +40,7 @@ func delay(m *dns.Msg) (buf []byte) { } println("Ok: let it through") for _, c := range qr { - o = c.Client.Exchange(m, c.Addr) + o, _ = c.Client.Exchange(m, c.Addr) } buf, _ = o.Pack() return diff --git a/_examples/funkensturm/config_rproxy.go b/_examples/funkensturm/config_rproxy.go index 4ba4a212..2bd575a1 100644 --- a/_examples/funkensturm/config_rproxy.go +++ b/_examples/funkensturm/config_rproxy.go @@ -99,7 +99,7 @@ func checkcache(m *dns.Msg) (o []byte) { println("Cache miss") var p *dns.Msg for _, c := range qr { - p = c.Client.Exchange(m, c.Addr) + p, _ = c.Client.Exchange(m, c.Addr) } cache.add(p) o, _ = p.Pack() diff --git a/_examples/funkensturm/config_sign.go b/_examples/funkensturm/config_sign.go index 447f3855..3ef5d40e 100644 --- a/_examples/funkensturm/config_sign.go +++ b/_examples/funkensturm/config_sign.go @@ -38,7 +38,7 @@ func sign(m *dns.Msg) *dns.Msg { func sendsign(m *dns.Msg) (o []byte) { var p *dns.Msg for _, c := range qr { - p = c.Client.Exchange(m, c.Addr) + p, _ = c.Client.Exchange(m, c.Addr) } o, _ = sign(p).Pack() println("signing") @@ -48,7 +48,7 @@ func sendsign(m *dns.Msg) (o []byte) { func send(m *dns.Msg) (o []byte) { var p *dns.Msg for _, c := range qr { - p = c.Client.Exchange(m, c.Addr) + p, _ = c.Client.Exchange(m, c.Addr) } o, _ = p.Pack() return diff --git a/_examples/key2ds/key2ds.go b/_examples/key2ds/key2ds.go index efe7181e..869441ba 100644 --- a/_examples/key2ds/key2ds.go +++ b/_examples/key2ds/key2ds.go @@ -28,7 +28,7 @@ func main() { m.Extra = append(m.Extra, e) c := dns.NewClient() - r := c.Exchange(m, conf.Servers[0]) + r, _ := c.Exchange(m, conf.Servers[0]) if r == nil { fmt.Printf("*** no answer received for %s\n", os.Args[1]) os.Exit(1) diff --git a/_examples/mx/mx.go b/_examples/mx/mx.go index 410e6d7f..53af98bd 100644 --- a/_examples/mx/mx.go +++ b/_examples/mx/mx.go @@ -22,9 +22,9 @@ func main() { m.MsgHdr.RecursionDesired = true // Simple sync query, nothing fancy - r := c.Exchange(m, config.Servers[0]) - - if r == nil { + r, err := c.Exchange(m, config.Servers[0]) + if err != nil { + fmt.Printf("%s\n", err.String()) os.Exit(1) } diff --git a/_examples/q/q.go b/_examples/q/q.go index f0b82177..0b65f13b 100644 --- a/_examples/q/q.go +++ b/_examples/q/q.go @@ -11,7 +11,10 @@ import ( func q(w dns.RequestWriter, m *dns.Msg) { w.Send(m) - r, _ := w.Receive() + r, err := w.Receive() + if err != nil { + fmt.Printf("%s\n", err.String()) + } w.Write(r) } @@ -130,9 +133,11 @@ forever: select { case r := <-dns.DefaultReplyChan: if r[1] != nil { - if r[0].Id != r[1].Id { - fmt.Printf("Id mismatch\n") - } + if r[1].Rcode == dns.RcodeSuccess { + if r[0].Id != r[1].Id { + fmt.Printf("Id mismatch\n") + } + } if *short { r[1] = shortMsg(r[1]) } diff --git a/client.go b/client.go index dbeab5b7..13bff4ad 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,8 @@ type RequestWriter interface { Write(*Msg) Send(*Msg) os.Error Receive() (*Msg, os.Error) + Close() os.Error + Dial() os.Error } // hijacked connections...? @@ -123,6 +125,7 @@ type Client struct { ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections TsigSecret map[string]string // secret(s) for Tsig map[] + //Conn net.Conn // if set, use this connection, otherwise Dial again TODO // LocalAddr string // Local address to use } @@ -191,44 +194,59 @@ func (c *Client) Do(m *Msg, a string) { // ExchangeBuf performs a synchronous query. It sends the buffer m to the // address (net.Addr?) contained in a -func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) bool { +func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, err os.Error) { w := new(reply) w.client = c w.addr = a - _, err := w.writeClient(inbuf) - defer w.closeClient() // XXX here?? what about TCP which should remain open - if err != nil { - println(err.String()) - return false + if err = w.Dial(); err != nil { + return 0, err + } + defer w.Close() // XXX here?? what about TCP which should remain open + if n, err = w.writeClient(inbuf); err != nil { + return 0, err } - // udp / tcp TODO - n, err := w.readClient(outbuf) - if err != nil { - return false + if n, err = w.readClient(outbuf); err != nil { + return n, err } - outbuf = outbuf[:n] - return true + return n, nil } // Exchange performs an synchronous query. It sends the message m to the address // contained in a and waits for an reply. -func (c *Client) Exchange(m *Msg, a string) *Msg { +func (c *Client) Exchange(m *Msg, a string) (r *Msg, err os.Error) { + var n int out, ok := m.Pack() if !ok { panic("failed to pack message") } - in := make([]byte, DefaultMsgSize) - if ok := c.ExchangeBuffer(out, a, in); !ok { - return nil - } - r := new(Msg) - if ok := r.Unpack(in); !ok { - return nil + in := make([]byte, DefaultMsgSize) + if n, err = c.ExchangeBuffer(out, a, in); err != nil { + return nil, err } - return r + r = new(Msg) + if ok := r.Unpack(in[:n]); !ok { + return nil, ErrUnpack + } + return r, nil } +// Dial connects to the address addr for the networks c.Net +func (w *reply) Dial() os.Error { + conn, err := net.Dial(w.Client().Net, w.addr) + if err != nil { + return err + } + w.conn = conn + return nil +} + +// UDP/TCP stuff big TODO +func (w *reply) Close() (err os.Error) { + return w.conn.Close() +} + + func (w *reply) WriteMessages(m []*Msg) { m1 := append([]*Msg{w.req}, m...) w.Client().ChannelReply <- m1 @@ -347,12 +365,12 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { if w.Client().Net == "" { panic("c.Net empty") } - - conn, err := net.Dial(w.Client().Net, w.addr) - if err != nil { - return 0, err + if w.conn == nil { + // No connection yet, dial it. impl. at this place? TODO + if err = w.Dial(); err != nil { + return 0, err + } } - w.conn = conn switch w.Client().Net { case "tcp", "tcp4", "tcp6": if len(p) < 2 { @@ -361,7 +379,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { for a := 0; a < w.Client().Attempts; a++ { l := make([]byte, 2) l[0], l[1] = packUint16(uint16(len(p))) - n, err = conn.Write(l) + n, err = w.conn.Write(l) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { continue @@ -371,7 +389,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { if n != 2 { return n, io.ErrShortWrite } - n, err = conn.Write(p) + n, err = w.conn.Write(p) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { continue @@ -380,7 +398,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { } i := n if i < len(p) { - j, err := conn.Write(p[i:len(p)]) + j, err := w.conn.Write(p[i:len(p)]) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { // We are half way in our write... @@ -394,7 +412,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { } case "udp", "udp4", "udp6": for a := 0; a < w.Client().Attempts; a++ { - n, err = conn.(*net.UDPConn).WriteTo(p, conn.RemoteAddr()) + n, err = w.conn.(*net.UDPConn).WriteTo(p, w.conn.RemoteAddr()) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { continue @@ -405,8 +423,3 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) { } return 0, nil } - -// UDP/TCP stuff -func (w *reply) closeClient() (err os.Error) { - return w.conn.Close() -} diff --git a/msg.go b/msg.go index b972ca88..d5b288cf 100644 --- a/msg.go +++ b/msg.go @@ -365,13 +365,14 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) //fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") return len(msg), false } - msg[off] = byte(i) + msg[off] = byte(fv.Uint()) off++ case reflect.Uint16: if off+2 > len(msg) { //fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") return len(msg), false } + i := fv.Uint() msg[off] = byte(i >> 8) msg[off+1] = byte(i) off += 2 @@ -380,6 +381,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) //fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") return len(msg), false } + i := fv.Uint() msg[off] = byte(i >> 24) msg[off+1] = byte(i >> 16) msg[off+2] = byte(i >> 8) @@ -391,6 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) //fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") return len(msg), false } + i := fv.Uint() msg[off] = byte(i >> 40) msg[off+1] = byte(i >> 32) msg[off+2] = byte(i >> 24)