From 17254c4768a9cd873d86f428fee14dbfa7834ceb Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 27 Aug 2020 22:58:17 -0400 Subject: [PATCH] fix inconsistent index naming with union/intersect GH35847 --- doc/source/user_guide/merging.rst | 8 ++ doc/source/whatsnew/v1.2.0.rst | 20 +++++ pandas/core/indexes/api.py | 32 +------ pandas/core/indexes/base.py | 50 ++++++++--- pandas/core/indexes/datetimelike.py | 22 +++-- pandas/core/indexes/datetimes.py | 6 +- pandas/core/indexes/interval.py | 2 +- pandas/core/indexes/multi.py | 16 ++-- pandas/core/indexes/range.py | 3 +- pandas/core/reshape/concat.py | 4 +- pandas/tests/frame/test_constructors.py | 4 +- pandas/tests/indexes/datetimes/test_setops.py | 2 +- pandas/tests/indexes/multi/test_join.py | 2 +- pandas/tests/indexes/multi/test_setops.py | 6 +- pandas/tests/indexes/test_common.py | 87 +++++++++++++++++++ pandas/tests/reshape/test_concat.py | 4 +- 16 files changed, 196 insertions(+), 72 deletions(-) diff --git a/doc/source/user_guide/merging.rst b/doc/source/user_guide/merging.rst index da16aaf5b3a56..eeac0ed4837dd 100644 --- a/doc/source/user_guide/merging.rst +++ b/doc/source/user_guide/merging.rst @@ -154,6 +154,14 @@ functionality below. frames = [ process_your_file(f) for f in files ] result = pd.concat(frames) +.. note:: + + When concatenating DataFrames with named axes, pandas will attempt to preserve + these index/column names whenever possible. In the case where all inputs share a + common name, this name will be assigned to the result. When the input names do + not all agree, the result will be unnamed. The same is true for :class:`MultiIndex`, + but the logic is applied separately on a level-by-level basis. + Set logic on the other axes ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 2fe878897b2e7..a2e879791d198 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -157,6 +157,26 @@ Alternatively, you can also use the dtype object: behaviour or API may still change without warning. Expecially the behaviour regarding NaN (distinct from NA missing values) is subject to change. +.. _whatsnew_120.index_name_preservation: + +Index/column name preservation when aggregating +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When aggregating using :meth:`concat` or the :class:`DataFrame` constructor, Pandas +will attempt to preserve index (and column) names whenever possible (:issue:`35847`). +In the case where all inputs share a common name, this name will be assigned to the +result. When the input names do not all agree, the result will be unnamed. Here is an +example where the index name is preserved: + +.. ipython:: python + + idx = pd.Index(range(5), name='abc') + ser = pd.Series(range(5, 10), index=idx) + pd.concat({'x': ser[1:], 'y': ser[:-1]}, axis=1) + +The same is true for :class:`MultiIndex`, but the logic is applied separately on a +level-by-level basis. + .. _whatsnew_120.enhancements.other: Other enhancements diff --git a/pandas/core/indexes/api.py b/pandas/core/indexes/api.py index d352b001f5d2a..18981a2190552 100644 --- a/pandas/core/indexes/api.py +++ b/pandas/core/indexes/api.py @@ -4,12 +4,12 @@ from pandas._libs import NaT, lib from pandas.errors import InvalidIndexError -import pandas.core.common as com from pandas.core.indexes.base import ( Index, _new_Index, ensure_index, ensure_index_from_sequences, + get_unanimous_names, ) from pandas.core.indexes.category import CategoricalIndex from pandas.core.indexes.datetimes import DatetimeIndex @@ -57,7 +57,7 @@ "ensure_index_from_sequences", "get_objs_combined_axis", "union_indexes", - "get_consensus_names", + "get_unanimous_names", "all_indexes_same", ] @@ -221,9 +221,9 @@ def conv(i): if not all(index.equals(other) for other in indexes[1:]): index = _unique_indices(indexes) - name = get_consensus_names(indexes)[0] + name = get_unanimous_names(*indexes)[0] if name != index.name: - index = index._shallow_copy(name=name) + index = index.rename(name) return index else: # kind='list' return _unique_indices(indexes) @@ -267,30 +267,6 @@ def _sanitize_and_check(indexes): return indexes, "array" -def get_consensus_names(indexes): - """ - Give a consensus 'names' to indexes. - - If there's exactly one non-empty 'names', return this, - otherwise, return empty. - - Parameters - ---------- - indexes : list of Index objects - - Returns - ------- - list - A list representing the consensus 'names' found. - """ - # find the non-none names, need to tupleify to make - # the set hashable, then reverse on return - consensus_names = {tuple(i.names) for i in indexes if com.any_not_none(*i.names)} - if len(consensus_names) == 1: - return list(list(consensus_names)[0]) - return [None] * indexes[0].nlevels - - def all_indexes_same(indexes): """ Determine if all indexes contain the same elements. diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 4967e13a9855a..1e42ebf25c26d 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -1,5 +1,6 @@ from copy import copy as copy_func from datetime import datetime +from itertools import zip_longest import operator from textwrap import dedent from typing import ( @@ -11,6 +12,7 @@ List, Optional, Sequence, + Tuple, TypeVar, Union, ) @@ -2525,7 +2527,7 @@ def _get_reconciled_name_object(self, other): """ name = get_op_result_name(self, other) if self.name != name: - return self._shallow_copy(name=name) + return self.rename(name) return self def _union_incompatible_dtypes(self, other, sort): @@ -2633,7 +2635,9 @@ def union(self, other, sort=None): if not self._can_union_without_object_cast(other): return self._union_incompatible_dtypes(other, sort=sort) - return self._union(other, sort=sort) + result = self._union(other, sort=sort) + + return self._wrap_setop_result(other, result) def _union(self, other, sort): """ @@ -2655,10 +2659,10 @@ def _union(self, other, sort): Index """ if not len(other) or self.equals(other): - return self._get_reconciled_name_object(other) + return self if not len(self): - return other._get_reconciled_name_object(self) + return other # TODO(EA): setops-refactor, clean all this up lvals = self._values @@ -2700,12 +2704,16 @@ def _union(self, other, sort): stacklevel=3, ) - # for subclasses - return self._wrap_setop_result(other, result) + return self._shallow_copy(result) def _wrap_setop_result(self, other, result): name = get_op_result_name(self, other) - return self._shallow_copy(result, name=name) + if isinstance(result, Index): + if result.name != name: + return result.rename(name) + return result + else: + return self._shallow_copy(result, name=name) # TODO: standardize return type of non-union setops type(self vs other) def intersection(self, other, sort=False): @@ -2775,15 +2783,12 @@ def intersection(self, other, sort=False): indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0]) indexer = indexer[indexer != -1] - taken = other.take(indexer) - res_name = get_op_result_name(self, other) + result = other.take(indexer) if sort is None: - taken = algos.safe_sort(taken.values) - return self._shallow_copy(taken, name=res_name) + result = algos.safe_sort(result.values) - taken.name = res_name - return taken + return self._wrap_setop_result(other, result) def difference(self, other, sort=None): """ @@ -5968,3 +5973,22 @@ def _maybe_asobject(dtype, klass, data, copy: bool, name: Label, **kwargs): return index.astype(object) return klass(data, dtype=dtype, copy=copy, name=name, **kwargs) + + +def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]: + """ + Return common name if all indices agree, otherwise None (level-by-level). + + Parameters + ---------- + indexes : list of Index objects + + Returns + ------- + list + A list representing the unanimous 'names' found. + """ + name_tups = [tuple(i.names) for i in indexes] + name_sets = [{*ns} for ns in zip_longest(*name_tups)] + names = tuple(ns.pop() if len(ns) == 1 else None for ns in name_sets) + return names diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 4440238dbd493..5821ff0aca3c2 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -719,15 +719,14 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - res_name = get_op_result_name(self, other) if self.equals(other): return self._get_reconciled_name_object(other) if len(self) == 0: - return self.copy() + return self.copy()._get_reconciled_name_object(other) if len(other) == 0: - return other.copy() + return other.copy()._get_reconciled_name_object(self) if not isinstance(other, type(self)): result = Index.intersection(self, other, sort=sort) @@ -735,7 +734,6 @@ def intersection(self, other, sort=False): if result.freq is None: # TODO: no tests rely on this; needed? result = result._with_freq("infer") - result.name = res_name return result elif not self._can_fast_intersect(other): @@ -743,9 +741,7 @@ def intersection(self, other, sort=False): # We need to invalidate the freq because Index.intersection # uses _shallow_copy on a view of self._data, which will preserve # self.freq if we're not careful. - result = result._with_freq(None)._with_freq("infer") - result.name = res_name - return result + return result._with_freq(None)._with_freq("infer") # to make our life easier, "sort" the two ranges if self[0] <= other[0]: @@ -759,11 +755,13 @@ def intersection(self, other, sort=False): start = right[0] if end < start: - return type(self)(data=[], dtype=self.dtype, freq=self.freq, name=res_name) + result = type(self)(data=[], dtype=self.dtype, freq=self.freq) else: lslice = slice(*left.slice_locs(start, end)) left_chunk = left._values[lslice] - return type(self)._simple_new(left_chunk, name=res_name) + result = type(self)._simple_new(left_chunk) + + return self._wrap_setop_result(other, result) def _can_fast_intersect(self: _T, other: _T) -> bool: if self.freq is None: @@ -858,7 +856,7 @@ def _fast_union(self, other, sort=None): # The can_fast_union check ensures that the result.freq # should match self.freq dates = type(self._data)(dates, freq=self.freq) - result = type(self)._simple_new(dates, name=self.name) + result = type(self)._simple_new(dates) return result else: return left @@ -883,8 +881,8 @@ def _union(self, other, sort): result = result._with_freq("infer") return result else: - i8self = Int64Index._simple_new(self.asi8, name=self.name) - i8other = Int64Index._simple_new(other.asi8, name=other.name) + i8self = Int64Index._simple_new(self.asi8) + i8other = Int64Index._simple_new(other.asi8) i8result = i8self._union(i8other, sort=sort) result = type(self)(i8result, dtype=self.dtype, freq="infer") return result diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 06405995f7685..67b71ce63a6e3 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -25,7 +25,7 @@ from pandas.core.arrays.datetimes import DatetimeArray, tz_to_dtype import pandas.core.common as com -from pandas.core.indexes.base import Index, maybe_extract_name +from pandas.core.indexes.base import Index, get_unanimous_names, maybe_extract_name from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin from pandas.core.indexes.extension import inherit_names from pandas.core.tools.times import to_time @@ -405,6 +405,10 @@ def union_many(self, others): this = this._fast_union(other) else: this = Index.union(this, other) + + res_name = get_unanimous_names(self, *others)[0] + if this.name != res_name: + return this.rename(res_name) return this # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index e7747761d2a29..efb8a3e850b1a 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1023,7 +1023,7 @@ def intersection( if sort is None: taken = taken.sort_values() - return taken + return self._wrap_setop_result(other, taken) def _intersection_unique(self, other: "IntervalIndex") -> "IntervalIndex": """ diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index a157fdfdde447..7e2bad3e4bf93 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -47,7 +47,12 @@ from pandas.core.arrays.categorical import factorize_from_iterables import pandas.core.common as com import pandas.core.indexes.base as ibase -from pandas.core.indexes.base import Index, _index_shared_docs, ensure_index +from pandas.core.indexes.base import ( + Index, + _index_shared_docs, + ensure_index, + get_unanimous_names, +) from pandas.core.indexes.frozen import FrozenList from pandas.core.indexes.numeric import Int64Index import pandas.core.missing as missing @@ -3426,7 +3431,7 @@ def union(self, other, sort=None): other, result_names = self._convert_can_do_setop(other) if len(other) == 0 or self.equals(other): - return self + return self.rename(result_names) # TODO: Index.union returns other when `len(self)` is 0. @@ -3468,7 +3473,7 @@ def intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) if self.equals(other): - return self + return self.rename(result_names) if not is_object_dtype(other.dtype): # The intersection is empty @@ -3539,7 +3544,7 @@ def difference(self, other, sort=None): other, result_names = self._convert_can_do_setop(other) if len(other) == 0: - return self + return self.rename(result_names) if self.equals(other): return MultiIndex( @@ -3587,7 +3592,8 @@ def _convert_can_do_setop(self, other): except TypeError as err: raise TypeError(msg) from err else: - result_names = self.names if self.names == other.names else None + result_names = get_unanimous_names(self, other) + return other, result_names # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index f0b0773aeb47b..90b713e8f09a9 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -539,7 +539,8 @@ def intersection(self, other, sort=False): new_index = new_index[::-1] if sort is None: new_index = new_index.sort_values() - return new_index + + return self._wrap_setop_result(other, new_index) def _min_fitting_element(self, lower_limit: int) -> int: """Returns the smallest element greater than or equal to the limit""" diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index a07c7b49ac55b..91310a3659f33 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -23,8 +23,8 @@ MultiIndex, all_indexes_same, ensure_index, - get_consensus_names, get_objs_combined_axis, + get_unanimous_names, ) import pandas.core.indexes.base as ibase from pandas.core.internals import concatenate_block_managers @@ -655,7 +655,7 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde ) # also copies - names = names + get_consensus_names(indexes) + names = list(names) + list(get_unanimous_names(*indexes)) return MultiIndex( levels=levels, codes=codes_list, names=names, verify_integrity=False diff --git a/pandas/tests/frame/test_constructors.py b/pandas/tests/frame/test_constructors.py index b5e211895672a..8ec11d14cd606 100644 --- a/pandas/tests/frame/test_constructors.py +++ b/pandas/tests/frame/test_constructors.py @@ -1637,8 +1637,8 @@ def test_constructor_Series_differently_indexed(self): "name_in1,name_in2,name_in3,name_out", [ ("idx", "idx", "idx", "idx"), - ("idx", "idx", None, "idx"), - ("idx", None, None, "idx"), + ("idx", "idx", None, None), + ("idx", None, None, None), ("idx1", "idx2", None, None), ("idx1", "idx1", "idx2", None), ("idx1", "idx2", "idx3", None), diff --git a/pandas/tests/indexes/datetimes/test_setops.py b/pandas/tests/indexes/datetimes/test_setops.py index 102c8f97a8a6b..a8baf67273490 100644 --- a/pandas/tests/indexes/datetimes/test_setops.py +++ b/pandas/tests/indexes/datetimes/test_setops.py @@ -473,7 +473,7 @@ def test_intersection_list(self): values = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")] idx = pd.DatetimeIndex(values, name="a") res = idx.intersection(values) - tm.assert_index_equal(res, idx) + tm.assert_index_equal(res, idx.rename(None)) def test_month_range_union_tz_pytz(self, sort): from pytz import timezone diff --git a/pandas/tests/indexes/multi/test_join.py b/pandas/tests/indexes/multi/test_join.py index 6be9ec463ce36..562d07d283293 100644 --- a/pandas/tests/indexes/multi/test_join.py +++ b/pandas/tests/indexes/multi/test_join.py @@ -46,7 +46,7 @@ def test_join_level_corner_case(idx): def test_join_self(idx, join_type): joined = idx.join(idx, how=join_type) - assert idx is joined + tm.assert_index_equal(joined, idx) def test_join_multi(): diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index 6d4928547cad1..0b17c1c4c9679 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -243,10 +243,10 @@ def test_union(idx, sort): # corner case, pass self or empty thing: the_union = idx.union(idx, sort=sort) - assert the_union is idx + tm.assert_index_equal(the_union, idx) the_union = idx.union(idx[:0], sort=sort) - assert the_union is idx + tm.assert_index_equal(the_union, idx) # FIXME: dont leave commented-out # won't work in python 3 @@ -278,7 +278,7 @@ def test_intersection(idx, sort): # corner case, pass self the_int = idx.intersection(idx, sort=sort) - assert the_int is idx + tm.assert_index_equal(the_int, idx) # empty intersection: disjoint empty = idx[:2].intersection(idx[2:], sort=sort) diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index e2dea7828b3ad..94b10572fb5e1 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -124,6 +124,93 @@ def test_corner_union(self, index, fname, sname, expected_name): expected = index.drop(index).set_names(expected_name) tm.assert_index_equal(union, expected) + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_union_unequal(self, index, fname, sname, expected_name): + if isinstance(index, MultiIndex) or not index.is_unique: + pytest.skip("Not for MultiIndex or repeated indices") + + # test copy.union(subset) - need sort for unicode and string + first = index.copy().set_names(fname) + second = index[1:].set_names(sname) + union = first.union(second).sort_values() + expected = index.set_names(expected_name).sort_values() + tm.assert_index_equal(union, expected) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_corner_intersect(self, index, fname, sname, expected_name): + # GH35847 + # Test intersections with various name combinations + + if isinstance(index, MultiIndex) or not index.is_unique: + pytest.skip("Not for MultiIndex or repeated indices") + + # Test copy.intersection(copy) + first = index.copy().set_names(fname) + second = index.copy().set_names(sname) + intersect = first.intersection(second) + expected = index.copy().set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test copy.intersection(empty) + first = index.copy().set_names(fname) + second = index.drop(index).set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test empty.intersection(copy) + first = index.drop(index).set_names(fname) + second = index.copy().set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test empty.intersection(empty) + first = index.drop(index).set_names(fname) + second = index.drop(index).set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_intersect_unequal(self, index, fname, sname, expected_name): + if isinstance(index, MultiIndex) or not index.is_unique: + pytest.skip("Not for MultiIndex or repeated indices") + + # test copy.intersection(subset) - need sort for unicode and string + first = index.copy().set_names(fname) + second = index[1:].set_names(sname) + intersect = first.intersection(second).sort_values() + expected = index[1:].set_names(expected_name).sort_values() + tm.assert_index_equal(intersect, expected) + def test_to_flat_index(self, index): # 22866 if isinstance(index, MultiIndex): diff --git a/pandas/tests/reshape/test_concat.py b/pandas/tests/reshape/test_concat.py index b0f6a8ef0c517..f0eb745041a66 100644 --- a/pandas/tests/reshape/test_concat.py +++ b/pandas/tests/reshape/test_concat.py @@ -1300,8 +1300,8 @@ def test_concat_ignore_index(self, sort): "name_in1,name_in2,name_in3,name_out", [ ("idx", "idx", "idx", "idx"), - ("idx", "idx", None, "idx"), - ("idx", None, None, "idx"), + ("idx", "idx", None, None), + ("idx", None, None, None), ("idx1", "idx2", None, None), ("idx1", "idx1", "idx2", None), ("idx1", "idx2", "idx3", None),