Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1690717 Applying Snowpark Python function (sin) #2415

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
- Added support for `errors="ignore"` in `pd.to_datetime`.
- Added support for `DataFrame.tz_localize` and `Series.tz_localize`.
- Added support for `DataFrame.tz_convert` and `Series.tz_convert`.
- Added support for applying Snowpark Python functions (e.g., `sin`) in `Series.map`, `Series.apply`, `DataFrame.apply` and `DataFrame.applymap`.

#### Improvements

Expand Down
21 changes: 20 additions & 1 deletion src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import inspect
import json
import sys
from collections import namedtuple
Expand All @@ -14,10 +15,11 @@
from pandas._typing import AggFuncType
from pandas.api.types import is_scalar

from snowflake.snowpark import functions
from snowflake.snowpark._internal.type_utils import PYTHON_TO_SNOW_TYPE_MAPPINGS
from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints
from snowflake.snowpark.column import Column as SnowparkColumn
from snowflake.snowpark.functions import builtin, col, dense_rank, udf, udtf
from snowflake.snowpark.functions import builtin, col, dense_rank, sin, udf, udtf
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
OrderedDataFrame,
Expand All @@ -29,6 +31,7 @@
parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label,
parse_snowflake_object_construct_identifier_to_map,
)
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -58,6 +61,10 @@
# https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
cloudpickle.register_pickle_by_value(sys.modules[__name__])

SUPPORTED_SNOWPARK_PYTHON_FUNCTIONS_IN_APPLY = {
sin,
}


class GroupbyApplySortMethod(Enum):
"""
Expand Down Expand Up @@ -1356,3 +1363,15 @@ def groupby_apply_sort_method(
else GroupbyApplySortMethod.GROUP_KEY_APPEARANCE_ORDER
)
)


def is_supported_snowpark_python_function(func: AggFuncType) -> bool:
"""Return True if the `func` is a supported Snowpark Python function."""
func_module = inspect.getmodule(func)
if functions != func_module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like checking the module is necessary if you are checking the function reference as well, but it probably doesn't hurt. Were you concerned about something in particular here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just use this to report NotImplementedError for those not in the supported list.

return False
if func not in SUPPORTED_SNOWPARK_PYTHON_FUNCTIONS_IN_APPLY:
ErrorMessage.not_implemented(
f"Snowpark Python function `{func.__name__}` is not supported yet."
)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
groupby_apply_create_internal_frame_from_final_ordered_dataframe,
groupby_apply_pivot_result_to_final_ordered_dataframe,
groupby_apply_sort_method,
is_supported_snowpark_python_function,
sort_apply_udtf_result_columns_by_pandas_positions,
)
from snowflake.snowpark.modin.plugin._internal.binary_op_utils import (
Expand Down Expand Up @@ -8292,6 +8293,21 @@ def apply(
"Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'"
)

if is_supported_snowpark_python_function(func):
sfc-gh-joshi marked this conversation as resolved.
Show resolved Hide resolved
if axis != 0:
ErrorMessage.not_implemented(
f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with axis = {axis}."
)
if raw is not False:
ErrorMessage.not_implemented(
f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with raw = {raw}."
)
if args:
ErrorMessage.not_implemented(
f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'."
)
return self._apply_snowpark_python_function_to_columns(func)

if axis == 0:
frame = self._modin_frame

Expand Down Expand Up @@ -8546,6 +8562,19 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no
func, raw, result_type, args, column_index, input_types, **kwargs
)

def _apply_snowpark_python_function_to_columns(
self,
snowpark_function: Callable,
) -> "SnowflakeQueryCompiler":
"""Apply Snowpark Python function to columns."""

def sf_function(col: SnowparkColumn) -> SnowparkColumn:
return snowpark_function(col)

return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(sf_function)
)

def applymap(
self,
func: AggFuncType,
Expand All @@ -8566,6 +8595,16 @@ def applymap(
"""
self._raise_not_implemented_error_for_timedelta()

if is_supported_snowpark_python_function(func):
if na_action:
ErrorMessage.not_implemented(
f"Snowpark pandas applymap API doesn't yet support Snowpark Python function `{func.__name__}` with na_action == '{na_action}'"
)
if args:
ErrorMessage.not_implemented(
f"Snowpark pandas applymap API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'."
)
return self._apply_snowpark_python_function_to_columns(func)
# Currently, NULL values are always passed into the udtf even if strict=True,
# which is a bug on the server side SNOW-880105.
# The fix will not land soon, so we are going to raise not implemented error for now.
Expand Down
51 changes: 51 additions & 0 deletions tests/integ/modin/test_apply_snowpark_python_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import math

import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest

from tests.integ.modin.utils import assert_frame_equal, assert_series_equal
from tests.integ.utils.sql_counter import sql_count_checker


@sql_count_checker(query_count=4)
def test_apply_sin():
from snowflake.snowpark.functions import sin

native_s = native_pd.Series([0.00, -1.23, 10, math.pi, math.pi / 2])
s = pd.Series(native_s)

assert_series_equal(s.apply(sin), native_s.apply(math.sin))
assert_series_equal(s.map(sin), native_s.map(math.sin))
assert_frame_equal(
s.to_frame().applymap(sin), native_s.to_frame().applymap(math.sin)
)
assert_frame_equal(
s.to_frame().apply(sin),
native_s.to_frame().apply(np.sin), # Note math.sin does not work with df.apply
)


@sql_count_checker(query_count=0)
def test_apply_snowpark_python_function_not_implemented():
sfc-gh-joshi marked this conversation as resolved.
Show resolved Hide resolved
from snowflake.snowpark.functions import cos, sin

with pytest.raises(NotImplementedError):
pd.Series([1, 2, 3]).apply(cos)
with pytest.raises(NotImplementedError):
pd.Series([1, 2, 3]).to_frame().applymap(sin, na_action="ignore")
with pytest.raises(NotImplementedError):
pd.Series([1, 2, 3]).to_frame().applymap(sin, args=[1, 2])
with pytest.raises(NotImplementedError):
pd.DataFrame({"a": [1, 2, 3]}).apply(cos)
with pytest.raises(NotImplementedError):
pd.DataFrame({"a": [1, 2, 3]}).apply(sin, raw=True)
with pytest.raises(NotImplementedError):
pd.DataFrame({"a": [1, 2, 3]}).apply(sin, axis=1)
with pytest.raises(NotImplementedError):
pd.DataFrame({"a": [1, 2, 3]}).apply(sin, args=(1, 2))
Loading