diff --git a/client.go b/client.go index aab0886e..9677c9d5 100644 --- a/client.go +++ b/client.go @@ -23,12 +23,22 @@ type QueryHandler interface { type RequestWriter interface { WriteMessages([]*Msg) Write(*Msg) + WriteClient(*Msg) os.Error + ReadClient() (*Msg, os.Error) } // hijacked connections...? type reply struct { - Client *Client + client *Client + addr string req *Msg + conn net.Conn +} + +type Request struct { + Request *Msg + Addr string + Client *Client } // QueryMux is an DNS request multiplexer. It matches the @@ -46,7 +56,7 @@ func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} } var DefaultQueryMux = NewQueryMux() func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) } -func newQueryChan() chan *Msg { return make(chan *Msg) } +func newQueryChan() chan *Request { return make(chan *Request) } // Default channel to use for the resolver var DefaultReplyChan = newQueryChanSlice() @@ -67,9 +77,6 @@ func HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) { DefaultQueryMux.HandleQueryFunc(pattern, handler) } -// Helper handlers -// Todo - // reusing zoneMatch from server.go func (mux *QueryMux) match(zone string) QueryHandler { var h QueryHandler @@ -101,124 +108,152 @@ func (mux *QueryMux) HandleQueryFunc(pattern string, handler func(RequestWriter, mux.Handle(pattern, HandlerQueryFunc(handler)) } -func (mux *QueryMux) QueryDNS(w RequestWriter, request *Msg) { - h := mux.match(request.Question[0].Name) +func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) { + h := mux.match(r.Question[0].Name) if h == nil { // h = RefusedHandler() // something else } - h.QueryDNS(w, request) + h.QueryDNS(w, r) } type Client struct { - Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one - Addr string // address to call - Attempts int // number of attempts - Retry bool // retry with TCP - ChannelQuery chan *Msg // read DNS request from this channel - ChannelReply chan []*Msg // read DNS request from this channel - Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil - ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections - WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections - conn net.Conn // current net work connection + Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one + Addr string // address to call + Attempts int // number of attempts + Retry bool // retry with TCP + ChannelQuery chan *Request // read DNS request from this channel + ChannelReply chan []*Msg // read DNS request from this channel + ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections + WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections } -// Query accepts incoming DNS request, -// Write to in -// creating a new service thread for each. The service threads -// read requests and then call handler to reply to them. -// Handler is typically nil, in which case the DefaultServeMux is used. -func Query(c chan *Msg, handler QueryHandler) os.Error { - client := &Client{ChannelQuery: c, Handler: handler} - return client.Query() +func NewClient() *Client { + c := new(Client) + c.Net = "udp" + c.Attempts = 1 + c.ChannelReply = DefaultReplyChan + return c } -func (c *Client) Query() os.Error { - handler := c.Handler +type Query struct { + ChannelQuery chan *Request // read DNS request from this channel + Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil +} + +func (q *Query) Query() os.Error { + handler := q.Handler if handler == nil { handler = DefaultQueryMux } forever: for { select { - case in := <-c.ChannelQuery: + case in := <-q.ChannelQuery: w := new(reply) - w.Client = c - w.req = in - handler.QueryDNS(w, w.req) + w.req = in.Request + w.addr = in.Addr + w.client = in.Client + handler.QueryDNS(w, in.Request) } } return nil } -func (c *Client) ListenAndQuery() os.Error { - if c.ChannelQuery == nil { - c.ChannelQuery = DefaultQueryChan +func (q *Query) ListenAndQuery() os.Error { + if q.ChannelQuery == nil { + q.ChannelQuery = DefaultQueryChan } - if c.ChannelReply == nil { - c.ChannelReply = DefaultReplyChan - } - return c.Query() + return q.Query() } -func (c *Client) Do(m *Msg, addr string) { - if c.ChannelQuery == nil { - DefaultQueryChan <- m - } - if c.Net == "" { - c.Net = "udp" - } - if c.Attempts == 0 { - c.Attempts = 1 - } - c.Addr = addr -} - -func ListenAndQuery(c chan *Msg, handler QueryHandler) { - client := &Client{ChannelQuery: c, Handler: handler} - go client.ListenAndQuery() +func ListenAndQuery(c chan *Request, handler QueryHandler) { + q := &Query{ChannelQuery: c, Handler: handler} + go q.ListenAndQuery() } func (w *reply) Write(m *Msg) { - // Write to the channel - w.Client.ChannelReply <- []*Msg{w.req, m} + w.Client().ChannelReply <- []*Msg{w.req, m} +} + +// async querie +func (c *Client) Do(m *Msg, a string) { + if c.ChannelQuery == nil { + DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m} + } else { + c.ChannelQuery <- &Request{Client: c, Addr: a, Request: m} + } +} + +// A synchronize query +func (c *Client) Exchange(m *Msg, a string) *Msg { + w := new(reply) + w.client = c + w.addr = a + out, ok := m.Pack() + if !ok { + // + } + _, err := w.writeClient(out) + if err != nil { + return nil + } + + // udp / tcp + p := make([]byte, DefaultMsgSize) + n, err := w.readClient(p) + if err != nil { + return nil + } + p = p[:n] + if ok := m.Unpack(p); !ok { + return nil + } + return m } func (w *reply) WriteMessages(m []*Msg) { - // Write to the channel - m1 := append([]*Msg{w.req}, m...) // Really the way? - w.Client.ChannelReply <- m1 + m1 := append([]*Msg{w.req}, m...) + w.Client().ChannelReply <- m1 } -func (c *Client) Read() (*Msg, os.Error) { - if c.conn == nil { - panic("no connection") - } - var p []byte - var m *Msg - switch c.Net { - case "tcp": - - case "udp": - p = make([]byte, DefaultMsgSize) - n, err := c.read(p) - if err != nil { - return nil, err - } - p = p[:n] - if ok := m.Unpack(p); !ok { - return nil, ErrUnpack - } - } - return m, nil +func (w *reply) Client() *Client { + return w.client } -func (c *Client) read(p []byte) (n int, err os.Error) { - switch c.Net { +func (w *reply) Request() *Msg { + return w.req +} + +func (w *reply) ReadClient() (*Msg, os.Error) { + var p []byte + m := new(Msg) + switch w.Client().Net { case "tcp": - + // case "udp": - n, _, err = c.conn.(*net.UDPConn).ReadFromUDP(p) + p = make([]byte, DefaultMsgSize) + n, err := w.readClient(p) + if err != nil { + return nil, err + } + p = p[:n] + if ok := m.Unpack(p); !ok { + return nil, ErrUnpack + } + } + return m, nil +} + +func (w *reply) readClient(p []byte) (n int, err os.Error) { + if w.conn == nil { + panic("no connection") + } + switch w.Client().Net { + case "tcp": + // + case "udp": + n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) if err != nil { return n, err } @@ -226,28 +261,32 @@ func (c *Client) read(p []byte) (n int, err os.Error) { return } -func (c *Client) Write(m *Msg) os.Error { +func (w *reply) WriteClient(m *Msg) os.Error { out, ok := m.Pack() if !ok { return ErrPack } - _, err := c.write(out) + _, err := w.writeClient(out) if err != nil { return err } return nil } -// Fill Client.Conn with the connection -func (c *Client) write(p []byte) (n int, err os.Error) { - conn, err := net.Dial(c.Net, "", c.Addr) +func (w *reply) writeClient(p []byte) (n int, err os.Error) { + c := w.Client() + if c.Attempts == 0 { + panic("c.Attempts 0") + } + if c.Net == "" { + panic("c.Net empty") + } + + conn, err := net.Dial(c.Net, "", w.addr) if err != nil { return 0, err } - if c.Attempts == 0 { - panic("client attempts 0") - } - c.conn = conn + w.conn = conn switch c.Net { case "tcp": if len(p) < 2 {