Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Snowflake Agent Bug #2605

Merged
merged 38 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
52c79b6
fix snowflake agent bug
Future-Outlier Jul 24, 2024
2d5c1a6
a work version
Future-Outlier Jul 24, 2024
804052a
Snowflake work version
Future-Outlier Jul 25, 2024
90f08dc
fix secret encode
Future-Outlier Jul 25, 2024
c0f84f2
all works, I am so happy
Future-Outlier Jul 25, 2024
05adfd1
improve additional protocol
Future-Outlier Jul 25, 2024
89d633f
fix tests
Future-Outlier Jul 25, 2024
d1d8024
Fix Tests
Future-Outlier Jul 25, 2024
a17f28d
update agent
pingsutw Jul 25, 2024
547a801
Add snowflake test
pingsutw Jul 25, 2024
a6de45c
nit
pingsutw Jul 25, 2024
14c4318
sd
pingsutw Jul 25, 2024
76637e8
snowflake loglinks
Future-Outlier Jul 25, 2024
762ad0b
add metadata
Future-Outlier Jul 26, 2024
1fcd2de
secret
pingsutw Jul 29, 2024
4a8c8ba
nit
pingsutw Jul 29, 2024
3a7a9cd
remove table
Future-Outlier Jul 30, 2024
2704555
add comment for get private key
Future-Outlier Jul 30, 2024
469b86c
update comments:
Future-Outlier Jul 30, 2024
378327f
Fix Tests
Future-Outlier Jul 30, 2024
d71ef8f
update comments
Future-Outlier Jul 30, 2024
6a8cd9a
update comments
Future-Outlier Jul 30, 2024
5035063
Better Secrets
Future-Outlier Jul 30, 2024
aaff3d2
use union secret
Future-Outlier Jul 30, 2024
45a788d
Update Changes
Future-Outlier Jul 30, 2024
dfe6f97
use if not get_plugin().secret_requires_group()
Future-Outlier Jul 30, 2024
03e8b69
Use Union SDK
Future-Outlier Jul 30, 2024
41e2a19
Update
Future-Outlier Jul 30, 2024
af5a2f1
Fix Secrets
Future-Outlier Jul 30, 2024
c4b641e
Fix Secrets
Future-Outlier Jul 30, 2024
c8f472f
remove pacakge.json
Future-Outlier Jul 31, 2024
4b08fbe
lint
Future-Outlier Jul 31, 2024
de6ce1a
add snowflake-connector-python
Future-Outlier Jul 31, 2024
58a1106
fix test_snowflake
Future-Outlier Jul 31, 2024
4aa2411
Try to fix tests
Future-Outlier Jul 31, 2024
31e57c8
fix tests
Future-Outlier Jul 31, 2024
4a9e936
Try Fix snowflake Import
Future-Outlier Jul 31, 2024
1dd36b2
snowflake test passed
Future-Outlier Jul 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ def lazy_import_transformers(cls):
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
register_snowflake_handlers,
)
from flytekit.types.structured.structured_dataset import DuplicateHandlerError

Expand Down Expand Up @@ -1015,6 +1016,11 @@ def lazy_import_transformers(cls):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
from flytekit.types.file import image # noqa: F401
if is_imported("snowflake.connector"):
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
try:
register_snowflake_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for snowflake is already registered.")
pingsutw marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
Expand Down
14 changes: 14 additions & 0 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,17 @@ def register_bigquery_handlers():
"We won't register bigquery handler for structured dataset because "
"we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery"
)


def register_snowflake_handlers():
try:
from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler

StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers())

except ImportError:
logger.info(
"We won't register snowflake handler for structured dataset because "
"we can't find package snowflake-connector-python"
)
105 changes: 105 additions & 0 deletions flytekit/types/structured/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import re
import typing

import pandas as pd
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

import flytekit
from flytekit import FlyteContext
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
)

SNOWFLAKE = "snowflake"


