Skip to content

Commit

Permalink
fix inconsistent index naming with union/intersect GH35847 (#36413)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlemec authored Oct 7, 2020
1 parent ec8c1c4 commit 1f67100
Show file tree
Hide file tree
Showing 16 changed files with 196 additions and 72 deletions.
8 changes: 8 additions & 0 deletions doc/source/user_guide/merging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ functionality below.
frames = [ process_your_file(f) for f in files ]
result = pd.concat(frames)

.. note::

When concatenating DataFrames with named axes, pandas will attempt to preserve
these index/column names whenever possible. In the case where all inputs share a
common name, this name will be assigned to the result. When the input names do
not all agree, the result will be unnamed. The same is true for :class:`MultiIndex`,
but the logic is applied separately on a level-by-level basis.


Set logic on the other axes
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
20 changes: 20 additions & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ Alternatively, you can also use the dtype object:
behaviour or API may still change without warning. Expecially the behaviour
regarding NaN (distinct from NA missing values) is subject to change.

.. _whatsnew_120.index_name_preservation:

Index/column name preservation when aggregating
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When aggregating using :meth:`concat` or the :class:`DataFrame` constructor, Pandas
will attempt to preserve index (and column) names whenever possible (:issue:`35847`).
In the case where all inputs share a common name, this name will be assigned to the
result. When the input names do not all agree, the result will be unnamed. Here is an
example where the index name is preserved:

.. ipython:: python
idx = pd.Index(range(5), name='abc')
ser = pd.Series(range(5, 10), index=idx)
pd.concat({'x': ser[1:], 'y': ser[:-1]}, axis=1)
The same is true for :class:`MultiIndex`, but the logic is applied separately on a
level-by-level basis.

.. _whatsnew_120.enhancements.other:

Other enhancements
Expand Down
32 changes: 4 additions & 28 deletions pandas/core/indexes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from pandas._libs import NaT, lib
from pandas.errors import InvalidIndexError

import pandas.core.common as com
from pandas.core.indexes.base import (
Index,
_new_Index,
ensure_index,
ensure_index_from_sequences,
get_unanimous_names,
)
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import DatetimeIndex
Expand Down Expand Up @@ -57,7 +57,7 @@
"ensure_index_from_sequences",
"get_objs_combined_axis",
"union_indexes",
"get_consensus_names",
"get_unanimous_names",
"all_indexes_same",
]

Expand Down Expand Up @@ -221,9 +221,9 @@ def conv(i):
if not all(index.equals(other) for other in indexes[1:]):
index = _unique_indices(indexes)

name = get_consensus_names(indexes)[0]
name = get_unanimous_names(*indexes)[0]
if name != index.name:
index = index._shallow_copy(name=name)
index = index.rename(name)
return index
else: # kind='list'
return _unique_indices(indexes)
Expand Down Expand Up @@ -267,30 +267,6 @@ def _sanitize_and_check(indexes):
return indexes, "array"


def get_consensus_names(indexes):
"""
Give a consensus 'names' to indexes.
If there's exactly one non-empty 'names', return this,
otherwise, return empty.
Parameters
----------
indexes : list of Index objects
Returns
-------
list
A list representing the consensus 'names' found.
"""
# find the non-none names, need to tupleify to make
# the set hashable, then reverse on return
consensus_names = {tuple(i.names) for i in indexes if com.any_not_none(*i.names)}
if len(consensus_names) == 1:
return list(list(consensus_names)[0])
return [None] * indexes[0].nlevels


def all_indexes_same(indexes):
"""
Determine if all indexes contain the same elements.
Expand Down
50 changes: 37 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import copy as copy_func
from datetime import datetime
from itertools import zip_longest
import operator
from textwrap import dedent
from typing import (
Expand All @@ -11,6 +12,7 @@
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
Expand Down Expand Up @@ -2525,7 +2527,7 @@ def _get_reconciled_name_object(self, other):
"""
name = get_op_result_name(self, other)
if self.name != name:
return self._shallow_copy(name=name)
return self.rename(name)
return self

