diff --git a/goth.go b/goth.go index d0f254f..d82cf6e 100644 --- a/goth.go +++ b/goth.go @@ -5,13 +5,20 @@ package goth import ( + "bytes" + "compress/gzip" "encoding/base64" "errors" + "fmt" + "io" "math/rand" + "strings" "time" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/session" "github.com/markbates/goth" + "github.com/markbates/goth/gothic" ) var _ GothHandler = (*BeginAuthHandler)(nil) @@ -37,6 +44,83 @@ const ( provider = "provider" ) +// SessionStore ... +type SessionStore interface { + Get(c *fiber.Ctx, key string) (string, error) + Update(c *fiber.Ctx, key, value string) error +} + +var _ SessionStore = (*sessionStore)(nil) + +// NewSessionStore ... +func NewSessionStore(store *session.Store) *sessionStore { + return &sessionStore{ + store: store, + } +} + +type sessionStore struct { + store *session.Store +} + +// Get returns session data. +func (s *sessionStore) Get(c *fiber.Ctx, key string) (string, error) { + session, err := s.store.Get(c) + if err != nil { + return "", err + } + + value := session.Get(key) + if value == nil { + return "", errors.New("could not find a matching session for this request") + } + + rdata := strings.NewReader(value.(string)) + r, err := gzip.NewReader(rdata) + if err != nil { + return "", err + } + + v, err := io.ReadAll(r) + if err != nil { + return "", err + } + + return string(v), nil +} + +// Update updates session data. +func (s *sessionStore) Update(c *fiber.Ctx, key, value string) error { + session, err := s.store.Get(c) + if err != nil { + return err + } + + var b bytes.Buffer + + gz := gzip.NewWriter(&b) + if _, err := gz.Write([]byte(value)); err != nil { + return err + } + + if err := gz.Flush(); err != nil { + return err + } + + if err := gz.Close(); err != nil { + return err + } + + session.Set(key, b.String()) + + err = session.Save() + if err != nil { + return err + } + + return nil +} + // ProviderFromContext returns the provider from the request context. func ProviderFromContext(c *fiber.Ctx) { } @@ -51,7 +135,7 @@ func (BeginAuthHandler) New(cfg Config) fiber.Handler { return c.Next() } - url, err := GetAuthURLFromContext(c) + url, err := GetAuthURLFromContext(c, cfg.Session) if err != nil { return cfg.ErrorHandler(c, err) } @@ -73,7 +157,7 @@ func NewBeginAuthHandler(config ...Config) fiber.Handler { } // GetAuthURLFromContext returns the provider specific authentication URL. -func GetAuthURLFromContext(c *fiber.Ctx) (string, error) { +func GetAuthURLFromContext(c *fiber.Ctx, session SessionStore) (string, error) { p := c.Query(provider) if p == "" { return "", ErrMissingProviderName @@ -94,12 +178,17 @@ func GetAuthURLFromContext(c *fiber.Ctx) (string, error) { return "", err } + err = session.Update(c, p, sess.Marshal()) + if err != nil { + return "", err + } + return url, err } // GetStateFromContext return the state that is returned during the callback. func GetStateFromContext(ctx *fiber.Ctx) string { - return ctx.Query("state") + return ctx.Query(state) } // Config caputes the configuration for running the goth middleware. @@ -110,6 +199,9 @@ type Config struct { // BeginAuthHandler ... BeginAuthHandler GothHandler + // Session ... + Session SessionStore + // ErrorHandler is executed when an error is returned from fiber.Handler. // // Optional. Default: DefaultErrorHandler @@ -127,6 +219,11 @@ func defaultErrorHandler(_ *fiber.Ctx, _ error) error { return fiber.ErrBadRequest } +var defaultSessionConfig = session.Config{ + KeyLookup: fmt.Sprintf("cookie:%s", gothic.SessionName), + CookieHTTPOnly: true, +} + // Helper function to set default values func configDefault(config ...Config) Config { if len(config) < 1 { @@ -140,6 +237,10 @@ func configDefault(config ...Config) Config { cfg.Next = ConfigDefault.Next } + if cfg.Session == nil { + cfg.Session = NewSessionStore(session.New(defaultSessionConfig)) + } + if cfg.BeginAuthHandler == nil { cfg.BeginAuthHandler = ConfigDefault.BeginAuthHandler }