mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 02:01:14 +01:00 
			
		
		
		
	This updates all source files to use a new standard header for copyright and license declaration. Notably, copyright no longer includes a date, and we now use the standard SPDX-License-Identifier header. This commit was done almost entirely mechanically with perl, and then some minimal manual fixes. Updates #6865 Signed-off-by: Will Norris <will@tailscale.com>
		
			
				
	
	
		
			147 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
package speedtest
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/rand"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
// Serve starts up the server on a given host and port pair. It starts to listen for
 | 
						|
// connections and handles each one in a goroutine. Because it runs in an infinite loop,
 | 
						|
// this function only returns if any of the speedtests return with errors, or if the
 | 
						|
// listener is closed.
 | 
						|
func Serve(l net.Listener) error {
 | 
						|
	for {
 | 
						|
		conn, err := l.Accept()
 | 
						|
		if errors.Is(err, net.ErrClosed) {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		err = handleConnection(conn)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// handleConnection handles the initial exchange between the server and the client.
 | 
						|
// It reads the testconfig message into a config struct. If any errors occur with
 | 
						|
// the testconfig (specifically, if there is a version mismatch), it will return those
 | 
						|
// errors to the client with a configResponse. After the exchange, it will start
 | 
						|
// the speed test.
 | 
						|
func handleConnection(conn net.Conn) error {
 | 
						|
	defer conn.Close()
 | 
						|
	var conf config
 | 
						|
 | 
						|
	decoder := json.NewDecoder(conn)
 | 
						|
	err := decoder.Decode(&conf)
 | 
						|
	encoder := json.NewEncoder(conn)
 | 
						|
 | 
						|
	// Both return and encode errors that occurred before the test started.
 | 
						|
	if err != nil {
 | 
						|
		encoder.Encode(configResponse{Error: err.Error()})
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// The server should always be doing the opposite of what the client is doing.
 | 
						|
	conf.Direction.Reverse()
 | 
						|
 | 
						|
	if conf.Version != version {
 | 
						|
		err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
 | 
						|
		encoder.Encode(configResponse{Error: err.Error()})
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// Start the test
 | 
						|
	encoder.Encode(configResponse{})
 | 
						|
	_, err = doTest(conn, conf)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// TODO include code to detect whether the code is direct vs DERP
 | 
						|
 | 
						|
// doTest contains the code to run both the upload and download speedtest.
 | 
						|
// the direction value in the config parameter determines which test to run.
 | 
						|
func doTest(conn net.Conn, conf config) ([]Result, error) {
 | 
						|
	bufferData := make([]byte, blockSize)
 | 
						|
 | 
						|
	intervalBytes := 0
 | 
						|
	totalBytes := 0
 | 
						|
 | 
						|
	var currentTime time.Time
 | 
						|
	var results []Result
 | 
						|
 | 
						|
	if conf.Direction == Download {
 | 
						|
		conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
 | 
						|
	} else {
 | 
						|
		_, err := rand.Read(bufferData)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
 | 
						|
	startTime := time.Now()
 | 
						|
	lastCalculated := startTime
 | 
						|
 | 
						|
SpeedTestLoop:
 | 
						|
	for {
 | 
						|
		var n int
 | 
						|
		var err error
 | 
						|
 | 
						|
		if conf.Direction == Download {
 | 
						|
			n, err = io.ReadFull(conn, bufferData)
 | 
						|
			switch err {
 | 
						|
			case io.EOF, io.ErrUnexpectedEOF:
 | 
						|
				break SpeedTestLoop
 | 
						|
			case nil:
 | 
						|
				// successful read
 | 
						|
			default:
 | 
						|
				return nil, fmt.Errorf("unexpected error has occurred: %w", err)
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			n, err = conn.Write(bufferData)
 | 
						|
			if err != nil {
 | 
						|
				// If the write failed, there is most likely something wrong with the connection.
 | 
						|
				return nil, fmt.Errorf("upload failed: %w", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
		intervalBytes += n
 | 
						|
 | 
						|
		currentTime = time.Now()
 | 
						|
		// checks if the current time is more or equal to the lastCalculated time plus the increment
 | 
						|
		if currentTime.Sub(lastCalculated) >= increment {
 | 
						|
			results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
 | 
						|
			lastCalculated = currentTime
 | 
						|
			totalBytes += intervalBytes
 | 
						|
			intervalBytes = 0
 | 
						|
		}
 | 
						|
 | 
						|
		if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
 | 
						|
			break SpeedTestLoop
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// get last segment
 | 
						|
	if currentTime.Sub(lastCalculated) > minInterval {
 | 
						|
		results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
 | 
						|
	}
 | 
						|
 | 
						|
	// get total
 | 
						|
	totalBytes += intervalBytes
 | 
						|
	if currentTime.Sub(startTime) > minInterval {
 | 
						|
		results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
 | 
						|
	}
 | 
						|
 | 
						|
	return results, nil
 | 
						|
}
 |