def _union_incompatible_dtypes(self, other, sort):
Expand Down Expand Up @@ -2633,7 +2635,9 @@ def union(self, other, sort=None):
if not self._can_union_without_object_cast(other):
return self._union_incompatible_dtypes(other, sort=sort)

return self._union(other, sort=sort)
result = self._union(other, sort=sort)

return self._wrap_setop_result(other, result)

def _union(self, other, sort):
"""
Expand All @@ -2655,10 +2659,10 @@ def _union(self, other, sort):
Index
"""
if not len(other) or self.equals(other):
return self._get_reconciled_name_object(other)
return self

if not len(self):
return other._get_reconciled_name_object(self)
return other

# TODO(EA): setops-refactor, clean all this up
lvals = self._values
Expand Down Expand Up @@ -2700,12 +2704,16 @@ def _union(self, other, sort):
stacklevel=3,
)

# for subclasses
return self._wrap_setop_result(other, result)
return self._shallow_copy(result)

def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name)
if isinstance(result, Index):
if result.name != name:
return result.rename(name)
return result
else:
return self._shallow_copy(result, name=name)

# TODO: standardize return type of non-union setops type(self vs other)
def intersection(self, other, sort=False):
Expand Down Expand Up @@ -2775,15 +2783,12 @@ def intersection(self, other, sort=False):
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
indexer = indexer[indexer != -1]

taken = other.take(indexer)
res_name = get_op_result_name(self, other)
result = other.take(indexer)

if sort is None:
taken = algos.safe_sort(taken.values)
return self._shallow_copy(taken, name=res_name)
result = algos.safe_sort(result.values)

taken.name = res_name
return taken
return self._wrap_setop_result(other, result)

def difference(self, other, sort=None):
"""
Expand Down Expand Up @@ -5968,3 +5973,22 @@ def _maybe_asobject(dtype, klass, data, copy: bool, name: Label, **kwargs):
return index.astype(object)

return klass(data, dtype=dtype, copy=copy, name=name, **kwargs)


def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]:
"""
Return common name if all indices agree, otherwise None (level-by-level).
Parameters
----------
indexes : list of Index objects
Returns
-------
list
A list representing the unanimous 'names' found.
"""
name_tups = [tuple(i.names) for i in indexes]
name_sets = [{*ns} for ns in zip_longest(*name_tups)]
names = tuple(ns.pop() if len(ns) == 1 else None for ns in name_sets)
return names
22 changes: 10 additions & 12 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,33 +719,29 @@ def intersection(self, other, sort=False):
"""
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
res_name = get_op_result_name(self, other)

if self.equals(other):
return self._get_reconciled_name_object(other)

if len(self) == 0:
return self.copy()
return self.copy()._get_reconciled_name_object(other)
if len(other) == 0:
return other.copy()
return other.copy()._get_reconciled_name_object(self)

if not isinstance(other, type(self)):
result = Index.intersection(self, other, sort=sort)
if isinstance(result, type(self)):
if result.freq is None:
# TODO: no tests rely on this; needed?
result = result._with_freq("infer")
result.name = res_name
return result

elif not self._can_fast_intersect(other):
result = Index.intersection(self, other, sort=sort)
# We need to invalidate the freq because Index.intersection
# uses _shallow_copy on a view of self._data, which will preserve
# self.freq if we're not careful.
result = result._with_freq(None)._with_freq("infer")
result.name = res_name
return result
return result._with_freq(None)._with_freq("infer")

# to make our life easier, "sort" the two ranges
if self[0] <= other[0]:
Expand All @@ -759,11 +755,13 @@ def intersection(self, other, sort=False):
start = right[0]

if end < start:
return type(self)(data=[], dtype=self.dtype, freq=self.freq, name=res_name)
result = type(self)(data=[], dtype=self.dtype, freq=self.freq)
else:
lslice = slice(*left.slice_locs(start, end))
left_chunk = left._values[lslice]
return type(self)._simple_new(left_chunk, name=res_name)
result = type(self)._simple_new(left_chunk)

return self._wrap_setop_result(other, result)

