diff --git a/executor/simple.go b/executor/simple.go index 43bbc72733f60..2242cc1916e7b 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -841,7 +841,6 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } failedUsers := make([]string, 0, len(s.UserList)) - notExistUsers := make([]string, 0, len(s.UserList)) sysSession, err := e.getSysSession() defer e.releaseSysSession(sysSession) if err != nil { @@ -849,104 +848,84 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } sqlExecutor := sysSession.(sqlexec.SQLExecutor) + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return err + } + for _, user := range s.UserList { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { return err } if !exists { - notExistUsers = append(notExistUsers, user.String()) - continue + if s.IfExists { + e.ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrUserDropExists.GenWithStackByArgs(user)) + } else { + failedUsers = append(failedUsers, user.String()) + break + } } // begin a transaction to delete a user. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { - return err - } sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } // delete privileges from mysql.db sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } // delete privileges from mysql.tables_priv sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } // delete relationship from mysql.role_edges sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE TO_HOST = '%s' and TO_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE FROM_HOST = '%s' and FROM_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } // delete relationship from mysql.default_roles sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE DEFAULT_ROLE_HOST = '%s' and DEFAULT_ROLE_USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE HOST = '%s' and USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { - return err - } - continue + break } - //TODO: need delete columns_priv once we implement columns_priv functionality. - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { - failedUsers = append(failedUsers, user.String()) - } } - if len(notExistUsers) > 0 { - if s.IfExists { - for _, user := range notExistUsers { - e.ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrUserDropExists.GenWithStackByArgs(user)) - } - } else { - failedUsers = append(failedUsers, notExistUsers...) + if len(failedUsers) == 0 { + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return err + } + } else { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + return err + } + if s.IsDropRole { + return ErrCannotUser.GenWithStackByArgs("DROP ROLE", strings.Join(failedUsers, ",")) } - } - - if len(failedUsers) > 0 { return ErrCannotUser.GenWithStackByArgs("DROP USER", strings.Join(failedUsers, ",")) } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) diff --git a/executor/simple_test.go b/executor/simple_test.go index 02b49b5d3e192..99847e7d511c8 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -377,11 +377,9 @@ func (s *testSuite3) TestUser(c *C) { _, err = tk.Exec(dropUserSQL) c.Check(err, NotNil) dropUserSQL = `DROP USER 'test3'@'localhost';` - _, err = tk.Exec(dropUserSQL) - c.Check(err, NotNil) + tk.MustExec(dropUserSQL) dropUserSQL = `DROP USER 'test1'@'localhost';` - _, err = tk.Exec(dropUserSQL) - c.Check(err, NotNil) + tk.MustExec(dropUserSQL) // Test positive cases without IF EXISTS. createUserSQL = `CREATE USER 'test1'@'localhost', 'test3'@'localhost';` tk.MustExec(createUserSQL) @@ -625,3 +623,20 @@ func (s *testSuite3) TestIssue9111(c *C) { tk.MustExec("drop user 'user_admin'@'localhost';") } + +func (s *testSuite3) TestRoleAtomic(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("create role r2;") + _, err := tk.Exec("create role r1, r2, r3") + c.Check(err, NotNil) + // Check atomic create role. + result := tk.MustQuery(`SELECT user FROM mysql.User WHERE user in ('r1', 'r2', 'r3')`) + result.Check(testkit.Rows("r2")) + // Check atomic drop role. + _, err = tk.Exec("drop role r1, r2, r3") + c.Check(err, NotNil) + result = tk.MustQuery(`SELECT user FROM mysql.User WHERE user in ('r1', 'r2', 'r3')`) + result.Check(testkit.Rows("r2")) + tk.MustExec("drop role r2;") +}