Use common code for creating AWS sessions

This commit is contained in:
John Gardiner Myers 2023-05-28 19:57:40 -07:00
parent a9ae5ea43a
commit 794a10dfbe
5 changed files with 101 additions and 86 deletions

23
main.go
View File

@ -25,6 +25,9 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53"
sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
@ -180,6 +183,20 @@ func main() {
zoneTypeFilter := provider.NewZoneTypeFilter(cfg.AWSZoneType) zoneTypeFilter := provider.NewZoneTypeFilter(cfg.AWSZoneType)
zoneTagFilter := provider.NewZoneTagFilter(cfg.AWSZoneTagFilter) zoneTagFilter := provider.NewZoneTagFilter(cfg.AWSZoneTagFilter)
var awsSession *session.Session
if cfg.Provider == "aws" || cfg.Provider == "aws-sd" {
awsSession, err = aws.NewSession(
aws.AWSSessionConfig{
AssumeRole: cfg.AWSAssumeRole,
AssumeRoleExternalID: cfg.AWSAssumeRoleExternalID,
APIRetries: cfg.AWSAPIRetries,
},
)
if err != nil {
log.Fatal(err)
}
}
var p provider.Provider var p provider.Provider
switch cfg.Provider { switch cfg.Provider {
case "akamai": case "akamai":
@ -207,13 +224,11 @@ func main() {
BatchChangeSize: cfg.AWSBatchChangeSize, BatchChangeSize: cfg.AWSBatchChangeSize,
BatchChangeInterval: cfg.AWSBatchChangeInterval, BatchChangeInterval: cfg.AWSBatchChangeInterval,
EvaluateTargetHealth: cfg.AWSEvaluateTargetHealth, EvaluateTargetHealth: cfg.AWSEvaluateTargetHealth,
AssumeRole: cfg.AWSAssumeRole,
AssumeRoleExternalID: cfg.AWSAssumeRoleExternalID,
APIRetries: cfg.AWSAPIRetries,
PreferCNAME: cfg.AWSPreferCNAME, PreferCNAME: cfg.AWSPreferCNAME,
DryRun: cfg.DryRun, DryRun: cfg.DryRun,
ZoneCacheDuration: cfg.AWSZoneCacheDuration, ZoneCacheDuration: cfg.AWSZoneCacheDuration,
}, },
route53.New(awsSession),
) )
case "aws-sd": case "aws-sd":
// Check that only compatible Registry is used with AWS-SD // Check that only compatible Registry is used with AWS-SD
@ -221,7 +236,7 @@ func main() {
log.Infof("Registry \"%s\" cannot be used with AWS Cloud Map. Switching to \"aws-sd\".", cfg.Registry) log.Infof("Registry \"%s\" cannot be used with AWS Cloud Map. Switching to \"aws-sd\".", cfg.Registry)
cfg.Registry = "aws-sd" cfg.Registry = "aws-sd"
} }
p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.AWSAssumeRole, cfg.AWSAssumeRoleExternalID, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID) p, err = awssd.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.DryRun, cfg.AWSSDServiceCleanup, cfg.TXTOwnerID, sd.New(awsSession))
case "azure-dns", "azure": case "azure-dns", "azure":
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun) p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun)
case "azure-private-dns": case "azure-private-dns":

View File

