Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: Support for caching_sha2_password authentication #24991

Merged
merged 4 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
if !ok {
return errors.Trace(ErrPasswordFormat)
}
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, mysql.AuthNativePassword)
authPlugin := mysql.AuthNativePassword
if user.AuthOpt.AuthPlugin != "" {
authPlugin = user.AuthOpt.AuthPlugin
}
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx,
`INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`,
mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, authPlugin)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions executor/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"strings"

"github.com/pingcap/check"
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
Expand Down Expand Up @@ -444,14 +445,13 @@ func (s *testSuite5) TestShowCreateUser(c *C) {
rows = tk1.MustQuery("show create user current_user")
rows.Check(testkit.Rows("CREATE USER 'check_priv'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK"))

// Creating users with `IDENTIFIED WITH 'caching_sha2_password'` is not supported yet. So manually creating an entry for now.
// later this can be changed to test the full path once 'caching_sha2_password' support is completed.
tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED BY 'temp_passwd'")
tk.MustExec("UPDATE mysql.user SET plugin='caching_sha2_password', authentication_string=0x24412430303524532C06366D1D1E2B2F4437681A057B6807193D1C4B6E772F667A764663534E6C3978716C3057644D73427A787747674679687632644A384F337941704A542F WHERE user='sha_test' AND host='%'")
tk.MustExec("FLUSH PRIVILEGES")
// Creating users with `IDENTIFIED WITH 'caching_sha2_password'`
tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' BY 'temp_passwd'")

// Compare only the start of the output as the salt changes every time.
rows = tk.MustQuery("SHOW CREATE USER 'sha_test'@'%'")
rows.Check(testkit.Rows("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$S,\x066m\x1d\x1e+/D7h\x1a\x05{h\a\x19=\x1cKnw/fzvFcSNl9xql0WdMsBzxwGgFyhv2dJ8O3yApJT/' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK"))
c.Assert(rows.Rows()[0][0].(string)[:78], check.Equals, "CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$")

}

func (s *testSuite5) TestUnprivilegedShow(c *C) {
Expand Down
38 changes: 35 additions & 3 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
continue
}
pwd, ok := spec.EncodedPassword()

if !ok {
return errors.Trace(ErrPasswordFormat)
}
authPlugin := mysql.AuthNativePassword
if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" {
authPlugin = spec.AuthOpt.AuthPlugin
}
if s.IsCreateRole {
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword, "Y")
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin, "Y")
} else {
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword)
sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin)
}
users = append(users, spec.User)
}
Expand Down Expand Up @@ -908,6 +913,13 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error {
continue
}

authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname)
if err != nil {
return err
}
if spec.AuthOpt != nil {
spec.AuthOpt.AuthPlugin = authplugin
}
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
if spec.AuthOpt != nil {
pwd, ok := spec.EncodedPassword()
Expand Down Expand Up @@ -1322,6 +1334,15 @@ func userExistsInternal(sqlExecutor sqlexec.SQLExecutor, name string, host strin
return rows > 0, err
}

func (e *SimpleExec) userAuthPlugin(name string, host string) (string, error) {
pm := privilege.GetPrivilegeManager(e.ctx)
authplugin, err := pm.GetAuthPlugin(name, host)
if err != nil {
return "", err
}
return authplugin, nil
}

func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
var u, h string
if s.User == nil {
Expand All @@ -1347,9 +1368,20 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
return errors.Trace(ErrPasswordNoMatch)
}

authplugin, err := e.userAuthPlugin(u, h)
if err != nil {
return err
}
var pwd string
if authplugin == mysql.AuthCachingSha2Password {
pwd = auth.NewSha2Password(s.Password)
} else {
pwd = auth.EncodePassword(s.Password)
}

// update mysql.user
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h)
stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, h)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ type Manager interface {

// IsDynamicPrivilege returns if a privilege is in the list of privileges.
IsDynamicPrivilege(privNameInUpper string) bool

// Get the authentication plugin for a user
GetAuthPlugin(user, host string) (string, error)
}

