diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index 0bc91d3cd9637..2394d4721edfc 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -220,6 +220,7 @@ Missing MultiIndex ^^^^^^^^^^ +- Bug in :meth:`MultiIndex.argsort` raising ``TypeError`` when index contains :attr:`NA` (:issue:`48495`) - Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`) - Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`) - Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 574dfdf48055d..048c8f0ba5e69 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -1952,7 +1952,7 @@ def _lexsort_depth(self) -> int: return self.sortorder return _lexsort_depth(self.codes, self.nlevels) - def _sort_levels_monotonic(self) -> MultiIndex: + def _sort_levels_monotonic(self, raise_if_incomparable: bool = False) -> MultiIndex: """ This is an *internal* function. @@ -1999,7 +1999,8 @@ def _sort_levels_monotonic(self) -> MultiIndex: # indexer to reorder the levels indexer = lev.argsort() except TypeError: - pass + if raise_if_incomparable: + raise else: lev = lev.take(indexer) @@ -2245,9 +2246,9 @@ def append(self, other): def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: if len(args) == 0 and len(kwargs) == 0: - # np.lexsort is significantly faster than self._values.argsort() - values = [self._get_level_values(i) for i in reversed(range(self.nlevels))] - return np.lexsort(values) + # lexsort is significantly faster than self._values.argsort() + target = self._sort_levels_monotonic(raise_if_incomparable=True) + return lexsort_indexer(target._get_codes_for_sorting()) return self._values.argsort(*args, **kwargs) @Appender(_index_shared_docs["repeat"] % _index_doc_kwargs) diff --git a/pandas/tests/indexes/multi/test_sorting.py b/pandas/tests/indexes/multi/test_sorting.py index 6fd1781beeda4..3f364473270fb 100644 --- a/pandas/tests/indexes/multi/test_sorting.py +++ b/pandas/tests/indexes/multi/test_sorting.py @@ -14,6 +14,7 @@ Index, MultiIndex, RangeIndex, + Timestamp, ) import pandas._testing as tm from pandas.core.indexes.frozen import FrozenList @@ -280,3 +281,26 @@ def test_remove_unused_levels_with_nan(): result = idx.levels expected = FrozenList([["a", np.nan], [4]]) assert str(result) == str(expected) + + +def test_sort_values_nan(): + # GH48495, GH48626 + midx = MultiIndex(levels=[["A", "B", "C"], ["D"]], codes=[[1, 0, 2], [-1, -1, 0]]) + result = midx.sort_values() + expected = MultiIndex( + levels=[["A", "B", "C"], ["D"]], codes=[[0, 1, 2], [-1, -1, 0]] + ) + tm.assert_index_equal(result, expected) + + +def test_sort_values_incomparable(): + # GH48495 + mi = MultiIndex.from_arrays( + [ + [1, Timestamp("2000-01-01")], + [3, 4], + ] + ) + match = "'<' not supported between instances of 'Timestamp' and 'int'" + with pytest.raises(TypeError, match=match): + mi.sort_values() diff --git a/pandas/tests/indexing/multiindex/test_sorted.py b/pandas/tests/indexing/multiindex/test_sorted.py index 2214aaa9cfbdb..1d2fd6586337b 100644 --- a/pandas/tests/indexing/multiindex/test_sorted.py +++ b/pandas/tests/indexing/multiindex/test_sorted.py @@ -2,9 +2,11 @@ import pytest from pandas import ( + NA, DataFrame, MultiIndex, Series, + array, ) import pandas._testing as tm @@ -86,6 +88,36 @@ def test_sort_values_key(self): tm.assert_frame_equal(result, expected) + def test_argsort_with_na(self): + # GH48495 + arrays = [ + array([2, NA, 1], dtype="Int64"), + array([1, 2, 3], dtype="Int64"), + ] + index = MultiIndex.from_arrays(arrays) + result = index.argsort() + expected = np.array([2, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + def test_sort_values_with_na(self): + # GH48495 + arrays = [ + array([2, NA, 1], dtype="Int64"), + array([1, 2, 3], dtype="Int64"), + ] + index = MultiIndex.from_arrays(arrays) + index = index.sort_values() + result = DataFrame(range(3), index=index) + + arrays = [ + array([1, 2, NA], dtype="Int64"), + array([3, 1, 2], dtype="Int64"), + ] + index = MultiIndex.from_arrays(arrays) + expected = DataFrame(range(3), index=index) + + tm.assert_frame_equal(result, expected) + def test_frame_getitem_not_sorted(self, multiindex_dataframe_random_data): frame = multiindex_dataframe_random_data df = frame.T