diff --git a/service/rds/customizations.go b/service/rds/customizations.go index 536af93adff..aea8fe6b4dd 100644 --- a/service/rds/customizations.go +++ b/service/rds/customizations.go @@ -47,6 +47,10 @@ func copyDBSnapshotPresign(r *request.Request) { } originParams.DestinationRegion = aws.String(r.Config.Region) + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } newParams := awsutil.CopyOf(r.Params).(*CopyDBSnapshotInput) originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) } @@ -59,6 +63,10 @@ func createDBInstanceReadReplicaPresign(r *request.Request) { } originParams.DestinationRegion = aws.String(r.Config.Region) + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } newParams := awsutil.CopyOf(r.Params).(*CreateDBInstanceReadReplicaInput) originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) } @@ -71,6 +79,10 @@ func copyDBClusterSnapshotPresign(r *request.Request) { } originParams.DestinationRegion = aws.String(r.Config.Region) + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput) originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) } @@ -83,6 +95,10 @@ func createDBClusterPresign(r *request.Request) { } originParams.DestinationRegion = aws.String(r.Config.Region) + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput) originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) } diff --git a/service/rds/customizations_test.go b/service/rds/customizations_test.go index 2bf38fa4c5e..e928bd553ea 100644 --- a/service/rds/customizations_test.go +++ b/service/rds/customizations_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "net/url" "regexp" - "strings" "testing" "time" @@ -16,8 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" ) -func TestPresignWithPresignNotSet(t *testing.T) { - reqs := map[string]*request.Request{} +func TestCopyDBSnapshotNoPanic(t *testing.T) { cfg := unit.Config() cfg.Region = "us-west-2" @@ -34,73 +32,184 @@ func TestPresignWithPresignNotSet(t *testing.T) { t.Errorf("expect no panic, got %v", p) } - reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ - SourceRegion: aws.String("us-west-1"), - SourceDBSnapshotIdentifier: aws.String("foo"), - TargetDBSnapshotIdentifier: aws.String("bar"), - }).Request - - reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{ - SourceRegion: aws.String("us-west-1"), - SourceDBInstanceIdentifier: aws.String("foo"), - DBInstanceIdentifier: aws.String("bar"), - }).Request - - for op, req := range reqs { - req.Sign() - b, _ := ioutil.ReadAll(req.HTTPRequest.Body) - q, _ := url.ParseQuery(string(b)) - - u, _ := url.QueryUnescape(q.Get("PreSignedUrl")) - - exp := fmt.Sprintf(`^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=us-west-2.+`, op) - if re, a := regexp.MustCompile(exp), u; !re.MatchString(a) { - t.Errorf("expect %s to match %s", re, a) - } - } } -func TestPresignWithPresignSet(t *testing.T) { - reqs := map[string]*request.Request{} +func TestPresignCrossRegionRequest(t *testing.T) { cfg := unit.Config() cfg.Region = "us-west-2" + cfg.EndpointResolver = endpoints.NewDefaultResolver() svc := New(cfg) + const regexPattern= `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+` - f := func() { - // Doesn't panic on nil input - req := svc.CopyDBSnapshotRequest(nil) - req.Sign() - } - if paniced, p := awstesting.DidPanic(f); paniced { - t.Errorf("expect no panic, got %v", p) - } + cases := map[string]struct { + Req *request.Request + Assert func(*testing.T, string) + }{ + opCopyDBSnapshot: { + Req: func() *request.Request { + req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBSnapshotIdentifier: aws.String("foo"), + TargetDBSnapshotIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCopyDBSnapshot, cfg.Region)), + }, - reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ - SourceRegion: aws.String("us-west-1"), - SourceDBSnapshotIdentifier: aws.String("foo"), - TargetDBSnapshotIdentifier: aws.String("bar"), - PreSignedUrl: aws.String("presignedURL"), - }).Request + opCreateDBInstanceReadReplica: { + Req: func() *request.Request { + req := svc.CreateDBInstanceReadReplicaRequest( + &CreateDBInstanceReadReplicaInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBInstanceIdentifier: aws.String("foo"), + DBInstanceIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCreateDBInstanceReadReplica, cfg.Region)), + }, + opCopyDBClusterSnapshot: { + Req: func() *request.Request { + req := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCopyDBClusterSnapshot, cfg.Region)), + }, + opCreateDBCluster: { + Req: func() *request.Request { + req := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCreateDBCluster, cfg.Region)), + }, + opCopyDBSnapshot + " same region": { + Req: func() *request.Request { + req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ + SourceRegion: aws.String("us-west-2"), + SourceDBSnapshotIdentifier: aws.String("foo"), + TargetDBSnapshotIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsEmpty(), + }, + opCreateDBInstanceReadReplica + " same region": { + Req: func() *request.Request { + req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{ + SourceRegion: aws.String("us-west-2"), + SourceDBInstanceIdentifier: aws.String("foo"), + DBInstanceIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsEmpty(), + }, + opCopyDBClusterSnapshot + " same region": { + Req: func() *request.Request { + req := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-2"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsEmpty(), + }, + opCreateDBCluster + " same region": { + Req: func() *request.Request { + req := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-2"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + }) + return req.Request + }(), + Assert: assertAsEmpty(), + }, + opCopyDBSnapshot + " presignURL set": { + Req: func() *request.Request { + req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBSnapshotIdentifier: aws.String("foo"), + TargetDBSnapshotIdentifier: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req.Request + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + opCreateDBInstanceReadReplica + " presignURL set": { + Req: func() *request.Request { + req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBInstanceIdentifier: aws.String("foo"), + DBInstanceIdentifier: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req.Request + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + opCopyDBClusterSnapshot + " presignURL set": { + Req: func() *request.Request { + req := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req.Request + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + opCreateDBCluster + " presignURL set": { + Req: func() *request.Request { + req := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req.Request + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + } - reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{ - SourceRegion: aws.String("us-west-1"), - SourceDBInstanceIdentifier: aws.String("foo"), - DBInstanceIdentifier: aws.String("bar"), - PreSignedUrl: aws.String("presignedURL"), - }).Request + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if err := c.Req.Sign(); err != nil { + t.Fatalf("expect no error, got %v", err) + } + b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body) + q, _ := url.ParseQuery(string(b)) - for _, req := range reqs { - req.Sign() + u, _ := url.QueryUnescape(q.Get("PreSignedUrl")) - b, _ := ioutil.ReadAll(req.HTTPRequest.Body) - q, _ := url.ParseQuery(string(b)) + c.Assert(t, u) - u, _ := url.QueryUnescape(q.Get("PreSignedUrl")) - if e, a := "presignedURL", u; !strings.Contains(a, e) { - t.Errorf("expect %s to be in %s", e, a) - } + }) } } @@ -112,15 +221,6 @@ func TestPresignWithSourceNotSet(t *testing.T) { svc := New(cfg) - f := func() { - // Doesn't panic on nil input - req := svc.CopyDBSnapshotRequest(nil) - req.Sign() - } - if paniced, p := awstesting.DidPanic(f); paniced { - t.Errorf("expect no panic, got %v", p) - } - reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ SourceDBSnapshotIdentifier: aws.String("foo"), TargetDBSnapshotIdentifier: aws.String("bar"), @@ -133,3 +233,33 @@ func TestPresignWithSourceNotSet(t *testing.T) { } } } + +func assertAsRegexMatch(exp string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) { + t.Errorf("expect %s to match %s", re, a) + } + } +} + +func assertAsEmpty() func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if len(v) != 0 { + t.Errorf("expect empty, got %v", v) + } + } +} + +func assertAsEqual(expect string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if e, a := expect, v; e != a { + t.Errorf("expect %v, got %v", e, a) + } + } +}