def get_private_key() -> bytes:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

pk_string = flytekit.current_context().secrets.get(None, "snowflake", encode_mode="r")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does setting group to None work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

# cryptography needs str to be stripped and converted to bytes
pk_string = pk_string.strip().encode()
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return pkb


def _write_to_sf(structured_dataset: StructuredDataset):
if structured_dataset.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = structured_dataset.uri
_, user, account, warehouse, database, schema, table = re.split("\\/|://|:", uri)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks very brittle. Does the uri always come like this?

Copy link
Member Author

@Future-Outlier Future-Outlier Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We create by this format, so it will 100% work

df = structured_dataset.dataframe

conn = snowflake.connector.connect(
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

write_pandas(conn, df, table)


def _read_from_sf(
flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata
) -> pd.DataFrame:
if flyte_value.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = flyte_value.uri
_, user, account, warehouse, database, schema, query_id = re.split("\\/|://|:", uri)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this regex is used twice. Does it makes sense to use a constant?


conn = snowflake.connector.connect(
user=user,
account=account,
private_key=get_private_key(),
database=database,
schema=schema,
warehouse=warehouse,
table="FLYTEAGENT.PUBLIC.TEST",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be a config?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will delete this line table, thank you.

)

cs = conn.cursor()
cs.get_results_from_sfqid(query_id)
return cs.fetch_pandas_all()


class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder):
def __init__(self):
super().__init__(python_type=pd.DataFrame, protocol=SNOWFLAKE, supported_format="")

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
_write_to_sf(structured_dataset)
return literals.StructuredDataset(
uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type)
)


class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(pd.DataFrame, protocol=SNOWFLAKE, supported_format="")

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
return _read_from_sf(flyte_value, current_task_metadata)
17 changes: 14 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, Optional, Type, Union
from typing import Dict, Generator, List, Optional, Type, Union

from dataclasses_json import config
from fsspec.utils import get_protocol
Expand Down Expand Up @@ -222,7 +222,12 @@ def extract_cols_and_format(


class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[T],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand Down Expand Up @@ -284,7 +289,13 @@ def encode(


class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[DF],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
additional_protocols: Optional[List[str]] = None,
):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[BigQueryConfig],
task_config: BigQueryConfig,
inputs: Optional[Dict[str, Type]] = None,
output_structured_dataset_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
Expand Down
61 changes: 42 additions & 19 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from dataclasses import dataclass
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog

from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger
import flytekit
from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType

snowflake_connector = lazy_module("snowflake.connector")
from snowflake import connector as sc

TASK_TYPE = "snowflake"
SNOWFLAKE_PRIVATE_KEY = "snowflake_private_key"
Expand All @@ -27,15 +27,16 @@ class SnowflakeJobMetadata(ResourceMeta):
warehouse: str
table: str
query_id: str
has_output: bool


def get_private_key():
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import flytekit

pk_string = flytekit.current_context().secrets.get(SNOWFLAKE_PRIVATE_KEY, encode_mode="rb")
pk_string = flytekit.current_context().secrets.get(SNOWFLAKE_PRIVATE_KEY, encode_mode="r")
# cryptography needs str to be stripped and converted to bytes
pk_string = pk_string.strip().encode()
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
Expand All @@ -47,8 +48,8 @@ def get_private_key():
return pkb


def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector:
return snowflake_connector.connect(
def get_connection(metadata: SnowflakeJobMetadata) -> sc:
return sc.connect(
user=metadata.user,
account=metadata.account,
private_key=get_private_key(),
Expand All @@ -69,10 +70,11 @@ async def create(
) -> SnowflakeJobMetadata:
ctx = FlyteContextManager.current_context()
literal_types = task_template.interface.inputs
params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs else None

params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None

config = task_template.config
conn = snowflake_connector.connect(
conn = sc.connect(
user=config["user"],
account=config["account"],
private_key=get_private_key(),
Expand All @@ -82,7 +84,7 @@ async def create(
)

cs = conn.cursor()
cs.execute_async(task_template.sql.statement, params=params)
cs.execute_async(task_template.sql.statement, params)

return SnowflakeJobMetadata(
user=config["user"],
Expand All @@ -91,34 +93,42 @@ async def create(
schema=config["schema"],
warehouse=config["warehouse"],
table=config["table"],
query_id=str(cs.sfqid),
query_id=cs.sfqid,
has_output=task_template.interface.outputs is not None and len(task_template.interface.outputs) > 0,
)

async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource:
conn = get_connection(resource_meta)
try:
query_status = conn.get_query_status_throw_if_error(resource_meta.query_id)
except snowflake_connector.ProgrammingError as err:
except sc.ProgrammingError as err:
logger.error("Failed to get snowflake job status with error:", err.msg)
return Resource(phase=TaskExecution.FAILED)

log_link = TaskLog(
uri=construct_query_link(resource_meta=resource_meta),
name="Snowflake Query Details",
)
# The snowflake job's state is determined by query status.
# https://github.com/snowflakedb/snowflake-connector-python/blob/main/src/snowflake/connector/constants.py#L373
cur_phase = convert_to_flyte_phase(str(query_status.name))
res = None

if cur_phase == TaskExecution.SUCCEEDED:
if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output:
ctx = FlyteContextManager.current_context()
output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}"
uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}"
res = literals.LiteralMap(
{
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=output_metadata),
StructuredDataset(uri=uri),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}
).to_flyte_idl()
)

return Resource(phase=cur_phase, outputs=res)
return Resource(phase=cur_phase, outputs=res, log_links=[log_link])

async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs):
conn = get_connection(resource_meta)
Expand All @@ -131,4 +141,17 @@ async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs):
conn.close()


