Skip to content

Commit

Permalink
Enable sort=True for Index.union, Index.difference and `Index.i…
Browse files Browse the repository at this point in the history
…ntersection` (#13497)

This PR enables `sort=True` for `union`, `difference`, and `intersection` APIs in `Index`. 

This also fixes 1 pytest failure and adds 77 pytests:
On `Index_sort_2.0`:
```
= 230 failed, 95836 passed, 2045 skipped, 768 xfailed, 308 xpassed in 438.88s (0:07:18) =
```
On `pandas_2.0_feature_branch`:
```
= 231 failed, 95767 passed, 2045 skipped, 764 xfailed, 300 xpassed in 432.59s (0:07:12) =
```

xref: pandas-dev/pandas#25151
  • Loading branch information
galipremsagar authored Jun 3, 2023
1 parent 63b8fb1 commit 139e32d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 24 deletions.
25 changes: 15 additions & 10 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def union(self, other, sort=None):
2. `self` or `other` has length 0.
* False : do not sort the result.
* True : Sort the result (which may raise TypeError).
Returns
-------
Expand Down Expand Up @@ -395,10 +396,10 @@ def union(self, other, sort=None):
if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)

if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
f"[None, False, True]; {sort} was passed."
)

if not len(other) or self.equals(other):
Expand All @@ -425,6 +426,7 @@ def intersection(self, other, sort=False):
* False : do not sort the result.
* None : sort the result, except when `self` and `other` are equal
or when the values cannot be compared.
* True : Sort the result (which may raise TypeError).
Returns
-------
Expand Down Expand Up @@ -475,10 +477,10 @@ def intersection(self, other, sort=False):
if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)

if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
f"[None, False, True]; {sort} was passed."
)

if self.equals(other):
Expand Down Expand Up @@ -768,6 +770,7 @@ def difference(self, other, sort=None):
* None : Attempt to sort the result, but catch any TypeErrors
from comparing incomparable elements.
* False : Do not sort the result.
* True : Sort the result (which may raise TypeError).
Returns
-------
Expand All @@ -787,16 +790,18 @@ def difference(self, other, sort=None):
>>> idx1.difference(idx2, sort=False)
Index([2, 1], dtype='int64')
"""
if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values "
f"of None or False; {sort} was passed."
f"of [None, False, True]; {sort} was passed."
)

other = cudf.Index(other)

if is_mixed_with_object_dtype(self, other):
if is_mixed_with_object_dtype(self, other) or len(other) == 0:
difference = self.copy()
if sort is True:
return difference.sort_values()
else:
other = other.copy(deep=False)
other.names = self.names
Expand All @@ -813,7 +818,7 @@ def difference(self, other, sort=None):
if self.dtype != other.dtype:
difference = difference.astype(self.dtype)

if sort is None and len(other):
if sort in {None, True} and len(other):
return difference.sort_values()

return difference
Expand Down Expand Up @@ -1170,7 +1175,7 @@ def _union(self, other, sort=None):
)
union_result = cudf.core.index._index_from_data({0: res._data[0]})

if sort is None and len(other):
if sort in {None, True} and len(other):
return union_result.sort_values()
return union_result

Expand All @@ -1187,7 +1192,7 @@ def _intersection(self, other, sort=None):
._data
)

if sort is None and len(other):
if sort is {None, True} and len(other):
return intersection_result.sort_values()
return intersection_result

Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _union(self, other, sort=None):
):
result = type(self)(start_r, end_r + step_s / 2, step_s / 2)
if result is not None:
if sort is None and not result.is_monotonic_increasing:
if sort in {None, True} and not result.is_monotonic_increasing:
return result.sort_values()
else:
return result
Expand All @@ -710,7 +710,7 @@ def _union(self, other, sort=None):
return self._as_int_index()._union(other, sort=sort)

@_cudf_nvtx_annotate
def _intersection(self, other, sort=False):
def _intersection(self, other, sort=None):
if not isinstance(other, RangeIndex):
return super()._intersection(other, sort=sort)

Expand Down Expand Up @@ -750,7 +750,7 @@ def _intersection(self, other, sort=False):

if (self.step < 0 and other.step < 0) is not (new_index.step < 0):
new_index = new_index[::-1]
if sort is None:
if sort in {None, True}:
new_index = new_index.sort_values()

return new_index
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ def _union(self, other, sort=None):

midx = MultiIndex.from_frame(result_df.iloc[:, : self.nlevels])
midx.names = self.names if self.names == other.names else None
if sort is None and len(other):
if sort in {None, True} and len(other):
return midx.sort_values()
return midx

Expand All @@ -1819,7 +1819,7 @@ def _intersection(self, other, sort=None):

result_df = cudf.merge(self_df, other_df, how="inner")
midx = self.__class__.from_frame(result_df, names=res_name)
if sort is None and len(other):
if sort in {None, True} and len(other):
return midx.sort_values()
return midx

Expand Down
50 changes: 41 additions & 9 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def test_index_to_series(data):
[],
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_index_difference(data, other, sort):
pd_data = pd.Index(data)
pd_other = pd.Index(other)
Expand All @@ -801,8 +801,8 @@ def test_index_difference_sort_error():
assert_exceptions_equal(
pdi.difference,
gdi.difference,
([pdi], {"sort": True}),
([gdi], {"sort": True}),
([pdi], {"sort": "A"}),
([gdi], {"sort": "A"}),
)


Expand Down Expand Up @@ -2236,21 +2236,53 @@ def test_range_index_concat(objs):
[
(pd.RangeIndex(0, 10), pd.RangeIndex(3, 7)),
(pd.RangeIndex(0, 10), pd.RangeIndex(10, 20)),
(pd.RangeIndex(0, 10, 2), pd.RangeIndex(1, 5, 3)),
(pd.RangeIndex(1, 5, 3), pd.RangeIndex(0, 10, 2)),
(pd.RangeIndex(1, 10, 3), pd.RangeIndex(1, 5, 2)),
pytest.param(
pd.RangeIndex(0, 10, 2),
pd.RangeIndex(1, 5, 3),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
pytest.param(
pd.RangeIndex(1, 5, 3),
pd.RangeIndex(0, 10, 2),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
pytest.param(
pd.RangeIndex(1, 10, 3),
pd.RangeIndex(1, 5, 2),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
(pd.RangeIndex(1, 5, 2), pd.RangeIndex(1, 10, 3)),
(pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 3)),
(pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 6)),
(pd.RangeIndex(1, 100, 6), pd.RangeIndex(1, 50, 3)),
pytest.param(
pd.RangeIndex(1, 100, 6),
pd.RangeIndex(1, 50, 3),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
(pd.RangeIndex(0, 10, name="a"), pd.RangeIndex(90, 100, name="b")),
(pd.Index([0, 1, 2, 30], name="a"), pd.Index([90, 100])),
(pd.Index([0, 1, 2, 30], name="a"), [90, 100]),
(pd.Index([0, 1, 2, 30]), pd.Index([0, 10, 1.0, 11])),
(pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])),
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_union_index(idx1, idx2, sort):
expected = idx1.union(idx2, sort=sort)

Expand Down Expand Up @@ -2280,7 +2312,7 @@ def test_union_index(idx1, idx2, sort):
(pd.Index([True, False, True, True]), pd.Index([True, True])),
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_intersection_index(idx1, idx2, sort):

expected = idx1.intersection(idx2, sort=sort)
Expand Down

0 comments on commit 139e32d

Please sign in to comment.