From 898494118c41f341a994e1162b04c8f5a9b557b3 Mon Sep 17 00:00:00 2001 From: Henry <2671230065@qq.com> Date: Sat, 3 Aug 2024 17:36:22 +0800 Subject: [PATCH] feat: invalidate token on logout --- server/common/auth.go | 23 +++++++++++++++++++++++ server/handles/auth.go | 9 +++++++++ server/router.go | 1 + 3 files changed, 33 insertions(+) diff --git a/server/common/auth.go b/server/common/auth.go index b6a79b752aa..0de718cf9e8 100644 --- a/server/common/auth.go +++ b/server/common/auth.go @@ -3,6 +3,7 @@ package common import ( "time" + "github.com/Xhofe/go-cache" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" "github.com/golang-jwt/jwt/v4" @@ -17,6 +18,8 @@ type UserClaims struct { jwt.RegisteredClaims } +var validTokenCache = cache.NewMemCache[bool]() + func GenerateToken(user *model.User) (tokenString string, err error) { claim := UserClaims{ Username: user.Username, @@ -28,6 +31,10 @@ func GenerateToken(user *model.User) (tokenString string, err error) { }} token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) tokenString, err = token.SignedString(SecretKey) + if err != nil { + return "", err + } + validTokenCache.Set(tokenString, true) return tokenString, err } @@ -35,6 +42,9 @@ func ParseToken(tokenString string) (*UserClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { return SecretKey, nil }) + if IsTokenInvalidated(tokenString) { + return nil, errors.New("token is invalidated") + } if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { if ve.Errors&jwt.ValidationErrorMalformed != 0 { @@ -53,3 +63,16 @@ func ParseToken(tokenString string) (*UserClaims, error) { } return nil, errors.New("couldn't handle this token") } + +func InvalidateToken(tokenString string) error { + if tokenString == "" { + return nil // don't invalidate empty guest token + } + validTokenCache.Del(tokenString) + return nil +} + +func IsTokenInvalidated(tokenString string) bool { + _, ok := validTokenCache.Get(tokenString) + return !ok +} diff --git a/server/handles/auth.go b/server/handles/auth.go index 209bdd3a2b8..e1f512c4dc1 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -181,3 +181,12 @@ func Verify2FA(c *gin.Context) { common.SuccessResp(c) } } + +func LogOut(c *gin.Context) { + err := common.InvalidateToken(c.GetHeader("Authorization")) + if err != nil { + common.ErrorResp(c, err, 500) + } else { + common.SuccessResp(c) + } +} diff --git a/server/router.go b/server/router.go index 5f784aa4b7d..5be593f7497 100644 --- a/server/router.go +++ b/server/router.go @@ -54,6 +54,7 @@ func Init(e *gin.Engine) { auth.POST("/me/update", handles.UpdateCurrent) auth.POST("/auth/2fa/generate", handles.Generate2FA) auth.POST("/auth/2fa/verify", handles.Verify2FA) + auth.GET("/auth/logout", handles.LogOut) // auth api.GET("/auth/sso", handles.SSOLoginRedirect)