vault/builtin/logical/aws/client.go

348 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package aws
import (
"context"
"errors"
"fmt"
"os"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
// getRootIAMConfig creates an *aws.Config for Vault to connect to IAM.
func (b *backend) getRootIAMConfig(ctx context.Context, s logical.Storage, logger hclog.Logger) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
var maxRetries int = aws.UseServiceDefaultRetries
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
if entry != nil {
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err)
}
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.Region = config.Region
maxRetries = config.MaxRetries
if config.IAMEndpoint != "" {
endpoint = *aws.String(config.IAMEndpoint)
}
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
}
if credsConfig.Region == "" {
credsConfig.Region = getFallbackRegion()
}
credsConfig.HTTPClient = cleanhttp.DefaultClient()
credsConfig.Logger = logger
creds, err := credsConfig.GenerateCredentialChain()
if err != nil {
return nil, err
}
return &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: &endpoint,
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(maxRetries),
}, nil
}
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func (b *backend) getRootSTSConfigs(ctx context.Context, s logical.Storage, logger hclog.Logger) ([]*aws.Config, error) {
// set fallback region (we can overwrite later)
fallbackRegion := getFallbackRegion()
maxRetries := aws.UseServiceDefaultRetries
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
var configs []*aws.Config
// ensure the nil case uses defaults
if entry == nil {
ccfg := awsutil.CredentialsConfig{
HTTPClient: cleanhttp.DefaultClient(),
Logger: logger,
Region: fallbackRegion,
}
creds, err := ccfg.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(fallbackRegion),
Endpoint: aws.String(""),
MaxRetries: aws.Int(maxRetries),
})
return configs, nil
}
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err)
}
var endpoints []string
var regions []string
credsConfig := &awsutil.CredentialsConfig{}
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.HTTPClient = cleanhttp.DefaultClient()
credsConfig.Logger = logger
if config.Region != "" {
regions = append(regions, config.Region)
}
maxRetries = config.MaxRetries
if config.STSEndpoint != "" {
endpoints = append(endpoints, config.STSEndpoint)
if config.STSRegion != "" {
// this retains original logic, where sts region was only used if sts endpoint was set
regions = []string{config.STSRegion} // override to be "only" region if set
}
if len(config.STSFallbackEndpoints) > 0 {
endpoints = append(endpoints, config.STSFallbackEndpoints...)
}
if len(config.STSFallbackRegions) > 0 {
regions = append(regions, config.STSFallbackRegions...)
}
}
opts := make([]awsutil.Option, 0)
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
// explicitly disable environment and shared credential providers when using Web Identity Token Fetcher
// enables WIF usage in environments that may use AWS Profiles or environment variables for other use-cases
opts = append(opts, awsutil.WithEnvironmentCredentials(false), awsutil.WithSharedCredentials(false))
}
// at this point, in the IAM case,
// - regions contains config.Region, if it was set.
// - endpoints contains iam_endpoint, if it was set.
// in the sts case,
// - regions contains sts_region, if it was set, then sts_fallback_regions in order, if they were set.
// - endpoints contains sts_endpoint, if it was set, then sts_fallback_endpoints in order, if they were set.
// case in which nothing was supplied
if len(regions) == 0 {
// fallback region is in descending order, AWS_REGION, or AWS_DEFAULT_REGION, or us-east-1
regions = append(regions, fallbackRegion)
}
if len(endpoints) == 0 {
for _, v := range regions {
endpoints = append(endpoints, matchingSTSEndpoint(v))
}
}
// for this approach of using parallel arrays to part out the configs, we want equal numbers of regions and endpoints
if len(regions) != len(endpoints) {
return nil, errors.New("number of regions does not match number of endpoints")
}
for i := 0; i < len(endpoints); i++ {
if len(regions) > i {
credsConfig.Region = regions[i]
} else {
credsConfig.Region = fallbackRegion
}
creds, err := credsConfig.GenerateCredentialChain(opts...)
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: aws.String(endpoints[i]),
MaxRetries: aws.Int(maxRetries),
HTTPClient: cleanhttp.DefaultClient(),
})
}
return configs, nil
}
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (*iam.IAM, error) {
var awsConfig *aws.Config
var err error
if entry != nil && entry.AssumeRoleARN != "" {
awsConfig, err = b.assumeRoleStatic(ctx, s, entry)
if err != nil {
return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err)
}
} else {
awsConfig, err = b.getRootIAMConfig(ctx, s, logger)
if err != nil {
return nil, err
}
}
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}
client := iam.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain IAM client")
}
return client, nil
}
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootSTSConfigs(ctx, s, logger)
if err != nil {
return nil, err
}
var client *sts.STS
for _, cfg := range awsConfig {
sess, err := session.NewSession(cfg)
if err != nil {
return nil, err
}
client = sts.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
// ping the client - we only care about errors
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return client, nil
} else {
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", *cfg.Endpoint, "failed region", *cfg.Region)
}
}
return nil, fmt.Errorf("could not obtain sts client")
}
// matchingSTSEndpoint returns the endpoint for the supplied region, according to
// http://docs.aws.amazon.com/general/latest/gr/sts.html
func matchingSTSEndpoint(stsRegion string) string {
return fmt.Sprintf("https://sts.%s.amazonaws.com", stsRegion)
}
// getFallbackRegion returns an aws region fallback. It will check in the AWS specified order:
// - AWS_REGION, then
// - AWS_DEFAULT_REGION, then
// - us-east-1
func getFallbackRegion() string {
// set fallback region (we can overwrite later)
fallbackRegion := os.Getenv("AWS_REGION")
if fallbackRegion == "" {
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
}
if fallbackRegion == "" {
fallbackRegion = "us-east-1"
}
return fallbackRegion
}
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
// When the client's STS credentials expire, it will use this interface to fetch a new
// plugin identity token and exchange it for new STS credentials.
type PluginIdentityTokenFetcher struct {
sys logical.SystemView
logger hclog.Logger
audience string
ns *namespace.Namespace
ttl time.Duration
}
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
Audience: f.audience,
TTL: f.ttl,
})
if err != nil {
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
}
f.logger.Info("fetched new plugin identity token")
if resp.TTL < f.ttl {
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", f.ttl, "actual", resp.TTL)
}
return []byte(resp.Token.Token()), nil
}