Skip to content

Commit

Permalink
service/rds: Fix presign URL for same region (aws#331)
Browse files Browse the repository at this point in the history
Fixes RDS no-autopresign URL for same region issue for aws-sdk-go-v2.

Solves the issue by making sure that the presigned URLs are not created, when the source and destination regions are the same. Added and updated the tests accordingly.

Fix aws#271
  • Loading branch information
skotambkar authored and Kotambkar committed Aug 13, 2019
1 parent a357131 commit 467aa8a
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 66 deletions.
16 changes: 16 additions & 0 deletions service/rds/customizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func copyDBSnapshotPresign(r *request.Request) {
}

originParams.DestinationRegion = aws.String(r.Config.Region)
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}
newParams := awsutil.CopyOf(r.Params).(*CopyDBSnapshotInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -59,6 +63,10 @@ func createDBInstanceReadReplicaPresign(r *request.Request) {
}

originParams.DestinationRegion = aws.String(r.Config.Region)
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}
newParams := awsutil.CopyOf(r.Params).(*CreateDBInstanceReadReplicaInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -71,6 +79,10 @@ func copyDBClusterSnapshotPresign(r *request.Request) {
}

originParams.DestinationRegion = aws.String(r.Config.Region)
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}
newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand All @@ -83,6 +95,10 @@ func createDBClusterPresign(r *request.Request) {
}

originParams.DestinationRegion = aws.String(r.Config.Region)
// preSignedUrl is not required for instances in the same region.
if *originParams.SourceRegion == *originParams.DestinationRegion {
return
}
newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput)
originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams)
}
Expand Down
262 changes: 196 additions & 66 deletions service/rds/customizations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"io/ioutil"
"net/url"
"regexp"
"strings"
"testing"
"time"

Expand All @@ -16,8 +15,7 @@ import (
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
)

func TestPresignWithPresignNotSet(t *testing.T) {
reqs := map[string]*request.Request{}
func TestCopyDBSnapshotNoPanic(t *testing.T) {

cfg := unit.Config()
cfg.Region = "us-west-2"
Expand All @@ -34,73 +32,184 @@ func TestPresignWithPresignNotSet(t *testing.T) {
t.Errorf("expect no panic, got %v", p)
}

reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
}).Request

reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
}).Request

for op, req := range reqs {
req.Sign()
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))

u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))

exp := fmt.Sprintf(`^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=us-west-2.+`, op)
if re, a := regexp.MustCompile(exp), u; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
}
}

func TestPresignWithPresignSet(t *testing.T) {
reqs := map[string]*request.Request{}
func TestPresignCrossRegionRequest(t *testing.T) {

cfg := unit.Config()
cfg.Region = "us-west-2"
cfg.EndpointResolver = endpoints.NewDefaultResolver()

svc := New(cfg)
const regexPattern= `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+`

f := func() {
// Doesn't panic on nil input
req := svc.CopyDBSnapshotRequest(nil)
req.Sign()
}
if paniced, p := awstesting.DidPanic(f); paniced {
t.Errorf("expect no panic, got %v", p)
}
cases := map[string]struct {
Req *request.Request
Assert func(*testing.T, string)
}{
opCopyDBSnapshot: {
Req: func() *request.Request {
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCopyDBSnapshot, cfg.Region)),
},

reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("presignedURL"),
}).Request
opCreateDBInstanceReadReplica: {
Req: func() *request.Request {
req := svc.CreateDBInstanceReadReplicaRequest(
&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCreateDBInstanceReadReplica, cfg.Region)),
},
opCopyDBClusterSnapshot: {
Req: func() *request.Request {
req := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCopyDBClusterSnapshot, cfg.Region)),
},
opCreateDBCluster: {
Req: func() *request.Request {
req := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-1"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
opCreateDBCluster, cfg.Region)),
},
opCopyDBSnapshot + " same region": {
Req: func() *request.Request {
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-2"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsEmpty(),
},
opCreateDBInstanceReadReplica + " same region": {
Req: func() *request.Request {
req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-2"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsEmpty(),
},
opCopyDBClusterSnapshot + " same region": {
Req: func() *request.Request {
req := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-2"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsEmpty(),
},
opCreateDBCluster + " same region": {
Req: func() *request.Request {
req := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-2"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
})
return req.Request
}(),
Assert: assertAsEmpty(),
},
opCopyDBSnapshot + " presignURL set": {
Req: func() *request.Request {
req := svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req.Request
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCreateDBInstanceReadReplica + " presignURL set": {
Req: func() *request.Request {
req := svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req.Request
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCopyDBClusterSnapshot + " presignURL set": {
Req: func() *request.Request {
req := svc.CopyDBClusterSnapshotRequest(
&CopyDBClusterSnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceDBClusterSnapshotIdentifier: aws.String("foo"),
TargetDBClusterSnapshotIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req.Request
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
opCreateDBCluster + " presignURL set": {
Req: func() *request.Request {
req := svc.CreateDBClusterRequest(
&CreateDBClusterInput{
SourceRegion: aws.String("us-west-1"),
DBClusterIdentifier: aws.String("foo"),
Engine: aws.String("bar"),
PreSignedUrl: aws.String("mockPresignedURL"),
})
return req.Request
}(),
Assert: assertAsEqual("mockPresignedURL"),
},
}

reqs[opCreateDBInstanceReadReplica] = svc.CreateDBInstanceReadReplicaRequest(&CreateDBInstanceReadReplicaInput{
SourceRegion: aws.String("us-west-1"),
SourceDBInstanceIdentifier: aws.String("foo"),
DBInstanceIdentifier: aws.String("bar"),
PreSignedUrl: aws.String("presignedURL"),
}).Request
for name, c := range cases {
t.Run(name, func(t *testing.T) {
if err := c.Req.Sign(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))

for _, req := range reqs {
req.Sign()
u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))

b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
c.Assert(t, u)

u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
if e, a := "presignedURL", u; !strings.Contains(a, e) {
t.Errorf("expect %s to be in %s", e, a)
}
})
}
}

Expand All @@ -112,15 +221,6 @@ func TestPresignWithSourceNotSet(t *testing.T) {

svc := New(cfg)

f := func() {
// Doesn't panic on nil input
req := svc.CopyDBSnapshotRequest(nil)
req.Sign()
}
if paniced, p := awstesting.DidPanic(f); paniced {
t.Errorf("expect no panic, got %v", p)
}

reqs[opCopyDBSnapshot] = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{
SourceDBSnapshotIdentifier: aws.String("foo"),
TargetDBSnapshotIdentifier: aws.String("bar"),
Expand All @@ -133,3 +233,33 @@ func TestPresignWithSourceNotSet(t *testing.T) {
}
}
}

func assertAsRegexMatch(exp string) func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
}
}

func assertAsEmpty() func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

if len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
}

func assertAsEqual(expect string) func(*testing.T, string) {
return func(t *testing.T, v string) {
t.Helper()

if e, a := expect, v; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
}

0 comments on commit 467aa8a

Please sign in to comment.