Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Keep dtypes in MultiIndex.union without NAs #48505

Merged
merged 7 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.median` for nullable dtypes (:issue:`37493`)
- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`)
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)
Expand Down Expand Up @@ -170,7 +171,7 @@ Missing
MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
-

Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def is_bool_array(values: np.ndarray, skipna: bool = ...): ...
def fast_multiget(mapping: dict, keys: np.ndarray, default=...) -> np.ndarray: ...
def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ...
def fast_unique_multiple_list(lists: list, sort: bool | None = ...) -> list: ...
def fast_unique_multiple(arrays: list, sort: bool = ...) -> list: ...
def fast_unique_multiple(left: np.ndarray, right: np.ndarray) -> list: ...
def map_infer(
arr: np.ndarray,
f: Callable[[Any], Any],
Expand Down
59 changes: 23 additions & 36 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections import abc
from decimal import Decimal
from enum import Enum
import inspect
from typing import Literal
import warnings

cimport cython
from cpython.datetime cimport (
Expand Down Expand Up @@ -31,8 +29,6 @@ from cython cimport (
floating,
)

from pandas.util._exceptions import find_stack_level

import_datetime()

import numpy as np
Expand Down Expand Up @@ -314,51 +310,42 @@ def item_from_zerodim(val: object) -> object:

@cython.wraparound(False)
@cython.boundscheck(False)
def fast_unique_multiple(list arrays, sort: bool = True):
def fast_unique_multiple(ndarray left, ndarray right) -> list:
"""
Generate a list of unique values from a list of arrays.
Generate a list indices we have to add to the left to get the union
of both arrays.

Parameters
----------
list : array-like
List of array-like objects.
sort : bool
Whether or not to sort the resulting unique list.
left : np.ndarray
Left array that is used as base.
right : np.ndarray
right array that is checked for values that are not in left.
right can not have duplicates.

Returns
-------
list of unique values
list of indices that we have to add to the left array.
"""
cdef:
ndarray[object] buf
Py_ssize_t k = len(arrays)
Py_ssize_t i, j, n
list uniques = []
dict table = {}
Py_ssize_t j, n
list indices = []
set table = set()
object val, stub = 0

for i in range(k):
buf = arrays[i]
n = len(buf)
for j in range(n):
val = buf[j]
if val not in table:
table[val] = stub
uniques.append(val)
n = len(left)
for j in range(n):
val = left[j]
if val not in table:
table.add(val)

if sort is None:
try:
uniques.sort()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
n = len(right)
for j in range(n):
val = right[j]
if val not in table:
indices.append(j)

return uniques
return indices


@cython.wraparound(False)
Expand Down
25 changes: 21 additions & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3638,21 +3638,38 @@ def _union(self, other, sort) -> MultiIndex:
if (
any(-1 in code for code in self.codes)
and any(-1 in code for code in other.codes)
or self.has_duplicates
or other.has_duplicates
):
# This is only necessary if both sides have nans or one has dups,
# This is only necessary if both sides have nans or other has dups,
# fast_unique_multiple is faster
result = super()._union(other, sort)

if isinstance(result, MultiIndex):
return result
return MultiIndex.from_arrays(
zip(*result), sortorder=None, names=result_names
)

else:
rvals = other._values.astype(object, copy=False)
result = lib.fast_unique_multiple([self._values, rvals], sort=sort)
right_missing = lib.fast_unique_multiple(self._values, rvals)
if right_missing:
result = self.append(other.take(right_missing))
else:
result = self._get_reconciled_name_object(other)

return MultiIndex.from_arrays(zip(*result), sortorder=None, names=result_names)
if sort is None:
try:
result = result.sort_values()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
return result

def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
return is_object_dtype(dtype)
Expand Down
41 changes: 29 additions & 12 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,6 @@ def test_union(idx, sort):
assert result.equals(idx)


@pytest.mark.xfail(
# This test was commented out from Oct 2011 to Dec 2021, may no longer
# be relevant.
reason="Length of names must match number of levels in MultiIndex",
raises=ValueError,
)
def test_union_with_regular_index(idx):
other = Index(["A", "B", "C"])

Expand All @@ -277,7 +271,9 @@ def test_union_with_regular_index(idx):
msg = "The values in the array are unorderable"
with tm.assert_produces_warning(RuntimeWarning, match=msg):
result2 = idx.union(other)
assert result.equals(result2)
# This is more consistent now, if sorting fails then we don't sort at all
# in the MultiIndex case.
assert not result.equals(result2)


def test_intersection(idx, sort):
Expand Down Expand Up @@ -525,6 +521,26 @@ def test_union_nan_got_duplicated():
tm.assert_index_equal(result, mi2)


@pytest.mark.parametrize("val", [4, 1])
def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
# GH#48505

arr1 = Series([val, 2], dtype=any_numeric_ea_dtype)
arr2 = Series([2, 1], dtype=any_numeric_ea_dtype)
midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None])
midx2 = MultiIndex.from_arrays([arr2, [2, 1]])
result = midx.union(midx2)
if val == 4:
expected = MultiIndex.from_arrays(
[Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]]
)
else:
expected = MultiIndex.from_arrays(
[Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]]
)
tm.assert_index_equal(result, expected)


def test_union_duplicates(index, request):
# GH#38977
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
Expand All @@ -534,18 +550,19 @@ def test_union_duplicates(index, request):
values = index.unique().values.tolist()
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])
result = mi1.union(mi2)
result = mi2.union(mi1)
expected = mi2.sort_values()
tm.assert_index_equal(result, expected)

if mi2.levels[0].dtype == np.uint64 and (mi2.get_level_values(0) < 2**63).all():
# GH#47294 - union uses lib.fast_zip, converting data to Python integers
# and loses type information. Result is then unsigned only when values are
# sufficiently large to require unsigned dtype.
# sufficiently large to require unsigned dtype. This happens only if other
# has dups or one of both have missing values
expected = expected.set_levels(
[expected.levels[0].astype(int), expected.levels[1]]
)
tm.assert_index_equal(result, expected)

result = mi2.union(mi1)
result = mi1.union(mi2)
tm.assert_index_equal(result, expected)


Expand Down