vault/builtin/logical/aws/client.go
kpcraig d8482b008a
VAULT-32804: Add STS Fallback parameters to secrets-aws engine (#29051)
Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>

---------

Co-authored-by: Robert <17119716+robmonte@users.noreply.github.com>
Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>
2024-12-05 16:22:21 -05:00

232 lines
6.5 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package aws
import (
"context"
"errors"
"fmt"
"os"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) {
// set fallback region (we can overwrite later)
fallbackRegion := os.Getenv("AWS_REGION")
if fallbackRegion == "" {
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
}
if fallbackRegion == "" {
fallbackRegion = "us-east-1"
}
maxRetries := aws.UseServiceDefaultRetries
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
var configs []*aws.Config
// ensure the nil case uses defaults
if entry == nil {
ccfg := awsutil.CredentialsConfig{
HTTPClient: cleanhttp.DefaultClient(),
Logger: logger,
Region: fallbackRegion,
}
creds, err := ccfg.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(fallbackRegion),
Endpoint: aws.String(""),
MaxRetries: aws.Int(maxRetries),
})
return configs, nil
}
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err)
}
var endpoints []string
var regions []string
credsConfig := &awsutil.CredentialsConfig{}
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.HTTPClient = cleanhttp.DefaultClient()
credsConfig.Logger = logger
maxRetries = config.MaxRetries
if clientType == "iam" && config.IAMEndpoint != "" {
endpoints = append(endpoints, config.IAMEndpoint)
} else if clientType == "sts" && config.STSEndpoint != "" {
endpoints = append(endpoints, config.STSEndpoint)
if config.STSRegion != "" {
regions = append(regions, config.STSRegion)
}
if len(config.STSFallbackEndpoints) > 0 {
endpoints = append(endpoints, config.STSFallbackEndpoints...)
}
if len(config.STSFallbackRegions) > 0 {
regions = append(regions, config.STSFallbackRegions...)
}
}
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
if len(regions) == 0 {
regions = append(regions, fallbackRegion)
}
if len(regions) != len(endpoints) {
// this probably can't happen, if the input was checked correctly
return nil, errors.New("number of regions does not match number of endpoints")
}
for i := 0; i < len(endpoints); i++ {
if len(regions) > i {
credsConfig.Region = regions[i]
} else {
credsConfig.Region = fallbackRegion
}
creds, err := credsConfig.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: aws.String(endpoints[i]),
MaxRetries: aws.Int(maxRetries),
HTTPClient: cleanhttp.DefaultClient(),
})
}
return configs, nil
}
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
if len(awsConfig) != 1 {
return nil, errors.New("could not obtain aws config")
}
sess, err := session.NewSession(awsConfig[0])
if err != nil {
return nil, err
}
client := iam.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
return client, nil
}
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger)
if err != nil {
return nil, err
}
var client *sts.STS
for _, cfg := range awsConfig {
sess, err := session.NewSession(cfg)
if err != nil {
return nil, err
}
client = sts.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
// ping the client - we only care about errors
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return client, nil
} else {
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", cfg.Endpoint, "failed region", cfg.Region)
}
}
return nil, fmt.Errorf("could not obtain sts client")
}
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
// When the client's STS credentials expire, it will use this interface to fetch a new
// plugin identity token and exchange it for new STS credentials.
type PluginIdentityTokenFetcher struct {
sys logical.SystemView
logger hclog.Logger
audience string
ns *namespace.Namespace
ttl time.Duration
}
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
Audience: f.audience,
TTL: f.ttl,
})
if err != nil {
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
}
f.logger.Info("fetched new plugin identity token")
if resp.TTL < f.ttl {
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", f.ttl, "actual", resp.TTL)
}
return []byte(resp.Token.Token()), nil
}