From b1f0d4e495cb1007b4ee3f7fdf86071171543c0d Mon Sep 17 00:00:00 2001 From: Alexander Scheel Date: Tue, 23 May 2023 15:44:05 -0400 Subject: [PATCH] Add nonce service to sdk/helpers, use in PKI (#20688) * Build a better nonce service Signed-off-by: Alexander Scheel * Add internal nonce service for testing Signed-off-by: Alexander Scheel * Add benchmarks for nonce service Signed-off-by: Alexander Scheel * Add statistics around how long tidy took Signed-off-by: Alexander Scheel * Replace ACME nonces with shared nonce service Signed-off-by: Alexander Scheel * Add an initialize method to nonce services Signed-off-by: Alexander Scheel * Use the new initialize helper on nonce service in PKI Signed-off-by: Alexander Scheel * Add additional tests for nonces Signed-off-by: Alexander Scheel * Format sdk/helper/nonce Signed-off-by: Alexander Scheel * Use default 90s nonce expiry in PKI Signed-off-by: Alexander Scheel * Remove parallel test case as covered by benchmark Signed-off-by: Alexander Scheel * Add additional commentary to encrypted nonce implementation Signed-off-by: Alexander Scheel * Add nonce to test_packages Signed-off-by: Alexander Scheel --------- Signed-off-by: Alexander Scheel --- .../scripts/generate-test-package-lists.sh | 1 + builtin/logical/pki/acme_state.go | 80 +--- builtin/logical/pki/acme_state_test.go | 1 + builtin/logical/pki/backend.go | 5 + sdk/helper/nonce/benchmark_test.go | 248 ++++++++++ sdk/helper/nonce/encrypted_nonce.go | 443 ++++++++++++++++++ sdk/helper/nonce/nonce.go | 70 +++ sdk/helper/nonce/nonce_test.go | 92 ++++ sdk/helper/nonce/sync_map_nonce.go | 107 +++++ 9 files changed, 981 insertions(+), 66 deletions(-) create mode 100644 sdk/helper/nonce/benchmark_test.go create mode 100644 sdk/helper/nonce/encrypted_nonce.go create mode 100644 sdk/helper/nonce/nonce.go create mode 100644 sdk/helper/nonce/nonce_test.go create mode 100644 sdk/helper/nonce/sync_map_nonce.go diff --git a/.github/scripts/generate-test-package-lists.sh b/.github/scripts/generate-test-package-lists.sh index 8c902705bc..a74075e461 100755 --- a/.github/scripts/generate-test-package-lists.sh +++ b/.github/scripts/generate-test-package-lists.sh @@ -79,6 +79,7 @@ test_packages[5]+=" $base/vault/external_tests/sealmigration" if [ "${ENTERPRISE:+x}" == "x" ] ; then test_packages[5]+=" $base/vault/external_tests/transform" fi +test_packages[5]+=" $base/sdk/helper/nonce" # Total time: 588 test_packages[6]+=" $base" diff --git a/builtin/logical/pki/acme_state.go b/builtin/logical/pki/acme_state.go index b50a515006..baef33f561 100644 --- a/builtin/logical/pki/acme_state.go +++ b/builtin/logical/pki/acme_state.go @@ -17,13 +17,11 @@ import ( "time" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/nonce" "github.com/hashicorp/vault/sdk/logical" ) const ( - // How long nonces are considered valid. - nonceExpiry = 15 * time.Minute - // How many bytes are in a token. Per RFC 8555 Section // 8.3. HTTP Challenge and Section 11.3 Token Entropy: // @@ -40,9 +38,9 @@ const ( ) type acmeState struct { - nextExpiry *atomic.Int64 - nonces *sync.Map // map[string]time.Time - validator *ACMEChallengeEngine + nonces nonce.NonceService + + validator *ACMEChallengeEngine configDirty *atomic.Bool _config sync.RWMutex @@ -56,8 +54,7 @@ type acmeThumbprint struct { func NewACMEState() *acmeState { state := &acmeState{ - nextExpiry: new(atomic.Int64), - nonces: new(sync.Map), + nonces: nonce.NewNonceService(), validator: NewACMEChallengeEngine(), configDirty: new(atomic.Bool), } @@ -68,6 +65,11 @@ func NewACMEState() *acmeState { } func (a *acmeState) Initialize(b *backend, sc *storageContext) error { + // Initialize the nonce service. + if err := a.nonces.Initialize(); err != nil { + return fmt.Errorf("failed to initialize the ACME nonce service: %w", err) + } + // Load the ACME config. _, err := a.getConfigWithUpdate(sc) if err != nil { @@ -80,6 +82,7 @@ func (a *acmeState) Initialize(b *backend, sc *storageContext) error { } go a.validator.Run(b, a) + // All good. return nil } @@ -124,10 +127,6 @@ func (a *acmeState) getConfigWithUpdate(sc *storageContext) (*acmeConfigEntry, e return &configCopy, nil } -func generateNonce() (string, error) { - return generateRandomBase64(21) -} - func generateRandomBase64(srcBytes int) (string, error) { data := make([]byte, 21) if _, err := io.ReadFull(rand.Reader, data); err != nil { @@ -138,66 +137,15 @@ func generateRandomBase64(srcBytes int) (string, error) { } func (a *acmeState) GetNonce() (string, time.Time, error) { - now := time.Now() - nonce, err := generateNonce() - if err != nil { - return "", now, err - } - - then := now.Add(nonceExpiry) - a.nonces.Store(nonce, then) - - nextExpiry := a.nextExpiry.Load() - next := time.Unix(nextExpiry, 0) - if now.After(next) || then.Before(next) { - a.nextExpiry.Store(then.Unix()) - } - - return nonce, then, nil + return a.nonces.Get() } func (a *acmeState) RedeemNonce(nonce string) bool { - rawTimeout, present := a.nonces.LoadAndDelete(nonce) - if !present { - return false - } - - timeout := rawTimeout.(time.Time) - if time.Now().After(timeout) { - return false - } - - return true + return a.nonces.Redeem(nonce) } func (a *acmeState) DoTidyNonces() { - now := time.Now() - expiry := a.nextExpiry.Load() - then := time.Unix(expiry, 0) - - if expiry == 0 || now.After(then) { - a.TidyNonces() - } -} - -func (a *acmeState) TidyNonces() { - now := time.Now() - nextRun := now.Add(nonceExpiry) - - a.nonces.Range(func(key, value any) bool { - timeout := value.(time.Time) - if now.After(timeout) { - a.nonces.Delete(key) - } - - if timeout.Before(nextRun) { - nextRun = timeout - } - - return false /* don't quit looping */ - }) - - a.nextExpiry.Store(nextRun.Unix()) + a.nonces.Tidy() } type ACMEAccountStatus string diff --git a/builtin/logical/pki/acme_state_test.go b/builtin/logical/pki/acme_state_test.go index fda5f436e7..8d4f12127a 100644 --- a/builtin/logical/pki/acme_state_test.go +++ b/builtin/logical/pki/acme_state_test.go @@ -13,6 +13,7 @@ func TestAcmeNonces(t *testing.T) { t.Parallel() a := NewACMEState() + a.nonces.Initialize() // Simple operation should succeed. nonce, _, err := a.GetNonce() diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index f8cdf2b553..3367493dc4 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -693,9 +693,14 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er return nil } + // First tidy any ACME nonces to free memory. + b.acmeState.DoTidyNonces() + + // Then run unified transfer. backgroundSc := b.makeStorageContext(context.Background(), b.storage) go runUnifiedTransfer(backgroundSc) + // Then run the CRL rebuild and tidy operation. crlErr := doCRL() tidyErr := doAutoTidy() diff --git a/sdk/helper/nonce/benchmark_test.go b/sdk/helper/nonce/benchmark_test.go new file mode 100644 index 0000000000..05b6226315 --- /dev/null +++ b/sdk/helper/nonce/benchmark_test.go @@ -0,0 +1,248 @@ +package nonce + +import ( + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const ( + benchValidity = 5 * time.Second + logMemStats = true +) + +func benchWrapper(helper func(*testing.B, NonceService), b *testing.B, s NonceService) { + err := s.Initialize() + require.NoError(b, err) + + var m1, m2 runtime.MemStats + runtime.GC() + runtime.ReadMemStats(&m1) + helper(b, s) + runtime.ReadMemStats(&m2) + + if logMemStats { + b.Logf("in-use memory size: %v", m2.Alloc-m1.Alloc) + b.Logf("total alloc size: %v", m2.TotalAlloc-m1.TotalAlloc) + b.Logf("in-use memory count: %v", (m2.Mallocs-m2.Frees)-(m1.Mallocs-m1.Frees)) + b.Logf("total alloc count: %v", m2.Mallocs-m1.Mallocs) + } + b.Logf("Tidy output: %v", s.Tidy()) +} + +func BenchmarkEncryptedNonceServiceGet(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGet, b, s) +} + +func BenchmarkSyncMapNonceServiceGet(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGet, b, s) +} + +func benchGet(b *testing.B, s NonceService) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + } +} + +func BenchmarkEncryptedNonceServiceGetRedeem(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeem, b, s) +} + +func BenchmarkSyncMapNonceServiceGetRedeem(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeem, b, s) +} + +func benchGetRedeem(b *testing.B, s NonceService) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + ok = s.Redeem(token) + require.False(b, ok) + } +} + +func BenchmarkEncryptedNonceServiceGetRedeemTidy(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeemTidy, b, s) +} + +func BenchmarkSyncMapNonceServiceGetRedeemTidy(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeemTidy, b, s) +} + +func benchGetRedeemTidy(b *testing.B, s NonceService) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + s.Tidy() + } +} + +func BenchmarkEncryptedNonceServiceSequentialTidy(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeemSequentialTidy, b, s) +} + +func BenchmarkSyncMapNonceServiceSequentialTidy(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeemSequentialTidy, b, s) +} + +func benchGetRedeemSequentialTidy(b *testing.B, s NonceService) { + b.ResetTimer() + + for i := 0; i < b.N; i++ { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + } + + s.Tidy() +} + +func BenchmarkEncryptedNonceServiceRandomTidy(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeemRandomTidy, b, s) +} + +func BenchmarkSyncMapNonceServiceRandomTidy(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeemRandomTidy, b, s) +} + +func benchGetRedeemRandomTidy(b *testing.B, s NonceService) { + b.ResetTimer() + + for i := 0; i < b.N; i++ { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + if (i % 3) == 1 { + ok := s.Redeem(token) + require.True(b, ok) + } + } + + s.Tidy() +} + +func BenchmarkEncryptedNonceServiceParallelGet(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetParallelGet, b, s) +} + +func BenchmarkSyncMapNonceServiceParallelGet(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetParallelGet, b, s) +} + +func benchGetParallelGet(b *testing.B, s NonceService) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + } + }) +} + +func BenchmarkEncryptedNonceServiceParallelGetRedeem(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeemParallelGetRedeem, b, s) +} + +func BenchmarkSyncMapNonceServiceParallelGetRedeem(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeemParallelGetRedeem, b, s) +} + +func benchGetRedeemParallelGetRedeem(b *testing.B, s NonceService) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + } + }) +} + +func BenchmarkEncryptedNonceServiceParallelGetRedeemTidy(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchGetRedeemParallelGetRedeemTidy, b, s) +} + +func BenchmarkSyncMapNonceServiceParallelGetRedeemTidy(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchGetRedeemParallelGetRedeemTidy, b, s) +} + +func benchGetRedeemParallelGetRedeemTidy(b *testing.B, s NonceService) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + s.Tidy() + } + }) +} + +func BenchmarkEncryptedNonceServiceParallelTidy(b *testing.B) { + s := newEncryptedNonceService(benchValidity) + benchWrapper(benchParallelTidy, b, s) +} + +func BenchmarkSyncMapNonceServiceParallelTidy(b *testing.B) { + s := newSyncMapNonceService(benchValidity) + benchWrapper(benchParallelTidy, b, s) +} + +func benchParallelTidy(b *testing.B, s NonceService) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + token, _, err := s.Get() + require.NoError(b, err) + require.NotEmpty(b, token) + ok := s.Redeem(token) + require.True(b, ok) + } + }) + + b.StopTimer() + time.Sleep(2*time.Second + benchValidity) + runtime.GC() + b.StartTimer() + s.Tidy() +} diff --git a/sdk/helper/nonce/encrypted_nonce.go b/sdk/helper/nonce/encrypted_nonce.go new file mode 100644 index 0000000000..ffaa199c02 --- /dev/null +++ b/sdk/helper/nonce/encrypted_nonce.go @@ -0,0 +1,443 @@ +// Nonce is a class for generating and validating nonces loosely based off +// the design of Let's Encrypt's Boulder nonce service here: +// +// https://github.com/letsencrypt/boulder/blob/main/nonce/nonce.go +// +// We use an encrypted tuple of (expiry timestamp, counter value), allowing +// us to maintain a map of only unexpired counter values that have been +// redeemed. This means that issuing nonces involves updating counter values +// and creating only up to a fixed-amount of memory (the size of the validity +// period) in the maxIssued map, whereas the sync.Map potentially grows +// indefinitely when also coupled with the fact that sync.Map never releases +// memory back to the host when its size has shrunk. +// +// Redeeming a nonce thus only stores the used counter value (8 bytes) +// and other checks for delayed or reused nonces remain as fast as parsing +// and decrypting the token value. + +package nonce + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "sort" + "sync" + "sync/atomic" + "time" +) + +const ( + // Internal, versioned sentinel to make sure our base64 data is truly + // a nonce-like value. + nonceSentinel = "vault0" + + // Wire length of the nonce, excluding raw url base64 encoding: + // - 6 byte sentinel (above), + // - 8 byte AES-GCM IV + // - 16 byte encrypted (timestamp, counter) tuple (1 AES block) + // - 16 byte AES-GCM tag. + nonceLength = len(nonceSentinel) + 8 + 16 + 16 + + // Length of the decrypted plaintext underlying the nonce: + // - 8 byte expiry timestamp, unix seconds + // - 8 byte incrementing counter value (uint64) + noncePlaintextLength = 8 + 8 +) + +type ( + ensTimestamp uint64 + ensCounter uint64 +) + +type encryptedNonceService struct { + // How long a nonce is valid for. This directly correlates to memory + // usage (retention of redeemed nonces). + validity time.Duration + + // Underlying cipher for minting tokens. + crypt cipher.AEAD + + // The next counter value to use for issuing, _after_ calling Add(1) + // on it. + nextCounter *atomic.Uint64 + + // The remaining fields are locked by this read-only mutex. During + // issuing a nonce, we update maxIssued; during redeeming we update + // minCounter (an atomic) and redeemedTokens, and during tidy, we + // potentially update update all fields. + // + // By storing maxIssued, we can (from our tidy run) update the + // minCounter value when nonces were not redeemed recently, to make + // any later redemptions fast (within a time period). + // + // The outer map in redeemedTokens and maxIssued map are of fixed size, + // around the size of validity (in seconds). However, the internal + // redeemedTokens[timestamp] maps may grow unbounded (assuming a + // sufficiently fast system that can mint tokens infinitely fast). + // However, once this timestamp expires, we can fully delete all + // references to that map, and thus free up a potentially significant + // chunk of memory. + issueLock *sync.Mutex + maxIssued map[ensTimestamp]ensCounter + minCounter *atomic.Uint64 + redeemedTokens map[ensTimestamp]map[ensCounter]struct{} +} + +func newEncryptedNonceService(validity time.Duration) *encryptedNonceService { + return &encryptedNonceService{ + validity: validity, + + // nextCounter.Add(1) returns the _new_ value; by initializing to + // zero, we guarantee that nextCounter = minCounter + 1 on the first + // read; if it is redeemed right away, we then hold that invariant. + nextCounter: new(atomic.Uint64), + + issueLock: new(sync.Mutex), + maxIssued: make(map[ensTimestamp]ensCounter, validity/time.Second), + minCounter: new(atomic.Uint64), + redeemedTokens: make(map[ensTimestamp]map[ensCounter]struct{}, validity/time.Second), + } +} + +func (ens *encryptedNonceService) Initialize() error { + // On initialization, create a new AES key. This avoids having issues + // with the number of encryptions we can do under this service. + // + // Note that the nonce service will panic if this is not created. + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return fmt.Errorf("failed to initialize AES key: %w", err) + } + + block, err := aes.NewCipher(key) + if err != nil { + return fmt.Errorf("failed to initialize AES cipher: %w", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("failed to initialize AES-GCM: %w", err) + } + + ens.crypt = aead + return nil +} + +// This nonce service is strict (prohibits reuse of nonces even within +// the validity period) but is not cross-node: there would need to be +// an external communication mechanism to map nonce->node and only +// check for redemption there. Additionally, each initialization +// creates new key material and thus nonces from other nodes would +// not validate. + +func (ens *encryptedNonceService) IsStrict() bool { return true } +func (ens *encryptedNonceService) IsCrossNode() bool { return false } + +func (ens *encryptedNonceService) encryptNonce(counter uint64, expiry time.Time) (token string, err error) { + // counter is an 8-byte value and expiry (as a unix timestamp) is + // likewise, so we have exactly one block of data. + // + // We encode in argument order, i.e., (counter, expiry), and in + // big endian format. + + // Like Let's Encrypt, we use a 12-byte nonce with the leading four + // bytes as zero and the remaining 8 bytes as zeros. This gives us + // 2^(8*8)/2 = 2^32 birthday paradox limit to reuse a nonce. However, + // note that nonce reuse (in AES-GCM) doesn't leak the key, only the + // XOR of the plaintext. Here, as long as they can't forge valid nonces, + // we're fine. + nonce := make([]byte, 12) + for i := 0; i < 4; i++ { + nonce[i] = 0 + } + if _, err := io.ReadFull(rand.Reader, nonce[4:]); err != nil { + return "", fmt.Errorf("failed to read AEAD nonce: %w", err) + } + + plaintext := make([]byte, noncePlaintextLength) + binary.BigEndian.PutUint64(plaintext[0:], counter) + binary.BigEndian.PutUint64(plaintext[8:], uint64(expiry.Unix())) + ciphertext := ens.crypt.Seal(nil, nonce, plaintext, nil) + + // Now, generate the wire format of the nonce. Use a prefix, the nonce, + // and then the ciphertext. + var wire []byte + wire = append(wire, []byte(nonceSentinel)...) + wire = append(wire, nonce[4:]...) + wire = append(wire, ciphertext...) + + if len(wire) != nonceLength { + return "", fmt.Errorf("expected nonce length of %v got %v", nonceLength, len(wire)) + } + + return base64.RawURLEncoding.EncodeToString(wire), nil +} + +func (ens *encryptedNonceService) recordCounterForTime(counter uint64, expiry time.Time) { + timestamp := ensTimestamp(expiry.Unix()) + value := ensCounter(counter) + + ens.issueLock.Lock() + defer ens.issueLock.Unlock() + + // This allows us to update minCounter when a given timestamp expires, if + // we haven't seen all of that timestamp's nonces redeemed. Otherwise, we + // could potentially be stuck at a lower counter value, making it harder + // for us to check if nonces are redeemed quickly. + + lastValue, ok := ens.maxIssued[timestamp] + if !ok || lastValue < value { + ens.maxIssued[timestamp] = value + } +} + +func (ens *encryptedNonceService) Get() (token string, expiry time.Time, err error) { + counter := ens.nextCounter.Add(1) + now := time.Now() + then := now.Add(ens.validity) + + token, err = ens.encryptNonce(counter, then) + if err != nil { + return "", now, err + } + + ens.recordCounterForTime(counter, then) + return token, then, nil +} + +func (ens *encryptedNonceService) decryptNonce(token string) (counter uint64, expiry time.Time, ok bool) { + zero := time.Time{} + + wire, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return 0, zero, false + } + + if len(wire) != nonceLength { + return 0, zero, false + } + + data := wire + + sentinel := data[0:len(nonceSentinel)] + data = data[len(nonceSentinel):] + if subtle.ConstantTimeCompare([]byte(nonceSentinel), sentinel) != 1 { + return 0, zero, false + } + + nonce := make([]byte, 12) + for i := 0; i < 4; i++ { + nonce[i] = 0 + } + copy(nonce[4:12], data[0:8]) + data = data[8:] + + ciphertext := data[:] + + plaintext, err := ens.crypt.Open(nil, nonce, ciphertext, nil) + if err != nil { + return 0, zero, false + } + + if len(plaintext) != noncePlaintextLength { + return 0, zero, false + } + + counter = binary.BigEndian.Uint64(plaintext[0:8]) + unix := binary.BigEndian.Uint64(plaintext[8:]) + expiry = time.Unix(int64(unix), 0) + + return counter, expiry, true +} + +func (ens *encryptedNonceService) Redeem(token string) bool { + now := time.Now() + counter, expiry, ok := ens.decryptNonce(token) + if !ok { + return false + } + + if expiry.Before(now) { + return false + } + + if counter <= ens.minCounter.Load() { + return false + } + + timestamp := ensTimestamp(expiry.Unix()) + counterValue := ensCounter(counter) + + // From here on out, we're doing the expensive checks. This _looks_ + // like a valid token, but now we want to verify the used-exactly-once + // nature. + ens.issueLock.Lock() + defer ens.issueLock.Unlock() + + minCounter := ens.minCounter.Load() + if counter <= minCounter { + // Someone else redeemed this token or time has rolled over before we + // grabbed this lock. Reject this token. + return false + } + + // Check if this has already been redeemed. + timestampMap, present := ens.redeemedTokens[timestamp] + if !present { + // No tokens have been redeemed from this token. Provision the + // timestamp-specific map, but wait to see if we need to add into + // it. + timestampMap = make(map[ensCounter]struct{}) + ens.redeemedTokens[timestamp] = timestampMap + } + + _, present = timestampMap[counterValue] + if present { + // Token was already redeemed. Reject this request. + return false + } + + // From here on out, the token is valid. Let's start by seeing if we can + // free any memory usage. + minCounter = ens.tidyMemoryHoldingLock(now, minCounter) + + // Before we add to the map, we should see if we can save memory by just + // incrementing the minimum accepted by one, instead of adding to the + // timestamp for out of order redemption. + if minCounter+1 == counter { + minCounter = counter + } else { + // Otherwise, we've got to flag this counter as valid. + timestampMap[counterValue] = struct{}{} + } + + // Finally, update our value of minCounter because we held the lock. + ens.minCounter.Store(minCounter) + + return true +} + +func (ens *encryptedNonceService) tidyMemoryHoldingLock(now time.Time, minCounter uint64) uint64 { + // Quick and dirty tidy: any expired timestamps should be deleted, which + // should free the most memory (relatively speaking, given a uniform + // usage pattern). This also avoids an expensive iteration over all + // redeemed counter values. + // + // First tidy the redeemed tokens, as that is the largest value. + var deleteCandidates []ensTimestamp + for candidate := range ens.redeemedTokens { + if candidate < ensTimestamp(now.Unix()) { + deleteCandidates = append(deleteCandidates, candidate) + } + } + for _, candidate := range deleteCandidates { + delete(ens.redeemedTokens, candidate) + } + + // Then tidy the last used timestamp values. Here, any removed timestamps + // have an expiry time before now, which means they cannot be used. This + // means our minCounterValue, if it is + deleteCandidates = nil + for candidate, lastIssuedInTimestamp := range ens.maxIssued { + if candidate < ensTimestamp(now.Unix()) { + deleteCandidates = append(deleteCandidates, candidate) + if lastIssuedInTimestamp > ensCounter(minCounter) { + minCounter = uint64(lastIssuedInTimestamp) + } + } + } + for _, candidate := range deleteCandidates { + delete(ens.maxIssued, candidate) + } + return minCounter +} + +func (ens *encryptedNonceService) tidySequentialNonces(now time.Time, minCounter uint64) uint64 { + // This potentially slow sequential tidy allows us to free up an + // incremental amount of memory when out-of-order (common) redemption + // occurs. The underlying map may not shrink, but it should have + // additional capacity to handle additional redemptions of cohort + // nonces without additional allocations _if_ this tidy works. + // + // This is made possible by updating the minCounter based on the + // earlier maxIssued map and tries to maintain the fast-case invariant + // described in newEncryptedNonceService(...). + var timestamps []ensTimestamp + for timestamp := range ens.redeemedTokens { + timestamps = append(timestamps, timestamp) + } + + sort.Slice(timestamps, func(i, j int) bool { return timestamps[i] < timestamps[j] }) + var deleteCandidates []ensTimestamp + for _, timestamp := range timestamps { + counters := ens.redeemedTokens[timestamp] + for len(counters) > 0 { + _, present := counters[ensCounter(minCounter+1)] + if !present { + return minCounter + } + + minCounter += 1 + delete(counters, ensCounter(minCounter)) + } + + if len(counters) == 0 { + deleteCandidates = append(deleteCandidates, timestamp) + } + } + for _, candidate := range deleteCandidates { + delete(ens.redeemedTokens, candidate) + } + + return minCounter +} + +func (ens *encryptedNonceService) getMessage(lock time.Duration, memory time.Duration, sequential time.Duration) string { + now := time.Now() + var message string + message += fmt.Sprintf("len(ens.maxIssued): %v\n", len(ens.maxIssued)) + message += fmt.Sprintf("len(ens.redeemedTokens): %v\n", len(ens.redeemedTokens)) + + var total int + for timestamp, counters := range ens.redeemedTokens { + message += fmt.Sprintf(" ens.redeemedTokens[%v]: %v\n", timestamp, len(counters)) + total += len(counters) + } + build := time.Now() + + message += fmt.Sprintf("total redeemed tokens: %v\n", total) + message += fmt.Sprintf("time to grab lock: %v\n", lock) + message += fmt.Sprintf("time to tidy memory: %v\n", memory) + message += fmt.Sprintf("time to tidy sequential: %v\n", sequential) + message += fmt.Sprintf("time to build message: %v\n", build.Sub(now)) + return message +} + +func (ens *encryptedNonceService) Tidy() *NonceStatus { + lockStart := time.Now() + ens.issueLock.Lock() + defer ens.issueLock.Unlock() + lockEnd := time.Now() + + minCounter := ens.minCounter.Load() + + now := time.Now() + minCounter = ens.tidyMemoryHoldingLock(now, minCounter) + memory := time.Now() + minCounter = ens.tidySequentialNonces(now, minCounter) + sequential := time.Now() + ens.minCounter.Store(minCounter) + + issued := ens.nextCounter.Load() + return &NonceStatus{ + Issued: issued, + Outstanding: issued - minCounter, + Message: ens.getMessage(lockEnd.Sub(lockStart), memory.Sub(now), sequential.Sub(memory)), + } +} diff --git a/sdk/helper/nonce/nonce.go b/sdk/helper/nonce/nonce.go new file mode 100644 index 0000000000..2fa4e661fc --- /dev/null +++ b/sdk/helper/nonce/nonce.go @@ -0,0 +1,70 @@ +// Nonce is a class for generating and validating nonces loosely based off +// the design of Let's Encrypt's Boulder nonce service here: +// +// https://github.com/letsencrypt/boulder/blob/main/nonce/nonce.go + +package nonce + +import ( + "time" +) + +// NonceService is an interface for issuing and redeeming nonces, with +// a hook to periodically free resources when no redemptions have happened +// recently. +// +// Notably, nonces are not guaranteed to be stored or persisted; nonces +// from one startup will not necessarily be valid from another. +type NonceService interface { + // Before using a nonce service, it must be initialized. Failure to + // initialize might result in panics or other unexpected results. + Initialize() error + + // Get a nonce; returns three values: + // + // 1. The nonce itself, a base64-url-no-padding opaque value. + // 2. A time at which the nonce will expire, based on the validity + // period specified at construction. By default, the service issues + // short-lived nonces. + // 3. An error if one occurred during generation of the nonce. + Get() (string, time.Time, error) + + // Redeem the given nonce, returning whether or not it was accepted. A + // nonce given twice will be rejected if the service is a strict nonce + // service, but potentially accepted if the nonce service is loose + // (i.e., temporal revocation only). + Redeem(string) bool + + // A hook to tidy the memory usage of the underlying implementation; is + // implementation dependent. Some implementations may not return status + // information. + Tidy() *NonceStatus + + // If true, this is a strict only-once redemption service implementation, + // else a nonce could be accepted more than once within some safety + // window. + IsStrict() bool + + // Whether or not this service is usable across nodes. + IsCrossNode() bool +} + +func NewNonceService() NonceService { + // By default, we create an encrypted nonce service that is strict but not + // cross node, using a default window of 90 seconds (equal to the default + // context request timeout window). + return NewNonceServiceWithValidity(90 * time.Second) +} + +func NewNonceServiceWithValidity(validity time.Duration) NonceService { + return newEncryptedNonceService(validity) +} + +// Status information about the number of nonces in this service, perhaps +// local to this node. Presumably, the delta roughly correlates to present +// memory usage. +type NonceStatus struct { + Issued uint64 + Outstanding uint64 + Message string +} diff --git a/sdk/helper/nonce/nonce_test.go b/sdk/helper/nonce/nonce_test.go new file mode 100644 index 0000000000..a60d78e182 --- /dev/null +++ b/sdk/helper/nonce/nonce_test.go @@ -0,0 +1,92 @@ +package nonce + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNonceService(t *testing.T) { + t.Parallel() + + s := NewNonceService() + err := s.Initialize() + require.NoError(t, err) + + // Double redemption should fail. + nonce, _, err := s.Get() + require.NoError(t, err) + require.NotEmpty(t, nonce) + + require.True(t, s.Redeem(nonce)) + require.False(t, s.Redeem(nonce)) + + // Redeeming in opposite order should work. + var nonces []string + numNonces := 100 + for i := 0; i < numNonces; i++ { + nonce, _, err = s.Get() + require.NoError(t, err) + require.NotEmpty(t, nonce) + + nonces = append(nonces, nonce) + } + + for i := len(nonces) - 1; i >= 0; i-- { + nonce = nonces[i] + require.True(t, s.Redeem(nonce)) + } + + for i := 0; i < len(nonces); i++ { + nonce = nonces[i] + require.False(t, s.Redeem(nonce)) + } + + status := s.Tidy() + require.NotNil(t, status) + require.Equal(t, uint64(1+numNonces), status.Issued) + require.Equal(t, uint64(0), status.Outstanding) +} + +func TestNonceExpiry(t *testing.T) { + t.Parallel() + + s := NewNonceServiceWithValidity(2 * time.Second) + err := s.Initialize() + require.NoError(t, err) + + // Issue and redeem should succeed. + nonce, _, err := s.Get() + original := nonce + require.NoError(t, err) + require.NotEmpty(t, nonce) + require.True(t, s.Redeem(nonce)) + + // Issue and wait should fail to redeem. + nonce, _, err = s.Get() + require.NoError(t, err) + require.NotEmpty(t, nonce) + time.Sleep(3 * time.Second) + require.False(t, s.Redeem(nonce)) + + // Issue and wait+tidy should fail to redeem. + nonce, _, err = s.Get() + require.NoError(t, err) + require.NotEmpty(t, nonce) + time.Sleep(3 * time.Second) + s.Tidy() + require.False(t, s.Redeem(nonce)) + require.False(t, s.Redeem(nonce)) + + nonce, _, err = s.Get() + require.NoError(t, err) + require.NotEmpty(t, nonce) + s.Tidy() + time.Sleep(3 * time.Second) + require.False(t, s.Redeem(nonce)) + require.False(t, s.Redeem(nonce)) + + // Original nonce should fail on second use. + require.False(t, s.Redeem(original)) +} diff --git a/sdk/helper/nonce/sync_map_nonce.go b/sdk/helper/nonce/sync_map_nonce.go new file mode 100644 index 0000000000..d1da8000d6 --- /dev/null +++ b/sdk/helper/nonce/sync_map_nonce.go @@ -0,0 +1,107 @@ +package nonce + +import ( + "crypto/rand" + "encoding/base64" + "io" + "sync" + "sync/atomic" + "time" +) + +type syncMapNonceService struct { + validity time.Duration + issued *atomic.Uint64 + nextExpiry *atomic.Int64 + nonces *sync.Map // map[string]time.Time +} + +var _ NonceService = &syncMapNonceService{} + +func newSyncMapNonceService(validity time.Duration) *syncMapNonceService { + return &syncMapNonceService{ + validity: validity, + issued: new(atomic.Uint64), + nextExpiry: new(atomic.Int64), + nonces: new(sync.Map), + } +} + +func (a *syncMapNonceService) Initialize() error { return nil } +func (a *syncMapNonceService) IsStrict() bool { return true } +func (a *syncMapNonceService) IsCrossNode() bool { return false } + +func generateNonce() (string, error) { + return generateRandomBase64(21) +} + +func generateRandomBase64(srcBytes int) (string, error) { + data := make([]byte, 21) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(data), nil +} + +func (a *syncMapNonceService) Get() (string, time.Time, error) { + now := time.Now() + nonce, err := generateNonce() + if err != nil { + return "", now, err + } + + then := now.Add(a.validity) + a.nonces.Store(nonce, then) + + nextExpiry := a.nextExpiry.Load() + next := time.Unix(nextExpiry, 0) + if then.Before(next) { + a.nextExpiry.Store(then.Unix()) + } + + a.issued.Add(1) + + return nonce, then, nil +} + +func (a *syncMapNonceService) Redeem(nonce string) bool { + rawTimeout, present := a.nonces.LoadAndDelete(nonce) + if !present { + return false + } + + timeout := rawTimeout.(time.Time) + if time.Now().After(timeout) { + return false + } + + return true +} + +func (a *syncMapNonceService) Tidy() *NonceStatus { + now := time.Now() + nextRun := now.Add(a.validity) + var outstanding uint64 + a.nonces.Range(func(key, value any) bool { + timeout := value.(time.Time) + if now.After(timeout) { + a.nonces.Delete(key) + } else { + outstanding += 1 + } + + if timeout.Before(nextRun) { + nextRun = timeout + } + + return false /* don't quit looping */ + }) + + a.nextExpiry.Store(nextRun.Unix()) + + return &NonceStatus{ + Issued: a.issued.Load(), + Outstanding: outstanding, + } +}