diff --git a/parser/auth/auth.go b/parser/auth/auth.go index 109fdda04b644..cb9abb883a33d 100644 --- a/parser/auth/auth.go +++ b/parser/auth/auth.go @@ -33,6 +33,7 @@ type UserIdentity struct { CurrentUser bool AuthUsername string // Username matched in privileges system AuthHostname string // Match in privs system (i.e. could be a wildcard) + AuthPlugin string // The plugin specified in handshake, only used during authentication. } // Restore implements Node interface. diff --git a/privilege/privileges/BUILD.bazel b/privilege/privileges/BUILD.bazel index 563ea25ee90ed..25f85025647ab 100644 --- a/privilege/privileges/BUILD.bazel +++ b/privilege/privileges/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//parser/terror", "//privilege", "//sessionctx", + "//sessionctx/sessionstates", "//sessionctx/variable", "//types", "//util", @@ -66,6 +67,7 @@ go_test( "//privilege", "//session", "//sessionctx", + "//sessionctx/sessionstates", "//sessionctx/variable", "//testkit", "//testkit/testsetup", diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index af038f984f413..9fc73fc26fab0 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/sessionstates" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" @@ -544,7 +545,13 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } - if record.AuthPlugin == mysql.AuthTiDBAuthToken { + // If the user uses session token to log in, skip checking record.AuthPlugin. + if user.AuthPlugin == mysql.AuthTiDBSessionToken { + if err = sessionstates.ValidateSessionToken(authentication, user.Username); err != nil { + logutil.BgLogger().Warn("verify session token failed", zap.String("username", user.Username), zap.Error(err)) + return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } + } else if record.AuthPlugin == mysql.AuthTiDBAuthToken { if len(authentication) == 0 { logutil.BgLogger().Error("empty authentication") return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) @@ -617,7 +624,11 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse } else { info.ResourceGroupName = record.ResourceGroup } - info.InSandBoxMode, err = p.CheckPasswordExpired(sessionVars, record) + // Skip checking password expiration if the session is migrated from another session. + // Otherwise, the user cannot log in or execute statements after migration. + if user.AuthPlugin != mysql.AuthTiDBSessionToken { + info.InSandBoxMode, err = p.CheckPasswordExpired(sessionVars, record) + } return } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index fc67838616e8c..9c0cfa2f84dd2 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -20,9 +20,11 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/json" "fmt" "net/url" "os" + "path/filepath" "strings" "testing" "time" @@ -39,6 +41,7 @@ import ( "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/sessionstates" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testutil" @@ -3148,3 +3151,57 @@ func TestPasswordExpireWithSandBoxMode(t *testing.T) { require.NoError(t, err) require.False(t, tk.Session().InSandBoxMode()) } + +func TestVerificationInfoWithSessionTokenPlugin(t *testing.T) { + // prepare signing certs + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + err := util.CreateCertificates(certPath, keyPath, 4096, x509.RSA, x509.UnknownSignatureAlgorithm) + require.NoError(t, err) + sessionstates.SetKeyPath(keyPath) + sessionstates.SetCertPath(certPath) + + // prepare user + store := createStoreAndPrepareDB(t) + rootTk := testkit.NewTestKit(t, store) + rootTk.MustExec(`CREATE USER 'testuser'@'localhost' PASSWORD EXPIRE`) + // prepare session token + token, err := sessionstates.CreateSessionToken("testuser") + require.NoError(t, err) + tokenBytes, err := json.Marshal(token) + require.NoError(t, err) + + // Test password expiration without sandbox. + user := &auth.UserIdentity{Username: "testuser", Hostname: "localhost", AuthPlugin: mysql.AuthTiDBSessionToken} + tk := testkit.NewTestKit(t, store) + err = tk.Session().Auth(user, tokenBytes, nil) + require.NoError(t, err) + require.False(t, tk.Session().InSandBoxMode()) + + // Test password expiration with sandbox. + variable.IsSandBoxModeEnabled.Store(true) + err = tk.Session().Auth(user, tokenBytes, nil) + require.NoError(t, err) + require.False(t, tk.Session().InSandBoxMode()) + + // Disable resource group. + require.Equal(t, "", tk.Session().GetSessionVars().ResourceGroupName) + + // Enable resource group. + variable.EnableResourceControl.Store(true) + err = tk.Session().Auth(user, tokenBytes, nil) + require.NoError(t, err) + require.Equal(t, "default", tk.Session().GetSessionVars().ResourceGroupName) + + // Non-default resource group. + rootTk.MustExec("CREATE RESOURCE GROUP rg1 WRU_PER_SEC = 999") + rootTk.MustExec(`ALTER USER 'testuser'@'localhost' RESOURCE GROUP rg1`) + err = tk.Session().Auth(user, tokenBytes, nil) + require.NoError(t, err) + require.Equal(t, "rg1", tk.Session().GetSessionVars().ResourceGroupName) + + // Wrong token + err = tk.Session().Auth(user, nil, nil) + require.ErrorContains(t, err, "Access denied") +} diff --git a/server/conn.go b/server/conn.go index d96362396fb9e..4d4300c099ecb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -75,7 +75,6 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/sessionstates" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" @@ -833,16 +832,8 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e return errAccessDeniedNoPassword.FastGenByArgs(cc.user, host) } - userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host} - if authPlugin == mysql.AuthTiDBSessionToken { - if !cc.ctx.AuthWithoutVerification(userIdentity) { - return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) - } - if err = sessionstates.ValidateSessionToken(authData, cc.user); err != nil { - logutil.BgLogger().Warn("verify session token failed", zap.String("username", cc.user), zap.Error(err)) - return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) - } - } else if err = cc.ctx.Auth(userIdentity, authData, cc.salt); err != nil { + userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host, AuthPlugin: authPlugin} + if err = cc.ctx.Auth(userIdentity, authData, cc.salt); err != nil { return err } cc.ctx.SetPort(port) diff --git a/server/conn_test.go b/server/conn_test.go index fcfdc2c9b5a42..9f8033cd1f98d 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1484,7 +1484,7 @@ func TestAuthPlugin2(t *testing.T) { require.NoError(t, err) } -func TestAuthTokenPlugin(t *testing.T) { +func TestAuthSessionTokenPlugin(t *testing.T) { // create the cert tempDir := t.TempDir() certPath := filepath.Join(tempDir, "test1_cert.pem") @@ -1555,6 +1555,13 @@ func TestAuthTokenPlugin(t *testing.T) { err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin) require.NoError(t, err) + // login succeeds even if the password expires now + tk.MustExec("ALTER USER auth_session_token PASSWORD EXPIRE") + err = cc.openSessionAndDoAuth([]byte{}, mysql.AuthNativePassword) + require.ErrorContains(t, err, "Your password has expired") + err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin) + require.NoError(t, err) + // wrong token should fail tokenBytes[0] ^= 0xff err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin) diff --git a/session/session.go b/session/session.go index 3359c4c9ac192..7e62b74d53a64 100644 --- a/session/session.go +++ b/session/session.go @@ -4138,6 +4138,10 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte if len(s.lockedTables) > 0 { return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") } + // It's insecure to migrate sandBoxMode because users can fake it. + if s.InSandBoxMode() { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session is in sandbox mode") + } if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { return err diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 21de8d53727d6..4d1541cc9443d 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -1252,6 +1252,16 @@ func TestShowStateFail(t *testing.T) { tk.MustExec("drop table test.t1") }, }, + { + // enable sandbox mode + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.Session().EnableSandBoxMode() + }, + showErr: errno.ErrCannotMigrateSession, + cleanFunc: func(tk *testkit.TestKit) { + tk.Session().DisableSandBoxMode() + }, + }, { // after COM_STMT_SEND_LONG_DATA setFunc: func(tk *testkit.TestKit, conn server.MockConn) {