From c8c0ddcf8132c6b8a639b9f20903a968691e7231 Mon Sep 17 00:00:00 2001 From: hhyo Date: Sun, 14 Apr 2019 16:46:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sql/tests.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/sql/tests.py b/sql/tests.py index bcf1121e36..fbbebbec62 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -601,7 +601,7 @@ def test_user_query_priv_no_query_mgtpriv(self): self.assertEqual(json.loads(r.content), {"total": 0, "rows": []}) -class TestQuery(TestCase): +class TestQuery(TransactionTestCase): def setUp(self): self.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', host='testhost', port=3306, user='mysql_user', password='mysql_password') @@ -625,10 +625,11 @@ def tearDown(self): archer_config = SysConfig() archer_config.set('disable_star', False) + @patch('sql.query.fetch') + @patch('sql.query.async_task') @patch('sql.engines.mysql.MysqlEngine.query') - @patch('sql.engines.mysql.MysqlEngine.query_masking') @patch('sql.query.query_priv_check') - def testCorrectSQL(self, _priv_check, _query_masking, _query): + def testCorrectSQL(self, _priv_check, _query, _async_task, _fetch): c = Client() some_sql = 'select some from some_table limit 100;' some_db = 'some_db' @@ -643,22 +644,23 @@ def testCorrectSQL(self, _priv_check, _query_masking, _query): q_result = ResultSet(full_sql=some_sql, rows=['value']) q_result.column_list = ['some'] - _query.return_value = q_result - _query_masking.return_value = q_result + _async_task.return_value = q_result + _fetch.return_value.result = q_result _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}} r = c.post('/query/', data={'instance_name': self.slave1.instance_name, 'sql_content': some_sql, 'db_name': some_db, 'limit_num': some_limit}) - _query.assert_called_once_with(db_name=some_db, sql=some_sql, limit_num=some_limit) + _async_task.assert_called_once_with(_query, db_name=some_db, sql=some_sql, limit_num=some_limit, timeout=60) r_json = r.json() self.assertEqual(r_json['data']['rows'], ['value']) self.assertEqual(r_json['data']['column_list'], ['some']) + @patch('sql.query.fetch') + @patch('sql.query.async_task') @patch('sql.engines.mysql.MysqlEngine.query') - @patch('sql.engines.mysql.MysqlEngine.query_masking') @patch('sql.query.query_priv_check') - def testSQLWithoutLimit(self, _priv_check, _query_masking, _query): + def testSQLWithoutLimit(self, _priv_check, _query, _async_task, _fetch): c = Client() some_limit = 100 sql_without_limit = 'select some from some_table' @@ -667,14 +669,16 @@ def testSQLWithoutLimit(self, _priv_check, _query_masking, _query): c.force_login(self.u2) q_result = ResultSet(full_sql=sql_without_limit, rows=['value']) q_result.column_list = ['some'] - _query.return_value = q_result - _query_masking.return_value = q_result + _async_task.return_value = q_result + _fetch.return_value.result = q_result + _fetch.return_value.time_taken.return_value = 1 _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}} r = c.post('/query/', data={'instance_name': self.slave1.instance_name, 'sql_content': sql_without_limit, 'db_name': some_db, 'limit_num': some_limit}) - _query.assert_called_once_with(db_name=some_db, sql=sql_with_limit, limit_num=some_limit) + _async_task.assert_called_once_with(_query, db_name=some_db, sql=sql_with_limit, limit_num=some_limit, + timeout=60) r_json = r.json() self.assertEqual(r_json['data']['rows'], ['value']) self.assertEqual(r_json['data']['column_list'], ['some']) @@ -682,12 +686,13 @@ def testSQLWithoutLimit(self, _priv_check, _query_masking, _query): # 带 * 且不带 limit 的sql sql_with_star = 'select * from some_table' filtered_sql_with_star = 'select * from some_table limit {0};'.format(some_limit) - _query.reset_mock() + _async_task.reset_mock() c.post('/query/', data={'instance_name': self.slave1.instance_name, 'sql_content': sql_with_star, 'db_name': some_db, 'limit_num': some_limit}) - _query.assert_called_once_with(db_name=some_db, sql=filtered_sql_with_star, limit_num=some_limit) + _async_task.assert_called_once_with(_query, db_name=some_db, sql=filtered_sql_with_star, limit_num=some_limit, + timeout=60) @patch('sql.query.query_priv_check') def testStarOptionOn(self, _priv_check):