mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-22 15:11:07 +02:00
348 lines
10 KiB
Go
348 lines
10 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package aws
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/iam"
|
|
"github.com/aws/aws-sdk-go/service/sts"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/awsutil"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
// getRootIAMConfig creates an *aws.Config for Vault to connect to IAM.
|
|
func (b *backend) getRootIAMConfig(ctx context.Context, s logical.Storage, logger hclog.Logger) (*aws.Config, error) {
|
|
credsConfig := &awsutil.CredentialsConfig{}
|
|
var endpoint string
|
|
var maxRetries int = aws.UseServiceDefaultRetries
|
|
|
|
entry, err := s.Get(ctx, "config/root")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if entry != nil {
|
|
var config rootConfig
|
|
if err := entry.DecodeJSON(&config); err != nil {
|
|
return nil, fmt.Errorf("error reading root configuration: %w", err)
|
|
}
|
|
|
|
credsConfig.AccessKey = config.AccessKey
|
|
credsConfig.SecretKey = config.SecretKey
|
|
credsConfig.Region = config.Region
|
|
maxRetries = config.MaxRetries
|
|
|
|
if config.IAMEndpoint != "" {
|
|
endpoint = *aws.String(config.IAMEndpoint)
|
|
}
|
|
|
|
if config.IdentityTokenAudience != "" {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
|
|
}
|
|
|
|
fetcher := &PluginIdentityTokenFetcher{
|
|
sys: b.System(),
|
|
logger: b.Logger(),
|
|
ns: ns,
|
|
audience: config.IdentityTokenAudience,
|
|
ttl: config.IdentityTokenTTL,
|
|
}
|
|
|
|
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
|
|
credsConfig.WebIdentityTokenFetcher = fetcher
|
|
credsConfig.RoleARN = config.RoleARN
|
|
}
|
|
}
|
|
|
|
if credsConfig.Region == "" {
|
|
credsConfig.Region = getFallbackRegion()
|
|
}
|
|
|
|
credsConfig.HTTPClient = cleanhttp.DefaultClient()
|
|
|
|
credsConfig.Logger = logger
|
|
|
|
creds, err := credsConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(credsConfig.Region),
|
|
Endpoint: &endpoint,
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
MaxRetries: aws.Int(maxRetries),
|
|
}, nil
|
|
}
|
|
|
|
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
|
|
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
|
|
func (b *backend) getRootSTSConfigs(ctx context.Context, s logical.Storage, logger hclog.Logger) ([]*aws.Config, error) {
|
|
// set fallback region (we can overwrite later)
|
|
fallbackRegion := getFallbackRegion()
|
|
|
|
maxRetries := aws.UseServiceDefaultRetries
|
|
|
|
entry, err := s.Get(ctx, "config/root")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var configs []*aws.Config
|
|
|
|
// ensure the nil case uses defaults
|
|
if entry == nil {
|
|
ccfg := awsutil.CredentialsConfig{
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
Logger: logger,
|
|
Region: fallbackRegion,
|
|
}
|
|
creds, err := ccfg.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
configs = append(configs, &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(fallbackRegion),
|
|
Endpoint: aws.String(""),
|
|
MaxRetries: aws.Int(maxRetries),
|
|
})
|
|
|
|
return configs, nil
|
|
}
|
|
|
|
var config rootConfig
|
|
if err := entry.DecodeJSON(&config); err != nil {
|
|
return nil, fmt.Errorf("error reading root configuration: %w", err)
|
|
}
|
|
|
|
var endpoints []string
|
|
var regions []string
|
|
credsConfig := &awsutil.CredentialsConfig{}
|
|
|
|
credsConfig.AccessKey = config.AccessKey
|
|
credsConfig.SecretKey = config.SecretKey
|
|
credsConfig.HTTPClient = cleanhttp.DefaultClient()
|
|
credsConfig.Logger = logger
|
|
|
|
if config.Region != "" {
|
|
regions = append(regions, config.Region)
|
|
}
|
|
|
|
maxRetries = config.MaxRetries
|
|
if config.STSEndpoint != "" {
|
|
endpoints = append(endpoints, config.STSEndpoint)
|
|
if config.STSRegion != "" {
|
|
// this retains original logic, where sts region was only used if sts endpoint was set
|
|
regions = []string{config.STSRegion} // override to be "only" region if set
|
|
}
|
|
|
|
if len(config.STSFallbackEndpoints) > 0 {
|
|
endpoints = append(endpoints, config.STSFallbackEndpoints...)
|
|
}
|
|
|
|
if len(config.STSFallbackRegions) > 0 {
|
|
regions = append(regions, config.STSFallbackRegions...)
|
|
}
|
|
}
|
|
|
|
opts := make([]awsutil.Option, 0)
|
|
if config.IdentityTokenAudience != "" {
|
|
ns, err := namespace.FromContext(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
|
|
}
|
|
|
|
fetcher := &PluginIdentityTokenFetcher{
|
|
sys: b.System(),
|
|
logger: b.Logger(),
|
|
ns: ns,
|
|
audience: config.IdentityTokenAudience,
|
|
ttl: config.IdentityTokenTTL,
|
|
}
|
|
|
|
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
|
|
credsConfig.WebIdentityTokenFetcher = fetcher
|
|
credsConfig.RoleARN = config.RoleARN
|
|
|
|
// explicitly disable environment and shared credential providers when using Web Identity Token Fetcher
|
|
// enables WIF usage in environments that may use AWS Profiles or environment variables for other use-cases
|
|
opts = append(opts, awsutil.WithEnvironmentCredentials(false), awsutil.WithSharedCredentials(false))
|
|
}
|
|
|
|
// at this point, in the IAM case,
|
|
// - regions contains config.Region, if it was set.
|
|
// - endpoints contains iam_endpoint, if it was set.
|
|
// in the sts case,
|
|
// - regions contains sts_region, if it was set, then sts_fallback_regions in order, if they were set.
|
|
// - endpoints contains sts_endpoint, if it was set, then sts_fallback_endpoints in order, if they were set.
|
|
|
|
// case in which nothing was supplied
|
|
if len(regions) == 0 {
|
|
// fallback region is in descending order, AWS_REGION, or AWS_DEFAULT_REGION, or us-east-1
|
|
regions = append(regions, fallbackRegion)
|
|
}
|
|
|
|
if len(endpoints) == 0 {
|
|
for _, v := range regions {
|
|
endpoints = append(endpoints, matchingSTSEndpoint(v))
|
|
}
|
|
}
|
|
|
|
// for this approach of using parallel arrays to part out the configs, we want equal numbers of regions and endpoints
|
|
if len(regions) != len(endpoints) {
|
|
return nil, errors.New("number of regions does not match number of endpoints")
|
|
}
|
|
|
|
for i := 0; i < len(endpoints); i++ {
|
|
if len(regions) > i {
|
|
credsConfig.Region = regions[i]
|
|
} else {
|
|
credsConfig.Region = fallbackRegion
|
|
}
|
|
creds, err := credsConfig.GenerateCredentialChain(opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
configs = append(configs, &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(credsConfig.Region),
|
|
Endpoint: aws.String(endpoints[i]),
|
|
MaxRetries: aws.Int(maxRetries),
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
})
|
|
}
|
|
|
|
return configs, nil
|
|
}
|
|
|
|
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (*iam.IAM, error) {
|
|
var awsConfig *aws.Config
|
|
var err error
|
|
|
|
if entry != nil && entry.AssumeRoleARN != "" {
|
|
awsConfig, err = b.assumeRoleStatic(ctx, s, entry)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err)
|
|
}
|
|
} else {
|
|
awsConfig, err = b.getRootIAMConfig(ctx, s, logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
sess, err := session.NewSession(awsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client := iam.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain IAM client")
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
|
|
awsConfig, err := b.getRootSTSConfigs(ctx, s, logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var client *sts.STS
|
|
|
|
for _, cfg := range awsConfig {
|
|
sess, err := session.NewSession(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client = sts.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain sts client")
|
|
}
|
|
|
|
// ping the client - we only care about errors
|
|
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
|
|
if err == nil {
|
|
return client, nil
|
|
} else {
|
|
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", *cfg.Endpoint, "failed region", *cfg.Region)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("could not obtain sts client")
|
|
}
|
|
|
|
// matchingSTSEndpoint returns the endpoint for the supplied region, according to
|
|
// http://docs.aws.amazon.com/general/latest/gr/sts.html
|
|
func matchingSTSEndpoint(stsRegion string) string {
|
|
return fmt.Sprintf("https://sts.%s.amazonaws.com", stsRegion)
|
|
}
|
|
|
|
// getFallbackRegion returns an aws region fallback. It will check in the AWS specified order:
|
|
// - AWS_REGION, then
|
|
// - AWS_DEFAULT_REGION, then
|
|
// - us-east-1
|
|
func getFallbackRegion() string {
|
|
// set fallback region (we can overwrite later)
|
|
fallbackRegion := os.Getenv("AWS_REGION")
|
|
if fallbackRegion == "" {
|
|
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
|
|
}
|
|
if fallbackRegion == "" {
|
|
fallbackRegion = "us-east-1"
|
|
}
|
|
|
|
return fallbackRegion
|
|
}
|
|
|
|
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
|
|
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
|
|
// When the client's STS credentials expire, it will use this interface to fetch a new
|
|
// plugin identity token and exchange it for new STS credentials.
|
|
type PluginIdentityTokenFetcher struct {
|
|
sys logical.SystemView
|
|
logger hclog.Logger
|
|
audience string
|
|
ns *namespace.Namespace
|
|
ttl time.Duration
|
|
}
|
|
|
|
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
|
|
|
|
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
|
|
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
|
|
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
|
|
Audience: f.audience,
|
|
TTL: f.ttl,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
|
|
}
|
|
f.logger.Info("fetched new plugin identity token")
|
|
|
|
if resp.TTL < f.ttl {
|
|
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
|
|
"requested", f.ttl, "actual", resp.TTL)
|
|
}
|
|
|
|
return []byte(resp.Token.Token()), nil
|
|
}
|