From 0c2bda5c0f8e6525765ae15e6b8ed5cd04229bfb Mon Sep 17 00:00:00 2001 From: Ke Chen Date: Fri, 10 Nov 2023 22:51:20 +0800 Subject: [PATCH] feat: move middleware get user to routes --- apis/routes.go | 19 +++++++++++++++++++ bootstrap/init.go | 16 ---------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/apis/routes.go b/apis/routes.go index 06eca9d..9892706 100644 --- a/apis/routes.go +++ b/apis/routes.go @@ -1,6 +1,8 @@ package apis import ( + "github.com/opentreehole/go-common" + "treehole_next/apis/division" "treehole_next/apis/favourite" "treehole_next/apis/floor" @@ -11,7 +13,9 @@ import ( "treehole_next/apis/subscription" "treehole_next/apis/tag" "treehole_next/apis/user" + "treehole_next/config" _ "treehole_next/docs" + "treehole_next/models" "github.com/gofiber/fiber/v2" fiberSwagger "github.com/swaggo/fiber-swagger" @@ -32,6 +36,7 @@ func RegisterRoutes(app *fiber.App) { group := app.Group("/api") group.Get("/", Index) + group.Use(MiddlewareGetUser) division.RegisterRoutes(group) tag.RegisterRoutes(group) hole.RegisterRoutes(group) @@ -43,3 +48,17 @@ func RegisterRoutes(app *fiber.App) { user.RegisterRoutes(group) message.RegisterRoutes(group) } + +func MiddlewareGetUser(c *fiber.Ctx) error { + userObject, err := models.GetUser(c) + if err != nil { + return err + } + c.Locals("user", userObject) + if config.Config.AdminOnly { + if !userObject.IsAdmin { + return common.Forbidden() + } + } + return c.Next() +} diff --git a/bootstrap/init.go b/bootstrap/init.go index 75049fd..d00ce05 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -43,23 +43,7 @@ func registerMiddlewares(app *fiber.App) { if config.Config.Mode != "bench" { app.Use(common.MiddlewareCustomLogger) } - app.Use(MiddlewareGetUser) app.Use(pprof.New()) - -} - -func MiddlewareGetUser(c *fiber.Ctx) error { - user, err := models.GetUser(c) - if err != nil { - return err - } - c.Locals("user", user) - if config.Config.AdminOnly { - if !user.IsAdmin { - return common.Forbidden() - } - } - return c.Next() } func startTasks() context.CancelFunc {