const key keyType = 0
Expand Down
9 changes: 8 additions & 1 deletion privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const (
References_priv,Alter_priv,Execute_priv,Index_priv,Create_view_priv,Show_view_priv,
Create_role_priv,Drop_role_priv,Create_tmp_table_priv,Lock_tables_priv,Create_routine_priv,
Alter_routine_priv,Event_priv,Shutdown_priv,Reload_priv,File_priv,Config_priv,Repl_client_priv,Repl_slave_priv,
account_locked FROM mysql.user`
account_locked,plugin FROM mysql.user`
sqlLoadGlobalGrantsTable = `SELECT HIGH_PRIORITY Host,User,Priv,With_Grant_Option FROM mysql.global_grants`
)

Expand Down Expand Up @@ -96,6 +96,7 @@ type UserRecord struct {
AuthenticationString string
Privileges mysql.PrivilegeType
AccountLocked bool // A role record when this field is true
AuthPlugin string
}

// NewUserRecord return a UserRecord, only use for unit test.
Expand Down Expand Up @@ -632,6 +633,12 @@ func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField
if row.GetEnum(i).String() == "Y" {
value.AccountLocked = true
}
case f.ColumnAsName.L == "plugin":
if row.GetString(i) != "" {
value.AuthPlugin = row.GetString(i)
} else {
value.AuthPlugin = mysql.AuthNativePassword
}
case f.Column.Tp == mysql.TypeEnum:
if row.GetEnum(i).String() != "Y" {
continue
Expand Down
83 changes: 65 additions & 18 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ func (p *UserPrivileges) RequestVerificationWithUser(db, table, column string, p
return mysqlPriv.RequestVerification(roles, user.Username, user.Hostname, db, table, column, priv)
}

func (p *UserPrivileges) isValidHash(record *UserRecord) bool {
pwd := record.AuthenticationString
if pwd == "" {
return true
}
if record.AuthPlugin == mysql.AuthNativePassword {
if len(pwd) == mysql.PWDHashLen+1 {
return true
}
logutil.BgLogger().Error("user password from system DB not like a mysql_native_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

if record.AuthPlugin == mysql.AuthCachingSha2Password {
if len(pwd) == mysql.SHAPWDHashLen {
return true
}
logutil.BgLogger().Error("user password from system DB not like a caching_sha2_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd)))
return false
}

// GetEncodedPassword implements the Manager interface.
func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
mysqlPriv := p.Handle.Get()
Expand All @@ -173,19 +198,28 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
zap.String("user", user), zap.String("host", host))
return ""
}
pwd := record.AuthenticationString
switch len(pwd) {
case 0:
return pwd
case mysql.PWDHashLen + 1: // mysql_native_password
return pwd
case 70: // caching_sha2_password
return pwd
}
logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", user), zap.Int("hash_length", len(pwd)))
if p.isValidHash(record) {
return record.AuthenticationString
}
return ""
}

// GetAuthPlugin gets the authentication plugin for the account identified by the user and host
func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
return "", errors.New("Failed to get user record")
}
if len(record.AuthenticationString) == 0 {
return "", nil
}
if p.isValidHash(record) {
return record.AuthPlugin, nil
}
return "", errors.New("Failed to get plugin for user")
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
if SkipWithGrant {
Expand Down Expand Up @@ -251,8 +285,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

pwd := record.AuthenticationString
if len(pwd) != 0 && len(pwd) != mysql.PWDHashLen+1 {
logutil.BgLogger().Error("user password from system DB not like sha1sum", zap.String("user", user))
if !p.isValidHash(record) {
return
}

Expand All @@ -268,13 +301,27 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

hpwd, err := auth.DecodePassword(pwd)
if err != nil {
logutil.BgLogger().Error("decode password string failed", zap.Error(err))
return
}
if record.AuthPlugin == mysql.AuthNativePassword {
hpwd, err := auth.DecodePassword(pwd)
if err != nil {
logutil.BgLogger().Error("decode password string failed", zap.Error(err))
return
}

if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
return
}
} else if record.AuthPlugin == mysql.AuthCachingSha2Password {
authok, err := auth.CheckShaPassword([]byte(pwd), string(authentication))
if err != nil {
logutil.BgLogger().Error("Failed to check caching_sha2_password", zap.Error(err))
}

if !authok {
return
}
} else {
logutil.BgLogger().Error("unknown authentication plugin", zap.String("user", user), zap.String("plugin", record.AuthPlugin))
return
}

Expand Down
Loading