Skip to content

Commit

Permalink
[SPARK-50039][CONNECT][PYTHON] API compatibility check for Grouping
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to add API compatibility check for Spark SQL Grouping functions

### Why are the changes needed?

To guarantee of the same behavior between Spark Classic and Spark Connect

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #48560 from itholic/compat_grouping.

Authored-by: Haejoon Lee <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
itholic authored and HyukjinKwon committed Oct 20, 2024
1 parent ae75cac commit 76ea894
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 20 deletions.
10 changes: 5 additions & 5 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,29 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
session=self._df._session,
)

def min(self, *cols: str) -> "DataFrame":
def min(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("min", list(cols))

min.__doc__ = PySparkGroupedData.min.__doc__

def max(self, *cols: str) -> "DataFrame":
def max(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("max", list(cols))

max.__doc__ = PySparkGroupedData.max.__doc__

def sum(self, *cols: str) -> "DataFrame":
def sum(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("sum", list(cols))

sum.__doc__ = PySparkGroupedData.sum.__doc__

def avg(self, *cols: str) -> "DataFrame":
def avg(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("avg", list(cols))

avg.__doc__ = PySparkGroupedData.avg.__doc__

mean = avg

def count(self) -> "DataFrame":
def count(self: "GroupedData") -> "DataFrame":
return self.agg(F._invoke_function("count", F.lit(1)).alias("count"))

count.__doc__ = PySparkGroupedData.count.__doc__
Expand Down
18 changes: 9 additions & 9 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def dfapi(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def _api(self: "GroupedData") -> DataFrame:
def _api(self: "GroupedData") -> "DataFrame":
name = f.__name__
jdf = getattr(self._jgd, name)()
return DataFrame(jdf, self.session)
Expand All @@ -43,7 +43,7 @@ def _api(self: "GroupedData") -> DataFrame:


def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def _api(self: "GroupedData", *cols: str) -> DataFrame:
def _api(self: "GroupedData", *cols: str) -> "DataFrame":
from pyspark.sql.classic.column import _to_seq

name = f.__name__
Expand Down Expand Up @@ -80,14 +80,14 @@ def __repr__(self) -> str:
return super().__repr__()

@overload
def agg(self, *exprs: Column) -> DataFrame:
def agg(self, *exprs: Column) -> "DataFrame":
...

@overload
def agg(self, __exprs: Dict[str, str]) -> DataFrame:
def agg(self, __exprs: Dict[str, str]) -> "DataFrame":
...

def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions can be:
Expand Down Expand Up @@ -190,7 +190,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
return DataFrame(jdf, self.session)

@dfapi
def count(self) -> DataFrame: # type: ignore[empty-body]
def count(self) -> "DataFrame": # type: ignore[empty-body]
"""Counts the number of records for each group.
.. versionadded:: 1.3.0
Expand Down Expand Up @@ -241,7 +241,7 @@ def mean(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
"""

@df_varargs_api
def avg(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
def avg(self, *cols: str) -> "DataFrame": # type: ignore[empty-body]
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Expand Down Expand Up @@ -292,7 +292,7 @@ def avg(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
"""

@df_varargs_api
def max(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
def max(self, *cols: str) -> "DataFrame": # type: ignore[empty-body]
"""Computes the max value for each numeric columns for each group.
.. versionadded:: 1.3.0
Expand Down Expand Up @@ -336,7 +336,7 @@ def max(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
"""

@df_varargs_api
def min(self, *cols: str) -> DataFrame: # type: ignore[empty-body]
def min(self, *cols: str) -> "DataFrame": # type: ignore[empty-body]
"""Computes the min value for each numeric column for each group.
.. versionadded:: 1.3.0
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PandasGroupedOpsMixin:
can use this class.
"""

def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame:
def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
"""
It is an alias of :meth:`pyspark.sql.GroupedData.applyInPandas`; however, it takes a
:meth:`pyspark.sql.functions.pandas_udf` whereas
Expand Down Expand Up @@ -121,8 +121,8 @@ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame:
return self.applyInPandas(udf.func, schema=udf.returnType) # type: ignore[attr-defined]

def applyInPandas(
self, func: "PandasGroupedMapFunction", schema: Union[StructType, str]
) -> DataFrame:
self, func: "PandasGroupedMapFunction", schema: Union["StructType", str]
) -> "DataFrame":
"""
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
as a `DataFrame`.
Expand Down Expand Up @@ -246,7 +246,7 @@ def applyInPandasWithState(
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame:
) -> "DataFrame":
"""
Applies the given function to each group of data, while maintaining a user-defined
per-group state. The result Dataset will represent the flattened record returned by the
Expand Down Expand Up @@ -684,8 +684,8 @@ def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
self._gd2 = gd2

def applyInPandas(
self, func: "PandasCogroupedMapFunction", schema: Union[StructType, str]
) -> DataFrame:
self, func: "PandasCogroupedMapFunction", schema: Union["StructType", str]
) -> "DataFrame":
"""
Applies a function to each cogroup using pandas and returns the result
as a `DataFrame`.
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pyspark.sql.window import Window as ClassicWindow
from pyspark.sql.window import WindowSpec as ClassicWindowSpec
import pyspark.sql.functions as ClassicFunctions
from pyspark.sql.group import GroupedData as ClassicGroupedData

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
Expand All @@ -43,6 +44,7 @@
from pyspark.sql.connect.window import Window as ConnectWindow
from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec
import pyspark.sql.connect.functions as ConnectFunctions
from pyspark.sql.connect.group import GroupedData as ConnectGroupedData


class ConnectCompatibilityTestsMixin:
Expand Down Expand Up @@ -357,6 +359,22 @@ def test_functions_compatibility(self):
expected_missing_classic_methods,
)

def test_grouping_compatibility(self):
"""Test Grouping compatibility between classic and connect."""
expected_missing_connect_properties = set()
expected_missing_classic_properties = set()
expected_missing_connect_methods = {"transformWithStateInPandas"}
expected_missing_classic_methods = set()
self.check_compatibility(
ClassicGroupedData,
ConnectGroupedData,
"Grouping",
expected_missing_connect_properties,
expected_missing_classic_properties,
expected_missing_connect_methods,
expected_missing_classic_methods,
)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
Expand Down

0 comments on commit 76ea894

Please sign in to comment.