Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed parameter format of Decimal type (fix #121) #123

Merged
merged 1 commit into from
Feb 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyathena/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def _format_seq(formatter, escaper, val):
else:
formatted = '{0}'.format(formatted)
results.append(formatted)
return '({0})'.format(','.join(results))
return '({0})'.format(', '.join(results))


def _format_decimal(formatter, escaper, val):
return "DECIMAL {0}".format(escaper('{0:f}'.format(val)))


_DEFAULT_FORMATTERS = {
Expand All @@ -87,7 +91,7 @@ def _format_seq(formatter, escaper, val):
int: _format_default,
float: _format_default,
long: _format_default,
Decimal: _format_default,
Decimal: _format_decimal,
bool: _format_bool,
str: _format_str,
unicode: _format_str,
Expand Down
5 changes: 5 additions & 0 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def test_unicode(self, cursor):
cursor.execute('SELECT %(param)s FROM one_row', {'param': unicode_str})
self.assertEqual(cursor.fetchall(), [(unicode_str,)])

@with_cursor()
def test_decimal(self, cursor):
cursor.execute('SELECT %(decimal)s', {'decimal': Decimal('0.00000000001')})
self.assertEqual(cursor.fetchall(), [(Decimal('0.00000000001'),)])

@with_cursor()
def test_null(self, cursor):
cursor.execute('SELECT null FROM many_rows')
Expand Down
22 changes: 11 additions & 11 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def test_format_decimal(self):
expected = """
SELECT *
FROM test_table
WHERE col_decimal <= 0.0000000001
WHERE col_decimal <= DECIMAL '0.0000000001'
""".strip()

actual = self.format("""
SELECT *
FROM test_table
WHERE col_decimal <= %(param).10f
WHERE col_decimal <= %(param)s
""", {'param': Decimal('0.0000000001')})
self.assertEqual(actual, expected)

Expand Down Expand Up @@ -175,7 +175,7 @@ def test_format_none_list(self):
expected = """
SELECT *
FROM test_table
WHERE col IN (null,null)
WHERE col IN (null, null)
""".strip()

actual = self.format("""
Expand All @@ -190,7 +190,7 @@ def test_format_datetime_list(self):
SELECT *
FROM test_table
WHERE col_timestamp IN
(timestamp'2017-01-01 12:00:00.000',timestamp'2017-01-02 06:00:00.000')
(timestamp'2017-01-01 12:00:00.000', timestamp'2017-01-02 06:00:00.000')
""".strip()

actual = self.format("""
Expand All @@ -205,7 +205,7 @@ def test_format_date_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_date IN (date'2017-01-01',date'2017-01-02')
WHERE col_date IN (date'2017-01-01', date'2017-01-02')
""".strip()

actual = self.format("""
Expand All @@ -219,7 +219,7 @@ def test_format_int_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_int IN (1,2)
WHERE col_int IN (1, 2)
""".strip()

actual = self.format("""
Expand All @@ -234,7 +234,7 @@ def test_format_float_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_float IN (0.100000,0.200000)
WHERE col_float IN (0.100000, 0.200000)
""".strip()

actual = self.format("""
Expand All @@ -248,7 +248,7 @@ def test_format_decimal_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_decimal IN (0.0000000001,99.9999999999)
WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999')
""".strip()

actual = self.format("""
Expand All @@ -262,7 +262,7 @@ def test_format_bool_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_boolean IN (True,False)
WHERE col_boolean IN (True, False)
""".strip()

actual = self.format("""
Expand All @@ -276,7 +276,7 @@ def test_format_str_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_string IN ('amazon','athena')
WHERE col_string IN ('amazon', 'athena')
""".strip()

actual = self.format("""
Expand All @@ -290,7 +290,7 @@ def test_format_unicode_list(self):
expected = """
SELECT *
FROM test_table
WHERE col_string IN ('密林','女神')
WHERE col_string IN ('密林', '女神')
""".strip()

actual = self.format("""
Expand Down