diff --git a/docs/source/get_started/prerequisites.md b/docs/source/get_started/prerequisites.md index 1f5126d5b9..8c6e6ebb70 100644 --- a/docs/source/get_started/prerequisites.md +++ b/docs/source/get_started/prerequisites.md @@ -25,10 +25,10 @@ Depending on your preferred Python installation, you can create virtual environm Create a new Python virtual environment, called `kedro-environment`, using `conda`: ```bash -conda create --name kedro-environment python=3.7 -y +conda create --name kedro-environment python=3.8 -y ``` -This will create an isolated Python 3.7 environment. To activate it: +This will create an isolated Python 3.8 environment. To activate it: ```bash conda activate kedro-environment diff --git a/kedro/extras/datasets/snowflake/__init__.py b/kedro/extras/datasets/snowflake/__init__.py new file mode 100644 index 0000000000..124de7463d --- /dev/null +++ b/kedro/extras/datasets/snowflake/__init__.py @@ -0,0 +1,8 @@ +"""Provides I/O modules for Snowflake.""" + +__all__ = ["SnowflakeTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .snowflake_dataset import SnowflakeTableDataSet diff --git a/kedro/extras/datasets/snowflake/snowflake_dataset.py b/kedro/extras/datasets/snowflake/snowflake_dataset.py new file mode 100644 index 0000000000..bd6cc3cecb --- /dev/null +++ b/kedro/extras/datasets/snowflake/snowflake_dataset.py @@ -0,0 +1,252 @@ +"""``SnowflakeTableDataSet`` to load and save data to a snowflake backend.""" + +import copy +import re +from typing import Any, Dict, Optional + +import pandas as pd +from snowflake.snowpark import Session + +from kedro.io.core import AbstractDataSet, DataSetError + +KNOWN_PIP_INSTALL = { + "snowflake.snowpark": "snowflake.snowpark", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to Snowflake +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DataSetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DataSetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +class SnowflakeTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): + """`SnowflakeTableDataSet` loads data from a snowflake table and saves a pandas + dataframe to a table through snowpark. It uses ``pandas.DataFrame`` internally, + so it supports all allowed pandas options on ``read_sql_table`` and + ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when + instantiating ``SQLTableDataSet`` one needs to pass a compatible connection + string either in ``credentials`` (see the example code snippet below) or in + ``load_args`` and ``save_args``. Connection string formats supported by + SQLAlchemy can be found here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + + ``SQLTableDataSet`` modifies the save parameters and stores + the data with no index. This is designed to make load and save methods + symmetric. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> shuttles_table_dataset: + >>> type: snowflake.SnowflakeTableDataSet + >>> credentials: db_credentials + >>> table_name: shuttles + >>> load_args: + >>> schema: dwschema + >>> save_args: + >>> schema: dwschema + >>> if_exists: replace + + Sample database credentials entry in ``credentials.yml``: + + .. code-block:: yaml + + >>> db_creds: + + Example using Python API: + :: + + >>> from kedro.extras.datasets.pandas import SQLTableDataSet + >>> import pandas as pd + >>> + >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], + >>> "col3": [5, 6]}) + >>> table_name = "table_a" + >>> credentials = { + >>> "account": "", + >>> "user": "", + >>> "password": "", + >>> "role": "", + >>> "warehouse": "", + >>> "database": "", + >>> "schema": "" + >>> } + >>> data_set = SnowflakeTableDataSet(table_name=table_name, + >>> credentials=credentials) + >>> + >>> data_set.save(data) + >>> reloaded = data_set.load() + >>> + >>> assert data.equals(reloaded) + + """ + + DEFAULT_LOAD_ARGS: Dict[str, Any] = {} + DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} + # using Any because of Sphinx but it should be + # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine + sessions: Dict[str, Any] = {} + + def __init__( + self, + table_name: str, + credentials: Dict[str, Any], + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + ) -> None: + """Creates a new ``SQLTableDataSet``. + + Args: + table_name: The table name to load or save data to. It + overwrites name in ``save_args`` and ``table_name`` + parameters in ``load_args``. + credentials: A dictionary with a ``SQLAlchemy`` connection string. + Users are supposed to provide the connection string 'con' + through credentials. It overwrites `con` parameter in + ``load_args`` and ``save_args`` in case it is provided. To find + all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + load_args: Provided to underlying pandas ``read_sql_table`` + function along with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + save_args: Provided to underlying pandas ``to_sql`` function along + with the connection string. + To find all supported arguments, see here: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html + To find all supported connection string formats, see here: + https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + It has ``index=False`` in the default parameters. + + Raises: + DataSetError: When either ``table_name`` or ``con`` is empty. + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + if not credentials: + raise DataSetError("Please configure expected credentials") + + # print(self._load_args) + + # Handle default load and save arguments + self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._load_args["table_name"] = table_name + self._save_args["name"] = table_name + + self._credentials = credentials["credentials"] + + # self._connection_str = credentials["con"] + self._session = self._get_session(self._credentials) + + @classmethod + def _get_session(cls, credentials: dict) -> Session: + """Given a connection string, create singleton connection + to be used across all instances of `SQLQueryDataSet` that + need to connect to the same source. + connection_params = { + "account": "", + "user": "", + "password": "", + "role": "", + "warehouse": "", + "database": "", + "schema": "" + } + """ + try: + session = Session.builder.configs(credentials).create() + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except Exception as exception: + raise exception + return session + + def _describe(self) -> Dict[str, Any]: + load_args = copy.deepcopy(self._load_args) + save_args = copy.deepcopy(self._save_args) + return dict( + table_name=self._load_args["table_name"], + load_args=load_args, + save_args=save_args, + ) + + def _load(self) -> pd.DataFrame: + sp_df = self._session.table(self._load_args["table_name"]) + return sp_df.to_pandas() + + def _save(self, data: pd.DataFrame) -> None: + # pd df to snowpark df + sp_df = self._session.create_dataframe(data) + table_name = [ + self._credentials.get("database"), + self._credentials.get("schema"), + self._save_args["name"], + ] + sp_df.write.mode(self._save_args["mode"]).save_as_table( + table_name, + column_order=self._save_args["column_order"], + table_type=self._save_args["table_type"], + ) + + def _exists(self) -> bool: + session = self.sessions[self._credentials["account"]] # type: ignore + schema = self._load_args.get("schema", None) + exists = self._load_args["table_name"] in session.table_names(schema) + return exists diff --git a/kedro/extras/datasets/snowflake/snowpark_dataset.py b/kedro/extras/datasets/snowflake/snowpark_dataset.py new file mode 100644 index 0000000000..b968b6d5e9 --- /dev/null +++ b/kedro/extras/datasets/snowflake/snowpark_dataset.py @@ -0,0 +1,221 @@ +"""``AbstractDataSet`` implementation to access Snowflake using Snowpark dataframes +""" +from copy import deepcopy +from re import findall +from typing import Any, Dict, Optional, Union + +import pandas as pd +import snowflake.snowpark as sp + +from kedro.io.core import AbstractDataSet, DataSetError + +KNOWN_PIP_INSTALL = { + "snowflake.snowpark": "snowflake.snowpark", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to Snowflake +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DataSetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DataSetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +# TODO: Update docstring after interface finalised +# TODO: Add to docs example of using API to add dataset +class SnowParkDataSet( + AbstractDataSet[pd.DataFrame, pd.DataFrame] +): # pylint: disable=too-many-instance-attributes + """``SnowParkDataSet`` loads and saves Snowpark dataframes. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> weather: + >>> type: snowflake.SnowParkDataSet + >>> table_name: weather_data + >>> warehouse: warehouse_warehouse + >>> database: meteorology + >>> schema: observations + >>> credentials: db_credentials + >>> load_args (WIP): + >>> Do we need any? + >>> save_args: + >>> mode: overwrite + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # ``ThreadRunner`` instead + _SINGLE_PROCESS = True + DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] + DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] + + # TODO: Update docstring + def __init__( # pylint: disable=too-many-arguments + self, + table_name: str, + warehouse: str, + database: str, + schema: str, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + credentials: Dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``SnowParkDataSet``. + + Args: + table_name: + warehouse: + database: + schema: + load_args: + save_args: whatever supported by snowpark writer + https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html + credentials: + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + # TODO: Check if we can use default warehouse when user is not providing one explicitly + if not warehouse: + raise DataSetError("'warehouse' argument cannot be empty.") + + if not database: + raise DataSetError("'database' argument cannot be empty.") + + if not schema: + raise DataSetError("'schema' argument cannot be empty.") + + if not credentials: + raise DataSetError("Please configure expected credentials") + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._credentials = credentials["credentials"] + + self._session = self._get_session(self._credentials) + self._table_name = table_name + self._warehouse = warehouse + self._database = database + self._schema = schema + + def _describe(self) -> Dict[str, Any]: + return dict( + table_name=self._table_name, + warehouse=self._warehouse, + database=self._database, + schema=self._schema, + ) + + # TODO: Do we want to make it static method? + @classmethod + def _get_session(cls, credentials: dict) -> sp.Session: + """Given a connection string, create singleton connection + to be used across all instances of `SnowParkDataSet` that + need to connect to the same source. + connection_params = { + "account": "", + "user": "", + "password": "", + "role": "", (optional) + "warehouse": "", (optional) + "database": "", (optional) + "schema": "" (optional) + } + """ + try: + session = sp.Session.builder.configs(credentials).create() + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except Exception as exception: + raise exception + return session + + def _load(self) -> sp.DataFrame: + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + sp_df = self._session.table(".".join(table_name)) + return sp_df + + def _save(self, data: Union[pd.DataFrame, sp.DataFrame]) -> None: + if not isinstance(data, sp.DataFrame): + sp_df = self._session.create_dataframe(data) + else: + sp_df = data + + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + sp_df.write.mode(self._save_args["mode"]).save_as_table( + table_name, + column_order=self._save_args["column_order"], + table_type=self._save_args["table_type"], + statement_params=self._save_args["statement_params"], + ) + + def _exists(self) -> bool: + session = self._session + schema = self._schema + exists = self._table_name in session.table_names(schema) + return exists diff --git a/test_requirements.txt b/test_requirements.txt index 52305e865a..808dcbc80f 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -37,7 +37,7 @@ Pillow~=9.0 plotly>=4.8.0, <6.0 pre-commit>=2.9.2, <3.0 # The hook `mypy` requires pre-commit version 2.9.2. psutil==5.8.0 -pyarrow>=1.0, <7.0 +pyarrow>=1.0, <9.0 pylint>=2.5.2, <3.0 pyproj~=3.0 pyspark>=2.2, <4.0 @@ -49,6 +49,7 @@ redis~=4.1 requests-mock~=1.6 requests~=2.20 s3fs>=0.3.0, <0.5 # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. +snowflake-snowpark-python~=1.0.0 SQLAlchemy~=1.2 tables~=3.6.0; platform_system == "Windows" and python_version<'3.9' tables~=3.6; platform_system != "Windows"