Skip to content

Commit

Permalink
Add Deferrable switch to SnowflakeSqlApiOperator (#31596)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 committed Jul 4, 2023
1 parent 8a6766f commit 891c2e4
Show file tree
Hide file tree
Showing 10 changed files with 490 additions and 19 deletions.
43 changes: 31 additions & 12 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import Any

import aiohttp
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand Down Expand Up @@ -59,7 +60,8 @@ class SnowflakeSqlApiHook(SnowflakeHook):
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param token_life_time: lifetime of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
:param deferrable: Run operator in the deferrable mode.
"""

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
Expand Down Expand Up @@ -225,17 +227,7 @@ def check_query_output(self, query_ids: list[str]) -> None:
f"Response: {e.response.content}, Status Code: {e.response.status_code}"
)

def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
"""
Based on the query id async HTTP request is made to snowflake SQL API and return response.
:param query_id: statement handle id for the individual statements.
"""
self.log.info("Retrieving status for query id %s", {query_id})
header, params, url = self.get_request_url_header_params(query_id)
response = requests.get(url, params=params, headers=header)
status_code = response.status_code
resp = response.json()
def _process_response(self, status_code, resp):
self.log.info("Snowflake SQL GET statements status API response: %s", resp)
if status_code == 202:
return {"status": "running", "message": "Query statements are still running"}
Expand All @@ -254,3 +246,30 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
}
else:
return {"status": "error", "message": resp["message"]}

def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
"""
Based on the query id async HTTP request is made to snowflake SQL API and return response.
:param query_id: statement handle id for the individual statements.
"""
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
response = requests.get(url, params=params, headers=header)
status_code = response.status_code
resp = response.json()
return self._process_response(status_code, resp)

async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
"""
Based on the query id async HTTP request is made to snowflake SQL API and return response.
:param query_id: statement handle id for the individual statements.
"""
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
async with aiohttp.ClientSession(headers=header) as session:
async with session.get(url, params=params) as response:
status_code = response.status
resp = await response.json()
return self._process_response(status_code, resp)
46 changes: 41 additions & 5 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, SupportsAbs, cast

from airflow import AirflowException
from airflow.exceptions import AirflowProviderDeprecationWarning
Expand All @@ -33,6 +33,7 @@
from airflow.providers.snowflake.hooks.snowflake_sql_api import (
SnowflakeSqlApiHook,
)
from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -430,6 +431,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
:param bindings: (Optional) Values of bind variables in the SQL statement.
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
:param deferrable: Run operator in the deferrable mode.
""" # noqa

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
Expand All @@ -450,6 +452,7 @@ def __init__(
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
bindings: dict[str, Any] | None = None,
deferrable: bool = False,
**kwargs: Any,
) -> None:
self.snowflake_conn_id = snowflake_conn_id
Expand All @@ -459,6 +462,7 @@ def __init__(
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
self.deferrable = deferrable
if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
kwargs["hook_params"] = {
Expand All @@ -482,6 +486,7 @@ def execute(self, context: Context) -> None:
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
deferrable=self.deferrable,
)
self.query_ids = self._hook.execute_query(
self.sql, statement_count=self.statement_count, bindings=self.bindings # type: ignore[arg-type]
Expand All @@ -491,10 +496,23 @@ def execute(self, context: Context) -> None:
if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)

statement_status = self.poll_on_queries()
if statement_status["error"]:
raise AirflowException(statement_status["error"])
self._hook.check_query_output(self.query_ids)
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=SnowflakeSqlApiTrigger(
poll_interval=self.poll_interval,
query_ids=self.query_ids,
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
),
method_name="execute_complete",
)
else:
statement_status = self.poll_on_queries()
if statement_status["error"]:
raise AirflowException(statement_status["error"])
self._hook.check_query_output(self.query_ids)

def poll_on_queries(self):
"""Poll on requested queries."""
Expand All @@ -517,3 +535,21 @@ def poll_on_queries(self):
queries_in_progress.remove(query_id)
time.sleep(self.poll_interval)
return {"success": statement_success_status, "error": statement_error_status}

def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
query_ids = cast(List[str], event["statement_query_ids"])
hook.check_query_output(query_ids)
self.log.info("%s completed successfully.", self.task_id)
else:
self.log.info("%s completed successfully.", self.task_id)
5 changes: 5 additions & 0 deletions airflow/providers/snowflake/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ transfers:
connection-types:
- hook-class-name: airflow.providers.snowflake.hooks.snowflake.SnowflakeHook
connection-type: snowflake

triggers:
- integration-name: Snowflake
python-modules:
- airflow.providers.snowflake.triggers.snowflake_trigger
16 changes: 16 additions & 0 deletions airflow/providers/snowflake/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
109 changes: 109 additions & 0 deletions airflow/providers/snowflake/triggers/snowflake_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
from datetime import timedelta
from typing import Any, AsyncIterator

from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class SnowflakeSqlApiTrigger(BaseTrigger):
"""
Fetch the status for the query ids passed.
:param poll_interval: polling period in seconds to check for the status
:param query_ids: List of Query ids to run and poll for the status
:param snowflake_conn_id: Reference to Snowflake connection id
:param token_life_time: lifetime of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
"""

def __init__(
self,
poll_interval: float,
query_ids: list[str],
snowflake_conn_id: str,
token_life_time: timedelta,
token_renewal_delta: timedelta,
):
super().__init__()
self.poll_interval = poll_interval
self.query_ids = query_ids
self.snowflake_conn_id = snowflake_conn_id
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes SnowflakeSqlApiTrigger arguments and classpath."""
return (
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger",
{
"poll_interval": self.poll_interval,
"query_ids": self.query_ids,
"snowflake_conn_id": self.snowflake_conn_id,
"token_life_time": self.token_life_time,
"token_renewal_delta": self.token_renewal_delta,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Wait for the query the snowflake query to complete."""
SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)
try:
statement_query_ids: list[str] = []
for query_id in self.query_ids:
while True:
statement_status = await self.get_query_status(query_id)
if statement_status["status"] not in ["running"]:
break
await asyncio.sleep(self.poll_interval)
if statement_status["status"] == "error":
yield TriggerEvent(statement_status)
return
if statement_status["status"] == "success":
statement_query_ids.extend(statement_status["statement_handles"])
yield TriggerEvent(
{
"status": "success",
"statement_query_ids": statement_query_ids,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

async def get_query_status(self, query_id: str) -> dict[str, Any]:
"""
Async function to check whether the query statement submitted via SQL API is still
running state and returns True if it is still running else
return False.
"""
hook = SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)
return await hook.get_sql_api_query_status_async(query_id)

def _set_context(self, context):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ SnowflakeSqlApiOperator
Use the :class:`SnowflakeSqlApiHook <airflow.providers.snowflake.operators.snowflake>` to execute
SQL commands in a `Snowflake <https://docs.snowflake.com/en/>`__ database.

You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.
This will ensure that the task is deferred from the Airflow worker slot and polling for the task status happens on the trigger.

Using the Operator
^^^^^^^^^^^^^^^^^^
Expand Down
56 changes: 55 additions & 1 deletion tests/providers/snowflake/hooks/test_snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import Any
from unittest import mock
from unittest.mock import AsyncMock

import pytest
import requests
Expand Down Expand Up @@ -396,7 +397,6 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key(
), pytest.raises(TypeError, match="Password was given but private key is not encrypted."):
SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key()

@pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,response,expected_response",
[
Expand Down Expand Up @@ -456,3 +456,57 @@ def json(self):
mock_requests.get.return_value = MockResponse(status_code, response)
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
assert hook.get_sql_api_query_status("uuid") == expected_response

@pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,response,expected_response",
[
(
200,
{
"status": "success",
"message": "Statement executed successfully.",
"statementHandle": "uuid",
},
{
"status": "success",
"message": "Statement executed successfully.",
"statement_handles": ["uuid"],
},
),
(
200,
{
"status": "success",
"message": "Statement executed successfully.",
"statementHandles": ["uuid", "uuid1"],
},
{
"status": "success",
"message": "Statement executed successfully.",
"statement_handles": ["uuid", "uuid1"],
},
),
(202, {}, {"status": "running", "message": "Query statements are still running"}),
(422, {"status": "error", "message": "test"}, {"status": "error", "message": "test"}),
(404, {"status": "error", "message": "test"}, {"status": "error", "message": "test"}),
],
)
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
"get_request_url_header_params"
)
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get")
async def test_get_sql_api_query_status_async(
self, mock_get, mock_geturl_header_params, status_code, response, expected_response
):
"""Test Async get_sql_api_query_status_async function by mocking the status,
response and expected response"""
req_id = uuid.uuid4()
params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/"
mock_get.return_value.__aenter__.return_value.status = status_code
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=response)
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
response = await hook.get_sql_api_query_status_async("uuid")
assert response == expected_response
Loading

0 comments on commit 891c2e4

Please sign in to comment.