AWS SD: Load Region Fallback (#18019)

* AWS SD: Load Region Fallback

---------

Signed-off-by: matt-gp <small_minority@hotmail.com>
This commit is contained in:
Matt 2026-02-10 11:02:24 +00:00 committed by GitHub
parent 2f1a797d1a
commit e27bcdf03f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 233 additions and 99 deletions

View File

@ -14,10 +14,13 @@
package aws
import (
"context"
"errors"
"fmt"
"time"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/config"
"github.com/prometheus/common/model"
@ -100,6 +103,12 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
}
*c = SDConfig(aux)
var err error
c.Region, err = loadRegion(context.Background(), c.Region)
if err != nil {
return fmt.Errorf("could not determine AWS region: %w", err)
}
switch c.Role {
case RoleEC2:
if c.EC2SDConfig == nil {
@ -107,9 +116,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
c.EC2SDConfig = &ec2Config
}
c.EC2SDConfig.HTTPClientConfig = c.HTTPClientConfig
if c.Region != "" {
c.EC2SDConfig.Region = c.Region
}
c.EC2SDConfig.Region = c.Region
if c.Endpoint != "" {
c.EC2SDConfig.Endpoint = c.Endpoint
}
@ -140,9 +147,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
c.ECSSDConfig = &ecsConfig
}
c.ECSSDConfig.HTTPClientConfig = c.HTTPClientConfig
if c.Region != "" {
c.ECSSDConfig.Region = c.Region
}
c.ECSSDConfig.Region = c.Region
if c.Endpoint != "" {
c.ECSSDConfig.Endpoint = c.Endpoint
}
@ -173,9 +178,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
c.LightsailSDConfig = &lightsailConfig
}
c.LightsailSDConfig.HTTPClientConfig = c.HTTPClientConfig
if c.Region != "" {
c.LightsailSDConfig.Region = c.Region
}
c.LightsailSDConfig.Region = c.Region
if c.Endpoint != "" {
c.LightsailSDConfig.Endpoint = c.Endpoint
}
@ -203,9 +206,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
c.MSKSDConfig = &mskConfig
}
c.MSKSDConfig.HTTPClientConfig = c.HTTPClientConfig
if c.Region != "" {
c.MSKSDConfig.Region = c.Region
}
c.MSKSDConfig.Region = c.Region
if c.Endpoint != "" {
c.MSKSDConfig.Endpoint = c.Endpoint
}
@ -268,3 +269,32 @@ func (c *SDConfig) NewDiscoverer(opts discovery.DiscovererOptions) (discovery.Di
return nil, fmt.Errorf("unknown AWS SD role %q", c.Role)
}
}
// loadRegion finds the region in order: AWS config/env vars ->IMDS.
func loadRegion(ctx context.Context, specifiedRegion string) (string, error) {
if specifiedRegion != "" {
return specifiedRegion, nil
}
cfg, err := awsConfig.LoadDefaultConfig(ctx)
if err != nil {
return "", fmt.Errorf("failed to load AWS config: %w", err)
}
if cfg.Region != "" {
return cfg.Region, nil
}
// Fallback (may fail in non-AWS environments)
imdsClient := imds.NewFromConfig(cfg)
region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
return "", fmt.Errorf("failed to get region from IMDS: %w", err)
}
if region.Region == "" {
return "", errors.New("region not found in AWS config or IMDS")
}
return region.Region, nil
}

View File

