Skip to content

Commit

Permalink
*: Support for caching_sha2_password authentication (#24991)
Browse files Browse the repository at this point in the history
  • Loading branch information
dveeden authored Jul 5, 2021
1 parent fee39d3 commit f23e100
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 58 deletions.
8 changes: 7 additions & 1 deletion executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
if !ok {
return errors.Trace(ErrPasswordFormat)
}
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, mysql.AuthNativePassword)
authPlugin := mysql.AuthNativePassword
if user.AuthOpt.AuthPlugin != "" {
authPlugin = user.AuthOpt.AuthPlugin
}
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx,
`INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`,
mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, authPlugin)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions executor/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"strings"

"github.com/pingcap/check"
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
Expand Down Expand Up @@ -444,14 +445,13 @@ func (s *testSuite5) TestShowCreateUser(c *C) {
rows = tk1.MustQuery("show create user current_user")
rows.Check(testkit.Rows("CREATE USER 'check_priv'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK"))

// Creating users with `IDENTIFIED WITH 'caching_sha2_password'` is not supported yet. So manually creating an entry for now.
// later this can be changed to test the full path once 'caching_sha2_password' support is completed.
tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED BY 'temp_passwd'")
tk.MustExec("UPDATE mysql.user SET plugin='caching_sha2_password', authentication_string=0x24412430303524532C06366D1D1E2B2F4437681A057B6807193D1C4B6E772F667A764663534E6C3978716C3057644D73427A787747674679687632644A384F337941704A542F WHERE user='sha_test' AND host='%'")
tk.MustExec("FLUSH PRIVILEGES")
// Creating users with `IDENTIFIED WITH 'caching_sha2_password'`
tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' BY 'temp_passwd'")

// Compare only the start of the output as the salt changes every time.
rows = tk.MustQuery("SHOW CREATE USER 'sha_test'@'%'")
rows.Check(testkit.Rows("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$S,\x066m\x1d\x1e+/D7h\x1a\x05{h\a\x19=\x1cKnw/fzvFcSNl9xql0WdMsBzxwGgFyhv2dJ8O3yApJT/' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK"))
c.Assert(rows.Rows()[0][0].(string)[:78], check.Equals, "CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$")

}

func (s *testSuite5) TestUnprivilegedShow(c *C) {
Expand Down
38 changes: 35 additions & 3 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
continue
}
pwd, ok := spec.EncodedPassword()

if !ok {
return errors.Trace(ErrPasswordFormat)
}
authPlugin := mysql.AuthNativePassword
if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" {
authPlugin = spec.AuthOpt.AuthPlugin
}
if s.IsCreateRole {
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword, "Y")
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin, "Y")
} else {
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword)
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin)
}
users = append(users, spec.User)
}
Expand Down Expand Up @@ -908,6 +913,13 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error {
continue
}

authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname)
if err != nil {
return err
}
if spec.AuthOpt != nil {
spec.AuthOpt.AuthPlugin = authplugin
}
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
if spec.AuthOpt != nil {
pwd, ok := spec.EncodedPassword()
Expand Down Expand Up @@ -1322,6 +1334,15 @@ func userExistsInternal(sqlExecutor sqlexec.SQLExecutor, name string, host strin
return rows > 0, err
}

func (e *SimpleExec) userAuthPlugin(name string, host string) (string, error) {
pm := privilege.GetPrivilegeManager(e.ctx)
authplugin, err := pm.GetAuthPlugin(name, host)
if err != nil {
return "", err
}
return authplugin, nil
}

func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
var u, h string
if s.User == nil {
Expand All @@ -1347,9 +1368,20 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
return errors.Trace(ErrPasswordNoMatch)
}

authplugin, err := e.userAuthPlugin(u, h)
if err != nil {
return err
}
var pwd string
if authplugin == mysql.AuthCachingSha2Password {
pwd = auth.NewSha2Password(s.Password)
} else {
pwd = auth.EncodePassword(s.Password)
}

// update mysql.user
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h)
stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, h)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ type Manager interface {

// IsDynamicPrivilege returns if a privilege is in the list of privileges.
IsDynamicPrivilege(privNameInUpper string) bool

// Get the authentication plugin for a user
GetAuthPlugin(user, host string) (string, error)
}

