diff --git a/go.mod b/go.mod index da536da00..543b2ebdf 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,12 @@ require ( github.com/alecthomas/kingpin/v2 v2.4.0 github.com/aliyun/alibaba-cloud-sdk-go v1.63.0 github.com/aws/aws-sdk-go v1.55.5 + github.com/aws/aws-sdk-go-v2 v1.30.3 + github.com/aws/aws-sdk-go-v2/config v1.27.27 + github.com/aws/aws-sdk-go-v2/credentials v1.17.27 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.10 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.34.4 + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 github.com/bodgit/tsig v1.2.2 github.com/cenkalti/backoff/v4 v4.3.0 github.com/civo/civogo v0.3.73 @@ -85,6 +91,17 @@ require ( github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.22.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect + github.com/aws/smithy-go v1.20.3 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 1e7d64fc8..1798b3272 100644 --- a/go.sum +++ b/go.sum @@ -119,6 +119,40 @@ github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= +github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= +github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= +github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.10 h1:orAIBscNu5aIjDOnKIrjO+IUFPMLKj3Lp0bPf4chiPc= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.14.10/go.mod h1:GNjJ8daGhv10hmQYCnmkV8HuY6xXOXV4vzBssSjEIlU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.34.4 h1:utG3S4T+X7nONPIpRoi1tVcQdAdJxntiVS2yolPJyXc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.34.4/go.mod h1:q9vzW3Xr1KEXa8n4waHiFt1PrppNDlMymlYP+xpsFbY= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.22.3 h1:r27/FnxLPixKBRIlslsvhqscBuMK8uysCYG9Kfgm098= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.22.3/go.mod h1:jqOFyN+QSWSoQC+ppyc4weiO8iNQXbzRbxDjQ1ayYd4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.16 h1:lhAX5f7KpgwyieXjbDnRTjPEUI0l3emSRyxXj1PXP8w= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.16/go.mod h1:AblAlCwvi7Q/SFowvckgN+8M3uFPlopSYeLlbNDArhA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= +github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20160804104726-4c0e84591b9a/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/main.go b/main.go index dfdde6692..27f91877c 100644 --- a/main.go +++ b/main.go @@ -25,8 +25,7 @@ import ( "syscall" "time" - awsSDK "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go/service/route53" sd "github.com/aws/aws-sdk-go/service/servicediscovery" "github.com/go-logr/logr" @@ -383,11 +382,15 @@ func main() { var r registry.Registry switch cfg.Registry { case "dynamodb": - config := awsSDK.NewConfig() + var dynamodbOpts []func(*dynamodb.Options) if cfg.AWSDynamoDBRegion != "" { - config = config.WithRegion(cfg.AWSDynamoDBRegion) + dynamodbOpts = []func(*dynamodb.Options){ + func(opts *dynamodb.Options) { + opts.Region = cfg.AWSDynamoDBRegion + }, + } } - r, err = registry.NewDynamoDBRegistry(p, cfg.TXTOwnerID, dynamodb.New(aws.CreateDefaultSession(cfg), config), cfg.AWSDynamoDBTable, cfg.TXTPrefix, cfg.TXTSuffix, cfg.TXTWildcardReplacement, cfg.ManagedDNSRecordTypes, cfg.ExcludeDNSRecordTypes, []byte(cfg.TXTEncryptAESKey), cfg.TXTCacheInterval) + r, err = registry.NewDynamoDBRegistry(p, cfg.TXTOwnerID, dynamodb.NewFromConfig(aws.CreateDefaultV2Config(cfg), dynamodbOpts...), cfg.AWSDynamoDBTable, cfg.TXTPrefix, cfg.TXTSuffix, cfg.TXTWildcardReplacement, cfg.ManagedDNSRecordTypes, cfg.ExcludeDNSRecordTypes, []byte(cfg.TXTEncryptAESKey), cfg.TXTCacheInterval) case "noop": r, err = registry.NewNoopRegistry(p) case "txt": diff --git a/provider/aws/session.go b/provider/aws/session.go index da578b292..8de5d0f40 100644 --- a/provider/aws/session.go +++ b/provider/aws/session.go @@ -17,9 +17,16 @@ limitations under the License. package aws import ( + "context" "fmt" + "net/http" "strings" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/request" @@ -38,6 +45,20 @@ type AWSSessionConfig struct { Profile string } +func CreateDefaultV2Config(cfg *externaldns.Config) awsv2.Config { + result, err := newV2Config( + AWSSessionConfig{ + AssumeRole: cfg.AWSAssumeRole, + AssumeRoleExternalID: cfg.AWSAssumeRoleExternalID, + APIRetries: cfg.AWSAPIRetries, + }, + ) + if err != nil { + logrus.Fatal(err) + } + return result +} + func CreateDefaultSession(cfg *externaldns.Config) *session.Session { result, err := newSession( AWSSessionConfig{ @@ -123,3 +144,42 @@ func newSession(awsConfig AWSSessionConfig) (*session.Session, error) { return session, nil } + +func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { + defaultOpts := []func(*config.LoadOptions) error{ + config.WithRetryer(func() awsv2.Retryer { + return retry.AddWithMaxAttempts(retry.NewStandard(), awsConfig.APIRetries) + }), + config.WithHTTPClient(instrumented_http.NewClient(&http.Client{}, &instrumented_http.Callbacks{ + PathProcessor: func(path string) string { + parts := strings.Split(path, "/") + return parts[len(parts)-1] + }, + })), + config.WithSharedConfigProfile(awsConfig.Profile), + } + + cfg, err := config.LoadDefaultConfig(context.Background(), defaultOpts...) + if err != nil { + return awsv2.Config{}, fmt.Errorf("instantiating AWS config: %w", err) + } + + if awsConfig.AssumeRole != "" { + stsSvc := sts.NewFromConfig(cfg) + var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions) + if awsConfig.AssumeRoleExternalID != "" { + logrus.Infof("Assuming role: %s with external id %s", awsConfig.AssumeRole, awsConfig.AssumeRoleExternalID) + assumeRoleOpts = []func(*stscredsv2.AssumeRoleOptions){ + func(opts *stscredsv2.AssumeRoleOptions) { + opts.ExternalID = &awsConfig.AssumeRoleExternalID + }, + } + } else { + logrus.Infof("Assuming role: %s", awsConfig.AssumeRole) + } + creds := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...) + cfg.Credentials = awsv2.NewCredentialsCache(creds) + } + + return cfg, nil +} diff --git a/provider/aws/session_test.go b/provider/aws/session_test.go index 206fcf940..b73485c03 100644 --- a/provider/aws/session_test.go +++ b/provider/aws/session_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "os" "testing" @@ -63,6 +64,45 @@ func Test_newSession(t *testing.T) { }) } +func Test_newV2Config(t *testing.T) { + t.Run("should use profile from credentials file", func(t *testing.T) { + // setup + credsFile, err := prepareCredentialsFile(t) + defer os.Remove(credsFile.Name()) + require.NoError(t, err) + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile.Name()) + defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") + + // when + cfg, err := newV2Config(AWSSessionConfig{Profile: "profile2"}) + require.NoError(t, err) + creds, err := cfg.Credentials.Retrieve(context.Background()) + + // then + assert.NoError(t, err) + assert.Equal(t, "AKID2345", creds.AccessKeyID) + assert.Equal(t, "SECRET2", creds.SecretAccessKey) + }) + + t.Run("should respect env variables without profile", func(t *testing.T) { + // setup + os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") + os.Setenv("AWS_SECRET_ACCESS_KEY", "topsecret") + defer os.Unsetenv("AWS_ACCESS_KEY_ID") + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + // when + cfg, err := newV2Config(AWSSessionConfig{}) + require.NoError(t, err) + creds, err := cfg.Credentials.Retrieve(context.Background()) + + // then + assert.NoError(t, err) + assert.Equal(t, "AKIAIOSFODNN7EXAMPLE", creds.AccessKeyID) + assert.Equal(t, "topsecret", creds.SecretAccessKey) + }) +} + func prepareCredentialsFile(t *testing.T) (*os.File, error) { credsFile, err := os.CreateTemp("", "aws-*.creds") require.NoError(t, err) diff --git a/registry/dynamodb.go b/registry/dynamodb.go index b13d55ce9..805985f34 100644 --- a/registry/dynamodb.go +++ b/registry/dynamodb.go @@ -23,9 +23,10 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + dynamodbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" @@ -34,11 +35,11 @@ import ( "sigs.k8s.io/external-dns/provider" ) -// DynamoDBAPI is the subset of the AWS Route53 API that we actually use. Add methods as required. Signatures must match exactly. +// DynamoDBAPI is the subset of the AWS DynamoDB API that we actually use. Add methods as required. Signatures must match exactly. type DynamoDBAPI interface { - DescribeTableWithContext(ctx aws.Context, input *dynamodb.DescribeTableInput, opts ...request.Option) (*dynamodb.DescribeTableOutput, error) - ScanPagesWithContext(ctx aws.Context, input *dynamodb.ScanInput, fn func(*dynamodb.ScanOutput, bool) bool, opts ...request.Option) error - BatchExecuteStatementWithContext(aws.Context, *dynamodb.BatchExecuteStatementInput, ...request.Option) (*dynamodb.BatchExecuteStatementOutput, error) + DescribeTable(context.Context, *dynamodb.DescribeTableInput, ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) + Scan(context.Context, *dynamodb.ScanInput, ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) + BatchExecuteStatement(context.Context, *dynamodb.BatchExecuteStatementInput, ...func(*dynamodb.Options)) (*dynamodb.BatchExecuteStatementOutput, error) } // DynamoDBRegistry implements registry interface with ownership implemented via an AWS DynamoDB table. @@ -225,7 +226,7 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan Delete: endpoint.FilterEndpointsByOwnerID(im.ownerID, changes.Delete), } - statements := make([]*dynamodb.BatchStatementRequest, 0, len(filteredChanges.Create)+len(filteredChanges.UpdateNew)) + statements := make([]dynamodbtypes.BatchStatementRequest, 0, len(filteredChanges.Create)+len(filteredChanges.UpdateNew)) for _, r := range filteredChanges.Create { if r.Labels == nil { r.Labels = make(map[string]string) @@ -286,12 +287,15 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan } } - err := im.executeStatements(ctx, statements, func(request *dynamodb.BatchStatementRequest, response *dynamodb.BatchStatementResponse) error { + err := im.executeStatements(ctx, statements, func(request dynamodbtypes.BatchStatementRequest, response dynamodbtypes.BatchStatementResponse) error { var context string if strings.HasPrefix(*request.Statement, "INSERT") { - if aws.StringValue(response.Error.Code) == "DuplicateItem" { + if response.Error.Code == dynamodbtypes.BatchStatementErrorCodeEnumDuplicateItem { // We lost a race with a different owner or another owner has an orphaned ownership record. - key := fromDynamoKey(request.Parameters[0]) + key, err := fromDynamoKey(request.Parameters[0]) + if err != nil { + return err + } for i, endpoint := range filteredChanges.Create { if endpoint.Key() == key { log.Infof("Skipping endpoint %v because owner does not match", endpoint) @@ -303,11 +307,19 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan } } } - context = fmt.Sprintf("inserting dynamodb record %q", aws.StringValue(request.Parameters[0].S)) + var record string + if err := attributevalue.Unmarshal(request.Parameters[0], &record); err != nil { + return fmt.Errorf("inserting dynamodb record: %w", err) + } + context = fmt.Sprintf("inserting dynamodb record %q", record) } else { - context = fmt.Sprintf("updating dynamodb record %q", aws.StringValue(request.Parameters[1].S)) + var record string + if err := attributevalue.Unmarshal(request.Parameters[1], &record); err != nil { + return fmt.Errorf("inserting dynamodb record: %w", err) + } + context = fmt.Sprintf("updating dynamodb record %q", record) } - return fmt.Errorf("%s: %s: %s", context, aws.StringValue(response.Error.Code), aws.StringValue(response.Error.Message)) + return fmt.Errorf("%s: %s: %s", context, response.Error.Code, *response.Error.Message) }) if err != nil { im.recordsCache = nil @@ -326,7 +338,7 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan return err } - statements = make([]*dynamodb.BatchStatementRequest, 0, len(filteredChanges.Delete)+len(im.orphanedLabels)) + statements = make([]dynamodbtypes.BatchStatementRequest, 0, len(filteredChanges.Delete)+len(im.orphanedLabels)) for _, r := range filteredChanges.Delete { statements = im.appendDelete(statements, r.Key()) } @@ -335,9 +347,13 @@ func (im *DynamoDBRegistry) ApplyChanges(ctx context.Context, changes *plan.Chan delete(im.labels, r) } im.orphanedLabels = nil - return im.executeStatements(ctx, statements, func(request *dynamodb.BatchStatementRequest, response *dynamodb.BatchStatementResponse) error { + return im.executeStatements(ctx, statements, func(request dynamodbtypes.BatchStatementRequest, response dynamodbtypes.BatchStatementResponse) error { im.labels = nil - return fmt.Errorf("deleting dynamodb record %q: %s: %s", aws.StringValue(request.Parameters[0].S), aws.StringValue(response.Error.Code), aws.StringValue(response.Error.Message)) + record, err := fromDynamoKey(request.Parameters[0]) + if err != nil { + return fmt.Errorf("deleting dynamodb record: %w", err) + } + return fmt.Errorf("deleting dynamodb record %q: %s: %s", record, response.Error.Code, *response.Error.Message) }) } @@ -347,7 +363,7 @@ func (im *DynamoDBRegistry) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]* } func (im *DynamoDBRegistry) readLabels(ctx context.Context) error { - table, err := im.dynamodbAPI.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{ + table, err := im.dynamodbAPI.DescribeTable(ctx, &dynamodb.DescribeTableInput{ TableName: aws.String(im.table), }) if err != nil { @@ -356,8 +372,8 @@ func (im *DynamoDBRegistry) readLabels(ctx context.Context) error { foundKey := false for _, def := range table.Table.AttributeDefinitions { - if aws.StringValue(def.AttributeName) == "k" { - if aws.StringValue(def.AttributeType) != "S" { + if *def.AttributeName == "k" { + if def.AttributeType != dynamodbtypes.ScalarAttributeTypeS { return fmt.Errorf("table %q attribute \"k\" must have type \"S\"", im.table) } foundKey = true @@ -367,7 +383,7 @@ func (im *DynamoDBRegistry) readLabels(ctx context.Context) error { return fmt.Errorf("table %q must have attribute \"k\" of type \"S\"", im.table) } - if aws.StringValue(table.Table.KeySchema[0].AttributeName) != "k" { + if *table.Table.KeySchema[0].AttributeName != "k" { return fmt.Errorf("table %q must have hash key \"k\"", im.table) } if len(table.Table.KeySchema) > 1 { @@ -375,76 +391,92 @@ func (im *DynamoDBRegistry) readLabels(ctx context.Context) error { } labels := map[endpoint.EndpointKey]endpoint.Labels{} - err = im.dynamodbAPI.ScanPagesWithContext(ctx, &dynamodb.ScanInput{ + scanPaginator := dynamodb.NewScanPaginator(im.dynamodbAPI, &dynamodb.ScanInput{ TableName: aws.String(im.table), FilterExpression: aws.String("o = :ownerval"), - ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":ownerval": {S: aws.String(im.ownerID)}, + ExpressionAttributeValues: map[string]dynamodbtypes.AttributeValue{ + ":ownerval": &dynamodbtypes.AttributeValueMemberS{Value: im.ownerID}, }, ProjectionExpression: aws.String("k,l"), ConsistentRead: aws.Bool(true), - }, func(output *dynamodb.ScanOutput, last bool) bool { - for _, item := range output.Items { - labels[fromDynamoKey(item["k"])] = fromDynamoLabels(item["l"], im.ownerID) - } - return true }) - if err != nil { - return fmt.Errorf("querying dynamodb: %w", err) + for scanPaginator.HasMorePages() { + output, err := scanPaginator.NextPage(ctx) + if err != nil { + return fmt.Errorf("scanning table %q: %w", im.table, err) + } + for _, item := range output.Items { + k, err := fromDynamoKey(item["k"]) + if err != nil { + return fmt.Errorf("querying dynamodb for key: %w", err) + } + l, err := fromDynamoLabels(item["l"], im.ownerID) + if err != nil { + return fmt.Errorf("querying dynamodb for labels: %w", err) + } + + labels[k] = l + } } im.labels = labels return nil } -func fromDynamoKey(key *dynamodb.AttributeValue) endpoint.EndpointKey { - split := strings.SplitN(aws.StringValue(key.S), "#", 3) +func fromDynamoKey(key dynamodbtypes.AttributeValue) (endpoint.EndpointKey, error) { + var ep string + if err := attributevalue.Unmarshal(key, &ep); err != nil { + return endpoint.EndpointKey{}, fmt.Errorf("unmarshalling endpoint key: %w", err) + } + split := strings.SplitN(ep, "#", 3) return endpoint.EndpointKey{ DNSName: split[0], RecordType: split[1], SetIdentifier: split[2], + }, nil +} + +func toDynamoKey(key endpoint.EndpointKey) dynamodbtypes.AttributeValue { + return &dynamodbtypes.AttributeValueMemberS{ + Value: fmt.Sprintf("%s#%s#%s", key.DNSName, key.RecordType, key.SetIdentifier), } } -func toDynamoKey(key endpoint.EndpointKey) *dynamodb.AttributeValue { - return &dynamodb.AttributeValue{ - S: aws.String(fmt.Sprintf("%s#%s#%s", key.DNSName, key.RecordType, key.SetIdentifier)), - } -} - -func fromDynamoLabels(label *dynamodb.AttributeValue, owner string) endpoint.Labels { +func fromDynamoLabels(label dynamodbtypes.AttributeValue, owner string) (endpoint.Labels, error) { labels := endpoint.NewLabels() - for k, v := range label.M { - labels[k] = aws.StringValue(v.S) + if err := attributevalue.Unmarshal(label, &labels); err != nil { + return endpoint.Labels{}, fmt.Errorf("unmarshalling labels: %w", err) } labels[endpoint.OwnerLabelKey] = owner - return labels + return labels, nil } -func toDynamoLabels(labels endpoint.Labels) *dynamodb.AttributeValue { - labelMap := make(map[string]*dynamodb.AttributeValue, len(labels)) +func toDynamoLabels(labels endpoint.Labels) dynamodbtypes.AttributeValue { + labelMap := make(map[string]dynamodbtypes.AttributeValue, len(labels)) for k, v := range labels { if k == endpoint.OwnerLabelKey { continue } - labelMap[k] = &dynamodb.AttributeValue{S: aws.String(v)} + labelMap[k] = &dynamodbtypes.AttributeValueMemberS{Value: v} } - return &dynamodb.AttributeValue{M: labelMap} + return &dynamodbtypes.AttributeValueMemberM{Value: labelMap} } -func (im *DynamoDBRegistry) appendInsert(statements []*dynamodb.BatchStatementRequest, key endpoint.EndpointKey, new endpoint.Labels) []*dynamodb.BatchStatementRequest { - return append(statements, &dynamodb.BatchStatementRequest{ - Statement: aws.String(fmt.Sprintf("INSERT INTO %q VALUE {'k':?, 'o':?, 'l':?}", im.table)), - Parameters: []*dynamodb.AttributeValue{ +func (im *DynamoDBRegistry) appendInsert(statements []dynamodbtypes.BatchStatementRequest, key endpoint.EndpointKey, new endpoint.Labels) []dynamodbtypes.BatchStatementRequest { + return append(statements, dynamodbtypes.BatchStatementRequest{ + Statement: aws.String(fmt.Sprintf("INSERT INTO %q VALUE {'k':?, 'o':?, 'l':?}", im.table)), + ConsistentRead: aws.Bool(true), + Parameters: []dynamodbtypes.AttributeValue{ toDynamoKey(key), - {S: aws.String(im.ownerID)}, + &dynamodbtypes.AttributeValueMemberS{ + Value: im.ownerID, + }, toDynamoLabels(new), }, - ConsistentRead: aws.Bool(true), }) } -func (im *DynamoDBRegistry) appendUpdate(statements []*dynamodb.BatchStatementRequest, key endpoint.EndpointKey, old endpoint.Labels, new endpoint.Labels) []*dynamodb.BatchStatementRequest { +func (im *DynamoDBRegistry) appendUpdate(statements []dynamodbtypes.BatchStatementRequest, key endpoint.EndpointKey, old endpoint.Labels, new endpoint.Labels) []dynamodbtypes.BatchStatementRequest { if len(old) == len(new) { equal := true for k, v := range old { @@ -458,28 +490,28 @@ func (im *DynamoDBRegistry) appendUpdate(statements []*dynamodb.BatchStatementRe } } - return append(statements, &dynamodb.BatchStatementRequest{ + return append(statements, dynamodbtypes.BatchStatementRequest{ Statement: aws.String(fmt.Sprintf("UPDATE %q SET \"l\"=? WHERE \"k\"=?", im.table)), - Parameters: []*dynamodb.AttributeValue{ + Parameters: []dynamodbtypes.AttributeValue{ toDynamoLabels(new), toDynamoKey(key), }, }) } -func (im *DynamoDBRegistry) appendDelete(statements []*dynamodb.BatchStatementRequest, key endpoint.EndpointKey) []*dynamodb.BatchStatementRequest { - return append(statements, &dynamodb.BatchStatementRequest{ +func (im *DynamoDBRegistry) appendDelete(statements []dynamodbtypes.BatchStatementRequest, key endpoint.EndpointKey) []dynamodbtypes.BatchStatementRequest { + return append(statements, dynamodbtypes.BatchStatementRequest{ Statement: aws.String(fmt.Sprintf("DELETE FROM %q WHERE \"k\"=? AND \"o\"=?", im.table)), - Parameters: []*dynamodb.AttributeValue{ + Parameters: []dynamodbtypes.AttributeValue{ toDynamoKey(key), - {S: aws.String(im.ownerID)}, + &dynamodbtypes.AttributeValueMemberS{Value: im.ownerID}, }, }) } -func (im *DynamoDBRegistry) executeStatements(ctx context.Context, statements []*dynamodb.BatchStatementRequest, handleErr func(request *dynamodb.BatchStatementRequest, response *dynamodb.BatchStatementResponse) error) error { +func (im *DynamoDBRegistry) executeStatements(ctx context.Context, statements []dynamodbtypes.BatchStatementRequest, handleErr func(request dynamodbtypes.BatchStatementRequest, response dynamodbtypes.BatchStatementResponse) error) error { for len(statements) > 0 { - var chunk []*dynamodb.BatchStatementRequest + var chunk []dynamodbtypes.BatchStatementRequest if len(statements) > int(dynamodbMaxBatchSize) { chunk = statements[:dynamodbMaxBatchSize] statements = statements[dynamodbMaxBatchSize:] @@ -488,7 +520,7 @@ func (im *DynamoDBRegistry) executeStatements(ctx context.Context, statements [] statements = nil } - output, err := im.dynamodbAPI.BatchExecuteStatementWithContext(ctx, &dynamodb.BatchExecuteStatementInput{ + output, err := im.dynamodbAPI.BatchExecuteStatement(ctx, &dynamodb.BatchExecuteStatementInput{ Statements: chunk, }) if err != nil { @@ -501,9 +533,13 @@ func (im *DynamoDBRegistry) executeStatements(ctx context.Context, statements [] op, _, _ := strings.Cut(*request.Statement, " ") var key string if op == "UPDATE" { - key = *request.Parameters[1].S + if err := attributevalue.Unmarshal(request.Parameters[1], &key); err != nil { + return err + } } else { - key = *request.Parameters[0].S + if err := attributevalue.Unmarshal(request.Parameters[0], &key); err != nil { + return err + } } log.Infof("%s dynamodb record %q", op, key) } else { diff --git a/registry/dynamodb_test.go b/registry/dynamodb_test.go index 1cadd7dcd..c280554fe 100644 --- a/registry/dynamodb_test.go +++ b/registry/dynamodb_test.go @@ -22,9 +22,10 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + dynamodbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/sets" @@ -69,40 +70,40 @@ func TestDynamoDBRegistryNew(t *testing.T) { func TestDynamoDBRegistryRecordsBadTable(t *testing.T) { for _, tc := range []struct { name string - setup func(desc *dynamodb.TableDescription) + setup func(desc *dynamodbtypes.TableDescription) expected string }{ { name: "missing attribute k", - setup: func(desc *dynamodb.TableDescription) { + setup: func(desc *dynamodbtypes.TableDescription) { desc.AttributeDefinitions[0].AttributeName = aws.String("wrong") }, expected: "table \"test-table\" must have attribute \"k\" of type \"S\"", }, { name: "wrong attribute type", - setup: func(desc *dynamodb.TableDescription) { - desc.AttributeDefinitions[0].AttributeType = aws.String("SS") + setup: func(desc *dynamodbtypes.TableDescription) { + desc.AttributeDefinitions[0].AttributeType = "SS" }, expected: "table \"test-table\" attribute \"k\" must have type \"S\"", }, { name: "wrong key", - setup: func(desc *dynamodb.TableDescription) { + setup: func(desc *dynamodbtypes.TableDescription) { desc.KeySchema[0].AttributeName = aws.String("wrong") }, expected: "table \"test-table\" must have hash key \"k\"", }, { name: "has range key", - setup: func(desc *dynamodb.TableDescription) { - desc.AttributeDefinitions = append(desc.AttributeDefinitions, &dynamodb.AttributeDefinition{ + setup: func(desc *dynamodbtypes.TableDescription) { + desc.AttributeDefinitions = append(desc.AttributeDefinitions, dynamodbtypes.AttributeDefinition{ AttributeName: aws.String("o"), - AttributeType: aws.String("S"), + AttributeType: dynamodbtypes.ScalarAttributeTypeS, }) - desc.KeySchema = append(desc.KeySchema, &dynamodb.KeySchemaElement{ + desc.KeySchema = append(desc.KeySchema, dynamodbtypes.KeySchemaElement{ AttributeName: aws.String("o"), - KeyType: aws.String("RANGE"), + KeyType: dynamodbtypes.KeyTypeRange, }) }, expected: "table \"test-table\" must not have a range key", @@ -559,8 +560,8 @@ func TestDynamoDBRegistryApplyChanges(t *testing.T) { }, }, stubConfig: DynamoDBStubConfig{ - ExpectInsertError: map[string]string{ - "new.test-zone.example.org#CNAME#set-new": "DuplicateItem", + ExpectInsertError: map[string]dynamodbtypes.BatchStatementErrorCodeEnum{ + "new.test-zone.example.org#CNAME#set-new": dynamodbtypes.BatchStatementErrorCodeEnumDuplicateItem, }, ExpectDelete: sets.New("quux.test-zone.example.org#A#set-2"), }, @@ -620,7 +621,7 @@ func TestDynamoDBRegistryApplyChanges(t *testing.T) { }, }, stubConfig: DynamoDBStubConfig{ - ExpectInsertError: map[string]string{ + ExpectInsertError: map[string]dynamodbtypes.BatchStatementErrorCodeEnum{ "new.test-zone.example.org#CNAME#set-new": "TestingError", }, }, @@ -928,7 +929,7 @@ func TestDynamoDBRegistryApplyChanges(t *testing.T) { }, }, stubConfig: DynamoDBStubConfig{ - ExpectUpdateError: map[string]string{ + ExpectUpdateError: map[string]dynamodbtypes.BatchStatementErrorCodeEnum{ "bar.test-zone.example.org#CNAME#": "TestingError", }, }, @@ -1073,15 +1074,15 @@ func TestDynamoDBRegistryApplyChanges(t *testing.T) { type DynamoDBStub struct { t *testing.T stubConfig *DynamoDBStubConfig - tableDescription dynamodb.TableDescription + tableDescription dynamodbtypes.TableDescription changesApplied bool } type DynamoDBStubConfig struct { ExpectInsert map[string]map[string]string - ExpectInsertError map[string]string + ExpectInsertError map[string]dynamodbtypes.BatchStatementErrorCodeEnum ExpectUpdate map[string]map[string]string - ExpectUpdateError map[string]string + ExpectUpdateError map[string]dynamodbtypes.BatchStatementErrorCodeEnum ExpectDelete sets.Set[string] } @@ -1100,17 +1101,17 @@ func newDynamoDBAPIStub(t *testing.T, stubConfig *DynamoDBStubConfig) (*DynamoDB stub := &DynamoDBStub{ t: t, stubConfig: stubConfig, - tableDescription: dynamodb.TableDescription{ - AttributeDefinitions: []*dynamodb.AttributeDefinition{ + tableDescription: dynamodbtypes.TableDescription{ + AttributeDefinitions: []dynamodbtypes.AttributeDefinition{ { AttributeName: aws.String("k"), - AttributeType: aws.String("S"), + AttributeType: dynamodbtypes.ScalarAttributeTypeS, }, }, - KeySchema: []*dynamodb.KeySchemaElement{ + KeySchema: []dynamodbtypes.KeySchemaElement{ { AttributeName: aws.String("k"), - KeyType: aws.String("HASH"), + KeyType: dynamodbtypes.KeyTypeHash, }, }, }, @@ -1131,7 +1132,7 @@ func newDynamoDBAPIStub(t *testing.T, stubConfig *DynamoDBStubConfig) (*DynamoDB } } -func (r *DynamoDBStub) DescribeTableWithContext(ctx aws.Context, input *dynamodb.DescribeTableInput, opts ...request.Option) (*dynamodb.DescribeTableOutput, error) { +func (r *DynamoDBStub) DescribeTable(ctx context.Context, input *dynamodb.DescribeTableInput, opts ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) { assert.NotNil(r.t, ctx) assert.Equal(r.t, "test-table", *input.TableName, "table name") return &dynamodb.DescribeTableOutput{ @@ -1139,75 +1140,80 @@ func (r *DynamoDBStub) DescribeTableWithContext(ctx aws.Context, input *dynamodb }, nil } -func (r *DynamoDBStub) ScanPagesWithContext(ctx aws.Context, input *dynamodb.ScanInput, fn func(*dynamodb.ScanOutput, bool) bool, opts ...request.Option) error { +func (r *DynamoDBStub) Scan(ctx context.Context, input *dynamodb.ScanInput, opts ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) { assert.NotNil(r.t, ctx) assert.Equal(r.t, "test-table", *input.TableName, "table name") assert.Equal(r.t, "o = :ownerval", *input.FilterExpression) assert.Len(r.t, input.ExpressionAttributeValues, 1) - assert.Equal(r.t, "test-owner", *input.ExpressionAttributeValues[":ownerval"].S) + var owner string + assert.Nil(r.t, attributevalue.Unmarshal(input.ExpressionAttributeValues[":ownerval"], &owner)) + assert.Equal(r.t, "test-owner", owner) assert.Equal(r.t, "k,l", *input.ProjectionExpression) assert.True(r.t, *input.ConsistentRead) - fn(&dynamodb.ScanOutput{ - Items: []map[string]*dynamodb.AttributeValue{ + return &dynamodb.ScanOutput{ + Items: []map[string]dynamodbtypes.AttributeValue{ { - "k": &dynamodb.AttributeValue{S: aws.String("bar.test-zone.example.org#CNAME#")}, - "l": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - endpoint.ResourceLabelKey: {S: aws.String("ingress/default/my-ingress")}, + "k": &dynamodbtypes.AttributeValueMemberS{Value: "bar.test-zone.example.org#CNAME#"}, + "l": &dynamodbtypes.AttributeValueMemberM{Value: map[string]dynamodbtypes.AttributeValue{ + endpoint.ResourceLabelKey: &dynamodbtypes.AttributeValueMemberS{Value: "ingress/default/my-ingress"}, }}, }, { - "k": &dynamodb.AttributeValue{S: aws.String("baz.test-zone.example.org#A#set-1")}, - "l": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - endpoint.ResourceLabelKey: {S: aws.String("ingress/default/my-ingress")}, + "k": &dynamodbtypes.AttributeValueMemberS{Value: "baz.test-zone.example.org#A#set-1"}, + "l": &dynamodbtypes.AttributeValueMemberM{Value: map[string]dynamodbtypes.AttributeValue{ + endpoint.ResourceLabelKey: &dynamodbtypes.AttributeValueMemberS{Value: "ingress/default/my-ingress"}, }}, }, { - "k": &dynamodb.AttributeValue{S: aws.String("baz.test-zone.example.org#A#set-2")}, - "l": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - endpoint.ResourceLabelKey: {S: aws.String("ingress/default/other-ingress")}, + "k": &dynamodbtypes.AttributeValueMemberS{Value: "baz.test-zone.example.org#A#set-2"}, + "l": &dynamodbtypes.AttributeValueMemberM{Value: map[string]dynamodbtypes.AttributeValue{ + endpoint.ResourceLabelKey: &dynamodbtypes.AttributeValueMemberS{Value: "ingress/default/other-ingress"}, }}, }, { - "k": &dynamodb.AttributeValue{S: aws.String("quux.test-zone.example.org#A#set-2")}, - "l": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{ - endpoint.ResourceLabelKey: {S: aws.String("ingress/default/quux-ingress")}, + "k": &dynamodbtypes.AttributeValueMemberS{Value: "quux.test-zone.example.org#A#set-2"}, + "l": &dynamodbtypes.AttributeValueMemberM{Value: map[string]dynamodbtypes.AttributeValue{ + endpoint.ResourceLabelKey: &dynamodbtypes.AttributeValueMemberS{Value: "ingress/default/quux-ingress"}, }}, }, }, - }, true) - return nil + }, nil } -func (r *DynamoDBStub) BatchExecuteStatementWithContext(context aws.Context, input *dynamodb.BatchExecuteStatementInput, option ...request.Option) (*dynamodb.BatchExecuteStatementOutput, error) { +func (r *DynamoDBStub) BatchExecuteStatement(context context.Context, input *dynamodb.BatchExecuteStatementInput, option ...func(*dynamodb.Options)) (*dynamodb.BatchExecuteStatementOutput, error) { assert.NotNil(r.t, context) - hasDelete := strings.HasPrefix(strings.ToLower(aws.StringValue(input.Statements[0].Statement)), "delete") + hasDelete := strings.HasPrefix(strings.ToLower(*input.Statements[0].Statement), "delete") assert.Equal(r.t, hasDelete, r.changesApplied, "delete after provider changes, everything else before") assert.LessOrEqual(r.t, len(input.Statements), 25) - responses := make([]*dynamodb.BatchStatementResponse, 0, len(input.Statements)) + responses := make([]dynamodbtypes.BatchStatementResponse, 0, len(input.Statements)) for _, statement := range input.Statements { - assert.Equal(r.t, hasDelete, strings.HasPrefix(strings.ToLower(aws.StringValue(statement.Statement)), "delete")) - switch aws.StringValue(statement.Statement) { + assert.Equal(r.t, hasDelete, strings.HasPrefix(strings.ToLower(*statement.Statement), "delete")) + switch *statement.Statement { case "DELETE FROM \"test-table\" WHERE \"k\"=? AND \"o\"=?": assert.True(r.t, r.changesApplied, "unexpected delete before provider changes") - key := aws.StringValue(statement.Parameters[0].S) + var key string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[0], &key)) assert.True(r.t, r.stubConfig.ExpectDelete.Has(key), "unexpected delete for key %q", key) r.stubConfig.ExpectDelete.Delete(key) - assert.Equal(r.t, "test-owner", aws.StringValue(statement.Parameters[1].S)) + var testOwner string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[1], &testOwner)) + assert.Equal(r.t, "test-owner", testOwner) - responses = append(responses, &dynamodb.BatchStatementResponse{}) + responses = append(responses, dynamodbtypes.BatchStatementResponse{}) case "INSERT INTO \"test-table\" VALUE {'k':?, 'o':?, 'l':?}": assert.False(r.t, r.changesApplied, "unexpected insert after provider changes") - key := aws.StringValue(statement.Parameters[0].S) + var key string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[0], &key)) if code, exists := r.stubConfig.ExpectInsertError[key]; exists { delete(r.stubConfig.ExpectInsertError, key) - responses = append(responses, &dynamodb.BatchStatementResponse{ - Error: &dynamodb.BatchStatementError{ - Code: aws.String(code), + responses = append(responses, dynamodbtypes.BatchStatementResponse{ + Error: &dynamodbtypes.BatchStatementError{ + Code: code, Message: aws.String("testing error"), }, }) @@ -1218,10 +1224,15 @@ func (r *DynamoDBStub) BatchExecuteStatementWithContext(context aws.Context, inp assert.True(r.t, found, "unexpected insert for key %q", key) delete(r.stubConfig.ExpectInsert, key) - assert.Equal(r.t, "test-owner", aws.StringValue(statement.Parameters[1].S)) + var testOwner string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[1], &testOwner)) + assert.Equal(r.t, "test-owner", testOwner) - for label, attribute := range statement.Parameters[2].M { - value := aws.StringValue(attribute.S) + var labels map[string]string + err := attributevalue.Unmarshal(statement.Parameters[2], &labels) + assert.Nil(r.t, err) + + for label, value := range labels { expectedValue, found := expectedLabels[label] assert.True(r.t, found, "insert for key %q has unexpected label %q", key, label) delete(expectedLabels, label) @@ -1232,17 +1243,18 @@ func (r *DynamoDBStub) BatchExecuteStatementWithContext(context aws.Context, inp r.t.Errorf("insert for key %q did not get expected label %q", key, label) } - responses = append(responses, &dynamodb.BatchStatementResponse{}) + responses = append(responses, dynamodbtypes.BatchStatementResponse{}) case "UPDATE \"test-table\" SET \"l\"=? WHERE \"k\"=?": assert.False(r.t, r.changesApplied, "unexpected update after provider changes") - key := aws.StringValue(statement.Parameters[1].S) + var key string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[1], &key)) if code, exists := r.stubConfig.ExpectUpdateError[key]; exists { delete(r.stubConfig.ExpectInsertError, key) - responses = append(responses, &dynamodb.BatchStatementResponse{ - Error: &dynamodb.BatchStatementError{ - Code: aws.String(code), + responses = append(responses, dynamodbtypes.BatchStatementResponse{ + Error: &dynamodbtypes.BatchStatementError{ + Code: code, Message: aws.String("testing error"), }, }) @@ -1253,8 +1265,10 @@ func (r *DynamoDBStub) BatchExecuteStatementWithContext(context aws.Context, inp assert.True(r.t, found, "unexpected update for key %q", key) delete(r.stubConfig.ExpectUpdate, key) - for label, attribute := range statement.Parameters[0].M { - value := aws.StringValue(attribute.S) + var labels map[string]string + assert.Nil(r.t, attributevalue.Unmarshal(statement.Parameters[0], &labels)) + + for label, value := range labels { expectedValue, found := expectedLabels[label] assert.True(r.t, found, "update for key %q has unexpected label %q", key, label) delete(expectedLabels, label) @@ -1265,10 +1279,10 @@ func (r *DynamoDBStub) BatchExecuteStatementWithContext(context aws.Context, inp r.t.Errorf("update for key %q did not get expected label %q", key, label) } - responses = append(responses, &dynamodb.BatchStatementResponse{}) + responses = append(responses, dynamodbtypes.BatchStatementResponse{}) default: - r.t.Errorf("unexpected statement: %s", aws.StringValue(statement.Statement)) + r.t.Errorf("unexpected statement: %s", *statement.Statement) } }