diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index 661c4d8f0..4a014b81a 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -2,16 +2,15 @@ """ import logging from copy import deepcopy -from typing import Any, Dict, Union +from typing import Any, Dict -import pandas as pd import snowflake.snowpark as sp from kedro.io.core import AbstractDataSet, DataSetError logger = logging.getLogger(__name__) -class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): +class SnowparkTableDataSet(AbstractDataSet[sp.DataFrame, sp.DataFrame]): """``SnowparkTableDataSet`` loads and saves Snowpark dataframes. Example usage for the @@ -38,7 +37,8 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): Example: Credentials file provides all connection attributes, catalog entry "weather" reuse credentials parameters, "polygons" catalog entry reuse - all credentials parameters except providing different schema name + all credentials parameters except providing different schema name. + Second example of credentials file uses externalbrowser authentication catalog.yml @@ -46,6 +46,9 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): weather: type: kedro_datasets.snowflake.SnowparkTableDataSet table_name: "weather_data" + database: "meteorology" + schema: "observations" + credentials: snowflake_client save_args: mode: overwrite column_order: name @@ -54,6 +57,7 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): polygons: type: kedro_datasets.snowflake.SnowparkTableDataSet table_name: "geopolygons" + credentials: snowflake_client schema: "geodata" credentials.yml @@ -68,6 +72,18 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): user: "service_account_abc" password: "supersecret" + credentials.yml (with externalbrowser authenticator) + + .. code-block:: yaml + snowflake_client: + account: 'ab12345.eu-central-1' + port: 443 + warehouse: "datascience_wh" + database: "detailed_data" + schema: "observations" + user: "john_doe@wdomain.com" + authenticator: "externalbrowser" + As of Jan-2023, the snowpark connector only works with Python 3.8 """ @@ -177,10 +193,8 @@ def _get_session(connection_parameters) -> sp.Session: """ try: logger.debug("Trying to reuse active snowpark session...") - # if hook is implemented, get active session session = sp.context.get_active_session() except sp.exceptions.SnowparkSessionException: - # create session if there is no active one logger.debug("No active snowpark session found. Creating") session = sp.Session.builder.configs(connection_parameters).create() return session @@ -195,19 +209,14 @@ def _load(self) -> sp.DataFrame: sp_df = self._session.table(".".join(table_name)) return sp_df - def _save(self, data: Union[pd.DataFrame, sp.DataFrame]) -> None: - if not isinstance(data, sp.DataFrame): - sp_df = self._session.create_dataframe(data) - else: - sp_df = data - + def _save(self, data: sp.DataFrame) -> None: table_name = [ self._database, self._schema, self._table_name, ] - sp_df.write.save_as_table(table_name, **self._save_args) + data.write.save_as_table(table_name, **self._save_args) def _exists(self) -> bool: session = self._session diff --git a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py index 527647a37..665bd5a7b 100644 --- a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py +++ b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py @@ -1,7 +1,6 @@ import datetime import os -import pandas as pd import pytest from kedro.io import DataSetError @@ -82,14 +81,16 @@ def run_query(session, query): return df -def pandas_equals_ignore_dtype(df1, df2): +def df_equals_ignore_dtype(df1, df2): # Pytest will show respective stdout only if test fails # this will help to debug what was exactly not matching right away - print(df1) + + c1 = df1.to_pandas().values.tolist() + c2 = df2.to_pandas().values.tolist() + + print(c1) print("--- comparing to ---") - print(df2) - c1 = df1.values.tolist() - c2 = df2.values.tolist() + print(c2) for i, row in enumerate(c1): for j, column in enumerate(row): @@ -100,19 +101,25 @@ def pandas_equals_ignore_dtype(df1, df2): @pytest.fixture -def sample_pandas_df() -> pd.DataFrame: - return pd.DataFrame( - { - "name": ["John", "Jane"], - "age": [23, 41], - "bday": [datetime.date(1999, 12, 2), datetime.date(1981, 1, 3)], - "height": [6.5, 5.7], - "insert_dttm": [ +def sample_sp_df(sf_session) -> sp.DataFrame: + return sf_session.create_dataframe( + [ + [ + "John", + 23, + datetime.date(1999, 12, 2), + 6.5, datetime.datetime(2022, 12, 2, 13, 20, 1), + ], + [ + "Jane", + 41, + datetime.date(1981, 1, 3), + 5.7, datetime.datetime(2022, 12, 2, 13, 21, 11), ], - }, - columns=["name", "age", "bday", "height", "insert_dttm"], + ], + schema=["name", "age", "bday", "height", "insert_dttm"], ) @@ -131,23 +138,22 @@ def sf_session(): class TestSnowparkTableDataSet: @pytest.mark.snowflake - def test_save(self, sample_pandas_df, sf_session): + def test_save(self, sample_sp_df, sf_session): sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection()) - sp_df._save(sample_pandas_df) + sp_df._save(sample_sp_df) sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVE") assert sp_df_saved.count() == 2 @pytest.mark.snowflake - def test_load(self, sample_pandas_df, sf_session): + def test_load(self, sample_sp_df, sf_session): print(sf_session) - df_sf = spds( + sp_df = spds( table_name="KEDRO_PYTEST_TESTLOAD", credentials=get_connection() )._load() - sf = df_sf.to_pandas() # Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare # fails on that - assert pandas_equals_ignore_dtype(sample_pandas_df, sf) is True + assert df_equals_ignore_dtype(sample_sp_df, sp_df) is True @pytest.mark.snowflake def test_exists(self, sf_session):