@ -457,12 +457,12 @@ func (cfg *Config) ParseFlags(args []string) error {
app.Flag("alibaba-cloud-zone-type", "When using the Alibaba Cloud provider, filter for zones of this type (optional, options: public, private)").Default(defaultConfig.AlibabaCloudZoneType).EnumVar(&cfg.AlibabaCloudZoneType, "", "public", "private") app.Flag("alibaba-cloud-zone-type", "When using the Alibaba Cloud provider, filter for zones of this type (optional, options: public, private)").Default(defaultConfig.AlibabaCloudZoneType).EnumVar(&cfg.AlibabaCloudZoneType, "", "public", "private")
app.Flag("aws-zone-type", "When using the AWS provider, filter for zones of this type (optional, options: public, private)").Default(defaultConfig.AWSZoneType).EnumVar(&cfg.AWSZoneType, "", "public", "private") app.Flag("aws-zone-type", "When using the AWS provider, filter for zones of this type (optional, options: public, private)").Default(defaultConfig.AWSZoneType).EnumVar(&cfg.AWSZoneType, "", "public", "private")
app.Flag("aws-zone-tags", "When using the AWS provider, filter for zones with these tags").Default("").StringsVar(&cfg.AWSZoneTagFilter) app.Flag("aws-zone-tags", "When using the AWS provider, filter for zones with these tags").Default("").StringsVar(&cfg.AWSZoneTagFilter)
app.Flag("aws-assume-role", "When using the AWS provider, assume this IAM role. Useful for hosted zones in another AWS account. Specify the full ARN, e.g. `arn:aws:iam::123455567:role/external-dns` (optional)").Default(defaultConfig.AWSAssumeRole).StringVar(&cfg.AWSAssumeRole) app.Flag("aws-assume-role", "When using the AWS API, assume this IAM role. Useful for hosted zones in another AWS account. Specify the full ARN, e.g. `arn:aws:iam::123455567:role/external-dns` (optional)").Default(defaultConfig.AWSAssumeRole).StringVar(&cfg.AWSAssumeRole)
app.Flag("aws-assume-role-external-id", "When using the AWS provider and assuming a role then specify this external ID` (optional)").Default(defaultConfig.AWSAssumeRoleExternalID).StringVar(&cfg.AWSAssumeRoleExternalID) app.Flag("aws-assume-role-external-id", "When using the AWS API and assuming a role then specify this external ID` (optional)").Default(defaultConfig.AWSAssumeRoleExternalID).StringVar(&cfg.AWSAssumeRoleExternalID)
app.Flag("aws-batch-change-size", "When using the AWS provider, set the maximum number of changes that will be applied in each batch.").Default(strconv.Itoa(defaultConfig.AWSBatchChangeSize)).IntVar(&cfg.AWSBatchChangeSize) app.Flag("aws-batch-change-size", "When using the AWS provider, set the maximum number of changes that will be applied in each batch.").Default(strconv.Itoa(defaultConfig.AWSBatchChangeSize)).IntVar(&cfg.AWSBatchChangeSize)
app.Flag("aws-batch-change-interval", "When using the AWS provider, set the interval between batch changes.").Default(defaultConfig.AWSBatchChangeInterval.String()).DurationVar(&cfg.AWSBatchChangeInterval) app.Flag("aws-batch-change-interval", "When using the AWS provider, set the interval between batch changes.").Default(defaultConfig.AWSBatchChangeInterval.String()).DurationVar(&cfg.AWSBatchChangeInterval)
app.Flag("aws-evaluate-target-health", "When using the AWS provider, set whether to evaluate the health of a DNS target (default: enabled, disable with --no-aws-evaluate-target-health)").Default(strconv.FormatBool(defaultConfig.AWSEvaluateTargetHealth)).BoolVar(&cfg.AWSEvaluateTargetHealth) app.Flag("aws-evaluate-target-health", "When using the AWS provider, set whether to evaluate the health of a DNS target (default: enabled, disable with --no-aws-evaluate-target-health)").Default(strconv.FormatBool(defaultConfig.AWSEvaluateTargetHealth)).BoolVar(&cfg.AWSEvaluateTargetHealth)
app.Flag("aws-api-retries", "When using the AWS provider, set the maximum number of retries for API calls before giving up.").Default(strconv.Itoa(defaultConfig.AWSAPIRetries)).IntVar(&cfg.AWSAPIRetries) app.Flag("aws-api-retries", "When using the AWS API, set the maximum number of retries before giving up.").Default(strconv.Itoa(defaultConfig.AWSAPIRetries)).IntVar(&cfg.AWSAPIRetries)
app.Flag("aws-prefer-cname", "When using the AWS provider, prefer using CNAME instead of ALIAS (default: disabled)").BoolVar(&cfg.AWSPreferCNAME) app.Flag("aws-prefer-cname", "When using the AWS provider, prefer using CNAME instead of ALIAS (default: disabled)").BoolVar(&cfg.AWSPreferCNAME)
app.Flag("aws-zones-cache-duration", "When using the AWS provider, set the zones list cache TTL (0s to disable).").Default(defaultConfig.AWSZoneCacheDuration.String()).DurationVar(&cfg.AWSZoneCacheDuration) app.Flag("aws-zones-cache-duration", "When using the AWS provider, set the zones list cache TTL (0s to disable).").Default(defaultConfig.AWSZoneCacheDuration.String()).DurationVar(&cfg.AWSZoneCacheDuration)
app.Flag("aws-sd-service-cleanup", "When using the AWS CloudMap provider, delete empty Services without endpoints (default: disabled)").BoolVar(&cfg.AWSSDServiceCleanup) app.Flag("aws-sd-service-cleanup", "When using the AWS CloudMap provider, delete empty Services without endpoints (default: disabled)").BoolVar(&cfg.AWSSDServiceCleanup)

View File

@ -25,11 +25,8 @@ import (
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53" "github.com/aws/aws-sdk-go/service/route53"
"github.com/linki/instrumented_http"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -226,49 +223,15 @@ type AWSConfig struct {
BatchChangeSize int BatchChangeSize int
BatchChangeInterval time.Duration BatchChangeInterval time.Duration
EvaluateTargetHealth bool EvaluateTargetHealth bool
AssumeRole string
AssumeRoleExternalID string
APIRetries int
PreferCNAME bool PreferCNAME bool
DryRun bool DryRun bool
ZoneCacheDuration time.Duration ZoneCacheDuration time.Duration
} }
// NewAWSProvider initializes a new AWS Route53 based Provider. // NewAWSProvider initializes a new AWS Route53 based Provider.
func NewAWSProvider(awsConfig AWSConfig) (*AWSProvider, error) { func NewAWSProvider(awsConfig AWSConfig, client Route53API) (*AWSProvider, error) {
config := aws.NewConfig().WithMaxRetries(awsConfig.APIRetries)
config.WithHTTPClient(
instrumented_http.NewClient(config.HTTPClient, &instrumented_http.Callbacks{
PathProcessor: func(path string) string {
parts := strings.Split(path, "/")
return parts[len(parts)-1]
},
}),
)
session, err := session.NewSessionWithOptions(session.Options{
Config: *config,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, errors.Wrap(err, "failed to instantiate AWS session")
}
if awsConfig.AssumeRole != "" {
if awsConfig.AssumeRoleExternalID != "" {
log.Infof("Assuming role: %s with external id %s", awsConfig.AssumeRole, awsConfig.AssumeRoleExternalID)
session.Config.WithCredentials(stscreds.NewCredentials(session, awsConfig.AssumeRole, func(p *stscreds.AssumeRoleProvider) {
p.ExternalID = &awsConfig.AssumeRoleExternalID
}))
} else {
log.Infof("Assuming role: %s", awsConfig.AssumeRole)
session.Config.WithCredentials(stscreds.NewCredentials(session, awsConfig.AssumeRole))
}
}
provider := &AWSProvider{ provider := &AWSProvider{
client: route53.New(session), client: client,
domainFilter: awsConfig.DomainFilter, domainFilter: awsConfig.DomainFilter,
zoneIDFilter: awsConfig.ZoneIDFilter, zoneIDFilter: awsConfig.ZoneIDFilter,
zoneTypeFilter: awsConfig.ZoneTypeFilter, zoneTypeFilter: awsConfig.ZoneTypeFilter,

75
provider/aws/session.go Normal file
View File

@ -0,0 +1,75 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/linki/instrumented_http"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"sigs.k8s.io/external-dns/pkg/apis/externaldns"
)
// AWSSessionConfig contains configuration to create a new AWS provider.
type AWSSessionConfig struct {
AssumeRole string
AssumeRoleExternalID string
APIRetries int
}
func NewSession(awsConfig AWSSessionConfig) (*session.Session, error) {
config := aws.NewConfig().WithMaxRetries(awsConfig.APIRetries)
config.WithHTTPClient(
instrumented_http.NewClient(config.HTTPClient, &instrumented_http.Callbacks{
PathProcessor: func(path string) string {
parts := strings.Split(path, "/")
return parts[len(parts)-1]
},
}),
)
session, err := session.NewSessionWithOptions(session.Options{
Config: *config,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, errors.Wrap(err, "failed to instantiate AWS session")
}
if awsConfig.AssumeRole != "" {
if awsConfig.AssumeRoleExternalID != "" {
logrus.Infof("Assuming role: %s with external id %s", awsConfig.AssumeRole, awsConfig.AssumeRoleExternalID)
session.Config.WithCredentials(stscreds.NewCredentials(session, awsConfig.AssumeRole, func(p *stscreds.AssumeRoleProvider) {
p.ExternalID = &awsConfig.AssumeRoleExternalID
}))
} else {
logrus.Infof("Assuming role: %s", awsConfig.AssumeRole)
session.Config.WithCredentials(stscreds.NewCredentials(session, awsConfig.AssumeRole))
}
}
session.Handlers.Build.PushBack(request.MakeAddToUserAgentHandler("ExternalDNS", externaldns.Version))
return session, nil
}

View File

@ -25,15 +25,10 @@ import (
"strings" "strings"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
sd "github.com/aws/aws-sdk-go/service/servicediscovery" sd "github.com/aws/aws-sdk-go/service/servicediscovery"
"github.com/linki/instrumented_http"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/pkg/apis/externaldns"
"sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/plan"
"sigs.k8s.io/external-dns/provider" "sigs.k8s.io/external-dns/provider"
) )
@ -86,42 +81,9 @@ type AWSSDProvider struct {
} }
// NewAWSSDProvider initializes a new AWS Cloud Map based Provider. // NewAWSSDProvider initializes a new AWS Cloud Map based Provider.
func NewAWSSDProvider(domainFilter endpoint.DomainFilter, namespaceType string, assumeRole string, assumeRoleExternalID string, dryRun, cleanEmptyService bool, ownerID string) (*AWSSDProvider, error) { func NewAWSSDProvider(domainFilter endpoint.DomainFilter, namespaceType string, dryRun, cleanEmptyService bool, ownerID string, client AWSSDClient) (*AWSSDProvider, error) {
config := aws.NewConfig()
config = config.WithHTTPClient(
instrumented_http.NewClient(config.HTTPClient, &instrumented_http.Callbacks{
PathProcessor: func(path string) string {
parts := strings.Split(path, "/")
return parts[len(parts)-1]
},
}),
)
sess, err := session.NewSessionWithOptions(session.Options{
Config: *config,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}
if assumeRole != "" {
if assumeRoleExternalID != "" {
log.Infof("Assuming role %q with external ID %q", assumeRole, assumeRoleExternalID)
sess.Config.WithCredentials(stscreds.NewCredentials(sess, assumeRole, func(p *stscreds.AssumeRoleProvider) {
p.ExternalID = &assumeRoleExternalID
}))
} else {
log.Infof("Assuming role: %s", assumeRole)
sess.Config.WithCredentials(stscreds.NewCredentials(sess, assumeRole))
}
}
sess.Handlers.Build.PushBack(request.MakeAddToUserAgentHandler("ExternalDNS", externaldns.Version))
provider := &AWSSDProvider{ provider := &AWSSDProvider{
client: sd.New(sess), client: client,
dryRun: dryRun, dryRun: dryRun,
namespaceFilter: domainFilter, namespaceFilter: domainFilter,
namespaceTypeFilter: newSdNamespaceFilter(namespaceType), namespaceTypeFilter: newSdNamespaceFilter(namespaceType),