From 241d44137111c9ddd1d41d449bbf0a5423e3eb5c Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Wed, 9 Feb 2011 17:59:06 +0100 Subject: [PATCH] fix the multiplexing --- msg.go | 1 - resolver.go | 4 ++ server.go | 182 ++++++++++++++++++++++++++++++------------------- server_test.go | 2 +- 4 files changed, 118 insertions(+), 71 deletions(-) diff --git a/msg.go b/msg.go index b5e15a8d..dd8907df 100644 --- a/msg.go +++ b/msg.go @@ -52,7 +52,6 @@ type Msg struct { Answer []RR Ns []RR Extra []RR -// Error os.Error // Do I want this?? } // Map of strings for each RR wire type. diff --git a/resolver.go b/resolver.go index d6542fe1..e9ce3cf3 100644 --- a/resolver.go +++ b/resolver.go @@ -53,7 +53,11 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) { in *Msg port string ) + if len(res.Servers) == 0 { + return nil, &Error{Error: "No servers defined"} + } // len(res.Server) == 0 can be perfectly valid, when setting up the resolver + // It is now if res.Port == "" { port = "53" } else { diff --git a/server.go b/server.go index 4810dbef..dcc018e1 100644 --- a/server.go +++ b/server.go @@ -27,9 +27,14 @@ import ( "net" ) -type Server struct { - // timeout and other stuff - Timeout int +// Wrap request in this struct +type Request struct { + Tcp bool // True for tcp, false for udp + Buf []byte // The received message + Addr net.Addr // Remote site + UDPConn *net.UDPConn // Connection for UDP + TCPConn *net.TCPConn // Connection for TCP + Error os.Error // Any errors that are found } // Every nameserver implements the Hander interface. It defines @@ -44,89 +49,128 @@ type Handler interface { // a TCP response. A TCP connection does need to // know explicitly be told the remote address. ServeTCP() must // take care of sending back a response to the requestor. - ReplyTCP(c *net.TCPConn, in []byte) + ReplyTCP(c *net.TCPConn, a net.Addr, in []byte) } -func ServeUDP(l *net.UDPConn, handler Handler) os.Error { - if handler == nil { - // handler == DefaultServer - } +func accepterUDP(l *net.UDPConn, ch chan *Request, quit chan bool) { for { - m := make([]byte, DefaultMsgSize) // TODO(mg) out of this loop? - n, radd, err := l.ReadFromUDP(m) - if err != nil { - return err - } - m = m[:n] - go handler.ReplyUDP(l, radd, m) - } - panic("not reached") -} - -func ServeTCP(l *net.TCPListener, handler Handler) os.Error { - if handler == nil { - // handler = DefaultServer - } - for { - b := make([]byte, 2) // receiver length - c, err := l.AcceptTCP() - if err != nil { - return err - } - - n, cerr := c.Read(b) - if cerr != nil { - return cerr - } - length := uint16(b[0])<<8 | uint16(b[1]) - if length == 0 { - return &Error{Error: "received nil msg length"} - } - m := make([]byte, length) - - n, cerr = c.Read(m) - if cerr != nil { - return cerr - } - i := n - if i < int(length) { - n, err = c.Read(m[i:]) + select { + case <-quit: + return + default: + r := new(Request) + r.Tcp = false + m := make([]byte, DefaultMsgSize) + n, radd, err := l.ReadFromUDP(m) if err != nil { - return err + r.Error = err + ch <- r + continue } - i += n + m = m[:n] + r.Buf = m + r.Addr = radd + r.UDPConn = l + ch <- r } - go handler.ReplyTCP(c, m) - } - panic("not reached") + } + panic("not reached") } -func ListenAndServeTCP(addr string, handler Handler) os.Error { - ta, err := net.ResolveTCPAddr(addr) - if err != nil { - return err - } - l, err := net.ListenTCP("tcp", ta) - if err != nil { - return err - } - err = ServeTCP(l, handler) - l.Close() - return err +func accepterTCP(l *net.TCPListener, ch chan *Request, quit chan bool) { + b := make([]byte, 2) + for { + select { + case <-quit: + return + default: + r := new(Request) + r.Tcp = true + c, err := l.AcceptTCP() + if err != nil { + r.Error = err + ch <- r + continue + } + n, cerr := c.Read(b) + if cerr != nil { + r.Error = cerr + ch <- r + continue + } + + length := uint16(b[0])<<8 | uint16(b[1]) + if length == 0 { + r.Error = &Error{Error: "received nil msg length"} + ch <- r + } + m := make([]byte, length) + + n, cerr = c.Read(m) + if cerr != nil { + r.Error = cerr + ch <- r + continue + } + i := n + if i < int(length) { + n, err = c.Read(m[i:]) + if err != nil { + r.Error = err + ch <- r + } + i += n + } + r.Buf = m + r.Addr = c.RemoteAddr() + r.TCPConn = c + ch <- r + } + } + panic("not reached") } -func ListenAndServeUDP(addr string, handler Handler) os.Error { +func ListenAndServe(addr string, handler Handler, q chan bool) os.Error { + ta, err := net.ResolveTCPAddr(addr) + if err != nil { + return err + } + lt, err := net.ListenTCP("tcp", ta) + if err != nil { + return err + } + ua, err := net.ResolveUDPAddr(addr) if err != nil { return err } - l, err := net.ListenUDP("udp", ua) + lu, err := net.ListenUDP("udp", ua) if err != nil { return err } - err = ServeUDP(l, handler) - l.Close() - return err + + rc := make(chan *Request) + qt := make(chan bool) + qu := make(chan bool) + go accepterTCP(lt, rc, qt) + go accepterUDP(lu, rc, qu) + + select { + case <-q: + /* quit received, lets stop */ + lt.Close() + lu.Close() + qt <- true + qu <- true + case r:=<-rc: + /* request recieved */ + if r.Tcp { + go handler.ReplyTCP(r.TCPConn, r.Addr, r.Buf) + } else { + go handler.ReplyUDP(r.UDPConn, r.Addr, r.Buf) + } + } + return err } // Send a buffer on the TCP connection. diff --git a/server_test.go b/server_test.go index a0af6bab..586abc9d 100644 --- a/server_test.go +++ b/server_test.go @@ -58,7 +58,7 @@ func TestResponder(t *testing.T) { var h *server go ListenAndServeTCP("127.0.0.1:8053", h) go ListenAndServeUDP("127.0.0.1:8053", h) - time.Sleep(20 * 1e9) + time.Sleep(1 * 1e9) } /*