mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 02:01:14 +01:00 
			
		
		
		
	This exports a number of things from the derp (generic + client) package to be used by the new derpserver package, as now used by cmd/derper. And then enough other misc changes to lock in that cmd/tailscaled can be configured to not bring in tailscale.com/client/local. (The webclient in particular, even when disabled, was bringing it in, so that's now fixed) Fixes #17257 Change-Id: I88b6c7958643fb54f386dd900bddf73d2d4d96d5 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
		
			
				
	
	
		
			236 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			236 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
package derp
 | 
						|
 | 
						|
import (
 | 
						|
	"bufio"
 | 
						|
	"bytes"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"reflect"
 | 
						|
	"sync"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"tailscale.com/tstest"
 | 
						|
	"tailscale.com/types/key"
 | 
						|
)
 | 
						|
 | 
						|
type dummyNetConn struct {
 | 
						|
	net.Conn
 | 
						|
}
 | 
						|
 | 
						|
func (dummyNetConn) SetReadDeadline(time.Time) error { return nil }
 | 
						|
 | 
						|
func TestClientRecv(t *testing.T) {
 | 
						|
	tests := []struct {
 | 
						|
		name  string
 | 
						|
		input []byte
 | 
						|
		want  any
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name: "ping",
 | 
						|
			input: []byte{
 | 
						|
				byte(FramePing), 0, 0, 0, 8,
 | 
						|
				1, 2, 3, 4, 5, 6, 7, 8,
 | 
						|
			},
 | 
						|
			want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "pong",
 | 
						|
			input: []byte{
 | 
						|
				byte(FramePong), 0, 0, 0, 8,
 | 
						|
				1, 2, 3, 4, 5, 6, 7, 8,
 | 
						|
			},
 | 
						|
			want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "health_bad",
 | 
						|
			input: []byte{
 | 
						|
				byte(FrameHealth), 0, 0, 0, 3,
 | 
						|
				byte('B'), byte('A'), byte('D'),
 | 
						|
			},
 | 
						|
			want: HealthMessage{Problem: "BAD"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "health_ok",
 | 
						|
			input: []byte{
 | 
						|
				byte(FrameHealth), 0, 0, 0, 0,
 | 
						|
			},
 | 
						|
			want: HealthMessage{},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "server_restarting",
 | 
						|
			input: []byte{
 | 
						|
				byte(FrameRestarting), 0, 0, 0, 8,
 | 
						|
				0, 0, 0, 1,
 | 
						|
				0, 0, 0, 2,
 | 
						|
			},
 | 
						|
			want: ServerRestartingMessage{
 | 
						|
				ReconnectIn: 1 * time.Millisecond,
 | 
						|
				TryFor:      2 * time.Millisecond,
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			c := &Client{
 | 
						|
				nc:    dummyNetConn{},
 | 
						|
				br:    bufio.NewReader(bytes.NewReader(tt.input)),
 | 
						|
				logf:  t.Logf,
 | 
						|
				clock: &tstest.Clock{},
 | 
						|
			}
 | 
						|
			got, err := c.Recv()
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
			if !reflect.DeepEqual(got, tt.want) {
 | 
						|
				t.Errorf("got %#v; want %#v", got, tt.want)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientSendPing(t *testing.T) {
 | 
						|
	var buf bytes.Buffer
 | 
						|
	c := &Client{
 | 
						|
		bw: bufio.NewWriter(&buf),
 | 
						|
	}
 | 
						|
	if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	want := []byte{
 | 
						|
		byte(FramePing), 0, 0, 0, 8,
 | 
						|
		1, 2, 3, 4, 5, 6, 7, 8,
 | 
						|
	}
 | 
						|
	if !bytes.Equal(buf.Bytes(), want) {
 | 
						|
		t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientSendPong(t *testing.T) {
 | 
						|
	var buf bytes.Buffer
 | 
						|
	c := &Client{
 | 
						|
		bw: bufio.NewWriter(&buf),
 | 
						|
	}
 | 
						|
	if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	want := []byte{
 | 
						|
		byte(FramePong), 0, 0, 0, 8,
 | 
						|
		1, 2, 3, 4, 5, 6, 7, 8,
 | 
						|
	}
 | 
						|
	if !bytes.Equal(buf.Bytes(), want) {
 | 
						|
		t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func BenchmarkWriteUint32(b *testing.B) {
 | 
						|
	w := bufio.NewWriter(io.Discard)
 | 
						|
	b.ReportAllocs()
 | 
						|
	b.ResetTimer()
 | 
						|
	for range b.N {
 | 
						|
		writeUint32(w, 0x0ba3a)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type nopRead struct{}
 | 
						|
 | 
						|
func (r nopRead) Read(p []byte) (int, error) {
 | 
						|
	return len(p), nil
 | 
						|
}
 | 
						|
 | 
						|
var sinkU32 uint32
 | 
						|
 | 
						|
func BenchmarkReadUint32(b *testing.B) {
 | 
						|
	r := bufio.NewReader(nopRead{})
 | 
						|
	var err error
 | 
						|
	b.ReportAllocs()
 | 
						|
	b.ResetTimer()
 | 
						|
	for range b.N {
 | 
						|
		sinkU32, err = readUint32(r)
 | 
						|
		if err != nil {
 | 
						|
			b.Fatal(err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type countWriter struct {
 | 
						|
	mu     sync.Mutex
 | 
						|
	writes int
 | 
						|
	bytes  int64
 | 
						|
}
 | 
						|
 | 
						|
func (w *countWriter) Write(p []byte) (n int, err error) {
 | 
						|
	w.mu.Lock()
 | 
						|
	defer w.mu.Unlock()
 | 
						|
	w.writes++
 | 
						|
	w.bytes += int64(len(p))
 | 
						|
	return len(p), nil
 | 
						|
}
 | 
						|
 | 
						|
func (w *countWriter) Stats() (writes int, bytes int64) {
 | 
						|
	w.mu.Lock()
 | 
						|
	defer w.mu.Unlock()
 | 
						|
	return w.writes, w.bytes
 | 
						|
}
 | 
						|
 | 
						|
func (w *countWriter) ResetStats() {
 | 
						|
	w.mu.Lock()
 | 
						|
	defer w.mu.Unlock()
 | 
						|
	w.writes, w.bytes = 0, 0
 | 
						|
}
 | 
						|
 | 
						|
func TestClientSendRateLimiting(t *testing.T) {
 | 
						|
	cw := new(countWriter)
 | 
						|
	c := &Client{
 | 
						|
		bw:    bufio.NewWriter(cw),
 | 
						|
		clock: &tstest.Clock{},
 | 
						|
	}
 | 
						|
	c.setSendRateLimiter(ServerInfoMessage{})
 | 
						|
 | 
						|
	pkt := make([]byte, 1000)
 | 
						|
	if err := c.send(key.NodePublic{}, pkt); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	writes1, bytes1 := cw.Stats()
 | 
						|
	if writes1 != 1 {
 | 
						|
		t.Errorf("writes = %v, want 1", writes1)
 | 
						|
	}
 | 
						|
 | 
						|
	// Flood should all succeed.
 | 
						|
	cw.ResetStats()
 | 
						|
	for range 1000 {
 | 
						|
		if err := c.send(key.NodePublic{}, pkt); err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	writes1K, bytes1K := cw.Stats()
 | 
						|
	if writes1K != 1000 {
 | 
						|
		t.Logf("writes = %v; want 1000", writes1K)
 | 
						|
	}
 | 
						|
	if got, want := bytes1K, bytes1*1000; got != want {
 | 
						|
		t.Logf("bytes = %v; want %v", got, want)
 | 
						|
	}
 | 
						|
 | 
						|
	// Set a rate limiter
 | 
						|
	cw.ResetStats()
 | 
						|
	c.setSendRateLimiter(ServerInfoMessage{
 | 
						|
		TokenBucketBytesPerSecond: 1,
 | 
						|
		TokenBucketBytesBurst:     int(bytes1 * 2),
 | 
						|
	})
 | 
						|
	for range 1000 {
 | 
						|
		if err := c.send(key.NodePublic{}, pkt); err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	writesLimited, bytesLimited := cw.Stats()
 | 
						|
	if writesLimited == 0 || writesLimited == writes1K {
 | 
						|
		t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
 | 
						|
	}
 | 
						|
	if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
 | 
						|
		t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
 | 
						|
	}
 | 
						|
}
 |