def _can_fast_intersect(self: _T, other: _T) -> bool:
if self.freq is None:
Expand Down Expand Up @@ -858,7 +856,7 @@ def _fast_union(self, other, sort=None):
# The can_fast_union check ensures that the result.freq
# should match self.freq
dates = type(self._data)(dates, freq=self.freq)
result = type(self)._simple_new(dates, name=self.name)
result = type(self)._simple_new(dates)
return result
else:
return left
Expand All @@ -883,8 +881,8 @@ def _union(self, other, sort):
result = result._with_freq("infer")
return result
else:
i8self = Int64Index._simple_new(self.asi8, name=self.name)
i8other = Int64Index._simple_new(other.asi8, name=other.name)
i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self._union(i8other, sort=sort)
result = type(self)(i8result, dtype=self.dtype, freq="infer")
return result
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from pandas.core.arrays.datetimes import DatetimeArray, tz_to_dtype
import pandas.core.common as com
from pandas.core.indexes.base import Index, maybe_extract_name
from pandas.core.indexes.base import Index, get_unanimous_names, maybe_extract_name
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
from pandas.core.indexes.extension import inherit_names
from pandas.core.tools.times import to_time
Expand Down Expand Up @@ -405,6 +405,10 @@ def union_many(self, others):
this = this._fast_union(other)
else:
this = Index.union(this, other)

res_name = get_unanimous_names(self, *others)[0]
if this.name != res_name:
return this.rename(res_name)
return this

# --------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def intersection(
if sort is None:
taken = taken.sort_values()

return taken
return self._wrap_setop_result(other, taken)

def _intersection_unique(self, other: "IntervalIndex") -> "IntervalIndex":
"""
Expand Down
16 changes: 11 additions & 5 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
from pandas.core.arrays.categorical import factorize_from_iterables
import pandas.core.common as com
import pandas.core.indexes.base as ibase
from pandas.core.indexes.base import Index, _index_shared_docs, ensure_index
from pandas.core.indexes.base import (
Index,
_index_shared_docs,
ensure_index,
get_unanimous_names,
)
from pandas.core.indexes.frozen import FrozenList
from pandas.core.indexes.numeric import Int64Index
import pandas.core.missing as missing
Expand Down Expand Up @@ -3426,7 +3431,7 @@ def union(self, other, sort=None):
other, result_names = self._convert_can_do_setop(other)

if len(other) == 0 or self.equals(other):
return self
return self.rename(result_names)

# TODO: Index.union returns other when `len(self)` is 0.

Expand Down Expand Up @@ -3468,7 +3473,7 @@ def intersection(self, other, sort=False):
other, result_names = self._convert_can_do_setop(other)

if self.equals(other):
return self
return self.rename(result_names)

if not is_object_dtype(other.dtype):
# The intersection is empty
Expand Down Expand Up @@ -3539,7 +3544,7 @@ def difference(self, other, sort=None):
other, result_names = self._convert_can_do_setop(other)

if len(other) == 0:
return self
return self.rename(result_names)

if self.equals(other):
return MultiIndex(
Expand Down Expand Up @@ -3587,7 +3592,8 @@ def _convert_can_do_setop(self, other):
except TypeError as err:
raise TypeError(msg) from err
else:
result_names = self.names if self.names == other.names else None
result_names = get_unanimous_names(self, other)

return other, result_names

# --------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def intersection(self, other, sort=False):
new_index = new_index[::-1]
if sort is None:
new_index = new_index.sort_values()
return new_index

return self._wrap_setop_result(other, new_index)

def _min_fitting_element(self, lower_limit: int) -> int:
"""Returns the smallest element greater than or equal to the limit"""
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
MultiIndex,
all_indexes_same,
ensure_index,
get_consensus_names,
get_objs_combined_axis,
get_unanimous_names,
)
import pandas.core.indexes.base as ibase
from pandas.core.internals import concatenate_block_managers
Expand Down Expand Up @@ -655,7 +655,7 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
)

# also copies
names = names + get_consensus_names(indexes)
names = list(names) + list(get_unanimous_names(*indexes))

return MultiIndex(
levels=levels, codes=codes_list, names=names, verify_integrity=False
Expand Down
Loading

0 comments on commit 1f67100

Please sign in to comment.