diff --git a/privilege/privilege.go b/privilege/privilege.go index bdcea116bed27..80dbe6e2ed18c 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -69,6 +69,12 @@ type Manager interface { // GetAllRoles return all roles of user. GetAllRoles(user, host string) []*auth.RoleIdentity + + // CheckAccountLock return if the account has been locked. + CheckAccountLocked(ctx sessionctx.Context, user, host string) bool + + // IncFailTimer is used to increase lock timer. + IncFailTimer(ctx sessionctx.Context, user, host string) } const key keyType = 0 diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 0d3af938f9e4a..b6f8691f28b2b 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -192,6 +192,12 @@ type roleGraphEdgesTable struct { roleList map[string]*auth.RoleIdentity } +type blackListItem struct { + baseRecord + + startTime time.Time +} + // Find method is used to find role from table func (g roleGraphEdgesTable) Find(user, host string) bool { if host == "" { @@ -229,6 +235,8 @@ type MySQLPrivilege struct { ColumnsPriv []columnsPrivRecord DefaultRoles []defaultRoleRecord RoleGraph map[string]roleGraphEdgesTable + PwdErrorCnt map[string]int + BlackList map[string][]blackListItem } // FindAllRole is used to find all roles grant to this user. @@ -256,6 +264,66 @@ func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.R return ret } +func (p *MySQLPrivilege) IncLoginFail(user, host string) int { + key := user + "@" + host + cnt, exist := p.PwdErrorCnt[key] + if exist { + p.PwdErrorCnt[key] = cnt + 1 + return cnt + 1 + } + p.PwdErrorCnt[key] = 1 + return 1 +} + +func (p *MySQLPrivilege) ClearLoginFail(user, host string) { + key := user + "@" + host + if p.PwdErrorCnt == nil { + p.PwdErrorCnt = make(map[string]int) + } + p.PwdErrorCnt[key] = 0 +} + +func (p *MySQLPrivilege) LockAccount(user, host string, sctx sessionctx.Context) error { + lock := types.CurrentTime(0) + ctx := context.Background() + sql := fmt.Sprintf("insert into mysql.login_blacklist(USER, HOST, lock_time) values('%s', '%s', '%s') on duplicate key update lock_time = '%s'", user, host, lock.String(), lock.String()) + _, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql) + return err +} + +func (p *MySQLPrivilege) CheckAccountLock(user, host string, sctx sessionctx.Context, limit time.Duration) bool { + recs, exist := p.BlackList[user] + if exist { + for _, r := range recs { + if r.Host == host { + t := r.startTime + if time.Now().Sub(t) > limit*time.Second { + ctx := context.Background() + sql := fmt.Sprintf("delete from mysql.login_blacklist where user = '%s' and host = '%s'", user, host) + _, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql) + if err != nil { + // ignore + } + return false + } else { + return true + } + } + } + } + return false +} + +func (p *MySQLPrivilege) LoadBlackList(ctx sessionctx.Context) error { + p.BlackList = make(map[string][]blackListItem) + err := p.loadTable(ctx, "select HOST, USER, lock_time from mysql.login_blacklist;", p.decodeBlackListRow) + if err != nil { + logutil.BgLogger().Warn("load mysql.user fail", zap.Error(err)) + return errLoadPrivilege.FastGen("mysql.login_blacklist") + } + return nil +} + // FindRole is used to detect whether there is edges between users and roles. func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool { rec := p.matchUser(user, host) @@ -269,6 +337,9 @@ func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdent // LoadAll loads the tables from database to memory. func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { + if p.PwdErrorCnt == nil { + p.PwdErrorCnt = make(map[string]int) + } err := p.LoadUserTable(ctx) if err != nil { logutil.BgLogger().Warn("load mysql.user fail", zap.Error(err)) @@ -324,6 +395,11 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { } logutil.BgLogger().Warn("mysql.role_edges missing") } + + err = p.LoadBlackList(ctx) + if err != nil { + return errLoadPrivilege.FastGen("mysql.login_blacklist") + } return nil } @@ -584,6 +660,30 @@ func (record *baseRecord) assignUserOrHost(row chunk.Row, i int, f *ast.ResultFi } } +func (p *MySQLPrivilege) decodeBlackListRow(row chunk.Row, fs []*ast.ResultField) error { + var value blackListItem + for i, f := range fs { + switch { + case f.ColumnAsName.L == "host": + value.Host = row.GetString(i) + case f.ColumnAsName.L == "user": + value.User = row.GetString(i) + case f.ColumnAsName.L == "lock_time": + ti := row.GetTime(i) + var err error + value.startTime, err = ti.GoTime(time.Local) + if err != nil { + return err + } + } + } + if p.BlackList == nil { + p.BlackList = make(map[string][]blackListItem) + } + p.BlackList[value.User] = append(p.BlackList[value.User], value) + return nil +} + func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField) error { var value UserRecord for i, f := range fs { diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index f712060e79de3..4b25af9e6e897 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -16,7 +16,9 @@ package privileges import ( "crypto/tls" "fmt" + "strconv" "strings" + "time" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" @@ -24,6 +26,7 @@ import ( "github.com/pingcap/tidb/infoschema/perfschema" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" @@ -145,6 +148,81 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string return } +// CheckAccountLock return if the account has been locked. +func (p *UserPrivileges) CheckAccountLocked(ctx sessionctx.Context, user, host string) bool { + if SkipWithGrant { + return false + } + + mysqlPriv := p.Handle.Get() + record := mysqlPriv.connectionVerification(user, host) + if record == nil { + logutil.BgLogger().Error("get user privilege record fail", + zap.String("user", user), zap.String("host", host)) + return true + } + + u := record.User + h := record.Host + + sessionVars := ctx.GetSessionVars() + lockTimeVar, err := variable.GetSessionSystemVar(sessionVars, variable.LoginBlockInterval) + if err != nil { + logutil.BgLogger().Error("Get locktime fail.", zap.Error(err)) + return true + } + lockTime, err := strconv.ParseInt(lockTimeVar, 10, 0) + if err != nil { + logutil.BgLogger().Error("Parse password limit fail.", zap.Error(err)) + return true + } + if mysqlPriv.CheckAccountLock(u, h, ctx, time.Duration(lockTime)) { + return true + } + + return false +} + +func (p *UserPrivileges) IncFailTimer(ctx sessionctx.Context, user, host string) { + if SkipWithGrant { + return + } + + mysqlPriv := p.Handle.Get() + record := mysqlPriv.connectionVerification(user, host) + if record == nil { + logutil.BgLogger().Error("get user privilege record fail", + zap.String("user", user), zap.String("host", host)) + return + } + + u := record.User + h := record.Host + + sessionVars := ctx.GetSessionVars() + cnt := mysqlPriv.IncLoginFail(u, h) + limitVar, err := variable.GetSessionSystemVar(sessionVars, variable.MaxLoginAttempts) + if err != nil { + logutil.BgLogger().Error("Get password limit fail.", zap.Error(err)) + return + } + limit, err := strconv.ParseInt(limitVar, 10, 0) + if err != nil { + logutil.BgLogger().Error("Parse password limit fail.", zap.Error(err)) + return + } + if cnt > int(limit) { + err := mysqlPriv.LockAccount(u, h, ctx) + if err != nil { + logutil.BgLogger().Error("error occuer while locking account", zap.Error(err)) + } + err = p.Update(ctx) + if err != nil { + logutil.BgLogger().Error("error occuer while updating account lock info", zap.Error(err)) + } + } +} + // ConnectionVerification implements the Manager interface. func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { if SkipWithGrant { @@ -214,6 +292,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio p.user = user p.host = h + mysqlPriv.ClearLoginFail(u, h) success = true return } diff --git a/session/bootstrap.go b/session/bootstrap.go index 67b1a285c07a7..3a37af975cc69 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -284,6 +284,13 @@ const ( CreateOptRuleBlacklist = `CREATE TABLE IF NOT EXISTS mysql.opt_rule_blacklist ( name char(100) NOT NULL );` + + CreateLoginBlackList = `CREATE TABLE IF NOT EXISTS mysql.login_blacklist ( + HOST char(60) COLLATE utf8_bin NOT NULL DEFAULT '', + USER char(32) COLLATE utf8_bin NOT NULL DEFAULT '', + lock_time timestamp default null, + primary key(USER, HOST) + )` ) // bootstrap initiates system DB for a store. @@ -384,6 +391,8 @@ const ( version44 = 44 // version45 introduces CONFIG_PRIV for SET CONFIG statements. version45 = 45 + // version46 introduces mysql.login_blacklist to detect login fail. + version46 = 46 ) var ( @@ -432,6 +441,7 @@ var ( upgradeToVer43, upgradeToVer44, upgradeToVer45, + upgradeToVer46, } ) @@ -1045,6 +1055,19 @@ func upgradeToVer45(s Session, ver int64) { mustExecute(s, "UPDATE HIGH_PRIORITY mysql.user SET Config_priv='Y' where Super_priv='Y'") } +func upgradeToVer46(s Session, ver int64) { + if ver >= version46 { + return + } + mustExecute(s, CreateLoginBlackList) + sql := fmt.Sprintf("INSERT IGNORE INTO %s.%s (`VARIABLE_NAME`, `VARIABLE_VALUE`) VALUES ('%s', '%d')", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.MaxLoginAttempts, 10) + mustExecute(s, sql) + sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (`VARIABLE_NAME`, `VARIABLE_VALUE`) VALUES ('%s', '%d')", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.LoginBlockInterval, 3600) + mustExecute(s, sql) +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. @@ -1108,6 +1131,7 @@ func doDDLWorks(s Session) { mustExecute(s, CreateExprPushdownBlacklist) // Create opt_rule_blacklist table. mustExecute(s, CreateOptRuleBlacklist) + mustExecute(s, CreateLoginBlackList) } // doDMLWorks executes DML statements in bootstrap stage. diff --git a/session/session.go b/session/session.go index e83cd1a6f6937..48807e1e6c434 100644 --- a/session/session.go +++ b/session/session.go @@ -1505,6 +1505,12 @@ func (s *session) GetSessionVars() *variable.SessionVars { func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) + // check account lock. + locked := pm.CheckAccountLocked(s, user.Username, user.Hostname) + if locked { + return false + } + // Check IP or localhost. var success bool user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) @@ -1513,6 +1519,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true } else if user.Hostname == variable.DefHostname { + pm.IncFailTimer(s, user.AuthUsername, user.AuthHostname) return false } @@ -1530,6 +1537,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by return true } } + pm.IncFailTimer(s, user.AuthUsername, user.AuthHostname) return false } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 0549fc46776cf..738a07c581454 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -718,6 +718,9 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBEnableSlowLog, BoolToIntStr(logutil.DefaultTiDBEnableSlowLog)}, {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)}, + // Used to block DoS. + {ScopeGlobal, MaxLoginAttempts, strconv.Itoa(DefMaxLoginAttempts)}, + {ScopeGlobal, LoginBlockInterval, strconv.Itoa(DefLoginBlockInterval)}, } // SynonymsSysVariables is synonyms of system variables. @@ -979,6 +982,10 @@ const ( ThreadPoolSize = "thread_pool_size" // WindowingUseHighPrecision is the name of 'windowing_use_high_precision' system variable. WindowingUseHighPrecision = "windowing_use_high_precision" + // MaxLoginAttempts is the name for 'max_login_attempts' system variable. + MaxLoginAttempts = "max_login_attempts" + // LoginBlockInterval is the name for 'login_block_interval' system variable. + LoginBlockInterval = "login_block_interval" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 1e510d308a99d..3a68658ebeff2 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -479,6 +479,8 @@ const ( DefTiDBStoreLimit = 0 DefTiDBMetricSchemaStep = 60 // 60s DefTiDBMetricSchemaRangeDuration = 60 // 60s + DefMaxLoginAttempts = 10 + DefLoginBlockInterval = 3600 ) // Process global variables.