diff --git a/server.go b/server.go index 41725222..7d3be2ac 100644 --- a/server.go +++ b/server.go @@ -304,16 +304,17 @@ func (srv *Server) ActivateAndServe() error { // ActivateAndServe will return. All in progress queries are completed before the server // is taken down. If the Shutdown was not succesful an error is returned. func (srv *Server) Shutdown() error { - // Client sends fake request here to not wait for timeout in readUDP/readTCP loop - // and trap to stop event ASAP. - net, addr := srv.Net, srv.Addr + var net, addr string - if srv.Listener != nil { + switch { + case srv.Listener != nil: a := srv.Listener.Addr() net, addr = a.Network(), a.String() - } else if srv.PacketConn != nil { + case srv.PacketConn != nil: a := srv.PacketConn.LocalAddr() net, addr = a.Network(), a.String() + default: + net, addr = srv.Net, srv.Addr } switch net { @@ -322,8 +323,11 @@ func (srv *Server) Shutdown() error { case "udp", "udp4", "udp6": go func() { srv.stopUDP <- true }() } + + // Send packet to server socket in order to force readUDP or readTCP to finish waiting for data. + // TODO(asergeyev): Alternative concurrent watchdog is possible to create in "serve*" in future c := &Client{Net: net} - c.Exchange(new(Msg), addr) + go c.Exchange(new(Msg), addr) return nil } @@ -342,13 +346,6 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { } for { rw, e := l.AcceptTCP() - select { - case <-srv.stopTCP: - // Asked to shutdown - srv.wgTCP.Wait() - return nil - default: - } if e != nil { continue } @@ -358,6 +355,13 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { } srv.wgTCP.Add(1) go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + select { + case <-srv.stopTCP: + // Asked to shutdown + srv.wgTCP.Wait() + return nil + default: + } } panic("dns: not reached") } @@ -377,6 +381,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { // deadline is not used here for { m, s, e := srv.readUDP(l, rtimeout) + if e != nil { + continue + } + srv.wgUDP.Add(1) + go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) select { case <-srv.stopUDP: // Asked to shutdown @@ -384,11 +393,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { return nil default: } - if e != nil { - continue - } - srv.wgUDP.Add(1) - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) } panic("dns: not reached") } diff --git a/server_test.go b/server_test.go index bbea2f85..e23d9aa3 100644 --- a/server_test.go +++ b/server_test.go @@ -10,7 +10,6 @@ import ( "runtime" "sync" "testing" - "time" ) func HelloServer(w ResponseWriter, req *Msg) { @@ -31,8 +30,8 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) { w.WriteMsg(m) } -func RunLocalUDPServer() (*Server, string, error) { - pc, err := net.ListenPacket("udp", "127.0.0.1:0") +func RunLocalUDPServer(laddr string) (*Server, string, error) { + pc, err := net.ListenPacket("udp", laddr) if err != nil { return nil, "", err } @@ -44,13 +43,26 @@ func RunLocalUDPServer() (*Server, string, error) { return server, pc.LocalAddr().String(), nil } +func RunLocalTCPServer(laddr string) (*Server, string, error) { + l, err := net.Listen("tcp", laddr) + if err != nil { + return nil, "", err + } + server := &Server{Listener: l} + go func() { + server.ActivateAndServe() + l.Close() + }() + return server, l.Addr().String(), nil +} + func TestServing(t *testing.T) { HandleFunc("miek.nl.", HelloServer) HandleFunc("example.com.", AnotherHelloServer) - s, addrstr, err := RunLocalUDPServer() + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") if err != nil { - t.Fatalf("Unable to run test server on port 8053: %s", err) + t.Fatalf("Unable to run test server: %s", err) } defer s.Shutdown() @@ -98,16 +110,20 @@ func BenchmarkServe(b *testing.B) { b.StopTimer() HandleFunc("miek.nl.", HelloServer) a := runtime.GOMAXPROCS(4) - go func() { - ListenAndServe("127.0.0.1:8053", "udp", nil) - }() + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + c := new(Client) m := new(Msg) m.SetQuestion("miek.nl", TypeSOA) b.StartTimer() for i := 0; i < b.N; i++ { - c.Exchange(m, "127.0.0.1:8053") + c.Exchange(m, addrstr) } runtime.GOMAXPROCS(a) } @@ -116,16 +132,19 @@ func benchmarkServe6(b *testing.B) { b.StopTimer() HandleFunc("miek.nl.", HelloServer) a := runtime.GOMAXPROCS(4) - go func() { - ListenAndServe("[::1]:8053", "udp", nil) - }() + s, addrstr, err := RunLocalUDPServer("[::1]:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + c := new(Client) m := new(Msg) m.SetQuestion("miek.nl", TypeSOA) b.StartTimer() for i := 0; i < b.N; i++ { - c.Exchange(m, "[::1]:8053") + c.Exchange(m, addrstr) } runtime.GOMAXPROCS(a) } @@ -143,16 +162,18 @@ func BenchmarkServeCompress(b *testing.B) { b.StopTimer() HandleFunc("miek.nl.", HelloServerCompress) a := runtime.GOMAXPROCS(4) - go func() { - ListenAndServe("127.0.0.1:8053", "udp", nil) - }() + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() c := new(Client) m := new(Msg) m.SetQuestion("miek.nl", TypeSOA) b.StartTimer() for i := 0; i < b.N; i++ { - c.Exchange(m, "127.0.0.1:8053") + c.Exchange(m, addrstr) } runtime.GOMAXPROCS(a) } @@ -242,16 +263,12 @@ func TestServingLargeResponses(t *testing.T) { mux := NewServeMux() mux.HandleFunc("example.", HelloServerLargeResponse) - server := &Server{ - Addr: "127.0.0.1:10000", - Net: "udp", - Handler: mux, + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) } - - go func() { - server.ListenAndServe() - }() - time.Sleep(50 * time.Millisecond) + defer s.Shutdown() + s.Handler = mux // Create request m := new(Msg) @@ -262,7 +279,7 @@ func TestServingLargeResponses(t *testing.T) { M.Lock() M.max = 2 M.Unlock() - _, _, err := c.Exchange(m, "127.0.0.1:10000") + _, _, err = c.Exchange(m, addrstr) if err != nil { t.Logf("failed to exchange: %s", err.Error()) t.Fail() @@ -271,14 +288,14 @@ func TestServingLargeResponses(t *testing.T) { M.Lock() M.max = 20 M.Unlock() - _, _, err = c.Exchange(m, "127.0.0.1:10000") + _, _, err = c.Exchange(m, addrstr) if err == nil { t.Logf("failed to fail exchange, this should generate packet error") t.Fail() } // But this must work again c.UDPSize = 7000 - _, _, err = c.Exchange(m, "127.0.0.1:10000") + _, _, err = c.Exchange(m, addrstr) if err != nil { t.Logf("failed to exchange: %s", err.Error()) t.Fail() @@ -288,31 +305,23 @@ func TestServingLargeResponses(t *testing.T) { // TODO(miek): These tests should actually fail when the server does // not shut down. func TestShutdownTCP(t *testing.T) { - server := Server{Addr: ":8055", Net: "tcp"} - go func() { - err := server.ListenAndServe() - if err != nil { - t.Logf("failed to setup the tcp server: %s\n", err.Error()) - t.Fail() - } - t.Logf("successfully stopped the tcp server") - }() - time.Sleep(4e8) - server.Shutdown() - time.Sleep(1e9) + s, _, err := RunLocalTCPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + err = s.Shutdown() + if err != nil { + t.Error("Could not shutdown test TCP server, %s", err) + } } func TestShutdownUDP(t *testing.T) { - server := Server{Addr: ":8054", Net: "udp"} - go func() { - err := server.ListenAndServe() - if err != nil { - t.Logf("failed to setup the udp server: %s\n", err.Error()) - t.Fail() - } - t.Logf("successfully stopped the udp server") - }() - time.Sleep(4e8) - server.Shutdown() - time.Sleep(1e9) + s, _, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + err = s.Shutdown() + if err != nil { + t.Error("Could not shutdown test UDP server, %s", err) + } }