diff --git a/README.md b/README.md index 746f094..770c80d 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,19 @@ $ go get github.com/zeiss/fiber-goth * GitHub (github.com, Enterprise, and Enterprise Cloud) * Microsoft Entra ID +## CSRF + +The middleware supports CSRF protection. It is added via the following package. + +```golang +import "github.com/zeiss/fiber-goth/csrf" + +app := fiber.New() +app.Use(csrf.New()) +``` + +The CSRF protection depends on the session middleware. + ## Examples See [examples](https://github.com/zeiss/fiber-goth/tree/master/examples) to understand the provided interfaces diff --git a/csrf/csrf.go b/csrf/csrf.go index 0dda3e6..43428dc 100644 --- a/csrf/csrf.go +++ b/csrf/csrf.go @@ -4,7 +4,6 @@ import ( "time" "github.com/google/uuid" - "github.com/valyala/fasthttp" goth "github.com/zeiss/fiber-goth" "github.com/zeiss/fiber-goth/adapters" "github.com/zeiss/pkg/slices" @@ -22,6 +21,8 @@ var ( ErrMissingSession = fiber.NewError(fiber.StatusForbidden, "missing session in context") // ErrGenerateToken is returned when the token generator returns an error. ErrGenerateToken = fiber.NewError(fiber.StatusForbidden, "failed to generate csrf token") + // ErrMissingToken is returned when the token is missing from the request. + ErrMissingToken = fiber.NewError(fiber.StatusForbidden, "missing csrf token in request") ) // HeaderName is the default header name used to extract the token. @@ -56,35 +57,6 @@ type Config struct { // Extractor is the function used to extract the token from the request. Extractor func(c *fiber.Ctx) (string, error) - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool - - // Decides whether cookie should last for only the browser sesison. - // Ignores Expiration if set to true - CookieSessionOnly bool - - // SingleUseToken indicates if the CSRF token be destroyed - // and a new one generated on each use. - // - // Optional. Default: false - SingleUseToken bool - - // CookieName is the name of the cookie used to store the session. - CookieName string - - // CookieSameSite is the SameSite attribute of the cookie. - CookieSameSite fasthttp.CookieSameSite - - // CookiePath is the path of the cookie. - CookiePath string - - // CookieDomain is the domain of the cookie. - CookieDomain string - - // CookieHTTPOnly is the HTTPOnly attribute of the cookie. - CookieHTTPOnly bool - // TrustedOrigins is a list of origins that are allowed to set the cookie. TrustedOrigins []string @@ -98,8 +70,6 @@ type Config struct { // ConfigDefault is the default config. var ConfigDefault = Config{ IdleTimeout: 30 * time.Minute, - CookieName: "csrf_", - CookieSameSite: fasthttp.CookieSameSiteLaxMode, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), TokenGenerator: DefaultCsrfTokenGenerator, @@ -138,14 +108,6 @@ func configDefault(config ...Config) Config { cfg.IdleTimeout = ConfigDefault.IdleTimeout } - if cfg.CookieName == "" { - cfg.CookieName = ConfigDefault.CookieName - } - - if cfg.CookieSameSite == 0 { - cfg.CookieSameSite = ConfigDefault.CookieSameSite - } - if cfg.ErrorHandler == nil { cfg.ErrorHandler = ConfigDefault.ErrorHandler } @@ -253,3 +215,42 @@ func FromHeader(param string) func(c *fiber.Ctx) (string, error) { return token, nil } } + +// FromParam returns a function that extracts token from the request query parameter. +func FromParam(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + token := c.Params(param) + + if utilx.Empty(token) { + return "", ErrMissingToken + } + + return token, nil + } +} + +// FromForm returns a function that extracts token from the request form. +func FromForm(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + token := c.FormValue(param) + + if utilx.Empty(token) { + return "", ErrMissingToken + } + + return token, nil + } +} + +// FromQuery returns a function that extracts token from the request query parameter. +func FromQuery(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + token := c.Query(param) + + if utilx.Empty(token) { + return "", ErrMissingToken + } + + return token, nil + } +}