From 0a4c79e38bc1871716c9febf8a023f3d70a2e808 Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Mon, 31 Oct 2022 17:20:53 -0600 Subject: [PATCH 1/8] Add SnowflakeTableDataSet --- kedro/extras/datasets/pandas/__init__.py | 3 + .../datasets/pandas/snowflake_dataset.py | 251 ++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 kedro/extras/datasets/pandas/snowflake_dataset.py diff --git a/kedro/extras/datasets/pandas/__init__.py b/kedro/extras/datasets/pandas/__init__.py index b84015d1d9..b485099c3f 100644 --- a/kedro/extras/datasets/pandas/__init__.py +++ b/kedro/extras/datasets/pandas/__init__.py @@ -13,6 +13,7 @@ "SQLTableDataSet", "XMLDataSet", "GenericDataSet", + "SnowflakeTableDataSet", ] from contextlib import suppress @@ -37,3 +38,5 @@ from .xml_dataset import XMLDataSet with suppress(ImportError): from .generic_dataset import GenericDataSet +with suppress(ImportError): + from .snowflake_dataset import SnowflakeTableDataSet diff --git a/kedro/extras/datasets/pandas/snowflake_dataset.py b/kedro/extras/datasets/pandas/snowflake_dataset.py new file mode 100644 index 0000000000..cc34cc503e --- /dev/null +++ b/kedro/extras/datasets/pandas/snowflake_dataset.py @@ -0,0 +1,251 @@ +"""``SnowflakeTableDataSet`` to load and save data to a SQL backend.""" + +import copy +import re +from typing import Any, Dict, Optional + +import pandas as pd +from snowflake.snowpark import Session + +from kedro.io.core import AbstractDataSet, DataSetError + +KNOWN_PIP_INSTALL = { + "snowflake.snowpark": "snowflake.snowpark", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to Snowflake +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DataSetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DataSetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +class SnowflakeTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): + """``SQLTableDataSet`` loads data from a SQL table and saves a pandas + dataframe to a table. It uses ``pandas.DataFrame`` internally, + so it supports all allowed pandas options on ``read_sql_table`` and + ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when + instantiating ``SQLTableDataSet`` one needs to pass a compatible connection + string either in ``credentials`` (see the example code snippet below) or in + ``load_args`` and ``save_args``. Connection string formats supported by + SQLAlchemy can be found here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + + ``SQLTableDataSet`` modifies the save parameters and stores + the data with no index. This is designed to make load and save methods + symmetric. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> shuttles_table_dataset: + >>> type: pandas.SQLTableDataSet + >>> credentials: db_credentials + >>> table_name: shuttles + >>> load_args: + >>> schema: dwschema + >>> save_args: + >>> schema: dwschema + >>> if_exists: replace + + Sample database credentials entry in ``credentials.yml``: + + .. code-block:: yaml + + >>> db_creds: + >>> con: postgresql://scott:tiger@localhost/test + + Example using Python API: + :: + + >>> from kedro.extras.datasets.pandas import SQLTableDataSet + >>> import pandas as pd + >>> + >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], + >>> "col3": [5, 6]}) + >>> table_name = "table_a" + >>> credentials = { + >>> "con": "postgresql://scott:tiger@localhost/test" + >>> } + >>> data_set = SQLTableDataSet(table_name=table_name, + >>> credentials=credentials) + >>> + >>> data_set.save(data) + >>> reloaded = data_set.load() + >>> + >>> assert data.equals(reloaded) + + """ + + DEFAULT_LOAD_ARGS: Dict[str, Any] = {} + DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} + # using Any because of Sphinx but it should be + # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine + sessions: Dict[str, Any] = {} + + def __init__( + self, + table_name: str, + credentials: Dict[str, Any], + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + ) -> None: + """Creates a new ``SQLTableDataSet``. + + Args: + table_name: The table name to load or save data to. It + overwrites name in ``save_args`` and ``table_name`` + parameters in ``load_args``. + credentials: A dictionary with a ``SQLAlchemy`` connection string. + Users are supposed to provide the connection string 'con' + through credentials. It overwrites `con` parameter in + ``load_args`` and ``save_args`` in case it is provided. To find + all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + load_args: Provided to underlying pandas ``read_sql_table`` + function along with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + save_args: Provided to underlying pandas ``to_sql`` function along + with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + It has ``index=False`` in the default parameters. + + Raises: + DataSetError: When either ``table_name`` or ``con`` is empty. + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + if not credentials: + raise DataSetError("Please configure expected credentials") + + # print(self._load_args) + + # Handle default load and save arguments + self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._load_args["table_name"] = table_name + self._save_args["name"] = table_name + + self._credentials = credentials["credentials"] + + # self._connection_str = credentials["con"] + self.create_connection(self._credentials) + + @classmethod + def create_connection(cls, credentials: dict) -> None: + """Given a connection string, create singleton connection + to be used across all instances of `SQLQueryDataSet` that + need to connect to the same source. + connection_params = { + "account": "",∂ + "user": "", + "password": "", + "role": "", + "warehouse": "", + "database": "", + "schema": "" + } + """ + if credentials["account"] in cls.sessions: + return + try: + session = Session.builder.configs(credentials).create() + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except Exception as exception: + raise exception + + cls.sessions[credentials["account"]] = session + print("session") + print(session) + print("connection successful") + + def _describe(self) -> Dict[str, Any]: + load_args = copy.deepcopy(self._load_args) + save_args = copy.deepcopy(self._save_args) + del load_args["table_name"] + del save_args["name"] + return dict( + table_name=self._load_args["table_name"], + load_args=load_args, + save_args=save_args, + ) + + def _load(self) -> pd.DataFrame: + pass + + def _save(self, data: pd.DataFrame) -> None: + # pd df to snowpark df + session = self.sessions[self._credentials["account"]] # type: ignore + sp_df = session.create_dataframe(data) + table_name = [ + self._credentials.get("database"), + self._credentials.get("schema"), + self._load_args["table_name"], + ] + sp_df.write.mode("overwrite").save_as_table(table_name, table_type="") + + def _exists(self) -> bool: + session = self.sessions[self._credentials["account"]] # type: ignore + schema = self._load_args.get("schema", None) + exists = self._load_args["table_name"] in session.table_names(schema) + return exists From a657d3a8f7e632caf3f8e9d686f8518ddb2f9f95 Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Mon, 31 Oct 2022 17:21:53 -0600 Subject: [PATCH 2/8] Add snowflake-snowpark-python --- test_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test_requirements.txt b/test_requirements.txt index 52305e865a..111da4c8c8 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -50,6 +50,7 @@ requests-mock~=1.6 requests~=2.20 s3fs>=0.3.0, <0.5 # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. SQLAlchemy~=1.2 +snowflake-snowpark-python~=0.12.0 tables~=3.6.0; platform_system == "Windows" and python_version<'3.9' tables~=3.6; platform_system != "Windows" tensorflow~=2.0 From 95e272ba42e6870a1277493c8a81c9bb1295c588 Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:43:22 -0600 Subject: [PATCH 3/8] Delete snowflake_dataset.py on pandas folder --- .../datasets/pandas/snowflake_dataset.py | 251 ------------------ 1 file changed, 251 deletions(-) delete mode 100644 kedro/extras/datasets/pandas/snowflake_dataset.py diff --git a/kedro/extras/datasets/pandas/snowflake_dataset.py b/kedro/extras/datasets/pandas/snowflake_dataset.py deleted file mode 100644 index cc34cc503e..0000000000 --- a/kedro/extras/datasets/pandas/snowflake_dataset.py +++ /dev/null @@ -1,251 +0,0 @@ -"""``SnowflakeTableDataSet`` to load and save data to a SQL backend.""" - -import copy -import re -from typing import Any, Dict, Optional - -import pandas as pd -from snowflake.snowpark import Session - -from kedro.io.core import AbstractDataSet, DataSetError - -KNOWN_PIP_INSTALL = { - "snowflake.snowpark": "snowflake.snowpark", -} - -DRIVER_ERROR_MESSAGE = """ -A module/driver is missing when connecting to Snowflake -\n\n -""" - - -def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: - """Looks up known keywords in a ``ModuleNotFoundError`` so that it can - provide better guideline for the user. - - Args: - module_import_error: Error raised while connecting to a SQL server. - - Returns: - Instructions for installing missing driver. An empty string is - returned in case error is related to an unknown driver. - - """ - - # module errors contain string "No module name 'module_name'" - # we are trying to extract module_name surrounded by quotes here - res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) - - # in case module import error does not match our expected pattern - # we have no recommendation - if not res: - return None - - missing_module = res[0] - - if KNOWN_PIP_INSTALL.get(missing_module): - return ( - f"You can also try installing missing driver with\n" - f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" - ) - - return None - - -def _get_missing_module_error(import_error: ImportError) -> DataSetError: - missing_module_instruction = _find_known_drivers(import_error) - - if missing_module_instruction is None: - return DataSetError( - f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" - ) - - return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") - - -class SnowflakeTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): - """``SQLTableDataSet`` loads data from a SQL table and saves a pandas - dataframe to a table. It uses ``pandas.DataFrame`` internally, - so it supports all allowed pandas options on ``read_sql_table`` and - ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when - instantiating ``SQLTableDataSet`` one needs to pass a compatible connection - string either in ``credentials`` (see the example code snippet below) or in - ``load_args`` and ``save_args``. Connection string formats supported by - SQLAlchemy can be found here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - - ``SQLTableDataSet`` modifies the save parameters and stores - the data with no index. This is designed to make load and save methods - symmetric. - - Example adding a catalog entry with - `YAML API `_: - - .. code-block:: yaml - - >>> shuttles_table_dataset: - >>> type: pandas.SQLTableDataSet - >>> credentials: db_credentials - >>> table_name: shuttles - >>> load_args: - >>> schema: dwschema - >>> save_args: - >>> schema: dwschema - >>> if_exists: replace - - Sample database credentials entry in ``credentials.yml``: - - .. code-block:: yaml - - >>> db_creds: - >>> con: postgresql://scott:tiger@localhost/test - - Example using Python API: - :: - - >>> from kedro.extras.datasets.pandas import SQLTableDataSet - >>> import pandas as pd - >>> - >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], - >>> "col3": [5, 6]}) - >>> table_name = "table_a" - >>> credentials = { - >>> "con": "postgresql://scott:tiger@localhost/test" - >>> } - >>> data_set = SQLTableDataSet(table_name=table_name, - >>> credentials=credentials) - >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() - >>> - >>> assert data.equals(reloaded) - - """ - - DEFAULT_LOAD_ARGS: Dict[str, Any] = {} - DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} - # using Any because of Sphinx but it should be - # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine - sessions: Dict[str, Any] = {} - - def __init__( - self, - table_name: str, - credentials: Dict[str, Any], - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - ) -> None: - """Creates a new ``SQLTableDataSet``. - - Args: - table_name: The table name to load or save data to. It - overwrites name in ``save_args`` and ``table_name`` - parameters in ``load_args``. - credentials: A dictionary with a ``SQLAlchemy`` connection string. - Users are supposed to provide the connection string 'con' - through credentials. It overwrites `con` parameter in - ``load_args`` and ``save_args`` in case it is provided. To find - all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - load_args: Provided to underlying pandas ``read_sql_table`` - function along with the connection string. - To find all supported arguments, see here: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html - To find all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - save_args: Provided to underlying pandas ``to_sql`` function along - with the connection string. - To find all supported arguments, see here: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html - To find all supported connection string formats, see here: - https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls - It has ``index=False`` in the default parameters. - - Raises: - DataSetError: When either ``table_name`` or ``con`` is empty. - """ - - if not table_name: - raise DataSetError("'table_name' argument cannot be empty.") - - if not credentials: - raise DataSetError("Please configure expected credentials") - - # print(self._load_args) - - # Handle default load and save arguments - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) - - self._load_args["table_name"] = table_name - self._save_args["name"] = table_name - - self._credentials = credentials["credentials"] - - # self._connection_str = credentials["con"] - self.create_connection(self._credentials) - - @classmethod - def create_connection(cls, credentials: dict) -> None: - """Given a connection string, create singleton connection - to be used across all instances of `SQLQueryDataSet` that - need to connect to the same source. - connection_params = { - "account": "",∂ - "user": "", - "password": "", - "role": "", - "warehouse": "", - "database": "", - "schema": "" - } - """ - if credentials["account"] in cls.sessions: - return - try: - session = Session.builder.configs(credentials).create() - except ImportError as import_error: - raise _get_missing_module_error(import_error) from import_error - except Exception as exception: - raise exception - - cls.sessions[credentials["account"]] = session - print("session") - print(session) - print("connection successful") - - def _describe(self) -> Dict[str, Any]: - load_args = copy.deepcopy(self._load_args) - save_args = copy.deepcopy(self._save_args) - del load_args["table_name"] - del save_args["name"] - return dict( - table_name=self._load_args["table_name"], - load_args=load_args, - save_args=save_args, - ) - - def _load(self) -> pd.DataFrame: - pass - - def _save(self, data: pd.DataFrame) -> None: - # pd df to snowpark df - session = self.sessions[self._credentials["account"]] # type: ignore - sp_df = session.create_dataframe(data) - table_name = [ - self._credentials.get("database"), - self._credentials.get("schema"), - self._load_args["table_name"], - ] - sp_df.write.mode("overwrite").save_as_table(table_name, table_type="") - - def _exists(self) -> bool: - session = self.sessions[self._credentials["account"]] # type: ignore - schema = self._load_args.get("schema", None) - exists = self._load_args["table_name"] in session.table_names(schema) - return exists From 864360938be53946683c39117b610525a09e366a Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:43:58 -0600 Subject: [PATCH 4/8] Update __init__.py --- kedro/extras/datasets/pandas/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/kedro/extras/datasets/pandas/__init__.py b/kedro/extras/datasets/pandas/__init__.py index b485099c3f..b84015d1d9 100644 --- a/kedro/extras/datasets/pandas/__init__.py +++ b/kedro/extras/datasets/pandas/__init__.py @@ -13,7 +13,6 @@ "SQLTableDataSet", "XMLDataSet", "GenericDataSet", - "SnowflakeTableDataSet", ] from contextlib import suppress @@ -38,5 +37,3 @@ from .xml_dataset import XMLDataSet with suppress(ImportError): from .generic_dataset import GenericDataSet -with suppress(ImportError): - from .snowflake_dataset import SnowflakeTableDataSet From d4976f7b7dfaba6b37544c9c3f468effb2657068 Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:45:00 -0600 Subject: [PATCH 5/8] Create __init__.py --- kedro/extras/datasets/snowflake/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 kedro/extras/datasets/snowflake/__init__.py diff --git a/kedro/extras/datasets/snowflake/__init__.py b/kedro/extras/datasets/snowflake/__init__.py new file mode 100644 index 0000000000..124de7463d --- /dev/null +++ b/kedro/extras/datasets/snowflake/__init__.py @@ -0,0 +1,8 @@ +"""Provides I/O modules for Snowflake.""" + +__all__ = ["SnowflakeTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .snowflake_dataset import SnowflakeTableDataSet From 7ccef62784585db893f5dd9346af04fcf1e54caf Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:45:39 -0600 Subject: [PATCH 6/8] Create snowflake_dataset.py --- .../datasets/snowflake/snowflake_dataset.py | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 kedro/extras/datasets/snowflake/snowflake_dataset.py diff --git a/kedro/extras/datasets/snowflake/snowflake_dataset.py b/kedro/extras/datasets/snowflake/snowflake_dataset.py new file mode 100644 index 0000000000..e1f63ad4cf --- /dev/null +++ b/kedro/extras/datasets/snowflake/snowflake_dataset.py @@ -0,0 +1,252 @@ +"""``SnowflakeTableDataSet`` to load and save data to a snowflake backend.""" + +import copy +import re +from typing import Any, Dict, Optional + +import pandas as pd +from snowflake.snowpark import Session + +from kedro.io.core import AbstractDataSet, DataSetError + +KNOWN_PIP_INSTALL = { + "snowflake.snowpark": "snowflake.snowpark", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to Snowflake +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DataSetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DataSetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +class SnowflakeTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): + """`SnowflakeTableDataSet` loads data from a snowflake table and saves a pandas + dataframe to a table through snowpark. It uses ``pandas.DataFrame`` internally, + so it supports all allowed pandas options on ``read_sql_table`` and + ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when + instantiating ``SQLTableDataSet`` one needs to pass a compatible connection + string either in ``credentials`` (see the example code snippet below) or in + ``load_args`` and ``save_args``. Connection string formats supported by + SQLAlchemy can be found here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + + ``SQLTableDataSet`` modifies the save parameters and stores + the data with no index. This is designed to make load and save methods + symmetric. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> shuttles_table_dataset: + >>> type: snowflake.SnowflakeTableDataSet + >>> credentials: db_credentials + >>> table_name: shuttles + >>> load_args: + >>> schema: dwschema + >>> save_args: + >>> schema: dwschema + >>> if_exists: replace + + Sample database credentials entry in ``credentials.yml``: + + .. code-block:: yaml + + >>> db_creds: + + Example using Python API: + :: + + >>> from kedro.extras.datasets.pandas import SQLTableDataSet + >>> import pandas as pd + >>> + >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], + >>> "col3": [5, 6]}) + >>> table_name = "table_a" + >>> credentials = { + >>> "account": "", + >>> "user": "", + >>> "password": "", + >>> "role": "", + >>> "warehouse": "", + >>> "database": "", + >>> "schema": "" + >>> } + >>> data_set = SnowflakeTableDataSet(table_name=table_name, + >>> credentials=credentials) + >>> + >>> data_set.save(data) + >>> reloaded = data_set.load() + >>> + >>> assert data.equals(reloaded) + + """ + + DEFAULT_LOAD_ARGS: Dict[str, Any] = {} + DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} + # using Any because of Sphinx but it should be + # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine + sessions: Dict[str, Any] = {} + + def __init__( + self, + table_name: str, + credentials: Dict[str, Any], + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + ) -> None: + """Creates a new ``SQLTableDataSet``. + + Args: + table_name: The table name to load or save data to. It + overwrites name in ``save_args`` and ``table_name`` + parameters in ``load_args``. + credentials: A dictionary with a ``SQLAlchemy`` connection string. + Users are supposed to provide the connection string 'con' + through credentials. It overwrites `con` parameter in + ``load_args`` and ``save_args`` in case it is provided. To find + all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + load_args: Provided to underlying pandas ``read_sql_table`` + function along with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + save_args: Provided to underlying pandas ``to_sql`` function along + with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + It has ``index=False`` in the default parameters. + + Raises: + DataSetError: When either ``table_name`` or ``con`` is empty. + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + if not credentials: + raise DataSetError("Please configure expected credentials") + + # print(self._load_args) + + # Handle default load and save arguments + self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._load_args["table_name"] = table_name + self._save_args["name"] = table_name + + self._credentials = credentials["credentials"] + + # self._connection_str = credentials["con"] + self._session = self._get_session(self._credentials) + + @classmethod + def _get_session(cls, credentials: dict) -> None: + """Given a connection string, create singleton connection + to be used across all instances of `SQLQueryDataSet` that + need to connect to the same source. + connection_params = { + "account": "", + "user": "", + "password": "", + "role": "", + "warehouse": "", + "database": "", + "schema": "" + } + """ + try: + session = Session.builder.configs(credentials).create() + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except Exception as exception: + raise exception + return session + + def _describe(self) -> Dict[str, Any]: + load_args = copy.deepcopy(self._load_args) + save_args = copy.deepcopy(self._save_args) + return dict( + table_name=self._load_args["table_name"], + load_args=load_args, + save_args=save_args, + ) + + def _load(self) -> pd.DataFrame: + sp_df = self._session.table(self._load_args["table_name"]) + return sp_df.to_pandas() + + def _save(self, data: pd.DataFrame) -> None: + # pd df to snowpark df + sp_df = self._session.create_dataframe(data) + table_name = [ + self._credentials.get("database"), + self._credentials.get("schema"), + self._save_args["name"], + ] + sp_df.write.mode(self._save_args["mode"]).save_as_table( + table_name, + column_order=self._save_args["column_order"], + table_type=self._save_args["table_type"], + ) + + def _exists(self) -> bool: + session = self.sessions[self._credentials["account"]] # type: ignore + schema = self._load_args.get("schema", None) + exists = self._load_args["table_name"] in session.table_names(schema) + return exists From 6d22856c196de7b36ea18b0cd66a38ecc1d55b1a Mon Sep 17 00:00:00 2001 From: heber-urdaneta <98349957+heber-urdaneta@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:47:10 -0600 Subject: [PATCH 7/8] Updated pyarrow dependency --- test_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_requirements.txt b/test_requirements.txt index 111da4c8c8..665156be42 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -37,7 +37,7 @@ Pillow~=9.0 plotly>=4.8.0, <6.0 pre-commit>=2.9.2, <3.0 # The hook `mypy` requires pre-commit version 2.9.2. psutil==5.8.0 -pyarrow>=1.0, <7.0 +pyarrow>=1.0, <9.0 pylint>=2.5.2, <3.0 pyproj~=3.0 pyspark>=2.2, <4.0 From da3d463c98a9a1df45bf63e61d25f73ca6994276 Mon Sep 17 00:00:00 2001 From: Vladimir Filimonov Date: Wed, 16 Nov 2022 15:51:55 +0100 Subject: [PATCH 8/8] Bumped python version to 3.8 as required by snowpark Updated snowpark test prerequisites Draft implementation of ShowParkDataSet class --- docs/source/get_started/prerequisites.md | 4 +- .../datasets/snowflake/snowflake_dataset.py | 2 +- .../datasets/snowflake/snowpark_dataset.py | 221 ++++++++++++++++++ test_requirements.txt | 2 +- 4 files changed, 225 insertions(+), 4 deletions(-) create mode 100644 kedro/extras/datasets/snowflake/snowpark_dataset.py diff --git a/docs/source/get_started/prerequisites.md b/docs/source/get_started/prerequisites.md index 1f5126d5b9..8c6e6ebb70 100644 --- a/docs/source/get_started/prerequisites.md +++ b/docs/source/get_started/prerequisites.md @@ -25,10 +25,10 @@ Depending on your preferred Python installation, you can create virtual environm Create a new Python virtual environment, called `kedro-environment`, using `conda`: ```bash -conda create --name kedro-environment python=3.7 -y +conda create --name kedro-environment python=3.8 -y ``` -This will create an isolated Python 3.7 environment. To activate it: +This will create an isolated Python 3.8 environment. To activate it: ```bash conda activate kedro-environment diff --git a/kedro/extras/datasets/snowflake/snowflake_dataset.py b/kedro/extras/datasets/snowflake/snowflake_dataset.py index e1f63ad4cf..bd6cc3cecb 100644 --- a/kedro/extras/datasets/snowflake/snowflake_dataset.py +++ b/kedro/extras/datasets/snowflake/snowflake_dataset.py @@ -196,7 +196,7 @@ def __init__( self._session = self._get_session(self._credentials) @classmethod - def _get_session(cls, credentials: dict) -> None: + def _get_session(cls, credentials: dict) -> Session: """Given a connection string, create singleton connection to be used across all instances of `SQLQueryDataSet` that need to connect to the same source. diff --git a/kedro/extras/datasets/snowflake/snowpark_dataset.py b/kedro/extras/datasets/snowflake/snowpark_dataset.py new file mode 100644 index 0000000000..b968b6d5e9 --- /dev/null +++ b/kedro/extras/datasets/snowflake/snowpark_dataset.py @@ -0,0 +1,221 @@ +"""``AbstractDataSet`` implementation to access Snowflake using Snowpark dataframes +""" +from copy import deepcopy +from re import findall +from typing import Any, Dict, Optional, Union + +import pandas as pd +import snowflake.snowpark as sp + +from kedro.io.core import AbstractDataSet, DataSetError + +KNOWN_PIP_INSTALL = { + "snowflake.snowpark": "snowflake.snowpark", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to Snowflake +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DataSetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DataSetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +# TODO: Update docstring after interface finalised +# TODO: Add to docs example of using API to add dataset +class SnowParkDataSet( + AbstractDataSet[pd.DataFrame, pd.DataFrame] +): # pylint: disable=too-many-instance-attributes + """``SnowParkDataSet`` loads and saves Snowpark dataframes. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> weather: + >>> type: snowflake.SnowParkDataSet + >>> table_name: weather_data + >>> warehouse: warehouse_warehouse + >>> database: meteorology + >>> schema: observations + >>> credentials: db_credentials + >>> load_args (WIP): + >>> Do we need any? + >>> save_args: + >>> mode: overwrite + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # ``ThreadRunner`` instead + _SINGLE_PROCESS = True + DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] + DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] + + # TODO: Update docstring + def __init__( # pylint: disable=too-many-arguments + self, + table_name: str, + warehouse: str, + database: str, + schema: str, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + credentials: Dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``SnowParkDataSet``. + + Args: + table_name: + warehouse: + database: + schema: + load_args: + save_args: whatever supported by snowpark writer + https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html + credentials: + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + # TODO: Check if we can use default warehouse when user is not providing one explicitly + if not warehouse: + raise DataSetError("'warehouse' argument cannot be empty.") + + if not database: + raise DataSetError("'database' argument cannot be empty.") + + if not schema: + raise DataSetError("'schema' argument cannot be empty.") + + if not credentials: + raise DataSetError("Please configure expected credentials") + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._credentials = credentials["credentials"] + + self._session = self._get_session(self._credentials) + self._table_name = table_name + self._warehouse = warehouse + self._database = database + self._schema = schema + + def _describe(self) -> Dict[str, Any]: + return dict( + table_name=self._table_name, + warehouse=self._warehouse, + database=self._database, + schema=self._schema, + ) + + # TODO: Do we want to make it static method? + @classmethod + def _get_session(cls, credentials: dict) -> sp.Session: + """Given a connection string, create singleton connection + to be used across all instances of `SnowParkDataSet` that + need to connect to the same source. + connection_params = { + "account": "", + "user": "", + "password": "", + "role": "", (optional) + "warehouse": "", (optional) + "database": "", (optional) + "schema": "" (optional) + } + """ + try: + session = sp.Session.builder.configs(credentials).create() + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except Exception as exception: + raise exception + return session + + def _load(self) -> sp.DataFrame: + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + sp_df = self._session.table(".".join(table_name)) + return sp_df + + def _save(self, data: Union[pd.DataFrame, sp.DataFrame]) -> None: + if not isinstance(data, sp.DataFrame): + sp_df = self._session.create_dataframe(data) + else: + sp_df = data + + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + sp_df.write.mode(self._save_args["mode"]).save_as_table( + table_name, + column_order=self._save_args["column_order"], + table_type=self._save_args["table_type"], + statement_params=self._save_args["statement_params"], + ) + + def _exists(self) -> bool: + session = self._session + schema = self._schema + exists = self._table_name in session.table_names(schema) + return exists diff --git a/test_requirements.txt b/test_requirements.txt index 665156be42..808dcbc80f 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -49,8 +49,8 @@ redis~=4.1 requests-mock~=1.6 requests~=2.20 s3fs>=0.3.0, <0.5 # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. +snowflake-snowpark-python~=1.0.0 SQLAlchemy~=1.2 -snowflake-snowpark-python~=0.12.0 tables~=3.6.0; platform_system == "Windows" and python_version<'3.9' tables~=3.6; platform_system != "Windows" tensorflow~=2.0