diff --git a/Makefile b/Makefile index 0b6bd723b..86daa6313 100644 --- a/Makefile +++ b/Makefile @@ -56,3 +56,7 @@ test-no-spark: test-no-spark-sequential: cd kedro-datasets && pytest tests --no-cov --ignore tests/spark + +# kedro-datasets/snowflake tests skipped from default scope +test-snowflake-only: + cd kedro-datasets && pytest tests --no-cov --numprocesses 1 --dist loadfile -m snowflake diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 36b2d6c12..3e108e7f4 100644 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -9,6 +9,7 @@ | Type | Description | Location | | ------------------------------------ | -------------------------------------------------------------------------- | ----------------------------- | | `polars.CSVDataSet` | A `CSVDataSet` backed by [polars](https://www.pola.rs/), a lighting fast dataframe package built entirely using Rust. | `kedro_datasets.polars` | +| `snowflake.SnowparkTableDataSet` | Work with [Snowpark](https://www.snowflake.com/en/data-cloud/snowpark/) DataFrames from tables in Snowflake. | `kedro_datasets.snowflake` | ## Bug fixes and other changes * Add `mssql` backend to the `SQLQueryDataSet` DataSet using `pyodbc` library. diff --git a/kedro-datasets/kedro_datasets/snowflake/__init__.py b/kedro-datasets/kedro_datasets/snowflake/__init__.py new file mode 100644 index 000000000..fdcd16af2 --- /dev/null +++ b/kedro-datasets/kedro_datasets/snowflake/__init__.py @@ -0,0 +1,8 @@ +"""Provides I/O modules for Snowflake.""" + +__all__ = ["SnowparkTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .snowpark_dataset import SnowparkTableDataSet diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py new file mode 100644 index 000000000..e0ea1c1db --- /dev/null +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -0,0 +1,232 @@ +"""``AbstractDataSet`` implementation to access Snowflake using Snowpark dataframes +""" +import logging +from copy import deepcopy +from typing import Any, Dict + +import snowflake.snowpark as sp +from kedro.io.core import AbstractDataSet, DataSetError + +logger = logging.getLogger(__name__) + + +class SnowparkTableDataSet(AbstractDataSet): + """``SnowparkTableDataSet`` loads and saves Snowpark dataframes. + + As of Mar-2023, the snowpark connector only works with Python 3.8. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + weather: + type: kedro_datasets.snowflake.SnowparkTableDataSet + table_name: "weather_data" + database: "meteorology" + schema: "observations" + credentials: db_credentials + save_args: + mode: overwrite + column_order: name + table_type: '' + + You can skip everything but "table_name" if the database and + schema are provided via credentials. That way catalog entries can be shorter + if, for example, all used Snowflake tables live in same database/schema. + Values in the dataset definition take priority over those defined in credentials. + + Example: + Credentials file provides all connection attributes, catalog entry + "weather" reuses credentials parameters, "polygons" catalog entry reuses + all credentials parameters except providing a different schema name. + Second example of credentials file uses ``externalbrowser`` authentication. + + catalog.yml + + .. code-block:: yaml + weather: + type: kedro_datasets.snowflake.SnowparkTableDataSet + table_name: "weather_data" + database: "meteorology" + schema: "observations" + credentials: snowflake_client + save_args: + mode: overwrite + column_order: name + table_type: '' + + polygons: + type: kedro_datasets.snowflake.SnowparkTableDataSet + table_name: "geopolygons" + credentials: snowflake_client + schema: "geodata" + + credentials.yml + + .. code-block:: yaml + snowflake_client: + account: 'ab12345.eu-central-1' + port: 443 + warehouse: "datascience_wh" + database: "detailed_data" + schema: "observations" + user: "service_account_abc" + password: "supersecret" + + credentials.yml (with externalbrowser authenticator) + + .. code-block:: yaml + snowflake_client: + account: 'ab12345.eu-central-1' + port: 443 + warehouse: "datascience_wh" + database: "detailed_data" + schema: "observations" + user: "john_doe@wdomain.com" + authenticator: "externalbrowser" + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a pipeline please consider + # ``ThreadRunner`` instead + _SINGLE_PROCESS = True + DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] + DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] + + def __init__( # pylint: disable=too-many-arguments + self, + table_name: str, + schema: str = None, + database: str = None, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + credentials: Dict[str, Any] = None, + ) -> None: + """Creates a new instance of ``SnowparkTableDataSet``. + + Args: + table_name: The table name to load or save data to. + schema: Name of the schema where ``table_name`` is. + Optional as can be provided as part of ``credentials`` + dictionary. Argument value takes priority over one provided + in ``credentials`` if any. + database: Name of the database where ``schema`` is. + Optional as can be provided as part of ``credentials`` + dictionary. Argument value takes priority over one provided + in ``credentials`` if any. + load_args: Currently not used + save_args: Provided to underlying snowpark ``save_as_table`` + To find all supported arguments, see here: + https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html + credentials: A dictionary with a snowpark connection string. + To find all supported arguments, see here: + https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect + """ + + if not table_name: + raise DataSetError("'table_name' argument cannot be empty.") + + if not credentials: + raise DataSetError("'credentials' argument cannot be empty.") + + if not database: + if not ("database" in credentials and credentials["database"]): + raise DataSetError( + "'database' must be provided by credentials or dataset." + ) + database = credentials["database"] + + if not schema: + if not ("schema" in credentials and credentials["schema"]): + raise DataSetError( + "'schema' must be provided by credentials or dataset." + ) + schema = credentials["schema"] + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._table_name = table_name + self._database = database + self._schema = schema + + connection_parameters = credentials + connection_parameters.update( + {"database": self._database, "schema": self._schema} + ) + self._connection_parameters = connection_parameters + self._session = self._get_session(self._connection_parameters) + + def _describe(self) -> Dict[str, Any]: + return { + "table_name": self._table_name, + "database": self._database, + "schema": self._schema, + } + + @staticmethod + def _get_session(connection_parameters) -> sp.Session: + """Given a connection string, create singleton connection + to be used across all instances of `SnowparkTableDataSet` that + need to connect to the same source. + connection_parameters is a dictionary of any values + supported by snowflake python connector: + https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect + example: + connection_parameters = { + "account": "", + "user": "", + "password": "", (optional) + "role": "", (optional) + "warehouse": "", (optional) + "database": "", (optional) + "schema": "", (optional) + "authenticator: "" (optional) + } + """ + try: + logger.debug("Trying to reuse active snowpark session...") + session = sp.context.get_active_session() + except sp.exceptions.SnowparkSessionException: + logger.debug("No active snowpark session found. Creating") + session = sp.Session.builder.configs(connection_parameters).create() + return session + + def _load(self) -> sp.DataFrame: + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + sp_df = self._session.table(".".join(table_name)) + return sp_df + + def _save(self, data: sp.DataFrame) -> None: + table_name = [ + self._database, + self._schema, + self._table_name, + ] + + data.write.save_as_table(table_name, **self._save_args) + + def _exists(self) -> bool: + session = self._session + query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \ + WHERE TABLE_SCHEMA = '{schema}' \ + AND TABLE_NAME = '{table_name}'" + rows = session.sql( + query.format( + database=self._database, + schema=self._schema, + table_name=self._table_name, + ) + ).collect() + return rows[0][0] == 1 diff --git a/kedro-datasets/kedro_datasets/video/video_dataset.py b/kedro-datasets/kedro_datasets/video/video_dataset.py index 07f0e1c8f..03311146d 100644 --- a/kedro-datasets/kedro_datasets/video/video_dataset.py +++ b/kedro-datasets/kedro_datasets/video/video_dataset.py @@ -258,7 +258,6 @@ class VideoDataSet(AbstractDataSet[AbstractVideo, AbstractVideo]): """ - # pylint: disable=too-many-arguments def __init__( self, filepath: str, diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index a32898cf6..6df7bd372 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -34,7 +34,7 @@ min-public-methods = 1 [tool.coverage.report] fail_under = 100 show_missing = true -omit = ["tests/*", "kedro_datasets/holoviews/*"] +omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*"] exclude_lines = ["pragma: no cover", "raise NotImplementedError"] [tool.pytest.ini_options] diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index 44bb97185..26e583574 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -82,6 +82,9 @@ def _collect_requirements(requires): "spark.SparkJDBCDataSet": [SPARK, HDFS, S3FS], "spark.DeltaTableDataSet": [SPARK, HDFS, S3FS, "delta-spark~=1.0"], } +snowpark_require = { + "snowflake.SnowparkTableDataSet": ["snowflake-snowpark-python~=1.0.0", "pyarrow~=8.0"] +} svmlight_require = {"svmlight.SVMLightDataSet": ["scikit-learn~=1.0.2", "scipy~=1.7.3"]} tensorflow_required = { "tensorflow.TensorflowModelDataset": [ @@ -136,6 +139,7 @@ def _collect_requirements(requires): **video_require, **plotly_require, **spark_require, + **snowpark_require, **svmlight_require, **tensorflow_required, **yaml_require, diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index c6fc16a3e..48c3b511b 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -37,7 +37,7 @@ plotly>=4.8.0, <6.0 polars~=0.15.13 pre-commit>=2.9.2, <3.0 # The hook `mypy` requires pre-commit version 2.9.2. psutil==5.8.0 -pyarrow>=1.0, <7.0 +pyarrow~=8.0 pylint>=2.5.2, <3.0 pyodbc~=4.0.35 pyproj~=3.0 @@ -52,6 +52,7 @@ requests~=2.20 s3fs>=0.3.0, <0.5 # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. scikit-learn~=1.0.2 scipy~=1.7.3 +snowflake-snowpark-python~=1.0.0; python_version == '3.8' SQLAlchemy~=1.2 tables~=3.6.0; platform_system == "Windows" and python_version<'3.9' tables~=3.6; platform_system != "Windows" diff --git a/kedro-datasets/tests/snowflake/README.md b/kedro-datasets/tests/snowflake/README.md new file mode 100644 index 000000000..69fde3fd9 --- /dev/null +++ b/kedro-datasets/tests/snowflake/README.md @@ -0,0 +1,34 @@ +# Snowpark connector testing + +Execution of automated tests for Snowpark connector requires real Snowflake instance access. Therefore tests located in this folder are **disabled** by default from pytest execution scope using [conftest.py](conftest.py). + +[Makefile](/Makefile) provides separate argument ``test-snowflake-only`` to run only tests related to Snowpark connector. To run tests one need to provide Snowflake connection parameters via environment variables: +* SNOWSQL_ACCOUNT - Snowflake account name with region. Ex `ab12345.eu-central-2` +* SNOWSQL_WAREHOUSE - Snowflake virtual warehouse to use +* SNOWSQL_DATABASE - Database to use +* SNOWSQL_SCHEMA - Schema to use when creating tables for tests +* SNOWSQL_ROLE - Role to use for connection +* SNOWSQL_USER - Username to use for connection +* SNOWSQL_PWD - Plain password to use for connection + +All environment variables need to be provided for tests to run. + +Here is example shell command to run snowpark tests via make utility: +```bash +SNOWSQL_ACCOUNT='ab12345.eu-central-2' SNOWSQL_WAREHOUSE='DEV_WH' SNOWSQL_DATABASE='DEV_DB' SNOWSQL_ROLE='DEV_ROLE' SNOWSQL_USER='DEV_USER' SNOWSQL_SCHEMA='DATA' SNOWSQL_PWD='supersecret' make test-snowflake-only +``` + +Currently running tests supports only simple username & password authentication and not SSO/MFA. + +As of Mar-2023, the snowpark connector only works with Python 3.8. + +## Snowflake permissions required +Credentials provided via environment variables should have following permissions granted to run tests successfully: +* Create tables in a given schema +* Drop tables in a given schema +* Insert rows into tables in a given schema +* Query tables in a given schema +* Query `INFORMATION_SCHEMA.TABLES` of respective database + +## Extending tests +Contributors adding new tests should add `@pytest.mark.snowflake` decorator to each test. Exclusion of Snowpark-related pytests from overall execution scope in [conftest.py](conftest.py) works based on markers. diff --git a/kedro-datasets/tests/snowflake/__init__.py b/kedro-datasets/tests/snowflake/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/snowflake/conftest.py b/kedro-datasets/tests/snowflake/conftest.py new file mode 100644 index 000000000..f6188da76 --- /dev/null +++ b/kedro-datasets/tests/snowflake/conftest.py @@ -0,0 +1,24 @@ +""" +We disable execution of tests that require real Snowflake instance +to run by default. Providing -m snowflake option explicitly to +pytest will make these and only these tests run +""" +import pytest + + +def pytest_collection_modifyitems(config, items): + markers_arg = config.getoption("-m") + + # Naive implementation to handle basic marker expressions + # Will not work if someone will (ever) run pytest with complex marker + # expressions like "-m spark and not (snowflake or pandas)" + if ( + "snowflake" in markers_arg.lower() + and "not snowflake" not in markers_arg.lower() + ): + return + + skip_snowflake = pytest.mark.skip(reason="need -m snowflake option to run") + for item in items: + if "snowflake" in item.keywords: + item.add_marker(skip_snowflake) diff --git a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py new file mode 100644 index 000000000..2133953b5 --- /dev/null +++ b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py @@ -0,0 +1,166 @@ +import datetime +import os + +import pytest +from kedro.io import DataSetError + +try: + import snowflake.snowpark as sp + + from kedro_datasets.snowflake import SnowparkTableDataSet as spds +except ImportError: + pass # this is only for test discovery to succeed on Python <> 3.8 + + +def get_connection(): + account = os.getenv("SNOWSQL_ACCOUNT") + warehouse = os.getenv("SNOWSQL_WAREHOUSE") + database = os.getenv("SNOWSQL_DATABASE") + role = os.getenv("SNOWSQL_ROLE") + user = os.getenv("SNOWSQL_USER") + schema = os.getenv("SNOWSQL_SCHEMA") + password = os.getenv("SNOWSQL_PWD") + + if not ( + account and warehouse and database and role and user and schema and password + ): + raise DataSetError( + "Snowflake connection environment variables provided not in full" + ) + + conn = { + "account": account, + "warehouse": warehouse, + "database": database, + "role": role, + "user": user, + "schema": schema, + "password": password, + } + return conn + + +def sf_setup_db(sf_session): + # For table exists test + run_query(sf_session, 'CREATE TABLE KEDRO_PYTEST_TESTEXISTS ("name" VARCHAR)') + + # For load test + query = 'CREATE TABLE KEDRO_PYTEST_TESTLOAD ("name" VARCHAR\ + , "age" INTEGER\ + , "bday" date\ + , "height" float\ + , "insert_dttm" timestamp)' + run_query(sf_session, query) + + query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('John'\ + , 23\ + , to_date('1999-12-02','YYYY-MM-DD')\ + , 6.5\ + , to_timestamp_ntz('2022-12-02 13:20:01',\ + 'YYYY-MM-DD hh24:mi:ss'))" + run_query(sf_session, query) + + query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('Jane'\ + , 41\ + , to_date('1981-01-03','YYYY-MM-DD')\ + , 5.7\ + , to_timestamp_ntz('2022-12-02 13:21:11',\ + 'YYYY-MM-DD hh24:mi:ss'))" + run_query(sf_session, query) + + +def sf_db_cleanup(sf_session): + run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTEXISTS") + run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTLOAD") + run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTSAVE") + + +def run_query(session, query): + df = session.sql(query) + df.collect() + return df + + +def df_equals_ignore_dtype(df1, df2): + # Pytest will show respective stdout only if test fails + # this will help to debug what was exactly not matching right away + + c1 = df1.to_pandas().values.tolist() + c2 = df2.to_pandas().values.tolist() + + print(c1) + print("--- comparing to ---") + print(c2) + + for i, row in enumerate(c1): + for j, column in enumerate(row): + if not column == c2[i][j]: + print(f"{column} not equal to {c2[i][j]}") + return False + return True + + +@pytest.fixture +def sample_sp_df(sf_session): + return sf_session.create_dataframe( + [ + [ + "John", + 23, + datetime.date(1999, 12, 2), + 6.5, + datetime.datetime(2022, 12, 2, 13, 20, 1), + ], + [ + "Jane", + 41, + datetime.date(1981, 1, 3), + 5.7, + datetime.datetime(2022, 12, 2, 13, 21, 11), + ], + ], + schema=["name", "age", "bday", "height", "insert_dttm"], + ) + + +@pytest.fixture +def sf_session(): + sf_session = sp.Session.builder.configs(get_connection()).create() + + # Running cleanup in case previous run was interrupted w/o proper cleanup + sf_db_cleanup(sf_session) + sf_setup_db(sf_session) + + yield sf_session + sf_db_cleanup(sf_session) + sf_session.close() + + +class TestSnowparkTableDataSet: + @pytest.mark.snowflake + def test_save(self, sample_sp_df, sf_session): + sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection()) + sp_df._save(sample_sp_df) + sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVE") + assert sp_df_saved.count() == 2 + + @pytest.mark.snowflake + def test_load(self, sample_sp_df, sf_session): + print(sf_session) + sp_df = spds( + table_name="KEDRO_PYTEST_TESTLOAD", credentials=get_connection() + )._load() + + # Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare + # fails on that + assert df_equals_ignore_dtype(sample_sp_df, sp_df) is True + + @pytest.mark.snowflake + def test_exists(self, sf_session): + print(sf_session) + df_e = spds(table_name="KEDRO_PYTEST_TESTEXISTS", credentials=get_connection()) + df_ne = spds( + table_name="KEDRO_PYTEST_TESTNEXISTS", credentials=get_connection() + ) + assert df_e._exists() is True + assert df_ne._exists() is False