diff --git a/cors.go b/cors.go index 20a66d0..724f242 100644 --- a/cors.go +++ b/cors.go @@ -364,9 +364,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { // Note: the Fetch standard guarantees that at most one // Access-Control-Request-Headers header is present in the preflight request; // see step 5.2 in https://fetch.spec.whatwg.org/#cors-preflight-fetch-0. - reqHeaders, found := first(r.Header, "Access-Control-Request-Headers") - if found && !c.allowedHeadersAll && !c.allowedHeaders.Subsumes(reqHeaders[0]) { - c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders[0]) + // However, some gateways split that header into multiple headers of the same name; + // see https://github.com/rs/cors/issues/184. + reqHeaders, found := r.Header["Access-Control-Request-Headers"] + if found && !c.allowedHeadersAll && !c.allowedHeaders.Accepts(reqHeaders) { + c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders) return } if c.allowedOriginsAll { diff --git a/cors_test.go b/cors_test.go index a3c0aab..46430fc 100644 --- a/cors_test.go +++ b/cors_test.go @@ -15,26 +15,63 @@ var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) _, _ = w.Write(testResponse) }) -var allHeaders = []string{ - "Vary", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Headers", - "Access-Control-Allow-Credentials", - "Access-Control-Allow-Private-Network", - "Access-Control-Max-Age", - "Access-Control-Expose-Headers", +// For each key-value pair of this map, the value indicates whether the key +// is a list-based field (i.e. not a singleton field); +// see https://httpwg.org/specs/rfc9110.html#abnf.extension. +var allRespHeaders = map[string]bool{ + // see https://www.rfc-editor.org/rfc/rfc9110#section-12.5.5 + "Vary": true, + // see https://fetch.spec.whatwg.org/#http-new-header-syntax + "Access-Control-Allow-Origin": false, + "Access-Control-Allow-Credentials": false, + "Access-Control-Allow-Methods": true, + "Access-Control-Allow-Headers": true, + "Access-Control-Max-Age": false, + "Access-Control-Expose-Headers": true, + // see https://wicg.github.io/private-network-access/ + "Access-Control-Allow-Private-Network": false, } -func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) { +func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders http.Header) { t.Helper() - for _, name := range allHeaders { - got := strings.Join(resHeaders[name], ", ") + for name, listBased := range allRespHeaders { + got := resHeaders[name] want := expHeaders[name] - if got != want { + if !listBased && !slicesEqual(got, want) { t.Errorf("Response header %q = %q, want %q", name, got, want) + continue + } + if listBased && !slicesEqual(normalize(got), normalize(want)) { + t.Errorf("Response header %q = %q, want %q", name, got, want) + continue + } + } +} + +// normalize normalizes a list-based field value, +// preserving both empty elements and the order of elements. +func normalize(s []string) (res []string) { + for _, v := range s { + for _, e := range strings.Split(v, ",") { + e = strings.Trim(e, " \t") + res = append(res, e) } } + return +} + +// TODO: when updating go directive to 1.21 or later, +// use slices.Equal instead. +func slicesEqual(s1, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true } func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) { @@ -49,8 +86,8 @@ func TestSpec(t *testing.T) { name string options Options method string - reqHeaders map[string]string - resHeaders map[string]string + reqHeaders http.Header + resHeaders http.Header originAllowed bool }{ { @@ -59,9 +96,9 @@ func TestSpec(t *testing.T) { // Intentionally left blank. }, "GET", - map[string]string{}, - map[string]string{ - "Vary": "Origin", + http.Header{}, + http.Header{ + "Vary": {"Origin"}, }, true, }, @@ -71,12 +108,12 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"*"}, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "*", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"*"}, }, true, }, @@ -87,13 +124,13 @@ func TestSpec(t *testing.T) { AllowCredentials: true, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Credentials": "true", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Allow-Credentials": {"true"}, }, true, }, @@ -103,12 +140,12 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://foobar.com"}, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, }, true, }, @@ -118,12 +155,12 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://*.bar.com"}, }, "GET", - map[string]string{ - "Origin": "http://foo.bar.com", + http.Header{ + "Origin": {"http://foo.bar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foo.bar.com", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foo.bar.com"}, }, true, }, @@ -133,11 +170,11 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://foobar.com"}, }, "GET", - map[string]string{ - "Origin": "http://barbaz.com", + http.Header{ + "Origin": {"http://barbaz.com"}, }, - map[string]string{ - "Vary": "Origin", + http.Header{ + "Vary": {"Origin"}, }, false, }, @@ -147,11 +184,11 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://*.bar.com"}, }, "GET", - map[string]string{ - "Origin": "http://foo.baz.com", + http.Header{ + "Origin": {"http://foo.baz.com"}, }, - map[string]string{ - "Vary": "Origin", + http.Header{ + "Vary": {"Origin"}, }, false, }, @@ -163,12 +200,12 @@ func TestSpec(t *testing.T) { }, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, }, true, }, @@ -180,13 +217,13 @@ func TestSpec(t *testing.T) { }, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", - "Authorization": "secret", + http.Header{ + "Origin": {"http://foobar.com"}, + "Authorization": {"secret"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, }, true, }, @@ -198,13 +235,13 @@ func TestSpec(t *testing.T) { }, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", - "Authorization": "secret", + http.Header{ + "Origin": {"http://foobar.com"}, + "Authorization": {"secret"}, }, - map[string]string{ - "Vary": "Origin, Authorization", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin, Authorization"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, }, true, }, @@ -216,52 +253,52 @@ func TestSpec(t *testing.T) { }, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", - "Authorization": "not-secret", + http.Header{ + "Origin": {"http://foobar.com"}, + "Authorization": {"not-secret"}, }, - map[string]string{ - "Vary": "Origin", + http.Header{ + "Vary": {"Origin"}, }, false, }, { "MaxAge", Options{ - AllowedOrigins: []string{"http://example.com/"}, + AllowedOrigins: []string{"http://example.com"}, AllowedMethods: []string{"GET"}, MaxAge: 10, }, "OPTIONS", - map[string]string{ - "Origin": "http://example.com/", - "Access-Control-Request-Method": "GET", + http.Header{ + "Origin": {"http://example.com"}, + "Access-Control-Request-Method": {"GET"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://example.com/", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Max-Age": "10", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://example.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Max-Age": {"10"}, }, true, }, { "MaxAgeNegative", Options{ - AllowedOrigins: []string{"http://example.com/"}, + AllowedOrigins: []string{"http://example.com"}, AllowedMethods: []string{"GET"}, MaxAge: -1, }, "OPTIONS", - map[string]string{ - "Origin": "http://example.com/", - "Access-Control-Request-Method": "GET", + http.Header{ + "Origin": {"http://example.com"}, + "Access-Control-Request-Method": {"GET"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://example.com/", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Max-Age": "0", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://example.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Max-Age": {"0"}, }, true, }, @@ -272,14 +309,14 @@ func TestSpec(t *testing.T) { AllowedMethods: []string{"PUT", "DELETE"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "PUT", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"PUT"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "PUT", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"PUT"}, }, true, }, @@ -290,12 +327,12 @@ func TestSpec(t *testing.T) { AllowedMethods: []string{"PUT", "DELETE"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "PATCH", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"PATCH"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, }, true, }, @@ -306,16 +343,16 @@ func TestSpec(t *testing.T) { AllowedHeaders: []string{"X-Header-1", "x-header-2", "X-HEADER-3"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "x-header-1,x-header-2", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-header-1,x-header-2"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "x-header-1,x-header-2", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Headers": {"x-header-1,x-header-2"}, }, true, }, @@ -326,16 +363,16 @@ func TestSpec(t *testing.T) { AllowedHeaders: []string{}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "x-requested-with", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-requested-with"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "x-requested-with", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Headers": {"x-requested-with"}, }, true, }, @@ -346,16 +383,16 @@ func TestSpec(t *testing.T) { AllowedHeaders: []string{"*"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "x-header-1,x-header-2", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-header-1,x-header-2"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "x-header-1,x-header-2", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Headers": {"x-header-1,x-header-2"}, }, true, }, @@ -366,13 +403,13 @@ func TestSpec(t *testing.T) { AllowedHeaders: []string{"X-Header-1", "x-header-2"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "x-header-1,x-header-3", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"x-header-1,x-header-3"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, }, true, }, @@ -383,13 +420,13 @@ func TestSpec(t *testing.T) { ExposedHeaders: []string{"X-Header-1", "x-header-2"}, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Expose-Headers": "X-Header-1, X-Header-2", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Expose-Headers": {"X-Header-1, X-Header-2"}, }, true, }, @@ -400,15 +437,15 @@ func TestSpec(t *testing.T) { AllowCredentials: true, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Credentials": "true", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Credentials": {"true"}, }, true, }, @@ -419,16 +456,16 @@ func TestSpec(t *testing.T) { AllowPrivateNetwork: true, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-Private-Network": "true", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Private-Network": {"true"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Private-Network": "true", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Private-Network": {"true"}, }, true, }, @@ -438,15 +475,15 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://foobar.com"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", - "Access-Control-Request-PrivateNetwork": "true", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-PrivateNetwork": {"true"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "http://foobar.com", - "Access-Control-Allow-Methods": "GET", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, }, true, }, @@ -456,14 +493,14 @@ func TestSpec(t *testing.T) { OptionsPassthrough: true, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", - "Access-Control-Request-Method": "GET", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, }, - map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Allow-Methods": {"GET"}, }, true, }, @@ -473,12 +510,12 @@ func TestSpec(t *testing.T) { AllowedOrigins: []string{"http://foobar.com"}, }, "OPTIONS", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, }, true, }, { @@ -490,12 +527,52 @@ func TestSpec(t *testing.T) { }, }, "GET", - map[string]string{ - "Origin": "http://foobar.com", + http.Header{ + "Origin": {"http://foobar.com"}, + }, + http.Header{ + "Vary": {"Origin"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + }, + true, + }, + { + "MultipleACRHHeaders", + Options{ + AllowedOrigins: []string{"http://foobar.com"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + }, + "OPTIONS", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"authorization", "content-type"}, }, - map[string]string{ - "Vary": "Origin", - "Access-Control-Allow-Origin": "http://foobar.com", + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Headers": {"authorization", "content-type"}, + }, + true, + }, + { + "MultipleACRHHeadersWithOWSAndEmptyElements", + Options{ + AllowedOrigins: []string{"http://foobar.com"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + }, + "OPTIONS", + http.Header{ + "Origin": {"http://foobar.com"}, + "Access-Control-Request-Method": {"GET"}, + "Access-Control-Request-Headers": {"authorization\t", " ", " content-type"}, + }, + http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, + "Access-Control-Allow-Origin": {"http://foobar.com"}, + "Access-Control-Allow-Methods": {"GET"}, + "Access-Control-Allow-Headers": {"authorization\t", " ", " content-type"}, }, true, }, @@ -506,8 +583,10 @@ func TestSpec(t *testing.T) { s := New(tc.options) req, _ := http.NewRequest(tc.method, "http://example.com/foo", nil) - for name, value := range tc.reqHeaders { - req.Header.Add(name, value) + for name, values := range tc.reqHeaders { + for _, value := range values { + req.Header.Add(name, value) + } } t.Run("OriginAllowed", func(t *testing.T) { @@ -591,12 +670,12 @@ func TestHandlePreflightInvalidOriginAbortion(t *testing.T) { }) res := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) - req.Header.Add("Origin", "http://example.com/") + req.Header.Add("Origin", "http://example.com") s.handlePreflight(res, req) - assertHeaders(t, res.Header(), map[string]string{ - "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + assertHeaders(t, res.Header(), http.Header{ + "Vary": {"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}, }) } @@ -609,7 +688,7 @@ func TestHandlePreflightNoOptionsAbortion(t *testing.T) { s.handlePreflight(res, req) - assertHeaders(t, res.Header(), map[string]string{}) + assertHeaders(t, res.Header(), http.Header{}) } func TestHandleActualRequestInvalidOriginAbortion(t *testing.T) { @@ -618,12 +697,12 @@ func TestHandleActualRequestInvalidOriginAbortion(t *testing.T) { }) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com/foo", nil) - req.Header.Add("Origin", "http://example.com/") + req.Header.Add("Origin", "http://example.com") s.handleActualRequest(res, req) - assertHeaders(t, res.Header(), map[string]string{ - "Vary": "Origin", + assertHeaders(t, res.Header(), http.Header{ + "Vary": {"Origin"}, }) } @@ -634,12 +713,12 @@ func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) { }) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com/foo", nil) - req.Header.Add("Origin", "http://example.com/") + req.Header.Add("Origin", "http://example.com") s.handleActualRequest(res, req) - assertHeaders(t, res.Header(), map[string]string{ - "Vary": "Origin", + assertHeaders(t, res.Header(), http.Header{ + "Vary": {"Origin"}, }) } diff --git a/internal/sortedset.go b/internal/sortedset.go index 513da20..844f3f9 100644 --- a/internal/sortedset.go +++ b/internal/sortedset.go @@ -52,46 +52,134 @@ func (set SortedSet) String() string { return strings.Join(elems, ",") } -// Subsumes reports whether csv is a sequence of comma-separated names that are -// - all elements of set, -// - sorted in lexicographically order, +// Accepts reports whether values is a sequence of list-based field values +// whose elements are +// - all members of set, +// - sorted in lexicographical order, // - unique. -func (set SortedSet) Subsumes(csv string) bool { - if csv == "" { - return true +func (set SortedSet) Accepts(values []string) bool { + var ( // effectively constant + maxLen = maxOWSBytes + set.maxLen + maxOWSBytes + 1 // +1 for comma + ) + var ( + posOfLastNameSeen = -1 + name string + commaFound bool + emptyElements int + ok bool + ) + for _, s := range values { + for { + // As a defense against maliciously long names in s, + // we process only a small number of s's leading bytes per iteration. + name, s, commaFound = cutAtComma(s, maxLen) + name, ok = trimOWS(name, maxOWSBytes) + if !ok { + return false + } + if name == "" { + // RFC 9110 requires recipients to tolerate + // "a reasonable number of empty list elements"; see + // https://httpwg.org/specs/rfc9110.html#abnf.extension.recipient. + emptyElements++ + if emptyElements > maxEmptyElements { + return false + } + if !commaFound { // We have now exhausted the names in s. + break + } + continue + } + pos, ok := set.m[name] + if !ok { + return false + } + // The names in s are expected to be sorted in lexicographical order + // and to each appear at most once. + // Therefore, the positions (in set) of the names that + // appear in s should form a strictly increasing sequence. + // If that's not actually the case, bail out. + if pos <= posOfLastNameSeen { + return false + } + posOfLastNameSeen = pos + if !commaFound { // We have now exhausted the names in s. + break + } + } + } + return true +} + +const ( + maxOWSBytes = 1 // number of leading/trailing OWS bytes tolerated + maxEmptyElements = 16 // number of empty list elements tolerated +) + +func cutAtComma(s string, n int) (before, after string, found bool) { + // Note: this implementation draws inspiration from strings.Cut's. + end := min(len(s), n) + if i := strings.IndexByte(s[:end], ','); i >= 0 { + after = s[i+1:] // deal with this first to save one bounds check + return s[:i], after, true + } + return s, "", false +} + +// TrimOWS trims up to n bytes of [optional whitespace (OWS)] +// from the start of and/or the end of s. +// If no more than n bytes of OWS are found at the start of s +// and no more than n bytes of OWS are found at the end of s, +// it returns the trimmed result and true. +// Otherwise, it returns the original string and false. +// +// [optional whitespace (OWS)]: https://httpwg.org/specs/rfc9110.html#whitespace +func trimOWS(s string, n int) (trimmed string, ok bool) { + if s == "" { + return s, true + } + trimmed, ok = trimRightOWS(s, n) + if !ok { + return s, false } - posOfLastNameSeen := -1 - chunkSize := set.maxLen + 1 // (to accommodate for at least one comma) - for { - // As a defense against maliciously long names in csv, - // we only process at most chunkSize bytes per iteration. - end := min(len(csv), chunkSize) - comma := strings.IndexByte(csv[:end], ',') - var name string - if comma == -1 { - name = csv - } else { - name = csv[:comma] + trimmed, ok = trimLeftOWS(trimmed, n) + if !ok { + return s, false + } + return trimmed, true +} + +func trimLeftOWS(s string, n int) (string, bool) { + sCopy := s + var i int + for len(s) > 0 { + if i > n { + return sCopy, false } - pos, found := set.m[name] - if !found { - return false + if !(s[0] == ' ' || s[0] == '\t') { + break } - // The names in csv are expected to be sorted in lexicographical order - // and appear at most once in csv. - // Therefore, the positions (in set) of the names that - // appear in csv should form a strictly increasing sequence. - // If that's not actually the case, bail out. - if pos <= posOfLastNameSeen { - return false + s = s[1:] + i++ + } + return s, true +} + +func trimRightOWS(s string, n int) (string, bool) { + sCopy := s + var i int + for len(s) > 0 { + if i > n { + return sCopy, false } - posOfLastNameSeen = pos - if comma < 0 { // We've now processed all the names in csv. + last := len(s) - 1 + if !(s[last] == ' ' || s[last] == '\t') { break } - csv = csv[comma+1:] + s = s[:last] + i++ } - return true + return s, true } // TODO: when updating go directive to 1.21 or later, diff --git a/internal/sortedset_test.go b/internal/sortedset_test.go index 9727686..1a362fc 100644 --- a/internal/sortedset_test.go +++ b/internal/sortedset_test.go @@ -1,108 +1,207 @@ package internal import ( + "strings" "testing" ) func TestSortedSet(t *testing.T) { cases := []struct { - desc string - elems []string - combined string - subsets []string - notSubsets []string - wantSize int + desc string + elems []string + // expectations + size int + combined string + slice []string + accepted [][]string + rejected [][]string }{ { desc: "empty set", + size: 0, combined: "", - notSubsets: []string{ - "bar", - "bar,foo", + accepted: [][]string{ + // some empty elements, possibly with OWS + {""}, + {","}, + {"\t, , "}, + // multiple field lines, some empty elements + make([]string, maxEmptyElements), + }, + rejected: [][]string{ + {"x-bar"}, + {"x-bar,x-foo"}, + // too many empty elements + {strings.Repeat(",", maxEmptyElements+1)}, + // multiple field lines, too many empty elements + make([]string, maxEmptyElements+1), }, - wantSize: 0, }, { desc: "singleton set", - elems: []string{"foo"}, - combined: "foo", - subsets: []string{ - "", - "foo", + elems: []string{"x-foo"}, + size: 1, + combined: "x-foo", + slice: []string{"X-Foo"}, + accepted: [][]string{ + {"x-foo"}, + // some empty elements, possibly with OWS + {""}, + {","}, + {"\t, , "}, + {"\tx-foo ,"}, + {" x-foo\t,"}, + {strings.Repeat(",", maxEmptyElements) + "x-foo"}, + // multiple field lines, some empty elements + append(make([]string, maxEmptyElements), "x-foo"), + make([]string, maxEmptyElements), }, - notSubsets: []string{ - "bar", - "bar,foo", + rejected: [][]string{ + {"x-bar"}, + {"x-bar,x-foo"}, + // too much OWS + {"x-foo "}, + {" x-foo "}, + {" x-foo "}, + {"x-foo\t\t"}, + {"\tx-foo\t\t"}, + {"\t\tx-foo\t\t"}, + // too many empty elements + {strings.Repeat(",", maxEmptyElements+1) + "x-foo"}, + // multiple field lines, too many empty elements + append(make([]string, maxEmptyElements+1), "x-foo"), + make([]string, maxEmptyElements+1), }, - wantSize: 1, }, { desc: "no dupes", - elems: []string{"foo", "bar", "baz"}, - combined: "bar,baz,foo", - subsets: []string{ - "", - "bar", - "baz", - "foo", - "bar,baz", - "bar,foo", - "baz,foo", - "bar,baz,foo", + elems: []string{"x-foo", "x-bar", "x-baz"}, + size: 3, + combined: "x-bar,x-baz,x-foo", + slice: []string{"X-Bar", "X-Baz", "X-Foo"}, + accepted: [][]string{ + {"x-bar"}, + {"x-baz"}, + {"x-foo"}, + {"x-bar,x-baz"}, + {"x-bar,x-foo"}, + {"x-baz,x-foo"}, + {"x-bar,x-baz,x-foo"}, + // some empty elements, possibly with OWS + {""}, + {","}, + {"\t, , "}, + {"\tx-bar ,"}, + {" x-baz\t,"}, + {"x-foo,"}, + {"\tx-bar ,\tx-baz ,"}, + {" x-bar\t, x-foo\t,"}, + {"x-baz,x-foo,"}, + {" x-bar , x-baz , x-foo ,"}, + {"x-bar" + strings.Repeat(",", maxEmptyElements+1) + "x-foo"}, + // multiple field lines + {"x-bar", "x-foo"}, + {"x-bar", "x-baz,x-foo"}, + // multiple field lines, some empty elements + append(make([]string, maxEmptyElements), "x-bar", "x-foo"), + make([]string, maxEmptyElements), }, - notSubsets: []string{ - "qux", - "bar,baz,baz", - "qux,baz", - "qux,foo", - "quxbaz,foo", + rejected: [][]string{ + {"x-qux"}, + {"x-bar,x-baz,x-baz"}, + {"x-qux,x-baz"}, + {"x-qux,x-foo"}, + {"x-quxbaz,x-foo"}, + // too much OWS + {"x-bar "}, + {" x-baz "}, + {" x-foo "}, + {"x-bar\t\t,x-baz"}, + {"x-bar,\tx-foo\t\t"}, + {"\t\tx-baz,x-foo\t\t"}, + {" x-bar\t,\tx-baz\t ,x-foo"}, + // too many empty elements + {"x-bar" + strings.Repeat(",", maxEmptyElements+2) + "x-foo"}, + // multiple field lines, elements in the wrong order + {"x-foo", "x-bar"}, + // multiple field lines, too many empty elements + append(make([]string, maxEmptyElements+1), "x-bar", "x-foo"), + make([]string, maxEmptyElements+1), }, - wantSize: 3, }, { desc: "some dupes", - elems: []string{"foo", "bar", "bar", "foo", "e"}, - combined: "bar,e,foo", - subsets: []string{ - "", - "bar", - "e", - "foo", - "bar,foo", - "bar,e", - "e,foo", - "bar,e,foo", + elems: []string{"x-foo", "x-bar", "x-foo"}, + size: 2, + combined: "x-bar,x-foo", + slice: []string{"X-Bar", "X-Foo"}, + accepted: [][]string{ + {"x-bar"}, + {"x-foo"}, + {"x-bar,x-foo"}, + // some empty elements, possibly with OWS + {""}, + {","}, + {"\t, , "}, + {"\tx-bar ,"}, + {" x-foo\t,"}, + {"x-foo,"}, + {"\tx-bar ,\tx-foo ,"}, + {" x-bar\t, x-foo\t,"}, + {"x-bar,x-foo,"}, + {" x-bar , x-foo ,"}, + {"x-bar" + strings.Repeat(",", maxEmptyElements+1) + "x-foo"}, + // multiple field lines + {"x-bar", "x-foo"}, + // multiple field lines, some empty elements + append(make([]string, maxEmptyElements), "x-bar", "x-foo"), + make([]string, maxEmptyElements), }, - notSubsets: []string{ - "qux", - "qux,bar", - "qux,foo", - "qux,baz,foo", + rejected: [][]string{ + {"x-qux"}, + {"x-qux,x-bar"}, + {"x-qux,x-foo"}, + {"x-qux,x-baz,x-foo"}, + // too much OWS + {"x-qux "}, + {"x-qux,\t\tx-bar"}, + {"x-qux,x-foo\t\t"}, + {"\tx-qux , x-baz\t\t,x-foo"}, + // too many empty elements + {"x-bar" + strings.Repeat(",", maxEmptyElements+2) + "x-foo"}, + // multiple field lines, elements in the wrong order + {"x-foo", "x-bar"}, + // multiple field lines, too much whitespace + {"x-qux", "\t\tx-bar"}, + {"x-qux", "x-foo\t\t"}, + {"\tx-qux ", " x-baz\t\t,x-foo"}, + // multiple field lines, too many empty elements + append(make([]string, maxEmptyElements+1), "x-bar", "x-foo"), + make([]string, maxEmptyElements+1), }, - wantSize: 3, }, } for _, tc := range cases { f := func(t *testing.T) { elems := clone(tc.elems) - s := NewSortedSet(tc.elems...) - size := s.Size() - if s.Size() != tc.wantSize { + set := NewSortedSet(tc.elems...) + size := set.Size() + if set.Size() != tc.size { const tmpl = "NewSortedSet(%#v...).Size(): got %d; want %d" - t.Errorf(tmpl, elems, size, tc.wantSize) + t.Errorf(tmpl, elems, size, tc.size) } - combined := s.String() + combined := set.String() if combined != tc.combined { const tmpl = "NewSortedSet(%#v...).String(): got %q; want %q" t.Errorf(tmpl, elems, combined, tc.combined) } - for _, sub := range tc.subsets { - if !s.Subsumes(sub) { - const tmpl = "%q is not a subset of %q, but should be" - t.Errorf(tmpl, sub, s) + for _, a := range tc.accepted { + if !set.Accepts(a) { + const tmpl = "%q rejects %q, but should accept it" + t.Errorf(tmpl, set, a) } } - for _, notSub := range tc.notSubsets { - if s.Subsumes(notSub) { - const tmpl = "%q is a subset of %q, but should not be" - t.Errorf(tmpl, notSub, s) + for _, r := range tc.rejected { + if set.Accepts(r) { + const tmpl = "%q accepts %q, but should reject it" + t.Errorf(tmpl, set, r) } } } diff --git a/utils.go b/utils.go index 7019f45..41b0c28 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package cors import ( - "net/http" "strings" ) @@ -24,11 +23,3 @@ func convert(s []string, f func(string) string) []string { } return out } - -func first(hdrs http.Header, k string) ([]string, bool) { - v, found := hdrs[k] - if !found || len(v) == 0 { - return nil, false - } - return v[:1], true -}