diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index e89ed12dac5..4244d29db91 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -154,13 +154,16 @@ func InitialSettings() []model.SettingItem { // SSO settings {Key: conf.SSOLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, - {Key: conf.SSOLoginplatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC}, + {Key: conf.SSOLoginPlatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC}, {Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOJwtPublicKey, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOAutoRegister, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSODefaultDir, Value: "/", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSODefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.SSO, Flag: model.PRIVATE}, // qbittorrent settings {Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/conf/const.go b/internal/conf/const.go index 403e2b48d8c..206f8e1f793 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -57,14 +57,17 @@ const ( IndexProgress = "index_progress" //SSO - SSOClientId = "sso_client_id" - SSOClientSecret = "sso_client_secret" - SSOLoginEnabled = "sso_login_enabled" - SSOLoginplatform = "sso_login_platform" - SSOOrganizationName = "sso_organization_name" - SSOApplicationName = "sso_application_name" - SSOEndpointName = "sso_endpoint_name" - SSOJwtPublicKey = "sso_jwt_public_key" + SSOClientId = "sso_client_id" + SSOClientSecret = "sso_client_secret" + SSOLoginEnabled = "sso_login_enabled" + SSOLoginPlatform = "sso_login_platform" + SSOOrganizationName = "sso_organization_name" + SSOApplicationName = "sso_application_name" + SSOEndpointName = "sso_endpoint_name" + SSOJwtPublicKey = "sso_jwt_public_key" + SSOAutoRegister = "sso_auto_register" + SSODefaultDir = "sso_default_dir" + SSODefaultPermission = "sso_default_permission" // qbittorrent QbittorrentUrl = "qbittorrent_url" diff --git a/internal/model/user.go b/internal/model/user.go index 2cde7a545ce..ca225c079ef 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -33,7 +33,7 @@ type User struct { // 10: can add qbittorrent tasks Permission int32 `json:"permission"` OtpSecret string `json:"-"` - SsoID string `json:"sso_id"` + SsoID string `json:"sso_id"` // unique by sso platform } func (u User) IsGuest() bool { diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index 76a7f5d0618..f7e85807925 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -11,8 +11,10 @@ import ( "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" "github.com/alist-org/alist/v3/server/common" "github.com/coreos/go-oidc" "github.com/gin-gonic/gin" @@ -20,14 +22,15 @@ import ( "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "golang.org/x/oauth2" + "gorm.io/gorm" ) var opts = totp.ValidateOpts{ // state verify won't expire in 30 secs, which is quite enough for the callback Period: 30, - Skew: 1, + Skew: 1, // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters - Digits: otp.DigitsEight, + Digits: otp.DigitsEight, Algorithm: otp.AlgorithmSHA1, } @@ -35,7 +38,7 @@ func SSOLoginRedirect(c *gin.Context) { method := c.Query("method") enabled := setting.GetBool(conf.SSOLoginEnabled) clientId := setting.GetStr(conf.SSOClientId) - platform := setting.GetStr(conf.SSOLoginplatform) + platform := setting.GetStr(conf.SSOLoginPlatform) var r_url string var redirect_uri string if enabled { @@ -76,7 +79,7 @@ func SSOLoginRedirect(c *gin.Context) { return } // generate state parameter - state,err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) if err != nil { common.ErrorStrResp(c, err.Error(), 400) return @@ -118,13 +121,39 @@ func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { }, nil } +func autoRegister(username, userID string, err error) (*model.User, error) { + if !errors.Is(err, gorm.ErrRecordNotFound) || !setting.GetBool(conf.SSOAutoRegister) { + return nil, err + } + if username == "" { + return nil, errors.New("cannot get username from SSO provider") + } + user := &model.User{ + ID: 0, + Username: username, + Password: random.String(16), + Permission: int32(setting.GetInt(conf.SSODefaultPermission, 0)), + BasePath: setting.GetStr(conf.SSODefaultDir), + Role: 0, + Disabled: false, + SsoID: userID, + } + if err = db.CreateUser(user); err != nil { + if strings.HasPrefix(err.Error(), "UNIQUE constraint failed") && strings.HasSuffix(err.Error(), "username") { + user.Username = user.Username + "_" + userID + if err = db.CreateUser(user); err != nil { + return nil, err + } + } else { + return nil, err + } + } + return user, nil +} + func OIDCLoginCallback(c *gin.Context) { argument := c.Query("method") - enabled := setting.GetBool(conf.SSOLoginEnabled) clientId := setting.GetStr(conf.SSOClientId) - if !enabled { - common.ErrorResp(c, errors.New("invalid request"), 500) - } endpoint := setting.GetStr(conf.SSOEndpointName) provider, err := oidc.NewProvider(c, endpoint) if err != nil { @@ -170,7 +199,7 @@ func OIDCLoginCallback(c *gin.Context) { } claims := UserInfo{} if err := idToken.Claims(&claims); err != nil { - c.Error(err) + common.ErrorResp(c, err, 400) return } UserID := claims.Name @@ -189,7 +218,10 @@ func OIDCLoginCallback(c *gin.Context) { if argument == "sso_get_token" { user, err := db.GetUserBySSOID(UserID) if err != nil { - common.ErrorResp(c, err, 400) + user, err = autoRegister(UserID, UserID, err) + if err != nil { + common.ErrorResp(c, err, 400) + } } token, err := common.GenerateToken(user.Username) if err != nil { @@ -209,133 +241,145 @@ func OIDCLoginCallback(c *gin.Context) { } func SSOLoginCallback(c *gin.Context) { + enabled := setting.GetBool(conf.SSOLoginEnabled) + if !enabled { + common.ErrorResp(c, errors.New("sso login is disabled"), 500) + } argument := c.Query("method") - if argument == "get_sso_id" || argument == "sso_get_token" { - enabled := setting.GetBool(conf.SSOLoginEnabled) - clientId := setting.GetStr(conf.SSOClientId) - platform := setting.GetStr(conf.SSOLoginplatform) - clientSecret := setting.GetStr(conf.SSOClientSecret) - var url1, url2, additionalbody, scope, authstring, idstring string - switch platform { - case "Github": - url1 = "https://github.com/login/oauth/access_token" - url2 = "https://api.github.com/user" - additionalbody = "" - authstring = "code" - scope = "read:user" - idstring = "id" - case "Microsoft": - url1 = "https://login.microsoftonline.com/common/oauth2/v2.0/token" - url2 = "https://graph.microsoft.com/v1.0/me" - additionalbody = "&grant_type=authorization_code" - scope = "user.read" - authstring = "code" - idstring = "id" - case "Google": - url1 = "https://oauth2.googleapis.com/token" - url2 = "https://www.googleapis.com/oauth2/v1/userinfo" - additionalbody = "&grant_type=authorization_code" - scope = "https://www.googleapis.com/auth/userinfo.profile" - authstring = "code" - idstring = "id" - case "Dingtalk": - url1 = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken" - url2 = "https://api.dingtalk.com/v1.0/contact/users/me" - authstring = "authCode" - idstring = "unionId" - case "Casdoor": - endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") - url1 = endpoint + "/api/login/oauth/access_token" - url2 = endpoint + "/api/userinfo" - additionalbody = "&grant_type=authorization_code" - scope = "profile" - authstring = "code" - idstring = "preferred_username" - case "OIDC": - OIDCLoginCallback(c) - return - default: - common.ErrorStrResp(c, "invalid platform", 400) - return - } - if enabled { - callbackCode := c.Query(authstring) - if callbackCode == "" { - common.ErrorStrResp(c, "No code provided", 400) - return - } - var resp *resty.Response - var err error - if platform == "Dingtalk" { - resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json"). - SetBody(map[string]string{ - "clientId": clientId, - "clientSecret": clientSecret, - "code": callbackCode, - "grantType": "authorization_code", - }). - Post(url1) - } else { - resp, err = ssoClient.R().SetHeader("content-type", "application/x-www-form-urlencoded").SetHeader("Accept", "application/json"). - SetBody("client_id=" + clientId + "&client_secret=" + clientSecret + "&code=" + callbackCode + "&redirect_uri=" + common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument + "&scope=" + scope + additionalbody). - Post(url1) - } - if err != nil { - common.ErrorResp(c, err, 400) - return - } - if platform == "Dingtalk" { - accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString() - resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken). - Get(url2) - } else { - accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() - resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken). - Get(url2) - } - if err != nil { - common.ErrorResp(c, err, 400) - return - } - UserID := utils.Json.Get(resp.Body(), idstring).ToString() - if UserID == "0" { - common.ErrorResp(c, errors.New("error occured"), 400) - return - } - if argument == "get_sso_id" { - html := fmt.Sprintf(` + if !utils.SliceContains([]string{"get_sso_id", "sso_get_token"}, argument) { + common.ErrorResp(c, errors.New("invalid request"), 500) + } + clientId := setting.GetStr(conf.SSOClientId) + platform := setting.GetStr(conf.SSOLoginPlatform) + clientSecret := setting.GetStr(conf.SSOClientSecret) + var tokenUrl, userUrl, scope, authField, idField, usernameField string + additionalForm := make(map[string]string) + switch platform { + case "Github": + tokenUrl = "https://github.com/login/oauth/access_token" + userUrl = "https://api.github.com/user" + authField = "code" + scope = "read:user" + idField = "id" + usernameField = "login" + case "Microsoft": + tokenUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + userUrl = "https://graph.microsoft.com/v1.0/me" + additionalForm["grant_type"] = "authorization_code" + scope = "user.read" + authField = "code" + idField = "id" + usernameField = "displayName" + case "Google": + tokenUrl = "https://oauth2.googleapis.com/token" + userUrl = "https://www.googleapis.com/oauth2/v1/userinfo" + additionalForm["grant_type"] = "authorization_code" + scope = "https://www.googleapis.com/auth/userinfo.profile" + authField = "code" + idField = "id" + usernameField = "name" + case "Dingtalk": + tokenUrl = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken" + userUrl = "https://api.dingtalk.com/v1.0/contact/users/me" + authField = "authCode" + idField = "unionId" + usernameField = "nick" + case "Casdoor": + endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") + tokenUrl = endpoint + "/api/login/oauth/access_token" + userUrl = endpoint + "/api/userinfo" + additionalForm["grant_type"] = "authorization_code" + scope = "profile" + authField = "code" + idField = "sub" + usernameField = "preferred_username" + case "OIDC": + OIDCLoginCallback(c) + return + default: + common.ErrorStrResp(c, "invalid platform", 400) + return + } + callbackCode := c.Query(authField) + if callbackCode == "" { + common.ErrorStrResp(c, "No code provided", 400) + return + } + var resp *resty.Response + var err error + if platform == "Dingtalk" { + resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json"). + SetBody(map[string]string{ + "clientId": clientId, + "clientSecret": clientSecret, + "code": callbackCode, + "grantType": "authorization_code", + }). + Post(tokenUrl) + } else { + resp, err = ssoClient.R().SetHeader("Accept", "application/json"). + SetFormData(map[string]string{ + "client_id": clientId, + "client_secret": clientSecret, + "code": callbackCode, + "redirect_uri": common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument, + "scope": scope, + }).SetFormData(additionalForm).Post(tokenUrl) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if platform == "Dingtalk" { + accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString() + resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken). + Get(userUrl) + } else { + accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() + resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken). + Get(userUrl) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + userID := utils.Json.Get(resp.Body(), idField).ToString() + if utils.SliceContains([]string{"", "0"}, userID) { + common.ErrorResp(c, errors.New("error occured"), 400) + return + } + if argument == "get_sso_id" { + html := fmt.Sprintf(`
- `, UserID) - c.Data(200, "text/html; charset=utf-8", []byte(html)) - return - } - if argument == "sso_get_token" { - user, err := db.GetUserBySSOID(UserID) - if err != nil { - common.ErrorResp(c, err, 400) - } - token, err := common.GenerateToken(user.Username) - if err != nil { - common.ErrorResp(c, err, 400) - } - html := fmt.Sprintf(` - - - - `, token) - c.Data(200, "text/html; charset=utf-8", []byte(html)) - return - } - } else { - common.ErrorResp(c, errors.New("invalid request"), 500) + + +`, userID) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } + username := utils.Json.Get(resp.Body(), usernameField).ToString() + user, err := db.GetUserBySSOID(userID) + if err != nil { + user, err = autoRegister(username, userID, err) + if err != nil { + common.ErrorResp(c, err, 400) + return } } + token, err := common.GenerateToken(user.Username) + if err != nil { + common.ErrorResp(c, err, 400) + } + html := fmt.Sprintf(` +
+
`, token) + c.Data(200, "text/html; charset=utf-8", []byte(html)) }