diff --git a/provider/aws/config.go b/provider/aws/config.go index ecc53c904..7ea2d49c6 100644 --- a/provider/aws/config.go +++ b/provider/aws/config.go @@ -19,17 +19,15 @@ package aws import ( "context" "fmt" - "net/http" awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/sirupsen/logrus" - extdnshttp "sigs.k8s.io/external-dns/pkg/http" - "sigs.k8s.io/external-dns/pkg/apis/externaldns" ) @@ -84,8 +82,8 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { config.WithRetryer(func() awsv2.Retryer { return retry.AddWithMaxAttempts(retry.NewStandard(), awsConfig.APIRetries) }), - config.WithHTTPClient(extdnshttp.NewInstrumentedClient(&http.Client{})), config.WithSharedConfigProfile(awsConfig.Profile), + config.WithAPIOptions(GetInstrumentationMiddlewares()), } cfg, err := config.LoadDefaultConfig(context.Background(), defaultOpts...) diff --git a/provider/aws/config_test.go b/provider/aws/config_test.go index 00b3b46aa..ec6ef021f 100644 --- a/provider/aws/config_test.go +++ b/provider/aws/config_test.go @@ -62,6 +62,19 @@ func Test_newV2Config(t *testing.T) { assert.Equal(t, "AKIAIOSFODNN7EXAMPLE", creds.AccessKeyID) assert.Equal(t, "topsecret", creds.SecretAccessKey) }) + + t.Run("should not error when AWS_CA_BUNDLE set", func(t *testing.T) { + // setup + os.Setenv("AWS_CA_BUNDLE", "../../internal/testresources/ca.pem") + defer os.Unsetenv("AWS_CA_BUNDLE") + + // when + _, err := newV2Config(AWSSessionConfig{}) + require.NoError(t, err) + + // then + assert.NoError(t, err) + }) } func prepareCredentialsFile(t *testing.T) (*os.File, error) { diff --git a/provider/aws/instrumented_config.go b/provider/aws/instrumented_config.go new file mode 100644 index 000000000..68134116a --- /dev/null +++ b/provider/aws/instrumented_config.go @@ -0,0 +1,103 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "fmt" + "time" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/prometheus/client_golang/prometheus" + + extdnshttp "sigs.k8s.io/external-dns/pkg/http" +) + +type requestMetrics struct { + StartTime time.Time +} + +type requestMetricsKey struct{} + +func getRequestMetric(ctx context.Context) requestMetrics { + requestMetrics, _ := middleware.GetStackValue(ctx, requestMetricsKey{}).(requestMetrics) + return requestMetrics +} + +func setRequestMetric(ctx context.Context, requestMetrics requestMetrics) context.Context { + return middleware.WithStackValue(ctx, requestMetricsKey{}, requestMetrics) +} + +var initializeTimedOperationMiddleware = middleware.InitializeMiddlewareFunc("timedOperation", func( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) (middleware.InitializeOutput, middleware.Metadata, error) { + requestMetrics := requestMetrics{} + requestMetrics.StartTime = time.Now() + ctx = setRequestMetric(ctx, requestMetrics) + + return next.HandleInitialize(ctx, in) +}) + +var extractAWSRequestParameters = middleware.DeserializeMiddlewareFunc("extractAWSRequestParameters", func( + ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, +) (middleware.DeserializeOutput, middleware.Metadata, error) { + // Call the next middleware first to get the response + out, metadata, err := next.HandleDeserialize(ctx, in) + + requestMetrics := getRequestMetric(ctx) + + var host, scheme, method, path, status string + if req, ok := in.Request.(*smithyhttp.Request); ok && req != nil { + host = req.URL.Host + scheme = req.URL.Scheme + method = req.Method + path = req.URL.Path + } + + // Try to access HTTP response and status code + if resp, ok := out.RawResponse.(*smithyhttp.Response); ok && resp != nil { + status = fmt.Sprintf("%d", resp.StatusCode) + } + + labels := prometheus.Labels{ + "scheme": scheme, + "host": host, + "path": extdnshttp.PathProcessor(path), + "method": method, + "status": status, + } + extdnshttp.RequestDuration.With(labels).Observe(time.Since(requestMetrics.StartTime).Seconds()) + + return out, metadata, err +}) + +func GetInstrumentationMiddlewares() []func(*middleware.Stack) error { + return []func(s *middleware.Stack) error{ + func(s *middleware.Stack) error { + if err := s.Initialize.Add(initializeTimedOperationMiddleware, middleware.Before); err != nil { + return fmt.Errorf("error adding timedOperationMiddleware: %w", err) + } + + if err := s.Deserialize.Add(extractAWSRequestParameters, middleware.After); err != nil { + return fmt.Errorf("error adding extractAWSRequestParameters: %w", err) + } + + return nil + }, + } +} diff --git a/provider/aws/instrumented_config_test.go b/provider/aws/instrumented_config_test.go new file mode 100644 index 000000000..397b599d9 --- /dev/null +++ b/provider/aws/instrumented_config_test.go @@ -0,0 +1,102 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "net/http" + "net/url" + "testing" + "time" + + "github.com/aws/smithy-go/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func Test_GetInstrumentationMiddlewares(t *testing.T) { + t.Run("adds expected middlewares", func(t *testing.T) { + stack := middleware.NewStack("test-stack", nil) + + for _, mw := range GetInstrumentationMiddlewares() { + err := mw(stack) + require.NoError(t, err) + } + + // Check Initialize stage + timedOperationMiddleware, found := stack.Initialize.Get("timedOperation") + assert.True(t, found, "timedOperation middleware should be present in Initialize stage") + assert.NotNil(t, timedOperationMiddleware) + + // Check Deserialize stage + extractAWSRequestParametersMiddleware, found := stack.Deserialize.Get("extractAWSRequestParameters") + assert.True(t, found, "extractAWSRequestParameters middleware should be present in Deserialize stage") + assert.NotNil(t, extractAWSRequestParametersMiddleware) + }) +} + +type MockInitializeHandler struct { + CapturedContext context.Context +} + +func (mock *MockInitializeHandler) HandleInitialize(ctx context.Context, in middleware.InitializeInput) (middleware.InitializeOutput, middleware.Metadata, error) { + mock.CapturedContext = ctx + + return middleware.InitializeOutput{}, middleware.Metadata{}, nil +} + +func Test_InitializedTimedOperationMiddleware(t *testing.T) { + testContext := context.Background() + mockInitializeHandler := &MockInitializeHandler{} + + _, _, err := initializeTimedOperationMiddleware.HandleInitialize(testContext, middleware.InitializeInput{}, mockInitializeHandler) + require.NoError(t, err) + + requestMetrics := middleware.GetStackValue(mockInitializeHandler.CapturedContext, requestMetricsKey{}).(requestMetrics) + assert.NotNil(t, requestMetrics.StartTime) +} + +type MockDeserializeHandler struct { +} + +func (mock *MockDeserializeHandler) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput) (middleware.DeserializeOutput, middleware.Metadata, error) { + return middleware.DeserializeOutput{}, middleware.Metadata{}, nil +} + +func Test_ExtractAWSRequestParameters(t *testing.T) { + testContext := context.Background() + middleware.WithStackValue(testContext, requestMetricsKey{}, requestMetrics{StartTime: time.Now()}) + + mockDeserializeHandler := &MockDeserializeHandler{} + + deserializeInput := middleware.DeserializeInput{ + Request: &smithyhttp.Request{ + Request: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Host: "example.com", + Scheme: "HTTPS", + Path: "/testPath", + }, + }, + }, + } + _, _, err := extractAWSRequestParameters.HandleDeserialize(testContext, deserializeInput, mockDeserializeHandler) + require.NoError(t, err) +}