Skip to content

Commit

Permalink
Implement Async Snowflake SQL API Operator
Browse files Browse the repository at this point in the history
Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Covered few more details about the operator and snowflake SQL API in doc string

Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Implement Async Snowflake SQL API Operator

Implemented Async Snowflake SQL API Operator to support multiple  SQL statements sequentially, which is the behavior of the SnowflakeOperator,  the Snowflake SQL API allows for submitting multiple SQL statements in a single request. In combination with aiohttp, this may be an option for creating a SnowflakeSQLOperatorAsync that matches the query submission behavior of the SnowflakeOperator.

Test case

Add Test case

Added Test case for Snowflake SQL API Trigger and Operator

Test case fix

Test case fix

Skip code coverage for import

Doc fix

Add example DAG Documenting

-  Added Example DAG Documentation
- Changed the class name to camel case

Docs FIx

Fix doc

Update Doc

Fix doc string

Move SnowflakeSqlApiOperatorAsync to snowflake.py file
  • Loading branch information
bharanidharan14 committed Jul 11, 2022
1 parent b432e72 commit 32cd057
Show file tree
Hide file tree
Showing 12 changed files with 1,418 additions and 4 deletions.
5 changes: 4 additions & 1 deletion .circleci/integration-tests/master_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def prepare_dag_dependency(task_info, execution_time):
chain(*http_trigger_tasks)

