diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 9422434a1d998..da249cb3592f4 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -160,9 +160,18 @@ your ``MyExtensionArray`` class, as follows: MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() -Note that since ``pandas`` automatically calls the underlying operator on each -element one-by-one, this might not be as performant as implementing your own -version of the associated operators directly on the ``ExtensionArray``. + +.. note:: + + Since ``pandas`` automatically calls the underlying operator on each + element one-by-one, this might not be as performant as implementing your own + version of the associated operators directly on the ``ExtensionArray``. + +This implementation will try to reconstruct a new ``ExtensionArray`` with the +result of the element-wise operation. Whether or not that succeeds depends on +whether the operation returns a result that's valid for the ``ExtensionArray``. +If an ``ExtensionArray`` cannot be reconstructed, a list containing the scalars +returned instead. .. _extending.extension.testing: diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 7bf13fb2fecc0..5a18a1049ee9c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -775,10 +775,18 @@ def convert_values(param): res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] if coerce_to_dtype: - try: - res = self._from_sequence(res) - except TypeError: - pass + if op.__name__ in {'divmod', 'rdivmod'}: + try: + a, b = zip(*res) + res = (self._from_sequence(a), + self._from_sequence(b)) + except TypeError: + pass + else: + try: + res = self._from_sequence(res) + except TypeError: + pass return res diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 387942234e6fd..f324cc2e0f345 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -138,5 +138,9 @@ def _concat_same_type(cls, to_concat): return cls(np.concatenate([x._data for x in to_concat])) +def to_decimal(values, context=None): + return DecimalArray([decimal.Decimal(x) for x in values], context=context) + + DecimalArray._add_arithmetic_ops() DecimalArray._add_comparison_ops() diff --git a/pandas/tests/extension/test_ops.py b/pandas/tests/extension/test_ops.py new file mode 100644 index 0000000000000..931194b474af4 --- /dev/null +++ b/pandas/tests/extension/test_ops.py @@ -0,0 +1,22 @@ +import pytest + +from pandas.tests.extension.decimal.array import to_decimal +import pandas.util.testing as tm + + +@pytest.mark.parametrize("reverse, expected_div, expected_mod", [ + (False, [0, 1, 1, 2], [1, 0, 1, 0]), + (True, [2, 1, 0, 0], [0, 0, 2, 2]), +]) +def test_divmod(reverse, expected_div, expected_mod): + # https://github.com/pandas-dev/pandas/issues/22930 + arr = to_decimal([1, 2, 3, 4]) + if reverse: + div, mod = divmod(2, arr) + else: + div, mod = divmod(arr, 2) + expected_div = to_decimal(expected_div) + expected_mod = to_decimal(expected_mod) + + tm.assert_extension_array_equal(div, expected_div) + tm.assert_extension_array_equal(mod, expected_mod)