mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 10:11:18 +01:00 
			
		
		
		
	This is a variant of DoChan that supports context propagation, such that the context provided to the inner function will only be canceled when there are no more waiters for a given key. This can be used to deduplicate expensive and cancelable calls among multiple callers safely. Updates #11935 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Ibe1fb67442a854babbc6924fd8437b02cc9e7bcf
		
			
				
	
	
		
			477 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			477 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
// Copyright 2013 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package singleflight
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
	"os/exec"
 | 
						|
	"runtime"
 | 
						|
	"runtime/debug"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"sync/atomic"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
func TestDo(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	v, err, _ := g.Do("key", func() (interface{}, error) {
 | 
						|
		return "bar", nil
 | 
						|
	})
 | 
						|
	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
 | 
						|
		t.Errorf("Do = %v; want %v", got, want)
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("Do error = %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDoErr(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	someErr := errors.New("Some error")
 | 
						|
	v, err, _ := g.Do("key", func() (interface{}, error) {
 | 
						|
		return nil, someErr
 | 
						|
	})
 | 
						|
	if err != someErr {
 | 
						|
		t.Errorf("Do error = %v; want someErr %v", err, someErr)
 | 
						|
	}
 | 
						|
	if v != nil {
 | 
						|
		t.Errorf("unexpected non-nil value %#v", v)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDoDupSuppress(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	var wg1, wg2 sync.WaitGroup
 | 
						|
	c := make(chan string, 1)
 | 
						|
	var calls int32
 | 
						|
	fn := func() (interface{}, error) {
 | 
						|
		if atomic.AddInt32(&calls, 1) == 1 {
 | 
						|
			// First invocation.
 | 
						|
			wg1.Done()
 | 
						|
		}
 | 
						|
		v := <-c
 | 
						|
		c <- v // pump; make available for any future calls
 | 
						|
 | 
						|
		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
 | 
						|
 | 
						|
		return v, nil
 | 
						|
	}
 | 
						|
 | 
						|
	const n = 10
 | 
						|
	wg1.Add(1)
 | 
						|
	for range n {
 | 
						|
		wg1.Add(1)
 | 
						|
		wg2.Add(1)
 | 
						|
		go func() {
 | 
						|
			defer wg2.Done()
 | 
						|
			wg1.Done()
 | 
						|
			v, err, _ := g.Do("key", fn)
 | 
						|
			if err != nil {
 | 
						|
				t.Errorf("Do error: %v", err)
 | 
						|
				return
 | 
						|
			}
 | 
						|
			if s, _ := v.(string); s != "bar" {
 | 
						|
				t.Errorf("Do = %T %v; want %q", v, v, "bar")
 | 
						|
			}
 | 
						|
		}()
 | 
						|
	}
 | 
						|
	wg1.Wait()
 | 
						|
	// At least one goroutine is in fn now and all of them have at
 | 
						|
	// least reached the line before the Do.
 | 
						|
	c <- "bar"
 | 
						|
	wg2.Wait()
 | 
						|
	if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
 | 
						|
		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Test that singleflight behaves correctly after Forget called.
 | 
						|
// See https://github.com/golang/go/issues/31420
 | 
						|
func TestForget(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
 | 
						|
	var (
 | 
						|
		firstStarted  = make(chan struct{})
 | 
						|
		unblockFirst  = make(chan struct{})
 | 
						|
		firstFinished = make(chan struct{})
 | 
						|
	)
 | 
						|
 | 
						|
	go func() {
 | 
						|
		g.Do("key", func() (i interface{}, e error) {
 | 
						|
			close(firstStarted)
 | 
						|
			<-unblockFirst
 | 
						|
			close(firstFinished)
 | 
						|
			return
 | 
						|
		})
 | 
						|
	}()
 | 
						|
	<-firstStarted
 | 
						|
	g.Forget("key")
 | 
						|
 | 
						|
	unblockSecond := make(chan struct{})
 | 
						|
	secondResult := g.DoChan("key", func() (i interface{}, e error) {
 | 
						|
		<-unblockSecond
 | 
						|
		return 2, nil
 | 
						|
	})
 | 
						|
 | 
						|
	close(unblockFirst)
 | 
						|
	<-firstFinished
 | 
						|
 | 
						|
	thirdResult := g.DoChan("key", func() (i interface{}, e error) {
 | 
						|
		return 3, nil
 | 
						|
	})
 | 
						|
 | 
						|
	close(unblockSecond)
 | 
						|
	<-secondResult
 | 
						|
	r := <-thirdResult
 | 
						|
	if r.Val != 2 {
 | 
						|
		t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDoChan(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	ch := g.DoChan("key", func() (interface{}, error) {
 | 
						|
		return "bar", nil
 | 
						|
	})
 | 
						|
 | 
						|
	res := <-ch
 | 
						|
	v := res.Val
 | 
						|
	err := res.Err
 | 
						|
	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
 | 
						|
		t.Errorf("Do = %v; want %v", got, want)
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		t.Errorf("Do error = %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Test singleflight behaves correctly after Do panic.
 | 
						|
// See https://github.com/golang/go/issues/41133
 | 
						|
func TestPanicDo(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	fn := func() (interface{}, error) {
 | 
						|
		panic("invalid memory address or nil pointer dereference")
 | 
						|
	}
 | 
						|
 | 
						|
	const n = 5
 | 
						|
	waited := int32(n)
 | 
						|
	panicCount := int32(0)
 | 
						|
	done := make(chan struct{})
 | 
						|
	for range n {
 | 
						|
		go func() {
 | 
						|
			defer func() {
 | 
						|
				if err := recover(); err != nil {
 | 
						|
					t.Logf("Got panic: %v\n%s", err, debug.Stack())
 | 
						|
					atomic.AddInt32(&panicCount, 1)
 | 
						|
				}
 | 
						|
 | 
						|
				if atomic.AddInt32(&waited, -1) == 0 {
 | 
						|
					close(done)
 | 
						|
				}
 | 
						|
			}()
 | 
						|
 | 
						|
			g.Do("key", fn)
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	select {
 | 
						|
	case <-done:
 | 
						|
		if panicCount != n {
 | 
						|
			t.Errorf("Expect %d panic, but got %d", n, panicCount)
 | 
						|
		}
 | 
						|
	case <-time.After(time.Second):
 | 
						|
		t.Fatalf("Do hangs")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGoexitDo(t *testing.T) {
 | 
						|
	var g Group[string, any]
 | 
						|
	fn := func() (interface{}, error) {
 | 
						|
		runtime.Goexit()
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	const n = 5
 | 
						|
	waited := int32(n)
 | 
						|
	done := make(chan struct{})
 | 
						|
	for range n {
 | 
						|
		go func() {
 | 
						|
			var err error
 | 
						|
			defer func() {
 | 
						|
				if err != nil {
 | 
						|
					t.Errorf("Error should be nil, but got: %v", err)
 | 
						|
				}
 | 
						|
				if atomic.AddInt32(&waited, -1) == 0 {
 | 
						|
					close(done)
 | 
						|
				}
 | 
						|
			}()
 | 
						|
			_, err, _ = g.Do("key", fn)
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	select {
 | 
						|
	case <-done:
 | 
						|
	case <-time.After(time.Second):
 | 
						|
		t.Fatalf("Do hangs")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPanicDoChan(t *testing.T) {
 | 
						|
	if runtime.GOOS == "js" {
 | 
						|
		t.Skipf("js does not support exec")
 | 
						|
	}
 | 
						|
 | 
						|
	if os.Getenv("TEST_PANIC_DOCHAN") != "" {
 | 
						|
		defer func() {
 | 
						|
			recover()
 | 
						|
		}()
 | 
						|
 | 
						|
		g := new(Group[string, any])
 | 
						|
		ch := g.DoChan("", func() (interface{}, error) {
 | 
						|
			panic("Panicking in DoChan")
 | 
						|
		})
 | 
						|
		<-ch
 | 
						|
		t.Fatalf("DoChan unexpectedly returned")
 | 
						|
	}
 | 
						|
 | 
						|
	t.Parallel()
 | 
						|
 | 
						|
	cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
 | 
						|
	cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
 | 
						|
	out := new(bytes.Buffer)
 | 
						|
	cmd.Stdout = out
 | 
						|
	cmd.Stderr = out
 | 
						|
	if err := cmd.Start(); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	err := cmd.Wait()
 | 
						|
	t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
 | 
						|
	if err == nil {
 | 
						|
		t.Errorf("Test subprocess passed; want a crash due to panic in DoChan")
 | 
						|
	}
 | 
						|
	if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
 | 
						|
		t.Errorf("Test subprocess failed with an unexpected failure mode.")
 | 
						|
	}
 | 
						|
	if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) {
 | 
						|
		t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPanicDoSharedByDoChan(t *testing.T) {
 | 
						|
	if runtime.GOOS == "js" {
 | 
						|
		t.Skipf("js does not support exec")
 | 
						|
	}
 | 
						|
 | 
						|
	if os.Getenv("TEST_PANIC_DOCHAN") != "" {
 | 
						|
		blocked := make(chan struct{})
 | 
						|
		unblock := make(chan struct{})
 | 
						|
 | 
						|
		g := new(Group[string, any])
 | 
						|
		go func() {
 | 
						|
			defer func() {
 | 
						|
				recover()
 | 
						|
			}()
 | 
						|
			g.Do("", func() (interface{}, error) {
 | 
						|
				close(blocked)
 | 
						|
				<-unblock
 | 
						|
				panic("Panicking in Do")
 | 
						|
			})
 | 
						|
		}()
 | 
						|
 | 
						|
		<-blocked
 | 
						|
		ch := g.DoChan("", func() (interface{}, error) {
 | 
						|
			panic("DoChan unexpectedly executed callback")
 | 
						|
		})
 | 
						|
		close(unblock)
 | 
						|
		<-ch
 | 
						|
		t.Fatalf("DoChan unexpectedly returned")
 | 
						|
	}
 | 
						|
 | 
						|
	t.Parallel()
 | 
						|
 | 
						|
	cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
 | 
						|
	cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
 | 
						|
	out := new(bytes.Buffer)
 | 
						|
	cmd.Stdout = out
 | 
						|
	cmd.Stderr = out
 | 
						|
	if err := cmd.Start(); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	err := cmd.Wait()
 | 
						|
	t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
 | 
						|
	if err == nil {
 | 
						|
		t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan")
 | 
						|
	}
 | 
						|
	if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
 | 
						|
		t.Errorf("Test subprocess failed with an unexpected failure mode.")
 | 
						|
	}
 | 
						|
	if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) {
 | 
						|
		t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDoChanContext(t *testing.T) {
 | 
						|
	t.Run("Basic", func(t *testing.T) {
 | 
						|
		ctx, cancel := context.WithCancel(context.Background())
 | 
						|
		defer cancel()
 | 
						|
 | 
						|
		var g Group[string, int]
 | 
						|
		ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) {
 | 
						|
			return 1, nil
 | 
						|
		})
 | 
						|
		ret := <-ch
 | 
						|
		assertOKResult(t, ret, 1)
 | 
						|
	})
 | 
						|
 | 
						|
	t.Run("DoesNotPropagateValues", func(t *testing.T) {
 | 
						|
		ctx, cancel := context.WithCancel(context.Background())
 | 
						|
		defer cancel()
 | 
						|
 | 
						|
		key := new(int)
 | 
						|
		const value = "hello world"
 | 
						|
 | 
						|
		ctx = context.WithValue(ctx, key, value)
 | 
						|
 | 
						|
		var g Group[string, int]
 | 
						|
		ch := g.DoChanContext(ctx, "foobar", func(ctx context.Context) (int, error) {
 | 
						|
			if _, ok := ctx.Value(key).(string); ok {
 | 
						|
				t.Error("expected no value, but was present in context")
 | 
						|
			}
 | 
						|
			return 1, nil
 | 
						|
		})
 | 
						|
		ret := <-ch
 | 
						|
		assertOKResult(t, ret, 1)
 | 
						|
	})
 | 
						|
 | 
						|
	t.Run("NoCancelWhenWaiters", func(t *testing.T) {
 | 
						|
		testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
 | 
						|
		defer testCancel()
 | 
						|
 | 
						|
		trigger := make(chan struct{})
 | 
						|
 | 
						|
		ctx1, cancel1 := context.WithCancel(context.Background())
 | 
						|
		defer cancel1()
 | 
						|
		ctx2, cancel2 := context.WithCancel(context.Background())
 | 
						|
		defer cancel2()
 | 
						|
 | 
						|
		fn := func(ctx context.Context) (int, error) {
 | 
						|
			select {
 | 
						|
			case <-ctx.Done():
 | 
						|
				return 0, ctx.Err()
 | 
						|
			case <-trigger:
 | 
						|
				return 1234, nil
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		// Create two waiters, then cancel the first before we trigger
 | 
						|
		// the function to return a value. This shouldn't result in a
 | 
						|
		// context canceled error.
 | 
						|
		var g Group[string, int]
 | 
						|
		ch1 := g.DoChanContext(ctx1, "key", fn)
 | 
						|
		ch2 := g.DoChanContext(ctx2, "key", fn)
 | 
						|
 | 
						|
		cancel1()
 | 
						|
 | 
						|
		// The first channel, now that it's canceled, should return a
 | 
						|
		// context canceled error.
 | 
						|
		select {
 | 
						|
		case res := <-ch1:
 | 
						|
			if !errors.Is(res.Err, context.Canceled) {
 | 
						|
				t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
 | 
						|
			}
 | 
						|
		case <-testCtx.Done():
 | 
						|
			t.Fatal("test timed out")
 | 
						|
		}
 | 
						|
 | 
						|
		// Actually return
 | 
						|
		close(trigger)
 | 
						|
		res := <-ch2
 | 
						|
		assertOKResult(t, res, 1234)
 | 
						|
	})
 | 
						|
 | 
						|
	t.Run("AllCancel", func(t *testing.T) {
 | 
						|
		for _, n := range []int{1, 2, 10, 20} {
 | 
						|
			t.Run(fmt.Sprintf("NumWaiters=%d", n), func(t *testing.T) {
 | 
						|
				testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
 | 
						|
				defer testCancel()
 | 
						|
 | 
						|
				trigger := make(chan struct{})
 | 
						|
				defer close(trigger)
 | 
						|
 | 
						|
				fn := func(ctx context.Context) (int, error) {
 | 
						|
					select {
 | 
						|
					case <-ctx.Done():
 | 
						|
						return 0, ctx.Err()
 | 
						|
					case <-trigger:
 | 
						|
						t.Error("unexpected trigger; want all callers to cancel")
 | 
						|
						return 0, errors.New("unexpected trigger")
 | 
						|
					}
 | 
						|
				}
 | 
						|
 | 
						|
				// Launch N goroutines that all wait on the same key.
 | 
						|
				var (
 | 
						|
					g       Group[string, int]
 | 
						|
					chs     []<-chan Result[int]
 | 
						|
					cancels []context.CancelFunc
 | 
						|
				)
 | 
						|
				for i := range n {
 | 
						|
					ctx, cancel := context.WithCancel(context.Background())
 | 
						|
					defer cancel()
 | 
						|
					cancels = append(cancels, cancel)
 | 
						|
 | 
						|
					ch := g.DoChanContext(ctx, "key", fn)
 | 
						|
					chs = append(chs, ch)
 | 
						|
 | 
						|
					// Every third goroutine should cancel
 | 
						|
					// immediately, which better tests the
 | 
						|
					// cancel logic.
 | 
						|
					if i%3 == 0 {
 | 
						|
						cancel()
 | 
						|
					}
 | 
						|
				}
 | 
						|
 | 
						|
				// Now that everything is waiting, cancel all the contexts.
 | 
						|
				for _, cancel := range cancels {
 | 
						|
					cancel()
 | 
						|
				}
 | 
						|
 | 
						|
				// Wait for a result from each channel. They
 | 
						|
				// should all return an error showing a context
 | 
						|
				// cancel.
 | 
						|
				for _, ch := range chs {
 | 
						|
					select {
 | 
						|
					case res := <-ch:
 | 
						|
						if !errors.Is(res.Err, context.Canceled) {
 | 
						|
							t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
 | 
						|
						}
 | 
						|
					case <-testCtx.Done():
 | 
						|
						t.Fatal("test timed out")
 | 
						|
					}
 | 
						|
				}
 | 
						|
			})
 | 
						|
		}
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func assertOKResult[V comparable](t testing.TB, res Result[V], want V) {
 | 
						|
	if res.Err != nil {
 | 
						|
		t.Fatalf("unexpected error: %v", res.Err)
 | 
						|
	}
 | 
						|
	if res.Val != want {
 | 
						|
		t.Fatalf("unexpected value; got %v, want %v", res.Val, want)
 | 
						|
	}
 | 
						|
}
 |