Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Comparison between two PeriodIndexes doesn't validate length (GH #23078) #23896

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ Datetimelike
- Bug in :class:`DatetimeIndex` and :class:`TimedeltaIndex` where indexing with ``Ellipsis`` would incorrectly lose the index's ``freq`` attribute (:issue:`21282`)
- Clarified error message produced when passing an incorrect ``freq`` argument to :class:`DatetimeIndex` with ``NaT`` as the first entry in the passed data (:issue:`11587`)
- Bug in :func:`to_datetime` where ``box`` and ``utc`` arguments were ignored when passing a :class:`DataFrame` or ``dict`` of unit mappings (:issue:`23760`)
- Bug in :class:`PeriodIndex` when comparing indexes of different lengths, ValueError is not raised (:issue:`23078`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use double backticks around ValueError


Timedelta
^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,10 @@ def _add_comparison_ops(cls):
cls.__le__ = cls._create_comparison_method(operator.le)
cls.__ge__ = cls._create_comparison_method(operator.ge)

def _validate_shape(self, other):
if len(self) != len(other):
raise ValueError('Lengths must match to compare')


class ExtensionScalarOpsMixin(ExtensionOpsMixin):
"""
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def wrapper(self, other):
elif isinstance(other, cls):
self._check_compatible_with(other)

self._validate_shape(other)

if not_implemented:
return NotImplemented
result = op(other.asi8)
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data):
"{} does not implement add".format(data.__class__.__name__)
)

def test_arith_diff_lengths(self, data, all_arithmetic_operators):
op = self.get_op_from_name(all_arithmetic_operators)
other = data[:3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test actually wouldn't catch the behavior in #23078, which is caused specifically by comparisons of PeriodIndex with length-1 objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue in #23078 was treated in the function test_comp_op (pandas/tests/indexes/period/test_period.py).

with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you match the message as well

op(data, other)


class BaseComparisonOpsTests(BaseOpsUtil):
"""Various Series and DataFrame comparison ops methods."""
Expand Down Expand Up @@ -164,3 +170,9 @@ def test_direct_arith_with_series_returns_not_implemented(self, data):
raise pytest.skip(
"{} does not implement __eq__".format(data.__class__.__name__)
)

def test_compare_diff_lengths(self, data, all_compare_operators):
op = self.get_op_from_name(all_compare_operators)
other = data[:3]
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match the message

op(data, other)
13 changes: 13 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
def test_error(self):
pass

# TODO
# Raise ValueError when carrying out arithmetic operation
# on two decimal arrays of different lengths
@pytest.mark.xfail(reason="raise of ValueError not implemented")
def test_arith_diff_lengths(self, data, all_compare_operators):
super().test_arith_diff_lengths(data, all_compare_operators)


class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests):

Expand All @@ -324,6 +331,12 @@ def test_compare_array(self, data, all_compare_operators):
for i in alter]
self._compare_other(s, data, op_name, other)

# TODO:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these xfail? make the change as needed in the decimal code

# Raise ValueError when comparing decimal arrays of different lenghts
@pytest.mark.xfail(reason="raise of ValueError not implemented")
def test_compare_diff_lengths(self, data, all_compare_operators):
super().test_compare_diff_lenths(data, all_compare_operators)


class DecimalArrayWithoutFromSequence(DecimalArray):
"""Helper class for testing error handling in _from_sequence."""
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,13 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
s, op, other, exc=TypeError
)

def test_arith_diff_lengths(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these added?

pass


class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
pass
def test_compare_diff_lengths(self):
pass


class TestPrinting(BaseJSON, base.BasePrintingTests):
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
s, op, other, exc=TypeError
)

def test_arith_diff_lengths(self):
pass


class TestComparisonOps(base.BaseComparisonOpsTests):

Expand All @@ -233,3 +236,11 @@ def _compare_other(self, s, data, op_name, other):
else:
with pytest.raises(TypeError):
op(data, other)

@pytest.mark.parametrize('op_name',
['__eq__', '__ne__'])
def test_compare_diff_lengths(self, data, op_name):
op = self.get_op_from_name(op_name)
other = data[:3]
with pytest.raises(ValueError):
op(data, other)
7 changes: 7 additions & 0 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def check_opname(self, s, op_name, other, exc=None):
def _compare_other(self, s, data, op_name, other):
self.check_opname(s, op_name, other)

@pytest.mark.filterwarnings("ignore:elementwise:DeprecationWarning")
def test_compare_diff_lengths(self, data, all_compare_operators):
op = self.get_op_from_name(all_compare_operators)
other = data[:3]
with pytest.raises(ValueError):
op(data, other)


class TestInterface(base.BaseInterfaceTests):
pass
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def test_add_series_with_extension_array(self, data):
def test_error(self):
pass

def test_arith_diff_lengths(self, data):
op = self.get_op_from_name('__sub__')
other = data[:3]
with pytest.raises(ValueError):
op(data, other)

def test_direct_arith_with_series_returns_not_implemented(self, data):
# Override to use __sub__ instead of __add__
other = pd.Series(data)
Expand Down
22 changes: 22 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
all_arithmetic_operators
)

def test_arith_diff_lengths(self, data, all_arithmetic_operators):
from pandas.core.dtypes.common import is_float_dtype

if is_float_dtype(data):
op = self.get_op_from_name(all_arithmetic_operators)
other = data[:3]
with pytest.raises(ValueError):
op(data, other)
else:
pass


class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests):

Expand Down Expand Up @@ -348,6 +359,17 @@ def _compare_other(self, s, data, op_name, other):
result = op(s, other)
tm.assert_series_equal(result, expected)

def test_compare_diff_lengths(self, data, all_compare_operators):
from pandas.core.dtypes.common import is_float_dtype

if is_float_dtype(data):
op = self.get_op_from_name(all_compare_operators)
other = data[:3]
with pytest.raises(ValueError):
op(data, other)
else:
pass


class TestPrinting(BaseSparseTests, base.BasePrintingTests):
@pytest.mark.xfail(reason='Different repr', strict=True)
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/indexes/period/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,12 @@ def test_insert(self):
result = period_range('2017Q1', periods=4, freq='Q').insert(1, na)
tm.assert_index_equal(result, expected)

def test_comp_op(self):
makbigc marked this conversation as resolved.
Show resolved Hide resolved
# GH 23078
index = period_range('2017', periods=12, freq="A-DEC")
with pytest.raises(ValueError, match="Lengths must match"):
index <= index[[0]]


def test_maybe_convert_timedelta():
pi = PeriodIndex(['2000', '2001'], freq='D')
Expand Down