From 82f365cd33ccf6236597cd0bceae795a7266627c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 28 May 2021 14:14:07 +0200 Subject: [PATCH] Support for caching_sha2_password authentication Issue link: https://github.com/pingcap/tidb/issues/9411 What this does: - Check the `plugin` column of the `mysql.user` table. - Based on the plugin from the user record and the plugin the client sent we may need to switch the authentication plugin to match the one from the user record - For accounts with `caching_sha2_password` send the "fast authentication failed" response to trigger full authentication. - call `auth.CheckShaPassword` to validate the user. Implemented functionality: - Full authentication with `caching_sha2_password` over TLS - The `default_authentication_plugin` variable - `CREATE USER... IDENTIFIED WITH 'caching_sha2_password'...` - `SET PASSWORD...` - `ALTER USER ... IDENTIFIED BY...` Missing functionality: - Support for the RSA public key request packet & response - Support for RSA key based secret exchange - Fast authentication (validate against cached entry) Related: - Requires https://github.com/pingcap/parser/pull/1232 - https://github.com/pingcap/tidb/pull/24141 makes testing of this easier, but this is not required. --- executor/grant.go | 8 +- executor/show_test.go | 12 +-- executor/simple.go | 38 ++++++- privilege/privilege.go | 3 + privilege/privileges/cache.go | 9 +- privilege/privileges/privileges.go | 83 ++++++++++---- server/conn.go | 167 ++++++++++++++++++++++++----- server/conn_test.go | 11 +- session/session.go | 10 ++ session/session_test.go | 23 ++++ sessionctx/variable/sysvar.go | 3 + 11 files changed, 309 insertions(+), 58 deletions(-) diff --git a/executor/grant.go b/executor/grant.go index 49536dc79aa86..60ea934d6f976 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -136,7 +136,13 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { if !ok { return errors.Trace(ErrPasswordFormat) } - _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, mysql.AuthNativePassword) + authPlugin := mysql.AuthNativePassword + if user.AuthOpt.AuthPlugin != "" { + authPlugin = user.AuthOpt.AuthPlugin + } + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, + `INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`, + mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, authPlugin) if err != nil { return err } diff --git a/executor/show_test.go b/executor/show_test.go index 9a18fdfe348b8..0b52704e32782 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -18,6 +18,7 @@ import ( "fmt" "strings" + "github.com/pingcap/check" . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/parser/auth" @@ -444,14 +445,13 @@ func (s *testSuite5) TestShowCreateUser(c *C) { rows = tk1.MustQuery("show create user current_user") rows.Check(testkit.Rows("CREATE USER 'check_priv'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) - // Creating users with `IDENTIFIED WITH 'caching_sha2_password'` is not supported yet. So manually creating an entry for now. - // later this can be changed to test the full path once 'caching_sha2_password' support is completed. - tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED BY 'temp_passwd'") - tk.MustExec("UPDATE mysql.user SET plugin='caching_sha2_password', authentication_string=0x24412430303524532C06366D1D1E2B2F4437681A057B6807193D1C4B6E772F667A764663534E6C3978716C3057644D73427A787747674679687632644A384F337941704A542F WHERE user='sha_test' AND host='%'") - tk.MustExec("FLUSH PRIVILEGES") + // Creating users with `IDENTIFIED WITH 'caching_sha2_password'` + tk.MustExec("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' BY 'temp_passwd'") + // Compare only the start of the output as the salt changes every time. rows = tk.MustQuery("SHOW CREATE USER 'sha_test'@'%'") - rows.Check(testkit.Rows("CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$S,\x066m\x1d\x1e+/D7h\x1a\x05{h\a\x19=\x1cKnw/fzvFcSNl9xql0WdMsBzxwGgFyhv2dJ8O3yApJT/' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) + c.Assert(rows.Rows()[0][0].(string)[:78], check.Equals, "CREATE USER 'sha_test'@'%' IDENTIFIED WITH 'caching_sha2_password' AS '$A$005$") + } func (s *testSuite5) TestUnprivilegedShow(c *C) { diff --git a/executor/simple.go b/executor/simple.go index d4681b9b24757..2ba3483625a8b 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -777,13 +777,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm continue } pwd, ok := spec.EncodedPassword() + if !ok { return errors.Trace(ErrPasswordFormat) } + authPlugin := mysql.AuthNativePassword + if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { + authPlugin = spec.AuthOpt.AuthPlugin + } if s.IsCreateRole { - sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword, "Y") + sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin, "Y") } else { - sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, mysql.AuthNativePassword) + sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, authPlugin) } users = append(users, spec.User) } @@ -908,6 +913,13 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { continue } + authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) + if err != nil { + return err + } + if spec.AuthOpt != nil { + spec.AuthOpt.AuthPlugin = authplugin + } exec := e.ctx.(sqlexec.RestrictedSQLExecutor) if spec.AuthOpt != nil { pwd, ok := spec.EncodedPassword() @@ -1322,6 +1334,15 @@ func userExistsInternal(sqlExecutor sqlexec.SQLExecutor, name string, host strin return rows > 0, err } +func (e *SimpleExec) userAuthPlugin(name string, host string) (string, error) { + pm := privilege.GetPrivilegeManager(e.ctx) + authplugin, err := pm.GetAuthPlugin(name, host) + if err != nil { + return "", err + } + return authplugin, nil +} + func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { var u, h string if s.User == nil { @@ -1347,9 +1368,20 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { return errors.Trace(ErrPasswordNoMatch) } + authplugin, err := e.userAuthPlugin(u, h) + if err != nil { + return err + } + var pwd string + if authplugin == mysql.AuthCachingSha2Password { + pwd = auth.NewSha2Password(s.Password) + } else { + pwd = auth.EncodePassword(s.Password) + } + // update mysql.user exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, h) if err != nil { return err } diff --git a/privilege/privilege.go b/privilege/privilege.go index f732d9da1199b..f3fbade9cfc0b 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -80,6 +80,9 @@ type Manager interface { // IsDynamicPrivilege returns if a privilege is in the list of privileges. IsDynamicPrivilege(privNameInUpper string) bool + + // Get the authentication plugin for a user + GetAuthPlugin(user, host string) (string, error) } const key keyType = 0 diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index caf25df1eeaaf..cd1797f1e0ba7 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -63,7 +63,7 @@ const ( References_priv,Alter_priv,Execute_priv,Index_priv,Create_view_priv,Show_view_priv, Create_role_priv,Drop_role_priv,Create_tmp_table_priv,Lock_tables_priv,Create_routine_priv, Alter_routine_priv,Event_priv,Shutdown_priv,Reload_priv,File_priv,Config_priv,Repl_client_priv,Repl_slave_priv, - account_locked FROM mysql.user` + account_locked,plugin FROM mysql.user` sqlLoadGlobalGrantsTable = `SELECT HIGH_PRIORITY Host,User,Priv,With_Grant_Option FROM mysql.global_grants` ) @@ -96,6 +96,7 @@ type UserRecord struct { AuthenticationString string Privileges mysql.PrivilegeType AccountLocked bool // A role record when this field is true + AuthPlugin string } // NewUserRecord return a UserRecord, only use for unit test. @@ -632,6 +633,12 @@ func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField if row.GetEnum(i).String() == "Y" { value.AccountLocked = true } + case f.ColumnAsName.L == "plugin": + if row.GetString(i) != "" { + value.AuthPlugin = row.GetString(i) + } else { + value.AuthPlugin = mysql.AuthNativePassword + } case f.Column.Tp == mysql.TypeEnum: if row.GetEnum(i).String() != "Y" { continue diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 744c4882c319a..f8ab4270525ee 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -164,6 +164,31 @@ func (p *UserPrivileges) RequestVerificationWithUser(db, table, column string, p return mysqlPriv.RequestVerification(roles, user.Username, user.Hostname, db, table, column, priv) } +func (p *UserPrivileges) isValidHash(record *UserRecord) bool { + pwd := record.AuthenticationString + if pwd == "" { + return true + } + if record.AuthPlugin == mysql.AuthNativePassword { + if len(pwd) == mysql.PWDHashLen+1 { + return true + } + logutil.BgLogger().Error("user password from system DB not like a mysql_native_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + return false + } + + if record.AuthPlugin == mysql.AuthCachingSha2Password { + if len(pwd) == mysql.SHAPWDHashLen { + return true + } + logutil.BgLogger().Error("user password from system DB not like a caching_sha2_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + return false + } + + logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + return false +} + // GetEncodedPassword implements the Manager interface. func (p *UserPrivileges) GetEncodedPassword(user, host string) string { mysqlPriv := p.Handle.Get() @@ -173,19 +198,28 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string { zap.String("user", user), zap.String("host", host)) return "" } - pwd := record.AuthenticationString - switch len(pwd) { - case 0: - return pwd - case mysql.PWDHashLen + 1: // mysql_native_password - return pwd - case 70: // caching_sha2_password - return pwd - } - logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", user), zap.Int("hash_length", len(pwd))) + if p.isValidHash(record) { + return record.AuthenticationString + } return "" } +// GetAuthPlugin gets the authentication plugin for the account identified by the user and host +func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { + mysqlPriv := p.Handle.Get() + record := mysqlPriv.connectionVerification(user, host) + if record == nil { + return "", errors.New("Failed to get user record") + } + if len(record.AuthenticationString) == 0 { + return "", nil + } + if p.isValidHash(record) { + return record.AuthPlugin, nil + } + return "", errors.New("Failed to get plugin for user") +} + // GetAuthWithoutVerification implements the Manager interface. func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) { if SkipWithGrant { @@ -251,8 +285,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio } pwd := record.AuthenticationString - if len(pwd) != 0 && len(pwd) != mysql.PWDHashLen+1 { - logutil.BgLogger().Error("user password from system DB not like sha1sum", zap.String("user", user)) + if !p.isValidHash(record) { return } @@ -268,13 +301,27 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } - hpwd, err := auth.DecodePassword(pwd) - if err != nil { - logutil.BgLogger().Error("decode password string failed", zap.Error(err)) - return - } + if record.AuthPlugin == mysql.AuthNativePassword { + hpwd, err := auth.DecodePassword(pwd) + if err != nil { + logutil.BgLogger().Error("decode password string failed", zap.Error(err)) + return + } - if !auth.CheckScrambledPassword(salt, hpwd, authentication) { + if !auth.CheckScrambledPassword(salt, hpwd, authentication) { + return + } + } else if record.AuthPlugin == mysql.AuthCachingSha2Password { + authok, err := auth.CheckShaPassword([]byte(pwd), string(authentication)) + if err != nil { + logutil.BgLogger().Error("Failed to check caching_sha2_password", zap.Error(err)) + } + + if !authok { + return + } + } else { + logutil.BgLogger().Error("unknown authentication plugin", zap.String("user", user), zap.String("plugin", record.AuthPlugin)) return } diff --git a/server/conn.go b/server/conn.go index 4fe41782d851d..130e7768bf702 100644 --- a/server/conn.go +++ b/server/conn.go @@ -158,6 +158,7 @@ func newClientConn(s *Server) *clientConn { alloc: arena.NewAllocator(32 * 1024), status: connStatusDispatching, lastActive: time.Now(), + authPlugin: mysql.AuthNativePassword, } } @@ -182,7 +183,8 @@ type clientConn struct { status int32 // dispatching/reading/shutdown/waitshutdown lastCode uint16 // last error code collation uint8 // collation used by client, may be different from the collation used by database. - lastActive time.Time + lastActive time.Time // last active time + authPlugin string // default authentication plugin // mu is used for cancelling the execution of current transaction. mu struct { @@ -198,15 +200,17 @@ func (cc *clientConn) String() string { ) } -// authSwitchRequest is used when the client asked to speak something -// other than mysql_native_password. The server is allowed to ask -// the client to switch, so lets ask for mysql_native_password +// authSwitchRequest is used by the server to ask the client to switch to a different authentication +// plugin. MySQL 8.0 libmysqlclient based clients by default always try `caching_sha2_password`, even +// when the server advertises the its default to be `mysql_native_password`. In addition to this switching +// may be needed on a per user basis as the authentication method is set per user. // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest -func (cc *clientConn) authSwitchRequest(ctx context.Context) ([]byte, error) { - enclen := 1 + len(mysql.AuthNativePassword) + 1 + len(cc.salt) + 1 +// https://bugs.mysql.com/bug.php?id=93044 +func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) { + enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1 data := cc.alloc.AllocWithLen(4, enclen) data = append(data, mysql.AuthSwitchRequest) // switch request - data = append(data, []byte(mysql.AuthNativePassword)...) + data = append(data, []byte(plugin)...) data = append(data, byte(0x00)) // requires null data = append(data, cc.salt...) data = append(data, 0) @@ -230,6 +234,7 @@ func (cc *clientConn) authSwitchRequest(ctx context.Context) ([]byte, error) { } return nil, err } + cc.authPlugin = plugin return resp, nil } @@ -348,9 +353,29 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error { data = append(data, cc.salt[8:]...) data = append(data, 0) // auth-plugin name - data = append(data, []byte(mysql.AuthNativePassword)...) + if cc.ctx == nil { + err := cc.openSession() + if err != nil { + return err + } + } + defAuthPlugin, err := variable.GetGlobalSystemVar(cc.ctx.GetSessionVars(), variable.DefaultAuthPlugin) + if err != nil { + return err + } + cc.authPlugin = defAuthPlugin + data = append(data, []byte(defAuthPlugin)...) + + // Close the session to force this to be re-opened after we parse the response. This is needed + // to ensure we use the collation and client flags from the response for the session. + err = cc.ctx.Close() + if err != nil { + return err + } + cc.ctx = nil + data = append(data, 0) - err := cc.writePacket(data) + err = cc.writePacket(data) if err != nil { return err } @@ -492,11 +517,15 @@ func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41 // MySQL client sets the wrong capability, it will set this bit even server doesn't // support ClientPluginAuthLenencClientData. // https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478 - num, null, off := parseLengthEncodedInt(data[offset:]) - offset += off - if !null { - packet.Auth = data[offset : offset+int(num)] - offset += int(num) + if data[offset] == 0x1 { // No auth data + offset += 2 + } else { + num, null, off := parseLengthEncodedInt(data[offset:]) + offset += off + if !null { + packet.Auth = data[offset : offset+int(num)] + offset += int(num) + } } } else if packet.Capability&mysql.ClientSecureConnection > 0 { // auth length and auth @@ -643,27 +672,66 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } - // switching from other methods should work, but not tested - if resp.AuthPlugin == mysql.AuthCachingSha2Password { - resp.Auth, err = cc.authSwitchRequest(ctx) - if err != nil { - logutil.Logger(ctx).Warn("attempt to send auth switch request packet failed", zap.Error(err)) - return err - } - } cc.capability = resp.Capability & cc.server.capability cc.user = resp.User cc.dbname = resp.DBName cc.collation = resp.Collation cc.attrs = resp.Attrs + newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin) + if err != nil { + logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + } + if len(newAuth) > 0 { + resp.Auth = newAuth + } + + switch resp.AuthPlugin { + case mysql.AuthCachingSha2Password: + resp.Auth, err = cc.authSha(ctx) + if err != nil { + return err + } + case mysql.AuthNativePassword: + default: + return errors.New("Unknown auth plugin") + } + err = cc.openSessionAndDoAuth(resp.Auth) if err != nil { - logutil.Logger(ctx).Warn("open new session failure", zap.Error(err)) + logutil.Logger(ctx).Warn("open new session or authentication failure", zap.Error(err)) } return err } +func (cc *clientConn) authSha(ctx context.Context) ([]byte, error) { + + const ( + ShaCommand = 1 + RequestRsaPubKey = 2 + FastAuthOk = 3 + FastAuthFail = 4 + ) + + err := cc.writePacket([]byte{0, 0, 0, 0, ShaCommand, FastAuthFail}) + if err != nil { + logutil.Logger(ctx).Error("authSha packet write failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Error("authSha packet flush failed", zap.Error(err)) + return nil, err + } + + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error("authSha packet read failed", zap.Error(err)) + return nil, err + } + return bytes.Trim(data, "\x00"), nil +} + func (cc *clientConn) SessionStatusToString() string { status := cc.ctx.Status() inTxn, autoCommit := 0, 0 @@ -678,7 +746,7 @@ func (cc *clientConn) SessionStatusToString() string { ) } -func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { +func (cc *clientConn) openSession() error { var tlsStatePtr *tls.ConnectionState if cc.tlsConn != nil { tlsState := cc.tlsConn.ConnectionState() @@ -690,9 +758,22 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { return err } - if err = cc.server.checkConnectionCount(); err != nil { + err = cc.server.checkConnectionCount() + if err != nil { return err } + return nil +} + +func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { + // Open a context unless this was done before. + if cc.ctx == nil { + err := cc.openSession() + if err != nil { + return err + } + } + hasPassword := "YES" if len(authData) == 0 { hasPassword = "NO" @@ -715,6 +796,42 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { return nil } +// Check if the Authentication Plugin of the server, client and user configuration matches +func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) { + // Open a context unless this was done before. + if cc.ctx == nil { + err := cc.openSession() + if err != nil { + return nil, err + } + } + + userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + if err != nil { + return nil, err + } + if len(userplugin) == 0 { + *authPlugin = mysql.AuthNativePassword + return nil, nil + } + + // If the authentication method send by the server (cc.authPlugin) doesn't match + // the plugin configured for the user account in the mysql.user.plugin column + // or if the authentication method send by the server doesn't match the authentication + // method send by the client (*authPlugin) then we need to switch the authentication + // method to match the one configured for that specific user. + if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) { + authData, err := cc.authSwitchRequest(ctx, userplugin) + if err != nil { + return nil, err + } + *authPlugin = userplugin + return authData, nil + } + + return nil, nil +} + func (cc *clientConn) PeerHost(hasPassword string) (host, port string, err error) { if len(cc.peerHost) > 0 { return cc.peerHost, "", nil diff --git a/server/conn_test.go b/server/conn_test.go index 20c63b6e4f9b7..6c259303d78a6 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -205,17 +205,20 @@ func (ts *ConnTestSuite) TestAuthSwitchRequest(c *C) { func (ts *ConnTestSuite) TestInitialHandshake(c *C) { c.Parallel() var outBuffer bytes.Buffer + cfg := newTestConfig() + drv := NewTiDBDriver(ts.store) + srv, err := NewServer(cfg, drv) + c.Assert(err, IsNil) cc := &clientConn{ connectionID: 1, salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14}, - server: &Server{ - capability: defaultCapability, - }, + server: srv, pkt: &packetIO{ bufWriter: bufio.NewWriter(&outBuffer), }, } - err := cc.writeInitialHandshake(context.TODO()) + + err = cc.writeInitialHandshake(context.TODO()) c.Assert(err, IsNil) expected := new(bytes.Buffer) diff --git a/session/session.go b/session/session.go index 322aa1ef45954..af2528a7ead10 100644 --- a/session/session.go +++ b/session/session.go @@ -149,6 +149,7 @@ type Session interface { Close() Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool AuthWithoutVerification(user *auth.UserIdentity) bool + AuthPluginForUser(user *auth.UserIdentity) (string, error) ShowProcess() *util.ProcessInfo // Return the information of the txn current running TxnInfo() *txninfo.TxnInfo @@ -2147,6 +2148,15 @@ func (s *session) GetSessionVars() *variable.SessionVars { return s.sessionVars } +func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { + pm := privilege.GetPrivilegeManager(s) + authplugin, err := pm.GetAuthPlugin(user.Username, user.Hostname) + if err != nil { + return "", err + } + return authplugin, nil +} + func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) diff --git a/session/session_test.go b/session/session_test.go index 04b35919f2992..9f9371007bbb1 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4823,3 +4823,26 @@ func (s *testStatisticsSuite) TestNewCollationStatsWithPrefixIndex(c *C) { "1 3 15 0 2 0", )) } + +func (s *testSessionSuite) TestAuthPluginForUser(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("CREATE USER 'tapfu1' IDENTIFIED WITH mysql_native_password BY 'tapfu1'") + plugin, err := tk.Se.AuthPluginForUser(&auth.UserIdentity{Username: "tapfu1", Hostname: `%`}) + c.Assert(err, IsNil) + c.Assert(plugin, Equals, "mysql_native_password") + + tk.MustExec("CREATE USER 'tapfu2' IDENTIFIED WITH mysql_native_password") + plugin, err = tk.Se.AuthPluginForUser(&auth.UserIdentity{Username: "tapfu2", Hostname: `%`}) + c.Assert(err, IsNil) + c.Assert(plugin, Equals, "") + + tk.MustExec("CREATE USER 'tapfu3' IDENTIFIED WITH caching_sha2_password BY 'tapfu3'") + plugin, err = tk.Se.AuthPluginForUser(&auth.UserIdentity{Username: "tapfu3", Hostname: `%`}) + c.Assert(err, IsNil) + c.Assert(plugin, Equals, "caching_sha2_password") + + tk.MustExec("CREATE USER 'tapfu4' IDENTIFIED WITH caching_sha2_password") + plugin, err = tk.Se.AuthPluginForUser(&auth.UserIdentity{Username: "tapfu4", Hostname: `%`}) + c.Assert(err, IsNil) + c.Assert(plugin, Equals, "") +} diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 028c92acad0d3..2f4a43a4889c3 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1759,6 +1759,7 @@ var defaultSysVars = []*SysVar{ return nil }}, {Scope: ScopeGlobal, Name: SkipNameResolve, Value: Off, Type: TypeBool}, + {Scope: ScopeGlobal, Name: DefaultAuthPlugin, Value: mysql.AuthNativePassword, Type: TypeEnum, PossibleValues: []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password}}, } // FeedbackProbability points to the FeedbackProbability in statistics package. @@ -2037,6 +2038,8 @@ const ( SystemTimeZone = "system_time_zone" // CTEMaxRecursionDepth is the name of 'cte_max_recursion_depth' system variable. CTEMaxRecursionDepth = "cte_max_recursion_depth" + // DefaultAuthPlugin is the name of 'default_authentication_plugin' system variable. + DefaultAuthPlugin = "default_authentication_plugin" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables.