Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates related to SNOW-616002 #405

Merged
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
THIS_DIR = os.path.dirname(os.path.realpath(__file__))
SRC_DIR = os.path.join(THIS_DIR, "src")
SNOWPARK_SRC_DIR = os.path.join(SRC_DIR, "snowflake", "snowpark")
CONNECTOR_DEPENDENCY_VERSION = "2.7.4"
CONNECTOR_DEPENDENCY_VERSION = "2.7.11"

# read the version
VERSION = ()
Expand Down
20 changes: 20 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,7 @@ def write_pandas(
quote_identifiers: bool = True,
auto_create_table: bool = False,
create_temp_table: bool = False,
overwrite: bool = False,
) -> Table:
"""Writes a pandas DataFrame to a table in Snowflake and returns a
Snowpark :class:`DataFrame` object referring to the table where the
Expand All @@ -982,6 +983,10 @@ def write_pandas(
table, you can always create your own table before calling this function. For example, auto-created
tables will store :class:`list`, :class:`tuple` and :class:`dict` as strings in a VARCHAR column.
create_temp_table: The to-be-created table will be temporary if this is set to ``True``.
overwrite: Default value is ``False`` and the Pandas DataFrame data is appended to the existing table. If set to ``True`` and if auto_create_table is also set to ``True``,
then it drops the table. If set to ``True`` and if auto_create_table is set to ``False``,
then it trunctates the table. Note that in both cases (when overwrite is set to ``True``) it will replace the existing
contents of the table with that of the passed in Pandas DataFrame.

Example::

Expand All @@ -993,6 +998,20 @@ def write_pandas(
0 1 Steve
1 2 Bob

>>> pandas_df2 = pd.DataFrame([(3, "John")], columns=["id", "name"])
>>> snowpark_df2 = session.write_pandas(pandas_df, "write_pandas_table", auto_create_table=False)
iamontheinet marked this conversation as resolved.
Show resolved Hide resolved
>>> snowpark_df2.to_pandas()
id name
0 1 Steve
1 2 Bob
2 3 John

>>> pandas_df3 = pd.DataFrame([(1, "Jane")], columns=["id", "name"])
>>> snowpark_df3 = session.write_pandas(pandas_df, "write_pandas_table", auto_create_table=False, overwrite=True)
iamontheinet marked this conversation as resolved.
Show resolved Hide resolved
>>> snowpark_df3.to_pandas()
id name
0 1 Jane

Note:
Unless ``auto_create_table`` is ``True``, you must first create a table in
Snowflake that the passed in pandas DataFrame can be written to. If
Expand Down Expand Up @@ -1026,6 +1045,7 @@ def write_pandas(
quote_identifiers=quote_identifiers,
auto_create_table=auto_create_table,
create_temp_table=create_temp_table,
overwrite=overwrite,
iamontheinet marked this conversation as resolved.
Show resolved Hide resolved
)
except ProgrammingError as pe:
if pe.msg.endswith("does not exist"):
Expand Down
93 changes: 93 additions & 0 deletions tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas import DataFrame as PandasDF
from pandas.testing import assert_frame_equal

from snowflake.connector.errors import ProgrammingError
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import SnowparkPandasException
from tests.utils import Utils
Expand Down Expand Up @@ -48,6 +49,98 @@ def tmp_table_complex(session):
Utils.drop_table(session, table_name)


@pytest.mark.parametrize("quote_identifiers", [True, False])
@pytest.mark.parametrize("auto_create_table", [True, False])
@pytest.mark.parametrize("overwrite", [True, False])
def test_write_pandas_with_overwrite(
session,
tmp_table_basic,
quote_identifiers: bool,
auto_create_table: bool,
overwrite: bool,
):
pd1 = PandasDF(
[
(1, 4.5, "Nike"),
(2, 7.5, "Adidas"),
(3, 10.5, "Puma"),
],
columns=["id".upper(), "foot_size".upper(), "shoe_make".upper()],
)

pd2 = PandasDF(
[(1, 8.0, "Dia Dora")],
columns=["id".upper(), "foot_size".upper(), "shoe_make".upper()],
)

pd3 = PandasDF(
[(1, "dash", 1000, 32)],
columns=["id".upper(), "name".upper(), "points".upper(), "age".upper()],
)

table_name = tmp_table_basic

# Create initial table and insert 3 rows
drop_sql = f'DROP TABLE IF EXISTS "{table_name}"'
session.sql(drop_sql).collect()
df1 = session.write_pandas(
pd1, table_name, quote_identifiers=quote_identifiers, auto_create_table=True
)
results = df1.to_pandas()
assert_frame_equal(results, pd1, check_dtype=False)

# Insert 1 row
df2 = session.write_pandas(
pd2,
table_name,
quote_identifiers=quote_identifiers,
overwrite=overwrite,
auto_create_table=auto_create_table,
)
results = df2.to_pandas()
if overwrite:
# Results should match pd2
assert_frame_equal(results, pd2, check_dtype=False)
else:
# Results count should match pd1 + pd2
assert results.shape[0] == 4

# Insert 1 row with new schema
if auto_create_table:
if overwrite:
# In this case, the table is first dropped and since there's a new schema, the results should now match pd3
df3 = session.write_pandas(
pd3,
table_name,
quote_identifiers=quote_identifiers,
overwrite=overwrite,
auto_create_table=auto_create_table,
)
results = df3.to_pandas()
assert_frame_equal(results, pd3, check_dtype=False)
else:
# In this case, the table is truncated but since there's a new schema, it should fail
with pytest.raises(ProgrammingError) as ex_info:
session.write_pandas(
pd3,
table_name,
quote_identifiers=quote_identifiers,
overwrite=overwrite,
auto_create_table=auto_create_table,
)
assert "invalid identifier 'NAME'" in str(ex_info)

with pytest.raises(SnowparkPandasException) as ex_info:
session.write_pandas(pd1, "tmp_table")
assert (
'Cannot write pandas DataFrame to table "tmp_table" because it does not exist. '
"Create table before trying to write a pandas DataFrame" in str(ex_info)
)

# Drop tables that were created for this test
session.sql(drop_sql).collect()


def test_write_pandas(session, tmp_table_basic):
pd = PandasDF(
[
Expand Down