Skip to content

Commit

Permalink
ENH: Implement interpolation for arrow and masked dtypes (#56757)
Browse files Browse the repository at this point in the history
* ENH: Implement interpolation for arrow and masked dtypes

* Fixup

* Fix typing

* Update
  • Loading branch information
phofl authored Jan 10, 2024
1 parent fce520d commit 5fc2ed2
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ Other enhancements
- :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`)
- :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area``; 3rd party :class:`.ExtensionArray` authors need to add this argument to the method ``_pad_or_backfill`` (:issue:`56492`)
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
- Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`)
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
- Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`)
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)
Expand Down
40 changes: 40 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def floordiv_compat(
AxisInt,
Dtype,
FillnaOptions,
InterpolateOptions,
Iterator,
NpDtype,
NumpySorter,
Expand Down Expand Up @@ -2068,6 +2069,45 @@ def _maybe_convert_setitem_value(self, value):
raise TypeError(msg) from err
return value

def interpolate(
self,
*,
method: InterpolateOptions,
axis: int,
index,
limit,
limit_direction,
limit_area,
copy: bool,
**kwargs,
) -> Self:
"""
See NDFrame.interpolate.__doc__.
"""
# NB: we return type(self) even if copy=False
mask = self.isna()
if self.dtype.kind == "f":
data = self._pa_array.to_numpy()
elif self.dtype.kind in "iu":
data = self.to_numpy(dtype="f8", na_value=0.0)
else:
raise NotImplementedError(
f"interpolate is not implemented for dtype={self.dtype}"
)

missing.interpolate_2d_inplace(
data,
method=method,
axis=0,
index=index,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
mask=mask,
**kwargs,
)
return type(self)(self._box_pa_array(pa.array(data, mask=mask)))

@classmethod
def _if_else(
cls,
Expand Down
54 changes: 54 additions & 0 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AxisInt,
DtypeObj,
FillnaOptions,
InterpolateOptions,
NpDtype,
PositionalIndexer,
Scalar,
Expand Down Expand Up @@ -99,6 +100,7 @@
NumpyValueArrayLike,
)
from pandas._libs.missing import NAType
from pandas.core.arrays import FloatingArray

from pandas.compat.numpy import function as nv

Expand Down Expand Up @@ -1521,6 +1523,58 @@ def all(
else:
return self.dtype.na_value

def interpolate(
self,
*,
method: InterpolateOptions,
axis: int,
index,
limit,
limit_direction,
limit_area,
copy: bool,
**kwargs,
) -> FloatingArray:
"""
See NDFrame.interpolate.__doc__.
"""
# NB: we return type(self) even if copy=False
if self.dtype.kind == "f":
if copy:
data = self._data.copy()
mask = self._mask.copy()
else:
data = self._data
mask = self._mask
elif self.dtype.kind in "iu":
copy = True
data = self._data.astype("f8")
mask = self._mask.copy()
else:
raise NotImplementedError(
f"interpolate is not implemented for dtype={self.dtype}"
)

missing.interpolate_2d_inplace(
data,
method=method,
axis=0,
index=index,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
mask=mask,
**kwargs,
)
if not copy:
return self # type: ignore[return-value]
if self.dtype.kind == "f":
return type(self)._simple_new(data, mask) # type: ignore[return-value]
else:
from pandas.core.arrays import FloatingArray

return FloatingArray._simple_new(data, mask)

def _accumulate(
self, name: str, *, skipna: bool = True, **kwargs
) -> BaseMaskedArray:
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def interpolate_2d_inplace(
limit_direction: str = "forward",
limit_area: str | None = None,
fill_value: Any | None = None,
mask=None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -396,6 +397,7 @@ def func(yvalues: np.ndarray) -> None:
limit_area=limit_area_validated,
fill_value=fill_value,
bounds_error=False,
mask=mask,
**kwargs,
)

Expand Down Expand Up @@ -440,6 +442,7 @@ def _interpolate_1d(
fill_value: Any | None = None,
bounds_error: bool = False,
order: int | None = None,
mask=None,
**kwargs,
) -> None:
"""
Expand All @@ -453,8 +456,10 @@ def _interpolate_1d(
-----
Fills 'yvalues' in-place.
"""

invalid = isna(yvalues)
if mask is not None:
invalid = mask
else:
invalid = isna(yvalues)
valid = ~invalid

if not valid.any():
Expand Down Expand Up @@ -531,7 +536,10 @@ def _interpolate_1d(
**kwargs,
)

if is_datetimelike:
if mask is not None:
mask[:] = False
mask[preserve_nans] = True
elif is_datetimelike:
yvalues[preserve_nans] = NaT.value
else:
yvalues[preserve_nans] = np.nan
Expand Down
41 changes: 37 additions & 4 deletions pandas/tests/frame/methods/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,41 @@ def test_interpolate_empty_df(self):
assert result is None
tm.assert_frame_equal(df, expected)

def test_interpolate_ea_raise(self):
def test_interpolate_ea(self, any_int_ea_dtype):
# GH#55347
df = DataFrame({"a": [1, None, 2]}, dtype="Int64")
with pytest.raises(NotImplementedError, match="does not implement"):
df.interpolate()
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=any_int_ea_dtype)
orig = df.copy()
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="Float64")
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(df, orig)

@pytest.mark.parametrize(
"dtype",
[
"Float64",
"Float32",
pytest.param("float32[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_interpolate_ea_float(self, dtype):
# GH#55347
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype)
orig = df.copy()
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype=dtype)
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(df, orig)

@pytest.mark.parametrize(
"dtype",
["int64", "uint64", "int32", "int16", "int8", "uint32", "uint16", "uint8"],
)
def test_interpolate_arrow(self, dtype):
# GH#55347
pytest.importorskip("pyarrow")
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype + "[pyarrow]")
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="float64[pyarrow]")
tm.assert_frame_equal(result, expected)

0 comments on commit 5fc2ed2

Please sign in to comment.