Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sort=True for Index.union, Index.difference and Index.intersection #13497

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def union(self, other, sort=None):
2. `self` or `other` has length 0.

* False : do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand Down Expand Up @@ -395,10 +396,10 @@ def union(self, other, sort=None):
if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)

if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
f"[None, False, True]; {sort} was passed."
)

if not len(other) or self.equals(other):
Expand All @@ -425,6 +426,7 @@ def intersection(self, other, sort=False):
* False : do not sort the result.
* None : sort the result, except when `self` and `other` are equal
or when the values cannot be compared.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand Down Expand Up @@ -475,10 +477,10 @@ def intersection(self, other, sort=False):
if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)

if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
f"[None, False, True]; {sort} was passed."
)

if self.equals(other):
Expand Down Expand Up @@ -768,6 +770,7 @@ def difference(self, other, sort=None):
* None : Attempt to sort the result, but catch any TypeErrors
from comparing incomparable elements.
* False : Do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand All @@ -787,16 +790,18 @@ def difference(self, other, sort=None):
>>> idx1.difference(idx2, sort=False)
Index([2, 1], dtype='int64')
"""
if sort not in {None, False}:
if sort not in {None, False, True}:
raise ValueError(
f"The 'sort' keyword only takes the values "
f"of None or False; {sort} was passed."
f"of [None, False, True]; {sort} was passed."
)

other = cudf.Index(other)

if is_mixed_with_object_dtype(self, other):
if is_mixed_with_object_dtype(self, other) or len(other) == 0:
difference = self.copy()
if sort is True:
return difference.sort_values()
else:
other = other.copy(deep=False)
other.names = self.names
Expand All @@ -813,7 +818,7 @@ def difference(self, other, sort=None):
if self.dtype != other.dtype:
difference = difference.astype(self.dtype)

if sort is None and len(other):
if sort in {None, True} and len(other):
return difference.sort_values()

return difference
Expand Down Expand Up @@ -1170,7 +1175,7 @@ def _union(self, other, sort=None):
)
union_result = cudf.core.index._index_from_data({0: res._data[0]})

if sort is None and len(other):
if sort in {None, True} and len(other):
return union_result.sort_values()
return union_result

Expand All @@ -1187,7 +1192,7 @@ def _intersection(self, other, sort=None):
._data
)

if sort is None and len(other):
if sort is {None, True} and len(other):
return intersection_result.sort_values()
return intersection_result

Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _union(self, other, sort=None):
):
result = type(self)(start_r, end_r + step_s / 2, step_s / 2)
if result is not None:
if sort is None and not result.is_monotonic_increasing:
if sort in {None, True} and not result.is_monotonic_increasing:
return result.sort_values()
else:
return result
Expand All @@ -710,7 +710,7 @@ def _union(self, other, sort=None):
return self._as_int_index()._union(other, sort=sort)

@_cudf_nvtx_annotate
def _intersection(self, other, sort=False):
def _intersection(self, other, sort=None):
if not isinstance(other, RangeIndex):
return super()._intersection(other, sort=sort)

Expand Down Expand Up @@ -750,7 +750,7 @@ def _intersection(self, other, sort=False):

if (self.step < 0 and other.step < 0) is not (new_index.step < 0):
new_index = new_index[::-1]
if sort is None:
if sort in {None, True}:
new_index = new_index.sort_values()

return new_index
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ def _union(self, other, sort=None):

midx = MultiIndex.from_frame(result_df.iloc[:, : self.nlevels])
midx.names = self.names if self.names == other.names else None
if sort is None and len(other):
if sort in {None, True} and len(other):
return midx.sort_values()
return midx

Expand All @@ -1819,7 +1819,7 @@ def _intersection(self, other, sort=None):

result_df = cudf.merge(self_df, other_df, how="inner")
midx = self.__class__.from_frame(result_df, names=res_name)
if sort is None and len(other):
if sort in {None, True} and len(other):
return midx.sort_values()
return midx

Expand Down
50 changes: 41 additions & 9 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def test_index_to_series(data):
[],
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_index_difference(data, other, sort):
pd_data = pd.Index(data)
pd_other = pd.Index(other)
Expand All @@ -801,8 +801,8 @@ def test_index_difference_sort_error():
assert_exceptions_equal(
pdi.difference,
gdi.difference,
([pdi], {"sort": True}),
([gdi], {"sort": True}),
([pdi], {"sort": "A"}),
([gdi], {"sort": "A"}),
)


Expand Down Expand Up @@ -2236,21 +2236,53 @@ def test_range_index_concat(objs):
[
(pd.RangeIndex(0, 10), pd.RangeIndex(3, 7)),
(pd.RangeIndex(0, 10), pd.RangeIndex(10, 20)),
(pd.RangeIndex(0, 10, 2), pd.RangeIndex(1, 5, 3)),
(pd.RangeIndex(1, 5, 3), pd.RangeIndex(0, 10, 2)),
(pd.RangeIndex(1, 10, 3), pd.RangeIndex(1, 5, 2)),
pytest.param(
pd.RangeIndex(0, 10, 2),
pd.RangeIndex(1, 5, 3),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
pytest.param(
pd.RangeIndex(1, 5, 3),
pd.RangeIndex(0, 10, 2),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
pytest.param(
pd.RangeIndex(1, 10, 3),
pd.RangeIndex(1, 5, 2),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
(pd.RangeIndex(1, 5, 2), pd.RangeIndex(1, 10, 3)),
(pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 3)),
(pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 6)),
(pd.RangeIndex(1, 100, 6), pd.RangeIndex(1, 50, 3)),
pytest.param(
pd.RangeIndex(1, 100, 6),
pd.RangeIndex(1, 50, 3),
marks=pytest.mark.xfail(
condition=PANDAS_GE_200,
reason="https://github.com/pandas-dev/pandas/issues/53490",
strict=False,
),
),
(pd.RangeIndex(0, 10, name="a"), pd.RangeIndex(90, 100, name="b")),
(pd.Index([0, 1, 2, 30], name="a"), pd.Index([90, 100])),
(pd.Index([0, 1, 2, 30], name="a"), [90, 100]),
(pd.Index([0, 1, 2, 30]), pd.Index([0, 10, 1.0, 11])),
(pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])),
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_union_index(idx1, idx2, sort):
expected = idx1.union(idx2, sort=sort)

Expand Down Expand Up @@ -2280,7 +2312,7 @@ def test_union_index(idx1, idx2, sort):
(pd.Index([True, False, True, True]), pd.Index([True, True])),
],
)
@pytest.mark.parametrize("sort", [None, False])
@pytest.mark.parametrize("sort", [None, False, True])
def test_intersection_index(idx1, idx2, sort):

expected = idx1.intersection(idx2, sort=sort)
Expand Down