diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 939daaea79..3c4fdc0529 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -4,6 +4,7 @@ import ( "errors" "net/url" "reflect" + "strings" "time" "github.com/gofiber/fiber/v2" @@ -220,7 +221,7 @@ func isCsrfFromCookie(extractor interface{}) bool { // returns an error if the referer header is not present or is invalid // returns nil if the referer header is valid func refererMatchesHost(c *fiber.Ctx) error { - referer := c.Get(fiber.HeaderReferer) + referer := strings.ToLower(c.Get(fiber.HeaderReferer)) if referer == "" { return ErrNoReferer } @@ -230,9 +231,9 @@ func refererMatchesHost(c *fiber.Ctx) error { return ErrBadReferer } - if refererURL.Scheme+"://"+refererURL.Host != c.Protocol()+"://"+c.Hostname() { - return ErrBadReferer + if refererURL.Scheme == c.Protocol() && refererURL.Host == c.Hostname() { + return nil } - return nil + return ErrBadReferer } diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index a51a932386..60f93abef1 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -992,7 +992,10 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) { return c.SendStatus(fiber.StatusTeapot) }) - fctx := &fasthttp.RequestCtx{} + app.Post("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTeapot) + }) + h := app.Handler() ctx := &fasthttp.RequestCtx{} @@ -1002,17 +1005,27 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) { token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] + // Test Correct Referer POST + ctx.Request.Reset() + ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") + ctx.Request.URI().SetScheme("https") + ctx.Request.URI().SetHost("example.com") + ctx.Request.Header.SetProtocol("https") + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) + ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { - h(fctx) + h(ctx) } - utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) + utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 @@ -1024,7 +1037,6 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { return c.SendStatus(fiber.StatusTeapot) }) - fctx := &fasthttp.RequestCtx{} h := app.Handler() ctx := &fasthttp.RequestCtx{} @@ -1034,8 +1046,9 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - h(fctx) + h(ctx) } - utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) + // Ensure the GET request returns a 418 status code + utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) }