# Snowflake DAG
snowflake_task_info = [{"snowflake_dag": "example_snowflake"}]
snowflake_task_info = [
{"snowflake_dag": "example_snowflake"},
{"snowflake_sql_api_dag": "example_snowflake_sql_api"},
]
snowflake_trigger_tasks, ids = prepare_dag_dependency(snowflake_task_info, "{{ ds }}")
dag_run_ids.extend(ids)
chain(*snowflake_trigger_tasks)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Example use of SnowflakeSqlApiAsync operator."""

import os
from datetime import timedelta

from airflow import DAG
from airflow.utils.timezone import datetime

from astronomer.providers.snowflake.operators.snowflake import (
SnowflakeSqlApiOperatorAsync,
)

SNOWFLAKE_CONN_ID = os.getenv("ASTRO_SNOWFLAKE_CONN_ID", "snowflake_api_default")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))


default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"snowflake_conn_id": SNOWFLAKE_CONN_ID,
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}

with DAG(
dag_id="example_snowflake_sql_api",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
default_args=default_args,
tags=["example", "async", "snowflake"],
catchup=False,
) as dag:
# [START howto_operator_snowflake_sql_api_async]
snowflake_op_sql_multiple_stmt = SnowflakeSqlApiOperatorAsync(
task_id="snowflake_op_sql_multiple_stmt",
sql="create or replace table user_test (i int); insert into user_test (i) "
"values (200); insert into user_test (i) values (300); select i from user_test order by i;",
statement_count=4,
)
# [END howto_operator_snowflake_sql_api_async]

# [START howto_operator_snowflake_single_sql_stmt]
snowflake_single_sql_stmt = SnowflakeSqlApiOperatorAsync(
task_id="snowflake_single_sql_stmt",
sql="select i from user_test order by i;",
statement_count=1,
)
# [END howto_operator_snowflake_single_sql_stmt]

(snowflake_op_sql_multiple_stmt >> snowflake_single_sql_stmt)
240 changes: 240 additions & 0 deletions astronomer/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import uuid
from abc import ABC
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp
import requests
from airflow import AirflowException
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from astronomer.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator


class SnowflakeSqlApiHookAsync(SnowflakeHook, ABC):
"""
A client to interact with Snowflake using SQL API and allows submitting
multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL
statements for execution, poll to check the status of the execution of a statement. Fetch query results
asynchronously.
This hook requires the snowflake_conn_id connection. This hooks mainly uses account, schema, database, warehouse,
private_key_file or private_key_content field must be setup in the connection.
Other inputs can be defined in the connection or hook instantiation.
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:param account: snowflake account name
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:param warehouse: name of snowflake warehouse
:param database: name of snowflake database
:param region: name of snowflake region
:param role: name of snowflake role
:param schema: name of snowflake schema
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param token_life_time: lifetime of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
"""

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes

def __init__(
self,
snowflake_conn_id: str,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
*args: Any,
**kwargs: Any,
):
self.snowflake_conn_id = snowflake_conn_id
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta
super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key = None

def get_private_key(self) -> None:
"""Gets the private key from snowflake connection"""
conn = self.get_connection(self.snowflake_conn_id)

# If private_key_file is specified in the extra json, load the contents of the file as a private key.
# If private_key_content is specified in the extra json, use it as a private key.
# As a next step, specify this private key in the connection configuration.
# The connection password then becomes the passphrase for the private key.
# If your private key is not encrypted (not recommended), then leave the password empty.

private_key_file = conn.extra_dejson.get(
"extra__snowflake__private_key_file"
) or conn.extra_dejson.get("private_key_file")
private_key_content = conn.extra_dejson.get(
"extra__snowflake__private_key_content"
) or conn.extra_dejson.get("private_key_content")

private_key_pem = None
if private_key_content and private_key_file:
raise AirflowException(
"The private_key_file and private_key_content extra fields are mutually exclusive. "
"Please remove one."
)
elif private_key_file:
private_key_pem = Path(private_key_file).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()

if private_key_pem:
passphrase = None
if conn.password:
passphrase = conn.password.strip().encode()

self.private_key = serialization.load_pem_private_key(
private_key_pem, password=passphrase, backend=default_backend()
)

def execute_query(
self, sql: str, statement_count: int, query_tag: str = "", bindings: Optional[Dict[str, Any]] = None
) -> List[str]:
"""
Using SnowflakeSQL API, run the query in snowflake by making API request
:param sql: the sql string to be executed with possibly multiple statements
:param statement_count: set the MULTI_STATEMENT_COUNT field to the number of SQL statements in the request
:param query_tag: (Optional) Query tag that you want to associate with the SQL statement.
For details, see https://docs.snowflake.com/en/sql-reference/parameters.html#label-query-tag parameter.
:param bindings: (Optional) Values of bind variables in the SQL statement.
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
"""
conn_config = self._get_conn_params()

req_id = uuid.uuid4()
url = "https://{0}.snowflakecomputing.com/api/v2/statements".format(conn_config["account"])
params: Optional[Union[Dict[str, Any]]] = {"requestId": str(req_id), "async": True, "pageSize": 10}
headers = self.get_headers()
if bindings is None:
bindings = {}
data = {
"statement": sql,
"resultSetMetaData": {"format": "json"},
"database": conn_config["database"],
"schema": conn_config["schema"],
"warehouse": conn_config["warehouse"],
"role": conn_config["role"],
"bindings": bindings,
"parameters": {
"MULTI_STATEMENT_COUNT": statement_count,
"query_tag": query_tag,
},
}
response = requests.post(url, json=data, headers=headers, params=params)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
raise AirflowException(
f"Response: {e.response.content}, " f"Status Code: {e.response.status_code}"
) # pragma: no cover
json_response = response.json()
self.log.info("Snowflake SQL POST API response: %s", json_response)
if "statementHandles" in json_response:
self.query_ids = json_response["statementHandles"]
elif "statementHandle" in json_response:
self.query_ids.append(json_response["statementHandle"])
else:
raise AirflowException("No statementHandle/statementHandles present in response")
return self.query_ids

def get_headers(self) -> Dict[str, Any]:
"""Based on the private key, and with connection details JWT Token is generated and header is formed"""
if not self.private_key:
self.get_private_key()
conn_config = self._get_conn_params()

# Get the JWT token from the connection details and the private key
token = JWTGenerator(
conn_config["account"], # type: ignore[arg-type]
conn_config["user"], # type: ignore[arg-type]
private_key=self.private_key,
lifetime=self.token_life_time,
renewal_delay=self.token_renewal_delta,
).get_token()

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"User-Agent": "snowflakeSQLAPI/1.0",
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}
return headers

def get_request_url_header_params(self, query_id: str) -> Tuple[Dict[str, Any], Dict[str, Any], str]:
"""
Build the request header Url with account name identifier and query id from the connection params
:param query_id: statement handles query ids for the individual statements.
"""
conn_config = self._get_conn_params()
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
url = "https://{0}.snowflakecomputing.com/api/v2/statements/{1}".format(
conn_config["account"], query_id
)
return header, params, url

def check_query_output(self, query_ids: List[str]) -> None:
"""
Based on the query ids passed as the parameter make HTTP request to snowflake SQL API and logs the response
:param query_ids: statement handles query id for the individual statements.
"""
for query_id in query_ids:
header, params, url = self.get_request_url_header_params(query_id)
try:
response = requests.get(url, headers=header, params=params)
response.raise_for_status()
self.log.info(response.json())
except requests.exceptions.HTTPError as e:
raise AirflowException(
f"Response: {e.response.content}, Status Code: {e.response.status_code}"
)

async def get_sql_api_query_status(self, query_id: str) -> Dict[str, Union[str, List[str]]]:
"""
Based on the query id async HTTP request is made to snowflake SQL API and return response.
:param query_id: statement handle id for the individual statements.
"""
self.log.info("Retrieving status for query id %s", {query_id})
header, params, url = self.get_request_url_header_params(query_id)
async with aiohttp.ClientSession(headers=header) as session:
async with session.get(url, params=params) as response:
status_code = response.status
resp = await response.json()
self.log.info("Snowflake SQL GET statements status API response: %s", resp)
if status_code == 202:
return {"status": "running", "message": "Query statements are still running"}
elif status_code == 422:
return {"status": "error", "message": resp["message"]}
elif status_code == 200:
statement_handles = []
if "statementHandles" in resp and resp["statementHandles"]:
statement_handles = resp["statementHandles"]
elif "statementHandle" in resp and resp["statementHandle"]:
statement_handles.append(resp["statementHandle"])
return {
"status": "success",
"message": resp["message"],
"statement_handles": statement_handles,
}
else:
return {"status": "error", "message": resp["message"]}
Loading

0 comments on commit 32cd057

Please sign in to comment.