Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: return RangeIndex from difference, symmetric_difference #36564

Merged
merged 7 commits into from
Oct 7, 2020
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ Other enhancements
- :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`)
- :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`)
- Added methods :meth:`IntegerArray.prod`, :meth:`IntegerArray.min`, and :meth:`IntegerArray.max` (:issue:`33790`)
- Where possible :meth:`RangeIndex.difference` and :meth:`RangeIndex.symmetric_difference` will return :class:`RangeIndex` instead of :class:`Int64Index` (:issue:`36564`)

.. _whatsnew_120.api_breaking.python:

Expand Down
59 changes: 59 additions & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ def equals(self, other: object) -> bool:
return self._range == other._range
return super().equals(other)

# --------------------------------------------------------------------
# Set Operations

def intersection(self, other, sort=False):
"""
Form the intersection of two Index objects.
Expand Down Expand Up @@ -634,6 +637,57 @@ def _union(self, other, sort):
return type(self)(start_r, end_r + step_o, step_o)
return self._int64index._union(other, sort=sort)

def difference(self, other, sort=None):
# optimized set operation if we have another RangeIndex
self._validate_sort_keyword(sort)

if not isinstance(other, RangeIndex):
return super().difference(other, sort=sort)

res_name = ops.get_op_result_name(self, other)

first = self._range[::-1] if self.step < 0 else self._range
overlap = self.intersection(other)
if overlap.step < 0:
overlap = overlap[::-1]

if len(overlap) == 0:
return self._shallow_copy(name=res_name)
if len(overlap) == len(self):
return self[:0].rename(res_name)
if not isinstance(overlap, RangeIndex):
# We wont end up with RangeIndex, so fall back
return super().difference(other, sort=sort)

if overlap[0] == first.start:
# The difference is everything after the intersection
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
elif overlap[-1] == first.stop:
# The difference is everything before the intersection
new_rng = range(first.start, overlap[0] - first.step, first.step)
else:
# The difference is not range-like
return super().difference(other, sort=sort)

new_index = type(self)._simple_new(new_rng, name=res_name)
if first is not self._range:
new_index = new_index[::-1]
return new_index

def symmetric_difference(self, other, result_name=None, sort=None):
if not isinstance(other, RangeIndex) or sort is not None:
return super().symmetric_difference(other, result_name, sort)

left = self.difference(other)
right = other.difference(self)
result = left.union(right)

if result_name is not None:
result = result.rename(result_name)
return result

# --------------------------------------------------------------------

@doc(Int64Index.join)
def join(self, other, how="left", level=None, return_indexers=False, sort=False):
if how == "outer" and self is not other:
Expand Down Expand Up @@ -746,12 +800,17 @@ def __floordiv__(self, other):
return self._simple_new(new_range, name=self.name)
return self._int64index // other

# --------------------------------------------------------------------
# Reductions

def all(self) -> bool:
return 0 not in self._range

def any(self) -> bool:
return any(self._range)

# --------------------------------------------------------------------

@classmethod
def _add_numeric_methods_binary(cls):
""" add in numeric methods, specialized to RangeIndex """
Expand Down
48 changes: 48 additions & 0 deletions pandas/tests/indexes/ranges/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,51 @@ def test_union_sorted(self, unions):
res3 = idx1._int64index.union(idx2, sort=None)
tm.assert_index_equal(res2, expected_sorted, exact=True)
tm.assert_index_equal(res3, expected_sorted)

def test_difference(self):
# GH#12034 Cases where we operate against another RangeIndex and may
# get back another RangeIndex
obj = RangeIndex.from_range(range(1, 10), name="foo")

result = obj.difference(obj)
expected = RangeIndex.from_range(range(0), name="foo")
tm.assert_index_equal(result, expected)

result = obj.difference(expected.rename("bar"))
tm.assert_index_equal(result, obj.rename(None))

result = obj.difference(obj[:3])
tm.assert_index_equal(result, obj[3:])

result = obj.difference(obj[-3:])
tm.assert_index_equal(result, obj[:-3])

result = obj.difference(obj[2:6])
expected = Int64Index([1, 2, 7, 8, 9], name="foo")
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self):
# GH#12034 Cases where we operate against another RangeIndex and may
# get back another RangeIndex
left = RangeIndex.from_range(range(1, 10), name="foo")

result = left.symmetric_difference(left)
expected = RangeIndex.from_range(range(0), name="foo")
tm.assert_index_equal(result, expected)

result = left.symmetric_difference(expected.rename("bar"))
tm.assert_index_equal(result, left.rename(None))

result = left[:-2].symmetric_difference(left[2:])
expected = Int64Index([1, 2, 8, 9], name="foo")
tm.assert_index_equal(result, expected)

right = RangeIndex.from_range(range(10, 15))

result = left.symmetric_difference(right)
expected = RangeIndex.from_range(range(1, 15))
tm.assert_index_equal(result, expected)

result = left.symmetric_difference(right[1:])
expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
tm.assert_index_equal(result, expected)