diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 825b33fac..a51b6fc49 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -10,7 +10,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw W 💣 github.com/dblohm7/wingoes from tailscale.com/util/winutil github.com/fxamacker/cbor/v2 from tailscale.com/tka - github.com/go-json-experiment/json from tailscale.com/types/opt + github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ @@ -146,9 +146,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/cloudenv from tailscale.com/hostinfo+ W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy tailscale.com/util/ctxkey from tailscale.com/tsweb+ + 💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/hostinfo+ tailscale.com/util/fastuuid from tailscale.com/tsweb + 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale tailscale.com/util/lineread from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns @@ -159,8 +161,17 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/cmd/derper+ tailscale.com/util/syspolicy from tailscale.com/ipn + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/util/syspolicy+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ + W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/derp+ tailscale.com/version/distro from tailscale.com/envknob+ @@ -180,6 +191,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ W golang.org/x/exp/constraints from tailscale.com/util/winutil + golang.org/x/exp/maps from tailscale.com/util/syspolicy/internal/metrics+ L golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http @@ -240,7 +252,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/derper + flag from tailscale.com/cmd/derper+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -273,7 +285,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa os from crypto/rand+ os/exec from github.com/coreos/go-iptables/iptables+ os/signal from tailscale.com/cmd/derper - W os/user from tailscale.com/util/winutil + W os/user from tailscale.com/util/winutil+ path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index d8bca1130..261ae9d6a 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -96,7 +96,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ 💣 github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/gaissmai/bart from tailscale.com/net/ipset+ - github.com/go-json-experiment/json from tailscale.com/types/opt + github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+ @@ -803,6 +803,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/singleflight from tailscale.com/control/controlclient+ tailscale.com/util/slicesx from tailscale.com/appc+ tailscale.com/util/syspolicy from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/control/controlclient+ @@ -811,7 +818,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/vizerror from tailscale.com/tailcfg+ 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns + W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index c03be655d..9e13b43f2 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -9,7 +9,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+ W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode github.com/fxamacker/cbor/v2 from tailscale.com/tka - github.com/go-json-experiment/json from tailscale.com/types/opt + github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ @@ -152,9 +152,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ tailscale.com/util/ctxkey from tailscale.com/types/logger + 💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/groupmember from tailscale.com/client/web + 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale+ tailscale.com/util/lineread from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns @@ -167,11 +169,19 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/singleflight from tailscale.com/net/dnscache+ tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/ipn - tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/vizerror from tailscale.com/tailcfg+ 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate + W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ @@ -191,7 +201,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ - golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli + golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+ golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index f6edbe7d7..cc9c1d7c2 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -90,7 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 github.com/djherbis/times from tailscale.com/drive/driveimpl github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/gaissmai/bart from tailscale.com/net/tstun+ - github.com/go-json-experiment/json from tailscale.com/types/opt + github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+ @@ -395,6 +395,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/singleflight from tailscale.com/control/controlclient+ tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+ @@ -403,7 +410,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/vizerror from tailscale.com/tailcfg+ 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns + W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index b5c22e54b..7f3cfaa08 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -52,6 +52,8 @@ import ( "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -2546,6 +2548,14 @@ func TestPreferencePolicyInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + definitions := make([]*setting.Definition, 0, len(preferencePolicies)+1) + definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(syspolicy.ControlURL))) + for _, pp := range preferencePolicies { + definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(pp.key))) + } + if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } for _, pp := range preferencePolicies { t.Run(string(pp.key), func(t *testing.T) { var h syspolicy.Handler @@ -2572,7 +2582,7 @@ func TestPreferencePolicyInfo(t *testing.T) { msh.stringPolicies[pp.key] = &tt.policyValue h = msh } - syspolicy.SetHandlerForTest(t, h) + rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, syspolicy.WrapHandler(h)) prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go deleted file mode 100644 index 5192958bc..000000000 --- a/util/syspolicy/caching_handler.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "sync" -) - -// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested -// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached, -// otherwise the actual error is returned and the next read for that key will retry using the handler. -type CachingHandler struct { - mu sync.Mutex - strings map[string]string - uint64s map[string]uint64 - bools map[string]bool - strArrs map[string][]string - notFound map[string]bool - handler Handler -} - -// NewCachingHandler creates a CachingHandler given a handler. -func NewCachingHandler(handler Handler) *CachingHandler { - return &CachingHandler{ - handler: handler, - strings: make(map[string]string), - uint64s: make(map[string]uint64), - bools: make(map[string]bool), - strArrs: make(map[string][]string), - notFound: make(map[string]bool), - } -} - -// ReadString reads the policy settings value string given the key. -// ReadString first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadString(key string) (string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strings[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return "", ErrNoSuchKey - } - val, err := ch.handler.ReadString(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return "", err - } else if err != nil { - return "", err - } - ch.strings[key] = val - return val, nil -} - -// ReadUInt64 reads the policy settings uint64 value given the key. -// ReadUInt64 first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.uint64s[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return 0, ErrNoSuchKey - } - val, err := ch.handler.ReadUInt64(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return 0, err - } else if err != nil { - return 0, err - } - ch.uint64s[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.bools[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return false, ErrNoSuchKey - } - val, err := ch.handler.ReadBoolean(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return false, err - } else if err != nil { - return false, err - } - ch.bools[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strArrs[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return nil, ErrNoSuchKey - } - val, err := ch.handler.ReadStringArray(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return nil, err - } else if err != nil { - return nil, err - } - ch.strArrs[key] = val - return val, nil -} diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go deleted file mode 100644 index 881f6ff83..000000000 --- a/util/syspolicy/caching_handler_test.go +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "testing" -) - -func TestHandlerReadString(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue string - handlerError error - preserveHandler bool - wantValue string - wantErr error - strings map[string]string - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - strings: map[string]string{"test": "foo"}, - wantValue: "foo", - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: "foo", - wantValue: "foo", - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - s: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.strings != nil { - cache.strings = tt.strings - } - got, err := cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } -} - -func TestHandlerReadUint64(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue uint64 - handlerError error - preserveHandler bool - wantValue uint64 - wantErr error - uint64s map[string]uint64 - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - uint64s: map[string]uint64{"test": 1}, - wantValue: 1, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: 1, - wantValue: 1, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - u64: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.uint64s != nil { - cache.uint64s = tt.uint64s - } - got, err := cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} - -func TestHandlerReadBool(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue bool - handlerError error - preserveHandler bool - wantValue bool - wantErr error - bools map[string]bool - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - bools: map[string]bool{"test": true}, - wantValue: true, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: true, - wantValue: true, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - b: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.bools != nil { - cache.bools = tt.bools - } - got, err := cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} diff --git a/util/syspolicy/handler.go b/util/syspolicy/handler.go index f1fad9770..0671dc058 100644 --- a/util/syspolicy/handler.go +++ b/util/syspolicy/handler.go @@ -4,16 +4,15 @@ package syspolicy import ( - "errors" - "sync/atomic" -) - -var ( - handlerUsed atomic.Bool - handler Handler = defaultHandler{} + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) // Handler reads system policies from OS-specific storage. +// +// Deprecated: implementing a [Store] should be preferred. type Handler interface { // ReadString reads the policy setting's string value for the given key. // It should return ErrNoSuchKey if the key does not have a value set. @@ -29,55 +28,81 @@ type Handler interface { ReadStringArray(key string) ([]string, error) } -// ErrNoSuchKey is returned by a Handler when the specified key does not have a -// value set. -var ErrNoSuchKey = errors.New("no such key") - -// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple. -type defaultHandler struct{} - -func (defaultHandler) ReadString(_ string) (string, error) { - return "", ErrNoSuchKey -} - -func (defaultHandler) ReadUInt64(_ string) (uint64, error) { - return 0, ErrNoSuchKey -} - -func (defaultHandler) ReadBoolean(_ string) (bool, error) { - return false, ErrNoSuchKey -} - -func (defaultHandler) ReadStringArray(_ string) ([]string, error) { - return nil, ErrNoSuchKey -} - -// markHandlerInUse is called before handler methods are called. -func markHandlerInUse() { - handlerUsed.Store(true) -} - -// RegisterHandler initializes the policy handler and ensures registration will happen once. +// RegisterHandler wraps and registers the specified handler as the device's +// policy [Store] for the program's lifetime. +// +// Deprecated: using [RegisterStore] should be preferred. func RegisterHandler(h Handler) { - // Technically this assignment is not concurrency safe, but in the - // event that there was any risk of a data race, we will panic due to - // the CompareAndSwap failing. - handler = h - if !handlerUsed.CompareAndSwap(false, true) { - panic("handler was already used before registration") - } + rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h)) } // TB is a subset of testing.TB that we use to set up test helpers. // It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) +type TB = internal.TB + +// SetHandlerForTest wraps and sets the specified handler as the device's policy +// [Store] for the duration of tb. +// +// Deprecated: using [resultant.RegisterStoreForTest] should be preferred. +func SetHandlerForTest(tb TB, h Handler) { + if err := setWellKnownSettingsForTest(tb); err != nil { + tb.Fatalf("setWellKnownSettingsForTest failed: %v", err) + } + rsop.RegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.CurrentScope(), WrapHandler(h)) } -func SetHandlerForTest(tb TB, h Handler) { - tb.Helper() - oldHandler := handler - handler = h - tb.Cleanup(func() { handler = oldHandler }) +var _ source.Store = (*handlerStore)(nil) + +// handlerStore is a [source.Store] that calls the underlying [Handler]. +// TODO(nickkhyl): remove it when the corp and android repos are updated. +type handlerStore struct { + h Handler +} + +// WrapHandler returns a [source.Store] that wraps the specified [Handler]. +func WrapHandler(h Handler) source.Store { + return handlerStore{h} +} + +func (s handlerStore) Lock() error { + if lockable, ok := s.h.(source.Lockable); ok { + return lockable.Lock() + } + return nil +} + +func (s handlerStore) Unlock() { + if lockable, ok := s.h.(source.Lockable); ok { + lockable.Unlock() + } +} + +func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) { + if lockable, ok := s.h.(source.Changeable); ok { + return lockable.RegisterChangeCallback(callback) + } + return func() {}, nil +} + +func (s handlerStore) ReadString(key setting.Key) (string, error) { + return s.h.ReadString(string(key)) +} + +func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) { + return s.h.ReadUInt64(string(key)) +} + +func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) { + return s.h.ReadBoolean(string(key)) +} + +func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) { + return s.h.ReadStringArray(string(key)) +} + +func (s handlerStore) Done() <-chan struct{} { + if expirable, ok := s.h.(source.Expirable); ok { + return expirable.Done() + } + return nil } diff --git a/util/syspolicy/handler_test.go b/util/syspolicy/handler_test.go deleted file mode 100644 index 39b18936f..000000000 --- a/util/syspolicy/handler_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import "testing" - -func TestDefaultHandlerReadValues(t *testing.T) { - var h defaultHandler - - got, err := h.ReadString(string(AdminConsoleVisibility)) - if got != "" || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", got, err) - } - result, err := h.ReadUInt64(string(LogSCMInteractions)) - if result != 0 || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", result, err) - } -} diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go deleted file mode 100644 index 661853ead..000000000 --- a/util/syspolicy/handler_windows.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "fmt" - - "tailscale.com/util/clientmetric" - "tailscale.com/util/winutil" -) - -var ( - windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors") - windowsAny = clientmetric.NewGauge("windows_syspolicy_any") -) - -type windowsHandler struct{} - -func init() { - RegisterHandler(NewCachingHandler(windowsHandler{})) - - keyList := []struct { - isSet func(Key) bool - keys []Key - }{ - { - isSet: func(k Key) bool { - _, err := handler.ReadString(string(k)) - return err == nil - }, - keys: stringKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadBoolean(string(k)) - return err == nil - }, - keys: boolKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadUInt64(string(k)) - return err == nil - }, - keys: uint64Keys, - }, - } - - var anySet bool - for _, l := range keyList { - for _, k := range l.keys { - if !l.isSet(k) { - continue - } - clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1) - anySet = true - } - } - if anySet { - windowsAny.Set(1) - } -} - -func (windowsHandler) ReadString(key string) (string, error) { - s, err := winutil.GetPolicyString(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - - return s, err -} - -func (windowsHandler) ReadUInt64(key string) (uint64, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} - -func (windowsHandler) ReadBoolean(key string) (bool, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value != 0, err -} - -func (windowsHandler) ReadStringArray(key string) ([]string, error) { - value, err := winutil.GetPolicyStringArray(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go new file mode 100644 index 000000000..4c3e28d39 --- /dev/null +++ b/util/syspolicy/internal/internal.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package internal contains miscellaneous functions and types +// that are internal to the syspolicy packages. +package internal + +import ( + "bytes" + + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/lazy" + "tailscale.com/version" +) + +// OSForTesting is the operating system override used for testing. +// It follows the same naming convention as [version.OS]. +var OSForTesting lazy.SyncValue[string] + +// OS is like [version.OS], but supports a test hook. +func OS() string { + return OSForTesting.Get(version.OS) +} + +// TB is a subset of testing.TB that we use to set up test helpers. +// It's defined here to avoid pulling in the testing package. +type TB interface { + Helper() + Cleanup(func()) + Logf(format string, args ...any) + Error(args ...any) + Errorf(format string, args ...any) + Fatal(args ...any) + Fatalf(format string, args ...any) +} + +// EqualJSONForTest compares the JSON in j1 and j2 for semantic equality. +// It returns "", "", true if j1 and j2 are equal. Otherwise, it returns +// indented versions of j1 and j2 and false. +func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) { + tb.Helper() + j1 = j1.Clone() + j2 = j2.Clone() + // Canonicalize JSON values for comparison. + if err := j1.Canonicalize(); err != nil { + tb.Error(err) + } + if err := j2.Canonicalize(); err != nil { + tb.Error(err) + } + // Check and return true if the two values are structurally equal. + if bytes.Equal(j1, j2) { + return "", "", true + } + // Otherwise, format the values for display and return false. + if err := j1.Indent("", "\t"); err != nil { + tb.Fatal(err) + } + if err := j2.Indent("", "\t"); err != nil { + tb.Fatal(err) + } + return j1.String(), j2.String(), false +} diff --git a/util/syspolicy/internal/lazyinit/lazyinit.go b/util/syspolicy/internal/lazyinit/lazyinit.go new file mode 100644 index 000000000..94c16c238 --- /dev/null +++ b/util/syspolicy/internal/lazyinit/lazyinit.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The lazyinit package facilitates deferred package initialization. +package lazyinit + +import ( + "sync" + "sync/atomic" +) + +var packageInit deferredOnce + +// Defer defers the specified action until [Do] is called. +// It returns a boolean indicating whether [Do] has already been called. +func Defer(action func() error) bool { + return packageInit.Defer(action) +} + +// DeferWithCleanup is like [Defer], but the action function returns a cleanup +// function to be called in case of an error. +func DeferWithCleanup(action func() (cleanup func(), err error)) bool { + return packageInit.DeferWithCleanup(action) +} + +// Do runs all deferred functions and returns an error if any of them fail. +func Do() error { + return packageInit.Do() +} + +type deferredOnce struct { + done atomic.Uint32 + err error + m sync.Mutex + funcs []func() (cleanup func(), err error) +} + +func (o *deferredOnce) Defer(action func() error) bool { + return o.DeferWithCleanup(func() (cleanup func(), err error) { + return nil, action() + }) +} + +func (o *deferredOnce) DeferWithCleanup(action func() (cleanup func(), err error)) bool { + o.m.Lock() + defer o.m.Unlock() + if o.done.Load() != 0 { + return false + } + o.funcs = append(o.funcs, action) + return true +} + +func (o *deferredOnce) Do() error { + if o.done.Load() == 0 { + o.doSlow() + } + return o.err +} + +func (o *deferredOnce) doSlow() (err error) { + o.m.Lock() + defer o.m.Unlock() + if o.done.Load() == 0 { + defer func() { + o.done.Store(1) + o.err = err + }() + for _, f := range o.funcs { + cleanup, err := f() + if err != nil { + return err + } + if cleanup != nil { + defer func() { + if err != nil { + cleanup() + } + }() + } + } + } + return o.err +} diff --git a/util/syspolicy/internal/loggerx/logger.go b/util/syspolicy/internal/loggerx/logger.go new file mode 100644 index 000000000..b28610826 --- /dev/null +++ b/util/syspolicy/internal/loggerx/logger.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package loggerx provides logging functions to the rest of the syspolicy packages. +package loggerx + +import ( + "log" + + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal" +) + +const ( + errorPrefix = "syspolicy: " + verbosePrefix = "syspolicy: [v2] " +) + +var ( + lazyErrorf lazy.SyncValue[logger.Logf] + lazyVerbosef lazy.SyncValue[logger.Logf] +) + +// Errorf formats and writes an error message to the log. +func Errorf(format string, args ...any) { + errorf := lazyErrorf.Get(func() logger.Logf { + return logger.WithPrefix(log.Printf, errorPrefix) + }) + errorf(format, args...) +} + +// Verbosef formats and writes an optional, verbose message to the log. +func Verbosef(format string, args ...any) { + verbosef := lazyVerbosef.Get(func() logger.Logf { + return logger.WithPrefix(log.Printf, verbosePrefix) + }) + verbosef(format, args...) +} + +// SetForTest sets the specified errorf and verbosef functions for the duration +// of tb and its subtests. +func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) { + lazyErrorf.SetForTest(tb, errorf, nil) + lazyVerbosef.SetForTest(tb, verbosef, nil) +} diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go new file mode 100644 index 000000000..4f2bf5396 --- /dev/null +++ b/util/syspolicy/internal/metrics/metrics.go @@ -0,0 +1,315 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package metrics provides logging and reporting for policy settings and scopes. +package metrics + +import ( + "strings" + "sync" + + xmaps "golang.org/x/exp/maps" + + "tailscale.com/syncs" + "tailscale.com/types/lazy" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) + +var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook + +// ShouldReport reports whether metrics should be reported on the current environment. +func ShouldReport() bool { + return lazyReportMetrics.Get(func() bool { + // macOS, iOS and tvOS create their own metrics, + // and we don't have syspolicy on any other platforms. + return setting.PlatformList{"android", "windows"}.HasCurrent() + }) +} + +// Reset metrics for the specified policy origin. +func Reset(origin *setting.Origin) { + scopeMetrics(origin).Reset() +} + +// ReportConfigured updates metrics and logs that the specified setting is +// configured with the given value in the origin. +func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) { + settingMetricsFor(setting).ReportValue(origin, value) +} + +// ReportError updates metrics and logs that the specified setting has an error +// in the origin. +func ReportError(origin *setting.Origin, setting *setting.Definition, err error) { + settingMetricsFor(setting).ReportError(origin, err) +} + +// ReportNotConfigured updates metrics and logs that the specified setting is +// not configured in the origin. +func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) { + settingMetricsFor(setting).Reset(origin) +} + +// metric is an interface implemented by [clientmetric.Metric] and [funcMetric]. +type metric interface { + Add(v int64) + Set(v int64) +} + +// policyScopeMetrics are metrics that apply to an entire policy scope rather +// than a specific policy setting. +type policyScopeMetrics struct { + hasAny metric + numErrored metric +} + +func newScopeMetrics(scope setting.Scope) *policyScopeMetrics { + prefix := metricScopeName(scope) + if prefix != "" { + prefix += "_" + } + // {os}_syspolicy_{scope_unless_device}_any + // Example: windows_syspolicy_any or windows_syspolicy_user_any. + hasAny := newMetric(prefix+"any", clientmetric.TypeGauge) + // {os}_syspolicy_{scope_unless_device}_errors + // Example: windows_syspolicy_errors or windows_syspolicy_user_errors. + // + // TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter? + // It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such. + // But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected + // policy value type or format and the actual value read from the underlying store (like the Windows Registry). + // We'll encounter the same error every time we re-read the policy setting from the backing store + // until the policy value is corrected by the user, or until we fix the bug in the code or ADMX. + // There's probably no reason to count and accumulate them over time. + numErrored := newMetric(prefix+"errors", clientmetric.TypeCounter) + return &policyScopeMetrics{hasAny, numErrored} +} + +// ReportHasSettings is called when there's any configured policy setting in the scope. +func (m *policyScopeMetrics) ReportHasSettings() { + if m != nil { + m.hasAny.Set(1) + } +} + +// ReportError is called when there's any errored policy setting in the scope. +func (m *policyScopeMetrics) ReportError() { + if m != nil { + m.numErrored.Add(1) + } +} + +// Reset is called to reset the policy scope metrics, such as when the policy scope +// is about to be reloaded. +func (m *policyScopeMetrics) Reset() { + if m != nil { + m.hasAny.Set(0) + // numErrored is a counter and cannot be (re-)set. + } +} + +// settingMetrics are metrics for a single policy setting in one or more scopes. +type settingMetrics struct { + definition *setting.Definition + isSet []metric // by scope + hasErrors []metric // by scope +} + +// ReportValue is called when the policy setting is found to be configured in the specified source. +func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) { + m.isSet[scope].Set(1) + m.hasErrors[scope].Set(0) + } + scopeMetrics(origin).ReportHasSettings() + loggerx.Verbosef("%v(%q) = %v\n", origin, m.definition.Key(), v) +} + +// ReportError is called when there's an error with the policy setting in the specified source. +func (m *settingMetrics) ReportError(origin *setting.Origin, err error) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) { + m.isSet[scope].Set(0) + m.hasErrors[scope].Set(1) + } + scopeMetrics(origin).ReportError() + loggerx.Errorf("%v(%q): %v\n", origin, m.definition.Key(), err) +} + +// Reset is called to reset the policy setting's metrics, such as when +// the policy setting does not exist or the source containing the policy +// is about to be reloaded. +func (m *settingMetrics) Reset(origin *setting.Origin) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) { + m.isSet[scope].Set(0) + m.hasErrors[scope].Set(0) + } +} + +// metricFn is a function that adds or sets a metric value. +type metricFn = func(name string, typ clientmetric.Type, v int64) + +// funcMetric implements [metric] by calling the specified add and set functions. +// Used for testing, and with nil functions on platforms that do not support +// syspolicy, and on platforms that report policy metrics from the GUI. +type funcMetric struct { + name string + typ clientmetric.Type + add, set metricFn +} + +func (m funcMetric) Add(v int64) { + if m.add != nil { + m.add(m.name, m.typ, v) + } +} + +func (m funcMetric) Set(v int64) { + if m.set != nil { + m.set(m.name, m.typ, v) + } +} + +var ( + lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics] + lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics] + lazyUserMetrics lazy.SyncValue[*policyScopeMetrics] +) + +func scopeMetrics(origin *setting.Origin) *policyScopeMetrics { + switch origin.Scope().Kind() { + case setting.DeviceSetting: + return lazyDeviceMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.DeviceSetting) + }) + case setting.ProfileSetting: + return lazyProfileMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.ProfileSetting) + }) + case setting.UserSetting: + return lazyUserMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.UserSetting) + }) + default: + panic("unreachable") + } +} + +var ( + settingMetricsMu sync.RWMutex + settingMetricsMap map[setting.Key]*settingMetrics +) + +func settingMetricsFor(setting *setting.Definition) *settingMetrics { + settingMetricsMu.RLock() + if metrics, ok := settingMetricsMap[setting.Key()]; ok { + settingMetricsMu.RUnlock() + return metrics + } + settingMetricsMu.RUnlock() + return settingMetricsForSlow(setting) +} + +func settingMetricsForSlow(d *setting.Definition) *settingMetrics { + settingMetricsMu.Lock() + defer settingMetricsMu.Unlock() + if metrics, ok := settingMetricsMap[d.Key()]; ok { + return metrics + } + + isSet := make([]metric, d.Scope()+1) + hasErrors := make([]metric, d.Scope()+1) + for i := range isSet { + scope := setting.Scope(i) + // {os}_syspolicy_{key}_{scope_unless_device} + // Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user. + isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge) + // {os}_syspolicy_{key}_{scope_unless_device}_error + // Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error. + hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge) + } + metrics := &settingMetrics{d, isSet, hasErrors} + mak.Set(&settingMetricsMap, d.Key(), metrics) + return metrics +} + +// hooks for testing +var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn] + +// SetHooksForTest sets the specified addMetric and setMetric functions +// as the metric functions for the duration of tb and all its subtests. +func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { + oldAddMetric := addMetricTestHook.Swap(addMetric) + oldSetMetric := setMetricTestHook.Swap(setMetric) + tb.Cleanup(func() { + addMetricTestHook.Store(oldAddMetric) + setMetricTestHook.Store(oldSetMetric) + }) + + settingMetricsMu.Lock() + oldSettingMetricsMap := xmaps.Clone(settingMetricsMap) + clear(settingMetricsMap) + settingMetricsMu.Unlock() + tb.Cleanup(func() { + settingMetricsMu.Lock() + settingMetricsMap = oldSettingMetricsMap + settingMetricsMu.Unlock() + }) + + // (re-)set the scope metrics to use the test hooks for the duration of tb. + lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil) + lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil) + lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil) +} + +func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { + name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") + if tag := metricScopeName(scope); tag != "" { + name += "_" + tag + } + if suffix != "" { + name += "_" + suffix + } + return newMetric(name, typ) +} + +func newMetric(name string, typ clientmetric.Type) metric { + name = internal.OS() + "_syspolicy_" + name + switch { + case !ShouldReport(): + return &funcMetric{name: name, typ: typ} + case testenv.InTest(): + return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()} + case typ == clientmetric.TypeCounter: + return clientmetric.NewCounter(name) + case typ == clientmetric.TypeGauge: + return clientmetric.NewGauge(name) + default: + panic("unreachable") + } +} + +func metricScopeName(scope setting.Scope) string { + switch scope { + case setting.DeviceSetting: + return "" + case setting.ProfileSetting: + return "profile" + case setting.UserSetting: + return "user" + default: + panic("unreachable") + } +} diff --git a/util/syspolicy/internal/metrics/metrics_test.go b/util/syspolicy/internal/metrics/metrics_test.go new file mode 100644 index 000000000..07be4773c --- /dev/null +++ b/util/syspolicy/internal/metrics/metrics_test.go @@ -0,0 +1,423 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "errors" + "testing" + + "tailscale.com/types/lazy" + "tailscale.com/util/clientmetric" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" +) + +func TestSettingMetricNames(t *testing.T) { + tests := []struct { + name string + key setting.Key + scope setting.Scope + suffix string + typ clientmetric.Type + osOverride string + wantMetricName string + }{ + { + name: "windows-device-no-suffix", + key: "AdminConsole", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole", + }, + { + name: "windows-user-no-suffix", + key: "AdminConsole", + scope: setting.UserSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_user", + }, + { + name: "windows-profile-no-suffix", + key: "AdminConsole", + scope: setting.ProfileSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_profile", + }, + { + name: "windows-profile-err", + key: "AdminConsole", + scope: setting.ProfileSetting, + suffix: "error", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_profile_error", + }, + { + name: "android-device-no-suffix", + key: "AdminConsole", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "android", + wantMetricName: "android_syspolicy_AdminConsole", + }, + { + name: "key-path", + key: "category/subcategory/setting", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "fakeos", + wantMetricName: "fakeos_syspolicy_category_subcategory_setting", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric) + if !ok { + t.Fatal("metric is not a funcMetric") + } + if metric.name != tt.wantMetricName { + t.Errorf("got %q, want %q", metric.name, tt.wantMetricName) + } + }) + } +} + +func TestScopeMetrics(t *testing.T) { + tests := []struct { + name string + scope setting.Scope + osOverride string + wantHasAnyName string + wantNumErroredName string + wantHasAnyType clientmetric.Type + wantNumErroredType clientmetric.Type + }{ + { + name: "windows-device", + scope: setting.DeviceSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "windows-profile", + scope: setting.ProfileSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_profile_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_profile_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "windows-user", + scope: setting.UserSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_user_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_user_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "android-device", + scope: setting.DeviceSetting, + osOverride: "android", + wantHasAnyName: "android_syspolicy_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "android_syspolicy_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + metrics := newScopeMetrics(tt.scope) + hasAny, ok := metrics.hasAny.(*funcMetric) + if !ok { + t.Fatal("hasAny is not a funcMetric") + } + numErrored, ok := metrics.numErrored.(*funcMetric) + if !ok { + t.Fatal("numErrored is not a funcMetric") + } + if hasAny.name != tt.wantHasAnyName { + t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName) + } + if hasAny.typ != tt.wantHasAnyType { + t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType) + } + if numErrored.name != tt.wantNumErroredName { + t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName) + } + if numErrored.typ != tt.wantNumErroredType { + t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType) + } + }) + } +} + +type testSettingDetails struct { + definition *setting.Definition + origin *setting.Origin + value any + err error +} + +func TestReportMetrics(t *testing.T) { + tests := []struct { + name string + osOverride string + useMetrics bool + settings []testSettingDetails + wantMetrics []TestState + wantResetMetrics []TestState + }{ + { + name: "none", + osOverride: "windows", + settings: []testSettingDetails{}, + wantMetrics: []TestState{}, + }, + { + name: "single-value", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + }, + }, + { + name: "single-error", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "value-and-error", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "two-values", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 17, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02", 0}, + }, + }, + { + name: "two-errors", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_errors", 2}, + {"windows_syspolicy_TestSetting01_error", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_errors", 2}, + {"windows_syspolicy_TestSetting01_error", 0}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "multi-scope", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.CurrentProfileScope), + err: errors.New("bang!"), + }, + { + definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.CurrentUserScope), + value: 17, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_profile_errors", 1}, + {"windows_syspolicy_user_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02_profile_error", 1}, + {"windows_syspolicy_TestSetting03_user", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_profile_errors", 1}, + {"windows_syspolicy_user_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02_profile_error", 0}, + {"windows_syspolicy_TestSetting03_user", 0}, + }, + }, + { + name: "report-metrics-on-android", + osOverride: "android", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + wantMetrics: []TestState{ + {"android_syspolicy_any", 1}, + {"android_syspolicy_TestSetting01", 1}, + }, + wantResetMetrics: []TestState{ + {"android_syspolicy_any", 0}, + {"android_syspolicy_TestSetting01", 0}, + }, + }, + { + name: "do-not-report-metrics-on-macos", + osOverride: "macos", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + + wantMetrics: []TestState{}, // none reported + }, + { + name: "do-not-report-metrics-on-ios", + osOverride: "ios", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + + wantMetrics: []TestState{}, // none reported + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the lazy value so it'll be re-evaluated with the osOverride. + lazyReportMetrics = lazy.SyncValue[bool]{} + t.Cleanup(func() { + // Also reset it during the cleanup. + lazyReportMetrics = lazy.SyncValue[bool]{} + }) + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + + h := NewTestHandler(t) + SetHooksForTest(t, h.AddMetric, h.SetMetric) + + for _, s := range tt.settings { + if s.err != nil { + ReportError(s.origin, s.definition, s.err) + } else { + ReportConfigured(s.origin, s.definition, s.value) + } + } + h.MustEqual(tt.wantMetrics...) + + for _, s := range tt.settings { + Reset(s.origin) + ReportNotConfigured(s.origin, s.definition) + } + h.MustEqual(tt.wantResetMetrics...) + }) + } +} diff --git a/util/syspolicy/internal/metrics/test_handler.go b/util/syspolicy/internal/metrics/test_handler.go new file mode 100644 index 000000000..50ee42bbe --- /dev/null +++ b/util/syspolicy/internal/metrics/test_handler.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "strings" + + "tailscale.com/util/clientmetric" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal" +) + +// TestState represents a metric name and its expected value. +type TestState struct { + Name string // `$os` in the name will be replaced by the actual operating system name.` + Value int64 +} + +// TestHandler facilitates testing of the code that uses metrics. +type TestHandler struct { + t internal.TB + + m map[string]int64 +} + +// NewTestHandler returns a new TestHandler. +func NewTestHandler(t internal.TB) *TestHandler { + return &TestHandler{t, make(map[string]int64)} +} + +// AddMetric increments the metric with the specified name and type by delta d. +func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) { + h.t.Helper() + if typ == clientmetric.TypeCounter && d < 0 { + h.t.Fatalf("an attempt was made to decrement a counter metric %q", name) + } + if v, ok := h.m[name]; ok || d != 0 { + h.m[name] = v + d + } +} + +// SetMetric sets the metric with the specified name and type to the value v. +func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) { + h.t.Helper() + if typ == clientmetric.TypeCounter { + h.t.Fatalf("an attempt was made to set a counter metric %q", name) + } + if _, ok := h.m[name]; ok || v != 0 { + h.m[name] = v + } +} + +// MustEqual fails the test if the actual metric state differs from the specified state. +func (h *TestHandler) MustEqual(metrics ...TestState) { + h.t.Helper() + h.MustContain(metrics...) + h.mustNoExtra(metrics...) +} + +// MustContain fails the test if the specified metrics are not set or have +// different values than specified. It permits other metrics to be set in +// addition to the ones being tested. +func (h *TestHandler) MustContain(metrics ...TestState) { + h.t.Helper() + for _, m := range metrics { + name := strings.ReplaceAll(m.Name, "$os", internal.OS()) + v, ok := h.m[name] + if !ok { + h.t.Errorf("%q: got (none), want %v", name, m.Value) + } else if v != m.Value { + h.t.Fatalf("%q: got %v, want %v", name, v, m.Value) + } + } +} + +func (h *TestHandler) mustNoExtra(metrics ...TestState) { + h.t.Helper() + s := make(set.Set[string]) + for i := range metrics { + s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS())) + } + for n, v := range h.m { + if !s.Contains(n) { + h.t.Errorf("%q: got %v, want (none)", n, v) + } + } +} diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go index ef0cfed8f..cf5685c01 100644 --- a/util/syspolicy/policy_keys.go +++ b/util/syspolicy/policy_keys.go @@ -3,7 +3,21 @@ package syspolicy -type Key string +import ( + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal/lazyinit" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) + +type Key = setting.Key + +// The const block below lists known policy keys. +// When adding a key to this list, remember to add a corresponding +// [setting.Definition] to [implicitDefinitions] below. +// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder. +// Preferably, use a strongly typed policy hierarchy, such as [Policy], +// instead of adding each key to the list below. const ( // Keys with a string value @@ -96,3 +110,83 @@ const ( // AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes. AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes" ) + +// implicitDefinitions is a list of [setting.Definition] that will be registered +// automatically by [settingDefinitions] as soon as the package needs to ready a policy. +var implicitDefinitions = []*setting.Definition{ + // Device policy settings + setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue), + + // User policy settings + setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue), + setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue), + setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue), +} + +func init() { + lazyinit.Defer(func() error { + // Avoid implicit [SettingDefinition] registration during tests. + // Each test should control which policy settings to register. + // Use [setting.SetDefinitionsForTest] to specify necessary definitions, + // or [setWellKnownSettingsForTest] to set implicit definitions for the test duration. + if testenv.InTest() { + return nil + } + for _, d := range implicitDefinitions { + setting.RegisterDefinition(d) + } + return nil + }) +} + +var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap] + +// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key, +// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist +// among implicit policy definitions. +func WellKnownSettingDefinition(k Key) (*setting.Definition, error) { + m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) { + return setting.DefinitionMapOf(implicitDefinitions) + }) + if err != nil { + return nil, err + } + if d, ok := m[k]; ok { + return d, nil + } + return nil, ErrNoSuchKey +} + +// setWellKnownSettingsForTest registers all implicit setting definitions +// for the duration of the test. +func setWellKnownSettingsForTest(tb lazy.TB) error { + return setting.SetDefinitionsForTest(tb, implicitDefinitions...) +} diff --git a/util/syspolicy/policy_keys_test.go b/util/syspolicy/policy_keys_test.go new file mode 100644 index 000000000..4d3260f3e --- /dev/null +++ b/util/syspolicy/policy_keys_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "os" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKnownKeysRegistered(t *testing.T) { + keyConsts, err := listStringConsts[Key]("policy_keys.go") + if err != nil { + t.Fatalf("listStringConsts failed: %v", err) + } + + m, err := setting.DefinitionMapOf(implicitDefinitions) + if err != nil { + t.Fatalf("definitionMapOf failed: %v", err) + } + + for _, key := range keyConsts { + t.Run(string(key), func(t *testing.T) { + d := m[key] + if d == nil { + t.Fatalf("%q was not registered", key) + } + if d.Key() != key { + t.Fatalf("d.Key got: %s, want %s", d.Key(), key) + } + }) + } +} + +func TestNotAWellKnownSetting(t *testing.T) { + d, err := WellKnownSettingDefinition("TestSettingDoesNotExist") + if d != nil || err == nil { + t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey) + } +} + +func listStringConsts[T ~string](filename string) (map[string]T, error) { + fset := token.NewFileSet() + src, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + f, err := parser.ParseFile(fset, filename, src, 0) + if err != nil { + return nil, err + } + + consts := make(map[string]T) + typeName := reflect.TypeFor[T]().Name() + for _, d := range f.Decls { + g, ok := d.(*ast.GenDecl) + if !ok || g.Tok != token.CONST { + continue + } + + for _, s := range g.Specs { + vs, ok := s.(*ast.ValueSpec) + if !ok || len(vs.Names) != len(vs.Values) { + continue + } + if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName { + continue + } + + for i, n := range vs.Names { + lit, ok := vs.Values[i].(*ast.BasicLit) + if !ok { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i])) + } + val, err := strconv.Unquote(lit.Value) + if err != nil { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value) + } + consts[n.Name] = T(val) + } + } + } + + return consts, nil +} diff --git a/util/syspolicy/policy_keys_windows.go b/util/syspolicy/policy_keys_windows.go deleted file mode 100644 index 5e9a71695..000000000 --- a/util/syspolicy/policy_keys_windows.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -var stringKeys = []Key{ - ControlURL, - LogTarget, - Tailnet, - ExitNodeID, - ExitNodeIP, - EnableIncomingConnections, - EnableServerMode, - ExitNodeAllowLANAccess, - EnableTailscaleDNS, - EnableTailscaleSubnets, - AdminConsoleVisibility, - NetworkDevicesVisibility, - TestMenuVisibility, - UpdateMenuVisibility, - RunExitNodeVisibility, - PreferencesMenuVisibility, - ExitNodeMenuVisibility, - AutoUpdateVisibility, - ResetToDefaultsVisibility, - KeyExpirationNoticeTime, - PostureChecking, - ManagedByOrganizationName, - ManagedByCaption, - ManagedByURL, -} - -var boolKeys = []Key{ - LogSCMInteractions, - FlushDNSOnSessionUnlock, -} - -var uint64Keys = []Key{} diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go new file mode 100644 index 000000000..e46ee38f6 --- /dev/null +++ b/util/syspolicy/rsop/change_callbacks.go @@ -0,0 +1,109 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "reflect" + "slices" + "sync" + "time" + + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" +) + +// Change represents a change from the Old to the New value of type T. +type Change[T any] struct { + New, Old T +} + +// PolicyChangeCallback is a function called whenever a policy changes. +type PolicyChangeCallback func(*PolicyChange) + +// PolicyChange describes a policy change. +type PolicyChange struct { + snapshots Change[*setting.Snapshot] +} + +// New returns the [setting.Snapshot] after the change. +func (c PolicyChange) New() *setting.Snapshot { + return c.snapshots.New +} + +// Old returns the [setting.Snapshot] before the change. +func (c PolicyChange) Old() *setting.Snapshot { + return c.snapshots.Old +} + +// HasChanged reports whether a policy setting with the specified [setting.Key], has changed. +func (c PolicyChange) HasChanged(key setting.Key) bool { + new, newErr := c.snapshots.New.GetErr(key) + old, oldErr := c.snapshots.Old.GetErr(key) + if newErr != nil && oldErr != nil { + return false + } + if newErr != nil || oldErr != nil { + return true + } + switch newVal := new.(type) { + case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration: + return newVal != old + case []string: + if oldVal, ok := old.([]string); ok { + return slices.Equal(newVal, oldVal) + } + return false + default: + loggerx.Errorf("%q has an unsupported value type: %T", newVal) + return reflect.DeepEqual(new, old) + } +} + +// policyChangeCallbacks are the callbacks to invoke when the resultant policy changes. +// It is safe for concurrent use. +type policyChangeCallbacks struct { + mu sync.RWMutex + cbs set.HandleSet[PolicyChangeCallback] +} + +// Register adds the specified callback to be invoked whenever the policy changes. +func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) { + c.mu.Lock() + handle := c.cbs.Add(callback) + c.mu.Unlock() + return func() { + c.mu.Lock() + delete(c.cbs, handle) + c.mu.Unlock() + } +} + +// Invoke calls the registered callback functions with the specified policy change info. +func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) { + var wg sync.WaitGroup + defer wg.Wait() + + c.mu.RLock() + defer c.mu.RUnlock() + + wg.Add(len(c.cbs)) + change := &PolicyChange{snapshots: snapshots} + for _, cb := range c.cbs { + go func() { + defer wg.Done() + cb(change) + }() + } +} + +// Close awaits the completion of active callbacks and prevents any further invocations. +func (c *policyChangeCallbacks) Close() { + c.mu.Lock() + defer c.mu.Unlock() + if c.cbs != nil { + clear(c.cbs) + c.cbs = nil + } +} diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go new file mode 100644 index 000000000..9191f80cb --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy.go @@ -0,0 +1,698 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rsop facilitates [source.Store] registration via [RegisterStore] +// and provides access to the resultant policy merged from all registered sources +// via [PolicyFor]. +package rsop + +import ( + "errors" + "fmt" + "reflect" + "slices" + "sync" + "sync/atomic" + "time" + + "tailscale.com/syncs" + "tailscale.com/types/lazy" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/internal/lazyinit" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +var errResultantPolicyClosed = errors.New("resultant policy closed") + +// The minimum and maximum wait times after detecting a policy change +// before reloading the policy. +// Policy changes occurring within [policyReloadMinDelay] of each other +// will be batched together, resulting in a single policy reload +// no later than [policyReloadMaxDelay] after the first detected change. +// In other words, the resultant policy will be reloaded no more often than once +// every 5 seconds, but at most 15 seconds after an underlying [source.Store] +// has issued a policy change callback. +// See [Policy.watchReload]. +const ( + defaultPolicyReloadMinDelay = 5 * time.Second + defaultPolicyReloadMaxDelay = 15 * time.Second +) + +// policyReloadMinDelay and policyReloadMaxDelay are test hooks. +// Their values default to [defaultPolicyReloadMinDelay] and [defaultPolicyReloadMaxDelay]. +var ( + policyReloadMinDelay, policyReloadMaxDelay lazy.SyncValue[time.Duration] +) + +// Policy provides access to the current resultant [setting.Snapshot] for a given +// scope and allows to reload it from the underlying [source.Store]s. It also allows to +// subscribe and receive a callback whenever the resultant [setting.Snapshot] is +// changed. It is safe for concurrent use. +type Policy struct { + scope setting.PolicyScope + + reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required + changeSourceCh chan sourceChangeRequest // written to to add a new or remove an existing source + closeCh chan struct{} // closed to signal that the Policy is being closed + doneCh chan struct{} // closed by closeInternal when watchReload exits + + // resultant is the most recent version of the [setting.Snapshot] containing policy settings + // merged from all applicable sources. + resultant atomic.Pointer[setting.Snapshot] + + changeCallbacks policyChangeCallbacks + + mu sync.RWMutex + sources source.ReadableSources + closing bool // Close was called (even if we're still closing) +} + +// newPolicy returns a new [Policy] for the specified [setting.PolicyScope] +// that tracks changes and merges policy settings read from the specified sources. +func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (p *Policy, err error) { + readableSources := source.ReadableSources(make([]source.ReadableSource, len(sources))) + for i, s := range sources { + reader, err := s.Reader() + if err != nil { + return nil, fmt.Errorf("failed to get a store reader: %v", err) + } + session, err := reader.OpenSession() + if err != nil { + return nil, fmt.Errorf("failed to open a reading session: %v", err) + } + + readableSource := source.ReadableSource{ + Source: s, + ReadingSession: session, + } + readableSources[i] = readableSource + defer func() { + if err != nil { + readableSource.Close() + } + }() + } + + // Sort policy sources by their precedence from lower to higher. + // For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}. + readableSources.StableSort() + + p = &Policy{ + scope: scope, + sources: readableSources, + reloadCh: make(chan reloadRequest, 1), + changeSourceCh: make(chan sourceChangeRequest), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + if err := p.start(); err != nil { + return nil, err + } + return p, nil +} + +// IsValid reports whether p is in a valid state and has not been closed. +func (p *Policy) IsValid() bool { + select { + case <-p.closeCh: + return false + default: + return true + } +} + +// Scope returns the [setting.PolicyScope] that this resultant policy applies to. +func (p *Policy) Scope() setting.PolicyScope { + return p.scope +} + +// Get returns the most recent resultant [setting.Snapshot]. +func (p *Policy) Get() *setting.Snapshot { + return p.resultant.Load() +} + +// RegisterChangeCallback adds a function to be called whenever the resultant +// policy changes. The returned function can be used to unregister the callback. +func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) { + return p.changeCallbacks.Register(callback) +} + +// Reload synchronously re-reads policy settings from the underlying policy +// [source.Store], constructing a new merged [setting.Snapshot] even if the policy remains +// unchanged. In most scenarios, there's no need to re-read the policy manually. +// Instead, it is recommended to register a policy change callback, or to use +// the most recent [setting.Snapshot] returned by the [Policy.Get] method. +func (p *Policy) Reload() (*setting.Snapshot, error) { + return p.reload(true) +} + +// reload is like Reload, but allows to specify whether to re-read policy settings +// from unchanged policy sources. +func (p *Policy) reload(force bool) (*setting.Snapshot, error) { + respCh := make(chan reloadResponse, 1) + select { + case p.reloadCh <- reloadRequest{force: force, respCh: respCh}: + // continue + case <-p.closeCh: + return nil, errResultantPolicyClosed + } + select { + case resp := <-respCh: + return resp.policy, resp.err + case <-p.closeCh: + return nil, errResultantPolicyClosed + } +} + +// Done returns a channel that is closed when the [Policy] is closed. +func (p *Policy) Done() <-chan struct{} { + return p.doneCh +} + +func (p *Policy) start() error { + if _, err := p.reloadNow(false); err != nil { + return err + } + go p.watchPolicyChanges() + go p.watchReload() + return nil +} + +// readAndMerge reads and merges policy settings from the underlying sources, +// returning a [setting.Snapshot] with the merged result. +// If the force parameter is true, it re-reads policy settings from each store +// even if no policy change was observed, and returns an error if the read +// operation fails. +func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) { + p.mu.RLock() + defer p.mu.RUnlock() + // Start with an empty policy in the target scope. + resultant := setting.NewSnapshot(nil, setting.SummaryWith(p.scope)) + // Then merge policy settings from all sources. + // Policy sources with the highest precedence (e.g., the device policy) are merged last, + // overriding any conflicting policy settings with lower precedence. + for _, s := range p.sources { + var policy *setting.Snapshot + if force { + var err error + if policy, err = s.ReadSettings(); err != nil { + return nil, err + } + } else { + policy = s.GetSettings() + } + resultant = setting.MergeSnapshots(resultant, policy) + } + return resultant, nil +} + +// reloadAsync requests an asynchronous background policy reload. +// The policy will be reloaded no later than in [policyReloadMaxDelay]. +func (p *Policy) reloadAsync() { + select { + case p.reloadCh <- reloadRequest{}: + // Sent. + default: + // A reload request is already en route. + } +} + +// reloadNow loads and merges policies from all sources, updating the resultant policy. +// If the force parameter is true, it forcibly reloads policies +// from the underlying policy store, even if no policy changes were detected. +// +// Except for the initial policy reload during the [Policy] creation, +// this method should only be called from the [Policy.watchReload] goroutine. +func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) { + new, err := p.readAndMerge(force) + if err != nil { + return nil, err + } + old := p.resultant.Swap(new) + // A nil old value indicates the initial policy load rather than a policy change. + // Additionally, we should not invoke the policy change callbacks unless the + // policy items have actually changed. + if old != nil && !old.EqualItems(new) { + snapshots := Change[*setting.Snapshot]{New: new, Old: old} + p.changeCallbacks.Invoke(snapshots) + } + return new, nil +} + +// AddSource adds the specified source to the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this resultant policy, +// or if the resultant policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) AddSource(source *source.Source) error { + return p.changeSource(source, nil) +} + +// RemoveSource removes the specified source from the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error if the +// resultant policy is being closed, or if policy refresh fails with an error. +func (p *Policy) RemoveSource(source *source.Source) error { + return p.changeSource(nil, source) +} + +// ReplaceSource replaces the old source with the new source atomically, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this resultant policy, +// or if the resultant policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) ReplaceSource(old, new *source.Source) error { + return p.changeSource(new, old) +} + +func (p *Policy) changeSource(toAdd, toRemove *source.Source) error { + if toAdd == toRemove { + return nil + } + if toAdd != nil && !p.scope.IsWithinOf(toAdd.Scope()) { + return errors.New("scope mismatch") + } + respCh := make(chan error, 1) + req := sourceChangeRequest{toAdd, toRemove, respCh} + select { + case p.changeSourceCh <- req: + return <-respCh + case <-p.closeCh: + return errResultantPolicyClosed + } +} + +// watchPolicyChanges awaits a policy change notification from any of the sources +// and calls reloadAsync whenever a notification is received. +func (p *Policy) watchPolicyChanges() { + const ( + closeIdx = iota + changeSourceIdx + policyChangedOffset + ) + + // The cases are Close, ChangeSource, PolicyChanged[0],...,PolicyChanged[N-1]. + p.mu.RLock() + cases := make([]reflect.SelectCase, len(p.sources)+policyChangedOffset) + // Add the PolicyChanged[N] cases. + for i, source := range p.sources { + cases[i+policyChangedOffset] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(source.PolicyChanged())} + } + // Add the Close case. + cases[closeIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.closeCh)} + // Add the ChangeSource case. + cases[changeSourceIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.changeSourceCh)} + p.mu.RUnlock() + + for { + switch chosen, recv, ok := reflect.Select(cases); chosen { + case closeIdx: // Close + // Exit the watch as the closeCh was closed, indicating that + // the [Policy] is being closed. + return + case changeSourceIdx: // ChangeSource + // We've received a source change request from one of the AddSource, + // RemoveSource, or ReplaceSource methods, meaning that we need to: + // - Open a reader session if a new source is being added; + // - Update the p.sources slice; + // - Update the cases slice; + // - Trigger a synchronous policy reload; + // - Report an error, if any, back to the caller. + req := recv.Interface().(sourceChangeRequest) + needClose, err := func() (close bool, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if req.toAdd != nil { + if !p.sources.Contains(req.toAdd) { + reader, err := req.toAdd.Reader() + if err != nil { + return false, fmt.Errorf("failed to get a store reader: %v", err) + } + session, err := reader.OpenSession() + if err != nil { + return false, fmt.Errorf("failed to open a reading session: %v", err) + } + + addAt := p.sources.InsertionIndexOf(req.toAdd) + toAdd := source.ReadableSource{ + Source: req.toAdd, + ReadingSession: session, + } + p.sources = slices.Insert(p.sources, addAt, toAdd) + newCase := reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(toAdd.PolicyChanged())} + caseIndex := addAt + policyChangedOffset + cases = slices.Insert(cases, caseIndex, newCase) + } + } + if req.toDelete != nil { + if deleteAt := p.sources.IndexOf(req.toDelete); deleteAt != -1 { + p.sources.DeleteAt(deleteAt) + caseIndex := deleteAt + policyChangedOffset + cases = slices.Delete(cases, caseIndex, caseIndex+1) + } + } + return len(p.sources) == 0, nil + }() + if err == nil { + if needClose { + // Close the resultant policy if the last policy source was deleted. + p.Close() + } else { + // Otherwise, reload the policy synchronously. + _, err = p.reload(false) + } + } + req.respCh <- err + default: // PolicyChanged[N] + if !ok { + // One of the PolicyChanged channels was closed, indicating that + // the corresponding [source.Source] is no longer valid. + // We can no longer keep this [Policy] up to date + // and should close it. + p.Close() + return + } + + // One of the PolicyChanged channels was signaled. + // We should request an asynchronous policy reload. + p.reloadAsync() + } + } +} + +// watchReload processes incoming synchronous and asynchronous policy reload requests. +// Synchronous requests (with a non-nil respCh) are served immediately. +// Asynchronous requests are debounced and throttled: they are executed at least +// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay] +// after the first request in a batch. +func (p *Policy) watchReload() { + force := false // whether a forced refresh was requested + var delayCh, timeoutCh <-chan time.Time + reload := func(respCh chan<- reloadResponse) { + delayCh, timeoutCh = nil, nil + policy, err := p.reloadNow(force) + if err != nil { + loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err) + } + if respCh != nil { + respCh <- reloadResponse{policy: policy, err: err} + } + force = false + } + +loop: + for { + select { + case req := <-p.reloadCh: + if req.force { + force = true + } + if req.respCh != nil { + reload(req.respCh) + continue + } + if delayCh == nil { + timeoutCh = time.After(policyReloadMaxDelay.Get(func() time.Duration { return defaultPolicyReloadMaxDelay })) + } + delayCh = time.After(policyReloadMinDelay.Get(func() time.Duration { return defaultPolicyReloadMinDelay })) + case <-delayCh: + reload(nil) + case <-timeoutCh: + reload(nil) + case <-p.closeCh: + break loop + } + } + + p.closeInternal() +} + +func (p *Policy) closeInternal() { + p.mu.Lock() + defer p.mu.Unlock() + p.sources.Close() + p.changeCallbacks.Close() + close(p.doneCh) +} + +// Close initiates the closing of the resultant policy. +// The actual closing is performed by closeInternal when watchReload exits, +// and the Done() channel is closed when closeInternal finishes. +func (p *Policy) Close() { + p.mu.Lock() + defer p.mu.Unlock() + if p.closing { + return + } + p.closing = true + close(p.closeCh) +} + +// sourceChangeRequest is a request to add and/or remove source from a [Policy]. +type sourceChangeRequest struct { + toAdd, toDelete *source.Source + respCh chan<- error +} + +// reloadRequest describes a policy reload request. +type reloadRequest struct { + // force triggers an immediate synchronous policy reload, + // reloading the policy regardless of whether a policy change was detected. + force bool + // respCh is an optional channel. If non-nil, it makes the reload request + // synchronous and receives the result. + respCh chan<- reloadResponse +} + +type reloadResponse struct { + policy *setting.Snapshot + err error +} + +var ( + policyMu sync.RWMutex + policySources []*source.Source + resultantPolicies []*Policy + + resultantPolicyLRU [setting.MaxSettingScope + 1]syncs.AtomicValue[*Policy] // by [Scope.Kind] +) + +// registerSource registers the specified [source.Source] to be used by the package. +// It updates existing [Policy]s returned by [PolicyFor] to use this source if +// they are within the source's [setting.PolicyScope]. +func registerSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + if slices.Contains(policySources, source) { + return nil + } + policySources = append(policySources, source) + return forEachResultantPolicyLocked(func(policy *Policy) error { + if !policy.Scope().IsWithinOf(source.Scope()) { + return nil + } + return policy.AddSource(source) + }) +} + +// replaceSource is like [unregisterSource](old) followed by [registerSource](new), +// but is atomic from the perspective of each [Policy]. +func replaceSource(old, new *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + oldIndex := slices.Index(policySources, old) + if oldIndex == -1 { + return fmt.Errorf("the source is not registered: %v", old) + } + policySources[oldIndex] = new + return forEachResultantPolicyLocked(func(policy *Policy) error { + if policy.Scope().IsWithinOf(old.Scope()) || policy.Scope().IsWithinOf(new.Scope()) { + return nil + } + return policy.ReplaceSource(old, new) + }) +} + +// unregisterSource unregisters the specified [source.Source], +// so that it won't be used by any new or existing [Policy]. +func unregisterSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + index := slices.Index(policySources, source) + if index == -1 { + return nil + } + policySources = slices.Delete(policySources, index, index+1) + return forEachResultantPolicyLocked(func(policy *Policy) error { + if !policy.Scope().IsWithinOf(source.Scope()) { + return nil + } + return policy.RemoveSource(source) + }) +} + +// forEachResultantPolicyLocked calls fn for every [Policy] in [resultantPolicies]. +// It accumulates the returned errors, except for [errResultantPolicyClosed], +// and returns an error that wraps all errors returned by fn. +// The [policyMu] mutex must be held while this function is executed. +func forEachResultantPolicyLocked(fn func(p *Policy) error) error { + var errs []error + for _, policy := range resultantPolicies { + err := fn(policy) + if err != nil && !errors.Is(err, errResultantPolicyClosed) { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +// PolicyFor returns the [Policy] for the specified scope, +// creating one from the registered [source.Store]s if it does not exist. +func PolicyFor(scope setting.PolicyScope) (*Policy, error) { + if err := lazyinit.Do(); err != nil { + return nil, err + } + if policy := resultantPolicyLRU[scope.Kind()].Load(); policy != nil && policy.Scope() == scope && policy.IsValid() { + return policy, nil + } + return policyForSlow(scope) +} + +func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) { + defer func() { + if policy != nil { + resultantPolicyLRU[scope.Kind()].Store(policy) + } + }() + + policyMu.RLock() + if policy, ok := findPolicyByScopeLocked(scope); ok { + policyMu.RUnlock() + return policy, nil + } + policyMu.RUnlock() + + policyMu.Lock() + defer policyMu.Unlock() + if policy, ok := findPolicyByScopeLocked(scope); ok { + return policy, nil + } + sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool { + return scope.IsWithinOf(source.Scope()) + }) + policy, err = newPolicy(scope, sources...) + if err != nil { + return nil, err + } + resultantPolicies = append(resultantPolicies, policy) + go func() { + <-policy.Done() + deletePolicy(policy) + }() + return policy, nil +} + +// findPolicyByScopeLocked returns a policy with the specified scope and true if +// one exists, otherwise it returns nil, false. +// [policyMu] must be held. +func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) { + for _, policy := range resultantPolicies { + if policy.Scope() == target && policy.IsValid() { + return policy, true + } + } + return nil, false +} + +// deletePolicy deletes the specified resultant policy from the [resultantPolicies] list. +func deletePolicy(policy *Policy) { + policyMu.Lock() + if i := slices.Index(resultantPolicies, policy); i != -1 { + resultantPolicies = slices.Delete(resultantPolicies, i, i+1) + } + resultantPolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil) + policyMu.Unlock() +} + +// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore] +// or [StoreRegistration.Unregister] is called more than once. +var ErrAlreadyConsumed = errors.New("the store registration is no longer valid") + +// StoreRegistration is a [source.Store] registered for use in the specified scope. +// It can be used to unregister the store, or replace it with another one. +type StoreRegistration struct { + source *source.Source + consumed atomic.Uint32 + m sync.Mutex +} + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + return newStoreRegistration(name, scope, store) +} + +// RegisterStoreForTest is like [RegisterStore], but unregisters the store when +// tb and all its subtests complete. +func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + reg, err := RegisterStore(name, scope, store) + if err == nil { + tb.Cleanup(func() { + if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) { + tb.Fatalf("Unregister failed: %v", err) + } + }) + } + return reg, err // may be nil or non-nil +} + +func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + source := source.NewSource(name, scope, store) + if err := registerSource(source); err != nil { + return nil, err + } + return &StoreRegistration{source: source}, nil +} + +// ReplaceStore replaces the registered store with the new one, +// returning a new [StoreRegistration] or an error. +func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) { + var res *StoreRegistration + err := r.consume(func() error { + newSource := source.NewSource(r.source.Name(), r.source.Scope(), new) + if err := replaceSource(r.source, newSource); err != nil { + return err + } + res = &StoreRegistration{source: newSource} + return nil + }) + return res, err +} + +// Unregister reverts the registration. +func (r *StoreRegistration) Unregister() error { + return r.consume(func() error { return unregisterSource(r.source) }) +} + +// consume invokes fn, consuming r if no error is returned. +// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call. +func (r *StoreRegistration) consume(fn func() error) (err error) { + if r.consumed.Load() != 0 { + return ErrAlreadyConsumed + } + return r.consumeSlow(fn) +} + +func (r *StoreRegistration) consumeSlow(fn func() error) (err error) { + r.m.Lock() + defer r.m.Unlock() + if r.consumed.Load() != 0 { + return ErrAlreadyConsumed + } + if err = fn(); err == nil { + r.consumed.Store(1) + } + return err // may be nil or non-nil +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go new file mode 100644 index 000000000..744d4bfe9 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -0,0 +1,368 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "slices" + "sort" + "testing" + + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +func TestRegisterSourceAndGetResultantPolicy(t *testing.T) { + type sourceConfig struct { + name string + scope setting.PolicyScope + settingKey setting.Key + settingValue string + wantEffective bool + } + tests := []struct { + name string + scope setting.PolicyScope + initialSources []sourceConfig + additionalSources []sourceConfig + wantSnapshot *setting.Snapshot + }{ + { + name: "DevicePolicy/NoSources", + scope: setting.DeviceScope, + wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope), + }, + { + name: "UserScope/NoSources", + scope: setting.CurrentUserScope, + wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/OneInitialSource", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/OneAdditionalSource", + scope: setting.DeviceScope, + additionalSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/ManyInitialSources/NoConflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/ManyInitialSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/MixedSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceD", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueD", + wantEffective: true, + }, + { + name: "TestSourceE", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueE", + wantEffective: true, + }, + { + name: "TestSourceF", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueF", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "UserScope/Init-DeviceSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)), + }, setting.CurrentUserScope), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceProfile", + scope: setting.CurrentProfileScope, + settingKey: "TestKeyB", + settingValue: "ProfileValue", + wantEffective: true, + }, + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)), + }, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/User-Source-does-not-apply", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyA", + settingValue: "UserValue", + wantEffective: false, // Registering a user source should have no impact on the device policy. + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Register all settings that we use in this test. + var definitions []*setting.Definition + for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) { + definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue)) + } + if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Add the initial policy sources. + var wantSources []*source.Source + for _, s := range tt.initialSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("failed to register policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + // Retrieve the resultant policy. + policy, err := resultantPolicyForTest(t, tt.scope) + if err != nil { + t.Fatalf("failed to get resultant policy for %v", tt.scope) + } + + // Add additional setting sources one by one, and check the policy settings at each step. + for _, s := range tt.additionalSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("failed to register additional policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + sort.SliceStable(wantSources, func(i, j int) bool { + return wantSources[i].Compare(wantSources[j]) < 0 + }) + gotSources := make([]*source.Source, len(policy.sources)) + for i, s := range policy.sources { + gotSources[i] = s.Source + } + if !slices.Equal(gotSources, wantSources) { + t.Errorf("Sources: got %v; want %v", gotSources, wantSources) + } + + // Verify the final resultant settings snapshots. + if got := policy.Get(); !got.Equal(tt.wantSnapshot) { + t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot) + } + }) + } +} + +// resultantPolicyForTest is like [resultantPolicyFor], but it deletes the policy +// when tb and all its subtests complete. +func resultantPolicyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { + policy, err := PolicyFor(target) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + policy.Close() + <-policy.Done() + deletePolicy(policy) + }) + return policy, nil +} diff --git a/util/syspolicy/setting/errors.go b/util/syspolicy/setting/errors.go new file mode 100644 index 000000000..8d5e73754 --- /dev/null +++ b/util/syspolicy/setting/errors.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import "errors" + +var ( + // ErrNotConfigured is returned when the requested policy setting is not configured. + ErrNotConfigured = errors.New("not configured") + // ErrTypeMismatch is returned when there's a type mismatch between the actual type + // of the setting value and the expected type. + ErrTypeMismatch = errors.New("type mismatch") + // ErrNoSuchKey is returned by [DefinitionOf] when no policy setting + // has been registered with the specified key. + // + // Until 2024-08-02, this error was also returned by a [Handler] when the specified + // key did not have a value set. While the package maintains compatibility with this + // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer + // [source.Store] implementations. + ErrNoSuchKey = errors.New("no such key") +) + +// Error is an error when reading or parsing a policy setting. +type Error struct { + text string +} + +// NewError returns a [Error] with the specified error message. +func NewError(text string) *Error { + return &Error{text} +} + +// WrapError returns an [Error] with the text of the specified error, +// or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey]. +func WrapError(err error) *Error { + if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) { + return nil + } + if err, ok := err.(*Error); ok { + return err + } + return &Error{err.Error()} +} + +// Error implements error. +func (e Error) Error() string { + return e.text +} + +// MarshalText implements [encoding.TextMarshaler]. +func (e Error) MarshalText() (text []byte, err error) { + return []byte(e.Error()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler]. +func (e *Error) UnmarshalText(text []byte) error { + e.text = string(text) + return nil +} diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go new file mode 100644 index 000000000..406fde132 --- /dev/null +++ b/util/syspolicy/setting/key.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +// Key is a string that uniquely identifies a policy and must remain unchanged +// once established and documented for a given policy setting. It may contain +// alphanumeric characters and zero or more [KeyPathSeparator]s to group +// individual policy settings into categories. +type Key string + +// KeyPathSeparator allows logical grouping of policy settings into categories. +const KeyPathSeparator = "/" diff --git a/util/syspolicy/setting/origin.go b/util/syspolicy/setting/origin.go new file mode 100644 index 000000000..3e61cd946 --- /dev/null +++ b/util/syspolicy/setting/origin.go @@ -0,0 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "fmt" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" +) + +// Origin describes where a policy or a policy setting is configured. +type Origin struct { + data settingOrigin +} + +// settingOrigin is the marshallable data of a [Origin]. +type settingOrigin struct { + Name string `json:",omitzero"` + Scope PolicyScope +} + +// NewOrigin returns a new [Origin] with the specified scope. +func NewOrigin(scope PolicyScope) *Origin { + return NewNamedOrigin("", scope) +} + +// NewNamedOrigin returns a new [Origin] with the specified scope and name. +func NewNamedOrigin(name string, scope PolicyScope) *Origin { + return &Origin{settingOrigin{name, scope}} +} + +// Scope reports the policy [PolicyScope] where the setting is configured. +func (s Origin) Scope() PolicyScope { + return s.data.Scope +} + +// Name returns the name of the policy source where the setting is configured, +// or "" if not available. +func (s Origin) Name() string { + return s.data.Name +} + +// String implements [fmt.Stringer]. +func (s Origin) String() string { + if s.Name() != "" { + return fmt.Sprintf("%s (%v)", s.Name(), s.Scope()) + } + return s.Scope().String() +} + +// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { + return jsonv2.MarshalEncode(out, &s.data, opts) +} + +// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. +func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { + return jsonv2.UnmarshalDecode(in, &s.data, opts) +} + +// MarshalJSON implements [json.Marshaler]. +func (s Origin) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(s) // uses MarshalJSONV2 +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (s *Origin) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 +} diff --git a/util/syspolicy/setting/policy_scope.go b/util/syspolicy/setting/policy_scope.go new file mode 100644 index 000000000..636c815b2 --- /dev/null +++ b/util/syspolicy/setting/policy_scope.go @@ -0,0 +1,195 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "fmt" + "strings" + + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal/lazyinit" +) + +var ( + lazyCurrentScope lazy.SyncValue[PolicyScope] + + // DeviceScope indicates a scope containing device-global policies. + DeviceScope = PolicyScope{kind: DeviceSetting} + // CurrentProfileScope indicates a scope containing policies that apply to the + // currently active Tailscale profile. + CurrentProfileScope = PolicyScope{kind: ProfileSetting} + // CurrentUserScope indicates a scope containing policies that apply to the + // current user, for whatever that means on the current platform and + // in the current application context. + CurrentUserScope = PolicyScope{kind: UserSetting} +) + +// PolicyScope is a management scope. +type PolicyScope struct { + kind Scope + userID string + profileID string +} + +// CurrentScope returns the default [PolicyScope] that the package will use to return +// the policy settings for unless a different scope is explicitly requested. +// This defaults to [DeviceScope], unless the process runs as a user (rather than LocalSystem) +// on Windows, in which case it returns the [CurrentUserScope]. +func CurrentScope() PolicyScope { + // Allow deferred package init functions to override the default scope. + lazyinit.Do() + return lazyCurrentScope.Get(func() PolicyScope { return DeviceScope }) +} + +// SetCurrentScope attempts to set the specified scope as the current scope, +// and reports whether it succeeds. +// It can be called only once and must be during lazy package initialization. +func SetCurrentScope(scope PolicyScope) bool { + return lazyCurrentScope.Set(scope) +} + +// UserScopeOf returns a policy [PolicyScope] of the specified user. +func UserScopeOf(uid string) PolicyScope { + return PolicyScope{kind: UserSetting, userID: uid} +} + +// Kind reports the base [Scope] of s. +func (s PolicyScope) Kind() Scope { + return s.kind +} + +// IsApplicableSetting reports whether the specified setting applies to +// and can be retrieved for this scope. Policy settings are applicable +// to their own scopes as well as more specific scopes. For example, +// device settings are applicable to device, profile and user scopes, +// but user settings are only applicable to user scopes. +// For instance, a menu visibility setting is inherently a user setting +// and only makes sense in the context of a specific user. +func (s PolicyScope) IsApplicableSetting(setting *Definition) bool { + return setting != nil && setting.Scope() <= s.Kind() +} + +// IsConfigurableSetting reports whether the specified setting can be configured +// by a policy at this scope. Policy settings are configurable at their own scopes +// as well as broader scopes. For example, [UserSetting]s are configurable in +// user, profile, and device scopes, but [DeviceSetting]s are only configurable +// in the [DeviceScope]. For instance, the InstallUpdates policy setting +// can only be configured in the device scope, as it controls whether updates +// will be installed automatically on the device, rather than for specific users. +func (s PolicyScope) IsConfigurableSetting(setting *Definition) bool { + return setting != nil && setting.Scope() >= s.Kind() +} + +// IsWithinOf reports whether policy settings that apply to s2 also apply to s. +// For example, policy settings that apply to the [DeviceScope] also apply to +// the [CurrentUserScope]. +func (s PolicyScope) IsWithinOf(s2 PolicyScope) bool { + if s2.Kind() > s.Kind() { + return false + } + switch s2.Kind() { + case DeviceSetting: + return true + case ProfileSetting: + return s.profileID == s2.profileID + case UserSetting: + return s.userID == s2.userID + default: + panic("unreachable") + } +} + +// IsStrictlyWithinOf is like [IsWithinOf], except it returns false +// when s and s2 is the same scope. +func (s PolicyScope) IsStrictlyWithinOf(s2 PolicyScope) bool { + return s != s2 && s.IsWithinOf(s2) +} + +// String implements [fmt.Stringer]. +func (s PolicyScope) String() string { + if s.profileID == "" && s.userID == "" { + return s.kind.String() + } + return s.stringSlow() +} + +// MarshalText implements [encoding.TextMarshaler]. +func (s PolicyScope) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +// MarshalText implements [encoding.TextUnmarshaler]. +func (s *PolicyScope) UnmarshalText(b []byte) error { + *s = PolicyScope{} + parts := strings.SplitN(string(b), "/", 2) + if len(parts) == 0 { + return fmt.Errorf("%s is not a valid scope", b) + } + for i, part := range parts { + kind, id, err := parseScopeAndID(part) + if err != nil { + return err + } + if i > 0 && kind <= s.kind { + return fmt.Errorf("invalid scope hierarchy: %s", b) + } + s.kind = kind + switch kind { + case DeviceSetting: + if id != "" { + return fmt.Errorf("the device scope must not have an ID: %s", b) + } + case ProfileSetting: + s.profileID = id + case UserSetting: + s.userID = id + } + } + return nil +} + +func (s PolicyScope) stringSlow() string { + var sb strings.Builder + writeScopeWithID := func(s Scope, id string) { + sb.WriteString(s.String()) + if id != "" { + sb.WriteRune('(') + sb.WriteString(id) + sb.WriteRune(')') + } + } + if s.kind == ProfileSetting || s.profileID != "" { + writeScopeWithID(ProfileSetting, s.profileID) + if s.kind != ProfileSetting { + sb.WriteRune('/') + } + } + if s.kind == UserSetting { + writeScopeWithID(UserSetting, s.userID) + } + return sb.String() +} + +func parseScopeAndID(s string) (scope Scope, id string, err error) { + name, params, ok := extractScopeAndParams(s) + if !ok { + return 0, "", fmt.Errorf("%q is not a valid scope string", s) + } + if err := scope.UnmarshalText([]byte(name)); err != nil { + return 0, "", err + } + return scope, params, nil +} + +func extractScopeAndParams(s string) (name, params string, ok bool) { + paramsStart := strings.Index(s, "(") + if paramsStart == -1 { + return s, "", true + } + paramsEnd := strings.LastIndex(s, ")") + if paramsEnd < paramsStart { + return "", "", false + } + return s[0:paramsStart], s[paramsStart+1 : paramsEnd], true +} diff --git a/util/syspolicy/setting/policy_scope_test.go b/util/syspolicy/setting/policy_scope_test.go new file mode 100644 index 000000000..8140fc5a0 --- /dev/null +++ b/util/syspolicy/setting/policy_scope_test.go @@ -0,0 +1,550 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "reflect" + "testing" + + jsonv2 "github.com/go-json-experiment/json" +) + +func TestPolicyScopeIsApplicableSetting(t *testing.T) { + tests := []struct { + name string + scope PolicyScope + setting *Definition + wantApplicable bool + }{ + { + name: "DeviceScope/DeviceSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantApplicable: true, + }, + { + name: "DeviceScope/ProfileSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantApplicable: false, + }, + { + name: "DeviceScope/UserSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantApplicable: false, + }, + { + name: "ProfileScope/DeviceSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantApplicable: true, + }, + { + name: "ProfileScope/ProfileSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantApplicable: true, + }, + { + name: "ProfileScope/UserSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantApplicable: false, + }, + { + name: "UserScope/DeviceSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantApplicable: true, + }, + { + name: "UserScope/ProfileSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantApplicable: true, + }, + { + name: "UserScope/UserSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantApplicable: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotApplicable := tt.scope.IsApplicableSetting(tt.setting) + if gotApplicable != tt.wantApplicable { + t.Fatalf("got %v, want %v", gotApplicable, tt.wantApplicable) + } + }) + } +} + +func TestPolicyScopeIsConfigurableSetting(t *testing.T) { + tests := []struct { + name string + scope PolicyScope + setting *Definition + wantConfigurable bool + }{ + { + name: "DeviceScope/DeviceSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantConfigurable: true, + }, + { + name: "DeviceScope/ProfileSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantConfigurable: true, + }, + { + name: "DeviceScope/UserSetting", + scope: DeviceScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantConfigurable: true, + }, + { + name: "ProfileScope/DeviceSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantConfigurable: false, + }, + { + name: "ProfileScope/ProfileSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantConfigurable: true, + }, + { + name: "ProfileScope/UserSetting", + scope: CurrentProfileScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantConfigurable: true, + }, + { + name: "UserScope/DeviceSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue), + wantConfigurable: false, + }, + { + name: "UserScope/ProfileSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue), + wantConfigurable: false, + }, + { + name: "UserScope/UserSetting", + scope: CurrentUserScope, + setting: NewDefinition("TestSetting", UserSetting, IntegerValue), + wantConfigurable: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotConfigurable := tt.scope.IsConfigurableSetting(tt.setting) + if gotConfigurable != tt.wantConfigurable { + t.Fatalf("got %v, want %v", gotConfigurable, tt.wantConfigurable) + } + }) + } +} + +func TestPolicyScopeIsWithinOf(t *testing.T) { + tests := []struct { + name string + scopeA PolicyScope + scopeB PolicyScope + wantBWithinOfA bool + wantBStrictlyWithinOfA bool + }{ + { + name: "DeviceScope/DeviceScope", + scopeA: DeviceScope, + scopeB: DeviceScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: false, + }, + { + name: "DeviceScope/CurrentProfileScope", + scopeA: DeviceScope, + scopeB: CurrentProfileScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: true, + }, + { + name: "DeviceScope/UserScope", + scopeA: DeviceScope, + scopeB: CurrentUserScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: true, + }, + { + name: "ProfileScope/DeviceScope", + scopeA: CurrentProfileScope, + scopeB: DeviceScope, + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + { + name: "ProfileScope/ProfileScope", + scopeA: CurrentProfileScope, + scopeB: CurrentProfileScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: false, + }, + { + name: "ProfileScope/UserScope", + scopeA: CurrentProfileScope, + scopeB: CurrentUserScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: true, + }, + { + name: "UserScope/DeviceScope", + scopeA: CurrentUserScope, + scopeB: DeviceScope, + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + { + name: "UserScope/ProfileScope", + scopeA: CurrentUserScope, + scopeB: CurrentProfileScope, + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + { + name: "UserScope/UserScope", + scopeA: CurrentUserScope, + scopeB: CurrentUserScope, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: false, + }, + { + name: "UserScope(1234)/UserScope(1234)", + scopeA: UserScopeOf("1234"), + scopeB: UserScopeOf("1234"), + wantBWithinOfA: true, + wantBStrictlyWithinOfA: false, + }, + { + name: "UserScope(1234)/UserScope(5678)", + scopeA: UserScopeOf("1234"), + scopeB: UserScopeOf("5678"), + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + { + name: "ProfileScope(A)/UserScope(A/1234)", + scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"}, + scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"}, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: true, + }, + { + name: "ProfileScope(A)/UserScope(B/1234)", + scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"}, + scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"}, + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + { + name: "UserScope(1234)/UserScope(A/1234)", + scopeA: PolicyScope{kind: UserSetting, userID: "1234"}, + scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"}, + wantBWithinOfA: true, + wantBStrictlyWithinOfA: true, + }, + { + name: "UserScope(1234)/UserScope(A/5678)", + scopeA: PolicyScope{kind: UserSetting, userID: "1234"}, + scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"}, + wantBWithinOfA: false, + wantBStrictlyWithinOfA: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotWithinOf := tt.scopeB.IsWithinOf(tt.scopeA) + if gotWithinOf != tt.wantBWithinOfA { + t.Fatalf("WithinOf: got %v, want %v", gotWithinOf, tt.wantBWithinOfA) + } + + gotStrictlyWithinOf := tt.scopeB.IsStrictlyWithinOf(tt.scopeA) + if gotStrictlyWithinOf != tt.wantBStrictlyWithinOfA { + t.Fatalf("StrictlyWithinOf: got %v, want %v", gotStrictlyWithinOf, tt.wantBStrictlyWithinOfA) + } + }) + } +} + +func TestPolicyScopeMarshalUnmarshal(t *testing.T) { + tests := []struct { + name string + in any + wantJSON string + wantError bool + }{ + { + name: "null-scope", + in: &struct { + Scope PolicyScope + }{}, + wantJSON: `{"Scope":"Device"}`, + }, + { + name: "null-scope-omit-zero", + in: &struct { + Scope PolicyScope `json:",omitzero"` + }{}, + wantJSON: `{}`, + }, + { + name: "device-scope", + in: &struct { + Scope PolicyScope + }{DeviceScope}, + wantJSON: `{"Scope":"Device"}`, + }, + { + name: "current-profile-scope", + in: &struct { + Scope PolicyScope + }{CurrentProfileScope}, + wantJSON: `{"Scope":"Profile"}`, + }, + { + name: "current-user-scope", + in: &struct { + Scope PolicyScope + }{CurrentUserScope}, + wantJSON: `{"Scope":"User"}`, + }, + { + name: "specific-user-scope", + in: &struct { + Scope PolicyScope + }{UserScopeOf("_")}, + wantJSON: `{"Scope":"User(_)"}`, + }, + { + name: "specific-user-scope", + in: &struct { + Scope PolicyScope + }{UserScopeOf("S-1-5-21-3698941153-1525015703-2649197413-1001")}, + wantJSON: `{"Scope":"User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`, + }, + { + name: "specific-profile-scope", + in: &struct { + Scope PolicyScope + }{PolicyScope{kind: ProfileSetting, profileID: "1234"}}, + wantJSON: `{"Scope":"Profile(1234)"}`, + }, + { + name: "specific-profile-and-user-scope", + in: &struct { + Scope PolicyScope + }{PolicyScope{ + kind: UserSetting, + profileID: "1234", + userID: "S-1-5-21-3698941153-1525015703-2649197413-1001", + }}, + wantJSON: `{"Scope":"Profile(1234)/User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotJSON, err := jsonv2.Marshal(tt.in) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(gotJSON) != tt.wantJSON { + t.Fatalf("Marshal got %s, want %s", gotJSON, tt.wantJSON) + } + wantBack := tt.in + gotBack := reflect.New(reflect.TypeOf(tt.in).Elem()).Interface() + err = jsonv2.Unmarshal(gotJSON, gotBack) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !reflect.DeepEqual(gotBack, wantBack) { + t.Fatalf("Unmarshal got %+v, want %+v", gotBack, wantBack) + } + }) + } +} + +func TestPolicyScopeUnmarshalSpecial(t *testing.T) { + tests := []struct { + name string + json string + want any + wantError bool + }{ + { + name: "empty", + json: "{}", + want: &struct { + Scope PolicyScope + }{}, + }, + { + name: "too-many-scopes", + json: `{"Scope":"Device/Profile/User"}`, + wantError: true, + }, + { + name: "user/profile", // incorrect order + json: `{"Scope":"User/Profile"}`, + wantError: true, + }, + { + name: "profile-user-no-params", + json: `{"Scope":"Profile/User"}`, + want: &struct { + Scope PolicyScope + }{CurrentUserScope}, + }, + { + name: "unknown-scope", + json: `{"Scope":"Unknown"}`, + wantError: true, + }, + { + name: "unknown-scope/unknown-scope", + json: `{"Scope":"Unknown/Unknown"}`, + wantError: true, + }, + { + name: "device-scope/unknown-scope", + json: `{"Scope":"Device/Unknown"}`, + wantError: true, + }, + { + name: "unknown-scope/device-scope", + json: `{"Scope":"Unknown/Device"}`, + wantError: true, + }, + { + name: "slash", + json: `{"Scope":"/"}`, + wantError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &struct { + Scope PolicyScope + }{} + err := jsonv2.Unmarshal([]byte(tt.json), got) + if (err != nil) != tt.wantError { + t.Errorf("Marshal error: got %v, want %v", err, tt.wantError) + } + if err != nil { + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("Unmarshal got %+v, want %+v", got, tt.want) + } + }) + } + +} + +func TestExtractScopeAndParams(t *testing.T) { + tests := []struct { + name string + s string + scope string + params string + wantOk bool + }{ + { + name: "empty", + s: "", + wantOk: true, + }, + { + name: "scope-only", + s: "device", + scope: "device", + wantOk: true, + }, + { + name: "scope-with-params", + s: "user(1234)", + scope: "user", + params: "1234", + wantOk: true, + }, + { + name: "params-empty-scope", + s: "(1234)", + scope: "", + params: "1234", + wantOk: true, + }, + { + name: "params-with-brackets", + s: "test()())))())", + scope: "test", + params: ")())))()", + wantOk: true, + }, + { + name: "no-closing-bracket", + s: "user(1234", + scope: "", + params: "", + wantOk: false, + }, + { + name: "open-before-close", + s: ")user(1234", + scope: "", + params: "", + wantOk: false, + }, + { + name: "brackets-only", + s: ")(", + scope: "", + params: "", + wantOk: false, + }, + { + name: "closing-bracket", + s: ")", + scope: "", + params: "", + wantOk: false, + }, + { + name: "opening-bracket", + s: ")", + scope: "", + params: "", + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scope, params, ok := extractScopeAndParams(tt.s) + if ok != tt.wantOk { + t.Logf("OK: got %v; want %v", ok, tt.wantOk) + } + if scope != tt.scope { + t.Logf("Scope: got %q; want %q", scope, tt.scope) + } + if params != tt.params { + t.Logf("Params: got %v; want %v", params, tt.params) + } + }) + } +} diff --git a/util/syspolicy/setting/raw_item.go b/util/syspolicy/setting/raw_item.go new file mode 100644 index 000000000..a901b505a --- /dev/null +++ b/util/syspolicy/setting/raw_item.go @@ -0,0 +1,47 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +// RawItem contains a raw policy setting as read from a policy store, or an +// error if the requested setting could not be read from the store. As a special +// case, it may also hold a value of the [Visibility], [PreferenceOption], +// or [time.Duration] types. While the policy store interface does not support +// these types natively, and the values of these types have to be unmarshalled +// or converted from strings, these setting types predate the typed policy +// hierarchies, and must be supported at this layer. +type RawItem struct { + value any + err *Error + origin *Origin // or nil +} + +// RawItemOf returns [RawItem] with the specified value. +func RawItemOf(value any) RawItem { + return RawItemWith(value, nil, nil) +} + +// RawItemWith returns an [RawItem] with the specified value, error and origin. +func RawItemWith(value any, err *Error, origin *Origin) RawItem { + return RawItem{value, err, origin} +} + +// Value returns the value of an untyped policy setting, +// or nil if the policy setting is not configured. +func (i RawItem) Value() any { + return i.value +} + +// Error returns the error that occurred when reading the policy setting, +// or nil if no error occurred. +func (i RawItem) Error() error { + if i.err != nil { + return i.err + } + return nil +} + +// Origin returns an optional [Origin] indicating the policy settings is configured. +func (i RawItem) Origin() *Origin { + return i.origin +} diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go new file mode 100644 index 000000000..e60aab12c --- /dev/null +++ b/util/syspolicy/setting/setting.go @@ -0,0 +1,352 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package setting contain types for policy settings. +package setting + +import ( + "fmt" + "slices" + "strings" + "sync" + "time" + + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/internal/lazyinit" +) + +// Scope indicates the broadest scope at which a policy setting may apply, +// and the narrowest scope at which it may be configured. +type Scope int8 + +const ( + // DeviceSetting indicates a policy setting that applies to a device, regardless of + // which OS user or Tailscale profile is currently active, if any. + // It can only be configured at a [DeviceScope]. + DeviceSetting Scope = iota + // ProfileSetting indicates a policy setting that applies to a Tailscale profile. + // It can only be configured for a specific profile or at a [DeviceScope], + // in which case it applies to all profiles on the device. + ProfileSetting + // UserSetting indicates a policy setting that applies to users. + // It can be configured for a user, profile, or the entire device. + UserSetting + + // MaxSettingScope is the maximum possible [Scope] value. + MaxSettingScope = UserSetting +) + +// String implements [fmt.Stringer]. +func (s Scope) String() string { + switch s { + case DeviceSetting: + return "Device" + case ProfileSetting: + return "Profile" + case UserSetting: + return "User" + default: + panic("unreachable") + } +} + +// MarshalText implements [encoding.TextMarshaler]. +func (s Scope) MarshalText() (text []byte, err error) { + return []byte(s.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler]. +func (s *Scope) UnmarshalText(text []byte) error { + switch strings.ToLower(string(text)) { + case "device": + *s = DeviceSetting + case "profile": + *s = ProfileSetting + case "user": + *s = UserSetting + default: + return fmt.Errorf("%q is not a valid scope", string(text)) + } + return nil +} + +// Type is a policy setting value type. +// Except for [InvalidValue], which represents an invalid policy setting type, +// and [PreferenceOptionValue], [VisibilityValue], and [DurationValue], +// which have special handling due to their legacy status in the package, +// SettingTypes represent the raw value types readable from policy stores. +type Type int + +const ( + // InvalidValue indicates an invalid policy setting value type. + InvalidValue Type = iota + // BooleanValue indicates a policy setting whose underlying type in the + // [source.Store] is a bool. + BooleanValue + // IntegerValue indicates a policy setting whose underlying type in the + // [source.Store] is a uint64. + IntegerValue + // StringValue indicates a policy setting whose underlying type in the + // [source.Store] is a string. + StringValue + // StringListValue indicates a policy setting whose underlying type in the + // [source.Store] is a []string. + StringListValue + // PreferenceOptionValue indicates a three-state policy setting whose + // underlying type in the [source.Store] is a string, but the actual value + // is a [PreferenceOption]. + PreferenceOptionValue + // VisibilityValue indicates a two-state boolean-like policy setting whose + // underlying type in the [source.Store] is a string, but the actual value + // is a [Visibility]. + VisibilityValue + // DurationValue indicates an interval/period/duration policy setting whose + // underlying type in the [source.Store] is a string, but the actual value + // is a [time.Duration]. + DurationValue +) + +// String returns a string representation of t. +func (t Type) String() string { + switch t { + case InvalidValue: + return "Invalid" + case BooleanValue: + return "Boolean" + case IntegerValue: + return "Integer" + case StringValue: + return "String" + case StringListValue: + return "StringList" + case PreferenceOptionValue: + return "PreferenceOption" + case VisibilityValue: + return "Visibility" + case DurationValue: + return "Duration" + default: + panic("unreachable") + } +} + +// ValueType is a constraint that allows Go types corresponding to [Type]. +type ValueType interface { + bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration +} + +// Definition defines policy key, scope and value type. +type Definition struct { + key Key + scope Scope + typ Type + platforms PlatformList +} + +// NewDefinition returns a new [Definition] with the specified +// key, scope, type and supported platforms (see [PlatformList]). +func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition { + return &Definition{key: k, scope: s, typ: t, platforms: platforms} +} + +// Key returns a policy setting's identifier. +func (d *Definition) Key() Key { + if d == nil { + return "" + } + return d.key +} + +// Scope reports the broadest [Scope] the policy setting may apply to. +func (d *Definition) Scope() Scope { + if d == nil { + return 0 + } + return d.scope +} + +// Type reports the underlying value type of the policy setting. +func (d *Definition) Type() Type { + if d == nil { + return InvalidValue + } + return d.typ +} + +// IsSupported reports whether the policy setting is supported on the current OS. +func (d *Definition) IsSupported() bool { + if d == nil { + return false + } + return d.platforms.HasCurrent() +} + +// SupportedPlatforms reports platforms on which the policy setting is supported. +// An empty [PlatformList] indicates that s is available on all platforms. +func (d *Definition) SupportedPlatforms() PlatformList { + if d == nil { + return nil + } + return d.platforms +} + +// String implements [fmt.Stringer]. +func (d *Definition) String() string { + if d == nil { + return "(nil)" + } + return fmt.Sprintf("%v(%q, %v)", d.scope, d.key, d.typ) +} + +// Equal reports whether d and d2 have the same key, type and scope. +// It does not check whether both s and s2 are supported on the same platforms. +func (d *Definition) Equal(d2 *Definition) bool { + if d == d2 { + return true + } + if d == nil || d2 == nil { + return false + } + return d.key == d2.key && d.typ == d2.typ && d.scope == d2.scope +} + +// DefinitionMap is a map of setting [Definition] by [Key]. +type DefinitionMap map[Key]*Definition + +var ( + definitions lazy.SyncValue[DefinitionMap] + + definitionsMu sync.Mutex + definitionsList []*Definition + definitionsUsed bool +) + +// Register registers a policy setting with the specified key, scope, and value type. +// All policy settings must be registered before any of them can be used. +// Register panics if called after invoking any syspolicy functions that use the +// registered policy definitions, such as functions that read the policy. +func Register(k Key, s Scope, t Type, platforms ...string) { + RegisterDefinition(NewDefinition(k, s, t, platforms...)) +} + +// RegisterDefinition is like [Register], but accepts a [Definition]. +func RegisterDefinition(d *Definition) { + definitionsMu.Lock() + defer definitionsMu.Unlock() + registerLocked(d) +} + +func registerLocked(d *Definition) { + if definitionsUsed { + panic("policy definitions are already in use") + } + definitionsList = append(definitionsList, d) +} + +func settingDefinitions() (DefinitionMap, error) { + return definitions.GetErr(func() (DefinitionMap, error) { + lazyinit.Do() + definitionsMu.Lock() + defer definitionsMu.Unlock() + definitionsUsed = true + return DefinitionMapOf(definitionsList) + }) +} + +// DefinitionMapOf returns a [DefinitionMap] with the specified settings, +// or an error if any settings have the same key but different type or scope. +func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) { + m := make(DefinitionMap, len(settings)) + for _, s := range settings { + if existing, exists := m[s.key]; exists { + if existing.Equal(s) { + // Ignore duplicate setting definitions if they match. It is acceptable + // if the same policy setting was registered more than once + // (e.g. by the syspolicy package itself and by iOS/Android code). + existing.platforms.mergeFrom(s.platforms) + continue + } + return nil, fmt.Errorf("duplicate policy definition: %q", s.key) + } + m[s.key] = s + } + return m, nil +} + +// SetDefinitionsForTest allows to register the specified setting definitions +// for the test duration. It is not concurrency-safe, but unlike [Register], +// it does not panic and can be called anytime. +// It returns an error if ds contains two different settings with the same [Key]. +func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error { + m, err := DefinitionMapOf(ds) + if err != nil { + return err + } + definitions.SetForTest(tb, m, err) + return nil +} + +// DefinitionOf returns a setting definition by key, +// or [ErrNoSuchKey] if the specified key does not exist, +// or an error if there are conflicting policy definitions. +func DefinitionOf(k Key) (*Definition, error) { + ds, err := settingDefinitions() + if err != nil { + return nil, err + } + if d, ok := ds[k]; ok { + return d, nil + } + return nil, ErrNoSuchKey +} + +// Definitions returns all registered setting definitions, +// or an error if different policies were registered under the same name. +func Definitions() ([]*Definition, error) { + ds, err := settingDefinitions() + if err != nil { + return nil, err + } + res := make([]*Definition, 0, len(ds)) + for _, d := range ds { + res = append(res, d) + } + return res, nil +} + +// PlatformList is a list of OSes. +// An empty list indicates that all possible platforms are supported. +type PlatformList []string + +// Has reports whether the list contains the target platform. +func (l PlatformList) Has(target string) bool { + if len(l) == 0 { + return true + } + return slices.ContainsFunc(l, func(os string) bool { + return strings.EqualFold(os, target) + }) +} + +// HasCurrent is like Has, but for the current platform. +func (l PlatformList) HasCurrent() bool { + return l.Has(internal.OS()) +} + +// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions, +// if either l or l2 is empty, the merged result in l will also be empty. +func (l *PlatformList) mergeFrom(l2 PlatformList) { + switch { + case len(*l) == 0: + // No-op. An empty list indicates no platform restrictions. + case len(l2) == 0: + // Merging with an empty list results in an empty list. + *l = l2 + default: + // Append, sort and dedup. + *l = append(*l, l2...) + slices.Sort(*l) + *l = slices.Compact(*l) + } +} diff --git a/util/syspolicy/setting/setting_test.go b/util/syspolicy/setting/setting_test.go new file mode 100644 index 000000000..3cc08e7da --- /dev/null +++ b/util/syspolicy/setting/setting_test.go @@ -0,0 +1,344 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "slices" + "strings" + "testing" + + "tailscale.com/types/lazy" + "tailscale.com/types/ptr" + "tailscale.com/util/syspolicy/internal" +) + +func TestSettingDefinition(t *testing.T) { + tests := []struct { + name string + setting *Definition + osOverride string + wantKey Key + wantScope Scope + wantType Type + wantIsSupported bool + wantSupportedPlatforms PlatformList + wantString string + }{ + { + name: "Nil", + setting: nil, + wantKey: "", + wantScope: 0, + wantType: InvalidValue, + wantIsSupported: false, + wantString: "(nil)", + }, + { + name: "Device/Invalid", + setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, InvalidValue), + wantKey: "TestDevicePolicySetting", + wantScope: DeviceSetting, + wantType: InvalidValue, + wantIsSupported: true, + wantString: `Device("TestDevicePolicySetting", Invalid)`, + }, + { + name: "Device/Integer", + setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue), + wantKey: "TestDevicePolicySetting", + wantScope: DeviceSetting, + wantType: IntegerValue, + wantIsSupported: true, + wantString: `Device("TestDevicePolicySetting", Integer)`, + }, + { + name: "Profile/String", + setting: NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue), + wantKey: "TestProfilePolicySetting", + wantScope: ProfileSetting, + wantType: StringValue, + wantIsSupported: true, + wantString: `Profile("TestProfilePolicySetting", String)`, + }, + { + name: "Device/StringList", + setting: NewDefinition("AllowedSuggestedExitNodes", DeviceSetting, StringListValue), + wantKey: "AllowedSuggestedExitNodes", + wantScope: DeviceSetting, + wantType: StringListValue, + wantIsSupported: true, + wantString: `Device("AllowedSuggestedExitNodes", StringList)`, + }, + { + name: "Device/PreferenceOption", + setting: NewDefinition("AdvertiseExitNode", DeviceSetting, PreferenceOptionValue), + wantKey: "AdvertiseExitNode", + wantScope: DeviceSetting, + wantType: PreferenceOptionValue, + wantIsSupported: true, + wantString: `Device("AdvertiseExitNode", PreferenceOption)`, + }, + { + name: "User/Boolean", + setting: NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue), + wantKey: "TestUserPolicySetting", + wantScope: UserSetting, + wantType: BooleanValue, + wantIsSupported: true, + wantString: `User("TestUserPolicySetting", Boolean)`, + }, + { + name: "User/Visibility", + setting: NewDefinition("AdminConsole", UserSetting, VisibilityValue), + wantKey: "AdminConsole", + wantScope: UserSetting, + wantType: VisibilityValue, + wantIsSupported: true, + wantString: `User("AdminConsole", Visibility)`, + }, + { + name: "User/Duration", + setting: NewDefinition("KeyExpirationNotice", UserSetting, DurationValue), + wantKey: "KeyExpirationNotice", + wantScope: UserSetting, + wantType: DurationValue, + wantIsSupported: true, + wantString: `User("KeyExpirationNotice", Duration)`, + }, + { + name: "SupportedSetting", + setting: NewDefinition("DesktopPolicySetting", DeviceSetting, StringValue, "macos", "windows"), + osOverride: "windows", + wantKey: "DesktopPolicySetting", + wantScope: DeviceSetting, + wantType: StringValue, + wantIsSupported: true, + wantSupportedPlatforms: PlatformList{"macos", "windows"}, + wantString: `Device("DesktopPolicySetting", String)`, + }, + { + name: "UnsupportedSetting", + setting: NewDefinition("AndroidPolicySetting", DeviceSetting, StringValue, "android"), + osOverride: "macos", + wantKey: "AndroidPolicySetting", + wantScope: DeviceSetting, + wantType: StringValue, + wantIsSupported: false, + wantSupportedPlatforms: PlatformList{"android"}, + wantString: `Device("AndroidPolicySetting", String)`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.osOverride != "" { + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + } + if !tt.setting.Equal(tt.setting) { + t.Errorf("the setting should be equal to itself") + } + if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) { + t.Errorf("the setting should be equal to its shallow copy") + } + if gotKey := tt.setting.Key(); gotKey != tt.wantKey { + t.Errorf("Key: got %q, want %q", gotKey, tt.wantKey) + } + if gotScope := tt.setting.Scope(); gotScope != tt.wantScope { + t.Errorf("Scope: got %v, want %v", gotScope, tt.wantScope) + } + if gotType := tt.setting.Type(); gotType != tt.wantType { + t.Errorf("Type: got %v, want %v", gotType, tt.wantType) + } + if gotIsSupported := tt.setting.IsSupported(); gotIsSupported != tt.wantIsSupported { + t.Errorf("IsSupported: got %v, want %v", gotIsSupported, tt.wantIsSupported) + } + if gotSupportedPlatforms := tt.setting.SupportedPlatforms(); !slices.Equal(gotSupportedPlatforms, tt.wantSupportedPlatforms) { + t.Errorf("SupportedPlatforms: got %v, want %v", gotSupportedPlatforms, tt.wantSupportedPlatforms) + } + if gotString := tt.setting.String(); gotString != tt.wantString { + t.Errorf("String: got %v, want %v", gotString, tt.wantString) + } + }) + } +} + +func TestRegisterSettingDefinition(t *testing.T) { + const testPolicySettingKey Key = "TestPolicySetting" + tests := []struct { + name string + key Key + wantEq *Definition + wantErr error + }{ + { + name: "GetRegistered", + key: "TestPolicySetting", + wantEq: NewDefinition(testPolicySettingKey, DeviceSetting, StringValue), + }, + { + name: "GetNonRegistered", + key: "OtherPolicySetting", + wantEq: nil, + wantErr: ErrNoSuchKey, + }, + } + + resetSettingDefinitions(t) + Register(testPolicySettingKey, DeviceSetting, StringValue) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := DefinitionOf(tt.key) + if gotErr != tt.wantErr { + t.Errorf("gotErr %v, wantErr %v", gotErr, tt.wantErr) + } + if !got.Equal(tt.wantEq) { + t.Errorf("got %v, want %v", got, tt.wantEq) + } + }) + } +} + +func TestRegisterAfterUsePanics(t *testing.T) { + resetSettingDefinitions(t) + + Register("TestPolicySetting", DeviceSetting, StringValue) + DefinitionOf("TestPolicySetting") + + func() { + defer func() { + if gotPanic, wantPanic := recover(), "policy definitions are already in use"; gotPanic != wantPanic { + t.Errorf("gotPanic: %q, wantPanic: %q", gotPanic, wantPanic) + } + }() + + Register("TestPolicySetting", DeviceSetting, StringValue) + }() +} + +func TestRegisterDuplicateSettings(t *testing.T) { + + tests := []struct { + name string + settings []*Definition + wantEq *Definition + wantErrStr string + }{ + { + name: "NoConflict/Exact", + settings: []*Definition{ + NewDefinition("TestPolicySetting", DeviceSetting, StringValue), + NewDefinition("TestPolicySetting", DeviceSetting, StringValue), + }, + wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), + }, + { + name: "NoConflict/MergeOS-First", + settings: []*Definition{ + NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"), + NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms + }, + wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms + }, + { + name: "NoConflict/MergeOS-Second", + settings: []*Definition{ + NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms + NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"), + }, + wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms + }, + { + name: "NoConflict/MergeOS-Both", + settings: []*Definition{ + NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos"), + NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "windows"), + }, + wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos", "windows"), + }, + { + name: "Conflict/Scope", + settings: []*Definition{ + NewDefinition("TestPolicySetting", DeviceSetting, StringValue), + NewDefinition("TestPolicySetting", UserSetting, StringValue), + }, + wantEq: nil, + wantErrStr: `duplicate policy definition: "TestPolicySetting"`, + }, + { + name: "Conflict/Type", + settings: []*Definition{ + NewDefinition("TestPolicySetting", UserSetting, StringValue), + NewDefinition("TestPolicySetting", UserSetting, IntegerValue), + }, + wantEq: nil, + wantErrStr: `duplicate policy definition: "TestPolicySetting"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetSettingDefinitions(t) + for _, s := range tt.settings { + Register(s.Key(), s.Scope(), s.Type(), s.SupportedPlatforms()...) + } + got, err := DefinitionOf("TestPolicySetting") + var gotErrStr string + if err != nil { + gotErrStr = err.Error() + } + if gotErrStr != tt.wantErrStr { + t.Fatalf("ErrStr: got %q, want %q", gotErrStr, tt.wantErrStr) + } + if !got.Equal(tt.wantEq) { + t.Errorf("Definition got %v, want %v", got, tt.wantEq) + } + if !slices.Equal(got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) { + t.Errorf("SupportedPlatforms got %v, want %v", got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) + } + }) + } +} + +func TestListSettingDefinitions(t *testing.T) { + definitions := []*Definition{ + NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue), + NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue), + NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue), + NewDefinition("TestStringListPolicySetting", DeviceSetting, StringListValue), + } + if err := SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + cmp := func(l, r *Definition) int { + return strings.Compare(string(l.Key()), string(r.Key())) + } + want := append([]*Definition{}, definitions...) + slices.SortFunc(want, cmp) + + got, err := Definitions() + if err != nil { + t.Fatalf("Definitions failed: %v", err) + } + slices.SortFunc(got, cmp) + + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func resetSettingDefinitions(t *testing.T) { + t.Cleanup(func() { + definitionsMu.Lock() + definitionsList = nil + definitions = lazy.SyncValue[DefinitionMap]{} + definitionsUsed = false + definitionsMu.Unlock() + }) + + definitionsMu.Lock() + definitionsList = nil + definitions = lazy.SyncValue[DefinitionMap]{} + definitionsUsed = false + definitionsMu.Unlock() +} diff --git a/util/syspolicy/setting/snapshot.go b/util/syspolicy/setting/snapshot.go new file mode 100644 index 000000000..4f4934a72 --- /dev/null +++ b/util/syspolicy/setting/snapshot.go @@ -0,0 +1,153 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + xmaps "golang.org/x/exp/maps" + "tailscale.com/util/deephash" +) + +// Snapshot is an immutable collection of [RawItem]s, representing +// a set of policy settings applied at a specific moment in time. +// A nil pointer to [Snapshot] is valid. +type Snapshot struct { + m map[Key]RawItem + sig deephash.Sum // of m + summary Summary +} + +// NewSnapshot returns a new [Snapshot] with the specified items and options. +func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot { + return &Snapshot{m: items, sig: deephash.Hash(&items), summary: SummaryWith(opts...)} +} + +type keyItemPair struct { + Key Key + Item RawItem +} + +// All returns an iterator over [[Key], [RawItem]] key-value pairs in b. The +// iteration order is not specified and is not guaranteed to be the same from +// one call to the next. +func (s *Snapshot) All() []keyItemPair { + if s == nil { + return nil + } + // TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23, + // and remove [keyItemPair]. + items := make([]keyItemPair, 0, len(s.m)) + for k, i := range s.m { + items = append(items, keyItemPair{k, i}) + } + return items +} + +// Get returns the value of the policy setting with the specified key +// or nil if it does not exist or could not be read. +func (s *Snapshot) Get(k Key) any { + v, _ := s.GetErr(k) + return v +} + +// GetErr returns the value of the policy setting with the specified key, +// [ErrNotConfigured] if it does not exist, or an error returned by +// the policy Store if the policy setting could not be read. +func (s *Snapshot) GetErr(k Key) (any, error) { + if s != nil { + if s, ok := s.m[k]; ok { + return s.Value(), s.Error() + } + } + return nil, ErrNotConfigured +} + +// GetSetting returns the untyped policy setting with the specified key and true +// if a policy setting with such key has been configured; +// otherwise, it returns zero, false. +func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) { + setting, ok = s.m[k] + return setting, ok +} + +// Equal reports whether s and s2 are equal. +func (s *Snapshot) Equal(s2 *Snapshot) bool { + if !s.EqualItems(s2) { + return false + } + return s.Summary() == s2.Summary() +} + +// EqualItems reports whether items in s and s2 are equal. +func (s *Snapshot) EqualItems(s2 *Snapshot) bool { + if s == s2 { + return true + } + if s.Len() != s2.Len() { + return false + } + if s.Len() == 0 { + return true + } + return s.sig == s2.sig +} + +// Keys return an iterator over keys in s. The iteration order is not specified +// and is not guaranteed to be the same from one call to the next. +func (s *Snapshot) Keys() []Key { + if s.m == nil { + return nil + } + // TODO(nickkhyl): return iter.Seq[Key] in Go 1.23. + return xmaps.Keys(s.m) +} + +// Len reports the number of [RawItem]s in s. +func (s *Snapshot) Len() int { + if s == nil { + return 0 + } + return len(s.m) +} + +// Summary returns information about s as a whole rather than about specific [RawItem]s in it. +func (s *Snapshot) Summary() Summary { + if s == nil { + return Summary{} + } + return s.summary +} + +// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s +// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope]. +// If there's a conflict between policy settings in the two snapshots, +// the policy settings from the snapshot with the broader scope take precedence. +// In other words, policy settings configured for the [DeviceScope] win +// over policy settings configured for a user scope. +func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot { + scope1, ok1 := snapshot1.Summary().Scope().GetOk() + scope2, ok2 := snapshot2.Summary().Scope().GetOk() + if ok1 && ok2 && scope2.IsStrictlyWithinOf(scope1) { + // Swap snapshots if snapshot1 has higher precedence than snapshot2. + snapshot1, snapshot2 = snapshot2, snapshot1 + } + if snapshot2.Len() == 0 { + return snapshot1 + } + summaryOpts := make([]SummaryOption, 0, 2) + if scope, ok := snapshot1.Summary().Scope().GetOk(); ok { + // Use the scope from snapshot1, if present, which is the more specific snapshot. + summaryOpts = append(summaryOpts, scope) + } + if snapshot1.Len() == 0 { + if origin, ok := snapshot2.Summary().Origin().GetOk(); ok { + // Use the origin from snapshot2 if snapshot1 is empty. + summaryOpts = append(summaryOpts, origin) + } + return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)} + } + m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len()) + xmaps.Copy(m, snapshot1.m) + xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence + return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)} +} diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go new file mode 100644 index 000000000..378fa6033 --- /dev/null +++ b/util/syspolicy/setting/snapshot_test.go @@ -0,0 +1,372 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "testing" + "time" +) + +func TestMergeSnapshots(t *testing.T) { + tests := []struct { + name string + s1, s2 *Snapshot + want *Snapshot + }{ + { + name: "both-nil", + s1: nil, + s2: nil, + want: NewSnapshot(map[Key]RawItem{}), + }, + { + name: "both-empty", + s1: NewSnapshot(map[Key]RawItem{}), + s2: NewSnapshot(map[Key]RawItem{}), + want: NewSnapshot(map[Key]RawItem{}), + }, + { + name: "first-nil", + s1: nil, + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + }, + { + name: "first-empty", + s1: NewSnapshot(map[Key]RawItem{}), + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + }, + { + name: "second-nil", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + s2: nil, + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + }, + { + name: "second-empty", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + s2: NewSnapshot(map[Key]RawItem{}), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + }, + { + name: "no-conflicts", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + s2: NewSnapshot(map[Key]RawItem{ + "Setting4": {value: 2 * time.Hour}, + "Setting5": {value: VisibleByPolicy}, + "Setting6": {value: ShowChoiceByPolicy}, + }), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + "Setting5": {value: VisibleByPolicy}, + "Setting6": {value: ShowChoiceByPolicy}, + }), + }, + { + name: "with-conflicts", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 456}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + }), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 456}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + }), + }, + { + name: "with-scope-first-wins", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }, DeviceScope), + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 456}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + }, CurrentUserScope), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + "Setting4": {value: 2 * time.Hour}, + }, CurrentUserScope), + }, + { + name: "with-scope-second-wins", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }, CurrentUserScope), + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 456}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + }, DeviceScope), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 456}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + "Setting4": {value: 2 * time.Hour}, + }, CurrentUserScope), + }, + { + name: "with-scope-both-empty", + s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), + s2: NewSnapshot(map[Key]RawItem{}, DeviceScope), + want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), + }, + { + name: "with-scope-first-empty", + s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), + s2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}}, DeviceScope), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }, CurrentUserScope), + }, + { + name: "with-scope-second-empty", + s1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }, CurrentUserScope), + s2: NewSnapshot(map[Key]RawItem{}, DeviceScope), + want: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }, CurrentUserScope), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MergeSnapshots(tt.s1, tt.s2) + if !got.Equal(tt.want) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestSnapshotEqual(t *testing.T) { + tests := []struct { + name string + b1, b2 *Snapshot + wantEqual bool + wantEqualItems bool + }{ + { + name: "nil-nil", + b1: nil, + b2: nil, + wantEqual: true, + wantEqualItems: true, + }, + { + name: "nil-empty", + b1: nil, + b2: NewSnapshot(map[Key]RawItem{}), + wantEqual: true, + wantEqualItems: true, + }, + { + name: "empty-nil", + b1: NewSnapshot(map[Key]RawItem{}), + b2: nil, + wantEqual: true, + wantEqualItems: true, + }, + { + name: "empty-empty", + b1: NewSnapshot(map[Key]RawItem{}), + b2: NewSnapshot(map[Key]RawItem{}), + wantEqual: true, + wantEqualItems: true, + }, + { + name: "first-nil", + b1: nil, + b2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + wantEqual: false, + wantEqualItems: false, + }, + { + name: "first-empty", + b1: NewSnapshot(map[Key]RawItem{}), + b2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + wantEqual: false, + wantEqualItems: false, + }, + { + name: "second-nil", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: true}, + }), + b2: nil, + wantEqual: false, + wantEqualItems: false, + }, + { + name: "second-empty", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + b2: NewSnapshot(map[Key]RawItem{}), + wantEqual: false, + wantEqualItems: false, + }, + { + name: "same-items-same-order-no-scope", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + b2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }), + wantEqual: true, + wantEqualItems: true, + }, + { + name: "same-items-same-order-same-scope", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, DeviceScope), + b2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, DeviceScope), + wantEqual: true, + wantEqualItems: true, + }, + { + name: "same-items-different-order-same-scope", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, DeviceScope), + b2: NewSnapshot(map[Key]RawItem{ + "Setting3": {value: false}, + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + }, DeviceScope), + wantEqual: true, + wantEqualItems: true, + }, + { + name: "same-items-same-order-different-scope", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, DeviceScope), + b2: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, CurrentUserScope), + wantEqual: false, + wantEqualItems: true, + }, + { + name: "different-items-same-scope", + b1: NewSnapshot(map[Key]RawItem{ + "Setting1": {value: 123}, + "Setting2": {value: "String"}, + "Setting3": {value: false}, + }, DeviceScope), + b2: NewSnapshot(map[Key]RawItem{ + "Setting4": {value: 2 * time.Hour}, + "Setting5": {value: VisibleByPolicy}, + "Setting6": {value: ShowChoiceByPolicy}, + }, DeviceScope), + wantEqual: false, + wantEqualItems: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotEqual := tt.b1.Equal(tt.b2); gotEqual != tt.wantEqual { + t.Errorf("WantEqual: got %v, want %v", gotEqual, tt.wantEqual) + } + if gotEqualItems := tt.b1.EqualItems(tt.b2); gotEqualItems != tt.wantEqualItems { + t.Errorf("WantEqualItems: got %v, want %v", gotEqualItems, tt.wantEqualItems) + } + }) + } +} diff --git a/util/syspolicy/setting/summary.go b/util/syspolicy/setting/summary.go new file mode 100644 index 000000000..5855b22e3 --- /dev/null +++ b/util/syspolicy/setting/summary.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/opt" +) + +// Summary is an immutable [PolicyScope] and [Origin]. +type Summary struct { + data summary +} + +type summary struct { + Scope opt.Value[PolicyScope] `json:",omitzero"` + Origin opt.Value[Origin] `json:",omitzero"` +} + +// SummaryWith returns a [Summary] with the specified options. +func SummaryWith(opts ...SummaryOption) Summary { + var summary Summary + for _, o := range opts { + o.applySummaryOption(&summary) + } + return summary +} + +// Scope reports the [PolicyScope] in s. +func (s Summary) Scope() opt.Value[PolicyScope] { + return s.data.Scope +} + +// Origin reports the [Origin] in s. +func (s Summary) Origin() opt.Value[Origin] { + return s.data.Origin +} + +// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { + return jsonv2.MarshalEncode(out, &s.data, opts) +} + +// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. +func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { + return jsonv2.UnmarshalDecode(in, &s.data, opts) +} + +// MarshalJSON implements [json.Marshaler]. +func (s Summary) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(s) // uses MarshalJSONV2 +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (s *Summary) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 +} + +// SummaryOption is an option that configures [Summary] +// The following are allowed options: +// +// - [Summary] +// - [PolicyScope] +// - [Origin] +type SummaryOption interface { + applySummaryOption(summary *Summary) +} + +func (s PolicyScope) applySummaryOption(summary *Summary) { + summary.data.Scope.Set(s) +} + +func (o Origin) applySummaryOption(summary *Summary) { + summary.data.Origin.Set(o) + if !summary.data.Scope.IsSet() { + summary.data.Scope.Set(o.Scope()) + } +} + +func (s Summary) applySummaryOption(summary *Summary) { + *summary = s +} diff --git a/util/syspolicy/setting/types.go b/util/syspolicy/setting/types.go new file mode 100644 index 000000000..16f9e7445 --- /dev/null +++ b/util/syspolicy/setting/types.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "encoding" +) + +// PreferenceOption is a policy that governs whether a boolean variable +// is forcibly assigned an administrator-defined value, or allowed to receive +// a user-defined value. +type PreferenceOption int + +const ( + ShowChoiceByPolicy PreferenceOption = iota + NeverByPolicy + AlwaysByPolicy +) + +// Show returns if the UI option that controls the choice administered by this +// policy should be shown. Currently this is true if and only if the policy is +// [ShowChoiceByPolicy]. +func (p PreferenceOption) Show() bool { + return p == ShowChoiceByPolicy +} + +// ShouldEnable checks if the choice administered by this policy should be +// enabled. If the administrator has chosen a setting, the administrator's +// setting is returned, otherwise userChoice is returned. +func (p PreferenceOption) ShouldEnable(userChoice bool) bool { + switch p { + case NeverByPolicy: + return false + case AlwaysByPolicy: + return true + default: + return userChoice + } +} + +// IsAlways reports whether the preference should always be enabled. +func (p PreferenceOption) IsAlways() bool { + return p == AlwaysByPolicy +} + +// IsNever reports whether the preference should always be disabled. +func (p PreferenceOption) IsNever() bool { + return p == NeverByPolicy +} + +// WillOverride checks if the choice administered by the policy is different +// from the user's choice. +func (p PreferenceOption) WillOverride(userChoice bool) bool { + return p.ShouldEnable(userChoice) != userChoice +} + +// String returns a string representation of p. +func (p PreferenceOption) String() string { + switch p { + case AlwaysByPolicy: + return "always" + case NeverByPolicy: + return "never" + default: + return "user-decides" + } +} + +// MarshalText implements [encoding.TextMarshaler]. +func (p *PreferenceOption) MarshalText() (text []byte, err error) { + return []byte(p.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler]. +func (p *PreferenceOption) UnmarshalText(text []byte) error { + switch string(text) { + case "always": + *p = AlwaysByPolicy + case "never": + *p = NeverByPolicy + default: + *p = ShowChoiceByPolicy + } + return nil +} + +// Visibility is a policy that controls whether or not a particular +// component of a user interface is to be shown. +type Visibility byte + +var ( + _ encoding.TextMarshaler = (*Visibility)(nil) + _ encoding.TextUnmarshaler = (*Visibility)(nil) +) + +const ( + VisibleByPolicy Visibility = 'v' + HiddenByPolicy Visibility = 'h' +) + +// Show reports whether the UI option administered by this policy should be shown. +// Currently this is true if the policy is not [hiddenByPolicy]. +func (p Visibility) Show() bool { + return p != HiddenByPolicy +} + +// String returns a string representation of p. +func (p Visibility) String() string { + switch p { + case 'h': + return "hide" + default: + return "show" + } +} + +// MarshalText implements [encoding.TextMarshaler]. +func (p Visibility) MarshalText() (text []byte, err error) { + return []byte(p.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler]. +func (p *Visibility) UnmarshalText(text []byte) error { + switch string(text) { + case "hide": + *p = HiddenByPolicy + default: + *p = VisibleByPolicy + } + return nil +} diff --git a/util/syspolicy/source/policy_reader.go b/util/syspolicy/source/policy_reader.go new file mode 100644 index 000000000..e608fd0da --- /dev/null +++ b/util/syspolicy/source/policy_reader.go @@ -0,0 +1,393 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "io" + "slices" + "sort" + "sync" + "time" + + "tailscale.com/util/mak" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" + "tailscale.com/util/syspolicy/setting" +) + +// Reader reads all configured policy settings from a given [Store]. +// It registers a change callback with the [Store] and maintains the current version +// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store] +// whenever a new snapshot is requested +// It is safe for concurrent use. +type Reader struct { + store Store + origin *setting.Origin + settings []*setting.Definition + unregisterChangeNotifier func() + doneCh chan struct{} // closed when policyCache is closed. + + mu sync.RWMutex + closing bool + upToDate bool + lastPolicy *setting.Snapshot + sessions set.HandleSet[*ReadingSession] +} + +// newReader returns a new [Reader] that reads policy settings from a given [Store]. +// The returned reader takes ownership of the store. If the store implements [io.Closer], +// the returned reader will close the store when it is closed. +func newReader(store Store, origin *setting.Origin) (*Reader, error) { + settings, err := setting.Definitions() + if err != nil { + return nil, err + } + + if expirable, ok := store.(Expirable); ok { + select { + case <-expirable.Done(): + return nil, ErrStoreClosed + default: + } + } + + reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})} + if changeable, ok := store.(Changeable); ok { + // We should subscribe to policy change notifications first before reading + // the policy settings from the store. This way we won't miss any notifications. + if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil { + // Errors registering policy change callbacks are non-fatal. + // TODO(nickkhyl): implement a background policy refresh every X minutes? + loggerx.Errorf("failed to register %v policy change callback: %v\n", origin, err) + } + } + + if _, err := reader.reload(true); err != nil { + if reader.unregisterChangeNotifier != nil { + reader.unregisterChangeNotifier() + } + return nil, err + } + + if expirable, ok := store.(Expirable); ok { + if waitCh := expirable.Done(); waitCh != nil { + go func() { + select { + case <-waitCh: + reader.Close() + case <-reader.doneCh: + } + }() + } + } + + return reader, nil +} + +// GetSettings returns the current [*setting.Snapshot], +// re-reading it from from the underlying [Store] only if the policy +// has changed since it was read last. It never fails and returns +// the previous version of the policy settings if a read attempt fails. +func (r *Reader) GetSettings() *setting.Snapshot { + r.mu.RLock() + if r.upToDate { + r.mu.RUnlock() + return r.lastPolicy + } + r.mu.RUnlock() + + policy, err := r.reload(false) + if err != nil { + // If the policy could not be reloaded at all, we'll return the last cached version of it. + // On the contrary, errors specific to individual policy items are always propagated to the callers. + loggerx.Errorf("failed to reload %v policy: %v\n", r.origin, err) + } + return policy +} + +// ReadSettings reads policy settings from the underlying [Store] even if no +// changes were detected. It returns the new [*setting.Snapshot], nil on +// success, or nil, error in case of failure. +func (r *Reader) ReadSettings() (*setting.Snapshot, error) { + b, err := r.reload(true) + if err != nil { + return nil, err + } + return b, nil +} + +// reload is like [Reader.ReadSettings], but allows specifying whether to re-read +// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails. +func (r *Reader) reload(force bool) (*setting.Snapshot, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.upToDate && !force { + return r.lastPolicy, nil + } + + if lockable, ok := r.store.(Lockable); ok { + if err := lockable.Lock(); err != nil { + return r.lastPolicy, err + } + defer lockable.Unlock() + } + + r.upToDate = true + + metrics.Reset(r.origin) + + var m map[setting.Key]setting.RawItem + if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 { + m = make(map[setting.Key]setting.RawItem, lastPolicyCount) + } + for _, s := range r.settings { + if !r.origin.Scope().IsConfigurableSetting(s) { + // Skip settings that cannot be configured in the current scope. + continue + } + + val, err := readPolicySettingValue(r.store, s) + if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) { + metrics.ReportNotConfigured(r.origin, s) + continue + } + + if err == nil { + metrics.ReportConfigured(r.origin, s, val) + } else { + metrics.ReportError(r.origin, s, err) + } + + // If there's an error reading a single policy, such as a value type mismatch, + // we'll wrap the error to preserve its text and return it + // whenever someone attempts to fetch the value. + mak.Set(&m, s.Key(), setting.RawItemWith(val, setting.WrapError(err), r.origin)) + } + + newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin)) + if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) { + r.lastPolicy = newPolicy + } + return r.lastPolicy, nil +} + +// ReadingSession is like [Reader], but with a channel that's written +// to when there's a policy change, and closed when the session is terminated. +type ReadingSession struct { + reader *Reader + policyChangedCh chan struct{} // 1-buffered channel + handle set.Handle // in the reader.sessions + closeInternal func() +} + +// OpenSession opens and returns a new session to r, allowing the caller +// to get notified whenever a policy change is reported by the [source.Store], +// or an [ErrStoreClosed] if the reader has already been closed. +func (r *Reader) OpenSession() (*ReadingSession, error) { + session := &ReadingSession{ + reader: r, + policyChangedCh: make(chan struct{}, 1), + } + session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) }) + r.mu.Lock() + if !r.closing { + session.handle = r.sessions.Add(session) + r.mu.Unlock() + return session, nil + } + r.mu.Unlock() + return nil, ErrStoreClosed +} + +// GetSettings is like [Reader.GetSettings]. +func (s *ReadingSession) GetSettings() *setting.Snapshot { + return s.reader.GetSettings() +} + +// ReadSettings is like [Reader.ReadSettings]. +func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) { + return s.reader.ReadSettings() +} + +// PolicyChanged returns a channel that's written to when +// there's a policy change, closed when the session is terminated. +func (s *ReadingSession) PolicyChanged() <-chan struct{} { + return s.policyChangedCh +} + +// Close unregisters this session with the [Reader]. +func (s *ReadingSession) Close() { + s.reader.mu.Lock() + delete(s.reader.sessions, s.handle) + s.closeInternal() + s.reader.mu.Unlock() +} + +// onPolicyChange handles a policy change notification from the [Store], +// invalidating the current [setting.Snapshot] in r, +// and notifying the active [ReadingSession]s. +func (r *Reader) onPolicyChange() { + r.mu.Lock() + defer r.mu.Unlock() + r.upToDate = false + for _, s := range r.sessions { + select { + case s.policyChangedCh <- struct{}{}: + // Notified. + default: + // 1-buffered channel is full, meaning that another policy change + // notification is already en route. + } + } +} + +// Close closes the store reader and the underlying store. +func (r *Reader) Close() error { + r.mu.Lock() + if r.closing { + r.mu.Unlock() + return nil + } + r.closing = true + r.mu.Unlock() + + if r.unregisterChangeNotifier != nil { + r.unregisterChangeNotifier() + r.unregisterChangeNotifier = nil + } + + if closer, ok := r.store.(io.Closer); ok { + if err := closer.Close(); err != nil { + return err + } + } + r.store = nil + + close(r.doneCh) + + r.mu.Lock() + defer r.mu.Unlock() + for _, c := range r.sessions { + c.closeInternal() + } + r.sessions = nil + return nil +} + +// Done returns a channel that is closed when the reader is closed. +func (r *Reader) Done() <-chan struct{} { + return r.doneCh +} + +// ReadableSource is a [Source] open for reading. +type ReadableSource struct { + *Source + *ReadingSession +} + +// Close closes the underlying [ReadingSession]. +func (s ReadableSource) Close() { + s.ReadingSession.Close() +} + +// ReadableSources is a slice of [ReadableSource]. +type ReadableSources []ReadableSource + +// Contains reports whether s contains the specified source. +func (s ReadableSources) Contains(source *Source) bool { + return s.IndexOf(source) != -1 +} + +// IndexOf returns position of the specified source in s, or -1 +// if the source does not exist. +func (s ReadableSources) IndexOf(source *Source) int { + return slices.IndexFunc(s, func(rs ReadableSource) bool { + return rs.Source == source + }) +} + +// InsertionIndexOf returns the position at which source can be inserted +// to maintain the sorted order of the readableSources. +// The return value is unspecified if s is not sorted on entry to InsertionIndexOf. +func (s ReadableSources) InsertionIndexOf(source *Source) int { + low, high := 0, len(s) + for low < high { + mid := (low + high) / 2 + if s[mid].Compare(source) <= 0 { + low = mid + 1 + } else { + high = mid + } + } + return low +} + +// StableSort sorts the readableSources by the precedence, so that policy settings +// from sources with higher precedence (e.g., [DeviceScope]) will be merged last, +// overriding any policy settings with the same keys configured in sources with +// lower precedence (e.g., [CurrentUserScope]). +func (s *ReadableSources) StableSort() { + sort.SliceStable(*s, func(i, j int) bool { + return (*s)[i].Source.Compare((*s)[j].Source) < 0 + }) +} + +// DeleteAt closes and deletes the i-th source from s. +func (s *ReadableSources) DeleteAt(i int) { + (*s)[i].Close() + *s = slices.Delete(*s, i, i+1) +} + +// Close closes and deletes all sources in s. +func (s *ReadableSources) Close() { + for _, s := range *s { + s.Close() + } + *s = nil +} + +func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) { + switch key := s.Key(); s.Type() { + case setting.BooleanValue: + return store.ReadBoolean(key) + case setting.IntegerValue: + return store.ReadUInt64(key) + case setting.StringValue: + return store.ReadString(key) + case setting.StringListValue: + return store.ReadStringArray(key) + case setting.PreferenceOptionValue: + s, err := store.ReadString(key) + if err == nil { + var value setting.PreferenceOption + if err = value.UnmarshalText([]byte(s)); err == nil { + return value, nil + } + } + return setting.ShowChoiceByPolicy, err + case setting.VisibilityValue: + s, err := store.ReadString(key) + if err == nil { + var value setting.Visibility + if err = value.UnmarshalText([]byte(s)); err == nil { + return value, nil + } + } + return setting.VisibleByPolicy, err + case setting.DurationValue: + s, err := store.ReadString(key) + if err == nil { + var value time.Duration + if value, err = time.ParseDuration(s); err == nil { + return value, nil + } + } + return nil, err + default: + return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type()) + } +} diff --git a/util/syspolicy/source/policy_reader_test.go b/util/syspolicy/source/policy_reader_test.go new file mode 100644 index 000000000..f2d411d12 --- /dev/null +++ b/util/syspolicy/source/policy_reader_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "testing" + "time" + + "tailscale.com/util/must" + "tailscale.com/util/syspolicy/setting" +) + +func TestReaderLifecycle(t *testing.T) { + tests := []struct { + name string + origin *setting.Origin + definitions []*setting.Definition + wantReads []TestExpectedReads + initStrings []TestSetting[string] + initUInt64s []TestSetting[uint64] + initWant *setting.Snapshot + addStrings []TestSetting[string] + addStringLists []TestSetting[[]string] + newWant *setting.Snapshot + }{ + { + name: "read-all-settings-once", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue), + setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "StringValue", Type: setting.StringValue, NumTimes: 1}, + {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1}, + {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1}, + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "re-read-all-settings-when-the-policy-changes", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue), + setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "StringValue", Type: setting.StringValue, NumTimes: 1}, + {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1}, + {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1}, + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")}, + addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})}, + newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "read-settings-if-in-scope/device", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1}, + {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-settings-if-in-scope/profile", + origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + // Device settings cannot be configured at the profile scope and should not be read. + {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-settings-if-in-scope/user", + origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + // Device and profile settings cannot be configured at the profile scope and should not be read. + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-stringy-settings", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initStrings: []TestSetting[string]{ + TestSettingOf("DurationValue", "2h30m"), + TestSettingOf("PreferenceOptionValue", "always"), + TestSettingOf("VisibilityValue", "show"), + }, + initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "read-erroneous-stringy-settings", + origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue), + setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initStrings: []TestSetting[string]{ + TestSettingOf("DurationValue1", "soon"), + TestSettingWithError[string]("DurationValue2", setting.NewError("bang!")), + TestSettingOf("PreferenceOptionValue", "sometimes"), + }, + initUInt64s: []TestSetting[uint64]{ + TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch + }, + initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "DurationValue1": setting.RawItemWith(nil, setting.NewError("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "DurationValue2": setting.RawItemWith(nil, setting.NewError("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewError("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + }, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setting.SetDefinitionsForTest(t, tt.definitions...) + store := NewTestStore(t) + store.SetStrings(tt.initStrings...) + store.SetUInt64s(tt.initUInt64s...) + + reader, err := newReader(store, tt.origin) + if err != nil { + t.Fatalf("newReader failed: %v", err) + } + + if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) { + t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant) + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + + // Should not result in new reads as there were no changes. + N := 100 + for range N { + reader.GetSettings() + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + store.ResetCounters() + + got, err := reader.ReadSettings() + if err != nil { + t.Fatalf("ReadSettings failed: %v", err) + } + + if tt.initWant != nil && !got.Equal(tt.initWant) { + t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant) + } + + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + store.ResetCounters() + + if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 { + store.SetStrings(tt.addStrings...) + store.SetStringLists(tt.addStringLists...) + + // As the settings have changed, GetSettings needs to re-read them. + if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) { + t.Errorf("New Settings do not match: got %v, want %v", got, want) + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + } + + select { + case <-reader.Done(): + t.Fatalf("the reader is closed") + default: + } + + store.Close() + + <-reader.Done() + }) + } +} + +func TestReadingSession(t *testing.T) { + setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue)) + store := NewTestStore(t) + + origin := setting.NewOrigin(setting.DeviceScope) + reader, err := newReader(store, origin) + if err != nil { + t.Fatalf("newReader failed: %v", err) + } + session, err := reader.OpenSession() + if err != nil { + t.Fatalf("failed to open a reading session: %v", err) + } + t.Cleanup(session.Close) + + if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) { + t.Errorf("Settings do not match: got %v, want %v", got, want) + } + + select { + case _, ok := <-session.PolicyChanged(): + if ok { + t.Fatalf("the policy changed notification was sent prematurely") + } else { + t.Fatalf("the session was closed prematurely") + } + default: + } + + store.SetStrings(TestSettingOf("StringValue", "S1")) + _, ok := <-session.PolicyChanged() + if !ok { + t.Fatalf("the session was closed prematurely") + } + + want := setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "StringValue": setting.RawItemWith("S1", nil, origin), + }, origin) + if got := session.GetSettings(); !got.Equal(want) { + t.Errorf("Settings do not match: got %v, want %v", got, want) + } + + store.Close() + if _, ok = <-session.PolicyChanged(); ok { + t.Fatalf("the session must be closed") + } +} diff --git a/util/syspolicy/source/policy_store.go b/util/syspolicy/source/policy_store.go new file mode 100644 index 000000000..9b150825e --- /dev/null +++ b/util/syspolicy/source/policy_store.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "errors" + "fmt" + "io" + + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/setting" +) + +// ErrStoreClosed is an error returned when attempting to use a [Store] after it +// has been closed. +var ErrStoreClosed = errors.New("the policy store has been closed") + +// Store provides methods to read system policy settings from OS-specific storage. +// Implementations must be concurrency-safe, and may also implement +// [Lockable], [Changeable], [Expirable] and [io.Closer]. +// +// If a [Store] implementation also implements [io.Closer], +// it will be called by the package to release the resources +// when the store is no longer needed. +type Store interface { + // ReadString returns the value of a [setting.StringValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an [setting.ErrTypeMismatch] if the policy setting is not of a string type. + ReadString(key setting.Key) (string, error) + // ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an [setting.ErrTypeMismatch] if the policy setting is not of a string type. + ReadUInt64(key setting.Key) (uint64, error) + // ReadBoolean returns the value of a [setting.BooleanValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an [setting.ErrTypeMismatch] if the policy setting is not of a string type. + ReadBoolean(key setting.Key) (bool, error) + // ReadStringArray returns the value of a [setting.StringListValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an [setting.ErrTypeMismatch] if the policy setting is not of a string list type. + ReadStringArray(key setting.Key) ([]string, error) +} + +// Lockable is an optional interface that [Store] implementations may support. +// Locking a [Store] is not mandatory as [Store] must be concurrency-safe, +// but is recommended to avoid issues where consecutive read calls for related +// policies might return inconsistent results if a policy change occurs between +// the calls. +type Lockable interface { + + // Lock acquires a read lock on the policy store, + // ensuring the store's state remains unchanged while locked. + // Multiple readers can hold the lock simultaneously. + // It should return nil if the store does not support locking, + // or an error if the store cannot be locked. + Lock() error + // Unlock unlocks the policy store. + // It is a runtime error if the store is not locked on entry to Unlock. + Unlock() +} + +// Changeable is an optional interface that [Store] implementations may support. +type Changeable interface { + // RegisterChangeCallback adds a function that will be called + // whenever there's a policy change in the [Store]. + // The returned function can be used to unregister the callback. + RegisterChangeCallback(callback func()) (unregister func(), err error) +} + +// Expirable is an optional interface that [Store] implementations may support. +type Expirable interface { + // Done returns a channel that is closed when the policy [Store] should no longer be used. + // It should return nil if the store never expires. + Done() <-chan struct{} +} + +// Source represents a named source of policy settings for a given scope. +type Source struct { + name string + scope setting.PolicyScope + store Store + origin *setting.Origin + + lazyReader lazy.SyncValue[*Reader] +} + +// NewSource returns a new [Source] with the specified name, scope, and store. +func NewSource(name string, scope setting.PolicyScope, store Store) *Source { + return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)} +} + +// Name reports the name of the policy source. +func (s *Source) Name() string { + return s.name +} + +// Scope reports the management scope of the policy source. +func (s *Source) Scope() setting.PolicyScope { + return s.scope +} + +// Store returns the [Store] that can be used to read policy settings from this source. +func (s *Source) Store() Store { + return s.store +} + +// Reader returns a [Reader] that reads from this source's [Store]. +func (s *Source) Reader() (*Reader, error) { + return s.lazyReader.GetErr(func() (*Reader, error) { + return newReader(s.store, s.origin) + }) +} + +// String implements [fmt.Stringer]. +func (s *Source) String() string { + if s.Name() != "" { + return fmt.Sprintf("%s (%v)", s.Name(), s.Scope()) + } + return s.Scope().String() +} + +// Compare returns an integer comparing [Source] s and s2 +// by their precedence, following the "last-wins" model. +// The result will be: +// +// -1 if policy settings from s should be processed before policy settings from s2; +// +1 if policy settings from s should be processed after policy settings from s2, overriding s2; +// 0 if the relative processing order of policy settings in s and s2 is unspecified. +func (s *Source) Compare(s2 *Source) int { + return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind()) +} + +// Close closes the [Source] and the underlying [Store]. +func (s *Source) Close() error { + // The [Reader], if any, owns the [Store]. + if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil { + return reader.Close() + } + // Otherwise, it is our responsibility to close it. + if closer, ok := s.store.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go new file mode 100644 index 000000000..5d6503981 --- /dev/null +++ b/util/syspolicy/source/policy_store_windows.go @@ -0,0 +1,438 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "strings" + "sync" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/winutil/gp" +) + +const ( + softwareKeyName = "Software" + tsPoliciesSubkey = `Policies\Tailscale` + tsIPNSubkey = "Tailscale IPN" // the legacy key we need to fallback to +) + +var ( + // [PlatformPolicyStore] implements [Store]. + _ Store = (*PlatformPolicyStore)(nil) +) + +// PlatformPolicyStore implements [Store] by providing read access to the Registry-based +// Tailscale policies, such as those configured via Group Policy or MDM. It is +// recommended to lock it when reading multiple policy values in a row. It also +// allows subscribing to notifications when there's a policy change. +type PlatformPolicyStore struct { + scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy] + + // The softwareKey can be HKLM\Software, HKCU\Software, or + // HKU\{SID}\Software. Anything below the Software subkey, including + // Software\Policies, may not yet exist or could be deleted throughout the + // [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer + // to always use a real registry key (rather than a predefined HKLM or HKCU) + // to simplify bookkeeping (predefined keys should never be closed). + // Finally, this will allow us to watch for any registry changes directly + // should we need this in the future in addition to gp.ChangeWatcher. + softwareKey registry.Key + watcher *gp.ChangeWatcher + + done chan struct{} // done is closed when Close call completes + + // The policyLock can be locked by the caller when reading multiple policy settings + // to prevent the Group Policy Client service from modifying policies while + // they are being read. + // + // When both policyLock and mu need to be taken, mu must be taken before policyLock. + policyLock *gp.PolicyLock + + mu sync.RWMutex + tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked. + cbs set.HandleSet[func()] // policy change callbacks + lockCnt int + locked sync.WaitGroup + closing bool + readable bool +} + +type registryValueGetter[T any] func(key registry.Key, name setting.Key) (T, error) + +// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine. +func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) { + softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ) + if err != nil { + return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err) + } + return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, 0) +} + +// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token. +// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY] +// access. The caller retains ownership of the token. +func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) { + var err error + var softwareKey registry.Key + if token != 0 { + var user *windows.Tokenuser + if user, err = token.GetTokenUser(); err != nil { + return nil, fmt.Errorf("failed to get token user: %w", err) + } + userSid := user.User.Sid + softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ) + } else { + softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ) + } + if err != nil { + return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err) + } + return newPlatformPolicyStore(gp.UserPolicy, softwareKey, token) +} + +func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, token windows.Token) (_ *PlatformPolicyStore, err error) { + store := &PlatformPolicyStore{ + scope: scope, + softwareKey: softwareKey, + done: make(chan struct{}), + readable: true, + } + defer func() { + if err != nil { + store.Close() + } + }() + + switch scope { + case gp.MachinePolicy: + store.policyLock = gp.NewMachinePolicyLock() + case gp.UserPolicy: + if store.policyLock, err = gp.NewUserPolicyLock(token); err != nil { + return nil, fmt.Errorf("failed to create a user policy lock: %w", err) + } + default: + panic("unreachable") + } + + return store, nil +} + +// Lock locks the policy store, preventing the system from modifying the policies +// while they are being read. It is a read lock that may be acquired by multiple goroutines. +// Each Lock call must be balanced by exactly one Unlock call. +func (ps *PlatformPolicyStore) Lock() (err error) { + ps.mu.Lock() + defer ps.mu.Unlock() + + if ps.closing { + return ErrStoreClosed + } + + ps.lockCnt += 1 + if ps.lockCnt != 1 { + return nil + } + defer func() { + if err != nil { + ps.lockCnt -= 1 + } + }() + + // Ensure ps remains open while the lock is held. + ps.locked.Add(1) + defer func() { + if err != nil { + ps.locked.Done() + } + }() + + // Acquire the GP lock to prevent the system from modifying policy settings + // while they are being read. + if err := ps.policyLock.Lock(); err != nil { + if errors.Is(err, gp.ErrInvalidLockState) { + return ErrStoreClosed + } + return err + } + defer func() { + if err != nil { + ps.policyLock.Unlock() + } + }() + + // Keep the Tailscale's registry keys open for the duration of the lock. + keyNames := tailscaleKeyNamesFor(ps.scope) + ps.tsKeys = make([]registry.Key, 0, len(keyNames)) + for _, keyName := range keyNames { + var tsKey registry.Key + tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ) + if err != nil { + if err == registry.ErrNotExist { + continue + } + return err + } + ps.tsKeys = append(ps.tsKeys, tsKey) + } + + return nil +} + +// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0. +// It panics if ps is not locked on entry to Unlock. +func (ps *PlatformPolicyStore) Unlock() { + ps.mu.Lock() + defer ps.mu.Unlock() + + ps.lockCnt -= 1 + if ps.lockCnt < 0 { + panic("negative lockCnt") + } else if ps.lockCnt != 0 { + return + } + + for _, key := range ps.tsKeys { + key.Close() + } + ps.tsKeys = nil + ps.policyLock.Unlock() + ps.locked.Done() +} + +// RegisterChangeCallback adds a function that will be called whenever there's a policy change. +// It returns a function that needs to be called to unregister the specified callback or an error. +// The error is [ErrStoreClosed] if ps has already been closed. +func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.closing { + return nil, ErrStoreClosed + } + + handle := ps.cbs.Add(cb) + if len(ps.cbs) == 1 { + if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil { + return nil, err + } + } + + return func() { + ps.mu.Lock() + defer ps.mu.Unlock() + delete(ps.cbs, handle) + if len(ps.cbs) == 0 { + if ps.watcher != nil { + ps.watcher.Close() + ps.watcher = nil + } + } + }, nil +} + +func (ps *PlatformPolicyStore) onChange() { + ps.mu.RLock() + defer ps.mu.RUnlock() + if ps.closing { + return + } + for _, callback := range ps.cbs { + go callback() + } +} + +// ReadString retrieves a string policy with the specified name. +// It returns [ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadString(name setting.Key) (val string, err error) { + return getPolicyValue(ps, canonicalizeValueName(name), + func(key registry.Key, name setting.Key) (string, error) { + val, _, err := key.GetStringValue(string(name)) + return val, err + }) +} + +// ReadUInt64 retrieves an integer policy with the specified name. +// It returns [ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadUInt64(name setting.Key) (uint64, error) { + return getPolicyValue(ps, canonicalizeValueName(name), + func(key registry.Key, name setting.Key) (uint64, error) { + val, _, err := key.GetIntegerValue(string(name)) + return val, err + }) +} + +// ReadBoolean retrieves a boolean policy with the specified name. +// It returns [ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadBoolean(name setting.Key) (bool, error) { + return getPolicyValue(ps, canonicalizeValueName(name), + func(key registry.Key, name setting.Key) (bool, error) { + val, _, err := key.GetIntegerValue(string(name)) + if err != nil { + return false, err + } + return val != 0, nil + }) +} + +// ReadString retrieves a multi-string policy with the specified name. +// It returns [ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadStringArray(name setting.Key) ([]string, error) { + return getPolicyValue(ps, name, + func(key registry.Key, name setting.Key) ([]string, error) { + val, _, err := key.GetStringsValue(string(canonicalizeValueName(name))) + if err != registry.ErrNotExist { + return val, err + } + + // The idiomatic way to store multiple string values in Group Policy + // and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ) + // values under a subkey rather than in a single REG_MULTI_SZ value. + // + // See the Group Policy: Registry Extension Encoding specification, + // and specifically the ListElement and ListBox types. + // https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf + valKey, err := registry.OpenKey(key, string(canonicalizeKeyName(name)), windows.KEY_READ) + if err != nil { + return nil, err + } + valNames, err := valKey.ReadValueNames(0) + if err != nil { + return nil, err + } + val = make([]string, 0, len(valNames)) + for _, name := range valNames { + switch item, _, err := valKey.GetStringValue(name); { + case err == registry.ErrNotExist: + continue + case err != nil: + return nil, err + default: + val = append(val, item) + } + } + return val, nil + }) +} + +func canonicalizeKeyName(name setting.Key) setting.Key { + return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `\`)) +} + +func canonicalizeValueName(name setting.Key) setting.Key { + return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `_`)) +} + +func getPolicyValue[T any](ps *PlatformPolicyStore, name setting.Key, getter registryValueGetter[T]) (T, error) { + var zero T + + ps.mu.RLock() + defer ps.mu.RUnlock() + if !ps.readable { + return zero, setting.ErrNotConfigured + } + + if ps.tsKeys != nil { + // A non-nil tsKeys indicates that ps has been locked. + // It may be empty if Tailscale policy keys do not exist. + for _, tsKey := range ps.tsKeys { + val, err := getter(tsKey, name) + if err == nil || err != registry.ErrNotExist { + return val, err + } + } + return zero, setting.ErrNotConfigured + } + + // The ps has not been locked, so we don't have any pre-opened keys. + for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) { + var tsKey registry.Key + tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ) + if err != nil { + if err == registry.ErrNotExist { + continue + } + return zero, err + } + defer tsKey.Close() + + val, err := getter(tsKey, name) + if err == nil || err != registry.ErrNotExist { + return val, err + } + } + + return zero, setting.ErrNotConfigured +} + +// Close closes the policy store and releases any associated resources. +// It cancels pending locks and prevents any new lock attempts, +// but waits for existing locks to be released. +func (ps *PlatformPolicyStore) Close() error { + // Request to close the Group Policy read lock. + // Existing held locks will remain valid, but any new or pending locks + // will fail. In certain scenarios, the corresponding write lock may be held + // by the Group Policy service for extended periods (minutes rather than + // seconds or milliseconds). In such cases, we prefer not to wait that long + // if the ps is being closed anyway. + if ps.policyLock != nil { + ps.policyLock.Close() + } + + // Signal to the external code that ps should no longer be used. + close(ps.done) + + // Mark ps as closing to fast-fail any new lock attempts. + // Callers that have already locked it can finish their reading. + ps.mu.Lock() + if ps.closing { + ps.mu.Unlock() + return nil + } + ps.closing = true + if ps.watcher != nil { + ps.watcher.Close() + ps.watcher = nil + } + ps.mu.Unlock() + + // Wait for any outstanding locks to be released. + ps.locked.Wait() + + // Deny any further read attempts and release remaining resources. + ps.mu.Lock() + defer ps.mu.Unlock() + ps.cbs = nil + ps.policyLock = nil + ps.readable = false + if ps.softwareKey != 0 { + ps.softwareKey.Close() + ps.softwareKey = 0 + } + return nil +} + +// Done returns a channel that is closed when the Close method is called. +func (ps *PlatformPolicyStore) Done() <-chan struct{} { + return ps.done +} + +func tailscaleKeyNamesFor(scope gp.Scope) []string { + switch scope { + case gp.MachinePolicy: + // If a computer-side policy value does not exist under Software\Policies\Tailscale, + // we need to fallback and use the legacy Software\Tailscale IPN key. + return []string{tsPoliciesSubkey, tsIPNSubkey} + case gp.UserPolicy: + // However, we've never used the legacy key with user-side policies, + // and we should never do so. Unlike HKLM\Software\Tailscale IPN, + // its HKCU counterpart is user-writable. + return []string{tsPoliciesSubkey} + default: + panic("unreachable") + } +} diff --git a/util/syspolicy/source/policy_store_windows_test.go b/util/syspolicy/source/policy_store_windows_test.go new file mode 100644 index 000000000..60c76837f --- /dev/null +++ b/util/syspolicy/source/policy_store_windows_test.go @@ -0,0 +1,298 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/util/cibuild" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/winutil" + "tailscale.com/util/winutil/gp" +) + +type testPolicyValue struct { + name setting.Key + value any +} + +func TestLockUnlockPolicyStore(t *testing.T) { + store, err := NewMachinePlatformPolicyStore() + if err != nil { + t.Fatalf("NewMachinePolicyStore failed: %v", err) + } + + t.Run("One-Goroutine", func(t *testing.T) { + if err := store.Lock(); err != nil { + t.Errorf("store.Lock(): got %v; want nil", err) + return + } + if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) { + t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured) + } + store.Unlock() + }) + + // Lock the store N times from different goroutines. + const N = 100 + var unlocked atomic.Int32 + t.Run("N-Goroutines", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(N) + for range N { + go func() { + if err := store.Lock(); err != nil { + t.Errorf("store.Lock(): got %v; want nil", err) + return + } + if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) { + t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured) + } + wg.Done() + time.Sleep(10 * time.Millisecond) + unlocked.Add(1) + store.Unlock() + }() + } + + // Wait until the store is locked N times. + wg.Wait() + }) + + // Close the store. The call should wait for all held locks to be released. + if err := store.Close(); err != nil { + t.Fatalf("(*PolicyStore).Close failed: %v", err) + } + if locked := unlocked.Load(); locked != N { + t.Errorf("locked.Load(): got %v; want %v", locked, N) + } + + // Any further attempts to lock it should fail. + if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) { + t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed) + } +} + +func TestReadPolicyStore(t *testing.T) { + if !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user") + } + tests := []struct { + name setting.Key + newValue any + legacyValue any + want any + }{ + {name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"}, + {name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"}, + {name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""}, + {name: "BoolPolicy_True", newValue: true, want: true}, + {name: "BoolPolicy_False", newValue: false, want: false}, + {name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64 + {name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)}, + {name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}}, + {name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}}, + } + + runTests := func(t *testing.T, userStore bool, token windows.Token) { + var hive registry.Key + if userStore { + hive = registry.CURRENT_USER + } else { + hive = registry.LOCAL_MACHINE + } + + // Write policy values to the registry. + newValues := make([]testPolicyValue, 0, len(tests)) + for _, tt := range tests { + if tt.newValue != nil { + newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue}) + } + } + policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey + cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues) + if err != nil { + t.Fatalf("createTestPolicyValues failed: %v", err) + } + t.Cleanup(cleanup) + + // Write legacy policy values to the registry. + legacyValues := make([]testPolicyValue, 0, len(tests)) + for _, tt := range tests { + if tt.legacyValue != nil { + legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue}) + } + } + legacyKeyName := softwareKeyName + `\` + tsIPNSubkey + cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues) + if err != nil { + t.Fatalf("createTestPolicyValues failed: %v", err) + } + t.Cleanup(cleanup) + + var store *PlatformPolicyStore + if userStore { + store, err = NewUserPlatformPolicyStore(token) + } else { + store, err = NewMachinePlatformPolicyStore() + } + if err != nil { + t.Fatalf("NewXPolicyStore failed: %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Errorf("(*PolicyStore).Close failed: %v", err) + } + }) + + // testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry. + testReadValues := func(t *testing.T, withLocks bool) { + for _, tt := range tests { + t.Run(string(tt.name), func(t *testing.T) { + if userStore && tt.newValue == nil { + t.Skip("there is no legacy policies for users") + } + + t.Parallel() + + if withLocks { + if err := store.Lock(); err != nil { + t.Errorf("failed to acquire the lock: %v", err) + } + defer store.Unlock() + } + + var got any + var err error + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.name) + case uint64: + got, err = store.ReadUInt64(tt.name) + case bool: + got, err = store.ReadBoolean(tt.name) + case []string: + got, err = store.ReadStringArray(tt.name) + } + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } + } + t.Run("NoLock", func(t *testing.T) { + testReadValues(t, false) + }) + + t.Run("WithLock", func(t *testing.T) { + testReadValues(t, true) + }) + } + + t.Run("MachineStore", func(t *testing.T) { + runTests(t, false, 0) + }) + + t.Run("CurrentUserStore", func(t *testing.T) { + runTests(t, true, 0) + }) + + t.Run("UserStoreWithToken", func(t *testing.T) { + var token windows.Token + if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil { + t.Fatalf("OpenProcessToken: %v", err) + } + defer token.Close() + runTests(t, true, token) + }) +} + +func TestPolicyStoreChangeNotifications(t *testing.T) { + if cibuild.On() { + t.Skipf("test requires running on a real Windows environment") + } + store, err := NewMachinePlatformPolicyStore() + if err != nil { + t.Fatalf("NewMachinePolicyStore failed: %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Errorf("(*PolicyStore).Close failed: %v", err) + } + }) + + done := make(chan struct{}) + unregister, err := store.RegisterChangeCallback(func() { close(done) }) + if err != nil { + t.Fatalf("RegisterChangeCallback failed: %v", err) + } + t.Cleanup(unregister) + + // RefreshMachinePolicy is a non-blocking call. + if err := gp.RefreshMachinePolicy(true); err != nil { + t.Fatalf("RefreshMachinePolicy failed: %v", err) + } + + // We should receive a policy change notification when + // the Group Policy service completes policy processing. + // Otherwise, the test will eventually time out. + <-done +} + +func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) { + key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS) + if err != nil { + return nil, err + } + doCleanup := func() { + for _, v := range values { + key.DeleteValue(string(v.name)) + } + key.Close() + if !existing { + registry.DeleteKey(hive, keyName) + } + } + defer func() { + if err != nil { + doCleanup() + } + }() + + for _, v := range values { + switch value := v.value.(type) { + case string: + err = key.SetStringValue(string(v.name), value) + case uint32: + err = key.SetDWordValue(string(v.name), value) + case uint64: + err = key.SetQWordValue(string(v.name), value) + case bool: + if value { + err = key.SetDWordValue(string(v.name), 1) + } else { + err = key.SetDWordValue(string(v.name), 0) + } + case []string: + err = key.SetStringsValue(string(v.name), value) + default: + err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name) + } + if err != nil { + return nil, err + } + } + return doCleanup, nil +} diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go new file mode 100644 index 000000000..fd422d852 --- /dev/null +++ b/util/syspolicy/source/test_store.go @@ -0,0 +1,446 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "fmt" + "sync" + "sync/atomic" + + xmaps "golang.org/x/exp/maps" + "tailscale.com/util/mak" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" +) + +var _ Store = (*TestStore)(nil) + +// TestValueType is a constraint that allows types supported by [TestStore]. +type TestValueType interface { + bool | uint64 | string | []string +} + +// TestSetting is a policy setting in a [TestStore]. +type TestSetting[T TestValueType] struct { + // Key is the setting's unique identifier. + Key setting.Key + // Error is the error to be returned by the [TestStore] when reading + // a policy setting with the specified key. + Error error + // Value is the value to be returned by the [TestStore] when reading + // a policy setting with the specified key. + // It is only used if the Error is nil. + Value T +} + +// TestSettingOf returns a [TestSetting] representing a policy setting +// configured with the specified key and value. +func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] { + return TestSetting[T]{Key: key, Value: value} +} + +// TestSettingWithError returns a [TestSetting] representing a policy setting +// with the specified key and error. +func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] { + return TestSetting[T]{Key: key, Error: err} +} + +// testReadOperation describes a single policy setting read operation. +type testReadOperation struct { + // Key is the setting's unique identifier. + Key setting.Key + // Type is a value type of a read operation. + // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] + Type setting.Type +} + +// TestExpectedReads is the number of read operations with the specified details. +type TestExpectedReads struct { + // Key is the setting's unique identifier. + Key setting.Key + // Type is a value type of a read operation. + // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] + Type setting.Type + // NumTimes is how many times a setting with the specified key and type should have been read. + NumTimes int +} + +func (r TestExpectedReads) operation() testReadOperation { + return testReadOperation{r.Key, r.Type} +} + +// TestStore is a [Store] that can be used in tests. +type TestStore struct { + tb internal.TB + + done chan struct{} + + storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock]. + storeLockCount atomic.Int32 + + mu sync.RWMutex + suspendCount int // change callback are suspended if > 0 + mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. + cbs set.HandleSet[func()] + + readsMu sync.Mutex + reads map[testReadOperation]int // how many times a policy setting was read +} + +// NewTestStore returns a new [TestStore]. +// The tb will be used to report coding errors detected by the [TestStore]. +func NewTestStore(tb internal.TB) *TestStore { + m := make(map[setting.Key]any) + return &TestStore{ + tb: tb, + done: make(chan struct{}), + mr: m, + mw: m, + } +} + +// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], +// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. +func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { + m := make(map[setting.Key]any) + store := &TestStore{ + tb: tb, + done: make(chan struct{}), + mr: m, + mw: m, + } + switch settings := any(settings).(type) { + case []TestSetting[bool]: + store.SetBooleans(settings...) + case []TestSetting[uint64]: + store.SetUInt64s(settings...) + case []TestSetting[string]: + store.SetStrings(settings...) + case []TestSetting[[]string]: + store.SetStringLists(settings...) + } + return store +} + +// Lock implements [Store]. +func (s *TestStore) Lock() error { + s.storeLock.RLock() + s.storeLockCount.Add(1) + return nil +} + +// Unlock implements [Store]. +func (s *TestStore) Unlock() { + if s.storeLockCount.Add(-1) < 0 { + s.tb.Fatal("negative storeLockCount") + } + s.storeLock.RUnlock() +} + +// RegisterChangeCallback implements [Store]. +func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) { + s.mu.Lock() + defer s.mu.Unlock() + handle := s.cbs.Add(callback) + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.cbs, handle) + }, nil +} + +// ReadString implements [Store]. +func (s *TestStore) ReadString(key setting.Key) (string, error) { + defer s.recordRead(key, setting.StringValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return "", setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return "", err + } + str, ok := v.(string) + if !ok { + return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v) + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) { + defer s.recordRead(key, setting.IntegerValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return 0, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return 0, err + } + u64, ok := v.(uint64) + if !ok { + return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v) + } + return u64, nil +} + +// ReadBoolean implements [Store]. +func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) { + defer s.recordRead(key, setting.BooleanValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return false, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return false, err + } + b, ok := v.(bool) + if !ok { + return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v) + } + return b, nil +} + +// ReadStringArray implements [Store]. +func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) { + defer s.recordRead(key, setting.StringListValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return nil, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return nil, err + } + slice, ok := v.([]string) + if !ok { + return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v) + } + return slice, nil +} + +func (s *TestStore) recordRead(key setting.Key, typ setting.Type) { + s.readsMu.Lock() + op := testReadOperation{key, typ} + num := s.reads[op] + num++ + mak.Set(&s.reads, op, num) + s.readsMu.Unlock() +} + +func (s *TestStore) ResetCounters() { + s.readsMu.Lock() + clear(s.reads) + s.readsMu.Unlock() +} + +// ReadsMustEqual fails the test if the actual reads differs from the specified reads. +func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) { + s.tb.Helper() + s.readsMu.Lock() + defer s.readsMu.Unlock() + s.readsMustContainLocked(reads...) + s.readMustNoExtraLocked(reads...) +} + +// ReadsMustContain fails the test if the specified reads have not been made, +// or have been made a different number of times. It permits other values to be +// read in addition to the ones being tested. +func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) { + s.tb.Helper() + s.readsMu.Lock() + defer s.readsMu.Unlock() + s.readsMustContainLocked(reads...) +} + +func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) { + s.tb.Helper() + for _, r := range reads { + if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes { + s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes) + } + } +} + +func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) { + s.tb.Helper() + rs := make(set.Set[testReadOperation]) + for i := range reads { + rs.Add(reads[i].operation()) + } + for ro, num := range s.reads { + if !rs.Contains(ro) { + s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num) + } + } +} + +// Suspend suspends the store, batching changes and notifications +// until [TestStore.Resume] is called the same number of times as Suspend. +func (s *TestStore) Suspend() { + s.mu.Lock() + defer s.mu.Unlock() + if s.suspendCount++; s.suspendCount == 1 { + s.mw = xmaps.Clone(s.mr) + } +} + +// Resume resumes the store, applying the changes and invoking +// the change callbacks. +func (s *TestStore) Resume() { + s.storeLock.Lock() + s.mu.Lock() + switch s.suspendCount--; { + case s.suspendCount == 0: + s.mr = s.mw + s.mu.Unlock() + s.storeLock.Unlock() + s.notifyPolicyChanged() + case s.suspendCount < 0: + s.tb.Fatal("negative suspendCount") + default: + s.mu.Unlock() + s.storeLock.Unlock() + } +} + +// SetBooleans sets the specified boolean settings in s. +func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetUInt64s sets the specified integer settings in s. +func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetStrings sets the specified string settings in s. +func (s *TestStore) SetStrings(settings ...TestSetting[string]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetStrings sets the specified string list settings in s. +func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// Delete deletes the specified settings from s. +func (s *TestStore) Delete(keys ...setting.Key) { + s.storeLock.Lock() + for _, key := range keys { + s.mu.Lock() + delete(s.mw, key) + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// Clear deletes all settings from s. +func (s *TestStore) Clear() { + s.storeLock.Lock() + s.mu.Lock() + clear(s.mw) + s.mu.Unlock() + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +func (s *TestStore) notifyPolicyChanged() { + s.mu.RLock() + if s.suspendCount != 0 { + s.mu.RUnlock() + return + } + cbs := xmaps.Values(s.cbs) + s.mu.RUnlock() + + var wg sync.WaitGroup + wg.Add(len(cbs)) + for _, cb := range cbs { + go func() { + defer wg.Done() + cb() + }() + } + wg.Wait() +} + +// Close closes s, notifying its users that it has expired. +func (s *TestStore) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if s.done != nil { + close(s.done) + s.done = nil + } +} + +// Done implements [Store]. +func (s *TestStore) Done() <-chan struct{} { + return s.done +} diff --git a/util/syspolicy/syspolicy.go b/util/syspolicy/syspolicy.go index 76e11e2b6..1ff9ff97a 100644 --- a/util/syspolicy/syspolicy.go +++ b/util/syspolicy/syspolicy.go @@ -1,122 +1,83 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package syspolicy provides functions to retrieve system settings of a device. +// Package syspolicy facilitates retrieval of the current policy settings +// applied to the device or user and receiving notifications when the policy +// changes. +// +// It provides functions that return specific policy settings by their unique +// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString], +// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration]. package syspolicy import ( "errors" + "fmt" + "reflect" "time" + + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" ) +var ( + // ErrNotConfigured is returned when the requested policy setting is not configured. + ErrNotConfigured = setting.ErrNotConfigured + // ErrTypeMismatch is returned when there's a type mismatch between the actual type + // of the setting value and the expected type. + ErrTypeMismatch = setting.ErrTypeMismatch + // ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting + // has been registered with the specified key. + // + // Until 2024-08-02, this error was also returned by a [Handler] when the specified + // key did not have a value set. While the package maintains compatibility with this + // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer + // [source.Store] implementations. + ErrNoSuchKey = setting.ErrNoSuchKey +) + +// GetString returns a string policy setting with the specified key, +// or defaultValue if it does not exist. func GetString(key Key, defaultValue string) (string, error) { - markHandlerInUse() - v, err := handler.ReadString(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetUint64 returns a numeric policy setting with the specified key, +// or defaultValue if it does not exist. func GetUint64(key Key, defaultValue uint64) (uint64, error) { - markHandlerInUse() - v, err := handler.ReadUInt64(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetBoolean returns a boolean policy setting with the specified key, +// or defaultValue if it does not exist. func GetBoolean(key Key, defaultValue bool) (bool, error) { - markHandlerInUse() - v, err := handler.ReadBoolean(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetStringArray returns a multi-string policy setting with the specified key, +// or defaultValue if it does not exist. func GetStringArray(key Key, defaultValue []string) ([]string, error) { - markHandlerInUse() - v, err := handler.ReadStringArray(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } -// PreferenceOption is a policy that governs whether a boolean variable -// is forcibly assigned an administrator-defined value, or allowed to receive -// a user-defined value. -type PreferenceOption int - -const ( - showChoiceByPolicy PreferenceOption = iota - neverByPolicy - alwaysByPolicy +type ( + // PreferenceOption is a policy that governs whether a boolean variable + // is forcibly assigned an administrator-defined value, or allowed to receive + // a user-defined value. + PreferenceOption = setting.PreferenceOption + // Visibility is a policy that controls whether or not a particular + // component of a user interface is to be shown. + Visibility = setting.Visibility ) -// Show returns if the UI option that controls the choice administered by this -// policy should be shown. Currently this is true if and only if the policy is -// showChoiceByPolicy. -func (p PreferenceOption) Show() bool { - return p == showChoiceByPolicy -} - -// ShouldEnable checks if the choice administered by this policy should be -// enabled. If the administrator has chosen a setting, the administrator's -// setting is returned, otherwise userChoice is returned. -func (p PreferenceOption) ShouldEnable(userChoice bool) bool { - switch p { - case neverByPolicy: - return false - case alwaysByPolicy: - return true - default: - return userChoice - } -} - -// WillOverride checks if the choice administered by the policy is different -// from the user's choice. -func (p PreferenceOption) WillOverride(userChoice bool) bool { - return p.ShouldEnable(userChoice) != userChoice -} - // GetPreferenceOption loads a policy from the registry that can be // managed by an enterprise policy management system and allows administrative // overrides of users' choices in a way that we do not want tailcontrol to have // the authority to set. It describes user-decides/always/never options, where // "always" and "never" remove the user's ability to make a selection. If not // present or set to a different value, "user-decides" is the default. -func GetPreferenceOption(name Key) (PreferenceOption, error) { - opt, err := GetString(name, "user-decides") - if err != nil { - return showChoiceByPolicy, err - } - switch opt { - case "always": - return alwaysByPolicy, nil - case "never": - return neverByPolicy, nil - default: - return showChoiceByPolicy, nil - } -} - -// Visibility is a policy that controls whether or not a particular -// component of a user interface is to be shown. -type Visibility byte - -const ( - visibleByPolicy Visibility = 'v' - hiddenByPolicy Visibility = 'h' -) - -// Show reports whether the UI option administered by this policy should be shown. -// Currently this is true if and only if the policy is visibleByPolicy. -func (p Visibility) Show() bool { - return p == visibleByPolicy +func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { + return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy) } // GetVisibility loads a policy from the registry that can be managed @@ -124,17 +85,8 @@ func (p Visibility) Show() bool { // for UI elements. The registry value should be a string set to "show" (return // true) or "hide" (return true). If not present or set to a different value, // "show" (return false) is the default. -func GetVisibility(name Key) (Visibility, error) { - opt, err := GetString(name, "show") - if err != nil { - return visibleByPolicy, err - } - switch opt { - case "hide": - return hiddenByPolicy, nil - default: - return visibleByPolicy, nil - } +func GetVisibility(name Key) (setting.Visibility, error) { + return getCurrentPolicySettingValue(name, setting.VisibleByPolicy) } // GetDuration loads a policy from the registry that can be managed @@ -143,15 +95,48 @@ func GetVisibility(name Key) (Visibility, error) { // understands. If the registry value is "" or can not be processed, // defaultValue is returned instead. func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) { - opt, err := GetString(name, "") - if opt == "" || err != nil { - return defaultValue, err + d, err := getCurrentPolicySettingValue(name, defaultValue) + if err != nil { + return d, err } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { + if d < 0 { return defaultValue, nil } - return v, nil + return d, nil +} + +// getCurrentPolicySettingValue returns the value of the policy setting +// specified by its key from the [rsop.Policy] of the [CurrentScope]. It +// returns def if the policy setting is not configured, or an error if it has +// an error or could not be converted to the specified type T. +func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) { + resultant, err := rsop.PolicyFor(setting.CurrentScope()) + if err != nil { + return def, err + } + value, err := resultant.Get().GetErr(key) + if err != nil { + if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) { + return def, nil + } + return def, err + } + if res, ok := value.(T); ok { + return res, nil + } + return convertPolicySettingValueTo(value, def) +} + +func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) { + // Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string + // if someone requests a string instead of the actual setting's value. + // TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests. + if reflect.TypeFor[T]().Kind() == reflect.String { + if str, ok := value.(fmt.Stringer); ok { + return any(str.String()).(T), nil + } + } + return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def) } // SelectControlURL returns the ControlURL to use based on a value in diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index c2810ebbb..2adbe9d25 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -5,16 +5,24 @@ package syspolicy import ( "errors" + "fmt" "slices" "testing" "time" + + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) // testHandler encompasses all data types returned when testing any of the syspolicy // methods that involve getting a policy value. // For keys and the corresponding values, check policy_keys.go. type testHandler struct { - t *testing.T + t testing.TB key Key s string u64 uint64 @@ -28,7 +36,10 @@ var someOtherError = errors.New("error other than not found") func (th *testHandler) ReadString(key string) (string, error) { if key != string(th.key) { - th.t.Errorf("ReadString(%q) want %q", key, th.key) + // The syspolicy package now reads and caches all registered policy settings. + // Therefore, it is expected to call the handler requesting all policies + // rather than just the specific ones we asked for. + return "", ErrNotConfigured } th.calls++ return th.s, th.err @@ -36,7 +47,10 @@ func (th *testHandler) ReadString(key string) (string, error) { func (th *testHandler) ReadUInt64(key string) (uint64, error) { if key != string(th.key) { - th.t.Errorf("ReadUint64(%q) want %q", key, th.key) + // The syspolicy package now reads and caches all registered policy settings. + // Therefore, it is expected to call the handler requesting all policies + // rather than just the specific ones we asked for. + return 0, ErrNotConfigured } th.calls++ return th.u64, th.err @@ -44,7 +58,10 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) { func (th *testHandler) ReadBoolean(key string) (bool, error) { if key != string(th.key) { - th.t.Errorf("ReadBool(%q) want %q", key, th.key) + // The syspolicy package now reads and caches all registered policy settings. + // Therefore, it is expected to call the handler requesting all policies + // rather than just the specific ones we asked for. + return false, ErrNotConfigured } th.calls++ return th.b, th.err @@ -52,7 +69,10 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) { func (th *testHandler) ReadStringArray(key string) ([]string, error) { if key != string(th.key) { - th.t.Errorf("ReadStringArray(%q) want %q", key, th.key) + // The syspolicy package now reads and caches all registered policy settings. + // Therefore, it is expected to call the handler requesting all policies + // rather than just the specific ones we asked for. + return nil, ErrNotConfigured } th.calls++ return th.sArr, th.err @@ -67,23 +87,28 @@ func TestGetString(t *testing.T) { defaultValue string wantValue string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AdminConsoleVisibility, handlerValue: "hide", wantValue: "hide", + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non-blank default", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: "test", wantValue: "test", wantError: nil, @@ -93,11 +118,17 @@ func TestGetString(t *testing.T) { key: NetworkDevicesVisibility, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_NetworkDevices_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -105,12 +136,21 @@ func TestGetString(t *testing.T) { err: tt.handlerError, }) value, err := GetString(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -127,7 +167,7 @@ func TestGetUint64(t *testing.T) { }{ { name: "read existing value", - key: KeyExpirationNoticeTime, + key: LogSCMInteractions, handlerValue: 1, wantValue: 1, }, @@ -135,14 +175,14 @@ func TestGetUint64(t *testing.T) { name: "read non-existing value", key: LogSCMInteractions, handlerValue: 0, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0, }, { name: "read non-existing value, non-zero default", key: LogSCMInteractions, defaultValue: 2, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 2, }, { @@ -155,14 +195,21 @@ func TestGetUint64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ + // None of the policy settings tested here are integers. + // In fact, we don't have any integer policies as of 2024-07-29. + // However, we can register each of them as an integer policy setting + // for the duration of the test, providing us with something to test against. + if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, WrapHandler(&testHandler{ t: t, key: tt.key, u64: tt.handlerValue, err: tt.handlerError, - }) + })) value, err := GetUint64(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { @@ -181,32 +228,43 @@ func TestGetBoolean(t *testing.T) { defaultValue bool wantValue bool wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: FlushDNSOnSessionUnlock, handlerValue: true, wantValue: true, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1}, + }, }, { name: "read non-existing value", key: LogSCMInteractions, handlerValue: false, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: false, }, { name: "reading value returns other error", key: FlushDNSOnSessionUnlock, handlerError: someOtherError, - wantError: someOtherError, + wantError: someOtherError, // expect error... defaultValue: true, - wantValue: false, + wantValue: true, // ...AND default value if the handler fails. + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -214,12 +272,21 @@ func TestGetBoolean(t *testing.T) { err: tt.handlerError, }) value, err := GetBoolean(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -232,42 +299,61 @@ func TestGetPreferenceOption(t *testing.T) { handlerError error wantValue PreferenceOption wantError error + wantMetrics []metrics.TestState }{ { name: "always by policy", key: EnableIncomingConnections, handlerValue: "always", - wantValue: alwaysByPolicy, + wantValue: setting.AlwaysByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "never by policy", key: EnableIncomingConnections, handlerValue: "never", - wantValue: neverByPolicy, + wantValue: setting.NeverByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "use default", key: EnableIncomingConnections, handlerValue: "", - wantValue: showChoiceByPolicy, + wantValue: setting.ShowChoiceByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "read non-existing value", key: EnableIncomingConnections, - handlerError: ErrNoSuchKey, - wantValue: showChoiceByPolicy, + handlerError: ErrNotConfigured, + wantValue: setting.ShowChoiceByPolicy, }, { name: "other error is returned", key: EnableIncomingConnections, handlerError: someOtherError, - wantValue: showChoiceByPolicy, + wantValue: setting.ShowChoiceByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -275,12 +361,21 @@ func TestGetPreferenceOption(t *testing.T) { err: tt.handlerError, }) option, err := GetPreferenceOption(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if option != tt.wantValue { t.Errorf("option=%v, want %v", option, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -293,38 +388,53 @@ func TestGetVisibility(t *testing.T) { handlerError error wantValue Visibility wantError error + wantMetrics []metrics.TestState }{ { name: "hidden by policy", key: AdminConsoleVisibility, handlerValue: "hide", - wantValue: hiddenByPolicy, + wantValue: setting.HiddenByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "visibility default", key: AdminConsoleVisibility, handlerValue: "show", - wantValue: visibleByPolicy, + wantValue: setting.VisibleByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: AdminConsoleVisibility, handlerValue: "show", - handlerError: ErrNoSuchKey, - wantValue: visibleByPolicy, + handlerError: ErrNotConfigured, + wantValue: setting.VisibleByPolicy, }, { name: "other error is returned", key: AdminConsoleVisibility, handlerValue: "show", handlerError: someOtherError, - wantValue: visibleByPolicy, + wantValue: setting.VisibleByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AdminConsole_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -332,12 +442,21 @@ func TestGetVisibility(t *testing.T) { err: tt.handlerError, }) visibility, err := GetVisibility(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if visibility != tt.wantValue { t.Errorf("visibility=%v, want %v", visibility, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -351,6 +470,7 @@ func TestGetDuration(t *testing.T) { defaultValue time.Duration wantValue time.Duration wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", @@ -358,25 +478,34 @@ func TestGetDuration(t *testing.T) { handlerValue: "2h", wantValue: 2 * time.Hour, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice", Value: 1}, + }, }, { name: "invalid duration value", key: KeyExpirationNoticeTime, handlerValue: "-20", wantValue: 24 * time.Hour, + wantError: errors.New(`time: missing unit in duration "-20"`), defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, { name: "read non-existing value", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 24 * time.Hour, defaultValue: 24 * time.Hour, }, { name: "read non-existing value different default", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0 * time.Second, defaultValue: 0 * time.Second, }, @@ -387,11 +516,17 @@ func TestGetDuration(t *testing.T) { wantValue: 24 * time.Hour, wantError: someOtherError, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -399,12 +534,21 @@ func TestGetDuration(t *testing.T) { err: tt.handlerError, }) duration, err := GetDuration(tt.key, tt.defaultValue) - if err != tt.wantError { + if fmt.Sprint(err) != fmt.Sprint(tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if duration != tt.wantValue { t.Errorf("duration=%v, want %v", duration, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -418,23 +562,28 @@ func TestGetStringArray(t *testing.T) { defaultValue []string wantValue []string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AllowedSuggestedExitNodes, handlerValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1}, + }, }, { name: "read non-existing value", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non nil default", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, wantError: nil, @@ -444,11 +593,17 @@ func TestGetStringArray(t *testing.T) { key: AllowedSuggestedExitNodes, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) SetHandlerForTest(t, &testHandler{ t: t, key: tt.key, @@ -456,16 +611,47 @@ func TestGetStringArray(t *testing.T) { err: tt.handlerError, }) value, err := GetStringArray(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if !slices.Equal(tt.wantValue, value) { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-08-02, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } +func BenchmarkGetString(b *testing.B) { + loggerx.SetForTest(b, logger.Discard, logger.Discard) + setWellKnownSettingsForTest(b) + + store := source.NewTestStore(b) + wantControlURL := "https://login.tailscale.com" + store.SetStrings(source.TestSetting[string]{Key: ControlURL, Value: wantControlURL}) + + _, err := rsop.RegisterStoreForTest(b, "Test Store", setting.DeviceScope, store) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com") + if gotControlURL != wantControlURL { + b.Fatalf("got %v; want %v", gotControlURL, wantControlURL) + } + } +} + func TestSelectControlURL(t *testing.T) { tests := []struct { reg, disk, want string @@ -497,3 +683,13 @@ func TestSelectControlURL(t *testing.T) { } } } + +func errorsMatchForTest(got, want error) bool { + if got == nil && want == nil { + return true + } + if got == nil || want == nil { + return false + } + return errors.Is(got, want) || got.Error() == want.Error() +} diff --git a/util/syspolicy/syspolicy_windows.go b/util/syspolicy/syspolicy_windows.go new file mode 100644 index 000000000..d17fa10b5 --- /dev/null +++ b/util/syspolicy/syspolicy_windows.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "errors" + "fmt" + "os/user" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/internal/lazyinit" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" +) + +func init() { + // On Windows, we should automatically register the Registry-based policy + // store for the device. If we are running in a user's security context + // (e.g., we're the GUI), we should also register the Registry policy store for + // the user. In the future, we should register (and unregister) user policy + // stores whenever a user connects to the local backend. This ensures the + // backend is aware of the user's policy settings and can send them to the + // GUI/CLI/Web clients on demand or whenever they change. + // + // Other platforms, such as macOS, iOS and Android, should register their + // platform-specific policy stores via [RegisterStore] (or [RegisterHandler] + // until they implement the [Store] interface). + // + // External code, such as the ipnlocal package, may choose to register + // additional policy stores, such as config files and policies received from + // the control plane. + lazyinit.Defer(func() error { + // Do not register or use default policy stores during tests. + // Each test should set up its own necessary configurations. + if testenv.InTest() { + return nil + } + return configureSyspolicy(nil) + }) +} + +// configureSyspolicy configures syspolicy for use on Windows, +// either in test or regular builds depending on whether tb has a non-nil value. +func configureSyspolicy(tb internal.TB) error { + const localSystemSID = "S-1-5-18" + // Always create and register a machine policy store that reads + // policy settings from the HKEY_LOCAL_MACHINE registry hive. + machineStore, err := source.NewMachinePlatformPolicyStore() + if err != nil { + return fmt.Errorf("failed to create the machine policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore) + } + if err != nil { + return err + } + // Check whether the current process is running as Local System or not. + u, err := user.Current() + if err != nil { + return err + } + if u.Uid == localSystemSID { + return nil + } + // If it's not a Local System's process (e.g., the GUI and not the tailscaled service), + // we should create and use a policy store for the current user that reads + // policy settings from that user's registry hive (HKEY_CURRENT_USER). + userStore, err := source.NewUserPlatformPolicyStore(0) + if err != nil { + return fmt.Errorf("failed to create the current user's policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore) + } + if err != nil { + return err + } + // And also set [CurrentUserScope] as the [CurrentScope], so [GetString], + // [GetVisibility] and similar functions would be returning a merged result + // of the machine's and user's policies. + if !setting.SetCurrentScope(setting.CurrentUserScope) { + return errors.New("current scope already set") + } + return nil +} diff --git a/util/winutil/gp/policylock_windows.go b/util/winutil/gp/policylock_windows.go index f92c534bb..95453aa16 100644 --- a/util/winutil/gp/policylock_windows.go +++ b/util/winutil/gp/policylock_windows.go @@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) { select { case resultCh <- policyLockResult{handle, err}: // lockSlow has received the result. + break send_result default: select { case <-closing: