Skip to content

Commit

Permalink
ENH: Keep dtypes in MultiIndex.union without NAs (#48505)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Sep 13, 2022
1 parent 68d6b47 commit aea824f
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 54 deletions.
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.median` for nullable dtypes (:issue:`37493`)
- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`)
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)
Expand Down Expand Up @@ -173,7 +174,7 @@ Missing
MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
-

Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def is_bool_array(values: np.ndarray, skipna: bool = ...): ...
def fast_multiget(mapping: dict, keys: np.ndarray, default=...) -> np.ndarray: ...
def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ...
def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ...
def fast_unique_multiple(arrays: list, sort: bool = ...) -> list: ...
def fast_unique_multiple(left: np.ndarray, right: np.ndarray) -> list: ...
def map_infer(
arr: np.ndarray,
f: Callable[[Any], Any],
Expand Down
59 changes: 23 additions & 36 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections import abc
from decimal import Decimal
from enum import Enum
import inspect
from typing import Literal
import warnings

cimport cython
from cpython.datetime cimport (
Expand Down Expand Up @@ -31,8 +29,6 @@ from cython cimport (
floating,
)

from pandas.util._exceptions import find_stack_level

import_datetime()

import numpy as np
Expand Down Expand Up @@ -314,51 +310,42 @@ def item_from_zerodim(val: object) -> object:

@cython.wraparound(False)
@cython.boundscheck(False)
def fast_unique_multiple(list arrays, sort: bool = True):
def fast_unique_multiple(ndarray left, ndarray right) -> list:
"""
Generate a list of unique values from a list of arrays.
Generate a list indices we have to add to the left to get the union
of both arrays.

Parameters
----------
list : array-like
List of array-like objects.
sort : bool
Whether or not to sort the resulting unique list.
left : np.ndarray
Left array that is used as base.
right : np.ndarray
right array that is checked for values that are not in left.
right can not have duplicates.

Returns
-------
list of unique values
list of indices that we have to add to the left array.
"""
cdef:
ndarray[object] buf
Py_ssize_t k = len(arrays)
Py_ssize_t i, j, n
list uniques = []
dict table = {}
Py_ssize_t j, n
list indices = []
set table = set()
object val, stub = 0

for i in range(k):
buf = arrays[i]
n = len(buf)
for j in range(n):
val = buf[j]
if val not in table:
table[val] = stub
uniques.append(val)
n = len(left)
for j in range(n):
val = left[j]
if val not in table:
table.add(val)

if sort is None:
try:
uniques.sort()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
n = len(right)
for j in range(n):
val = right[j]
if val not in table:
indices.append(j)

return uniques
return indices


@cython.wraparound(False)
Expand Down
25 changes: 21 additions & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3638,21 +3638,38 @@ def _union(self, other, sort) -> MultiIndex:
if (
any(-1 in code for code in self.codes)
and any(-1 in code for code in other.codes)
or self.has_duplicates
or other.has_duplicates
):
# This is only necessary if both sides have nans or one has dups,
# This is only necessary if both sides have nans or other has dups,
# fast_unique_multiple is faster
result = super()._union(other, sort)

if isinstance(result, MultiIndex):
return result
return MultiIndex.from_arrays(
zip(*result), sortorder=None, names=result_names
)

else:
rvals = other._values.astype(object, copy=False)
result = lib.fast_unique_multiple([self._values, rvals], sort=sort)
right_missing = lib.fast_unique_multiple(self._values, rvals)
if right_missing:
result = self.append(other.take(right_missing))
else:
result = self._get_reconciled_name_object(other)

return MultiIndex.from_arrays(zip(*result), sortorder=None, names=result_names)
if sort is None:
try:
result = result.sort_values()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
return result

def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
return is_object_dtype(dtype)
Expand Down
41 changes: 29 additions & 12 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,6 @@ def test_union(idx, sort):
assert result.equals(idx)


@pytest.mark.xfail(
# This test was commented out from Oct 2011 to Dec 2021, may no longer
# be relevant.
reason="Length of names must match number of levels in MultiIndex",
raises=ValueError,
)
def test_union_with_regular_index(idx):
other = Index(["A", "B", "C"])

Expand All @@ -277,7 +271,9 @@ def test_union_with_regular_index(idx):
msg = "The values in the array are unorderable"
with tm.assert_produces_warning(RuntimeWarning, match=msg):
result2 = idx.union(other)
assert result.equals(result2)
# This is more consistent now, if sorting fails then we don't sort at all
# in the MultiIndex case.
assert not result.equals(result2)


def test_intersection(idx, sort):
Expand Down Expand Up @@ -525,6 +521,26 @@ def test_union_nan_got_duplicated():
tm.assert_index_equal(result, mi2)


@pytest.mark.parametrize("val", [4, 1])
def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
# GH#48505

arr1 = Series([val, 2], dtype=any_numeric_ea_dtype)
arr2 = Series([2, 1], dtype=any_numeric_ea_dtype)
midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None])
midx2 = MultiIndex.from_arrays([arr2, [2, 1]])
result = midx.union(midx2)
if val == 4:
expected = MultiIndex.from_arrays(
[Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]]
)
else:
expected = MultiIndex.from_arrays(
[Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]]
)
tm.assert_index_equal(result, expected)


def test_union_duplicates(index, request):
# GH#38977
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
Expand All @@ -534,18 +550,19 @@ def test_union_duplicates(index, request):
values = index.unique().values.tolist()
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])
result = mi1.union(mi2)
result = mi2.union(mi1)
expected = mi2.sort_values()
tm.assert_index_equal(result, expected)

if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all():
# GH#47294 - union uses lib.fast_zip, converting data to Python integers
# and loses type information. Result is then unsigned only when values are
# sufficiently large to require unsigned dtype.
# sufficiently large to require unsigned dtype. This happens only if other
# has dups or one of both have missing values
expected = expected.set_levels(
[expected.levels[0].astype(int), expected.levels[1]]
)
tm.assert_index_equal(result, expected)

result = mi2.union(mi1)
result = mi1.union(mi2)
tm.assert_index_equal(result, expected)


Expand Down

0 comments on commit aea824f

Please sign in to comment.