diff --git a/sigv4/sigv4.go b/sigv4/sigv4.go index 9c8b9834..007cd77d 100644 --- a/sigv4/sigv4.go +++ b/sigv4/sigv4.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "net/http" "net/textproto" + "path" "sync" "time" @@ -115,9 +116,9 @@ func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error }() req.Body = ioutil.NopCloser(seeker) - // Escape URL like documented in AWS documentation. - // https://docs.aws.amazon.com/sdk-for-go/api/aws/signer/v4/#pkg-overview - req.URL.Path = req.URL.EscapedPath() + // Clean path like documented in AWS documentation. + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + req.URL.Path = path.Clean(req.URL.Path) // Clone the request and trim out headers that we don't want to sign. signReq := req.Clone(req.Context()) diff --git a/sigv4/sigv4_test.go b/sigv4/sigv4_test.go index 18313d8e..700db669 100644 --- a/sigv4/sigv4_test.go +++ b/sigv4/sigv4_test.go @@ -64,7 +64,7 @@ func TestSigV4RoundTripper(t *testing.T) { cli := &http.Client{Transport: rt} - req, err := http.NewRequest(http.MethodPost, "google.com", strings.NewReader("Hello, world!")) + req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) require.NoError(t, err) _, err = cli.Do(req) @@ -78,7 +78,7 @@ func TestSigV4RoundTripper(t *testing.T) { // Perform the same request but with a header that shouldn't included in the // signature; validate that the Authorization signature matches. t.Run("Ignored Headers", func(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "google.com", strings.NewReader("Hello, world!")) + req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!")) require.NoError(t, err) req.Header.Add("Uber-Trace-Id", "some-trace-id") @@ -91,12 +91,14 @@ func TestSigV4RoundTripper(t *testing.T) { }) t.Run("Escape URL", func(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "google.com/test//test", strings.NewReader("Hello, world!")) + req, err := http.NewRequest(http.MethodPost, "https://example.com/test//test", strings.NewReader("Hello, world!")) require.NoError(t, err) - require.Equal(t, "google.com/test//test", req.URL.Path) + require.Equal(t, "/test//test", req.URL.Path) - // Escape URL and check - req.URL.Path = req.URL.EscapedPath() - require.Equal(t, "google.com/test/test", req.URL.Path) + _, err = cli.Do(req) + require.NoError(t, err) + require.NotNil(t, gotReq) + + require.Equal(t, "/test/test", gotReq.URL.Path) }) }