Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Athena parameterized queries when paramstyle is qmark (fix #545) #557

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyathena/arrow/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]:
if self._unload:
Expand All @@ -125,6 +126,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return (
query_id,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> ArrowCursor:
self._reset_state()
Expand All @@ -129,6 +130,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down
2 changes: 2 additions & 0 deletions pyathena/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]:
query_id = self._execute(
Expand All @@ -115,6 +116,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return query_id, self._executor.submit(self._collect_result_set, query_id)

Expand Down
15 changes: 13 additions & 2 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import pyathena
from pyathena.converter import Converter, DefaultTypeConverter
from pyathena.error import DatabaseError, OperationalError, ProgrammingError
from pyathena.formatter import Formatter
Expand Down Expand Up @@ -144,6 +145,7 @@ def _build_start_query_execution_request(
s3_staging_dir: Optional[str] = None,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
execution_parameters: Optional[List[str]] = None,
) -> Dict[str, Any]:
request: Dict[str, Any] = {
"QueryString": query,
Expand Down Expand Up @@ -177,6 +179,8 @@ def _build_start_query_execution_request(
else self._result_reuse_minutes,
}
request["ResultReuseConfiguration"] = {"ResultReuseByAgeConfiguration": reuse_conf}
if execution_parameters:
request["ExecutionParameters"] = execution_parameters
return request

def _build_start_calculation_execution_request(
Expand Down Expand Up @@ -546,15 +550,21 @@ def _find_previous_query_id(
def _execute(
self,
operation: str,
parameters: Optional[Dict[str, Any]] = None,
parameters: Optional[Union[Dict[str, Any], List[str]]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
) -> str:
query = self._formatter.format(operation, parameters)
if pyathena.paramstyle == "qmark" or paramstyle == "qmark":
query = operation
execution_parameters = cast(Optional[List[str]], parameters)
else:
query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters))
execution_parameters = None
_logger.debug(query)

request = self._build_start_query_execution_request(
Expand All @@ -563,6 +573,7 @@ def _execute(
s3_staging_dir=s3_staging_dir,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
execution_parameters=execution_parameters,
)
query_id = self._find_previous_query_id(
query,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def execute(
cache_expiration_time: int = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Cursor:
self._reset_state()
Expand All @@ -94,6 +95,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down
2 changes: 2 additions & 0 deletions pyathena/pandas/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
keep_default_na: bool = False,
na_values: Optional[Iterable[str]] = ("",),
quoting: int = 1,
Expand All @@ -138,6 +139,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return (
query_id,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/pandas/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
keep_default_na: bool = False,
na_values: Optional[Iterable[str]] = ("",),
quoting: int = 1,
Expand All @@ -154,6 +155,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down