Skip to content

Commit

Permalink
feat: add csrf middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
katallaxie authored Oct 28, 2024
1 parent 784b987 commit 2c32192
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 106 deletions.
48 changes: 12 additions & 36 deletions adapters/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,6 @@ func init() {
gob.Register(&GothCsrfToken{})
}

// CsrfTokenGenerator is a function that generates a CSRF token.
type CsrfTokenGenerator func() (string, error)

// DefaultCsrfTokenGenerator generates a new CSRF token.
func DefaultCsrfTokenGenerator() (string, error) {
token, err := uuid.NewV7()
if err != nil {
return "", err
}

return token.String(), nil
}

// AccountType represents the type of an account.
type AccountType string

Expand Down Expand Up @@ -156,8 +143,18 @@ func (s *GothSession) IsValid() bool {
}

// GetCsrfToken returns the CSRF token.
func (s *GothSession) GetCsrfToken() string {
return s.CsrfToken.Token
func (s *GothSession) GetCsrfToken() GothCsrfToken {
return s.CsrfToken
}

// HasExpired returns true if the session has expired.
func (c GothCsrfToken) HasExpired() bool {
return c.ExpiresAt.Before(time.Now())
}

// IsValid returns true if the token is valid.
func (c GothCsrfToken) IsValid(token string) bool {
return c.Token == token
}

// GothVerificationToken is a verification token for a user
Expand Down Expand Up @@ -206,12 +203,6 @@ type Adapter interface {
CreateVerificationToken(ctx context.Context, verficationToken GothVerificationToken) (GothVerificationToken, error)
// UseVerficationToken uses a verification token.
UseVerficationToken(ctx context.Context, identifier string, token string) (GothVerificationToken, error)
// CreateCsrfToken creates a new CSRF token.
CreateCsrfToken(ctx context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error)
// GetCsrfToken retrieves a CSRF token by token.
GetCsrfToken(ctx context.Context, token string) (GothCsrfToken, error)
// DeleteCsrfToken deletes a CSRF token by token.
DeleteCsrfToken(ctx context.Context, token string) error
}

var _ Adapter = (*UnimplementedAdapter)(nil)
Expand Down Expand Up @@ -293,18 +284,3 @@ func (a *UnimplementedAdapter) CreateVerificationToken(_ context.Context, erfica
func (a *UnimplementedAdapter) UseVerficationToken(_ context.Context, identifier string, token string) (GothVerificationToken, error) {
return GothVerificationToken{}, ErrUnimplemented
}

// CreateCsrfToken creates a new CSRF token.
func (a *UnimplementedAdapter) CreateCsrfToken(_ context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error) {
return GothCsrfToken{}, ErrUnimplemented
}

// GetCsrfToken retrieves a CSRF token by token.
func (a *UnimplementedAdapter) GetCsrfToken(_ context.Context, token string) (GothCsrfToken, error) {
return GothCsrfToken{}, ErrUnimplemented
}

// DeleteCsrfToken deletes a CSRF token by token.
func (a *UnimplementedAdapter) DeleteCsrfToken(_ context.Context, token string) error {
return ErrUnimplemented
}
131 changes: 61 additions & 70 deletions csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (

"github.com/google/uuid"
"github.com/valyala/fasthttp"
goth "github.com/zeiss/fiber-goth"
"github.com/zeiss/fiber-goth/adapters"
"github.com/zeiss/pkg/slices"
"github.com/zeiss/pkg/utilx"

"github.com/gofiber/fiber/v2"
Expand All @@ -16,6 +18,10 @@ var (
ErrMissingHeader = fiber.NewError(fiber.StatusForbidden, "missing csrf token in header")
// ErrTokenNotFound is returned when the token is not found in the session.
ErrTokenNotFound = fiber.NewError(fiber.StatusForbidden, "csrf token not found in session")
// ErrMissingSession is returned when the session is missing from the context.
ErrMissingSession = fiber.NewError(fiber.StatusForbidden, "missing session in context")
// ErrGenerateToken is returned when the token generator returns an error.
ErrGenerateToken = fiber.NewError(fiber.StatusForbidden, "failed to generate csrf token")
)

// HeaderName is the default header name used to extract the token.
Expand All @@ -38,6 +44,10 @@ type Config struct {
// Adapter adapters.Adapter
Adapter adapters.Adapter

// IgnoredMethods is a list of methods to ignore from CSRF protection.
// Optional. Default: []string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
IgnoredMethods []string

// ErrorHandler is executed when an error is returned from fiber.Handler.
//
// Optional. Default: DefaultErrorHandler
Expand Down Expand Up @@ -93,6 +103,7 @@ var ConfigDefault = Config{
ErrorHandler: defaultErrorHandler,
Extractor: FromHeader(HeaderName),
TokenGenerator: DefaultCsrfTokenGenerator,
IgnoredMethods: []string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace},
}

// CsrfTokenGenerator is a function that generates a CSRF token.
Expand Down Expand Up @@ -147,107 +158,87 @@ func configDefault(config ...Config) Config {
cfg.TokenGenerator = ConfigDefault.TokenGenerator
}

return cfg
}
if cfg.IgnoredMethods == nil {
cfg.IgnoredMethods = ConfigDefault.IgnoredMethods
}

// Handler ...
type Handler struct {
config Config
return cfg
}

// New creates a new csrf middleware.
// nolint:gocyclo
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)

// handler := &Handler{
// config: cfg,
// }

var token string

// Return new handler
return func(c *fiber.Ctx) error {
// Skip middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}

switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
// cookieToken := c.Cookies(cfg.CookieName)
default:
extractedToken, err := cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}

if utilx.Empty(extractedToken) {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

raw := ""

if utilx.Empty(raw) {
// expire the token
cookieValue := fasthttp.Cookie{}
cookieValue.SetKey(cfg.CookieName)
cookieValue.SetValueBytes([]byte(""))
cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly)
cookieValue.SetSameSite(cfg.CookieSameSite)
cookieValue.SetExpire(time.Now().Add(-time.Hour))
cookieValue.SetPath(cfg.CookiePath)
cookieValue.SetDomain(cfg.CookieDomain)
cookieValue.SetSecure(cfg.CookieSecure)

// Set the cookie
c.Response().Header.SetCookie(&cookieValue)

return cfg.ErrorHandler(c, ErrTokenNotFound)
}
// extract the session
session, err := goth.SessionFromContext(c)
if err != nil {
return cfg.ErrorHandler(c, ErrMissingSession)
}

// Skip middleware if the method is ignored
if slices.Any(func(method string) bool { return method == c.Method() }, cfg.IgnoredMethods...) {
return c.Next()
}

// extract the token
token, err := cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

// Generate a new token
// if the token is empty, abort
if utilx.Empty(token) {
// csrfToken, err := cfg.TokenGenerator()
// if err != nil {
// return cfg.ErrorHandler(c, err)
// }
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

if session.GetCsrfToken().HasExpired() {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

// Create the cookie
cookieValue := fasthttp.Cookie{}
cookieValue.SetKey(cfg.CookieName)
cookieValue.SetValueBytes([]byte(token))
cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly)
cookieValue.SetSameSite(cfg.CookieSameSite)
cookieValue.SetExpire(time.Now().Add(cfg.IdleTimeout))
cookieValue.SetPath(cfg.CookiePath)
cookieValue.SetDomain(cfg.CookieDomain)
cookieValue.SetSecure(cfg.CookieSecure)
if !session.GetCsrfToken().IsValid(token) {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

// Set the cookie
c.Response().Header.SetCookie(&cookieValue)
t, err := cfg.TokenGenerator()
if err != nil {
return cfg.ErrorHandler(c, ErrGenerateToken)
}

// Add the token to the context
c.Vary(fiber.HeaderCookie)
session.CsrfToken = adapters.GothCsrfToken{
Token: t,
ExpiresAt: time.Now().Add(cfg.IdleTimeout),
}

session, err = cfg.Adapter.UpdateSession(c.Context(), session)
if err != nil {
return cfg.ErrorHandler(c, err)
}

// Add the token to the context
c.Locals(csrfTokenKey, token)
// Set the session in the context
c.Locals(csrfTokenKey, session.CsrfToken)

// Continue stack
// continue stack
return c.Next()
}
}

// CsrfTokenFromContext returns the csrf token from the context.
func CsrfTokenFromContext(c *fiber.Ctx) string {
token, ok := c.Locals(csrfTokenKey).(string)
// CsrfTokenFromContext returns the CSRF token from the context.
func CsrfTokenFromContext(c *fiber.Ctx) (string, error) {
token, ok := c.Locals(csrfTokenKey).(adapters.GothCsrfToken)
if !ok {
return ""
return "", ErrTokenNotFound
}

return token
return token.Token, nil
}

// FromHeader returns a function that extracts token from the request header.
Expand Down
10 changes: 10 additions & 0 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

goth "github.com/zeiss/fiber-goth"
gorm_adapter "github.com/zeiss/fiber-goth/adapters/gorm"
"github.com/zeiss/fiber-goth/csrf"
"github.com/zeiss/fiber-goth/providers"
"github.com/zeiss/fiber-goth/providers/entraid"
"github.com/zeiss/fiber-goth/providers/github"
Expand Down Expand Up @@ -132,6 +133,15 @@ func run(_ context.Context) error {
return c.JSON(session)
})

app.Get("/protected", func(c *fiber.Ctx) error {
t, err := csrf.CsrfTokenFromContext(c)
if err != nil {
return err
}

return c.SendString(t)
})

app.Get("/login", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextHTML)
return t.Execute(c.Response().BodyWriter(), providerIndex)
Expand Down
2 changes: 2 additions & 0 deletions goth.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ func (CompleteAuthCompleteHandler) New(cfg Config) fiber.Handler {
cookieValue.SetExpire(expires)
cookieValue.SetPath("/")

c.Vary(fiber.HeaderCookie)

c.Response().Header.SetCookie(&cookieValue)

return cfg.CompletionFilter(c)
Expand Down

0 comments on commit 2c32192

Please sign in to comment.