@ -14,7 +14,13 @@
package aws
import (
"context"
"errors"
"math/rand/v2"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
@ -309,3 +315,175 @@ func TestMultipleSDConfigsDoNotShareState(t *testing.T) {
})
}
}
// getRandomRegion is a helper to return a pseudo-random AWS region for testing.
func getRandomRegion() string {
regions := []string{
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"eu-west-1",
"eu-west-2",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-1",
"ap-northeast-2",
}
return regions[rand.IntN(len(regions))]
}
func TestLoadRegion(t *testing.T) {
t.Run("with_env_region", func(t *testing.T) {
randomRegion := getRandomRegion()
t.Setenv("AWS_REGION", randomRegion)
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
region, err := loadRegion(context.Background(), "")
require.NoError(t, err)
require.Equal(t, randomRegion, region)
})
t.Run("with_config_file_default_profile", func(t *testing.T) {
randomRegion := getRandomRegion()
// Create a temporary AWS config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config")
configContent := `[default]
region = ` + randomRegion + `
`
err := os.WriteFile(configFile, []byte(configContent), 0o644)
require.NoError(t, err)
defer os.Remove(configFile)
// Set up environment to use the config file
t.Setenv("AWS_CONFIG_FILE", configFile)
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
// Clear any region environment variables to force config file usage
t.Setenv("AWS_REGION", "")
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
t.Setenv("AWS_DEFAULT_REGION", "")
region, err := loadRegion(context.Background(), "")
require.NoError(t, err)
require.Equal(t, randomRegion, region)
})
t.Run("with_config_file_named_profile", func(t *testing.T) {
randomRegion := getRandomRegion()
// Create a temporary AWS config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config")
configContent := `[default]
region = ` + getRandomRegion() + `
[profile ` + randomRegion + `-profile]
region = ` + randomRegion + `
`
err := os.WriteFile(configFile, []byte(configContent), 0o644)
require.NoError(t, err)
defer os.Remove(configFile)
// Set up environment to use the config file
t.Setenv("AWS_CONFIG_FILE", configFile)
t.Setenv("AWS_PROFILE", randomRegion+"-profile")
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
// Clear any region environment variables to force config file usage
t.Setenv("AWS_REGION", "")
t.Setenv("AWS_DEFAULT_REGION", "")
region, err := loadRegion(context.Background(), "")
require.NoError(t, err)
require.Equal(t, randomRegion, region)
})
t.Run("with_specified_region", func(t *testing.T) {
specifiedRegion := getRandomRegion()
// Even with environment region set differently, specified region should take precedence
t.Setenv("AWS_REGION", getRandomRegion())
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
region, err := loadRegion(context.Background(), specifiedRegion)
require.NoError(t, err)
require.Equal(t, specifiedRegion, region)
})
t.Run("imds_fallback", func(t *testing.T) {
randomRegion := getRandomRegion()
// Mock IMDS server that returns a region
mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle instance identity document (contains region info)
if r.URL.Path == "/latest/dynamic/instance-identity/document" {
imdsPayload := `{"region": "` + randomRegion + `"}`
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(imdsPayload))
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockIMDS.Close()
// Set up environment with no region but valid credentials
// This will force fallback to IMDS
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
// Unset any existing region
t.Setenv("AWS_REGION", "")
t.Setenv("AWS_DEFAULT_REGION", "")
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
// Point IMDS to our mock server
t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL)
region, err := loadRegion(context.Background(), "")
require.NoError(t, err)
require.Equal(t, randomRegion, region)
})
t.Run("imds_empty_region", func(t *testing.T) {
// Mock IMDS server that returns empty region
mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle instance identity document with empty region
if r.URL.Path == "/latest/dynamic/instance-identity/document" {
imdsPayload := `{"region": ""}`
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(imdsPayload))
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockIMDS.Close()
// Set up environment with no region but valid credentials
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
// Unset any existing region
t.Setenv("AWS_REGION", "")
t.Setenv("AWS_DEFAULT_REGION", "")
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
// Point IMDS to our mock server
t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL)
_, err := loadRegion(context.Background(), "")
require.Error(t, err)
require.Contains(t, err.Error(), "failed to get region from IMDS")
})
}

View File

@ -27,7 +27,6 @@ import (
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
@ -125,31 +124,10 @@ func (c *EC2SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
return err
}
if c.Region == "" {
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
if err != nil {
return err
}
if cfg.Region != "" {
// If the region is already set in the config, use it.
// This can happen if the user has set the region in the AWS config file or environment variables.
c.Region = cfg.Region
}
if c.Region == "" {
// Try to get the region from the instance metadata service (IMDS).
imdsClient := imds.NewFromConfig(cfg)
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
if err != nil {
return err
}
c.Region = region.Region
}
}
if c.Region == "" {
return errors.New("EC2 SD configuration requires a region")
// Check if the region is set, if not attempt to load it from the AWS SDK.
c.Region, err = loadRegion(context.Background(), c.Region)
if err != nil {
return fmt.Errorf("could not determine AWS region: %w", err)
}
for _, f := range c.Filters {

View File

@ -27,7 +27,6 @@ import (
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
@ -137,17 +136,9 @@ func (c *ECSSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
return err
}
if c.Region == "" {
cfg, err := awsConfig.LoadDefaultConfig(context.TODO())
if err != nil {
return err
}
client := imds.NewFromConfig(cfg)
result, err := client.GetRegion(context.Background(), &imds.GetRegionInput{})
if err != nil {
return fmt.Errorf("ECS SD configuration requires a region. Tried to fetch it from the instance metadata: %w", err)
}
c.Region = result.Region
c.Region, err = loadRegion(context.Background(), c.Region)
if err != nil {
return fmt.Errorf("could not determine AWS region: %w", err)
}
return c.HTTPClientConfig.Validate()

View File

@ -26,7 +26,6 @@ import (
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/lightsail"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
@ -106,30 +105,9 @@ func (c *LightsailSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
return err
}
if c.Region == "" {
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
if err != nil {
return err
}
if cfg.Region != "" {
// Use the region from the AWS config. It will load environment variables and shared config files.
c.Region = cfg.Region
}
if c.Region == "" {
// Try to get the region from the instance metadata service (IMDS).
imdsClient := imds.NewFromConfig(cfg)
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
if err != nil {
return err
}
c.Region = region.Region
}
}
if c.Region == "" {
return errors.New("lightsail SD configuration requires a region")
c.Region, err = loadRegion(context.Background(), c.Region)
if err != nil {
return fmt.Errorf("could not determine AWS region: %w", err)
}
return c.HTTPClientConfig.Validate()

View File

@ -27,7 +27,6 @@ import (
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/kafka"
"github.com/aws/aws-sdk-go-v2/service/kafka/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
@ -136,29 +135,9 @@ func (c *MSKSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
return err
}
if c.Region == "" {
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
if err != nil {
return err
}
if cfg.Region != "" {
// If the region is already set in the config, use it (env vars).
c.Region = cfg.Region
}
if c.Region == "" {
// Try to get the region from IMDS.
imdsClient := imds.NewFromConfig(cfg)
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
if err != nil {
return err
}
c.Region = region.Region
}
}
if c.Region == "" {
return errors.New("MSK SD configuration requires a region")
c.Region, err = loadRegion(context.Background(), c.Region)
if err != nil {
return fmt.Errorf("could not determine AWS region: %w", err)
}
return c.HTTPClientConfig.Validate()