mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-24 05:41:40 +02:00
Updates #16998 Change-Id: I735d75129a97a929092e9075107e41cdade18944 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
250 lines
6.1 KiB
Go
250 lines
6.1 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// Package policytest contains test helpers for the syspolicy packages.
|
|
package policytest
|
|
|
|
import (
|
|
"fmt"
|
|
"maps"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"tailscale.com/util/set"
|
|
"tailscale.com/util/syspolicy/pkey"
|
|
"tailscale.com/util/syspolicy/policyclient"
|
|
"tailscale.com/util/syspolicy/ptype"
|
|
)
|
|
|
|
// Config is a [policyclient.Client] implementation with a static mapping of
|
|
// values.
|
|
//
|
|
// It is used for testing purposes to simulate policy client behavior.
|
|
//
|
|
// It panics if a value is Set with one type and then accessed with a different
|
|
// expected type and/or value. Some accessors such as GetPreferenceOption and
|
|
// GetVisibility support either a ptype.PreferenceOption/ptype.Visibility in the
|
|
// map, or the string representation as supported by their UnmarshalText
|
|
// methods.
|
|
//
|
|
// The map value may be an error to return that error value from the accessor.
|
|
type Config map[pkey.Key]any
|
|
|
|
var _ policyclient.Client = Config{}
|
|
|
|
// Set sets key to value. The value should be of the correct type that it will
|
|
// be read as later. For PreferenceOption and Visibility, you may also set them
|
|
// to 'string' values and they'll be UnmarshalText'ed into their correct value
|
|
// at Get time.
|
|
//
|
|
// As a special case, the value can also be of type error to make the accessors
|
|
// return that error value.
|
|
func (c *Config) Set(key pkey.Key, value any) {
|
|
if *c == nil {
|
|
*c = make(map[pkey.Key]any)
|
|
}
|
|
(*c)[key] = value
|
|
|
|
if w, ok := (*c)[watchersKey].(*watchers); ok && key != watchersKey {
|
|
w.mu.Lock()
|
|
vals := slices.Collect(maps.Values(w.s))
|
|
w.mu.Unlock()
|
|
for _, f := range vals {
|
|
f(policyChange(key))
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetMultiple is a batch version of [Config.Set]. It copies the contents of o
|
|
// into c and does at most one notification wake-up for the whole batch.
|
|
func (c *Config) SetMultiple(o Config) {
|
|
if *c == nil {
|
|
*c = make(map[pkey.Key]any)
|
|
}
|
|
|
|
maps.Copy(*c, o)
|
|
|
|
if w, ok := (*c)[watchersKey].(*watchers); ok {
|
|
w.mu.Lock()
|
|
vals := slices.Collect(maps.Values(w.s))
|
|
w.mu.Unlock()
|
|
for _, f := range vals {
|
|
f(policyChanges(o))
|
|
}
|
|
}
|
|
}
|
|
|
|
type policyChange pkey.Key
|
|
|
|
func (pc policyChange) HasChanged(v pkey.Key) bool { return pkey.Key(pc) == v }
|
|
func (pc policyChange) HasChangedAnyOf(keys ...pkey.Key) bool {
|
|
return slices.Contains(keys, pkey.Key(pc))
|
|
}
|
|
|
|
type policyChanges map[pkey.Key]any
|
|
|
|
func (pc policyChanges) HasChanged(v pkey.Key) bool {
|
|
_, ok := pc[v]
|
|
return ok
|
|
}
|
|
func (pc policyChanges) HasChangedAnyOf(keys ...pkey.Key) bool {
|
|
for _, k := range keys {
|
|
if pc.HasChanged(k) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
const watchersKey = "_policytest_watchers"
|
|
|
|
type watchers struct {
|
|
mu sync.Mutex
|
|
s set.HandleSet[func(policyclient.PolicyChange)]
|
|
}
|
|
|
|
// EnableRegisterChangeCallback makes c support the RegisterChangeCallback
|
|
// for testing. Without calling this, the RegisterChangeCallback does nothing.
|
|
// For watchers to be notified, use the [Config.Set] method. Changing the map
|
|
// directly obviously wouldn't work.
|
|
func (c *Config) EnableRegisterChangeCallback() {
|
|
if _, ok := (*c)[watchersKey]; !ok {
|
|
c.Set(watchersKey, new(watchers))
|
|
}
|
|
}
|
|
|
|
func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case []string:
|
|
return val, nil
|
|
case error:
|
|
return nil, val
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a []string; got %T", key, val))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case string:
|
|
return val, nil
|
|
case error:
|
|
return "", val
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a string; got %T", key, val))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case bool:
|
|
return val, nil
|
|
case error:
|
|
return false, val
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a bool; got %T", key, val))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case uint64:
|
|
return val, nil
|
|
case error:
|
|
return 0, val
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a uint64; got %T", key, val))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case time.Duration:
|
|
return val, nil
|
|
case error:
|
|
return 0, val
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a time.Duration; got %T", key, val))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case ptype.PreferenceOption:
|
|
return val, nil
|
|
case error:
|
|
var zero ptype.PreferenceOption
|
|
return zero, val
|
|
case string:
|
|
var p ptype.PreferenceOption
|
|
err := p.UnmarshalText(([]byte)(val))
|
|
return p, err
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key))
|
|
}
|
|
}
|
|
return defaultVal, nil
|
|
}
|
|
|
|
func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) {
|
|
if val, ok := c[key]; ok {
|
|
switch val := val.(type) {
|
|
case ptype.Visibility:
|
|
return val, nil
|
|
case error:
|
|
var zero ptype.Visibility
|
|
return zero, val
|
|
case string:
|
|
var p ptype.Visibility
|
|
err := p.UnmarshalText(([]byte)(val))
|
|
return p, err
|
|
default:
|
|
panic(fmt.Sprintf("key %s is not a ptype.Visibility", key))
|
|
}
|
|
}
|
|
return ptype.Visibility(ptype.ShowChoiceByPolicy), nil
|
|
}
|
|
|
|
func (c Config) HasAnyOf(keys ...pkey.Key) (bool, error) {
|
|
for _, key := range keys {
|
|
if _, ok := c[key]; ok {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (c Config) RegisterChangeCallback(callback func(policyclient.PolicyChange)) (func(), error) {
|
|
w, ok := c[watchersKey].(*watchers)
|
|
if !ok {
|
|
return func() {}, nil
|
|
}
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
h := w.s.Add(callback)
|
|
return func() {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
delete(w.s, h)
|
|
}, nil
|
|
}
|
|
|
|
func (sp Config) SetDebugLoggingEnabled(enabled bool) {}
|