diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..4d7c9e4693f 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,4 +1,6 @@ ### SDK Features +* `aws/request`: Fixes bug in WithSetRequestHeaders where the header key was added to the header map directly + * Addresses an issue where the header keys being added were being added directly to the header map, and did not have the canonical header casing applied. This introduced bugs where instead of overwriting existing header key, it added another map entry. ### SDK Enhancements diff --git a/aws/request/handlers.go b/aws/request/handlers.go index e819ab6c0e8..9556332b65e 100644 --- a/aws/request/handlers.go +++ b/aws/request/handlers.go @@ -330,6 +330,9 @@ func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) { // WithSetRequestHeaders updates the operation request's HTTP header to contain // the header key value pairs provided. If the header key already exists in the // request's HTTP header set, the existing value(s) will be replaced. +// +// Header keys added will be added as canonical format with title casing +// applied via http.Header.Set method. func WithSetRequestHeaders(h map[string]string) Option { return withRequestHeader(h).SetRequestHeaders } @@ -338,6 +341,6 @@ type withRequestHeader map[string]string func (h withRequestHeader) SetRequestHeaders(r *Request) { for k, v := range h { - r.HTTPRequest.Header[k] = []string{v} + r.HTTPRequest.Header.Set(k, v) } } diff --git a/aws/request/handlers_test.go b/aws/request/handlers_test.go index b2da558d6a0..cd922bfe613 100644 --- a/aws/request/handlers_test.go +++ b/aws/request/handlers_test.go @@ -1,6 +1,7 @@ package request_test import ( + "net/http" "reflect" "testing" @@ -197,6 +198,39 @@ func TestStopHandlers(t *testing.T) { } } +func TestWithSetRequestHeaders(t *testing.T) { + fn := request.WithSetRequestHeaders(map[string]string{ + "x-foo-bar": "abc123", + "X-Bar-foo": "efg456", + }) + + req := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} + fn(req) + + expect := map[string][]string{ + "X-Foo-Bar": {"abc123"}, + "X-Bar-Foo": {"efg456"}, + } + + if e, a := len(req.HTTPRequest.Header), len(expect); e != a { + t.Fatalf("expect %v headers, got %v", e, a) + } + for k, expectVs := range expect { + actualVs, ok := req.HTTPRequest.Header[k] + if !ok { + t.Errorf("expect %v header", k) + } + if e, a := len(expectVs), len(actualVs); e != a { + t.Fatalf("expect %v values for %v, got %v", e, k, a) + } + for i, expectV := range expectVs { + if e, a := expectV, actualVs[i]; e != a { + t.Errorf("expect %v[%d] to be %v, got %v", k, i, e, a) + } + } + } +} + func BenchmarkNewRequest(b *testing.B) { svc := s3.New(unit.Session)