Skip to content

Commit

Permalink
ENH: return RangeIndex from difference, symmetric_difference (#36564)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Oct 7, 2020
1 parent 09e7829 commit 4f674a1
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ Other enhancements
- :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`)
- :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`)
- Added methods :meth:`IntegerArray.prod`, :meth:`IntegerArray.min`, and :meth:`IntegerArray.max` (:issue:`33790`)
- Where possible :meth:`RangeIndex.difference` and :meth:`RangeIndex.symmetric_difference` will return :class:`RangeIndex` instead of :class:`Int64Index` (:issue:`36564`)

.. _whatsnew_120.api_breaking.python:

Expand Down
59 changes: 59 additions & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ def equals(self, other: object) -> bool:
return self._range == other._range
return super().equals(other)

# --------------------------------------------------------------------
# Set Operations

def intersection(self, other, sort=False):
"""
Form the intersection of two Index objects.
Expand Down Expand Up @@ -634,6 +637,57 @@ def _union(self, other, sort):
return type(self)(start_r, end_r + step_o, step_o)
return self._int64index._union(other, sort=sort)

def difference(self, other, sort=None):
# optimized set operation if we have another RangeIndex
self._validate_sort_keyword(sort)

if not isinstance(other, RangeIndex):
return super().difference(other, sort=sort)

res_name = ops.get_op_result_name(self, other)

first = self._range[::-1] if self.step < 0 else self._range
overlap = self.intersection(other)
if overlap.step < 0:
overlap = overlap[::-1]

if len(overlap) == 0:
return self._shallow_copy(name=res_name)
if len(overlap) == len(self):
return self[:0].rename(res_name)
if not isinstance(overlap, RangeIndex):
# We wont end up with RangeIndex, so fall back
return super().difference(other, sort=sort)

if overlap[0] == first.start:
# The difference is everything after the intersection
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
elif overlap[-1] == first.stop:
# The difference is everything before the intersection
new_rng = range(first.start, overlap[0] - first.step, first.step)
else:
# The difference is not range-like
return super().difference(other, sort=sort)

new_index = type(self)._simple_new(new_rng, name=res_name)
if first is not self._range:
new_index = new_index[::-1]
return new_index

def symmetric_difference(self, other, result_name=None, sort=None):
if not isinstance(other, RangeIndex) or sort is not None:
return super().symmetric_difference(other, result_name, sort)

left = self.difference(other)
right = other.difference(self)
result = left.union(right)

if result_name is not None:
result = result.rename(result_name)
return result

# --------------------------------------------------------------------

@doc(Int64Index.join)
def join(self, other, how="left", level=None, return_indexers=False, sort=False):
if how == "outer" and self is not other:
Expand Down Expand Up @@ -746,12 +800,17 @@ def __floordiv__(self, other):
return self._simple_new(new_range, name=self.name)
return self._int64index // other

# --------------------------------------------------------------------
# Reductions

def all(self) -> bool:
return 0 not in self._range

def any(self) -> bool:
return any(self._range)

# --------------------------------------------------------------------

@classmethod
def _add_numeric_methods_binary(cls):
""" add in numeric methods, specialized to RangeIndex """
Expand Down
48 changes: 48 additions & 0 deletions pandas/tests/indexes/ranges/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,51 @@ def test_union_sorted(self, unions):
res3 = idx1._int64index.union(idx2, sort=None)
tm.assert_index_equal(res2, expected_sorted, exact=True)
tm.assert_index_equal(res3, expected_sorted)

def test_difference(self):
# GH#12034 Cases where we operate against another RangeIndex and may
# get back another RangeIndex
obj = RangeIndex.from_range(range(1, 10), name="foo")

result = obj.difference(obj)
expected = RangeIndex.from_range(range(0), name="foo")
tm.assert_index_equal(result, expected)

result = obj.difference(expected.rename("bar"))
tm.assert_index_equal(result, obj.rename(None))

result = obj.difference(obj[:3])
tm.assert_index_equal(result, obj[3:])

result = obj.difference(obj[-3:])
tm.assert_index_equal(result, obj[:-3])

result = obj.difference(obj[2:6])
expected = Int64Index([1, 2, 7, 8, 9], name="foo")
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self):
# GH#12034 Cases where we operate against another RangeIndex and may
# get back another RangeIndex
left = RangeIndex.from_range(range(1, 10), name="foo")

result = left.symmetric_difference(left)
expected = RangeIndex.from_range(range(0), name="foo")
tm.assert_index_equal(result, expected)

result = left.symmetric_difference(expected.rename("bar"))
tm.assert_index_equal(result, left.rename(None))

result = left[:-2].symmetric_difference(left[2:])
expected = Int64Index([1, 2, 8, 9], name="foo")
tm.assert_index_equal(result, expected)

right = RangeIndex.from_range(range(10, 15))

result = left.symmetric_difference(right)
expected = RangeIndex.from_range(range(1, 15))
tm.assert_index_equal(result, expected)

result = left.symmetric_difference(right[1:])
expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
tm.assert_index_equal(result, expected)

0 comments on commit 4f674a1

Please sign in to comment.