const key keyType = 0
Expand Down
9 changes: 8 additions & 1 deletion privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const (
References_priv,Alter_priv,Execute_priv,Index_priv,Create_view_priv,Show_view_priv,
Create_role_priv,Drop_role_priv,Create_tmp_table_priv,Lock_tables_priv,Create_routine_priv,
Alter_routine_priv,Event_priv,Shutdown_priv,Reload_priv,File_priv,Config_priv,Repl_client_priv,Repl_slave_priv,
account_locked FROM mysql.user`
account_locked,plugin FROM mysql.user`
sqlLoadGlobalGrantsTable = `SELECT HIGH_PRIORITY Host,User,Priv,With_Grant_Option FROM mysql.global_grants`
)

Expand Down Expand Up @@ -96,6 +96,7 @@ type UserRecord struct {
AuthenticationString string
Privileges mysql.PrivilegeType
AccountLocked bool // A role record when this field is true
AuthPlugin string
}

// NewUserRecord return a UserRecord, only use for unit test.
Expand Down Expand Up @@ -632,6 +633,12 @@ func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField
if row.GetEnum(i).String() == "Y" {
value.AccountLocked = true
}
case f.ColumnAsName.L == "plugin":
if row.GetString(i) != "" {
value.AuthPlugin = row.GetString(i)
} else {
value.AuthPlugin = mysql.AuthNativePassword
}
case f.Column.Tp == mysql.TypeEnum:
if row.GetEnum(i).String() != "Y" {
continue
Expand Down
83 changes: 65 additions & 18 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ func (p *UserPrivileges) RequestVerificationWithUser(db, table, column string, p
return mysqlPriv.RequestVerification(roles, user.Username, user.Hostname, db, table, column, priv)
}

func (p *UserPrivileges) isValidHash(record *UserRecord) bool {
pwd := record.AuthenticationString
if pwd == "" {
return true
}
if record.AuthPlugin == mysql.AuthNativePassword {
if len(pwd) == mysql.PWDHashLen+1 {
return true
}
logutil.BgLogger().Error("user password from system DB not like a mysql_native_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

if record.AuthPlugin == mysql.AuthCachingSha2Password {
if len(pwd) == mysql.SHAPWDHashLen {
return true
}
logutil.BgLogger().Error("user password from system DB not like a caching_sha2_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

// GetEncodedPassword implements the Manager interface.
func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
mysqlPriv := p.Handle.Get()
Expand All @@ -173,19 +198,28 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
zap.String("user", user), zap.String("host", host))
return ""
}
pwd := record.AuthenticationString
switch len(pwd) {
case 0:
return pwd
case mysql.PWDHashLen + 1: // mysql_native_password
return pwd
case 70: // caching_sha2_password
return pwd
}
logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", user), zap.Int("hash_length", len(pwd)))
if p.isValidHash(record) {
return record.AuthenticationString
}
return ""
}

// GetAuthPlugin gets the authentication plugin for the account identified by the user and host
func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
return "", errors.New("Failed to get user record")
}
if len(record.AuthenticationString) == 0 {
return "", nil
}
if p.isValidHash(record) {
return record.AuthPlugin, nil
}
return "", errors.New("Failed to get plugin for user")
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
if SkipWithGrant {
Expand Down Expand Up @@ -251,8 +285,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

pwd := record.AuthenticationString
if len(pwd) != 0 && len(pwd) != mysql.PWDHashLen+1 {
logutil.BgLogger().Error("user password from system DB not like sha1sum", zap.String("user", user))
if !p.isValidHash(record) {
return
}

Expand All @@ -268,13 +301,27 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

hpwd, err := auth.DecodePassword(pwd)
if err != nil {
logutil.BgLogger().Error("decode password string failed", zap.Error(err))
return
}
if record.AuthPlugin == mysql.AuthNativePassword {
hpwd, err := auth.DecodePassword(pwd)
if err != nil {
logutil.BgLogger().Error("decode password string failed", zap.Error(err))
return
}

if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
return
}
} else if record.AuthPlugin == mysql.AuthCachingSha2Password {
authok, err := auth.CheckShaPassword([]byte(pwd), string(authentication))
if err != nil {
logutil.BgLogger().Error("Failed to check caching_sha2_password", zap.Error(err))
}

if !authok {
return
}
} else {
logutil.BgLogger().Error("unknown authentication plugin", zap.String("user", user), zap.String("plugin", record.AuthPlugin))
return
}

Expand Down
Loading

0 comments on commit f23e100

Please sign in to comment.