mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 00:01:40 +01:00 
			
		
		
		
	The callback itself is not removed as it is used in other repos, making it simpler for those to slowly transition to the eventbus. Updates #15160 Signed-off-by: Claus Lensbøl <claus@tailscale.com>
		
			
				
	
	
		
			867 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			867 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| package resolver
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/binary"
 | |
| 	"flag"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/netip"
 | |
| 	"reflect"
 | |
| 	"slices"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	dns "golang.org/x/net/dns/dnsmessage"
 | |
| 	"tailscale.com/control/controlknobs"
 | |
| 	"tailscale.com/health"
 | |
| 	"tailscale.com/net/netmon"
 | |
| 	"tailscale.com/net/tsdial"
 | |
| 	"tailscale.com/tstest"
 | |
| 	"tailscale.com/types/dnstype"
 | |
| 	"tailscale.com/util/eventbus/eventbustest"
 | |
| )
 | |
| 
 | |
| func (rr resolverAndDelay) String() string {
 | |
| 	return fmt.Sprintf("%v+%v", rr.name, rr.startDelay)
 | |
| }
 | |
| 
 | |
| func TestResolversWithDelays(t *testing.T) {
 | |
| 	// query
 | |
| 	q := func(ss ...string) (ipps []*dnstype.Resolver) {
 | |
| 		for _, host := range ss {
 | |
| 			ipps = append(ipps, &dnstype.Resolver{Addr: host})
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 	// output
 | |
| 	o := func(ss ...string) (rr []resolverAndDelay) {
 | |
| 		for _, s := range ss {
 | |
| 			var d time.Duration
 | |
| 			s, durStr, hasPlus := strings.Cut(s, "+")
 | |
| 			if hasPlus {
 | |
| 				var err error
 | |
| 				d, err = time.ParseDuration(durStr)
 | |
| 				if err != nil {
 | |
| 					panic(fmt.Sprintf("parsing duration in %q: %v", s, err))
 | |
| 				}
 | |
| 			}
 | |
| 			rr = append(rr, resolverAndDelay{
 | |
| 				name:       &dnstype.Resolver{Addr: s},
 | |
| 				startDelay: d,
 | |
| 			})
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		name string
 | |
| 		in   []*dnstype.Resolver
 | |
| 		want []resolverAndDelay
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "unknown-no-delays",
 | |
| 			in:   q("1.2.3.4", "2.3.4.5"),
 | |
| 			want: o("1.2.3.4", "2.3.4.5"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "google-all-ipv4",
 | |
| 			in:   q("8.8.8.8", "8.8.4.4"),
 | |
| 			want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "google-only-ipv6",
 | |
| 			in:   q("2001:4860:4860::8888", "2001:4860:4860::8844"),
 | |
| 			want: o("https://dns.google/dns-query", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "google-all-four",
 | |
| 			in:   q("8.8.8.8", "8.8.4.4", "2001:4860:4860::8888", "2001:4860:4860::8844"),
 | |
| 			want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "quad9-one-v4-one-v6",
 | |
| 			in:   q("9.9.9.9", "2620:fe::fe"),
 | |
| 			want: o("https://dns.quad9.net/dns-query", "9.9.9.9+0.5s", "2620:fe::fe+0.5s"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "nextdns-ipv6-expand",
 | |
| 			in:   q("2a07:a8c0::c3:a884"),
 | |
| 			want: o("https://dns.nextdns.io/c3a884"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "nextdns-doh-input",
 | |
| 			in:   q("https://dns.nextdns.io/c3a884"),
 | |
| 			want: o("https://dns.nextdns.io/c3a884"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "controld-ipv6-expand",
 | |
| 			in:   q("2606:1a40:0:6:7b5b:5949:35ad:0"),
 | |
| 			want: o("https://dns.controld.com/hyq3ipr2ct"),
 | |
| 		},
 | |
| 		{
 | |
| 			name: "controld-doh-input",
 | |
| 			in:   q("https://dns.controld.com/hyq3ipr2ct"),
 | |
| 			want: o("https://dns.controld.com/hyq3ipr2ct"),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			got := resolversWithDelays(tt.in)
 | |
| 			if !reflect.DeepEqual(got, tt.want) {
 | |
| 				t.Errorf("got %v; want %v", got, tt.want)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGetRCode(t *testing.T) {
 | |
| 	tests := []struct {
 | |
| 		name   string
 | |
| 		packet []byte
 | |
| 		want   dns.RCode
 | |
| 	}{
 | |
| 		{
 | |
| 			name:   "empty",
 | |
| 			packet: []byte{},
 | |
| 			want:   dns.RCode(5),
 | |
| 		},
 | |
| 		{
 | |
| 			name:   "too-short",
 | |
| 			packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
 | |
| 			want:   dns.RCode(5),
 | |
| 		},
 | |
| 		{
 | |
| 			name:   "noerror",
 | |
| 			packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01},
 | |
| 			want:   dns.RCode(0),
 | |
| 		},
 | |
| 		{
 | |
| 			name:   "refused",
 | |
| 			packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
 | |
| 			want:   dns.RCode(5),
 | |
| 		},
 | |
| 		{
 | |
| 			name:   "nxdomain",
 | |
| 			packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01},
 | |
| 			want:   dns.RCode(3),
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			got := getRCode(tt.packet)
 | |
| 			if got != tt.want {
 | |
| 				t.Errorf("got %d; want %d", got, tt.want)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server")
 | |
| 
 | |
| func TestGetKnownDoHClientForProvider(t *testing.T) {
 | |
| 	var fwd forwarder
 | |
| 	c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query")
 | |
| 	if !ok {
 | |
| 		t.Fatal("not found")
 | |
| 	}
 | |
| 	if !*testDNS {
 | |
| 		t.Skip("skipping without --test-dns")
 | |
| 	}
 | |
| 	res, err := c.Head("https://dns.google/")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	defer res.Body.Close()
 | |
| 	t.Logf("Got: %+v", res)
 | |
| }
 | |
| 
 | |
| func BenchmarkNameFromQuery(b *testing.B) {
 | |
| 	builder := dns.NewBuilder(nil, dns.Header{})
 | |
| 	builder.StartQuestions()
 | |
| 	builder.Question(dns.Question{
 | |
| 		Name:  dns.MustNewName("foo.example."),
 | |
| 		Type:  dns.TypeA,
 | |
| 		Class: dns.ClassINET,
 | |
| 	})
 | |
| 	msg, err := builder.Finish()
 | |
| 	if err != nil {
 | |
| 		b.Fatal(err)
 | |
| 	}
 | |
| 	b.ResetTimer()
 | |
| 	b.ReportAllocs()
 | |
| 	for range b.N {
 | |
| 		_, _, err := nameFromQuery(msg)
 | |
| 		if err != nil {
 | |
| 			b.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Reproduces https://github.com/tailscale/tailscale/issues/2533
 | |
| // Fixed by https://github.com/tailscale/tailscale/commit/f414a9cc01f3264912513d07c0244ff4f3e4ba54
 | |
| //
 | |
| // NOTE: fuzz tests act like unit tests when run without `-fuzz`
 | |
| func FuzzClampEDNSSize(f *testing.F) {
 | |
| 	// Empty DNS packet
 | |
| 	f.Add([]byte{
 | |
| 		// query id
 | |
| 		0x12, 0x34,
 | |
| 		// flags: standard query, recurse
 | |
| 		0x01, 0x20,
 | |
| 		// num questions
 | |
| 		0x00, 0x00,
 | |
| 		// num answers
 | |
| 		0x00, 0x00,
 | |
| 		// num authority RRs
 | |
| 		0x00, 0x00,
 | |
| 		// num additional RRs
 | |
| 		0x00, 0x00,
 | |
| 	})
 | |
| 
 | |
| 	// Empty OPT
 | |
| 	f.Add([]byte{
 | |
| 		// header
 | |
| 		0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
 | |
| 		0x00, 0x00, 0x00, 0x01,
 | |
| 		// query
 | |
| 		0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
 | |
| 		0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
 | |
| 		// OPT
 | |
| 		0x00,       // name: <root>
 | |
| 		0x00, 0x29, // type: OPT
 | |
| 		0x10, 0x00, // UDP payload size
 | |
| 		0x00,       // higher bits in extended RCODE
 | |
| 		0x00,       // EDNS0 version
 | |
| 		0x80, 0x00, // "Z" field
 | |
| 		0x00, 0x00, // data length
 | |
| 	})
 | |
| 
 | |
| 	// Query for "google.com"
 | |
| 	f.Add([]byte{
 | |
| 		// header
 | |
| 		0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
 | |
| 		0x00, 0x00, 0x00, 0x01,
 | |
| 		// query
 | |
| 		0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
 | |
| 		0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
 | |
| 		// OPT
 | |
| 		0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00,
 | |
| 		0x0c, 0x00, 0x0a, 0x00, 0x08, 0x62, 0x18, 0x1a, 0xcb, 0x19,
 | |
| 		0xd7, 0xee, 0x23,
 | |
| 	})
 | |
| 
 | |
| 	f.Fuzz(func(t *testing.T, data []byte) {
 | |
| 		clampEDNSSize(data, maxResponseBytes)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| type testDNSServerOptions struct {
 | |
| 	SkipUDP bool
 | |
| 	SkipTCP bool
 | |
| }
 | |
| 
 | |
| func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) {
 | |
| 	if opts != nil && opts.SkipUDP && opts.SkipTCP {
 | |
| 		tb.Fatal("cannot skip both UDP and TCP servers")
 | |
| 	}
 | |
| 
 | |
| 	logf := tstest.WhileTestRunningLogger(tb)
 | |
| 
 | |
| 	tcpResponse := make([]byte, len(response)+2)
 | |
| 	binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
 | |
| 	copy(tcpResponse[2:], response)
 | |
| 
 | |
| 	// Repeatedly listen until we can get the same port.
 | |
| 	const tries = 25
 | |
| 	var (
 | |
| 		tcpLn *net.TCPListener
 | |
| 		udpLn *net.UDPConn
 | |
| 		err   error
 | |
| 	)
 | |
| 	for try := 0; try < tries; try++ {
 | |
| 		if tcpLn != nil {
 | |
| 			tcpLn.Close()
 | |
| 			tcpLn = nil
 | |
| 		}
 | |
| 
 | |
| 		tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
 | |
| 			IP:   net.IPv4(127, 0, 0, 1),
 | |
| 			Port: 0, // Choose one
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			tb.Fatal(err)
 | |
| 		}
 | |
| 		udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
 | |
| 			IP:   net.IPv4(127, 0, 0, 1),
 | |
| 			Port: tcpLn.Addr().(*net.TCPAddr).Port,
 | |
| 		})
 | |
| 		if err == nil {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	if tcpLn == nil || udpLn == nil {
 | |
| 		if tcpLn != nil {
 | |
| 			tcpLn.Close()
 | |
| 		}
 | |
| 		if udpLn != nil {
 | |
| 			udpLn.Close()
 | |
| 		}
 | |
| 
 | |
| 		// Skip instead of being fatal to avoid flaking on extremely
 | |
| 		// heavily-loaded CI systems.
 | |
| 		tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
 | |
| 	}
 | |
| 
 | |
| 	port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
 | |
| 
 | |
| 	handleConn := func(conn net.Conn) {
 | |
| 		defer conn.Close()
 | |
| 
 | |
| 		// Read the length header, then the buffer
 | |
| 		var length uint16
 | |
| 		if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
 | |
| 			logf("error reading length header: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 		req := make([]byte, length)
 | |
| 		n, err := io.ReadFull(conn, req)
 | |
| 		if err != nil {
 | |
| 			logf("error reading query: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 		req = req[:n]
 | |
| 		onRequest(true, req)
 | |
| 
 | |
| 		// Write response
 | |
| 		if _, err := conn.Write(tcpResponse); err != nil {
 | |
| 			logf("error writing response: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var wg sync.WaitGroup
 | |
| 
 | |
| 	if opts == nil || !opts.SkipTCP {
 | |
| 		wg.Add(1)
 | |
| 		go func() {
 | |
| 			defer wg.Done()
 | |
| 			for {
 | |
| 				conn, err := tcpLn.Accept()
 | |
| 				if err != nil {
 | |
| 					return
 | |
| 				}
 | |
| 				go handleConn(conn)
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	handleUDP := func(addr netip.AddrPort, req []byte) {
 | |
| 		onRequest(false, req)
 | |
| 		if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
 | |
| 			logf("error writing response: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if opts == nil || !opts.SkipUDP {
 | |
| 		wg.Add(1)
 | |
| 		go func() {
 | |
| 			defer wg.Done()
 | |
| 			for {
 | |
| 				buf := make([]byte, 65535)
 | |
| 				n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
 | |
| 				if err != nil {
 | |
| 					return
 | |
| 				}
 | |
| 				buf = buf[:n]
 | |
| 				go handleUDP(addr, buf)
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	tb.Cleanup(func() {
 | |
| 		tcpLn.Close()
 | |
| 		udpLn.Close()
 | |
| 		logf("waiting for listeners to finish...")
 | |
| 		wg.Wait()
 | |
| 	})
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
 | |
| 	name := dns.MustNewName(domain)
 | |
| 
 | |
| 	builder := dns.NewBuilder(nil, dns.Header{Response: true})
 | |
| 	builder.StartQuestions()
 | |
| 	builder.Question(dns.Question{
 | |
| 		Name:  name,
 | |
| 		Type:  dns.TypeA,
 | |
| 		Class: dns.ClassINET,
 | |
| 	})
 | |
| 	builder.StartAnswers()
 | |
| 	for i := range 120 {
 | |
| 		builder.AResource(dns.ResourceHeader{
 | |
| 			Name:  name,
 | |
| 			Class: dns.ClassINET,
 | |
| 			TTL:   300,
 | |
| 		}, dns.AResource{
 | |
| 			A: [4]byte{127, 0, 0, byte(i)},
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	response, err = builder.Finish()
 | |
| 	if err != nil {
 | |
| 		tb.Fatal(err)
 | |
| 	}
 | |
| 	if len(response) <= maxResponseBytes {
 | |
| 		tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes)
 | |
| 	}
 | |
| 
 | |
| 	// Our request is a single A query for the domain in the answer, above.
 | |
| 	builder = dns.NewBuilder(nil, dns.Header{})
 | |
| 	builder.StartQuestions()
 | |
| 	builder.Question(dns.Question{
 | |
| 		Name:  dns.MustNewName(domain),
 | |
| 		Type:  dns.TypeA,
 | |
| 		Class: dns.ClassINET,
 | |
| 	})
 | |
| 	request, err = builder.Finish()
 | |
| 	if err != nil {
 | |
| 		tb.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
 | |
| 	logf := tstest.WhileTestRunningLogger(tb)
 | |
| 	bus := eventbustest.NewBus(tb)
 | |
| 	netMon, err := netmon.New(bus, logf)
 | |
| 	if err != nil {
 | |
| 		tb.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	var dialer tsdial.Dialer
 | |
| 	dialer.SetNetMon(netMon)
 | |
| 	dialer.SetBus(bus)
 | |
| 
 | |
| 	fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil)
 | |
| 	if modify != nil {
 | |
| 		modify(fwd)
 | |
| 	}
 | |
| 
 | |
| 	resolvers := make([]resolverAndDelay, len(ports))
 | |
| 	for i, port := range ports {
 | |
| 		resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
 | |
| 	}
 | |
| 
 | |
| 	rpkt := packet{
 | |
| 		bs:     request,
 | |
| 		family: "tcp",
 | |
| 		addr:   netip.MustParseAddrPort("127.0.0.1:12345"),
 | |
| 	}
 | |
| 
 | |
| 	rchan := make(chan packet, 1)
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | |
| 	tb.Cleanup(cancel)
 | |
| 	err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
 | |
| 	select {
 | |
| 	case res := <-rchan:
 | |
| 		return res.bs, err
 | |
| 	case <-ctx.Done():
 | |
| 		return nil, ctx.Err()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // makeTestRequest returns a new TypeA request for the given domain.
 | |
| func makeTestRequest(tb testing.TB, domain string) []byte {
 | |
| 	tb.Helper()
 | |
| 	name := dns.MustNewName(domain)
 | |
| 	builder := dns.NewBuilder(nil, dns.Header{})
 | |
| 	builder.StartQuestions()
 | |
| 	builder.Question(dns.Question{
 | |
| 		Name:  name,
 | |
| 		Type:  dns.TypeA,
 | |
| 		Class: dns.ClassINET,
 | |
| 	})
 | |
| 	request, err := builder.Finish()
 | |
| 	if err != nil {
 | |
| 		tb.Fatal(err)
 | |
| 	}
 | |
| 	return request
 | |
| }
 | |
| 
 | |
| // makeTestResponse returns a new Type A response for the given domain,
 | |
| // with the specified status code and zero or more addresses.
 | |
| func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
 | |
| 	tb.Helper()
 | |
| 	name := dns.MustNewName(domain)
 | |
| 	builder := dns.NewBuilder(nil, dns.Header{
 | |
| 		Response:      true,
 | |
| 		Authoritative: true,
 | |
| 		RCode:         code,
 | |
| 	})
 | |
| 	builder.StartQuestions()
 | |
| 	q := dns.Question{
 | |
| 		Name:  name,
 | |
| 		Type:  dns.TypeA,
 | |
| 		Class: dns.ClassINET,
 | |
| 	}
 | |
| 	builder.Question(q)
 | |
| 	if len(addrs) > 0 {
 | |
| 		builder.StartAnswers()
 | |
| 		for _, addr := range addrs {
 | |
| 			builder.AResource(dns.ResourceHeader{
 | |
| 				Name:  q.Name,
 | |
| 				Class: q.Class,
 | |
| 				TTL:   120,
 | |
| 			}, dns.AResource{
 | |
| 				A: addr.As4(),
 | |
| 			})
 | |
| 		}
 | |
| 	}
 | |
| 	response, err := builder.Finish()
 | |
| 	if err != nil {
 | |
| 		tb.Fatal(err)
 | |
| 	}
 | |
| 	return response
 | |
| }
 | |
| 
 | |
| func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
 | |
| 	resp, err := runTestQuery(tb, request, modify, ports...)
 | |
| 	if err != nil {
 | |
| 		tb.Fatalf("error making request: %v", err)
 | |
| 	}
 | |
| 	return resp
 | |
| }
 | |
| 
 | |
| func beVerbose(f *forwarder) {
 | |
| 	f.verboseFwd = true
 | |
| }
 | |
| 
 | |
| func TestForwarderTCPFallback(t *testing.T) {
 | |
| 	const domain = "large-dns-response.tailscale.com."
 | |
| 
 | |
| 	// Make a response that's very large, containing a bunch of localhost addresses.
 | |
| 	request, largeResponse := makeLargeResponse(t, domain)
 | |
| 
 | |
| 	var sawTCPRequest atomic.Bool
 | |
| 	port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
 | |
| 		if isTCP {
 | |
| 			t.Logf("saw TCP request")
 | |
| 			sawTCPRequest.Store(true)
 | |
| 		} else {
 | |
| 			t.Logf("saw UDP request")
 | |
| 		}
 | |
| 
 | |
| 		if !bytes.Equal(request, gotRequest) {
 | |
| 			t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	resp := mustRunTestQuery(t, request, beVerbose, port)
 | |
| 	if !bytes.Equal(resp, largeResponse) {
 | |
| 		t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
 | |
| 	}
 | |
| 	if !sawTCPRequest.Load() {
 | |
| 		t.Errorf("DNS server never saw TCP request")
 | |
| 	}
 | |
| 
 | |
| 	// NOTE: can't assert that we see a UDP request here since we might
 | |
| 	// race and run the TCP query first. We test the UDP codepath in
 | |
| 	// TestForwarderTCPFallbackDisabled below, though.
 | |
| }
 | |
| 
 | |
| // Test to ensure that if the UDP listener is unresponsive, we always make a
 | |
| // TCP request even if we never get a response.
 | |
| func TestForwarderTCPFallbackTimeout(t *testing.T) {
 | |
| 	const domain = "large-dns-response.tailscale.com."
 | |
| 
 | |
| 	// Make a response that's very large, containing a bunch of localhost addresses.
 | |
| 	request, largeResponse := makeLargeResponse(t, domain)
 | |
| 
 | |
| 	var sawTCPRequest atomic.Bool
 | |
| 	opts := &testDNSServerOptions{SkipUDP: true}
 | |
| 	port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) {
 | |
| 		if isTCP {
 | |
| 			t.Logf("saw TCP request")
 | |
| 			sawTCPRequest.Store(true)
 | |
| 		} else {
 | |
| 			t.Error("saw unexpected UDP request")
 | |
| 		}
 | |
| 
 | |
| 		if !bytes.Equal(request, gotRequest) {
 | |
| 			t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	resp := mustRunTestQuery(t, request, beVerbose, port)
 | |
| 	if !bytes.Equal(resp, largeResponse) {
 | |
| 		t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
 | |
| 	}
 | |
| 	if !sawTCPRequest.Load() {
 | |
| 		t.Errorf("DNS server never saw TCP request")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestForwarderTCPFallbackDisabled(t *testing.T) {
 | |
| 	const domain = "large-dns-response.tailscale.com."
 | |
| 
 | |
| 	// Make a response that's very large, containing a bunch of localhost addresses.
 | |
| 	request, largeResponse := makeLargeResponse(t, domain)
 | |
| 
 | |
| 	var sawUDPRequest atomic.Bool
 | |
| 	port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
 | |
| 		if isTCP {
 | |
| 			t.Error("saw unexpected TCP request")
 | |
| 		} else {
 | |
| 			t.Logf("saw UDP request")
 | |
| 			sawUDPRequest.Store(true)
 | |
| 		}
 | |
| 
 | |
| 		if !bytes.Equal(request, gotRequest) {
 | |
| 			t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
 | |
| 		fwd.verboseFwd = true
 | |
| 		// Disable retries for this test.
 | |
| 		fwd.controlKnobs = &controlknobs.Knobs{}
 | |
| 		fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
 | |
| 	}, port)
 | |
| 
 | |
| 	wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
 | |
| 
 | |
| 	// Set the truncated flag on the expected response, since that's what we expect.
 | |
| 	flags := binary.BigEndian.Uint16(wantResp[2:4])
 | |
| 	flags |= dnsFlagTruncated
 | |
| 	binary.BigEndian.PutUint16(wantResp[2:4], flags)
 | |
| 
 | |
| 	if !bytes.Equal(resp, wantResp) {
 | |
| 		t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
 | |
| 	}
 | |
| 	if !sawUDPRequest.Load() {
 | |
| 		t.Errorf("DNS server never saw UDP request")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Test to ensure that we propagate DNS errors
 | |
| func TestForwarderTCPFallbackError(t *testing.T) {
 | |
| 	const domain = "error-response.tailscale.com."
 | |
| 
 | |
| 	// Our response is a SERVFAIL
 | |
| 	response := makeTestResponse(t, domain, dns.RCodeServerFailure)
 | |
| 
 | |
| 	// Our request is a single A query for the domain in the answer, above.
 | |
| 	request := makeTestRequest(t, domain)
 | |
| 
 | |
| 	var sawRequest atomic.Bool
 | |
| 	port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
 | |
| 		sawRequest.Store(true)
 | |
| 		if !bytes.Equal(request, gotRequest) {
 | |
| 			t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	resp, err := runTestQuery(t, request, beVerbose, port)
 | |
| 	if !sawRequest.Load() {
 | |
| 		t.Error("did not see DNS request")
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("wanted nil, got %v", err)
 | |
| 	}
 | |
| 	var parser dns.Parser
 | |
| 	respHeader, err := parser.Start(resp)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("parser.Start() failed: %v", err)
 | |
| 	}
 | |
| 	if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want {
 | |
| 		t.Errorf("wanted %v, got %v", want, got)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Test to ensure that if we have more than one resolver, and at least one of them
 | |
| // returns a successful response, we propagate it.
 | |
| func TestForwarderWithManyResolvers(t *testing.T) {
 | |
| 	const domain = "example.com."
 | |
| 	request := makeTestRequest(t, domain)
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		name          string
 | |
| 		responses     [][]byte // upstream responses
 | |
| 		wantResponses [][]byte // we should receive one of these from the forwarder
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "Success",
 | |
| 			responses: [][]byte{ // All upstream servers returned successful, but different, response.
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{ // We may forward whichever response is received first.
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "ServFail",
 | |
| 			responses: [][]byte{ // All upstream servers returned a SERVFAIL.
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "ServFail+Success",
 | |
| 			responses: [][]byte{ // All upstream servers fail except for one.
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{ // We should forward the successful response.
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "NXDomain",
 | |
| 			responses: [][]byte{ // All upstream servers returned NXDOMAIN.
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "NXDomain+Success",
 | |
| 			responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "Refused",
 | |
| 			responses: [][]byte{ // All upstream servers return different failures.
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 				makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "MixFail",
 | |
| 			responses: [][]byte{ // All upstream servers return different failures.
 | |
| 				makeTestResponse(t, domain, dns.RCodeServerFailure),
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 			},
 | |
| 			wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
 | |
| 				makeTestResponse(t, domain, dns.RCodeNameError),
 | |
| 				makeTestResponse(t, domain, dns.RCodeRefused),
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			ports := make([]uint16, len(tt.responses))
 | |
| 			for i := range tt.responses {
 | |
| 				ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
 | |
| 			}
 | |
| 			gotResponse, err := runTestQuery(t, request, beVerbose, ports...)
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("wanted nil, got %v", err)
 | |
| 			}
 | |
| 			responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
 | |
| 				return slices.Equal(gotResponse, wantResponse)
 | |
| 			})
 | |
| 			if !responseOk {
 | |
| 				t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // mdnsResponder at minimum has an expectation that NXDOMAIN must include the
 | |
| // question, otherwise it will penalize our server (#13511).
 | |
| func TestNXDOMAINIncludesQuestion(t *testing.T) {
 | |
| 	var domain = "lb._dns-sd._udp.example.org."
 | |
| 
 | |
| 	// Our response is a NXDOMAIN
 | |
| 	response := func() []byte {
 | |
| 		name := dns.MustNewName(domain)
 | |
| 
 | |
| 		builder := dns.NewBuilder(nil, dns.Header{
 | |
| 			Response: true,
 | |
| 			RCode:    dns.RCodeNameError,
 | |
| 		})
 | |
| 		builder.StartQuestions()
 | |
| 		builder.Question(dns.Question{
 | |
| 			Name:  name,
 | |
| 			Type:  dns.TypePTR,
 | |
| 			Class: dns.ClassINET,
 | |
| 		})
 | |
| 		response, err := builder.Finish()
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		return response
 | |
| 	}()
 | |
| 
 | |
| 	// Our request is a single PTR query for the domain in the answer, above.
 | |
| 	request := func() []byte {
 | |
| 		builder := dns.NewBuilder(nil, dns.Header{})
 | |
| 		builder.StartQuestions()
 | |
| 		builder.Question(dns.Question{
 | |
| 			Name:  dns.MustNewName(domain),
 | |
| 			Type:  dns.TypePTR,
 | |
| 			Class: dns.ClassINET,
 | |
| 		})
 | |
| 		request, err := builder.Finish()
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		return request
 | |
| 	}()
 | |
| 
 | |
| 	port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
 | |
| 	})
 | |
| 
 | |
| 	res, err := runTestQuery(t, request, beVerbose, port)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	if !slices.Equal(res, response) {
 | |
| 		t.Errorf("invalid response\ngot: %+v\nwant: %+v", res, response)
 | |
| 	}
 | |
| }
 |