mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-10 16:47:01 +02:00
Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com> --------- Co-authored-by: Robert <17119716+robmonte@users.noreply.github.com> Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>
232 lines
6.5 KiB
Go
232 lines
6.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package aws
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/iam"
|
|
"github.com/aws/aws-sdk-go/service/sts"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/awsutil"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
|
|
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
|
|
func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) {
|
|
// set fallback region (we can overwrite later)
|
|
fallbackRegion := os.Getenv("AWS_REGION")
|
|
if fallbackRegion == "" {
|
|
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
|
|
}
|
|
if fallbackRegion == "" {
|
|
fallbackRegion = "us-east-1"
|
|
}
|
|
|
|
maxRetries := aws.UseServiceDefaultRetries
|
|
|
|
entry, err := s.Get(ctx, "config/root")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var configs []*aws.Config
|
|
|
|
// ensure the nil case uses defaults
|
|
if entry == nil {
|
|
ccfg := awsutil.CredentialsConfig{
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
Logger: logger,
|
|
Region: fallbackRegion,
|
|
}
|
|
creds, err := ccfg.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
configs = append(configs, &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(fallbackRegion),
|
|
Endpoint: aws.String(""),
|
|
MaxRetries: aws.Int(maxRetries),
|
|
})
|
|
|
|
return configs, nil
|
|
}
|
|
|
|
var config rootConfig
|
|
if err := entry.DecodeJSON(&config); err != nil {
|
|
return nil, fmt.Errorf("error reading root configuration: %w", err)
|
|
}
|
|
|
|
var endpoints []string
|
|
var regions []string
|
|
credsConfig := &awsutil.CredentialsConfig{}
|
|
|
|
credsConfig.AccessKey = config.AccessKey
|
|
credsConfig.SecretKey = config.SecretKey
|
|
credsConfig.HTTPClient = cleanhttp.DefaultClient()
|
|
credsConfig.Logger = logger
|
|
|
|
maxRetries = config.MaxRetries
|
|
if clientType == "iam" && config.IAMEndpoint != "" {
|
|
endpoints = append(endpoints, config.IAMEndpoint)
|
|
} else if clientType == "sts" && config.STSEndpoint != "" {
|
|
endpoints = append(endpoints, config.STSEndpoint)
|
|
if config.STSRegion != "" {
|
|
regions = append(regions, config.STSRegion)
|
|
}
|
|
|
|
if len(config.STSFallbackEndpoints) > 0 {
|
|
endpoints = append(endpoints, config.STSFallbackEndpoints...)
|
|
}
|
|
|
|
if len(config.STSFallbackRegions) > 0 {
|
|
regions = append(regions, config.STSFallbackRegions...)
|
|
}
|
|
}
|
|
|
|
if config.IdentityTokenAudience != "" {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
|
|
}
|
|
|
|
fetcher := &PluginIdentityTokenFetcher{
|
|
sys: b.System(),
|
|
logger: b.Logger(),
|
|
ns: ns,
|
|
audience: config.IdentityTokenAudience,
|
|
ttl: config.IdentityTokenTTL,
|
|
}
|
|
|
|
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
|
|
credsConfig.WebIdentityTokenFetcher = fetcher
|
|
credsConfig.RoleARN = config.RoleARN
|
|
}
|
|
|
|
if len(regions) == 0 {
|
|
regions = append(regions, fallbackRegion)
|
|
}
|
|
|
|
if len(regions) != len(endpoints) {
|
|
// this probably can't happen, if the input was checked correctly
|
|
return nil, errors.New("number of regions does not match number of endpoints")
|
|
}
|
|
|
|
for i := 0; i < len(endpoints); i++ {
|
|
if len(regions) > i {
|
|
credsConfig.Region = regions[i]
|
|
} else {
|
|
credsConfig.Region = fallbackRegion
|
|
}
|
|
creds, err := credsConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
configs = append(configs, &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(credsConfig.Region),
|
|
Endpoint: aws.String(endpoints[i]),
|
|
MaxRetries: aws.Int(maxRetries),
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
})
|
|
}
|
|
|
|
return configs, nil
|
|
}
|
|
|
|
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
|
|
awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(awsConfig) != 1 {
|
|
return nil, errors.New("could not obtain aws config")
|
|
}
|
|
sess, err := session.NewSession(awsConfig[0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client := iam.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain iam client")
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
|
|
awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var client *sts.STS
|
|
|
|
for _, cfg := range awsConfig {
|
|
sess, err := session.NewSession(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client = sts.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain sts client")
|
|
}
|
|
|
|
// ping the client - we only care about errors
|
|
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
|
|
if err == nil {
|
|
return client, nil
|
|
} else {
|
|
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", cfg.Endpoint, "failed region", cfg.Region)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("could not obtain sts client")
|
|
}
|
|
|
|
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
|
|
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
|
|
// When the client's STS credentials expire, it will use this interface to fetch a new
|
|
// plugin identity token and exchange it for new STS credentials.
|
|
type PluginIdentityTokenFetcher struct {
|
|
sys logical.SystemView
|
|
logger hclog.Logger
|
|
audience string
|
|
ns *namespace.Namespace
|
|
ttl time.Duration
|
|
}
|
|
|
|
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
|
|
|
|
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
|
|
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
|
|
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
|
|
Audience: f.audience,
|
|
TTL: f.ttl,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
|
|
}
|
|
f.logger.Info("fetched new plugin identity token")
|
|
|
|
if resp.TTL < f.ttl {
|
|
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
|
|
"requested", f.ttl, "actual", resp.TTL)
|
|
}
|
|
|
|
return []byte(resp.Token.Token()), nil
|
|
}
|