From 7ee707502032beaa70b02b6ce9fcbe6441fae983 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 8 Oct 2024 11:10:54 -0700 Subject: [PATCH] SNOW-1690717 Applying Snowpark Python function (sin) --- CHANGELOG.md | 1 + .../modin/plugin/_internal/apply_utils.py | 21 +++++++- .../compiler/snowflake_query_compiler.py | 39 ++++++++++++++ .../test_apply_snowpark_python_functions.py | 51 +++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 tests/integ/modin/test_apply_snowpark_python_functions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ac6cc0fe09..8d70b117e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ - Added support for `errors="ignore"` in `pd.to_datetime`. - Added support for `DataFrame.tz_localize` and `Series.tz_localize`. - Added support for `DataFrame.tz_convert` and `Series.tz_convert`. +- Added support for applying Snowpark Python functions (e.g., `sin`) in `Series.map`, `Series.apply`, `DataFrame.apply` and `DataFrame.applymap`. #### Improvements diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 413fdd8f02..f0511478b4 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import inspect import json import sys from collections import namedtuple @@ -14,10 +15,11 @@ from pandas._typing import AggFuncType from pandas.api.types import is_scalar +from snowflake.snowpark import functions from snowflake.snowpark._internal.type_utils import PYTHON_TO_SNOW_TYPE_MAPPINGS from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints from snowflake.snowpark.column import Column as SnowparkColumn -from snowflake.snowpark.functions import builtin, col, dense_rank, udf, udtf +from snowflake.snowpark.functions import builtin, col, dense_rank, sin, udf, udtf from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( OrderedDataFrame, @@ -29,6 +31,7 @@ parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label, parse_snowflake_object_construct_identifier_to_map, ) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL from snowflake.snowpark.session import Session from snowflake.snowpark.types import ( @@ -58,6 +61,10 @@ # https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs cloudpickle.register_pickle_by_value(sys.modules[__name__]) +SUPPORTED_SNOWPARK_PYTHON_FUNCTIONS_IN_APPLY = { + sin, +} + class GroupbyApplySortMethod(Enum): """ @@ -1356,3 +1363,15 @@ def groupby_apply_sort_method( else GroupbyApplySortMethod.GROUP_KEY_APPEARANCE_ORDER ) ) + + +def is_supported_snowpark_python_function(func: AggFuncType) -> bool: + """Return True if the `func` is a supported Snowpark Python function.""" + func_module = inspect.getmodule(func) + if functions != func_module: + return False + if func not in SUPPORTED_SNOWPARK_PYTHON_FUNCTIONS_IN_APPLY: + ErrorMessage.not_implemented( + f"Snowpark Python function `{func.__name__}` is not supported yet." + ) + return True diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index e5b199279b..25533f8bd6 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -182,6 +182,7 @@ groupby_apply_create_internal_frame_from_final_ordered_dataframe, groupby_apply_pivot_result_to_final_ordered_dataframe, groupby_apply_sort_method, + is_supported_snowpark_python_function, sort_apply_udtf_result_columns_by_pandas_positions, ) from snowflake.snowpark.modin.plugin._internal.binary_op_utils import ( @@ -8292,6 +8293,21 @@ def apply( "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'" ) + if is_supported_snowpark_python_function(func): + if axis != 0: + ErrorMessage.not_implemented( + f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with axis = {axis}." + ) + if raw is not False: + ErrorMessage.not_implemented( + f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with raw = {raw}." + ) + if args: + ErrorMessage.not_implemented( + f"Snowpark pandas apply API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'." + ) + return self._apply_snowpark_python_function_to_columns(func) + if axis == 0: frame = self._modin_frame @@ -8546,6 +8562,19 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no func, raw, result_type, args, column_index, input_types, **kwargs ) + def _apply_snowpark_python_function_to_columns( + self, + snowpark_function: Callable, + ) -> "SnowflakeQueryCompiler": + """Apply Snowpark Python function to columns.""" + + def sf_function(col: SnowparkColumn) -> SnowparkColumn: + return snowpark_function(col) + + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns(sf_function) + ) + def applymap( self, func: AggFuncType, @@ -8566,6 +8595,16 @@ def applymap( """ self._raise_not_implemented_error_for_timedelta() + if is_supported_snowpark_python_function(func): + if na_action: + ErrorMessage.not_implemented( + f"Snowpark pandas applymap API doesn't yet support Snowpark Python function `{func.__name__}` with na_action == '{na_action}'" + ) + if args: + ErrorMessage.not_implemented( + f"Snowpark pandas applymap API doesn't yet support Snowpark Python function `{func.__name__}` with args = '{args}'." + ) + return self._apply_snowpark_python_function_to_columns(func) # Currently, NULL values are always passed into the udtf even if strict=True, # which is a bug on the server side SNOW-880105. # The fix will not land soon, so we are going to raise not implemented error for now. diff --git a/tests/integ/modin/test_apply_snowpark_python_functions.py b/tests/integ/modin/test_apply_snowpark_python_functions.py new file mode 100644 index 0000000000..c40801a051 --- /dev/null +++ b/tests/integ/modin/test_apply_snowpark_python_functions.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import math + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +from tests.integ.modin.utils import assert_frame_equal, assert_series_equal +from tests.integ.utils.sql_counter import sql_count_checker + + +@sql_count_checker(query_count=4) +def test_apply_sin(): + from snowflake.snowpark.functions import sin + + native_s = native_pd.Series([0.00, -1.23, 10, math.pi, math.pi / 2]) + s = pd.Series(native_s) + + assert_series_equal(s.apply(sin), native_s.apply(math.sin)) + assert_series_equal(s.map(sin), native_s.map(math.sin)) + assert_frame_equal( + s.to_frame().applymap(sin), native_s.to_frame().applymap(math.sin) + ) + assert_frame_equal( + s.to_frame().apply(sin), + native_s.to_frame().apply(np.sin), # Note math.sin does not work with df.apply + ) + + +@sql_count_checker(query_count=0) +def test_apply_snowpark_python_function_not_implemented(): + from snowflake.snowpark.functions import cos, sin + + with pytest.raises(NotImplementedError): + pd.Series([1, 2, 3]).apply(cos) + with pytest.raises(NotImplementedError): + pd.Series([1, 2, 3]).to_frame().applymap(sin, na_action="ignore") + with pytest.raises(NotImplementedError): + pd.Series([1, 2, 3]).to_frame().applymap(sin, args=[1, 2]) + with pytest.raises(NotImplementedError): + pd.DataFrame({"a": [1, 2, 3]}).apply(cos) + with pytest.raises(NotImplementedError): + pd.DataFrame({"a": [1, 2, 3]}).apply(sin, raw=True) + with pytest.raises(NotImplementedError): + pd.DataFrame({"a": [1, 2, 3]}).apply(sin, axis=1) + with pytest.raises(NotImplementedError): + pd.DataFrame({"a": [1, 2, 3]}).apply(sin, args=(1, 2))