Skip to content

Commit

Permalink
[SPARK-46471][PS][TESTS][FOLLOWUPS] Move `OpsOnDiffFramesEnabledTests…
Browse files Browse the repository at this point in the history
…` to `pyspark.pandas.tests.diff_frames_ops.*``

### What changes were proposed in this pull request?
Move `OpsOnDiffFramesEnabledTests` to `pyspark.pandas.tests.diff_frames_ops.*``

### Why are the changes needed?
test code clean up

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#44471 from zhengruifeng/ps_test_diff_ops_3.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Dec 24, 2023
1 parent bd29569 commit ef7e1e6
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 92 deletions.
4 changes: 2 additions & 2 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,6 @@ def __hash__(self):
"pyspark.pandas.tests.groupby.test_stat_prod",
"pyspark.pandas.tests.groupby.test_value_counts",
"pyspark.pandas.tests.test_indexing",
"pyspark.pandas.tests.test_ops_on_diff_frames",
"pyspark.pandas.tests.diff_frames_ops.test_align",
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic",
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext",
Expand All @@ -872,6 +871,7 @@ def __hash__(self):
"pyspark.pandas.tests.diff_frames_ops.test_arithmetic_chain_ext_float",
"pyspark.pandas.tests.diff_frames_ops.test_assign_frame",
"pyspark.pandas.tests.diff_frames_ops.test_assign_series",
"pyspark.pandas.tests.diff_frames_ops.test_basic",
"pyspark.pandas.tests.diff_frames_ops.test_bitwise",
"pyspark.pandas.tests.diff_frames_ops.test_combine_first",
"pyspark.pandas.tests.diff_frames_ops.test_compare_series",
Expand Down Expand Up @@ -1235,7 +1235,6 @@ def __hash__(self):
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_map",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_property",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_round",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_ext_float",
Expand All @@ -1244,6 +1243,7 @@ def __hash__(self):
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_arithmetic_chain_ext_float",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_assign_frame",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_assign_series",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_basic",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_bitwise",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_combine_first",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_compare_series",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
#
import unittest

from pyspark.pandas.tests.test_ops_on_diff_frames import OpsOnDiffFramesEnabledTestsMixin

from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.pandas.tests.diff_frames_ops.test_basic import BasicMixin


class OpsOnDiffFramesEnabledParityTests(
OpsOnDiffFramesEnabledTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
class BasicParityTests(
BasicMixin,
PandasOnSparkTestUtils,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames import * # noqa: F401
from pyspark.pandas.tests.connect.diff_frames_ops.test_parity_basic import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/pandas/tests/diff_frames_ops/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ def pdf2(self):
index=list(range(9)),
)

@property
def pdf5(self):
return pd.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"b": [4, 5, 6, 3, 2, 1, 0, 0, 0],
"c": [4, 5, 6, 3, 2, 1, 0, 0, 0],
},
index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
).set_index(["a", "b"])

@property
def pdf6(self):
return pd.DataFrame(
{
"a": [9, 8, 7, 6, 5, 4, 3, 2, 1],
"b": [0, 0, 0, 4, 5, 6, 1, 2, 3],
"c": [9, 8, 7, 6, 5, 4, 3, 2, 1],
"e": [4, 5, 6, 3, 2, 1, 0, 0, 0],
},
index=list(range(9)),
).set_index(["a", "b"])

@property
def pser1(self):
midx = pd.MultiIndex(
Expand All @@ -130,10 +153,32 @@ def pser2(self):
)
return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3], index=midx)

@property
def psdf5(self):
return ps.from_pandas(self.pdf5)

@property
def psdf6(self):
return ps.from_pandas(self.pdf6)

def test_arithmetic(self):
self._test_arithmetic_frame(self.pdf1, self.pdf2, check_extension=False)
self._test_arithmetic_series(self.pser1, self.pser2, check_extension=False)

def test_multi_index_arithmetic(self):
psdf5 = self.psdf5
psdf6 = self.psdf6
pdf5 = self.pdf5
pdf6 = self.pdf6

# Series
self.assert_eq((psdf5.c - psdf6.e).sort_index(), (pdf5.c - pdf6.e).sort_index())

