mirror of
https://github.com/prometheus/prometheus.git
synced 2026-05-04 20:06:12 +02:00
AWS SD: Load Region Fallback (#18019)
* AWS SD: Load Region Fallback --------- Signed-off-by: matt-gp <small_minority@hotmail.com>
This commit is contained in:
parent
2f1a797d1a
commit
e27bcdf03f
@ -14,10 +14,13 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/common/config"
|
||||
"github.com/prometheus/common/model"
|
||||
@ -100,6 +103,12 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
}
|
||||
*c = SDConfig(aux)
|
||||
|
||||
var err error
|
||||
c.Region, err = loadRegion(context.Background(), c.Region)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine AWS region: %w", err)
|
||||
}
|
||||
|
||||
switch c.Role {
|
||||
case RoleEC2:
|
||||
if c.EC2SDConfig == nil {
|
||||
@ -107,9 +116,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
c.EC2SDConfig = &ec2Config
|
||||
}
|
||||
c.EC2SDConfig.HTTPClientConfig = c.HTTPClientConfig
|
||||
if c.Region != "" {
|
||||
c.EC2SDConfig.Region = c.Region
|
||||
}
|
||||
c.EC2SDConfig.Region = c.Region
|
||||
if c.Endpoint != "" {
|
||||
c.EC2SDConfig.Endpoint = c.Endpoint
|
||||
}
|
||||
@ -140,9 +147,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
c.ECSSDConfig = &ecsConfig
|
||||
}
|
||||
c.ECSSDConfig.HTTPClientConfig = c.HTTPClientConfig
|
||||
if c.Region != "" {
|
||||
c.ECSSDConfig.Region = c.Region
|
||||
}
|
||||
c.ECSSDConfig.Region = c.Region
|
||||
if c.Endpoint != "" {
|
||||
c.ECSSDConfig.Endpoint = c.Endpoint
|
||||
}
|
||||
@ -173,9 +178,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
c.LightsailSDConfig = &lightsailConfig
|
||||
}
|
||||
c.LightsailSDConfig.HTTPClientConfig = c.HTTPClientConfig
|
||||
if c.Region != "" {
|
||||
c.LightsailSDConfig.Region = c.Region
|
||||
}
|
||||
c.LightsailSDConfig.Region = c.Region
|
||||
if c.Endpoint != "" {
|
||||
c.LightsailSDConfig.Endpoint = c.Endpoint
|
||||
}
|
||||
@ -203,9 +206,7 @@ func (c *SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
c.MSKSDConfig = &mskConfig
|
||||
}
|
||||
c.MSKSDConfig.HTTPClientConfig = c.HTTPClientConfig
|
||||
if c.Region != "" {
|
||||
c.MSKSDConfig.Region = c.Region
|
||||
}
|
||||
c.MSKSDConfig.Region = c.Region
|
||||
if c.Endpoint != "" {
|
||||
c.MSKSDConfig.Endpoint = c.Endpoint
|
||||
}
|
||||
@ -268,3 +269,32 @@ func (c *SDConfig) NewDiscoverer(opts discovery.DiscovererOptions) (discovery.Di
|
||||
return nil, fmt.Errorf("unknown AWS SD role %q", c.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// loadRegion finds the region in order: AWS config/env vars ->IMDS.
|
||||
func loadRegion(ctx context.Context, specifiedRegion string) (string, error) {
|
||||
if specifiedRegion != "" {
|
||||
return specifiedRegion, nil
|
||||
}
|
||||
|
||||
cfg, err := awsConfig.LoadDefaultConfig(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Region != "" {
|
||||
return cfg.Region, nil
|
||||
}
|
||||
|
||||
// Fallback (may fail in non-AWS environments)
|
||||
imdsClient := imds.NewFromConfig(cfg)
|
||||
region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get region from IMDS: %w", err)
|
||||
}
|
||||
|
||||
if region.Region == "" {
|
||||
return "", errors.New("region not found in AWS config or IMDS")
|
||||
}
|
||||
|
||||
return region.Region, nil
|
||||
}
|
||||
|
||||
@ -14,7 +14,13 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -309,3 +315,175 @@ func TestMultipleSDConfigsDoNotShareState(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getRandomRegion is a helper to return a pseudo-random AWS region for testing.
|
||||
func getRandomRegion() string {
|
||||
regions := []string{
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-west-1",
|
||||
"us-west-2",
|
||||
"eu-west-1",
|
||||
"eu-west-2",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-northeast-1",
|
||||
"ap-northeast-2",
|
||||
}
|
||||
|
||||
return regions[rand.IntN(len(regions))]
|
||||
}
|
||||
|
||||
func TestLoadRegion(t *testing.T) {
|
||||
t.Run("with_env_region", func(t *testing.T) {
|
||||
randomRegion := getRandomRegion()
|
||||
t.Setenv("AWS_REGION", randomRegion)
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
|
||||
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
|
||||
|
||||
region, err := loadRegion(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, randomRegion, region)
|
||||
})
|
||||
|
||||
t.Run("with_config_file_default_profile", func(t *testing.T) {
|
||||
randomRegion := getRandomRegion()
|
||||
|
||||
// Create a temporary AWS config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config")
|
||||
|
||||
configContent := `[default]
|
||||
region = ` + randomRegion + `
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Set up environment to use the config file
|
||||
t.Setenv("AWS_CONFIG_FILE", configFile)
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
// Clear any region environment variables to force config file usage
|
||||
t.Setenv("AWS_REGION", "")
|
||||
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
|
||||
t.Setenv("AWS_DEFAULT_REGION", "")
|
||||
|
||||
region, err := loadRegion(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, randomRegion, region)
|
||||
})
|
||||
|
||||
t.Run("with_config_file_named_profile", func(t *testing.T) {
|
||||
randomRegion := getRandomRegion()
|
||||
|
||||
// Create a temporary AWS config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config")
|
||||
|
||||
configContent := `[default]
|
||||
region = ` + getRandomRegion() + `
|
||||
|
||||
[profile ` + randomRegion + `-profile]
|
||||
region = ` + randomRegion + `
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Set up environment to use the config file
|
||||
t.Setenv("AWS_CONFIG_FILE", configFile)
|
||||
t.Setenv("AWS_PROFILE", randomRegion+"-profile")
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
// Clear any region environment variables to force config file usage
|
||||
t.Setenv("AWS_REGION", "")
|
||||
t.Setenv("AWS_DEFAULT_REGION", "")
|
||||
|
||||
region, err := loadRegion(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, randomRegion, region)
|
||||
})
|
||||
|
||||
t.Run("with_specified_region", func(t *testing.T) {
|
||||
specifiedRegion := getRandomRegion()
|
||||
|
||||
// Even with environment region set differently, specified region should take precedence
|
||||
t.Setenv("AWS_REGION", getRandomRegion())
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
|
||||
region, err := loadRegion(context.Background(), specifiedRegion)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, specifiedRegion, region)
|
||||
})
|
||||
|
||||
t.Run("imds_fallback", func(t *testing.T) {
|
||||
randomRegion := getRandomRegion()
|
||||
|
||||
// Mock IMDS server that returns a region
|
||||
mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle instance identity document (contains region info)
|
||||
if r.URL.Path == "/latest/dynamic/instance-identity/document" {
|
||||
imdsPayload := `{"region": "` + randomRegion + `"}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(imdsPayload))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer mockIMDS.Close()
|
||||
|
||||
// Set up environment with no region but valid credentials
|
||||
// This will force fallback to IMDS
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
// Unset any existing region
|
||||
t.Setenv("AWS_REGION", "")
|
||||
t.Setenv("AWS_DEFAULT_REGION", "")
|
||||
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
|
||||
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
|
||||
// Point IMDS to our mock server
|
||||
t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL)
|
||||
|
||||
region, err := loadRegion(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, randomRegion, region)
|
||||
})
|
||||
|
||||
t.Run("imds_empty_region", func(t *testing.T) {
|
||||
// Mock IMDS server that returns empty region
|
||||
mockIMDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle instance identity document with empty region
|
||||
if r.URL.Path == "/latest/dynamic/instance-identity/document" {
|
||||
imdsPayload := `{"region": ""}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(imdsPayload))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer mockIMDS.Close()
|
||||
|
||||
// Set up environment with no region but valid credentials
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
// Unset any existing region
|
||||
t.Setenv("AWS_REGION", "")
|
||||
t.Setenv("AWS_DEFAULT_REGION", "")
|
||||
t.Setenv("AWS_CONFIG_FILE", "") // Ensure no config file is used
|
||||
t.Setenv("AWS_PROFILE", "") // Ensure no profile file is used
|
||||
// Point IMDS to our mock server
|
||||
t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", mockIMDS.URL)
|
||||
|
||||
_, err := loadRegion(context.Background(), "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to get region from IMDS")
|
||||
})
|
||||
}
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
@ -125,31 +124,10 @@ func (c *EC2SDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Region != "" {
|
||||
// If the region is already set in the config, use it.
|
||||
// This can happen if the user has set the region in the AWS config file or environment variables.
|
||||
c.Region = cfg.Region
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
// Try to get the region from the instance metadata service (IMDS).
|
||||
imdsClient := imds.NewFromConfig(cfg)
|
||||
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Region = region.Region
|
||||
}
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
return errors.New("EC2 SD configuration requires a region")
|
||||
// Check if the region is set, if not attempt to load it from the AWS SDK.
|
||||
c.Region, err = loadRegion(context.Background(), c.Region)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine AWS region: %w", err)
|
||||
}
|
||||
|
||||
for _, f := range c.Filters {
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ecs"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
|
||||
@ -137,17 +136,9 @@ func (c *ECSSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
cfg, err := awsConfig.LoadDefaultConfig(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := imds.NewFromConfig(cfg)
|
||||
result, err := client.GetRegion(context.Background(), &imds.GetRegionInput{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ECS SD configuration requires a region. Tried to fetch it from the instance metadata: %w", err)
|
||||
}
|
||||
c.Region = result.Region
|
||||
c.Region, err = loadRegion(context.Background(), c.Region)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine AWS region: %w", err)
|
||||
}
|
||||
|
||||
return c.HTTPClientConfig.Validate()
|
||||
|
||||
@ -26,7 +26,6 @@ import (
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/lightsail"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
"github.com/aws/smithy-go"
|
||||
@ -106,30 +105,9 @@ func (c *LightsailSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Region != "" {
|
||||
// Use the region from the AWS config. It will load environment variables and shared config files.
|
||||
c.Region = cfg.Region
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
// Try to get the region from the instance metadata service (IMDS).
|
||||
imdsClient := imds.NewFromConfig(cfg)
|
||||
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Region = region.Region
|
||||
}
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
return errors.New("lightsail SD configuration requires a region")
|
||||
c.Region, err = loadRegion(context.Background(), c.Region)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine AWS region: %w", err)
|
||||
}
|
||||
|
||||
return c.HTTPClientConfig.Validate()
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kafka"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kafka/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
@ -136,29 +135,9 @@ func (c *MSKSDConfig) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
cfg, err := awsConfig.LoadDefaultConfig(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.Region != "" {
|
||||
// If the region is already set in the config, use it (env vars).
|
||||
c.Region = cfg.Region
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
// Try to get the region from IMDS.
|
||||
imdsClient := imds.NewFromConfig(cfg)
|
||||
region, err := imdsClient.GetRegion(context.Background(), &imds.GetRegionInput{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Region = region.Region
|
||||
}
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
return errors.New("MSK SD configuration requires a region")
|
||||
c.Region, err = loadRegion(context.Background(), c.Region)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not determine AWS region: %w", err)
|
||||
}
|
||||
|
||||
return c.HTTPClientConfig.Validate()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user