Skip to content

Commit

Permalink
Fixed #36562
Browse files Browse the repository at this point in the history
* Use special sorting comparator for tuple arrays which can be created when consolidate_first is called on DataFrames with MultiIndex which contain nan and string values
  • Loading branch information
ssche committed Oct 13, 2020
1 parent 6db851b commit 1320ff1
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ MultiIndex
^^^^^^^^^^

- Bug in :meth:`DataFrame.xs` when used with :class:`IndexSlice` raises ``TypeError`` with message ``"Expected label or tuple of labels"`` (:issue:`35301`)
-
- Bug in :meth:`DataFrame.combine_first` when used with :class:`MultiIndex` containing string and ``NaN`` values raises ``TypeError`` with message ``"'<' not supported between instances of 'float' and 'str'"`` (:issue:`36562`)


I/O
^^^
Expand Down
48 changes: 44 additions & 4 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import operator
import functools
from textwrap import dedent
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, cast
from warnings import catch_warnings, simplefilter, warn
Expand Down Expand Up @@ -2055,13 +2056,52 @@ def sort_mixed(values):
strs = np.sort(values[str_pos])
return np.concatenate([nums, np.asarray(strs, dtype=object)])

def sort_tuples(values):
# sorts tuples with mixed values. can handle nan vs string comparisons.
def cmp_func(index_x, index_y):
x = values[index_x]
y = values[index_y]
if x == y:
return 0
len_x = len(x)
len_y = len(y)
for i in range(max(len_x, len_y)):
# check if the tuples have different lengths (shorter tuples
# first)
if i >= len_x:
return -1
if i >= len_y:
return +1
x_i_na = isna(x[i])
y_i_na = isna(y[i])
# values are the same -> resolve tie with next element
if (x_i_na and y_i_na) or (x[i] == y[i]):
continue
# check for nan values (sort nan to the end which is consistent
# with numpy
if x_i_na and not y_i_na:
return +1
if not x_i_na and y_i_na:
return -1
# normal greater/less than comparison
if x[i] < y[i]:
return -1
return +1
return 0

ixs = np.arange(len(values))
ixs = sorted(ixs, key=functools.cmp_to_key(cmp_func))
return values[ixs]

sorter = None
if (
not is_extension_array_dtype(values)
and lib.infer_dtype(values, skipna=False) == "mixed-integer"
):

ext_arr = is_extension_array_dtype(values)
if not ext_arr and lib.infer_dtype(values, skipna=False) == "mixed-integer":
# unorderable in py3 if mixed str/int
ordered = sort_mixed(values)
elif not ext_arr and values.size and isinstance(values[0], tuple):
# 1-D arrays with tuples of potentially mixed type (solves GH36562)
ordered = sort_tuples(values)
else:
try:
sorter = values.argsort()
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/indexing/multiindex/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,26 @@ def test_multiindex_get_loc_list_raises(self):
msg = "unhashable type"
with pytest.raises(TypeError, match=msg):
idx.get_loc([])


def test_combine_first_with_nan_index():
mi1 = pd.MultiIndex.from_arrays(
[["b", "b", "c", "a", "b", np.nan], [1, 2, 3, 4, 5, 6]],
names=["a", "b"]
)
df = pd.DataFrame({"c": [1, 1, 1, 1, 1, 1]}, index=mi1)
mi2 = pd.MultiIndex.from_arrays(
[["a", "b", "c", "a", "b", "d"], [1, 1, 1, 1, 1, 1]], names=["a", "b"]
)
s = pd.Series([1, 2, 3, 4, 5, 6], index=mi2)
df_combined = df.combine_first(pd.DataFrame({"col": s}))
mi_expected = pd.MultiIndex.from_arrays([
["a", "a", "a", "b", "b", "b", "b", "c", "c", "d", np.nan],
[1, 1, 4, 1, 1, 2, 5, 1, 3, 1, 6,]
], names=["a", "b"])
assert (df_combined.index == mi_expected).all()
exp_col = np.asarray(
[1.0, 4.0, np.nan, 2.0, 5.0, np.nan, np.nan, 3.0, np.nan, 6.0, np.nan]
)
act_col = df_combined['col'].values
assert np.allclose(act_col, exp_col, rtol=0, atol=0, equal_nan=True)
7 changes: 7 additions & 0 deletions pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,10 @@ def test_extension_array_codes(self, verify, na_sentinel):
expected_codes = np.array([0, 2, na_sentinel, 1], dtype=np.intp)
tm.assert_extension_array_equal(result, expected_values)
tm.assert_numpy_array_equal(codes, expected_codes)


def test_mixed_str_nan():
values = np.array(["b", np.nan, "a", "b"], dtype=object)
result = safe_sort(values)
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
tm.assert_numpy_array_equal(result, expected)

0 comments on commit 1320ff1

Please sign in to comment.