self.assert_eq((psdf5["c"] / psdf6["e"]).sort_index(), (pdf5["c"] / pdf6["e"]).sort_index())

# DataFrame
self.assert_eq((psdf5 + psdf6).sort_index(), (pdf5 + pdf6).sort_index(), almost=True)


class ArithmeticTests(
ArithmeticMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pyspark.testing.sqlutils import SQLTestUtils


class OpsOnDiffFramesEnabledTestsMixin:
class BasicMixin:
@classmethod
def setUpClass(cls):
super().setUpClass()
Expand Down Expand Up @@ -65,53 +65,6 @@ def pdf4(self):
index=list(range(9)),
)

@property
def pdf5(self):
return pd.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"b": [4, 5, 6, 3, 2, 1, 0, 0, 0],
"c": [4, 5, 6, 3, 2, 1, 0, 0, 0],
},
index=[0, 1, 3, 5, 6, 8, 9, 10, 11],
).set_index(["a", "b"])

@property
def pdf6(self):
return pd.DataFrame(
{
"a": [9, 8, 7, 6, 5, 4, 3, 2, 1],
"b": [0, 0, 0, 4, 5, 6, 1, 2, 3],
"c": [9, 8, 7, 6, 5, 4, 3, 2, 1],
"e": [4, 5, 6, 3, 2, 1, 0, 0, 0],
},
index=list(range(9)),
).set_index(["a", "b"])

@property
def pser1(self):
midx = pd.MultiIndex(
[["lama", "cow", "falcon", "koala"], ["speed", "weight", "length", "power"]],
[[0, 3, 1, 1, 1, 2, 2, 2], [0, 2, 0, 3, 2, 0, 1, 3]],
)
return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1], index=midx)

@property
def pser2(self):
midx = pd.MultiIndex(
[["lama", "cow", "falcon"], ["speed", "weight", "length"]],
[[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]],
)
return pd.Series([-45, 200, -1.2, 30, -250, 1.5, 320, 1, -0.3], index=midx)

@property
def pser3(self):
midx = pd.MultiIndex(
[["koalas", "cow", "falcon"], ["speed", "weight", "length"]],
[[0, 0, 0, 1, 1, 1, 2, 2, 2], [1, 1, 2, 0, 0, 2, 2, 2, 1]],
)
return pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)

@property
def psdf1(self):
return ps.from_pandas(self.pdf1)
Expand All @@ -128,26 +81,6 @@ def psdf3(self):
def psdf4(self):
return ps.from_pandas(self.pdf4)

@property
def psdf5(self):
return ps.from_pandas(self.pdf5)

@property
def psdf6(self):
return ps.from_pandas(self.pdf6)

@property
def psser1(self):
return ps.from_pandas(self.pser1)

@property
def psser2(self):
return ps.from_pandas(self.pser2)

@property
def psser3(self):
return ps.from_pandas(self.pser3)

def test_ranges(self):
self.assert_eq(
(ps.range(10) + ps.range(10)).sort_index(),
Expand Down Expand Up @@ -286,29 +219,17 @@ def test_different_columns(self):

self.assert_eq((psdf1 + psdf4).sort_index(), (pdf1 + pdf4).sort_index(), almost=True)

def test_multi_index_arithmetic(self):
psdf5 = self.psdf5
psdf6 = self.psdf6
pdf5 = self.pdf5
pdf6 = self.pdf6

# Series
self.assert_eq((psdf5.c - psdf6.e).sort_index(), (pdf5.c - pdf6.e).sort_index())

self.assert_eq((psdf5["c"] / psdf6["e"]).sort_index(), (pdf5["c"] / pdf6["e"]).sort_index())

# DataFrame
self.assert_eq((psdf5 + psdf6).sort_index(), (pdf5 + pdf6).sort_index(), almost=True)


class OpsOnDiffFramesEnabledTests(
OpsOnDiffFramesEnabledTestsMixin, PandasOnSparkTestCase, SQLTestUtils
class BasicTests(
BasicMixin,
PandasOnSparkTestCase,
SQLTestUtils,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401
from pyspark.pandas.tests.diff_frames_ops.test_basic import * # noqa: F401

try:
import xmlrunner
Expand Down

0 comments on commit ef7e1e6

Please sign in to comment.