From 72c3ed0a271a56f52fae8794acc28a5a71bb23df Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 29 Feb 2020 15:44:18 +0900 Subject: [PATCH] Fixed parameter format of Decimal type (fix #121) --- pyathena/formatter.py | 8 ++++++-- tests/test_cursor.py | 5 +++++ tests/test_formatter.py | 22 +++++++++++----------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pyathena/formatter.py b/pyathena/formatter.py index bb813763..4848a4a7 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -77,7 +77,11 @@ def _format_seq(formatter, escaper, val): else: formatted = '{0}'.format(formatted) results.append(formatted) - return '({0})'.format(','.join(results)) + return '({0})'.format(', '.join(results)) + + +def _format_decimal(formatter, escaper, val): + return "DECIMAL {0}".format(escaper('{0:f}'.format(val))) _DEFAULT_FORMATTERS = { @@ -87,7 +91,7 @@ def _format_seq(formatter, escaper, val): int: _format_default, float: _format_default, long: _format_default, - Decimal: _format_default, + Decimal: _format_decimal, bool: _format_bool, str: _format_str, unicode: _format_str, diff --git a/tests/test_cursor.py b/tests/test_cursor.py index a033e79b..06b3a0ca 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -236,6 +236,11 @@ def test_unicode(self, cursor): cursor.execute('SELECT %(param)s FROM one_row', {'param': unicode_str}) self.assertEqual(cursor.fetchall(), [(unicode_str,)]) + @with_cursor() + def test_decimal(self, cursor): + cursor.execute('SELECT %(decimal)s', {'decimal': Decimal('0.00000000001')}) + self.assertEqual(cursor.fetchall(), [(Decimal('0.00000000001'),)]) + @with_cursor() def test_null(self, cursor): cursor.execute('SELECT null FROM many_rows') diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 7214bd41..67156138 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -119,13 +119,13 @@ def test_format_decimal(self): expected = """ SELECT * FROM test_table - WHERE col_decimal <= 0.0000000001 + WHERE col_decimal <= DECIMAL '0.0000000001' """.strip() actual = self.format(""" SELECT * FROM test_table - WHERE col_decimal <= %(param).10f + WHERE col_decimal <= %(param)s """, {'param': Decimal('0.0000000001')}) self.assertEqual(actual, expected) @@ -175,7 +175,7 @@ def test_format_none_list(self): expected = """ SELECT * FROM test_table - WHERE col IN (null,null) + WHERE col IN (null, null) """.strip() actual = self.format(""" @@ -190,7 +190,7 @@ def test_format_datetime_list(self): SELECT * FROM test_table WHERE col_timestamp IN - (timestamp'2017-01-01 12:00:00.000',timestamp'2017-01-02 06:00:00.000') + (timestamp'2017-01-01 12:00:00.000', timestamp'2017-01-02 06:00:00.000') """.strip() actual = self.format(""" @@ -205,7 +205,7 @@ def test_format_date_list(self): expected = """ SELECT * FROM test_table - WHERE col_date IN (date'2017-01-01',date'2017-01-02') + WHERE col_date IN (date'2017-01-01', date'2017-01-02') """.strip() actual = self.format(""" @@ -219,7 +219,7 @@ def test_format_int_list(self): expected = """ SELECT * FROM test_table - WHERE col_int IN (1,2) + WHERE col_int IN (1, 2) """.strip() actual = self.format(""" @@ -234,7 +234,7 @@ def test_format_float_list(self): expected = """ SELECT * FROM test_table - WHERE col_float IN (0.100000,0.200000) + WHERE col_float IN (0.100000, 0.200000) """.strip() actual = self.format(""" @@ -248,7 +248,7 @@ def test_format_decimal_list(self): expected = """ SELECT * FROM test_table - WHERE col_decimal IN (0.0000000001,99.9999999999) + WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') """.strip() actual = self.format(""" @@ -262,7 +262,7 @@ def test_format_bool_list(self): expected = """ SELECT * FROM test_table - WHERE col_boolean IN (True,False) + WHERE col_boolean IN (True, False) """.strip() actual = self.format(""" @@ -276,7 +276,7 @@ def test_format_str_list(self): expected = """ SELECT * FROM test_table - WHERE col_string IN ('amazon','athena') + WHERE col_string IN ('amazon', 'athena') """.strip() actual = self.format(""" @@ -290,7 +290,7 @@ def test_format_unicode_list(self): expected = """ SELECT * FROM test_table - WHERE col_string IN ('密林','女神') + WHERE col_string IN ('密林', '女神') """.strip() actual = self.format("""