mirror of
				https://github.com/miekg/dns.git
				synced 2025-11-04 12:41:01 +01:00 
			
		
		
		
	Reduce some code duplication by making Exchange() use Client.Exchange(). When performing an Exchange if the query ID does not match the answer ID return an error. Also add a test for this condition.
		
			
				
	
	
		
			400 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			400 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package dns
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"net"
 | 
						|
	"runtime"
 | 
						|
	"sync"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
func HelloServer(w ResponseWriter, req *Msg) {
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetReply(req)
 | 
						|
 | 
						|
	m.Extra = make([]RR, 1)
 | 
						|
	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
 | 
						|
	w.WriteMsg(m)
 | 
						|
}
 | 
						|
 | 
						|
func HelloServerBadId(w ResponseWriter, req *Msg) {
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetReply(req)
 | 
						|
	m.Id += 1
 | 
						|
 | 
						|
	m.Extra = make([]RR, 1)
 | 
						|
	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
 | 
						|
	w.WriteMsg(m)
 | 
						|
}
 | 
						|
 | 
						|
func AnotherHelloServer(w ResponseWriter, req *Msg) {
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetReply(req)
 | 
						|
 | 
						|
	m.Extra = make([]RR, 1)
 | 
						|
	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}}
 | 
						|
	w.WriteMsg(m)
 | 
						|
}
 | 
						|
 | 
						|
func RunLocalUDPServer(laddr string) (*Server, string, error) {
 | 
						|
	pc, err := net.ListenPacket("udp", laddr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, "", err
 | 
						|
	}
 | 
						|
	server := &Server{PacketConn: pc}
 | 
						|
 | 
						|
	waitLock := sync.Mutex{}
 | 
						|
	waitLock.Lock()
 | 
						|
	server.NotifyStartedFunc = waitLock.Unlock
 | 
						|
 | 
						|
	go func() {
 | 
						|
		server.ActivateAndServe()
 | 
						|
		pc.Close()
 | 
						|
	}()
 | 
						|
 | 
						|
	waitLock.Lock()
 | 
						|
	return server, pc.LocalAddr().String(), nil
 | 
						|
}
 | 
						|
 | 
						|
func RunLocalUDPServerUnsafe(laddr string) (*Server, string, error) {
 | 
						|
	pc, err := net.ListenPacket("udp", laddr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, "", err
 | 
						|
	}
 | 
						|
	server := &Server{PacketConn: pc, Unsafe: true}
 | 
						|
 | 
						|
	waitLock := sync.Mutex{}
 | 
						|
	waitLock.Lock()
 | 
						|
	server.NotifyStartedFunc = waitLock.Unlock
 | 
						|
 | 
						|
	go func() {
 | 
						|
		server.ActivateAndServe()
 | 
						|
		pc.Close()
 | 
						|
	}()
 | 
						|
 | 
						|
	waitLock.Lock()
 | 
						|
	return server, pc.LocalAddr().String(), nil
 | 
						|
}
 | 
						|
 | 
						|
func RunLocalTCPServer(laddr string) (*Server, string, error) {
 | 
						|
	l, err := net.Listen("tcp", laddr)
 | 
						|
	if err != nil {
 | 
						|
		return nil, "", err
 | 
						|
	}
 | 
						|
 | 
						|
	server := &Server{Listener: l}
 | 
						|
 | 
						|
	waitLock := sync.Mutex{}
 | 
						|
	waitLock.Lock()
 | 
						|
	server.NotifyStartedFunc = waitLock.Unlock
 | 
						|
 | 
						|
	go func() {
 | 
						|
		server.ActivateAndServe()
 | 
						|
		l.Close()
 | 
						|
	}()
 | 
						|
 | 
						|
	waitLock.Lock()
 | 
						|
	return server, l.Addr().String(), nil
 | 
						|
}
 | 
						|
 | 
						|
