mirror of
https://github.com/miekg/dns.git
synced 2025-12-17 01:31:00 +01:00
fix the async API
The async concurrent api works. client.Exchange() is there as a sync
This commit is contained in:
parent
9f104d58f9
commit
f46069608a
185
client.go
185
client.go
@ -23,12 +23,22 @@ type QueryHandler interface {
|
|||||||
type RequestWriter interface {
|
type RequestWriter interface {
|
||||||
WriteMessages([]*Msg)
|
WriteMessages([]*Msg)
|
||||||
Write(*Msg)
|
Write(*Msg)
|
||||||
|
WriteClient(*Msg) os.Error
|
||||||
|
ReadClient() (*Msg, os.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hijacked connections...?
|
// hijacked connections...?
|
||||||
type reply struct {
|
type reply struct {
|
||||||
Client *Client
|
client *Client
|
||||||
|
addr string
|
||||||
req *Msg
|
req *Msg
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Request *Msg
|
||||||
|
Addr string
|
||||||
|
Client *Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryMux is an DNS request multiplexer. It matches the
|
// QueryMux is an DNS request multiplexer. It matches the
|
||||||
@ -46,7 +56,7 @@ func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
|
|||||||
var DefaultQueryMux = NewQueryMux()
|
var DefaultQueryMux = NewQueryMux()
|
||||||
|
|
||||||
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
|
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
|
||||||
func newQueryChan() chan *Msg { return make(chan *Msg) }
|
func newQueryChan() chan *Request { return make(chan *Request) }
|
||||||
|
|
||||||
// Default channel to use for the resolver
|
// Default channel to use for the resolver
|
||||||
var DefaultReplyChan = newQueryChanSlice()
|
var DefaultReplyChan = newQueryChanSlice()
|
||||||
@ -67,9 +77,6 @@ func HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
|
|||||||
DefaultQueryMux.HandleQueryFunc(pattern, handler)
|
DefaultQueryMux.HandleQueryFunc(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper handlers
|
|
||||||
// Todo
|
|
||||||
|
|
||||||
// reusing zoneMatch from server.go
|
// reusing zoneMatch from server.go
|
||||||
func (mux *QueryMux) match(zone string) QueryHandler {
|
func (mux *QueryMux) match(zone string) QueryHandler {
|
||||||
var h QueryHandler
|
var h QueryHandler
|
||||||
@ -101,13 +108,13 @@ func (mux *QueryMux) HandleQueryFunc(pattern string, handler func(RequestWriter,
|
|||||||
mux.Handle(pattern, HandlerQueryFunc(handler))
|
mux.Handle(pattern, HandlerQueryFunc(handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mux *QueryMux) QueryDNS(w RequestWriter, request *Msg) {
|
func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
|
||||||
h := mux.match(request.Question[0].Name)
|
h := mux.match(r.Question[0].Name)
|
||||||
if h == nil {
|
if h == nil {
|
||||||
// h = RefusedHandler()
|
// h = RefusedHandler()
|
||||||
// something else
|
// something else
|
||||||
}
|
}
|
||||||
h.QueryDNS(w, request)
|
h.QueryDNS(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@ -115,93 +122,118 @@ type Client struct {
|
|||||||
Addr string // address to call
|
Addr string // address to call
|
||||||
Attempts int // number of attempts
|
Attempts int // number of attempts
|
||||||
Retry bool // retry with TCP
|
Retry bool // retry with TCP
|
||||||
ChannelQuery chan *Msg // read DNS request from this channel
|
ChannelQuery chan *Request // read DNS request from this channel
|
||||||
ChannelReply chan []*Msg // read DNS request from this channel
|
ChannelReply chan []*Msg // read DNS request from this channel
|
||||||
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
|
|
||||||
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
||||||
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
||||||
conn net.Conn // current net work connection
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query accepts incoming DNS request,
|
func NewClient() *Client {
|
||||||
// Write to in
|
c := new(Client)
|
||||||
// creating a new service thread for each. The service threads
|
c.Net = "udp"
|
||||||
// read requests and then call handler to reply to them.
|
c.Attempts = 1
|
||||||
// Handler is typically nil, in which case the DefaultServeMux is used.
|
c.ChannelReply = DefaultReplyChan
|
||||||
func Query(c chan *Msg, handler QueryHandler) os.Error {
|
return c
|
||||||
client := &Client{ChannelQuery: c, Handler: handler}
|
|
||||||
return client.Query()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Query() os.Error {
|
type Query struct {
|
||||||
handler := c.Handler
|
ChannelQuery chan *Request // read DNS request from this channel
|
||||||
|
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) Query() os.Error {
|
||||||
|
handler := q.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = DefaultQueryMux
|
handler = DefaultQueryMux
|
||||||
}
|
}
|
||||||
forever:
|
forever:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case in := <-c.ChannelQuery:
|
case in := <-q.ChannelQuery:
|
||||||
w := new(reply)
|
w := new(reply)
|
||||||
w.Client = c
|
w.req = in.Request
|
||||||
w.req = in
|
w.addr = in.Addr
|
||||||
handler.QueryDNS(w, w.req)
|
w.client = in.Client
|
||||||
|
handler.QueryDNS(w, in.Request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListenAndQuery() os.Error {
|
func (q *Query) ListenAndQuery() os.Error {
|
||||||
if c.ChannelQuery == nil {
|
if q.ChannelQuery == nil {
|
||||||
c.ChannelQuery = DefaultQueryChan
|
q.ChannelQuery = DefaultQueryChan
|
||||||
}
|
}
|
||||||
if c.ChannelReply == nil {
|
return q.Query()
|
||||||
c.ChannelReply = DefaultReplyChan
|
|
||||||
}
|
|
||||||
return c.Query()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Do(m *Msg, addr string) {
|
func ListenAndQuery(c chan *Request, handler QueryHandler) {
|
||||||
if c.ChannelQuery == nil {
|
q := &Query{ChannelQuery: c, Handler: handler}
|
||||||
DefaultQueryChan <- m
|
go q.ListenAndQuery()
|
||||||
}
|
|
||||||
if c.Net == "" {
|
|
||||||
c.Net = "udp"
|
|
||||||
}
|
|
||||||
if c.Attempts == 0 {
|
|
||||||
c.Attempts = 1
|
|
||||||
}
|
|
||||||
c.Addr = addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListenAndQuery(c chan *Msg, handler QueryHandler) {
|
|
||||||
client := &Client{ChannelQuery: c, Handler: handler}
|
|
||||||
go client.ListenAndQuery()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *reply) Write(m *Msg) {
|
func (w *reply) Write(m *Msg) {
|
||||||
// Write to the channel
|
w.Client().ChannelReply <- []*Msg{w.req, m}
|
||||||
w.Client.ChannelReply <- []*Msg{w.req, m}
|
}
|
||||||
|
|
||||||
|
// async querie
|
||||||
|
func (c *Client) Do(m *Msg, a string) {
|
||||||
|
if c.ChannelQuery == nil {
|
||||||
|
DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m}
|
||||||
|
} else {
|
||||||
|
c.ChannelQuery <- &Request{Client: c, Addr: a, Request: m}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A synchronize query
|
||||||
|
func (c *Client) Exchange(m *Msg, a string) *Msg {
|
||||||
|
w := new(reply)
|
||||||
|
w.client = c
|
||||||
|
w.addr = a
|
||||||
|
out, ok := m.Pack()
|
||||||
|
if !ok {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
_, err := w.writeClient(out)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// udp / tcp
|
||||||
|
p := make([]byte, DefaultMsgSize)
|
||||||
|
n, err := w.readClient(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p = p[:n]
|
||||||
|
if ok := m.Unpack(p); !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *reply) WriteMessages(m []*Msg) {
|
func (w *reply) WriteMessages(m []*Msg) {
|
||||||
// Write to the channel
|
m1 := append([]*Msg{w.req}, m...)
|
||||||
m1 := append([]*Msg{w.req}, m...) // Really the way?
|
w.Client().ChannelReply <- m1
|
||||||
w.Client.ChannelReply <- m1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Read() (*Msg, os.Error) {
|
func (w *reply) Client() *Client {
|
||||||
if c.conn == nil {
|
return w.client
|
||||||
panic("no connection")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *reply) Request() *Msg {
|
||||||
|
return w.req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *reply) ReadClient() (*Msg, os.Error) {
|
||||||
var p []byte
|
var p []byte
|
||||||
var m *Msg
|
m := new(Msg)
|
||||||
switch c.Net {
|
switch w.Client().Net {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
|
//
|
||||||
case "udp":
|
case "udp":
|
||||||
p = make([]byte, DefaultMsgSize)
|
p = make([]byte, DefaultMsgSize)
|
||||||
n, err := c.read(p)
|
n, err := w.readClient(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -213,12 +245,15 @@ func (c *Client) Read() (*Msg, os.Error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) read(p []byte) (n int, err os.Error) {
|
func (w *reply) readClient(p []byte) (n int, err os.Error) {
|
||||||
switch c.Net {
|
if w.conn == nil {
|
||||||
|
panic("no connection")
|
||||||
|
}
|
||||||
|
switch w.Client().Net {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
|
//
|
||||||
case "udp":
|
case "udp":
|
||||||
n, _, err = c.conn.(*net.UDPConn).ReadFromUDP(p)
|
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
@ -226,28 +261,32 @@ func (c *Client) read(p []byte) (n int, err os.Error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Write(m *Msg) os.Error {
|
func (w *reply) WriteClient(m *Msg) os.Error {
|
||||||
out, ok := m.Pack()
|
out, ok := m.Pack()
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrPack
|
return ErrPack
|
||||||
}
|
}
|
||||||
_, err := c.write(out)
|
_, err := w.writeClient(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill Client.Conn with the connection
|
func (w *reply) writeClient(p []byte) (n int, err os.Error) {
|
||||||
func (c *Client) write(p []byte) (n int, err os.Error) {
|
c := w.Client()
|
||||||
conn, err := net.Dial(c.Net, "", c.Addr)
|
if c.Attempts == 0 {
|
||||||
|
panic("c.Attempts 0")
|
||||||
|
}
|
||||||
|
if c.Net == "" {
|
||||||
|
panic("c.Net empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial(c.Net, "", w.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if c.Attempts == 0 {
|
w.conn = conn
|
||||||
panic("client attempts 0")
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
switch c.Net {
|
switch c.Net {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
if len(p) < 2 {
|
if len(p) < 2 {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user