mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 08:11:32 +01:00 
			
		
		
		
	This can be used to reload a value periodically, whether from disk or another source, while handling jitter and graceful shutdown. Updates tailscale/corp#1297 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Iee2b4385c9abae59805f642a7308837877cb5b3f
		
			
				
	
	
		
			172 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| package reload
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"reflect"
 | |
| 	"sync/atomic"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"tailscale.com/tstest"
 | |
| )
 | |
| 
 | |
| func TestReloader(t *testing.T) {
 | |
| 	buf := []byte("hello world")
 | |
| 
 | |
| 	ctx := context.Background()
 | |
| 	r, err := newUnstarted[string](ctx, ReloadOpts[string]{
 | |
| 		Logf: t.Logf,
 | |
| 		Read: func(context.Context) ([]byte, error) {
 | |
| 			return buf, nil
 | |
| 		},
 | |
| 		Unmarshal: func(b []byte) (string, error) {
 | |
| 			return "The value is: " + string(b), nil
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// We should have an initial value.
 | |
| 	const wantInitial = "The value is: hello world"
 | |
| 	if v := r.store.Load(); v != wantInitial {
 | |
| 		t.Errorf("got initial value %q, want %q", v, wantInitial)
 | |
| 	}
 | |
| 
 | |
| 	// Reloading should result in a new value
 | |
| 	buf = []byte("new value")
 | |
| 	if err := r.updateOnce(); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	const wantReload = "The value is: new value"
 | |
| 	if v := r.store.Load(); v != wantReload {
 | |
| 		t.Errorf("got reloaded value %q, want %q", v, wantReload)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestReloader_InitialError(t *testing.T) {
 | |
| 	fakeErr := errors.New("fake error")
 | |
| 
 | |
| 	ctx := context.Background()
 | |
| 	_, err := newUnstarted[string](ctx, ReloadOpts[string]{
 | |
| 		Logf:      t.Logf,
 | |
| 		Read:      func(context.Context) ([]byte, error) { return nil, fakeErr },
 | |
| 		Unmarshal: func(b []byte) (string, error) { panic("unused because Read fails") },
 | |
| 	})
 | |
| 	if err == nil {
 | |
| 		t.Fatal("expected non-nil error")
 | |
| 	}
 | |
| 	if !errors.Is(err, fakeErr) {
 | |
| 		t.Errorf("wanted errors.Is(%v, fakeErr)=true", err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestReloader_ReloadError(t *testing.T) {
 | |
| 	fakeErr := errors.New("fake error")
 | |
| 	shouldError := false
 | |
| 
 | |
| 	ctx := context.Background()
 | |
| 	r, err := newUnstarted[string](ctx, ReloadOpts[string]{
 | |
| 		Logf: t.Logf,
 | |
| 		Read: func(context.Context) ([]byte, error) {
 | |
| 			return []byte("hello"), nil
 | |
| 		},
 | |
| 		Unmarshal: func(b []byte) (string, error) {
 | |
| 			if shouldError {
 | |
| 				return "", fakeErr
 | |
| 			}
 | |
| 			return string(b), nil
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if got := r.store.Load(); got != "hello" {
 | |
| 		t.Fatalf("got value %q, want \"hello\"", got)
 | |
| 	}
 | |
| 
 | |
| 	shouldError = true
 | |
| 
 | |
| 	if err := r.updateOnce(); err == nil {
 | |
| 		t.Errorf("expected error from updateOnce")
 | |
| 	}
 | |
| 	if got := r.store.Load(); got != "hello" {
 | |
| 		t.Fatalf("got value %q, want \"hello\"", got)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestReloader_Run(t *testing.T) {
 | |
| 	ctx, cancel := context.WithCancel(context.Background())
 | |
| 	defer cancel()
 | |
| 
 | |
| 	var ncalls atomic.Int64
 | |
| 	load, err := New[string](ctx, ReloadOpts[string]{
 | |
| 		Logf:     tstest.WhileTestRunningLogger(t),
 | |
| 		Interval: 10 * time.Millisecond,
 | |
| 		Read: func(context.Context) ([]byte, error) {
 | |
| 			return []byte("hello"), nil
 | |
| 		},
 | |
| 		Unmarshal: func(b []byte) (string, error) {
 | |
| 			callNum := ncalls.Add(1)
 | |
| 			if callNum == 3 {
 | |
| 				cancel()
 | |
| 			}
 | |
| 			return fmt.Sprintf("call %d: %s", callNum, b), nil
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	want := "call 1: hello"
 | |
| 	if got := load(); got != want {
 | |
| 		t.Fatalf("got value %q, want %q", got, want)
 | |
| 	}
 | |
| 
 | |
| 	// Wait for the periodic refresh to cancel our context
 | |
| 	select {
 | |
| 	case <-ctx.Done():
 | |
| 	case <-time.After(10 * time.Second):
 | |
| 		t.Fatal("test timed out")
 | |
| 	}
 | |
| 
 | |
| 	// Depending on how goroutines get scheduled, we can either read call 2
 | |
| 	// (if we woke up before the run goroutine stores call 3), or call 3
 | |
| 	// (if we woke up after the run goroutine stores the next value). Check
 | |
| 	// for both.
 | |
| 	want1, want2 := "call 2: hello", "call 3: hello"
 | |
| 	if got := load(); got != want1 && got != want2 {
 | |
| 		t.Fatalf("got value %q, want %q or %q", got, want1, want2)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestFromJSONFile(t *testing.T) {
 | |
| 	type testStruct struct {
 | |
| 		Value  string
 | |
| 		Number int
 | |
| 	}
 | |
| 	fpath := filepath.Join(t.TempDir(), "test.json")
 | |
| 	if err := os.WriteFile(fpath, []byte(`{"Value": "hello", "Number": 1234}`), 0600); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	ctx := context.Background()
 | |
| 	r, err := newUnstarted(ctx, FromJSONFile[*testStruct](fpath))
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	got := r.store.Load()
 | |
| 	want := &testStruct{Value: "hello", Number: 1234}
 | |
| 	if !reflect.DeepEqual(got, want) {
 | |
| 		t.Errorf("got %+v, want %+v", got, want)
 | |
| 	}
 | |
| }
 |