vault/physical/spanner/spanner_ha.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

411 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package spanner
import (
"context"
"fmt"
"sync"
"time"
"cloud.google.com/go/spanner"
metrics "github.com/armon/go-metrics"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/physical"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
)
// Verify Backend satisfies the correct interfaces
var (
_ physical.HABackend = (*Backend)(nil)
_ physical.Lock = (*Lock)(nil)
)
const (
// LockRenewInterval is the time to wait between lock renewals.
LockRenewInterval = 5 * time.Second
// LockRetryInterval is the amount of time to wait if the lock fails before
// trying again.
LockRetryInterval = 5 * time.Second
// LockTTL is the default lock TTL.
LockTTL = 15 * time.Second
// LockWatchRetryInterval is the amount of time to wait if a watch fails
// before trying again.
LockWatchRetryInterval = 5 * time.Second
// LockWatchRetryMax is the number of times to retry a failed watch before
// signaling that leadership is lost.
LockWatchRetryMax = 5
)
var (
// metricLockUnlock is the metric to register for a lock delete.
metricLockUnlock = []string{"spanner", "lock", "unlock"}
// metricLockGet is the metric to register for a lock get.
metricLockLock = []string{"spanner", "lock", "lock"}
// metricLockValue is the metric to register for a lock create/update.
metricLockValue = []string{"spanner", "lock", "value"}
)
// Lock is the HA lock.
type Lock struct {
// backend is the underlying physical backend.
backend *Backend
// key is the name of the key. value is the value of the key.
key, value string
// held is a boolean indicating if the lock is currently held.
held bool
// identity is the internal identity of this key (unique to this server
// instance).
identity string
// lock is an internal lock
lock sync.Mutex
// stopCh is the channel that stops all operations. It may be closed in the
// event of a leader loss or graceful shutdown. stopped is a boolean
// indicating if we are stopped - it exists to prevent double closing the
// channel. stopLock is a mutex around the locks.
stopCh chan struct{}
stopped bool
stopLock sync.Mutex
// Allow modifying the Lock durations for ease of unit testing.
renewInterval time.Duration
retryInterval time.Duration
ttl time.Duration
watchRetryInterval time.Duration
watchRetryMax int
}
// LockRecord is the struct that corresponds to a lock.
type LockRecord struct {
Key string
Value string
Identity string
Timestamp time.Time
}
// HAEnabled implements HABackend and indicates that this backend supports high
// availability.
func (b *Backend) HAEnabled() bool {
return b.haEnabled
}
// LockWith acquires a mutual exclusion based on the given key.
func (b *Backend) LockWith(key, value string) (physical.Lock, error) {
identity, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("lock with: %w", err)
}
return &Lock{
backend: b,
key: key,
value: value,
identity: identity,
stopped: true,
renewInterval: LockRenewInterval,
retryInterval: LockRetryInterval,
ttl: LockTTL,
watchRetryInterval: LockWatchRetryInterval,
watchRetryMax: LockWatchRetryMax,
}, nil
}
// Lock acquires the given lock. The stopCh is optional. If closed, it
// interrupts the lock acquisition attempt. The returned channel should be
// closed when leadership is lost.
func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
defer metrics.MeasureSince(metricLockLock, time.Now())
l.lock.Lock()
defer l.lock.Unlock()
if l.held {
return nil, errors.New("lock already held")
}
// Attempt to lock - this function blocks until a lock is acquired or an error
// occurs.
acquired, err := l.attemptLock(stopCh)
if err != nil {
return nil, fmt.Errorf("lock: %w", err)
}
if !acquired {
return nil, nil
}
// We have the lock now
l.held = true
// Build the locks
l.stopLock.Lock()
l.stopCh = make(chan struct{})
l.stopped = false
l.stopLock.Unlock()
// Periodically renew and watch the lock
go l.renewLock()
go l.watchLock()
return l.stopCh, nil
}
// Unlock releases the lock.
func (l *Lock) Unlock() error {
defer metrics.MeasureSince(metricLockUnlock, time.Now())
l.lock.Lock()
defer l.lock.Unlock()
if !l.held {
return nil
}
// Stop any existing locking or renewal attempts
l.stopLock.Lock()
if !l.stopped {
l.stopped = true
close(l.stopCh)
}
l.stopLock.Unlock()
// Delete
ctx := context.Background()
if _, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Identity"})
if err != nil {
if spanner.ErrCode(err) != codes.NotFound {
return nil
}
return err
}
var r LockRecord
if derr := row.ToStruct(&r); derr != nil {
return fmt.Errorf("failed to decode to struct: %w", derr)
}
// If the identity is different, that means that between the time that after
// we stopped acquisition, the TTL expired and someone else grabbed the
// lock. We do not want to delete a lock that is not our own.
if r.Identity != l.identity {
return nil
}
return txn.BufferWrite([]*spanner.Mutation{
spanner.Delete(l.backend.haTable, spanner.Key{l.key}),
})
}); err != nil {
return fmt.Errorf("unlock: %w", err)
}
// We are no longer holding the lock
l.held = false
return nil
}
// Value returns the value of the lock and if it is held.
func (l *Lock) Value() (bool, string, error) {
defer metrics.MeasureSince(metricLockValue, time.Now())
r, err := l.get(context.Background())
if err != nil {
return false, "", err
}
if r == nil {
return false, "", err
}
return true, string(r.Value), nil
}
// attemptLock attempts to acquire a lock. If the given channel is closed, the
// acquisition attempt stops. This function returns when a lock is acquired or
// an error occurs.
func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) {
ticker := time.NewTicker(l.retryInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
acquired, err := l.writeLock()
if err != nil {
return false, fmt.Errorf("attempt lock: %w", err)
}
if !acquired {
continue
}
return true, nil
case <-stopCh:
return false, nil
}
}
}
// renewLock renews the given lock until the channel is closed.
func (l *Lock) renewLock() {
ticker := time.NewTicker(l.renewInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.writeLock()
case <-l.stopCh:
return
}
}
}
// watchLock checks whether the lock has changed in the table and closes the
// leader channel accordingly. If an error occurs during the check, watchLock
// will retry the operation and then close the leader channel if it can't
// succeed after retries.
func (l *Lock) watchLock() {
retries := 0
ticker := time.NewTicker(l.watchRetryInterval)
OUTER:
for {
// Check if the channel is already closed
select {
case <-l.stopCh:
break OUTER
default:
}
// Check if we've exceeded retries
if retries >= l.watchRetryMax-1 {
break OUTER
}
// Wait for the timer
select {
case <-ticker.C:
case <-l.stopCh:
break OUTER
}
// Attempt to read the key
r, err := l.get(context.Background())
if err != nil {
retries++
continue
}
// Verify the identity is the same
if r == nil || r.Identity != l.identity {
break OUTER
}
}
l.stopLock.Lock()
defer l.stopLock.Unlock()
if !l.stopped {
l.stopped = true
close(l.stopCh)
}
}
// writeLock writes the given lock using the following algorithm:
//
// - lock does not exist
// - write the lock
//
// - lock exists
// - if key is empty or identity is the same or timestamp exceeds TTL
// - update the lock to self
func (l *Lock) writeLock() (bool, error) {
// Keep track of whether the lock was written
lockWritten := false
// Create a transaction to read and the update (maybe)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// The transaction will be retried, and it could sit in a queue behind, say,
// the delete operation. To stop the transaction, we close the context when
// the associated stopCh is received.
go func() {
select {
case <-l.stopCh:
cancel()
case <-ctx.Done():
}
}()
_, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Identity", "Timestamp"})
if err != nil && spanner.ErrCode(err) != codes.NotFound {
return err
}
// If there was a record, verify that the record is still trustable.
if row != nil {
var r LockRecord
if derr := row.ToStruct(&r); derr != nil {
return fmt.Errorf("failed to decode to struct: %w", derr)
}
// If the key is empty or the identity is ours or the ttl expired, we can
// write. Otherwise, return now because we cannot.
if r.Key != "" && r.Identity != l.identity && time.Now().UTC().Sub(r.Timestamp) < l.ttl {
return nil
}
}
m, err := spanner.InsertOrUpdateStruct(l.backend.haTable, &LockRecord{
Key: l.key,
Value: l.value,
Identity: l.identity,
Timestamp: time.Now().UTC(),
})
if err != nil {
return fmt.Errorf("failed to generate struct: %w", err)
}
if err := txn.BufferWrite([]*spanner.Mutation{m}); err != nil {
return fmt.Errorf("failed to write: %w", err)
}
// Mark that the lock was acquired
lockWritten = true
return nil
})
if err != nil {
return false, fmt.Errorf("write lock: %w", err)
}
return lockWritten, nil
}
// get retrieves the value for the lock.
func (l *Lock) get(ctx context.Context) (*LockRecord, error) {
// Read
row, err := l.backend.haClient.Single().ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Value", "Timestamp", "Identity"})
if spanner.ErrCode(err) == codes.NotFound {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to read value for %q: %w", l.key, err)
}
var r LockRecord
if err := row.ToStruct(&r); err != nil {
return nil, fmt.Errorf("failed to decode lock: %w", err)
}
return &r, nil
}