From 2c32192f6a70fa14c3cd9e7631f6d98785a5925c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20D=C3=B6ll?= Date: Mon, 28 Oct 2024 21:12:31 +0000 Subject: [PATCH] feat: add csrf middleware --- adapters/adapter.go | 48 ++++------------ csrf/csrf.go | 131 +++++++++++++++++++++----------------------- examples/main.go | 10 ++++ goth.go | 2 + 4 files changed, 85 insertions(+), 106 deletions(-) diff --git a/adapters/adapter.go b/adapters/adapter.go index ceb1fef..dc04ee0 100644 --- a/adapters/adapter.go +++ b/adapters/adapter.go @@ -18,19 +18,6 @@ func init() { gob.Register(&GothCsrfToken{}) } -// CsrfTokenGenerator is a function that generates a CSRF token. -type CsrfTokenGenerator func() (string, error) - -// DefaultCsrfTokenGenerator generates a new CSRF token. -func DefaultCsrfTokenGenerator() (string, error) { - token, err := uuid.NewV7() - if err != nil { - return "", err - } - - return token.String(), nil -} - // AccountType represents the type of an account. type AccountType string @@ -156,8 +143,18 @@ func (s *GothSession) IsValid() bool { } // GetCsrfToken returns the CSRF token. -func (s *GothSession) GetCsrfToken() string { - return s.CsrfToken.Token +func (s *GothSession) GetCsrfToken() GothCsrfToken { + return s.CsrfToken +} + +// HasExpired returns true if the session has expired. +func (c GothCsrfToken) HasExpired() bool { + return c.ExpiresAt.Before(time.Now()) +} + +// IsValid returns true if the token is valid. +func (c GothCsrfToken) IsValid(token string) bool { + return c.Token == token } // GothVerificationToken is a verification token for a user @@ -206,12 +203,6 @@ type Adapter interface { CreateVerificationToken(ctx context.Context, verficationToken GothVerificationToken) (GothVerificationToken, error) // UseVerficationToken uses a verification token. UseVerficationToken(ctx context.Context, identifier string, token string) (GothVerificationToken, error) - // CreateCsrfToken creates a new CSRF token. - CreateCsrfToken(ctx context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error) - // GetCsrfToken retrieves a CSRF token by token. - GetCsrfToken(ctx context.Context, token string) (GothCsrfToken, error) - // DeleteCsrfToken deletes a CSRF token by token. - DeleteCsrfToken(ctx context.Context, token string) error } var _ Adapter = (*UnimplementedAdapter)(nil) @@ -293,18 +284,3 @@ func (a *UnimplementedAdapter) CreateVerificationToken(_ context.Context, erfica func (a *UnimplementedAdapter) UseVerficationToken(_ context.Context, identifier string, token string) (GothVerificationToken, error) { return GothVerificationToken{}, ErrUnimplemented } - -// CreateCsrfToken creates a new CSRF token. -func (a *UnimplementedAdapter) CreateCsrfToken(_ context.Context, csrfToken GothCsrfToken) (GothCsrfToken, error) { - return GothCsrfToken{}, ErrUnimplemented -} - -// GetCsrfToken retrieves a CSRF token by token. -func (a *UnimplementedAdapter) GetCsrfToken(_ context.Context, token string) (GothCsrfToken, error) { - return GothCsrfToken{}, ErrUnimplemented -} - -// DeleteCsrfToken deletes a CSRF token by token. -func (a *UnimplementedAdapter) DeleteCsrfToken(_ context.Context, token string) error { - return ErrUnimplemented -} diff --git a/csrf/csrf.go b/csrf/csrf.go index 7979a67..0dda3e6 100644 --- a/csrf/csrf.go +++ b/csrf/csrf.go @@ -5,7 +5,9 @@ import ( "github.com/google/uuid" "github.com/valyala/fasthttp" + goth "github.com/zeiss/fiber-goth" "github.com/zeiss/fiber-goth/adapters" + "github.com/zeiss/pkg/slices" "github.com/zeiss/pkg/utilx" "github.com/gofiber/fiber/v2" @@ -16,6 +18,10 @@ var ( ErrMissingHeader = fiber.NewError(fiber.StatusForbidden, "missing csrf token in header") // ErrTokenNotFound is returned when the token is not found in the session. ErrTokenNotFound = fiber.NewError(fiber.StatusForbidden, "csrf token not found in session") + // ErrMissingSession is returned when the session is missing from the context. + ErrMissingSession = fiber.NewError(fiber.StatusForbidden, "missing session in context") + // ErrGenerateToken is returned when the token generator returns an error. + ErrGenerateToken = fiber.NewError(fiber.StatusForbidden, "failed to generate csrf token") ) // HeaderName is the default header name used to extract the token. @@ -38,6 +44,10 @@ type Config struct { // Adapter adapters.Adapter Adapter adapters.Adapter + // IgnoredMethods is a list of methods to ignore from CSRF protection. + // Optional. Default: []string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} + IgnoredMethods []string + // ErrorHandler is executed when an error is returned from fiber.Handler. // // Optional. Default: DefaultErrorHandler @@ -93,6 +103,7 @@ var ConfigDefault = Config{ ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), TokenGenerator: DefaultCsrfTokenGenerator, + IgnoredMethods: []string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}, } // CsrfTokenGenerator is a function that generates a CSRF token. @@ -147,25 +158,19 @@ func configDefault(config ...Config) Config { cfg.TokenGenerator = ConfigDefault.TokenGenerator } - return cfg -} + if cfg.IgnoredMethods == nil { + cfg.IgnoredMethods = ConfigDefault.IgnoredMethods + } -// Handler ... -type Handler struct { - config Config + return cfg } // New creates a new csrf middleware. +// nolint:gocyclo func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) - // handler := &Handler{ - // config: cfg, - // } - - var token string - // Return new handler return func(c *fiber.Ctx) error { // Skip middleware if Next returns true @@ -173,81 +178,67 @@ func New(config ...Config) fiber.Handler { return c.Next() } - switch c.Method() { - case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: - // cookieToken := c.Cookies(cfg.CookieName) - default: - extractedToken, err := cfg.Extractor(c) - if err != nil { - return cfg.ErrorHandler(c, err) - } - - if utilx.Empty(extractedToken) { - return cfg.ErrorHandler(c, ErrTokenNotFound) - } - - raw := "" - - if utilx.Empty(raw) { - // expire the token - cookieValue := fasthttp.Cookie{} - cookieValue.SetKey(cfg.CookieName) - cookieValue.SetValueBytes([]byte("")) - cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly) - cookieValue.SetSameSite(cfg.CookieSameSite) - cookieValue.SetExpire(time.Now().Add(-time.Hour)) - cookieValue.SetPath(cfg.CookiePath) - cookieValue.SetDomain(cfg.CookieDomain) - cookieValue.SetSecure(cfg.CookieSecure) - - // Set the cookie - c.Response().Header.SetCookie(&cookieValue) - - return cfg.ErrorHandler(c, ErrTokenNotFound) - } + // extract the session + session, err := goth.SessionFromContext(c) + if err != nil { + return cfg.ErrorHandler(c, ErrMissingSession) + } + + // Skip middleware if the method is ignored + if slices.Any(func(method string) bool { return method == c.Method() }, cfg.IgnoredMethods...) { + return c.Next() + } + + // extract the token + token, err := cfg.Extractor(c) + if err != nil { + return cfg.ErrorHandler(c, ErrTokenNotFound) } - // Generate a new token + // if the token is empty, abort if utilx.Empty(token) { - // csrfToken, err := cfg.TokenGenerator() - // if err != nil { - // return cfg.ErrorHandler(c, err) - // } + return cfg.ErrorHandler(c, ErrTokenNotFound) + } + + if session.GetCsrfToken().HasExpired() { + return cfg.ErrorHandler(c, ErrTokenNotFound) } - // Create the cookie - cookieValue := fasthttp.Cookie{} - cookieValue.SetKey(cfg.CookieName) - cookieValue.SetValueBytes([]byte(token)) - cookieValue.SetHTTPOnly(cfg.CookieHTTPOnly) - cookieValue.SetSameSite(cfg.CookieSameSite) - cookieValue.SetExpire(time.Now().Add(cfg.IdleTimeout)) - cookieValue.SetPath(cfg.CookiePath) - cookieValue.SetDomain(cfg.CookieDomain) - cookieValue.SetSecure(cfg.CookieSecure) + if !session.GetCsrfToken().IsValid(token) { + return cfg.ErrorHandler(c, ErrTokenNotFound) + } - // Set the cookie - c.Response().Header.SetCookie(&cookieValue) + t, err := cfg.TokenGenerator() + if err != nil { + return cfg.ErrorHandler(c, ErrGenerateToken) + } - // Add the token to the context - c.Vary(fiber.HeaderCookie) + session.CsrfToken = adapters.GothCsrfToken{ + Token: t, + ExpiresAt: time.Now().Add(cfg.IdleTimeout), + } + + session, err = cfg.Adapter.UpdateSession(c.Context(), session) + if err != nil { + return cfg.ErrorHandler(c, err) + } - // Add the token to the context - c.Locals(csrfTokenKey, token) + // Set the session in the context + c.Locals(csrfTokenKey, session.CsrfToken) - // Continue stack + // continue stack return c.Next() } } -// CsrfTokenFromContext returns the csrf token from the context. -func CsrfTokenFromContext(c *fiber.Ctx) string { - token, ok := c.Locals(csrfTokenKey).(string) +// CsrfTokenFromContext returns the CSRF token from the context. +func CsrfTokenFromContext(c *fiber.Ctx) (string, error) { + token, ok := c.Locals(csrfTokenKey).(adapters.GothCsrfToken) if !ok { - return "" + return "", ErrTokenNotFound } - return token + return token.Token, nil } // FromHeader returns a function that extracts token from the request header. diff --git a/examples/main.go b/examples/main.go index aeeb458..09c1157 100644 --- a/examples/main.go +++ b/examples/main.go @@ -10,6 +10,7 @@ import ( goth "github.com/zeiss/fiber-goth" gorm_adapter "github.com/zeiss/fiber-goth/adapters/gorm" + "github.com/zeiss/fiber-goth/csrf" "github.com/zeiss/fiber-goth/providers" "github.com/zeiss/fiber-goth/providers/entraid" "github.com/zeiss/fiber-goth/providers/github" @@ -132,6 +133,15 @@ func run(_ context.Context) error { return c.JSON(session) }) + app.Get("/protected", func(c *fiber.Ctx) error { + t, err := csrf.CsrfTokenFromContext(c) + if err != nil { + return err + } + + return c.SendString(t) + }) + app.Get("/login", func(c *fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextHTML) return t.Execute(c.Response().BodyWriter(), providerIndex) diff --git a/goth.go b/goth.go index 4f56d88..5dc710e 100644 --- a/goth.go +++ b/goth.go @@ -260,6 +260,8 @@ func (CompleteAuthCompleteHandler) New(cfg Config) fiber.Handler { cookieValue.SetExpire(expires) cookieValue.SetPath("/") + c.Vary(fiber.HeaderCookie) + c.Response().Header.SetCookie(&cookieValue) return cfg.CompletionFilter(c)