mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-29 23:31:28 +01:00 
			
		
		
		
	Updates #1412 Change-Id: Icd880035a31df59797b8379f4af19da5c4c453e2 Co-authored-by: Maisem Ali <maisem@tailscale.com> Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
		
			
				
	
	
		
			96 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			96 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| // Package tstest provides utilities for use in unit tests.
 | |
| package tstest
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"os"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync/atomic"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"tailscale.com/envknob"
 | |
| 	"tailscale.com/logtail/backoff"
 | |
| 	"tailscale.com/types/logger"
 | |
| 	"tailscale.com/util/cibuild"
 | |
| )
 | |
| 
 | |
| // Replace replaces the value of target with val.
 | |
| // The old value is restored when the test ends.
 | |
| func Replace[T any](t testing.TB, target *T, val T) {
 | |
| 	t.Helper()
 | |
| 	if target == nil {
 | |
| 		t.Fatalf("Replace: nil pointer")
 | |
| 		panic("unreachable") // pacify staticcheck
 | |
| 	}
 | |
| 	old := *target
 | |
| 	t.Cleanup(func() {
 | |
| 		*target = old
 | |
| 	})
 | |
| 
 | |
| 	*target = val
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // WaitFor retries try for up to maxWait.
 | |
| // It returns nil once try returns nil the first time.
 | |
| // If maxWait passes without success, it returns try's last error.
 | |
| func WaitFor(maxWait time.Duration, try func() error) error {
 | |
| 	bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4)
 | |
| 	deadline := time.Now().Add(maxWait)
 | |
| 	var err error
 | |
| 	for time.Now().Before(deadline) {
 | |
| 		err = try()
 | |
| 		if err == nil {
 | |
| 			break
 | |
| 		}
 | |
| 		bo.BackOff(context.Background(), err)
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| var testNum atomic.Int32
 | |
| 
 | |
| // Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to
 | |
| // "n/m" and this test execution number in the process mod m is not equal to n-1.
 | |
| // That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4
 | |
| // for the four jobs.
 | |
| func Shard(t testing.TB) {
 | |
| 	e := os.Getenv("TS_TEST_SHARD")
 | |
| 	a, b, ok := strings.Cut(e, "/")
 | |
| 	if !ok {
 | |
| 		return
 | |
| 	}
 | |
| 	wantShard, _ := strconv.ParseInt(a, 10, 32)
 | |
| 	shards, _ := strconv.ParseInt(b, 10, 32)
 | |
| 	if wantShard == 0 || shards == 0 {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	shard := ((testNum.Add(1) - 1) % int32(shards)) + 1
 | |
| 	if shard != int32(wantShard) {
 | |
| 		t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD
 | |
| // environment variable isn't set.
 | |
| func SkipOnUnshardedCI(t testing.TB) {
 | |
| 	if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" {
 | |
| 		t.Skip("skipping on CI without TS_TEST_SHARD")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS")
 | |
| 
 | |
| // Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true.
 | |
| func Parallel(t *testing.T) {
 | |
| 	if !serializeParallel() {
 | |
| 		t.Parallel()
 | |
| 	}
 | |
| }
 |