From 891c2e401928ecafea78f7c6c3b453663ef03dce Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Wed, 5 Jul 2023 02:19:47 +0530 Subject: [PATCH] Add Deferrable switch to SnowflakeSqlApiOperator (#31596) --- .../snowflake/hooks/snowflake_sql_api.py | 43 +++++-- .../snowflake/operators/snowflake.py | 46 ++++++- airflow/providers/snowflake/provider.yaml | 5 + .../providers/snowflake/triggers/__init__.py | 16 +++ .../snowflake/triggers/snowflake_trigger.py | 109 ++++++++++++++++ .../operators/snowflake.rst | 2 + .../snowflake/hooks/test_snowflake_sql_api.py | 56 +++++++- .../snowflake/operators/test_snowflake.py | 96 +++++++++++++- .../providers/snowflake/triggers/__init__.py | 16 +++ .../snowflake/triggers/test_snowflake.py | 120 ++++++++++++++++++ 10 files changed, 490 insertions(+), 19 deletions(-) create mode 100644 airflow/providers/snowflake/triggers/__init__.py create mode 100644 airflow/providers/snowflake/triggers/snowflake_trigger.py create mode 100644 tests/providers/snowflake/triggers/__init__.py create mode 100644 tests/providers/snowflake/triggers/test_snowflake.py diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 0d808291ffa0..eec3c7349e1e 100644 --- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -21,6 +21,7 @@ from pathlib import Path from typing import Any +import aiohttp import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -59,7 +60,8 @@ class SnowflakeSqlApiHook(SnowflakeHook): :param session_parameters: You can set session-level parameters at the time you connect to Snowflake :param token_life_time: lifetime of the JWT Token in timedelta - :param token_renewal_delta: Renewal time of the JWT Token in timedelta + :param token_renewal_delta: Renewal time of the JWT Token in timedelta + :param deferrable: Run operator in the deferrable mode. """ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime @@ -225,17 +227,7 @@ def check_query_output(self, query_ids: list[str]) -> None: f"Response: {e.response.content}, Status Code: {e.response.status_code}" ) - def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: - """ - Based on the query id async HTTP request is made to snowflake SQL API and return response. - - :param query_id: statement handle id for the individual statements. - """ - self.log.info("Retrieving status for query id %s", {query_id}) - header, params, url = self.get_request_url_header_params(query_id) - response = requests.get(url, params=params, headers=header) - status_code = response.status_code - resp = response.json() + def _process_response(self, status_code, resp): self.log.info("Snowflake SQL GET statements status API response: %s", resp) if status_code == 202: return {"status": "running", "message": "Query statements are still running"} @@ -254,3 +246,30 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: } else: return {"status": "error", "message": resp["message"]} + + def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: + """ + Based on the query id async HTTP request is made to snowflake SQL API and return response. + + :param query_id: statement handle id for the individual statements. + """ + self.log.info("Retrieving status for query id %s", query_id) + header, params, url = self.get_request_url_header_params(query_id) + response = requests.get(url, params=params, headers=header) + status_code = response.status_code + resp = response.json() + return self._process_response(status_code, resp) + + async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]: + """ + Based on the query id async HTTP request is made to snowflake SQL API and return response. + + :param query_id: statement handle id for the individual statements. + """ + self.log.info("Retrieving status for query id %s", query_id) + header, params, url = self.get_request_url_header_params(query_id) + async with aiohttp.ClientSession(headers=header) as session: + async with session.get(url, params=params) as response: + status_code = response.status + resp = await response.json() + return self._process_response(status_code, resp) diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index b56b14b0d5a4..db35fa000731 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -20,7 +20,7 @@ import time import warnings from datetime import timedelta -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, SupportsAbs, cast from airflow import AirflowException from airflow.exceptions import AirflowProviderDeprecationWarning @@ -33,6 +33,7 @@ from airflow.providers.snowflake.hooks.snowflake_sql_api import ( SnowflakeSqlApiHook, ) +from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -430,6 +431,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator): :param bindings: (Optional) Values of bind variables in the SQL statement. When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. + :param deferrable: Run operator in the deferrable mode. """ # noqa LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime @@ -450,6 +452,7 @@ def __init__( token_life_time: timedelta = LIFETIME, token_renewal_delta: timedelta = RENEWAL_DELTA, bindings: dict[str, Any] | None = None, + deferrable: bool = False, **kwargs: Any, ) -> None: self.snowflake_conn_id = snowflake_conn_id @@ -459,6 +462,7 @@ def __init__( self.token_renewal_delta = token_renewal_delta self.bindings = bindings self.execute_async = False + self.deferrable = deferrable if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover hook_params = kwargs.pop("hook_params", {}) # pragma: no cover kwargs["hook_params"] = { @@ -482,6 +486,7 @@ def execute(self, context: Context) -> None: snowflake_conn_id=self.snowflake_conn_id, token_life_time=self.token_life_time, token_renewal_delta=self.token_renewal_delta, + deferrable=self.deferrable, ) self.query_ids = self._hook.execute_query( self.sql, statement_count=self.statement_count, bindings=self.bindings # type: ignore[arg-type] @@ -491,10 +496,23 @@ def execute(self, context: Context) -> None: if self.do_xcom_push: context["ti"].xcom_push(key="query_ids", value=self.query_ids) - statement_status = self.poll_on_queries() - if statement_status["error"]: - raise AirflowException(statement_status["error"]) - self._hook.check_query_output(self.query_ids) + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=SnowflakeSqlApiTrigger( + poll_interval=self.poll_interval, + query_ids=self.query_ids, + snowflake_conn_id=self.snowflake_conn_id, + token_life_time=self.token_life_time, + token_renewal_delta=self.token_renewal_delta, + ), + method_name="execute_complete", + ) + else: + statement_status = self.poll_on_queries() + if statement_status["error"]: + raise AirflowException(statement_status["error"]) + self._hook.check_query_output(self.query_ids) def poll_on_queries(self): """Poll on requested queries.""" @@ -517,3 +535,21 @@ def poll_on_queries(self): queries_in_progress.remove(query_id) time.sleep(self.poll_interval) return {"success": statement_success_status, "error": statement_error_status} + + def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event: + if "status" in event and event["status"] == "error": + msg = f"{event['status']}: {event['message']}" + raise AirflowException(msg) + elif "status" in event and event["status"] == "success": + hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id) + query_ids = cast(List[str], event["statement_query_ids"]) + hook.check_query_output(query_ids) + self.log.info("%s completed successfully.", self.task_id) + else: + self.log.info("%s completed successfully.", self.task_id) diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml index 1e68fbddca55..2cea953ab428 100644 --- a/airflow/providers/snowflake/provider.yaml +++ b/airflow/providers/snowflake/provider.yaml @@ -100,3 +100,8 @@ transfers: connection-types: - hook-class-name: airflow.providers.snowflake.hooks.snowflake.SnowflakeHook connection-type: snowflake + +triggers: + - integration-name: Snowflake + python-modules: + - airflow.providers.snowflake.triggers.snowflake_trigger diff --git a/airflow/providers/snowflake/triggers/__init__.py b/airflow/providers/snowflake/triggers/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/snowflake/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/snowflake/triggers/snowflake_trigger.py b/airflow/providers/snowflake/triggers/snowflake_trigger.py new file mode 100644 index 000000000000..4f1e0cffb299 --- /dev/null +++ b/airflow/providers/snowflake/triggers/snowflake_trigger.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from datetime import timedelta +from typing import Any, AsyncIterator + +from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class SnowflakeSqlApiTrigger(BaseTrigger): + """ + Fetch the status for the query ids passed. + + :param poll_interval: polling period in seconds to check for the status + :param query_ids: List of Query ids to run and poll for the status + :param snowflake_conn_id: Reference to Snowflake connection id + :param token_life_time: lifetime of the JWT Token in timedelta + :param token_renewal_delta: Renewal time of the JWT Token in timedelta + """ + + def __init__( + self, + poll_interval: float, + query_ids: list[str], + snowflake_conn_id: str, + token_life_time: timedelta, + token_renewal_delta: timedelta, + ): + super().__init__() + self.poll_interval = poll_interval + self.query_ids = query_ids + self.snowflake_conn_id = snowflake_conn_id + self.token_life_time = token_life_time + self.token_renewal_delta = token_renewal_delta + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes SnowflakeSqlApiTrigger arguments and classpath.""" + return ( + "airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger", + { + "poll_interval": self.poll_interval, + "query_ids": self.query_ids, + "snowflake_conn_id": self.snowflake_conn_id, + "token_life_time": self.token_life_time, + "token_renewal_delta": self.token_renewal_delta, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Wait for the query the snowflake query to complete.""" + SnowflakeSqlApiHook( + self.snowflake_conn_id, + self.token_life_time, + self.token_renewal_delta, + ) + try: + statement_query_ids: list[str] = [] + for query_id in self.query_ids: + while True: + statement_status = await self.get_query_status(query_id) + if statement_status["status"] not in ["running"]: + break + await asyncio.sleep(self.poll_interval) + if statement_status["status"] == "error": + yield TriggerEvent(statement_status) + return + if statement_status["status"] == "success": + statement_query_ids.extend(statement_status["statement_handles"]) + yield TriggerEvent( + { + "status": "success", + "statement_query_ids": statement_query_ids, + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + + async def get_query_status(self, query_id: str) -> dict[str, Any]: + """ + Async function to check whether the query statement submitted via SQL API is still + running state and returns True if it is still running else + return False. + """ + hook = SnowflakeSqlApiHook( + self.snowflake_conn_id, + self.token_life_time, + self.token_renewal_delta, + ) + return await hook.get_sql_api_query_status_async(query_id) + + def _set_context(self, context): + pass diff --git a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst index 1e80f3af2955..d3ffbec5a4cf 100644 --- a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst +++ b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst @@ -66,6 +66,8 @@ SnowflakeSqlApiOperator Use the :class:`SnowflakeSqlApiHook ` to execute SQL commands in a `Snowflake `__ database. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. +This will ensure that the task is deferred from the Airflow worker slot and polling for the task status happens on the trigger. Using the Operator ^^^^^^^^^^^^^^^^^^ diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py index 61e88de864df..fd2da72c928e 100644 --- a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py +++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py @@ -21,6 +21,7 @@ from pathlib import Path from typing import Any from unittest import mock +from unittest.mock import AsyncMock import pytest import requests @@ -396,7 +397,6 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key( ), pytest.raises(TypeError, match="Password was given but private key is not encrypted."): SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key() - @pytest.mark.asyncio @pytest.mark.parametrize( "status_code,response,expected_response", [ @@ -456,3 +456,57 @@ def json(self): mock_requests.get.return_value = MockResponse(status_code, response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") assert hook.get_sql_api_query_status("uuid") == expected_response + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "status_code,response,expected_response", + [ + ( + 200, + { + "status": "success", + "message": "Statement executed successfully.", + "statementHandle": "uuid", + }, + { + "status": "success", + "message": "Statement executed successfully.", + "statement_handles": ["uuid"], + }, + ), + ( + 200, + { + "status": "success", + "message": "Statement executed successfully.", + "statementHandles": ["uuid", "uuid1"], + }, + { + "status": "success", + "message": "Statement executed successfully.", + "statement_handles": ["uuid", "uuid1"], + }, + ), + (202, {}, {"status": "running", "message": "Query statements are still running"}), + (422, {"status": "error", "message": "test"}, {"status": "error", "message": "test"}), + (404, {"status": "error", "message": "test"}, {"status": "error", "message": "test"}), + ], + ) + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." + "get_request_url_header_params" + ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") + async def test_get_sql_api_query_status_async( + self, mock_get, mock_geturl_header_params, status_code, response, expected_response + ): + """Test Async get_sql_api_query_status_async function by mocking the status, + response and expected response""" + req_id = uuid.uuid4() + params = {"requestId": str(req_id), "page": 2, "pageSize": 10} + mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" + mock_get.return_value.__aenter__.return_value.status = status_code + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=response) + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + response = await hook.get_sql_api_query_status_async("uuid") + assert response == expected_response diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 8f32c6e62d8d..41cbfe671746 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -19,10 +19,13 @@ from unittest import mock +import pendulum import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models.dag import DAG +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.providers.snowflake.operators.snowflake import ( SnowflakeCheckOperator, SnowflakeIntervalCheckOperator, @@ -30,7 +33,9 @@ SnowflakeSqlApiOperator, SnowflakeValueCheckOperator, ) +from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger from airflow.utils import timezone +from airflow.utils.types import DagRunType DEFAULT_DATE = timezone.datetime(2015, 1, 1) DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() @@ -88,6 +93,34 @@ def test_get_db_hook( mock_get_db_hook.assert_called_once() +def create_context(task, dag=None): + if dag is None: + dag = DAG(dag_id="dag") + tzinfo = pendulum.timezone("UTC") + execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) + dag_run = DagRun( + dag_id=dag.dag_id, + execution_date=execution_date, + run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), + ) + + task_instance = TaskInstance(task=task) + task_instance.dag_run = dag_run + task_instance.xcom_push = mock.Mock() + return { + "dag": dag, + "ts": execution_date.isoformat(), + "task": task, + "ti": task_instance, + "task_instance": task_instance, + "run_id": dag_run.run_id, + "dag_run": dag_run, + "execution_date": execution_date, + "data_interval_end": execution_date, + "logical_date": execution_date, + } + + class TestSnowflakeSqlApiOperator: @pytest.fixture def mock_execute_query(self): @@ -142,3 +175,64 @@ def test_snowflake_sql_api_to_fails_when_one_query_fails( mock_get_sql_api_query_status.side_effect = [{"status": "error"}, {"status": "success"}] with pytest.raises(AirflowException): operator.execute(context=None) + + @pytest.mark.parametrize("mock_sql, statement_count", [(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)]) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query") + def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, mock_sql, statement_count): + """ + Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired + when the SnowflakeSqlApiOperator is executed. + """ + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + sql=mock_sql, + statement_count=statement_count, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + operator.execute(create_context(operator)) + + assert isinstance( + exc.value.trigger, SnowflakeSqlApiTrigger + ), "Trigger is not a SnowflakeSqlApiTrigger" + + def test_snowflake_sql_api_execute_complete_failure(self): + """Test SnowflakeSqlApiOperator raise AirflowException of error event""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + deferrable=True, + ) + with pytest.raises(AirflowException): + operator.execute_complete( + context=None, + event={"status": "error", "message": "Test failure message", "type": "FAILED_WITH_ERROR"}, + ) + + @pytest.mark.parametrize( + "mock_event", + [ + None, + ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}), + ], + ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output") + def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event): + """Tests execute_complete assert with successful message""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + deferrable=True, + ) + + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete(context=None, event=mock_event) + mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) diff --git a/tests/providers/snowflake/triggers/__init__.py b/tests/providers/snowflake/triggers/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/snowflake/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/snowflake/triggers/test_snowflake.py b/tests/providers/snowflake/triggers/test_snowflake.py new file mode 100644 index 000000000000..9fc14591625e --- /dev/null +++ b/tests/providers/snowflake/triggers/test_snowflake.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from datetime import timedelta +from unittest import mock + +import pytest + +from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger +from airflow.triggers.base import TriggerEvent + +TASK_ID = "snowflake_check" +POLL_INTERVAL = 1.0 +LIFETIME = timedelta(minutes=59) +RENEWAL_DELTA = timedelta(minutes=54) +MODULE = "airflow.providers.snowflake" + + +class TestSnowflakeSqlApiTrigger: + TRIGGER = SnowflakeSqlApiTrigger( + poll_interval=POLL_INTERVAL, + query_ids=["uuid"], + snowflake_conn_id="test_conn", + token_life_time=LIFETIME, + token_renewal_delta=RENEWAL_DELTA, + ) + + def test_snowflake_sql_trigger_serialization(self): + """ + Asserts that the SnowflakeSqlApiTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = self.TRIGGER.serialize() + assert classpath == "airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger" + assert kwargs == { + "poll_interval": POLL_INTERVAL, + "query_ids": ["uuid"], + "snowflake_conn_id": "test_conn", + "token_life_time": LIFETIME, + "token_renewal_delta": RENEWAL_DELTA, + } + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status") + @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async") + async def test_snowflake_sql_trigger_running( + self, mock_get_sql_api_query_status_async, mock_get_query_status + ): + """Tests that the SnowflakeSqlApiTrigger in running by mocking get_query_status to true""" + mock_get_query_status.return_value = {"status": "running"} + + task = asyncio.create_task(self.TRIGGER.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status") + @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async") + async def test_snowflake_sql_trigger_completed( + self, mock_get_sql_api_query_status_async, mock_get_query_status + ): + """ + Test SnowflakeSqlApiTrigger run method with success status and mock the get_sql_api_query_status + result and get_query_status to False. + """ + mock_get_query_status.return_value = {"status": "success", "statement_handles": ["uuid", "uuid1"]} + statement_query_ids = ["uuid", "uuid1"] + mock_get_sql_api_query_status_async.return_value = { + "message": "Statement executed successfully.", + "status": "success", + "statement_handles": statement_query_ids, + } + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success", "statement_query_ids": statement_query_ids}) == actual + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async") + async def test_snowflake_sql_trigger_failure_status(self, mock_get_sql_api_query_status_async): + """Test SnowflakeSqlApiTrigger task is executed and triggered with failure status.""" + mock_response = { + "status": "error", + "message": "An error occurred when executing the statement. Check " + "the error code and error message for details", + } + mock_get_sql_api_query_status_async.return_value = mock_response + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + assert TriggerEvent(mock_response) == actual + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async") + async def test_snowflake_sql_trigger_exception(self, mock_get_sql_api_query_status_async): + """Tests the SnowflakeSqlApiTrigger does not fire if there is an exception.""" + mock_get_sql_api_query_status_async.side_effect = Exception("Test exception") + + task = [i async for i in self.TRIGGER.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task