Skip to content

Commit

Permalink
Merge pull request #57 from laughingman7743/fix_double_escaping_of_pe…
Browse files Browse the repository at this point in the history
…rcent_character_in_sqla

Fix double escaping of percent character in SQLAlchemy (fix #56)
  • Loading branch information
laughingman7743 authored Nov 26, 2018
2 parents d222366 + 0fdde5a commit 95ee446
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pyathena/sqlalchemy_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from sqlalchemy.engine import reflection
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.exc import NoSuchTableError, OperationalError
from sqlalchemy.sql.compiler import IdentifierPreparer, SQLCompiler
from sqlalchemy.sql.compiler import (BIND_PARAMS, BIND_PARAMS_ESC,
IdentifierPreparer, SQLCompiler)
from sqlalchemy.sql.sqltypes import (BIGINT, BINARY, BOOLEAN, DATE, DECIMAL, FLOAT,
INTEGER, NULLTYPE, STRINGTYPE, TIMESTAMP)
from tenacity import retry_if_exception, stop_after_attempt, wait_exponential
Expand Down Expand Up @@ -39,6 +40,29 @@ class AthenaCompiler(SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{0}'.format(self.function_argspec(fn, **kw))

def visit_textclause(self, textclause, **kw):
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
return self.process(textclause._bindparams[name], **kw)
else:
return self.bindparam_string(name, **kw)

if not self.stack:
self.isplaintext = True

if len(textclause._bindparams) == 0:
# Prevents double escaping of percent character
return textclause.text
else:
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
do_bindparam,
self.post_process_text(textclause.text))
)


_TYPE_MAPPINGS = {
'boolean': BOOLEAN,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_sqlalchemy_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,34 @@ def test_get_column_type(self, engine, connection):
self.assertEqual(dialect._get_column_type('map(integer, integer)'), 'map')
self.assertEqual(dialect._get_column_type('row(a integer, b integer)'), 'row')
self.assertEqual(dialect._get_column_type('decimal(10,1)'), 'decimal')

@with_engine
def test_contain_percents_character_query(self, engine, connection):
query = sqlalchemy.sql.text("""
SELECT date_parse('20191030', '%Y%m%d')
""")
result = engine.execute(query)
self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), )])

@with_engine
def test_query_with_parameter(self, engine, connection):
query = sqlalchemy.sql.text("""
SELECT :word
""")
result = engine.execute(query, word='cat')
self.assertEqual(result.fetchall(), [('cat', )])

@with_engine
def test_contain_percents_character_query_with_parameter(self, engine, connection):
query = sqlalchemy.sql.text("""
SELECT date_parse('20191030', '%Y%m%d'), :word
""")
result = engine.execute(query, word='cat')
self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), 'cat')])

query = sqlalchemy.sql.text("""
SELECT col_string FROM one_row_complex
WHERE col_string LIKE 'a%' OR col_string LIKE :param
""")
result = engine.execute(query, param='b%')
self.assertEqual(result.fetchall(), [('a string', )])

0 comments on commit 95ee446

Please sign in to comment.