diff --git a/goth.go b/goth.go index fc86467..d0f254f 100644 --- a/goth.go +++ b/goth.go @@ -5,12 +5,21 @@ package goth import ( + "encoding/base64" + "errors" + "math/rand" + "time" + "github.com/gofiber/fiber/v2" "github.com/markbates/goth" ) var _ GothHandler = (*BeginAuthHandler)(nil) +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int @@ -20,6 +29,14 @@ const ( providerKey contextKey = 0 ) +// ErrMissingProviderName is thrown if the provider cannot be determined. +var ErrMissingProviderName = errors.New("missing provider name in request") + +const ( + state = "state" + provider = "provider" +) + // ProviderFromContext returns the provider from the request context. func ProviderFromContext(c *fiber.Ctx) { } @@ -36,7 +53,7 @@ func (BeginAuthHandler) New(cfg Config) fiber.Handler { url, err := GetAuthURLFromContext(c) if err != nil { - return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + return cfg.ErrorHandler(c, err) } return c.Redirect(url, fiber.StatusTemporaryRedirect) @@ -57,12 +74,17 @@ func NewBeginAuthHandler(config ...Config) fiber.Handler { // GetAuthURLFromContext returns the provider specific authentication URL. func GetAuthURLFromContext(c *fiber.Ctx) (string, error) { - provider, err := goth.GetProvider("") + p := c.Query(provider) + if p == "" { + return "", ErrMissingProviderName + } + + provider, err := goth.GetProvider(p) if err != nil { return "", err } - sess, err := provider.BeginAuth("") + sess, err := provider.BeginAuth(stateFromContext(c)) if err != nil { return "", err } @@ -75,6 +97,11 @@ func GetAuthURLFromContext(c *fiber.Ctx) (string, error) { return url, err } +// GetStateFromContext return the state that is returned during the callback. +func GetStateFromContext(ctx *fiber.Ctx) string { + return ctx.Query("state") +} + // Config caputes the configuration for running the goth middleware. type Config struct { // Next defines a function to skip this middleware when returned true. @@ -82,13 +109,24 @@ type Config struct { // BeginAuthHandler ... BeginAuthHandler GothHandler + + // ErrorHandler is executed when an error is returned from fiber.Handler. + // + // Optional. Default: DefaultErrorHandler + ErrorHandler fiber.ErrorHandler } // ConfigDefault is the default config. var ConfigDefault = Config{ + ErrorHandler: defaultErrorHandler, BeginAuthHandler: BeginAuthHandler{}, } +// default ErrorHandler that process return error from fiber.Handler +func defaultErrorHandler(_ *fiber.Ctx, _ error) error { + return fiber.ErrBadRequest +} + // Helper function to set default values func configDefault(config ...Config) Config { if len(config) < 1 { @@ -108,3 +146,22 @@ func configDefault(config ...Config) Config { return cfg } + +func stateFromContext(ctx *fiber.Ctx) string { + state := ctx.Query(state) + if len(state) > 0 { + return state + } + + nonce := generateRandomString(64) + + return base64.URLEncoding.EncodeToString(nonce) +} + +func generateRandomString(length int) []byte { + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return b +}