Skip to content

Commit

Permalink
fix(middleware/cors): Vary header handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sixcolors committed Mar 26, 2024
1 parent 83da096 commit 7b4a2aa
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
21 changes: 18 additions & 3 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ func New(config ...Config) fiber.Handler {

// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}

return c.Next()
}

Expand Down Expand Up @@ -215,17 +221,28 @@ func New(config ...Config) fiber.Handler {
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}

// Preflight request
// Pre-flight request

// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" {
c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork)
c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true")
}
c.Vary(fiber.HeaderOrigin)

setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)

// Send 204 No Content
Expand All @@ -235,8 +252,6 @@ func New(config ...Config) fiber.Handler {

// Function to set CORS headers
func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
Expand Down
29 changes: 29 additions & 0 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,33 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}

func Test_CORS_AllowOrigins_Vary(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(
Config{
AllowOrigins: "http://localhost",
},
))

h := app.Handler()

// Test Vary header non-Cors request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")

// Test Vary header Cors request
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
}

// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -95,6 +122,7 @@ func Test_CORS_Wildcard(t *testing.T) {

// Check result
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
Expand All @@ -105,6 +133,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
handler(ctx)

require.NotContains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should not be set")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
Expand Down

0 comments on commit 7b4a2aa

Please sign in to comment.