Skip to content

Commit

Permalink
Add test case for workgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Nov 24, 2019
1 parent 448141e commit 9166710
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 39 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ Depends on the following environment variables:
$ export AWS_DEFAULT_REGION=us-west-2
$ export AWS_ATHENA_S3_STAGING_DIR=s3://YOUR_S3_BUCKET/path/to/
And you need to create a workgroup named ``test-pyathena``.

Run test
~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ENV = Env()
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
S3_PREFIX = 'test_pyathena'
WORK_GROUP = 'test-pyathena'
SCHEMA = 'test_pyathena_' + ''.join([random.choice(
string.ascii_lowercase + string.digits) for _ in xrange(10)])

Expand Down
69 changes: 37 additions & 32 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyathena.cursor import Cursor
from pyathena.error import DatabaseError, NotSupportedError, ProgrammingError
from pyathena.model import AthenaQueryExecution
from tests.conftest import ENV, S3_PREFIX, SCHEMA
from tests.conftest import ENV, S3_PREFIX, SCHEMA, WORK_GROUP
from tests.util import with_cursor


Expand All @@ -30,10 +30,10 @@ class TestCursor(unittest.TestCase):
https://github.com/dropbox/PyHive/blob/master/pyhive/tests/test_presto.py
"""

def connect(self):
return connect(schema_name=SCHEMA)
def connect(self, work_group=None):
return connect(schema_name=SCHEMA, work_group=work_group)

@with_cursor
@with_cursor()
def test_fetchone(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.rownumber, 0)
Expand All @@ -57,26 +57,26 @@ def test_fetchone(self, cursor):
self.assertIsNone(cursor.kms_key)
self.assertEqual(cursor.work_group, 'primary')

@with_cursor
@with_cursor()
def test_fetchmany(self, cursor):
cursor.execute('SELECT * FROM many_rows LIMIT 15')
self.assertEqual(len(cursor.fetchmany(10)), 10)
self.assertEqual(len(cursor.fetchmany(10)), 5)

@with_cursor
@with_cursor()
def test_fetchall(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
cursor.execute('SELECT a FROM many_rows ORDER BY a')
self.assertEqual(cursor.fetchall(), [(i,) for i in xrange(10000)])

@with_cursor
@with_cursor()
def test_iterator(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(list(cursor), [(1,)])
self.assertRaises(StopIteration, cursor.__next__)

@with_cursor
@with_cursor()
def test_cache_size(self, cursor):
# To test caching, we need to make sure the query is unique, otherwise
# we might accidentally pick up the cache results from another CI run.
Expand All @@ -99,67 +99,67 @@ def test_cache_size(self, cursor):
# When using caching, the same query ID should be returned.
self.assertIn(third_query_id, [first_query_id, second_query_id])

@with_cursor
@with_cursor()
def test_arraysize(self, cursor):
cursor.arraysize = 5
cursor.execute('SELECT * FROM many_rows LIMIT 20')
self.assertEqual(len(cursor.fetchmany()), 5)

@with_cursor
@with_cursor()
def test_arraysize_default(self, cursor):
self.assertEqual(cursor.arraysize, Cursor.DEFAULT_FETCH_SIZE)

@with_cursor
@with_cursor()
def test_invalid_arraysize(self, cursor):
with self.assertRaises(ProgrammingError):
cursor.arraysize = 10000
with self.assertRaises(ProgrammingError):
cursor.arraysize = -1

@with_cursor
@with_cursor()
def test_description(self, cursor):
cursor.execute('SELECT 1 AS foobar FROM one_row')
self.assertEqual(cursor.description,
[('foobar', 'integer', None, None, 10, 0, 'UNKNOWN')])

@with_cursor
@with_cursor()
def test_description_initial(self, cursor):
self.assertIsNone(cursor.description)

@with_cursor
@with_cursor()
def test_description_failed(self, cursor):
try:
cursor.execute('blah_blah')
except DatabaseError:
pass
self.assertIsNone(cursor.description)

@with_cursor
@with_cursor()
def test_bad_query(self, cursor):
def run():
cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
cursor.fetchone()
self.assertRaises(DatabaseError, run)

@with_cursor
@with_cursor()
def test_fetch_no_data(self, cursor):
self.assertRaises(ProgrammingError, cursor.fetchone)
self.assertRaises(ProgrammingError, cursor.fetchmany)
self.assertRaises(ProgrammingError, cursor.fetchall)

@with_cursor
@with_cursor()
def test_null_param(self, cursor):
cursor.execute('SELECT %(param)s FROM one_row', {'param': None})
self.assertEqual(cursor.fetchall(), [(None,)])

@with_cursor
@with_cursor()
def test_no_params(self, cursor):
self.assertRaises(DatabaseError, lambda: cursor.execute(
'SELECT %(param)s FROM one_row'))
self.assertRaises(KeyError, lambda: cursor.execute(
'SELECT %(param)s FROM one_row', {'a': 1}))

@with_cursor
@with_cursor()
def test_contain_special_character_query(self, cursor):
cursor.execute("""
SELECT col_string FROM one_row_complex
Expand All @@ -182,7 +182,7 @@ def test_contain_special_character_query(self, cursor):
""")
self.assertEqual(cursor.fetchall(), [('a string', '%%')])

@with_cursor
@with_cursor()
def test_contain_special_character_query_with_parameter(self, cursor):
self.assertRaises(TypeError, lambda: cursor.execute(
"""
Expand Down Expand Up @@ -211,17 +211,17 @@ def test_escape(self):
bad_str = """`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t """
self.run_escape_case(bad_str)

