fix the async API

The async concurrent api works.
client.Exchange() is there as a sync
This commit is contained in:
Miek Gieben 2011-04-17 21:56:40 +02:00
parent 9f104d58f9
commit f46069608a

189
client.go
View File

@ -23,12 +23,22 @@ type QueryHandler interface {
type RequestWriter interface { type RequestWriter interface {
WriteMessages([]*Msg) WriteMessages([]*Msg)
Write(*Msg) Write(*Msg)
WriteClient(*Msg) os.Error
ReadClient() (*Msg, os.Error)
} }
// hijacked connections...? // hijacked connections...?
type reply struct { type reply struct {
Client *Client client *Client
addr string
req *Msg req *Msg
conn net.Conn
}
type Request struct {
Request *Msg
Addr string
Client *Client
} }
// QueryMux is an DNS request multiplexer. It matches the // QueryMux is an DNS request multiplexer. It matches the
@ -46,7 +56,7 @@ func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
var DefaultQueryMux = NewQueryMux() var DefaultQueryMux = NewQueryMux()
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) } func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
func newQueryChan() chan *Msg { return make(chan *Msg) } func newQueryChan() chan *Request { return make(chan *Request) }
// Default channel to use for the resolver // Default channel to use for the resolver
var DefaultReplyChan = newQueryChanSlice() var DefaultReplyChan = newQueryChanSlice()
@ -67,9 +77,6 @@ func HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
DefaultQueryMux.HandleQueryFunc(pattern, handler) DefaultQueryMux.HandleQueryFunc(pattern, handler)
} }
// Helper handlers
// Todo
// reusing zoneMatch from server.go // reusing zoneMatch from server.go
func (mux *QueryMux) match(zone string) QueryHandler { func (mux *QueryMux) match(zone string) QueryHandler {
var h QueryHandler var h QueryHandler
@ -101,13 +108,13 @@ func (mux *QueryMux) HandleQueryFunc(pattern string, handler func(RequestWriter,
mux.Handle(pattern, HandlerQueryFunc(handler)) mux.Handle(pattern, HandlerQueryFunc(handler))
} }
func (mux *QueryMux) QueryDNS(w RequestWriter, request *Msg) { func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
h := mux.match(request.Question[0].Name) h := mux.match(r.Question[0].Name)
if h == nil { if h == nil {
// h = RefusedHandler() // h = RefusedHandler()
// something else // something else
} }
h.QueryDNS(w, request) h.QueryDNS(w, r)
} }
type Client struct { type Client struct {
@ -115,93 +122,118 @@ type Client struct {
Addr string // address to call Addr string // address to call
Attempts int // number of attempts Attempts int // number of attempts
Retry bool // retry with TCP Retry bool // retry with TCP
ChannelQuery chan *Msg // read DNS request from this channel ChannelQuery chan *Request // read DNS request from this channel
ChannelReply chan []*Msg // read DNS request from this channel ChannelReply chan []*Msg // read DNS request from this channel
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
conn net.Conn // current net work connection
} }
// Query accepts incoming DNS request, func NewClient() *Client {
// Write to in c := new(Client)
// creating a new service thread for each. The service threads c.Net = "udp"
// read requests and then call handler to reply to them. c.Attempts = 1
// Handler is typically nil, in which case the DefaultServeMux is used. c.ChannelReply = DefaultReplyChan
func Query(c chan *Msg, handler QueryHandler) os.Error { return c
client := &Client{ChannelQuery: c, Handler: handler}
return client.Query()
} }
func (c *Client) Query() os.Error { type Query struct {
handler := c.Handler ChannelQuery chan *Request // read DNS request from this channel
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
}
func (q *Query) Query() os.Error {
handler := q.Handler
if handler == nil { if handler == nil {
handler = DefaultQueryMux handler = DefaultQueryMux
} }
forever: forever:
for { for {
select { select {
case in := <-c.ChannelQuery: case in := <-q.ChannelQuery:
w := new(reply) w := new(reply)
w.Client = c w.req = in.Request
w.req = in w.addr = in.Addr
handler.QueryDNS(w, w.req) w.client = in.Client
handler.QueryDNS(w, in.Request)
} }
} }
return nil return nil
} }
func (c *Client) ListenAndQuery() os.Error { func (q *Query) ListenAndQuery() os.Error {
if c.ChannelQuery == nil { if q.ChannelQuery == nil {
c.ChannelQuery = DefaultQueryChan q.ChannelQuery = DefaultQueryChan
} }
if c.ChannelReply == nil { return q.Query()
c.ChannelReply = DefaultReplyChan
}
return c.Query()
} }
func (c *Client) Do(m *Msg, addr string) { func ListenAndQuery(c chan *Request, handler QueryHandler) {
if c.ChannelQuery == nil { q := &Query{ChannelQuery: c, Handler: handler}
DefaultQueryChan <- m go q.ListenAndQuery()
}
if c.Net == "" {
c.Net = "udp"
}
if c.Attempts == 0 {
c.Attempts = 1
}
c.Addr = addr
}
func ListenAndQuery(c chan *Msg, handler QueryHandler) {
client := &Client{ChannelQuery: c, Handler: handler}
go client.ListenAndQuery()
} }
func (w *reply) Write(m *Msg) { func (w *reply) Write(m *Msg) {
// Write to the channel w.Client().ChannelReply <- []*Msg{w.req, m}
w.Client.ChannelReply <- []*Msg{w.req, m} }
// async querie
func (c *Client) Do(m *Msg, a string) {
if c.ChannelQuery == nil {
DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m}
} else {
c.ChannelQuery <- &Request{Client: c, Addr: a, Request: m}
}
}
// A synchronize query
func (c *Client) Exchange(m *Msg, a string) *Msg {
w := new(reply)
w.client = c
w.addr = a
out, ok := m.Pack()
if !ok {
//
}
_, err := w.writeClient(out)
if err != nil {
return nil
}
// udp / tcp
p := make([]byte, DefaultMsgSize)
n, err := w.readClient(p)
if err != nil {
return nil
}
p = p[:n]
if ok := m.Unpack(p); !ok {
return nil
}
return m
} }
func (w *reply) WriteMessages(m []*Msg) { func (w *reply) WriteMessages(m []*Msg) {
// Write to the channel m1 := append([]*Msg{w.req}, m...)
m1 := append([]*Msg{w.req}, m...) // Really the way? w.Client().ChannelReply <- m1
w.Client.ChannelReply <- m1
} }
func (c *Client) Read() (*Msg, os.Error) { func (w *reply) Client() *Client {
if c.conn == nil { return w.client
panic("no connection") }
}
var p []byte
var m *Msg
switch c.Net {
case "tcp":
func (w *reply) Request() *Msg {
return w.req
}
func (w *reply) ReadClient() (*Msg, os.Error) {
var p []byte
m := new(Msg)
switch w.Client().Net {
case "tcp":
//
case "udp": case "udp":
p = make([]byte, DefaultMsgSize) p = make([]byte, DefaultMsgSize)
n, err := c.read(p) n, err := w.readClient(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -213,12 +245,15 @@ func (c *Client) Read() (*Msg, os.Error) {
return m, nil return m, nil
} }
func (c *Client) read(p []byte) (n int, err os.Error) { func (w *reply) readClient(p []byte) (n int, err os.Error) {
switch c.Net { if w.conn == nil {
panic("no connection")
}
switch w.Client().Net {
case "tcp": case "tcp":
//
case "udp": case "udp":
n, _, err = c.conn.(*net.UDPConn).ReadFromUDP(p) n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil { if err != nil {
return n, err return n, err
} }
@ -226,28 +261,32 @@ func (c *Client) read(p []byte) (n int, err os.Error) {
return return
} }
func (c *Client) Write(m *Msg) os.Error { func (w *reply) WriteClient(m *Msg) os.Error {
out, ok := m.Pack() out, ok := m.Pack()
if !ok { if !ok {
return ErrPack return ErrPack
} }
_, err := c.write(out) _, err := w.writeClient(out)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
// Fill Client.Conn with the connection func (w *reply) writeClient(p []byte) (n int, err os.Error) {
func (c *Client) write(p []byte) (n int, err os.Error) { c := w.Client()
conn, err := net.Dial(c.Net, "", c.Addr) if c.Attempts == 0 {
panic("c.Attempts 0")
}
if c.Net == "" {
panic("c.Net empty")
}
conn, err := net.Dial(c.Net, "", w.addr)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if c.Attempts == 0 { w.conn = conn
panic("client attempts 0")
}
c.conn = conn
switch c.Net { switch c.Net {
case "tcp": case "tcp":
if len(p) < 2 { if len(p) < 2 {