mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 00:01:40 +01:00 
			
		
		
		
	This deprecates the old "DERP string" packing a DERP region ID into an IP:port of 127.3.3.40:$REGION_ID and just uses an integer, like PeerChange.DERPRegion does. We still support servers sending the old form; they're converted to the new form internally right when they're read off the network. Updates #14636 Change-Id: I9427ec071f02a2c6d75ccb0fcbf0ecff9f19f26f Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
		
			
				
	
	
		
			391 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			391 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| package controlbase
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/binary"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"runtime"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"testing/iotest"
 | |
| 	"time"
 | |
| 
 | |
| 	chp "golang.org/x/crypto/chacha20poly1305"
 | |
| 	"golang.org/x/net/nettest"
 | |
| 	"tailscale.com/net/memnet"
 | |
| 	"tailscale.com/types/key"
 | |
| )
 | |
| 
 | |
| const testProtocolVersion = 1
 | |
| 
 | |
| func TestMessageSize(t *testing.T) {
 | |
| 	// This test is a regression guard against someone looking at
 | |
| 	// maxCiphertextSize, going "huh, we could be more efficient if it
 | |
| 	// were larger, and accidentally violating the Noise spec. Do not
 | |
| 	// change this max value, it's a deliberate limitation of the
 | |
| 	// cryptographic protocol we use (see Section 3 "Message Format"
 | |
| 	// of the Noise spec).
 | |
| 	const max = 65535
 | |
| 	if maxCiphertextSize > max {
 | |
| 		t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestConnBasic(t *testing.T) {
 | |
| 	client, server := pair(t)
 | |
| 
 | |
| 	sb := sinkReads(server)
 | |
| 
 | |
| 	want := "test"
 | |
| 	if _, err := io.WriteString(client, want); err != nil {
 | |
| 		t.Fatalf("client write failed: %v", err)
 | |
| 	}
 | |
| 	client.Close()
 | |
| 
 | |
| 	if got := sb.String(4); got != want {
 | |
| 		t.Fatalf("wrong content received: got %q, want %q", got, want)
 | |
| 	}
 | |
| 	if err := sb.Error(); err != io.EOF {
 | |
| 		t.Fatal("client close wasn't seen by server")
 | |
| 	}
 | |
| 	if sb.Total() != 4 {
 | |
| 		t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // bufferedWriteConn wraps a net.Conn and gives control over how
 | |
| // Writes get batched out.
 | |
| type bufferedWriteConn struct {
 | |
| 	net.Conn
 | |
| 	w           *bufio.Writer
 | |
| 	manualFlush bool
 | |
| }
 | |
| 
 | |
| func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
 | |
| 	n, err := c.w.Write(bs)
 | |
| 	if err == nil && !c.manualFlush {
 | |
| 		err = c.w.Flush()
 | |
| 	}
 | |
| 	return n, err
 | |
| }
 | |
| 
 | |
| // TestFastPath exercises the Read codepath that can receive multiple
 | |
| // Noise frames at once and decode each in turn without making another
 | |
| // syscall.
 | |
| func TestFastPath(t *testing.T) {
 | |
| 	s1, s2 := memnet.NewConn("noise", 128000)
 | |
| 	b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
 | |
| 	client, server := pairWithConns(t, b, s2)
 | |
| 
 | |
| 	b.manualFlush = true
 | |
| 
 | |
| 	sb := sinkReads(server)
 | |
| 
 | |
| 	const packets = 10
 | |
| 	s := "test"
 | |
| 	for range packets {
 | |
| 		// Many separate writes, to force separate Noise frames that
 | |
| 		// all get buffered up and then all sent as a single slice to
 | |
| 		// the server.
 | |
| 		if _, err := io.WriteString(client, s); err != nil {
 | |
| 			t.Fatalf("client write1 failed: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 	if err := b.w.Flush(); err != nil {
 | |
| 		t.Fatalf("client flush failed: %v", err)
 | |
| 	}
 | |
| 	client.Close()
 | |
| 
 | |
| 	want := strings.Repeat(s, packets)
 | |
| 	if got := sb.String(len(want)); got != want {
 | |
| 		t.Fatalf("wrong content received: got %q, want %q", got, want)
 | |
| 	}
 | |
| 	if err := sb.Error(); err != io.EOF {
 | |
| 		t.Fatalf("client close wasn't seen by server")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Writes things larger than a single Noise frame, to check the
 | |
| // chunking on the encoder and decoder.
 | |
| func TestBigData(t *testing.T) {
 | |
| 	client, server := pair(t)
 | |
| 
 | |
| 	serverReads := sinkReads(server)
 | |
| 	clientReads := sinkReads(client)
 | |
| 
 | |
| 	const sz = 15 * 1024 // 15KiB
 | |
| 	clientStr := strings.Repeat("abcde", sz/5)
 | |
| 	serverStr := strings.Repeat("fghij", sz/5*2)
 | |
| 
 | |
| 	if _, err := io.WriteString(client, clientStr); err != nil {
 | |
| 		t.Fatalf("writing client>server: %v", err)
 | |
| 	}
 | |
| 	if _, err := io.WriteString(server, serverStr); err != nil {
 | |
| 		t.Fatalf("writing server>client: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if serverGot := serverReads.String(sz); serverGot != clientStr {
 | |
| 		t.Error("server didn't receive what client sent")
 | |
| 	}
 | |
| 	if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
 | |
| 		t.Error("client didn't receive what server sent")
 | |
| 	}
 | |
| 
 | |
| 	getNonce := func(n [chp.NonceSize]byte) uint64 {
 | |
| 		if binary.BigEndian.Uint32(n[:4]) != 0 {
 | |
| 			panic("unexpected nonce")
 | |
| 		}
 | |
| 		return binary.BigEndian.Uint64(n[4:])
 | |
| 	}
 | |
| 
 | |
| 	// Reach into the Conns and verify the cipher nonces advanced as
 | |
| 	// expected.
 | |
| 	if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
 | |
| 		t.Error("desynchronized client tx nonce")
 | |
| 	}
 | |
| 	if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
 | |
| 		t.Error("desynchronized server tx nonce")
 | |
| 	}
 | |
| 	if n := getNonce(client.tx.nonce); n != 4 {
 | |
| 		t.Errorf("wrong client tx nonce, got %d want 4", n)
 | |
| 	}
 | |
| 	if n := getNonce(server.tx.nonce); n != 8 {
 | |
| 		t.Errorf("wrong client tx nonce, got %d want 8", n)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // readerConn wraps a net.Conn and routes its Reads through a separate
 | |
| // io.Reader.
 | |
| type readerConn struct {
 | |
| 	net.Conn
 | |
| 	r io.Reader
 | |
| }
 | |
| 
 | |
| func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
 | |
| 
 | |
| // Check that the receiver can handle not being able to read an entire
 | |
| // frame in a single syscall.
 | |
| func TestDataTrickle(t *testing.T) {
 | |
| 	s1, s2 := memnet.NewConn("noise", 128000)
 | |
| 	client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
 | |
| 	serverReads := sinkReads(server)
 | |
| 
 | |
| 	const sz = 10000
 | |
| 	clientStr := strings.Repeat("abcde", sz/5)
 | |
| 	if _, err := io.WriteString(client, clientStr); err != nil {
 | |
| 		t.Fatalf("writing client>server: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	serverGot := serverReads.String(sz)
 | |
| 	if serverGot != clientStr {
 | |
| 		t.Error("server didn't receive what client sent")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestConnStd(t *testing.T) {
 | |
| 	// You can run this test manually, and noise.Conn should pass all
 | |
| 	// of them except for TestConn/PastTimeout,
 | |
| 	// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
 | |
| 	// those tests assume that write errors are recoverable, and
 | |
| 	// they're not on our Conn due to cipher security.
 | |
| 	t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
 | |
| 	nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
 | |
| 		s1, s2 := memnet.NewConn("noise", 4096)
 | |
| 		controlKey := key.NewMachine()
 | |
| 		machineKey := key.NewMachine()
 | |
| 		serverErr := make(chan error, 1)
 | |
| 		go func() {
 | |
| 			var err error
 | |
| 			c2, err = Server(context.Background(), s2, controlKey, nil)
 | |
| 			serverErr <- err
 | |
| 		}()
 | |
| 		c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
 | |
| 		if err != nil {
 | |
| 			s1.Close()
 | |
| 			s2.Close()
 | |
| 			return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
 | |
| 		}
 | |
| 		if err := <-serverErr; err != nil {
 | |
| 			c1.Close()
 | |
| 			s1.Close()
 | |
| 			s2.Close()
 | |
| 			return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
 | |
| 		}
 | |
| 		return c1, c2, func() {
 | |
| 			c1.Close()
 | |
| 			c2.Close()
 | |
| 		}, nil
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // tests that the idle memory overhead of a Conn blocked in a read is
 | |
| // reasonable (under 2K). It was previously over 8KB with two 4KB
 | |
| // buffers for rx/tx. This make sure we don't regress. Hopefully it
 | |
| // doesn't turn into a flaky test. If so, const max can be adjusted,
 | |
| // or it can be deleted or reworked.
 | |
| func TestConnMemoryOverhead(t *testing.T) {
 | |
| 	num := 1000
 | |
| 	if testing.Short() {
 | |
| 		num = 100
 | |
| 	}
 | |
| 	ng0 := runtime.NumGoroutine()
 | |
| 
 | |
| 	runtime.GC()
 | |
| 	var ms0 runtime.MemStats
 | |
| 	runtime.ReadMemStats(&ms0)
 | |
| 
 | |
| 	var closers []io.Closer
 | |
| 	closeAll := func() {
 | |
| 		for _, c := range closers {
 | |
| 			c.Close()
 | |
| 		}
 | |
| 		closers = nil
 | |
| 	}
 | |
| 	defer closeAll()
 | |
| 
 | |
| 	for range num {
 | |
| 		client, server := pair(t)
 | |
| 		closers = append(closers, client, server)
 | |
| 		go func() {
 | |
| 			var buf [1]byte
 | |
| 			client.Read(buf[:])
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	t0 := time.Now()
 | |
| 	deadline := t0.Add(3 * time.Second)
 | |
| 	var ngo int
 | |
| 	for time.Now().Before(deadline) {
 | |
| 		runtime.GC()
 | |
| 		ngo = runtime.NumGoroutine()
 | |
| 		if ngo >= num {
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(10 * time.Millisecond)
 | |
| 	}
 | |
| 	if ngo < num {
 | |
| 		t.Fatalf("only %v goroutines; expected %v+", ngo, num)
 | |
| 	}
 | |
| 	runtime.GC()
 | |
| 	var ms runtime.MemStats
 | |
| 	runtime.ReadMemStats(&ms)
 | |
| 	growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc)
 | |
| 	growthEach := float64(growthTotal) / float64(num)
 | |
| 	t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach)
 | |
| 	const max = 2048
 | |
| 	if growthEach > max {
 | |
| 		t.Errorf("allocated more than expected; want max %v bytes/each", max)
 | |
| 	}
 | |
| 
 | |
| 	closeAll()
 | |
| 
 | |
| 	// And make sure our goroutines go away too.
 | |
| 	deadline = time.Now().Add(3 * time.Second)
 | |
| 	for time.Now().Before(deadline) {
 | |
| 		ngo = runtime.NumGoroutine()
 | |
| 		if ngo < ng0+num/10 {
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(10 * time.Millisecond)
 | |
| 	}
 | |
| 	if ngo >= ng0+num/10 {
 | |
| 		t.Errorf("goroutines didn't go back down; started at %v, now %v", ng0, ngo)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type readSink struct {
 | |
| 	r io.Reader
 | |
| 
 | |
| 	cond *sync.Cond
 | |
| 	sync.Mutex
 | |
| 	bs  bytes.Buffer
 | |
| 	err error
 | |
| }
 | |
| 
 | |
| func sinkReads(r io.Reader) *readSink {
 | |
| 	ret := &readSink{
 | |
| 		r: r,
 | |
| 	}
 | |
| 	ret.cond = sync.NewCond(&ret.Mutex)
 | |
| 	go func() {
 | |
| 		var buf [4096]byte
 | |
| 		for {
 | |
| 			n, err := r.Read(buf[:])
 | |
| 			ret.Lock()
 | |
| 			ret.bs.Write(buf[:n])
 | |
| 			if err != nil {
 | |
| 				ret.err = err
 | |
| 			}
 | |
| 			ret.cond.Broadcast()
 | |
| 			ret.Unlock()
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| func (s *readSink) String(total int) string {
 | |
| 	s.Lock()
 | |
| 	defer s.Unlock()
 | |
| 	for s.bs.Len() < total && s.err == nil {
 | |
| 		s.cond.Wait()
 | |
| 	}
 | |
| 	if s.err != nil {
 | |
| 		total = s.bs.Len()
 | |
| 	}
 | |
| 	return string(s.bs.Bytes()[:total])
 | |
| }
 | |
| 
 | |
| func (s *readSink) Error() error {
 | |
| 	s.Lock()
 | |
| 	defer s.Unlock()
 | |
| 	for s.err == nil {
 | |
| 		s.cond.Wait()
 | |
| 	}
 | |
| 	return s.err
 | |
| }
 | |
| 
 | |
| func (s *readSink) Total() int {
 | |
| 	s.Lock()
 | |
| 	defer s.Unlock()
 | |
| 	return s.bs.Len()
 | |
| }
 | |
| 
 | |
| func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
 | |
| 	var (
 | |
| 		controlKey = key.NewMachine()
 | |
| 		machineKey = key.NewMachine()
 | |
| 		server     *Conn
 | |
| 		serverErr  = make(chan error, 1)
 | |
| 	)
 | |
| 	go func() {
 | |
| 		var err error
 | |
| 		server, err = Server(context.Background(), serverConn, controlKey, nil)
 | |
| 		serverErr <- err
 | |
| 	}()
 | |
| 
 | |
| 	client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("client connection failed: %v", err)
 | |
| 	}
 | |
| 	if err := <-serverErr; err != nil {
 | |
| 		t.Fatalf("server connection failed: %v", err)
 | |
| 	}
 | |
| 	return client, server
 | |
| }
 | |
| 
 | |
| func pair(t *testing.T) (*Conn, *Conn) {
 | |
| 	s1, s2 := memnet.NewConn("noise", 128000)
 | |
| 	return pairWithConns(t, s1, s2)
 | |
| }
 |