Skip to content

Commit

Permalink
Make optional output_location attribute in AthenaOperator (#35265)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Oct 30, 2023
1 parent 6f112cf commit ba4b55a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
7 changes: 3 additions & 4 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,18 @@ def run_query(
client_request_token: str | None = None,
workgroup: str = "primary",
) -> str:
"""Run a Presto query on Athena with provided config.
"""Run a Trino/Presto query on Athena with provided config.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.start_query_execution`
:param query: Presto query to run.
:param query: Trino/Presto query to run.
:param query_context: Context in which query need to be run.
:param result_configuration: Dict with path to store results in and
config related to encryption.
:param client_request_token: Unique token created by user to avoid
multiple executions of same query.
:param workgroup: Athena workgroup name, when not specified, will be
``'primary'``.
:param workgroup: Athena workgroup name, when not specified, will be ``'primary'``.
:return: Submitted query execution ID.
"""
params = {
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
:param database: Database to select. (templated)
:param catalog: Catalog to select. (templated)
:param output_location: s3 path to write the query results into. (templated)
To run the query, you must specify the query results location using one of the ways:
either for individual queries using either this setting (client-side),
or in the workgroup, using WorkGroupConfiguration.
If none of them is set, Athena issues an error that no output location is provided
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:param workgroup: Athena workgroup in which query will be run. (templated)
:param query_execution_context: Context in which query need to be run
Expand Down Expand Up @@ -79,7 +83,7 @@ def __init__(
*,
query: str,
database: str,
output_location: str,
output_location: str | None = None,
client_request_token: str | None = None,
workgroup: str = "primary",
query_execution_context: dict[str, str] | None = None,
Expand Down Expand Up @@ -114,7 +118,8 @@ def execute(self, context: Context) -> str | None:
"""Run Trino/Presto Query on Amazon Athena."""
self.query_execution_context["Database"] = self.database
self.query_execution_context["Catalog"] = self.catalog
self.result_configuration["OutputLocation"] = self.output_location
if self.output_location:
self.result_configuration["OutputLocation"] = self.output_location
self.query_execution_id = self.hook.run_query(
self.query,
self.query_execution_context,
Expand Down
20 changes: 18 additions & 2 deletions tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ def setup_method(self):
task_id="test_athena_operator",
query="SELECT * FROM TEST_TABLE",
database="TEST_DATABASE",
output_location="s3://test_s3_bucket/",
client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
sleep_time=0,
max_polling_attempts=3,
)
self.athena = AthenaOperator(**self.default_op_kwargs, aws_conn_id=None, dag=self.dag)
self.athena = AthenaOperator(
**self.default_op_kwargs, output_location="s3://test_s3_bucket/", aws_conn_id=None, dag=self.dag
)

def test_base_aws_op_attributes(self):
op = AthenaOperator(**self.default_op_kwargs)
Expand Down Expand Up @@ -201,6 +202,21 @@ def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status):

assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID

@mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_optional_output_location(self, mock_conn, mock_run_query, mock_check_query_status):
op = AthenaOperator(**self.default_op_kwargs, aws_conn_id=None)

op.execute({})
mock_run_query.assert_called_once_with(
MOCK_DATA["query"],
query_context,
{}, # Should be an empty dict since we do not provide output_location
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)

@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
def test_is_deferred(self, mock_run_query):
self.athena.deferrable = True
Expand Down

0 comments on commit ba4b55a

Please sign in to comment.