diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index 7f16b109a2aac..ba6db137de083 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -104,6 +104,7 @@ Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ - Performance improvement in :meth:`.GroupBy.median` for nullable dtypes (:issue:`37493`) - Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`) +- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`) - Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`) - Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`) - Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`) @@ -170,7 +171,7 @@ Missing MultiIndex ^^^^^^^^^^ - Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`) -- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`) +- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`) - Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`) - diff --git a/pandas/_libs/lib.pyi b/pandas/_libs/lib.pyi index 77d3cbe92bef9..079f3ae5546be 100644 --- a/pandas/_libs/lib.pyi +++ b/pandas/_libs/lib.pyi @@ -59,7 +59,7 @@ def is_bool_array(values: np.ndarray, skipna: bool = ...): ... def fast_multiget(mapping: dict, keys: np.ndarray, default=...) -> np.ndarray: ... def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ... def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ... -def fast_unique_multiple(arrays: list, sort: bool = ...) -> list: ... +def fast_unique_multiple(left: np.ndarray, right: np.ndarray) -> list: ... def map_infer( arr: np.ndarray, f: Callable[[Any], Any], diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index ec7c3d61566dc..f1473392147f9 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -1,9 +1,7 @@ from collections import abc from decimal import Decimal from enum import Enum -import inspect from typing import Literal -import warnings cimport cython from cpython.datetime cimport ( @@ -31,8 +29,6 @@ from cython cimport ( floating, ) -from pandas.util._exceptions import find_stack_level - import_datetime() import numpy as np @@ -314,51 +310,42 @@ def item_from_zerodim(val: object) -> object: @cython.wraparound(False) @cython.boundscheck(False) -def fast_unique_multiple(list arrays, sort: bool = True): +def fast_unique_multiple(ndarray left, ndarray right) -> list: """ - Generate a list of unique values from a list of arrays. + Generate a list indices we have to add to the left to get the union + of both arrays. Parameters ---------- - list : array-like - List of array-like objects. - sort : bool - Whether or not to sort the resulting unique list. + left : np.ndarray + Left array that is used as base. + right : np.ndarray + right array that is checked for values that are not in left. + right can not have duplicates. Returns ------- - list of unique values + list of indices that we have to add to the left array. """ cdef: - ndarray[object] buf - Py_ssize_t k = len(arrays) - Py_ssize_t i, j, n - list uniques = [] - dict table = {} + Py_ssize_t j, n + list indices = [] + set table = set() object val, stub = 0 - for i in range(k): - buf = arrays[i] - n = len(buf) - for j in range(n): - val = buf[j] - if val not in table: - table[val] = stub - uniques.append(val) + n = len(left) + for j in range(n): + val = left[j] + if val not in table: + table.add(val) - if sort is None: - try: - uniques.sort() - except TypeError: - warnings.warn( - "The values in the array are unorderable. " - "Pass `sort=False` to suppress this warning.", - RuntimeWarning, - stacklevel=find_stack_level(inspect.currentframe()), - ) - pass + n = len(right) + for j in range(n): + val = right[j] + if val not in table: + indices.append(j) - return uniques + return indices @cython.wraparound(False) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 1b35cc03f6fdd..9506e775ceb03 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3638,21 +3638,38 @@ def _union(self, other, sort) -> MultiIndex: if ( any(-1 in code for code in self.codes) and any(-1 in code for code in other.codes) - or self.has_duplicates or other.has_duplicates ): - # This is only necessary if both sides have nans or one has dups, + # This is only necessary if both sides have nans or other has dups, # fast_unique_multiple is faster result = super()._union(other, sort) if isinstance(result, MultiIndex): return result + return MultiIndex.from_arrays( + zip(*result), sortorder=None, names=result_names + ) else: rvals = other._values.astype(object, copy=False) - result = lib.fast_unique_multiple([self._values, rvals], sort=sort) + right_missing = lib.fast_unique_multiple(self._values, rvals) + if right_missing: + result = self.append(other.take(right_missing)) + else: + result = self._get_reconciled_name_object(other) - return MultiIndex.from_arrays(zip(*result), sortorder=None, names=result_names) + if sort is None: + try: + result = result.sort_values() + except TypeError: + warnings.warn( + "The values in the array are unorderable. " + "Pass `sort=False` to suppress this warning.", + RuntimeWarning, + stacklevel=find_stack_level(inspect.currentframe()), + ) + pass + return result def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: return is_object_dtype(dtype) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index 7383d5a551e7b..ce310a75e8e45 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -261,12 +261,6 @@ def test_union(idx, sort): assert result.equals(idx) -@pytest.mark.xfail( - # This test was commented out from Oct 2011 to Dec 2021, may no longer - # be relevant. - reason="Length of names must match number of levels in MultiIndex", - raises=ValueError, -) def test_union_with_regular_index(idx): other = Index(["A", "B", "C"]) @@ -277,7 +271,9 @@ def test_union_with_regular_index(idx): msg = "The values in the array are unorderable" with tm.assert_produces_warning(RuntimeWarning, match=msg): result2 = idx.union(other) - assert result.equals(result2) + # This is more consistent now, if sorting fails then we don't sort at all + # in the MultiIndex case. + assert not result.equals(result2) def test_intersection(idx, sort): @@ -525,6 +521,26 @@ def test_union_nan_got_duplicated(): tm.assert_index_equal(result, mi2) +@pytest.mark.parametrize("val", [4, 1]) +def test_union_keep_ea_dtype(any_numeric_ea_dtype, val): + # GH#48505 + + arr1 = Series([val, 2], dtype=any_numeric_ea_dtype) + arr2 = Series([2, 1], dtype=any_numeric_ea_dtype) + midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None]) + midx2 = MultiIndex.from_arrays([arr2, [2, 1]]) + result = midx.union(midx2) + if val == 4: + expected = MultiIndex.from_arrays( + [Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]] + ) + else: + expected = MultiIndex.from_arrays( + [Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]] + ) + tm.assert_index_equal(result, expected) + + def test_union_duplicates(index, request): # GH#38977 if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)): @@ -534,18 +550,19 @@ def test_union_duplicates(index, request): values = index.unique().values.tolist() mi1 = MultiIndex.from_arrays([values, [1] * len(values)]) mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)]) - result = mi1.union(mi2) + result = mi2.union(mi1) expected = mi2.sort_values() + tm.assert_index_equal(result, expected) + if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all(): # GH#47294 - union uses lib.fast_zip, converting data to Python integers # and loses type information. Result is then unsigned only when values are - # sufficiently large to require unsigned dtype. + # sufficiently large to require unsigned dtype. This happens only if other + # has dups or one of both have missing values expected = expected.set_levels( [expected.levels[0].astype(int), expected.levels[1]] ) - tm.assert_index_equal(result, expected) - - result = mi2.union(mi1) + result = mi1.union(mi2) tm.assert_index_equal(result, expected)