func TestServing(t *testing.T) {
 | 
						|
	HandleFunc("miek.nl.", HelloServer)
 | 
						|
	HandleFunc("example.com.", AnotherHelloServer)
 | 
						|
	defer HandleRemove("miek.nl.")
 | 
						|
	defer HandleRemove("example.com.")
 | 
						|
 | 
						|
	s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("miek.nl.", TypeTXT)
 | 
						|
	r, _, err := c.Exchange(m, addrstr)
 | 
						|
	if err != nil || len(r.Extra) == 0 {
 | 
						|
		t.Fatal("failed to exchange miek.nl", err)
 | 
						|
	}
 | 
						|
	txt := r.Extra[0].(*TXT).Txt[0]
 | 
						|
	if txt != "Hello world" {
 | 
						|
		t.Error("Unexpected result for miek.nl", txt, "!= Hello world")
 | 
						|
	}
 | 
						|
 | 
						|
	m.SetQuestion("example.com.", TypeTXT)
 | 
						|
	r, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("failed to exchange example.com", err)
 | 
						|
	}
 | 
						|
	txt = r.Extra[0].(*TXT).Txt[0]
 | 
						|
	if txt != "Hello example" {
 | 
						|
		t.Error("Unexpected result for example.com", txt, "!= Hello example")
 | 
						|
	}
 | 
						|
 | 
						|
	// Test Mixes cased as noticed by Ask.
 | 
						|
	m.SetQuestion("eXaMplE.cOm.", TypeTXT)
 | 
						|
	r, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Error("failed to exchange eXaMplE.cOm", err)
 | 
						|
	}
 | 
						|
	txt = r.Extra[0].(*TXT).Txt[0]
 | 
						|
	if txt != "Hello example" {
 | 
						|
		t.Error("Unexpected result for example.com", txt, "!= Hello example")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func BenchmarkServe(b *testing.B) {
 | 
						|
	b.StopTimer()
 | 
						|
	HandleFunc("miek.nl.", HelloServer)
 | 
						|
	defer HandleRemove("miek.nl.")
 | 
						|
	a := runtime.GOMAXPROCS(4)
 | 
						|
 | 
						|
	s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		b.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("miek.nl", TypeSOA)
 | 
						|
 | 
						|
	b.StartTimer()
 | 
						|
	for i := 0; i < b.N; i++ {
 | 
						|
		c.Exchange(m, addrstr)
 | 
						|
	}
 | 
						|
	runtime.GOMAXPROCS(a)
 | 
						|
}
 | 
						|
 | 
						|
func benchmarkServe6(b *testing.B) {
 | 
						|
	b.StopTimer()
 | 
						|
	HandleFunc("miek.nl.", HelloServer)
 | 
						|
	defer HandleRemove("miek.nl.")
 | 
						|
	a := runtime.GOMAXPROCS(4)
 | 
						|
	s, addrstr, err := RunLocalUDPServer("[::1]:0")
 | 
						|
	if err != nil {
 | 
						|
		b.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("miek.nl", TypeSOA)
 | 
						|
 | 
						|
	b.StartTimer()
 | 
						|
	for i := 0; i < b.N; i++ {
 | 
						|
		c.Exchange(m, addrstr)
 | 
						|
	}
 | 
						|
	runtime.GOMAXPROCS(a)
 | 
						|
}
 | 
						|
 | 
						|
func HelloServerCompress(w ResponseWriter, req *Msg) {
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetReply(req)
 | 
						|
	m.Extra = make([]RR, 1)
 | 
						|
	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
 | 
						|
	m.Compress = true
 | 
						|
	w.WriteMsg(m)
 | 
						|
}
 | 
						|
 | 
						|
func BenchmarkServeCompress(b *testing.B) {
 | 
						|
	b.StopTimer()
 | 
						|
	HandleFunc("miek.nl.", HelloServerCompress)
 | 
						|
	defer HandleRemove("miek.nl.")
 | 
						|
	a := runtime.GOMAXPROCS(4)
 | 
						|
	s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		b.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("miek.nl", TypeSOA)
 | 
						|
	b.StartTimer()
 | 
						|
	for i := 0; i < b.N; i++ {
 | 
						|
		c.Exchange(m, addrstr)
 | 
						|
	}
 | 
						|
	runtime.GOMAXPROCS(a)
 | 
						|
}
 | 
						|
 | 
						|
func TestDotAsCatchAllWildcard(t *testing.T) {
 | 
						|
	mux := NewServeMux()
 | 
						|
	mux.Handle(".", HandlerFunc(HelloServer))
 | 
						|
	mux.Handle("example.com.", HandlerFunc(AnotherHelloServer))
 | 
						|
 | 
						|
	handler := mux.match("www.miek.nl.", TypeTXT)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("wildcard match failed")
 | 
						|
	}
 | 
						|
 | 
						|
	handler = mux.match("www.example.com.", TypeTXT)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("example.com match failed")
 | 
						|
	}
 | 
						|
 | 
						|
	handler = mux.match("a.www.example.com.", TypeTXT)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("a.www.example.com match failed")
 | 
						|
	}
 | 
						|
 | 
						|
	handler = mux.match("boe.", TypeTXT)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("boe. match failed")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCaseFolding(t *testing.T) {
 | 
						|
	mux := NewServeMux()
 | 
						|
	mux.Handle("_udp.example.com.", HandlerFunc(HelloServer))
 | 
						|
 | 
						|
	handler := mux.match("_dns._udp.example.com.", TypeSRV)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("case sensitive characters folded")
 | 
						|
	}
 | 
						|
 | 
						|
	handler = mux.match("_DNS._UDP.EXAMPLE.COM.", TypeSRV)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("case insensitive characters not folded")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestRootServer(t *testing.T) {
 | 
						|
	mux := NewServeMux()
 | 
						|
	mux.Handle(".", HandlerFunc(HelloServer))
 | 
						|
 | 
						|
	handler := mux.match(".", TypeNS)
 | 
						|
	if handler == nil {
 | 
						|
		t.Error("root match failed")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type maxRec struct {
 | 
						|
	max int
 | 
						|
	sync.RWMutex
 | 
						|
}
 | 
						|
 | 
						|
var M = new(maxRec)
 | 
						|
 | 
						|
func HelloServerLargeResponse(resp ResponseWriter, req *Msg) {
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetReply(req)
 | 
						|
	m.Authoritative = true
 | 
						|
	m1 := 0
 | 
						|
	M.RLock()
 | 
						|
	m1 = M.max
 | 
						|
	M.RUnlock()
 | 
						|
	for i := 0; i < m1; i++ {
 | 
						|
		aRec := &A{
 | 
						|
			Hdr: RR_Header{
 | 
						|
				Name:   req.Question[0].Name,
 | 
						|
				Rrtype: TypeA,
 | 
						|
				Class:  ClassINET,
 | 
						|
				Ttl:    0,
 | 
						|
			},
 | 
						|
			A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(),
 | 
						|
		}
 | 
						|
		m.Answer = append(m.Answer, aRec)
 | 
						|
	}
 | 
						|
	resp.WriteMsg(m)
 | 
						|
}
 | 
						|
 | 
						|
func TestServingLargeResponses(t *testing.T) {
 | 
						|
	HandleFunc("example.", HelloServerLargeResponse)
 | 
						|
	defer HandleRemove("example.")
 | 
						|
 | 
						|
	s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	// Create request
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("web.service.example.", TypeANY)
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	c.Net = "udp"
 | 
						|
	M.Lock()
 | 
						|
	M.max = 2
 | 
						|
	M.Unlock()
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("failed to exchange: %v", err)
 | 
						|
	}
 | 
						|
	// This must fail
 | 
						|
	M.Lock()
 | 
						|
	M.max = 20
 | 
						|
	M.Unlock()
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err == nil {
 | 
						|
		t.Error("failed to fail exchange, this should generate packet error")
 | 
						|
	}
 | 
						|
	// But this must work again
 | 
						|
	c.UDPSize = 7000
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("failed to exchange: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestServingResponse(t *testing.T) {
 | 
						|
	if testing.Short() {
 | 
						|
		t.Skip("skipping test in short mode.")
 | 
						|
	}
 | 
						|
	HandleFunc("miek.nl.", HelloServer)
 | 
						|
	s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	c := new(Client)
 | 
						|
	m := new(Msg)
 | 
						|
	m.SetQuestion("miek.nl.", TypeTXT)
 | 
						|
	m.Response = false
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("failed to exchange", err)
 | 
						|
	}
 | 
						|
	m.Response = true
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err == nil {
 | 
						|
		t.Fatal("exchanged response message")
 | 
						|
	}
 | 
						|
 | 
						|
	s.Shutdown()
 | 
						|
	s, addrstr, err = RunLocalUDPServerUnsafe("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	defer s.Shutdown()
 | 
						|
 | 
						|
	m.Response = true
 | 
						|
	_, _, err = c.Exchange(m, addrstr)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("could exchanged response message in Unsafe mode")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestShutdownTCP(t *testing.T) {
 | 
						|
	s, _, err := RunLocalTCPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	err = s.Shutdown()
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("Could not shutdown test TCP server, %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestShutdownUDP(t *testing.T) {
 | 
						|
	s, _, err := RunLocalUDPServer("127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to run test server: %v", err)
 | 
						|
	}
 | 
						|
	err = s.Shutdown()
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("Could not shutdown test UDP server, %v", err)
 | 
						|
	}
 | 
						|
}
 |