def construct_query_link(resource_meta: SnowflakeJobMetadata) -> str:
base_url = "https://app.snowflake.com"

# Extract the account and region (assuming the format is account-region, you might need to adjust this based on your actual account format)
account_parts = resource_meta.account.split("-")
account = account_parts[0]
region = account_parts[1] if len(account_parts) > 1 else ""

url = f"{base_url}/{region}/{account}/#/compute/history/queries/{resource_meta.query_id}/detail"

return url


AgentRegistry.register(SnowflakeAgent())
32 changes: 17 additions & 15 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@
class SnowflakeConfig(object):
"""
SnowflakeConfig should be used to configure a Snowflake Task.
You can use the query below to retrieve all metadata for this config.

SELECT
CURRENT_USER() AS "User",
CONCAT(CURRENT_ORGANIZATION_NAME(), '-', CURRENT_ACCOUNT_NAME()) AS "Account",
CURRENT_DATABASE() AS "Database",
CURRENT_SCHEMA() AS "Schema",
CURRENT_WAREHOUSE() AS "Warehouse";
"""

# The user to query against
user: Optional[str] = None
# The account to query against
account: Optional[str] = None
# The database to query against
database: Optional[str] = None
# The optional schema to separate query execution.
schema: Optional[str] = None
# The optional warehouse to set for the given Snowflake query
warehouse: Optional[str] = None
# The optional table to set for the given Snowflake query
table: Optional[str] = None
user: str
account: str
database: str
schema: str
warehouse: str
table: str


class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]):
Expand All @@ -47,7 +49,7 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[SnowflakeConfig] = None,
task_config: SnowflakeConfig,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
Expand All @@ -63,13 +65,13 @@ def __init__(
:param output_schema_type: If some data is produced by this query, then you can specify the output schema type
:param kwargs: All other args required by Parent type - SQLTask
"""

outputs = None
if output_schema_type is not None:
outputs = {
"results": output_schema_type,
}
if task_config is None:
task_config = SnowflakeConfig()

super().__init__(
name=name,
task_config=task_config,
Expand Down
Loading
Loading