Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pd interactions and add docs #2

Merged
merged 2 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
"""
import logging
from copy import deepcopy
from typing import Any, Dict, Union
from typing import Any, Dict

import pandas as pd
import snowflake.snowpark as sp
from kedro.io.core import AbstractDataSet, DataSetError

logger = logging.getLogger(__name__)


class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]):
class SnowparkTableDataSet(AbstractDataSet[sp.DataFrame, sp.DataFrame]):
"""``SnowparkTableDataSet`` loads and saves Snowpark dataframes.

Example usage for the
Expand All @@ -38,14 +37,18 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]):
Example:
Credentials file provides all connection attributes, catalog entry
"weather" reuse credentials parameters, "polygons" catalog entry reuse
all credentials parameters except providing different schema name
all credentials parameters except providing different schema name.
Second example of credentials file uses externalbrowser authentication

catalog.yml

.. code-block:: yaml
weather:
type: kedro_datasets.snowflake.SnowparkTableDataSet
table_name: "weather_data"
database: "meteorology"
schema: "observations"
credentials: snowflake_client
save_args:
mode: overwrite
column_order: name
Expand All @@ -54,6 +57,7 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]):
polygons:
type: kedro_datasets.snowflake.SnowparkTableDataSet
table_name: "geopolygons"
credentials: snowflake_client
schema: "geodata"

credentials.yml
Expand All @@ -68,6 +72,18 @@ class SnowparkTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]):
user: "service_account_abc"
password: "supersecret"

credentials.yml (with externalbrowser authenticator)

.. code-block:: yaml
snowflake_client:
account: 'ab12345.eu-central-1'
port: 443
warehouse: "datascience_wh"
database: "detailed_data"
schema: "observations"
user: "john_doe@wdomain.com"
authenticator: "externalbrowser"

As of Jan-2023, the snowpark connector only works with Python 3.8
"""

Expand Down Expand Up @@ -177,10 +193,8 @@ def _get_session(connection_parameters) -> sp.Session:
"""
try:
logger.debug("Trying to reuse active snowpark session...")
# if hook is implemented, get active session
session = sp.context.get_active_session()
except sp.exceptions.SnowparkSessionException:
# create session if there is no active one
logger.debug("No active snowpark session found. Creating")
session = sp.Session.builder.configs(connection_parameters).create()
return session
Expand All @@ -195,19 +209,14 @@ def _load(self) -> sp.DataFrame:
sp_df = self._session.table(".".join(table_name))
return sp_df

def _save(self, data: Union[pd.DataFrame, sp.DataFrame]) -> None:
if not isinstance(data, sp.DataFrame):
sp_df = self._session.create_dataframe(data)
else:
sp_df = data

def _save(self, data: sp.DataFrame) -> None:
table_name = [
self._database,
self._schema,
self._table_name,
]

sp_df.write.save_as_table(table_name, **self._save_args)
data.write.save_as_table(table_name, **self._save_args)

def _exists(self) -> bool:
session = self._session
Expand Down
50 changes: 28 additions & 22 deletions kedro-datasets/tests/snowflake/test_snowpark_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import os

import pandas as pd
import pytest
from kedro.io import DataSetError

Expand Down Expand Up @@ -82,14 +81,16 @@ def run_query(session, query):
return df


def pandas_equals_ignore_dtype(df1, df2):
def df_equals_ignore_dtype(df1, df2):
# Pytest will show respective stdout only if test fails
# this will help to debug what was exactly not matching right away
print(df1)

c1 = df1.to_pandas().values.tolist()
c2 = df2.to_pandas().values.tolist()

print(c1)
print("--- comparing to ---")
print(df2)
c1 = df1.values.tolist()
c2 = df2.values.tolist()
print(c2)

for i, row in enumerate(c1):
for j, column in enumerate(row):
Expand All @@ -100,19 +101,25 @@ def pandas_equals_ignore_dtype(df1, df2):


@pytest.fixture
def sample_pandas_df() -> pd.DataFrame:
return pd.DataFrame(
{
"name": ["John", "Jane"],
"age": [23, 41],
"bday": [datetime.date(1999, 12, 2), datetime.date(1981, 1, 3)],
"height": [6.5, 5.7],
"insert_dttm": [
def sample_sp_df(sf_session) -> sp.DataFrame:
return sf_session.create_dataframe(
[
[
"John",
23,
datetime.date(1999, 12, 2),
6.5,
datetime.datetime(2022, 12, 2, 13, 20, 1),
],
[
"Jane",
41,
datetime.date(1981, 1, 3),
5.7,
datetime.datetime(2022, 12, 2, 13, 21, 11),
],
},
columns=["name", "age", "bday", "height", "insert_dttm"],
],
schema=["name", "age", "bday", "height", "insert_dttm"],
)


Expand All @@ -131,23 +138,22 @@ def sf_session():

class TestSnowparkTableDataSet:
@pytest.mark.snowflake
def test_save(self, sample_pandas_df, sf_session):
def test_save(self, sample_sp_df, sf_session):
sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection())
sp_df._save(sample_pandas_df)
sp_df._save(sample_sp_df)
sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVE")
assert sp_df_saved.count() == 2

@pytest.mark.snowflake
def test_load(self, sample_pandas_df, sf_session):
def test_load(self, sample_sp_df, sf_session):
print(sf_session)
df_sf = spds(
sp_df = spds(
table_name="KEDRO_PYTEST_TESTLOAD", credentials=get_connection()
)._load()
sf = df_sf.to_pandas()

# Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare
# fails on that
assert pandas_equals_ignore_dtype(sample_pandas_df, sf) is True
assert df_equals_ignore_dtype(sample_sp_df, sp_df) is True

@pytest.mark.snowflake
def test_exists(self, sf_session):
Expand Down