Skip to content

Commit

Permalink
修改测试用例
Browse files Browse the repository at this point in the history
  • Loading branch information
hhyo committed Apr 14, 2019
1 parent 8bc29b3 commit c8c0ddc
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def test_user_query_priv_no_query_mgtpriv(self):
self.assertEqual(json.loads(r.content), {"total": 0, "rows": []})


class TestQuery(TestCase):
class TestQuery(TransactionTestCase):
def setUp(self):
self.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql',
host='testhost', port=3306, user='mysql_user', password='mysql_password')
Expand All @@ -625,10 +625,11 @@ def tearDown(self):
archer_config = SysConfig()
archer_config.set('disable_star', False)

@patch('sql.query.fetch')
@patch('sql.query.async_task')
@patch('sql.engines.mysql.MysqlEngine.query')
@patch('sql.engines.mysql.MysqlEngine.query_masking')
@patch('sql.query.query_priv_check')
def testCorrectSQL(self, _priv_check, _query_masking, _query):
def testCorrectSQL(self, _priv_check, _query, _async_task, _fetch):
c = Client()
some_sql = 'select some from some_table limit 100;'
some_db = 'some_db'
Expand All @@ -643,22 +644,23 @@ def testCorrectSQL(self, _priv_check, _query_masking, _query):
q_result = ResultSet(full_sql=some_sql, rows=['value'])
q_result.column_list = ['some']

_query.return_value = q_result
_query_masking.return_value = q_result
_async_task.return_value = q_result
_fetch.return_value.result = q_result
_priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}}
r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
'sql_content': some_sql,
'db_name': some_db,
'limit_num': some_limit})
_query.assert_called_once_with(db_name=some_db, sql=some_sql, limit_num=some_limit)
_async_task.assert_called_once_with(_query, db_name=some_db, sql=some_sql, limit_num=some_limit, timeout=60)
r_json = r.json()
self.assertEqual(r_json['data']['rows'], ['value'])
self.assertEqual(r_json['data']['column_list'], ['some'])

@patch('sql.query.fetch')
@patch('sql.query.async_task')
@patch('sql.engines.mysql.MysqlEngine.query')
@patch('sql.engines.mysql.MysqlEngine.query_masking')
@patch('sql.query.query_priv_check')
def testSQLWithoutLimit(self, _priv_check, _query_masking, _query):
def testSQLWithoutLimit(self, _priv_check, _query, _async_task, _fetch):
c = Client()
some_limit = 100
sql_without_limit = 'select some from some_table'
Expand All @@ -667,27 +669,30 @@ def testSQLWithoutLimit(self, _priv_check, _query_masking, _query):
c.force_login(self.u2)
q_result = ResultSet(full_sql=sql_without_limit, rows=['value'])
q_result.column_list = ['some']
_query.return_value = q_result
_query_masking.return_value = q_result
_async_task.return_value = q_result
_fetch.return_value.result = q_result
_fetch.return_value.time_taken.return_value = 1
_priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}}
r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
'sql_content': sql_without_limit,
'db_name': some_db,
'limit_num': some_limit})
_query.assert_called_once_with(db_name=some_db, sql=sql_with_limit, limit_num=some_limit)
_async_task.assert_called_once_with(_query, db_name=some_db, sql=sql_with_limit, limit_num=some_limit,
timeout=60)
r_json = r.json()
self.assertEqual(r_json['data']['rows'], ['value'])
self.assertEqual(r_json['data']['column_list'], ['some'])

# 带 * 且不带 limit 的sql
sql_with_star = 'select * from some_table'
filtered_sql_with_star = 'select * from some_table limit {0};'.format(some_limit)
_query.reset_mock()
_async_task.reset_mock()
c.post('/query/', data={'instance_name': self.slave1.instance_name,
'sql_content': sql_with_star,
'db_name': some_db,
'limit_num': some_limit})
_query.assert_called_once_with(db_name=some_db, sql=filtered_sql_with_star, limit_num=some_limit)
_async_task.assert_called_once_with(_query, db_name=some_db, sql=filtered_sql_with_star, limit_num=some_limit,
timeout=60)

@patch('sql.query.query_priv_check')
def testStarOptionOn(self, _priv_check):
Expand Down

0 comments on commit c8c0ddc

Please sign in to comment.