From a849e78d05cc77ce7df27b247c6c938b0fa54948 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte <961963+jasdel@users.noreply.github.com> Date: Fri, 22 Apr 2022 15:19:31 -0700 Subject: [PATCH] `aws/request` Fixes WithSetRequestHeaders key added to header map directly Addresses an issue where the header keys being added were being added directly to the header map, and did not have the canonical header casing applied. This introduced bugs where instead of overwriting existing header key, it added another map entry. --- CHANGELOG_PENDING.md | 2 ++ aws/request/handlers.go | 2 +- aws/request/handlers_test.go | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..4d7c9e4693f 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,4 +1,6 @@ ### SDK Features +* `aws/request`: Fixes bug in WithSetRequestHeaders where the header key was added to the header map directly + * Addresses an issue where the header keys being added were being added directly to the header map, and did not have the canonical header casing applied. This introduced bugs where instead of overwriting existing header key, it added another map entry. ### SDK Enhancements diff --git a/aws/request/handlers.go b/aws/request/handlers.go index e819ab6c0e8..e0fc048ebec 100644 --- a/aws/request/handlers.go +++ b/aws/request/handlers.go @@ -338,6 +338,6 @@ type withRequestHeader map[string]string func (h withRequestHeader) SetRequestHeaders(r *Request) { for k, v := range h { - r.HTTPRequest.Header[k] = []string{v} + r.HTTPRequest.Header.Set(k, v) } } diff --git a/aws/request/handlers_test.go b/aws/request/handlers_test.go index b2da558d6a0..cd922bfe613 100644 --- a/aws/request/handlers_test.go +++ b/aws/request/handlers_test.go @@ -1,6 +1,7 @@ package request_test import ( + "net/http" "reflect" "testing" @@ -197,6 +198,39 @@ func TestStopHandlers(t *testing.T) { } } +func TestWithSetRequestHeaders(t *testing.T) { + fn := request.WithSetRequestHeaders(map[string]string{ + "x-foo-bar": "abc123", + "X-Bar-foo": "efg456", + }) + + req := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} + fn(req) + + expect := map[string][]string{ + "X-Foo-Bar": {"abc123"}, + "X-Bar-Foo": {"efg456"}, + } + + if e, a := len(req.HTTPRequest.Header), len(expect); e != a { + t.Fatalf("expect %v headers, got %v", e, a) + } + for k, expectVs := range expect { + actualVs, ok := req.HTTPRequest.Header[k] + if !ok { + t.Errorf("expect %v header", k) + } + if e, a := len(expectVs), len(actualVs); e != a { + t.Fatalf("expect %v values for %v, got %v", e, k, a) + } + for i, expectV := range expectVs { + if e, a := expectV, actualVs[i]; e != a { + t.Errorf("expect %v[%d] to be %v, got %v", k, i, e, a) + } + } + } +} + func BenchmarkNewRequest(b *testing.B) { svc := s3.New(unit.Session)