This commit is contained in:
krishicks 2025-08-03 04:53:05 -07:00 committed by GitHub
commit 4ccf906ec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 16 deletions

View File

@ -93,6 +93,7 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) {
return awsv2.Config{}, fmt.Errorf("instantiating AWS config: %w", err)
}
var credentials awsv2.CredentialsProvider
if awsConfig.AssumeRole != "" {
stsSvc := sts.NewFromConfig(cfg)
var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions)
@ -107,9 +108,39 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) {
} else {
logrus.Infof("Assuming role: %s", awsConfig.AssumeRole)
}
creds := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...)
cfg.Credentials = awsv2.NewCredentialsCache(creds)
provider := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...)
credentials = awsv2.NewCredentialsCache(provider)
} else {
credentials = newReloadableStaticCredentialsProvider(defaultOpts...)
}
cfg.Credentials = credentials
return cfg, nil
}
// reloadableStaticCredentialsProvider is a credentials provider that loads
// default credentials on each retrieval. This makes it possible to load fresh
// credentials stored in a file referenced by AWS_SHARED_CREDENTIALS_FILE that
// is updated by another process.
type reloadableStaticCredentialsProvider struct {
opts []func(*config.LoadOptions) error
}
func newReloadableStaticCredentialsProvider(opts ...func(*config.LoadOptions) error) awsv2.CredentialsProvider {
return &reloadableStaticCredentialsProvider{opts: opts}
}
func (p *reloadableStaticCredentialsProvider) Retrieve(ctx context.Context) (awsv2.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx, p.opts...)
if err != nil {
return awsv2.Credentials{}, fmt.Errorf("instantiating AWS config: %w", err)
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return awsv2.Credentials{}, fmt.Errorf("retrieving credentials: %w", err)
}
return creds, nil
}

View File

@ -19,6 +19,7 @@ package aws
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@ -28,11 +29,20 @@ import (
func Test_newV2Config(t *testing.T) {
t.Run("should use profile from credentials file", func(t *testing.T) {
// setup
credsFile, err := prepareCredentialsFile(t)
defer os.Remove(credsFile.Name())
dir := t.TempDir()
credsFile := filepath.Join(dir, "credentials")
err := os.WriteFile(credsFile, []byte(`
[profile1]
aws_access_key_id=AKID1234
aws_secret_access_key=SECRET1
[profile2]
aws_access_key_id=AKID2345
aws_secret_access_key=SECRET2
`), 0777)
require.NoError(t, err)
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile.Name())
defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile)
// when
cfg, err := newV2Config(AWSSessionConfig{Profile: "profile2"})
@ -45,6 +55,44 @@ func Test_newV2Config(t *testing.T) {
assert.Equal(t, "SECRET2", creds.SecretAccessKey)
})
t.Run("should respect updates to the credentials file", func(t *testing.T) {
// setup
dir := t.TempDir()
credsFile := filepath.Join(dir, "credentials")
err := os.WriteFile(credsFile, []byte(`
[default]
aws_access_key_id=AKID1234
aws_secret_access_key=SECRET1
`), 0777)
require.NoError(t, err)
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile)
cfg, err := newV2Config(AWSSessionConfig{})
require.NoError(t, err)
creds, err := cfg.Credentials.Retrieve(context.Background())
require.NoError(t, err)
assert.Equal(t, "AKID1234", creds.AccessKeyID)
assert.Equal(t, "SECRET1", creds.SecretAccessKey)
// given
err = os.WriteFile(credsFile, []byte(`
[default]
aws_access_key_id=AKID2345
aws_secret_access_key=SECRET2
`), 0777)
require.NoError(t, err)
// when
creds, err = cfg.Credentials.Retrieve(context.Background())
// then
assert.NoError(t, err)
assert.Equal(t, "AKID2345", creds.AccessKeyID)
assert.Equal(t, "SECRET2", creds.SecretAccessKey)
})
t.Run("should respect env variables without profile", func(t *testing.T) {
// setup
os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE")
@ -63,13 +111,3 @@ func Test_newV2Config(t *testing.T) {
assert.Equal(t, "topsecret", creds.SecretAccessKey)
})
}
func prepareCredentialsFile(t *testing.T) (*os.File, error) {
credsFile, err := os.CreateTemp("", "aws-*.creds")
require.NoError(t, err)
_, err = credsFile.WriteString("[profile1]\naws_access_key_id=AKID1234\naws_secret_access_key=SECRET1\n\n[profile2]\naws_access_key_id=AKID2345\naws_secret_access_key=SECRET2\n")
require.NoError(t, err)
err = credsFile.Close()
require.NoError(t, err)
return credsFile, err
}