diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 6db8be27ea..ea91594a70 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -171,6 +171,12 @@ func New(config ...Config) fiber.Handler { // If the request does not have Origin header, the request is outside the scope of CORS if originHeader == "" { + // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches + // Unless all origins are allowed, we include the Vary header to cache the response correctly + if !allowAllOrigins { + c.Vary(fiber.HeaderOrigin) + } + return c.Next() } @@ -215,17 +221,28 @@ func New(config ...Config) fiber.Handler { // Simple request // Ommit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { + if !allowAllOrigins { + // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches + c.Vary(fiber.HeaderOrigin) + } setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg) return c.Next() } - // Preflight request + // Pre-flight request + + // Response to OPTIONS request should not be cached but, + // some caching can be configured to cache such responses. + // To Avoid poisoning the cache, we include the Vary header + // of preflight responses: c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" { c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork) c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true") } + c.Vary(fiber.HeaderOrigin) + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) // Send 204 No Content @@ -235,8 +252,6 @@ func New(config ...Config) fiber.Handler { // Function to set CORS headers func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { - c.Vary(fiber.HeaderOrigin) - if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' if allowOrigin == "*" { diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index c3eac3b4d0..5bf943e8b6 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -68,6 +68,33 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } +func Test_CORS_AllowOrigins_Vary(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New( + Config{ + AllowOrigins: "http://localhost", + }, + )) + + h := app.Handler() + + // Test Vary header non-Cors request + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") + + // Test Vary header Cors request + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + h(ctx) + require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") +} + // go test -run -v Test_CORS_Wildcard func Test_CORS_Wildcard(t *testing.T) { t.Parallel() @@ -95,6 +122,7 @@ func Test_CORS_Wildcard(t *testing.T) { // Check result require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response + require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) @@ -105,6 +133,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") handler(ctx) + require.NotContains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should not be set") require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) }