-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Async Snowflake SQL API Operator
Implement Async Snowflake SQL API Operator Implemented Async Snowflake SQL API Operator to support multiple SQL statements sequentially, which is the behavior of the SnowflakeOperator, the Snowflake SQL API allows for submitting multiple SQL statements in a single request. In combination with aiohttp, this may be an option for creating a SnowflakeSQLOperatorAsync that matches the query submission behavior of the SnowflakeOperator. Test case Add Test case Added Test case for Snowflake SQL API Trigger and Operator
- Loading branch information
1 parent
25f4b8a
commit 1a84b0d
Showing
7 changed files
with
991 additions
and
0 deletions.
There are no files selected for viewing
56 changes: 56 additions & 0 deletions
56
astronomer/providers/snowflake/example_dags/example_snowflake_sql_api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Example use of SnowflakeAsync related providers.""" | ||
|
||
import os | ||
from datetime import timedelta | ||
|
||
from airflow import DAG | ||
from airflow.utils.timezone import datetime | ||
|
||
from astronomer.providers.snowflake.operators.snowflake_sql_api import ( | ||
SnowflakeSQLOperatorAsync, | ||
) | ||
|
||
SNOWFLAKE_CONN_ID = os.getenv("ASTRO_SNOWFLAKE_CONN_ID", "snowflake_default") | ||
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) | ||
|
||
# SQL commands | ||
SQL_MULTIPLE_STMTS = ( | ||
"create or replace table user_test (i int); insert into user_test (i) " | ||
"values (200); insert into user_test (i) values (300); select i from user_test order by i;" | ||
) | ||
SINGLE_STMT = "select i from user_test order by i;" | ||
|
||
default_args = { | ||
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), | ||
"snowflake_conn_id": SNOWFLAKE_CONN_ID, | ||
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), | ||
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), | ||
} | ||
|
||
with DAG( | ||
dag_id="example_snowflake_sql_api", | ||
start_date=datetime(2022, 1, 1), | ||
schedule_interval=None, | ||
default_args=default_args, | ||
tags=["example", "async", "snowflake"], | ||
catchup=False, | ||
) as dag: | ||
# [START howto_operator_snowflake_op_sql_multiple_stmt] | ||
snowflake_op_sql_multiple_stmt = SnowflakeSQLOperatorAsync( | ||
task_id="snowflake_op_sql_multiple_stmt", | ||
dag=dag, | ||
sql=SQL_MULTIPLE_STMTS, | ||
statement_count=4, | ||
) | ||
# [END howto_operator_snowflake_op_sql_multiple_stmt] | ||
|
||
# [START howto_operator_snowflake_single_sql_stmt] | ||
snowflake_single_sql_stmt = SnowflakeSQLOperatorAsync( | ||
task_id="snowflake_single_sql_stmt", | ||
dag=dag, | ||
sql=SINGLE_STMT, | ||
statement_count=1, | ||
) | ||
# [END howto_operator_snowflake_single_sql_stmt] | ||
|
||
(snowflake_op_sql_multiple_stmt >> snowflake_single_sql_stmt) |
243 changes: 243 additions & 0 deletions
243
astronomer/providers/snowflake/hooks/snowflake_sql_api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
import uuid | ||
from abc import ABC | ||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import aiohttp | ||
import requests | ||
from aiohttp import ClientResponseError | ||
from airflow import AirflowException | ||
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook | ||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import serialization | ||
|
||
from astronomer.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator | ||
|
||
|
||
class SnowflakeSQLAPIHookAsync(SnowflakeHook, ABC): | ||
""" | ||
A client to interact with Snowflake. | ||
This hook requires the snowflake_conn_id connection. The snowflake host, login, | ||
and, password field must be setup in the connection. Other inputs can be defined | ||
in the connection or hook instantiation. If used with the S3ToSnowflakeOperator | ||
add 'aws_access_key_id' and 'aws_secret_access_key' to extra field in the connection. | ||
:param snowflake_conn_id: Reference to | ||
:ref:`Snowflake connection id<howto/connection:snowflake>` | ||
:param account: snowflake account name | ||
:param authenticator: authenticator for Snowflake. | ||
'snowflake' (default) to use the internal Snowflake authenticator | ||
'externalbrowser' to authenticate using your web browser and | ||
Okta, ADFS or any other SAML 2.0-compliant identify provider | ||
(IdP) that has been defined for your account | ||
'https://<your_okta_account_name>.okta.com' to authenticate | ||
through native Okta. | ||
:param warehouse: name of snowflake warehouse | ||
:param database: name of snowflake database | ||
:param region: name of snowflake region | ||
:param role: name of snowflake role | ||
:param schema: name of snowflake schema | ||
:param session_parameters: You can set session-level parameters at | ||
the time you connect to Snowflake | ||
""" | ||
|
||
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime | ||
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes | ||
|
||
def __init__( | ||
self, | ||
snowflake_conn_id: str, | ||
token_life_time: timedelta = LIFETIME, | ||
token_renewal_delta: timedelta = RENEWAL_DELTA, | ||
*args: Any, | ||
**kwargs: Any, | ||
): | ||
self.snowflake_conn_id = snowflake_conn_id | ||
self.token_life_time = token_life_time | ||
self.token_renewal_delta = token_renewal_delta | ||
super().__init__(snowflake_conn_id, *args, **kwargs) | ||
self.private_key = None | ||
|
||
def get_private_key(self) -> None: | ||
"""Gets the private key from snowflake connection""" | ||
conn = self.get_connection(self.snowflake_conn_id) | ||
|
||
# If private_key_file is specified in the extra json, load the contents of the file as a private key. | ||
# If private_key_content is specified in the extra json, use it as a private key. | ||
# As a next step, specify this private key in the connection configuration. | ||
# The connection password then becomes the passphrase for the private key. | ||
# If your private key is not encrypted (not recommended), then leave the password empty. | ||
|
||
private_key_file = conn.extra_dejson.get( | ||
"extra__snowflake__private_key_file" | ||
) or conn.extra_dejson.get("private_key_file") | ||
private_key_content = conn.extra_dejson.get( | ||
"extra__snowflake__private_key_content" | ||
) or conn.extra_dejson.get("private_key_content") | ||
|
||
private_key_pem = None | ||
if private_key_content and private_key_file: | ||
raise AirflowException( | ||
"The private_key_file and private_key_content extra fields are mutually exclusive. " | ||
"Please remove one." | ||
) | ||
elif private_key_file: | ||
private_key_pem = Path(private_key_file).read_bytes() | ||
elif private_key_content: | ||
private_key_pem = private_key_content.encode() | ||
|
||
if private_key_pem: | ||
passphrase = None | ||
if conn.password: | ||
passphrase = conn.password.strip().encode() | ||
|
||
self.private_key = serialization.load_pem_private_key( | ||
private_key_pem, password=passphrase, backend=default_backend() | ||
) | ||
|
||
def execute_query( | ||
self, sql: str, statement_count: int, query_tag: str = "", bindings: Optional[Dict[str, Any]] = None | ||
) -> List[str]: | ||
""" | ||
Using SnowflakeSQL API, run the query in snowflake by making API request | ||
:param sql: the sql string to be executed with possibly multiple statements | ||
:param statement_count: set the MULTI_STATEMENT_COUNT field to the number of SQL statements in the request | ||
:param query_tag: (Optional) Query tag that you want to associate with the SQL statement. | ||
For details, see https://docs.snowflake.com/en/sql-reference/parameters.html#label-query-tag parameter. | ||
:param bindings: (Optional) Values of bind variables in the SQL statement. | ||
When executing the statement, Snowflake replaces placeholders (? and :name) in | ||
the statement with these specified values. | ||
""" | ||
conn_config = self._get_conn_params() | ||
|
||
req_id = uuid.uuid4() | ||
url = "https://{0}.snowflakecomputing.com/api/v2/statements".format(conn_config["account"]) | ||
params: Optional[Union[Dict[str, Any]]] = {"requestId": str(req_id), "async": True, "pageSize": 10} | ||
headers = self.get_headers() | ||
if bindings is None: | ||
bindings = {} | ||
data = { | ||
"statement": sql, | ||
"timeout": 60, | ||
"resultSetMetaData": {"format": "json"}, | ||
"database": conn_config["database"], | ||
"schema": conn_config["schema"], | ||
"warehouse": conn_config["warehouse"], | ||
"role": conn_config["role"], | ||
"bindings": bindings, | ||
"parameters": { | ||
"MULTI_STATEMENT_COUNT": statement_count, | ||
"query_tag": query_tag, | ||
}, | ||
} | ||
response = requests.post(url, json=data, headers=headers, params=params) | ||
try: | ||
response.raise_for_status() | ||
except requests.exceptions.HTTPError: | ||
self.log.error("HTTP error: %s", response.reason) | ||
self.log.error(response.text) | ||
raise AirflowException(str(response.status_code) + ":" + response.reason) | ||
json_response = response.json() | ||
self.log.info("Snowflake SQL POST API response: %s", json_response) | ||
if "statementHandles" in json_response: | ||
self.query_ids = json_response["statementHandles"] | ||
elif "statementHandle" in json_response: | ||
self.query_ids.append(json_response["statementHandle"]) | ||
else: | ||
raise AirflowException("No statementHandle/statementHandles present in response") | ||
return self.query_ids | ||
|
||
def get_headers(self) -> Dict[str, Any]: | ||
"""Based on the private key, and with connection details JWT Token is generated and header is formed""" | ||
if not self.private_key: | ||
self.get_private_key() | ||
conn_config = self._get_conn_params() | ||
|
||
# Get the JWT token from the connection details and the private key | ||
token = JWTGenerator( | ||
conn_config["account"], # type: ignore[arg-type] | ||
conn_config["user"], # type: ignore[arg-type] | ||
private_key=self.private_key, | ||
lifetime=self.token_life_time, | ||
renewal_delay=self.token_renewal_delta, | ||
).get_token() | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {token}", | ||
"Accept": "application/json", | ||
"User-Agent": "snowflakeSQLAPI/1.0", | ||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT", | ||
} | ||
return headers | ||
|
||
def get_request_url_header_params(self, query_id: str) -> Tuple[Dict[str, Any], Dict[str, Any], str]: | ||
""" | ||
Build the request header Url with account name identifier and query id from the connection params | ||
:param query_id: statement handles query ids for the individual statements. | ||
""" | ||
conn_config = self._get_conn_params() | ||
req_id = uuid.uuid4() | ||
header = self.get_headers() | ||
params = {"requestId": str(req_id), "page": 2, "pageSize": 10} | ||
url = "https://{0}.snowflakecomputing.com/api/v2/statements/{1}".format( | ||
conn_config["account"], query_id | ||
) | ||
return header, params, url | ||
|
||
def check_query_output(self, query_ids: List[str]) -> None: | ||
""" | ||
Based on the query ids passed as the parameter make HTTP request to snowflake SQL API and logs the response | ||
:param query_ids: statement handles query id for the individual statements. | ||
""" | ||
for query_id in query_ids: | ||
header, params, url = self.get_request_url_header_params(query_id) | ||
response = requests.get(url, headers=header, params=params) | ||
try: | ||
response.raise_for_status() | ||
except requests.exceptions.HTTPError: | ||
self.log.error("HTTP error: %s", response.reason) | ||
self.log.error(response.text) | ||
raise AirflowException(str(response.status_code) + ":" + response.reason) | ||
self.log.info(response.json()) | ||
|
||
async def get_sql_api_query_status(self, query_id: str) -> Dict[str, Union[str, List[str]]]: | ||
""" | ||
Based on the query id async HTTP request is made to snowflake SQL API and return response. | ||
:param query_id: statement handle id for the individual statements. | ||
""" | ||
self.log.info("Retrieving status for query id %s", {query_id}) | ||
header, params, url = self.get_request_url_header_params(query_id) | ||
async with aiohttp.ClientSession(headers=header) as session: | ||
async with session.get(url, params=params) as response: | ||
try: | ||
status_code = response.status | ||
resp = await response.json() | ||
self.log.info("Snowflake SQL GET statements status API response: %s", resp) | ||
if status_code == 202: | ||
return {"status": "running", "message": "Query statements are still running"} | ||
elif status_code == 422: | ||
return {"status": "error", "message": resp["message"]} | ||
elif status_code == 200: | ||
statement_handles = [] | ||
if "statementHandles" in resp and resp["statementHandles"] is not None: | ||
statement_handles = resp["statementHandles"] | ||
elif "statementHandle" in resp and resp["statementHandle"] is not None: | ||
statement_handles.append(resp["statementHandle"]) | ||
return { | ||
"status": "success", | ||
"message": resp["message"], | ||
"statement_handles": statement_handles, | ||
} | ||
else: | ||
return {"status": "error", "message": resp["message"]} | ||
except ClientResponseError as e: | ||
msg = "HTTP error with status: %s", e.status | ||
self.log.exception(str(msg)) | ||
return {"status": "success", "message": str(msg)} |
Oops, something went wrong.