diff --git a/discovery/aws/aws.go b/discovery/aws/aws.go index 9db87965bb..69b3b41c06 100644 --- a/discovery/aws/aws.go +++ b/discovery/aws/aws.go @@ -14,10 +14,13 @@ package aws import ( + "context" "errors" "fmt" "time" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/config" "github.com/prometheus/common/model" @@ -100,6 +103,12 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error { } *c = SDConfig(aux) + var err error + c.Region, err = loadRegion(context.Background(), c.Region) + if err != nil { + return fmt.Errorf("could not determine AWS region: %w", err) + } + switch c.Role { case RoleEC2: if c.EC2SDConfig == nil { @@ -107,9 +116,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error { c.EC2SDConfig = &ec2Config } c.EC2SDConfig.HTTPClientConfig = c.HTTPClientConfig - if c.Region != "" { - c.EC2SDConfig.Region = c.Region - } + c.EC2SDConfig.Region = c.Region if c.Endpoint != "" { c.EC2SDConfig.Endpoint = c.Endpoint } @@ -140,9 +147,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error { c.ECSSDConfig = &ecsConfig } c.ECSSDConfig.HTTPClientConfig = c.HTTPClientConfig - if c.Region != "" { - c.ECSSDConfig.Region = c.Region - } + c.ECSSDConfig.Region = c.Region if c.Endpoint != "" { c.ECSSDConfig.Endpoint = c.Endpoint } @@ -173,9 +178,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error { c.LightsailSDConfig = &lightsailConfig } c.LightsailSDConfig.HTTPClientConfig = c.HTTPClientConfig - if c.Region != "" { - c.LightsailSDConfig.Region = c.Region - } + c.LightsailSDConfig.Region = c.Region if c.Endpoint != "" { c.LightsailSDConfig.Endpoint = c.Endpoint } @@ -203,9 +206,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error { c.MSKSDConfig = &mskConfig } c.MSKSDConfig.HTTPClientConfig = c.HTTPClientConfig - if c.Region != "" { - c.MSKSDConfig.Region = c.Region - } + c.MSKSDConfig.Region = c.Region if c.Endpoint != "" { c.MSKSDConfig.Endpoint = c.Endpoint } @@ -268,3 +269,32 @@ func (c *SDConfig) NewDiscoverer(opts discovery.DiscovererOptions) (discovery.Di return nil, fmt.Errorf("unknown AWS SD role %q", c.Role) } } + +// loadRegion finds the region in order: AWS config/env vars ->IMDS. +func loadRegion(ctx context.Context, specifiedRegion string) (string, error) { + if specifiedRegion != "" { + return specifiedRegion, nil + } + + cfg, err := awsConfig.LoadDefaultConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load AWS config: %w", err) + } + + if cfg.Region != "" { + return cfg.Region, nil + } + + // Fallback (may fail in non-AWS environments) + imdsClient := imds.NewFromConfig(cfg) + region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + return "", fmt.Errorf("failed to get region from IMDS: %w", err) + } + + if region.Region == "" { + return "", errors.New("region not found in AWS config or IMDS") + } + + return region.Region, nil +} diff --git a/discovery/aws/aws_test.go b/discovery/aws/aws_test.go index b47a6cd92c..d1ec7b2282 100644 --- a/discovery/aws/aws_test.go +++ b/discovery/aws/aws_test.go @@ -14,7 +14,13 @@ package aws import ( + "context" "errors" + "math/rand/v2" + "net/http" + "net/http/httptest" + "os" + "path/filepath" "testing" "time" @@ -309,3 +315,175 @@ func TestMultipleSDConfigsDoNotShareState(t *testing.T) { }) } } + +// getRandomRegion is a helper to return a pseudo-random AWS region for testing. +func getRandomRegion() string { + regions := []string{ + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + "eu-west-1", + "eu-west-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", + "ap-northeast-2", + } + + return regions[rand.IntN(len(regions))] +} + +func TestLoadRegion(t *testing.T) { + t.Run("with_env_region", func(t *testing.T) { + randomRegion := getRandomRegion() + t.Setenv("AWS_REGION", randomRegion) + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used + t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used + + region, err := loadRegion(context.Background(), "") + require.NoError(t, err) + require.Equal(t, randomRegion, region) + }) + + t.Run("with_config_file_default_profile", func(t *testing.T) { + randomRegion := getRandomRegion() + + // Create a temporary AWS config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config") + + configContent := `[default] +region = ` + randomRegion + ` +` + + err := os.WriteFile(configFile, []byte(configContent), 0o644) + require.NoError(t, err) + defer os.Remove(configFile) + + // Set up environment to use the config file + t.Setenv("AWS_CONFIG_FILE", configFile) + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + // Clear any region environment variables to force config file usage + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used + t.Setenv("AWS_DEFAULT_REGION", "") + + region, err := loadRegion(context.Background(), "") + require.NoError(t, err) + require.Equal(t, randomRegion, region) + }) + + t.Run("with_config_file_named_profile", func(t *testing.T) { + randomRegion := getRandomRegion() + + // Create a temporary AWS config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config") + + configContent := `[default] +region = ` + getRandomRegion() + ` + +[profile ` + randomRegion + `-profile] +region = ` + randomRegion + ` +` + + err := os.WriteFile(configFile, []byte(configContent), 0o644) + require.NoError(t, err) + defer os.Remove(configFile) + + // Set up environment to use the config file + t.Setenv("AWS_CONFIG_FILE", configFile) + t.Setenv("AWS_PROFILE", randomRegion+"-profile") + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + // Clear any region environment variables to force config file usage + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_DEFAULT_REGION", "") + + region, err := loadRegion(context.Background(), "") + require.NoError(t, err) + require.Equal(t, randomRegion, region) + }) + + t.Run("with_specified_region", func(t *testing.T) { + specifiedRegion := getRandomRegion() + + // Even with environment region set differently, specified region should take precedence + t.Setenv("AWS_REGION", getRandomRegion()) + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + + region, err := loadRegion(context.Background(), specifiedRegion) + require.NoError(t, err) + require.Equal(t, specifiedRegion, region) + }) + + t.Run("imds_fallback", func(t *testing.T) { + randomRegion := getRandomRegion() + + // Mock IMDS server that returns a region + mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle instance identity document (contains region info) + if r.URL.Path == "/latest/dynamic/instance-identity/document" { + imdsPayload := `{"region": "` + randomRegion + `"}` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(imdsPayload)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockIMDS.Close() + + // Set up environment with no region but valid credentials + // This will force fallback to IMDS + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + // Unset any existing region + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_DEFAULT_REGION", "") + t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used + t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used + // Point IMDS to our mock server + t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL) + + region, err := loadRegion(context.Background(), "") + require.NoError(t, err) + require.Equal(t, randomRegion, region) + }) + + t.Run("imds_empty_region", func(t *testing.T) { + // Mock IMDS server that returns empty region + mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle instance identity document with empty region + if r.URL.Path == "/latest/dynamic/instance-identity/document" { + imdsPayload := `{"region": ""}` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(imdsPayload)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockIMDS.Close() + + // Set up environment with no region but valid credentials + t.Setenv("AWS_ACCESS_KEY_ID", "dummy") + t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy") + // Unset any existing region + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_DEFAULT_REGION", "") + t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used + t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used + // Point IMDS to our mock server + t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL) + + _, err := loadRegion(context.Background(), "") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to get region from IMDS") + }) +} diff --git a/discovery/aws/ec2.go b/discovery/aws/ec2.go index 19ecebd491..4daff43ecc 100644 --- a/discovery/aws/ec2.go +++ b/discovery/aws/ec2.go @@ -27,7 +27,6 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -125,31 +124,10 @@ func (c *EC2SDConfig) UnmarshalYAML(unmarshal func(any) error) error { return err } - if c.Region == "" { - cfg, err := awsConfig.LoadDefaultConfig(context.Background()) - if err != nil { - return err - } - - if cfg.Region != "" { - // If the region is already set in the config, use it. - // This can happen if the user has set the region in the AWS config file or environment variables. - c.Region = cfg.Region - } - - if c.Region == "" { - // Try to get the region from the instance metadata service (IMDS). - imdsClient := imds.NewFromConfig(cfg) - region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{}) - if err != nil { - return err - } - c.Region = region.Region - } - } - - if c.Region == "" { - return errors.New("EC2 SD configuration requires a region") + // Check if the region is set, if not attempt to load it from the AWS SDK. + c.Region, err = loadRegion(context.Background(), c.Region) + if err != nil { + return fmt.Errorf("could not determine AWS region: %w", err) } for _, f := range c.Filters { diff --git a/discovery/aws/ecs.go b/discovery/aws/ecs.go index 1d5ff366de..e9d578aec3 100644 --- a/discovery/aws/ecs.go +++ b/discovery/aws/ecs.go @@ -27,7 +27,6 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" @@ -137,17 +136,9 @@ func (c *ECSSDConfig) UnmarshalYAML(unmarshal func(any) error) error { return err } - if c.Region == "" { - cfg, err := awsConfig.LoadDefaultConfig(context.TODO()) - if err != nil { - return err - } - client := imds.NewFromConfig(cfg) - result, err := client.GetRegion(context.Background(), &imds.GetRegionInput{}) - if err != nil { - return fmt.Errorf("ECS SD configuration requires a region. Tried to fetch it from the instance metadata: %w", err) - } - c.Region = result.Region + c.Region, err = loadRegion(context.Background(), c.Region) + if err != nil { + return fmt.Errorf("could not determine AWS region: %w", err) } return c.HTTPClientConfig.Validate() diff --git a/discovery/aws/lightsail.go b/discovery/aws/lightsail.go index b13f26cc5f..69a5b6625f 100644 --- a/discovery/aws/lightsail.go +++ b/discovery/aws/lightsail.go @@ -26,7 +26,6 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/lightsail" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" @@ -106,30 +105,9 @@ func (c *LightsailSDConfig) UnmarshalYAML(unmarshal func(any) error) error { return err } - if c.Region == "" { - cfg, err := awsConfig.LoadDefaultConfig(context.Background()) - if err != nil { - return err - } - - if cfg.Region != "" { - // Use the region from the AWS config. It will load environment variables and shared config files. - c.Region = cfg.Region - } - - if c.Region == "" { - // Try to get the region from the instance metadata service (IMDS). - imdsClient := imds.NewFromConfig(cfg) - region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{}) - if err != nil { - return err - } - c.Region = region.Region - } - } - - if c.Region == "" { - return errors.New("lightsail SD configuration requires a region") + c.Region, err = loadRegion(context.Background(), c.Region) + if err != nil { + return fmt.Errorf("could not determine AWS region: %w", err) } return c.HTTPClientConfig.Validate() diff --git a/discovery/aws/msk.go b/discovery/aws/msk.go index 2a2b240d49..a68960066f 100644 --- a/discovery/aws/msk.go +++ b/discovery/aws/msk.go @@ -27,7 +27,6 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/kafka" "github.com/aws/aws-sdk-go-v2/service/kafka/types" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -136,29 +135,9 @@ func (c *MSKSDConfig) UnmarshalYAML(unmarshal func(any) error) error { return err } - if c.Region == "" { - cfg, err := awsConfig.LoadDefaultConfig(context.Background()) - if err != nil { - return err - } - if cfg.Region != "" { - // If the region is already set in the config, use it (env vars). - c.Region = cfg.Region - } - - if c.Region == "" { - // Try to get the region from IMDS. - imdsClient := imds.NewFromConfig(cfg) - region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{}) - if err != nil { - return err - } - c.Region = region.Region - } - } - - if c.Region == "" { - return errors.New("MSK SD configuration requires a region") + c.Region, err = loadRegion(context.Background(), c.Region) + if err != nil { + return fmt.Errorf("could not determine AWS region: %w", err) } return c.HTTPClientConfig.Validate()