mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-10 16:47:01 +02:00
247 lines
6.3 KiB
Go
247 lines
6.3 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package aws
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/service/iam/iamiface"
|
|
"github.com/aws/aws-sdk-go/service/sts/stsiface"
|
|
"github.com/hashicorp/vault/sdk/framework"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/sdk/queue"
|
|
)
|
|
|
|
const (
|
|
rootConfigPath = "config/root"
|
|
minAwsUserRollbackAge = 5 * time.Minute
|
|
operationPrefixAWS = "aws"
|
|
operationPrefixAWSASD = "aws-config"
|
|
)
|
|
|
|
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
|
b := Backend(conf)
|
|
if err := b.Setup(ctx, conf); err != nil {
|
|
return nil, err
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func Backend(_ *logical.BackendConfig) *backend {
|
|
var b backend
|
|
b.minAllowableRotationPeriod = minAllowableRotationPeriod
|
|
b.credRotationQueue = queue.New()
|
|
b.Backend = &framework.Backend{
|
|
Help: strings.TrimSpace(backendHelp),
|
|
|
|
PathsSpecial: &logical.Paths{
|
|
LocalStorage: []string{
|
|
framework.WALPrefix,
|
|
},
|
|
SealWrapStorage: []string{
|
|
rootConfigPath,
|
|
pathStaticCreds + "/",
|
|
},
|
|
},
|
|
|
|
Paths: []*framework.Path{
|
|
pathConfigRoot(&b),
|
|
pathConfigRotateRoot(&b),
|
|
pathConfigLease(&b),
|
|
pathRoles(&b),
|
|
pathListRoles(&b),
|
|
pathStaticRoles(&b),
|
|
pathStaticCredentials(&b),
|
|
pathUser(&b),
|
|
},
|
|
|
|
Secrets: []*framework.Secret{
|
|
secretAccessKeys(&b),
|
|
},
|
|
|
|
InitializeFunc: b.initialize,
|
|
Invalidate: b.invalidate,
|
|
WALRollback: b.walRollback,
|
|
WALRollbackMinAge: minAwsUserRollbackAge,
|
|
PeriodicFunc: func(ctx context.Context, req *logical.Request) error {
|
|
if b.WriteSafeReplicationState() {
|
|
return b.rotateExpiredStaticCreds(ctx, req)
|
|
}
|
|
return nil
|
|
},
|
|
BackendType: logical.TypeLogical,
|
|
}
|
|
|
|
return &b
|
|
}
|
|
|
|
type backend struct {
|
|
*framework.Backend
|
|
|
|
// Mutex to protect access to reading and writing policies
|
|
roleMutex sync.RWMutex
|
|
|
|
// Mutex to protect access to iam/sts clients and client configs
|
|
clientMutex sync.RWMutex
|
|
|
|
// iamClient and stsClient hold configured iam and sts clients for reuse, and
|
|
// to enable mocking with AWS iface for tests
|
|
iamClient iamiface.IAMAPI
|
|
stsClient stsiface.STSAPI
|
|
|
|
// the age of a static role's credential is tracked by a priority queue and handled
|
|
// by the PeriodicFunc
|
|
credRotationQueue *queue.PriorityQueue
|
|
|
|
minAllowableRotationPeriod time.Duration
|
|
}
|
|
|
|
const backendHelp = `
|
|
The AWS backend dynamically generates AWS access keys for a set of
|
|
IAM policies. The AWS access keys have a configurable lease set and
|
|
are automatically revoked at the end of the lease.
|
|
|
|
After mounting this backend, credentials to generate IAM keys must
|
|
be configured with the "root" path and policies must be written using
|
|
the "roles/" endpoints before any access keys can be generated.
|
|
`
|
|
|
|
func (b *backend) invalidate(ctx context.Context, key string) {
|
|
switch {
|
|
case key == rootConfigPath:
|
|
b.clearClients()
|
|
}
|
|
}
|
|
|
|
// clearClients clears the backend's IAM and STS clients
|
|
func (b *backend) clearClients() {
|
|
b.clientMutex.Lock()
|
|
defer b.clientMutex.Unlock()
|
|
b.iamClient = nil
|
|
b.stsClient = nil
|
|
}
|
|
|
|
// clientIAM returns the configured IAM client. If nil, it constructs a new one
|
|
// and returns it, setting it the internal variable
|
|
func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) {
|
|
b.clientMutex.RLock()
|
|
if b.iamClient != nil {
|
|
b.clientMutex.RUnlock()
|
|
return b.iamClient, nil
|
|
}
|
|
|
|
// Upgrade the lock for writing
|
|
b.clientMutex.RUnlock()
|
|
b.clientMutex.Lock()
|
|
defer b.clientMutex.Unlock()
|
|
|
|
// check client again, in the event that a client was being created while we
|
|
// waited for Lock()
|
|
if b.iamClient != nil {
|
|
return b.iamClient, nil
|
|
}
|
|
|
|
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b.iamClient = iamClient
|
|
|
|
return b.iamClient, nil
|
|
}
|
|
|
|
func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.STSAPI, error) {
|
|
b.clientMutex.RLock()
|
|
if b.stsClient != nil {
|
|
b.clientMutex.RUnlock()
|
|
return b.stsClient, nil
|
|
}
|
|
|
|
// Upgrade the lock for writing
|
|
b.clientMutex.RUnlock()
|
|
b.clientMutex.Lock()
|
|
defer b.clientMutex.Unlock()
|
|
|
|
// check client again, in the event that a client was being created while we
|
|
// waited for Lock()
|
|
if b.stsClient != nil {
|
|
return b.stsClient, nil
|
|
}
|
|
|
|
stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b.stsClient = stsClient
|
|
|
|
return b.stsClient, nil
|
|
}
|
|
|
|
func (b *backend) initialize(ctx context.Context, request *logical.InitializationRequest) error {
|
|
if !b.WriteSafeReplicationState() {
|
|
b.Logger().Info("skipping populating rotation queue")
|
|
return nil
|
|
}
|
|
b.Logger().Info("populating rotation queue")
|
|
|
|
creds, err := request.Storage.List(ctx, pathStaticCreds+"/")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.Logger().Debug(fmt.Sprintf("Adding %d items to the rotation queue", len(creds)))
|
|
for _, roleName := range creds {
|
|
if roleName == "" {
|
|
continue
|
|
}
|
|
credPath := formatCredsStoragePath(roleName)
|
|
credsEntry, err := request.Storage.Get(ctx, credPath)
|
|
if err != nil {
|
|
return fmt.Errorf("could not read credentials: %w", err)
|
|
}
|
|
if credsEntry == nil {
|
|
continue
|
|
}
|
|
credentials := awsCredentials{}
|
|
if err := credsEntry.DecodeJSON(&credentials); err != nil {
|
|
return fmt.Errorf("failed to decode credentials: %w", err)
|
|
}
|
|
|
|
configEntry, err := request.Storage.Get(ctx, formatRoleStoragePath(roleName))
|
|
if err != nil {
|
|
return fmt.Errorf("could not read role: %w", err)
|
|
}
|
|
if configEntry == nil {
|
|
continue
|
|
}
|
|
config := staticRoleEntry{}
|
|
if err := configEntry.DecodeJSON(&config); err != nil {
|
|
return fmt.Errorf("failed to decode role config: %w", err)
|
|
}
|
|
|
|
if credentials.Expiration == nil {
|
|
expiration := time.Now().UTC().Add(config.RotationPeriod)
|
|
credentials.Expiration = &expiration
|
|
_, err := logical.StorageEntryJSON(credPath, creds)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal object to JSON: %w", err)
|
|
}
|
|
b.Logger().Debug("no known expiration time for credentials so resetting the expiration", "role", roleName, "new expiration", expiration)
|
|
}
|
|
|
|
err = b.credRotationQueue.Push(&queue.Item{
|
|
Key: config.Name,
|
|
Value: config,
|
|
Priority: credentials.priority(config),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add creds for role %s to queue: %w", roleName, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|