@with_cursor
@with_cursor()
def run_escape_case(self, cursor, bad_str):
cursor.execute('SELECT %(a)d, %(b)s FROM one_row', {'a': 1, 'b': bad_str})
self.assertEqual(cursor.fetchall(), [(1, bad_str,)])

@with_cursor
@with_cursor()
def test_none_empty_query(self, cursor):
self.assertRaises(ProgrammingError, lambda: cursor.execute(None))
self.assertRaises(ProgrammingError, lambda: cursor.execute(''))

@with_cursor
@with_cursor()
def test_invalid_params(self, cursor):
self.assertRaises(TypeError, lambda: cursor.execute(
'SELECT * FROM one_row', {'foo': {'bar': 1}}))
Expand All @@ -233,21 +233,21 @@ def test_open_close(self):
with conn.cursor():
pass

@with_cursor
@with_cursor()
def test_unicode(self, cursor):
unicode_str = '王兢'
cursor.execute('SELECT %(param)s FROM one_row', {'param': unicode_str})
self.assertEqual(cursor.fetchall(), [(unicode_str,)])

@with_cursor
@with_cursor()
def test_null(self, cursor):
cursor.execute('SELECT null FROM many_rows')
self.assertEqual(cursor.fetchall(), [(None,)] * 10000)
cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows')
self.assertEqual(cursor.fetchall(),
[(None if a % 11 == 0 else a,) for a in xrange(10000)])

@with_cursor
@with_cursor()
def test_query_id(self, cursor):
self.assertIsNone(cursor.query_id)
cursor.execute('SELECT * from one_row')
Expand All @@ -256,14 +256,14 @@ def test_query_id(self, cursor):
r'^[0-9a-f]{8}-[0-9a-f]{4}-[4][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$'
self.assertTrue(re.match(expected_pattern, cursor.query_id))

@with_cursor
@with_cursor()
def test_output_location(self, cursor):
self.assertIsNone(cursor.output_location)
cursor.execute('SELECT * from one_row')
self.assertEqual(cursor.output_location,
'{0}{1}.csv'.format(ENV.s3_staging_dir, cursor.query_id))

@with_cursor
@with_cursor()
def test_query_execution_initial(self, cursor):
self.assertFalse(cursor.has_result_set)
self.assertIsNone(cursor.rownumber)
Expand All @@ -277,7 +277,7 @@ def test_query_execution_initial(self, cursor):
self.assertIsNone(cursor.execution_time_in_millis)
self.assertIsNone(cursor.output_location)

@with_cursor
@with_cursor()
def test_complex(self, cursor):
cursor.execute("""
SELECT
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_complex(self, cursor):
NUMBER,
])

@with_cursor
@with_cursor()
def test_cancel(self, cursor):
def cancel(c):
time.sleep(randint(1, 5))
Expand All @@ -382,7 +382,7 @@ def cancel(c):
CROSS JOIN many_rows b
"""))

@with_cursor
@with_cursor()
def test_cancel_initial(self, cursor):
self.assertRaises(ProgrammingError, cursor.cancel)

Expand Down Expand Up @@ -411,7 +411,7 @@ def test_no_ops(self):
cursor.close()
conn.close()

@with_cursor
@with_cursor()
def test_show_partition(self, cursor):
location = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'partition_table')
Expand All @@ -423,3 +423,8 @@ def test_show_partition(self, cursor):
cursor.execute('SHOW PARTITIONS partition_table')
self.assertEqual(sorted(cursor.fetchall()),
[('b={0}'.format(i),) for i in xrange(10)])

@with_cursor(work_group=WORK_GROUP)
def test_workgroup(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.work_group, WORK_GROUP)
16 changes: 9 additions & 7 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ def __init__(self):
'Required environment variable `AWS_ATHENA_S3_STAGING_DIR` not found.'


def with_cursor(fn):
@functools.wraps(fn)
def wrapped_fn(self, *args, **kwargs):
with contextlib.closing(self.connect()) as conn:
with conn.cursor() as cursor:
fn(self, cursor, *args, **kwargs)
return wrapped_fn
def with_cursor(work_group=None):
def _with_cursor(fn):
@functools.wraps(fn)
def wrapped_fn(self, *args, **kwargs):
with contextlib.closing(self.connect(work_group=work_group)) as conn:
with conn.cursor() as cursor:
fn(self, cursor, *args, **kwargs)
return wrapped_fn
return _with_cursor


def with_async_cursor(fn):
Expand Down

0 comments on commit 9166710

Please sign in to comment.