diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index 46e7cdfac61..49721c23eb9 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -334,6 +334,7 @@ def union(self, other, sort=None): 2. `self` or `other` has length 0. * False : do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -395,10 +396,10 @@ def union(self, other, sort=None): if not isinstance(other, BaseIndex): other = cudf.Index(other, name=self.name) - if sort not in {None, False}: + if sort not in {None, False, True}: raise ValueError( f"The 'sort' keyword only takes the values of " - f"None or False; {sort} was passed." + f"[None, False, True]; {sort} was passed." ) if not len(other) or self.equals(other): @@ -425,6 +426,7 @@ def intersection(self, other, sort=False): * False : do not sort the result. * None : sort the result, except when `self` and `other` are equal or when the values cannot be compared. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -475,10 +477,10 @@ def intersection(self, other, sort=False): if not isinstance(other, BaseIndex): other = cudf.Index(other, name=self.name) - if sort not in {None, False}: + if sort not in {None, False, True}: raise ValueError( f"The 'sort' keyword only takes the values of " - f"None or False; {sort} was passed." + f"[None, False, True]; {sort} was passed." ) if self.equals(other): @@ -768,6 +770,7 @@ def difference(self, other, sort=None): * None : Attempt to sort the result, but catch any TypeErrors from comparing incomparable elements. * False : Do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -787,16 +790,18 @@ def difference(self, other, sort=None): >>> idx1.difference(idx2, sort=False) Index([2, 1], dtype='int64') """ - if sort not in {None, False}: + if sort not in {None, False, True}: raise ValueError( f"The 'sort' keyword only takes the values " - f"of None or False; {sort} was passed." + f"of [None, False, True]; {sort} was passed." ) other = cudf.Index(other) - if is_mixed_with_object_dtype(self, other): + if is_mixed_with_object_dtype(self, other) or len(other) == 0: difference = self.copy() + if sort is True: + return difference.sort_values() else: other = other.copy(deep=False) other.names = self.names @@ -813,7 +818,7 @@ def difference(self, other, sort=None): if self.dtype != other.dtype: difference = difference.astype(self.dtype) - if sort is None and len(other): + if sort in {None, True} and len(other): return difference.sort_values() return difference @@ -1170,7 +1175,7 @@ def _union(self, other, sort=None): ) union_result = cudf.core.index._index_from_data({0: res._data[0]}) - if sort is None and len(other): + if sort in {None, True} and len(other): return union_result.sort_values() return union_result @@ -1187,7 +1192,7 @@ def _intersection(self, other, sort=None): ._data ) - if sort is None and len(other): + if sort is {None, True} and len(other): return intersection_result.sort_values() return intersection_result diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index c0664d3ca4d..a71c285b737 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -699,7 +699,7 @@ def _union(self, other, sort=None): ): result = type(self)(start_r, end_r + step_s / 2, step_s / 2) if result is not None: - if sort is None and not result.is_monotonic_increasing: + if sort in {None, True} and not result.is_monotonic_increasing: return result.sort_values() else: return result @@ -710,7 +710,7 @@ def _union(self, other, sort=None): return self._as_int_index()._union(other, sort=sort) @_cudf_nvtx_annotate - def _intersection(self, other, sort=False): + def _intersection(self, other, sort=None): if not isinstance(other, RangeIndex): return super()._intersection(other, sort=sort) @@ -750,7 +750,7 @@ def _intersection(self, other, sort=False): if (self.step < 0 and other.step < 0) is not (new_index.step < 0): new_index = new_index[::-1] - if sort is None: + if sort in {None, True}: new_index = new_index.sort_values() return new_index diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index cdc120935ee..4803e2b8e4b 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -1796,7 +1796,7 @@ def _union(self, other, sort=None): midx = MultiIndex.from_frame(result_df.iloc[:, : self.nlevels]) midx.names = self.names if self.names == other.names else None - if sort is None and len(other): + if sort in {None, True} and len(other): return midx.sort_values() return midx @@ -1819,7 +1819,7 @@ def _intersection(self, other, sort=None): result_df = cudf.merge(self_df, other_df, how="inner") midx = self.__class__.from_frame(result_df, names=res_name) - if sort is None and len(other): + if sort in {None, True} and len(other): return midx.sort_values() return midx diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index de4c72389cf..81369cd2c6e 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -781,7 +781,7 @@ def test_index_to_series(data): [], ], ) -@pytest.mark.parametrize("sort", [None, False]) +@pytest.mark.parametrize("sort", [None, False, True]) def test_index_difference(data, other, sort): pd_data = pd.Index(data) pd_other = pd.Index(other) @@ -801,8 +801,8 @@ def test_index_difference_sort_error(): assert_exceptions_equal( pdi.difference, gdi.difference, - ([pdi], {"sort": True}), - ([gdi], {"sort": True}), + ([pdi], {"sort": "A"}), + ([gdi], {"sort": "A"}), ) @@ -2236,13 +2236,45 @@ def test_range_index_concat(objs): [ (pd.RangeIndex(0, 10), pd.RangeIndex(3, 7)), (pd.RangeIndex(0, 10), pd.RangeIndex(10, 20)), - (pd.RangeIndex(0, 10, 2), pd.RangeIndex(1, 5, 3)), - (pd.RangeIndex(1, 5, 3), pd.RangeIndex(0, 10, 2)), - (pd.RangeIndex(1, 10, 3), pd.RangeIndex(1, 5, 2)), + pytest.param( + pd.RangeIndex(0, 10, 2), + pd.RangeIndex(1, 5, 3), + marks=pytest.mark.xfail( + condition=PANDAS_GE_200, + reason="https://github.com/pandas-dev/pandas/issues/53490", + strict=False, + ), + ), + pytest.param( + pd.RangeIndex(1, 5, 3), + pd.RangeIndex(0, 10, 2), + marks=pytest.mark.xfail( + condition=PANDAS_GE_200, + reason="https://github.com/pandas-dev/pandas/issues/53490", + strict=False, + ), + ), + pytest.param( + pd.RangeIndex(1, 10, 3), + pd.RangeIndex(1, 5, 2), + marks=pytest.mark.xfail( + condition=PANDAS_GE_200, + reason="https://github.com/pandas-dev/pandas/issues/53490", + strict=False, + ), + ), (pd.RangeIndex(1, 5, 2), pd.RangeIndex(1, 10, 3)), (pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 3)), (pd.RangeIndex(1, 100, 3), pd.RangeIndex(1, 50, 6)), - (pd.RangeIndex(1, 100, 6), pd.RangeIndex(1, 50, 3)), + pytest.param( + pd.RangeIndex(1, 100, 6), + pd.RangeIndex(1, 50, 3), + marks=pytest.mark.xfail( + condition=PANDAS_GE_200, + reason="https://github.com/pandas-dev/pandas/issues/53490", + strict=False, + ), + ), (pd.RangeIndex(0, 10, name="a"), pd.RangeIndex(90, 100, name="b")), (pd.Index([0, 1, 2, 30], name="a"), pd.Index([90, 100])), (pd.Index([0, 1, 2, 30], name="a"), [90, 100]), @@ -2250,7 +2282,7 @@ def test_range_index_concat(objs): (pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])), ], ) -@pytest.mark.parametrize("sort", [None, False]) +@pytest.mark.parametrize("sort", [None, False, True]) def test_union_index(idx1, idx2, sort): expected = idx1.union(idx2, sort=sort) @@ -2280,7 +2312,7 @@ def test_union_index(idx1, idx2, sort): (pd.Index([True, False, True, True]), pd.Index([True, True])), ], ) -@pytest.mark.parametrize("sort", [None, False]) +@pytest.mark.parametrize("sort", [None, False, True]) def test_intersection_index(idx1, idx2, sort): expected = idx1.intersection(idx2, sort=sort)