Skip to content
/ polars Public
forked from pola-rs/polars

Commit

Permalink
fix: Fix map_elements for List return dtypes (pola-rs#18567)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Sep 5, 2024
1 parent f562b13 commit c85b338
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 33 deletions.
8 changes: 5 additions & 3 deletions crates/polars-python/src/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,11 @@ fn iterator_to_list(
for _ in 0..init_null_count {
builder.append_null()
}
builder
.append_opt_series(first_value)
.map_err(PyPolarsErr::from)?;
if first_value.is_some() {
builder
.append_opt_series(first_value)
.map_err(PyPolarsErr::from)?;
}
for opt_val in it {
match opt_val {
None => builder.append_null(),
Expand Down
67 changes: 37 additions & 30 deletions crates/polars-python/src/map/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
let series = py_pyseries.extract::<PySeries>().unwrap().series;
let dt = series.dtype();
applyer
.apply_lambda_with_list_out_type(py, lambda.to_object(py), null_count, &series, dt)
.apply_lambda_with_list_out_type(
py,
lambda.to_object(py),
null_count,
Some(&series),
dt,
)
.map(|ca| ca.into_series().into())
} else if out.is_instance_of::<PyList>() || out.is_instance_of::<PyTuple>() {
let series = SERIES.call1(py, (out,))?;
Expand All @@ -63,13 +69,14 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
let new_lambda = PyCFunction::new_closure_bound(py, None, None, move |args, _kwargs| {
Python::with_gil(|py| {
let out = lambda_owned.call1(py, args)?;
// check if Series, if not, call series constructor on it
SERIES.call1(py, (out,))
})
})?
.to_object(py);

let result = applyer
.apply_lambda_with_list_out_type(py, new_lambda, null_count, &series, dt)
.apply_lambda_with_list_out_type(py, new_lambda, null_count, Some(&series), dt)
.map(|ca| ca.into_series().into());
match result {
Ok(out) => Ok(out),
Expand Down Expand Up @@ -172,7 +179,7 @@ pub trait ApplyLambda<'a> {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked>;

Expand Down Expand Up @@ -417,10 +424,10 @@ impl<'a> ApplyLambda<'a> for BooleanChunked {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Ok(ChunkedArray::full_null(self.name().clone(), self.len()))
Expand All @@ -434,7 +441,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -449,7 +456,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -720,10 +727,10 @@ where
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Ok(ChunkedArray::full_null(self.name().clone(), self.len()))
Expand All @@ -737,7 +744,7 @@ where
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -752,7 +759,7 @@ where
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -1017,10 +1024,10 @@ impl<'a> ApplyLambda<'a> for StringChunked {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Ok(ChunkedArray::full_null(self.name().clone(), self.len()))
Expand All @@ -1034,7 +1041,7 @@ impl<'a> ApplyLambda<'a> for StringChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -1049,7 +1056,7 @@ impl<'a> ApplyLambda<'a> for StringChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -1441,10 +1448,10 @@ impl<'a> ApplyLambda<'a> for ListChunked {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let pypolars = PyModule::import_bound(py, "polars")?;
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Expand All @@ -1459,7 +1466,7 @@ impl<'a> ApplyLambda<'a> for ListChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -1472,7 +1479,7 @@ impl<'a> ApplyLambda<'a> for ListChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -1868,10 +1875,10 @@ impl<'a> ApplyLambda<'a> for ArrayChunked {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let pypolars = PyModule::import_bound(py, "polars")?;
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Expand All @@ -1886,7 +1893,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -1899,7 +1906,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -2187,10 +2194,10 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let lambda = lambda.bind(py);
if init_null_count == self.len() {
Ok(ChunkedArray::full_null(self.name().clone(), self.len()))
Expand All @@ -2204,7 +2211,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand All @@ -2219,7 +2226,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down Expand Up @@ -2415,10 +2422,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
first_value: Option<&Series>,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;
let skip = usize::from(first_value.is_some());
let lambda = lambda.bind(py);
let it = iter_struct(self)
.skip(init_null_count + skip)
Expand All @@ -2427,7 +2434,7 @@ impl<'a> ApplyLambda<'a> for StructChunked {
dt,
it,
init_null_count,
Some(first_value),
first_value,
self.name().clone(),
self.len(),
)
Expand Down
27 changes: 27 additions & 0 deletions crates/polars-python/src/series/map.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use pyo3::prelude::*;
use pyo3::types::PyCFunction;
use pyo3::Python;

use super::PySeries;
use crate::error::PyPolarsErr;
use crate::map::series::{call_lambda_and_extract, ApplyLambda};
use crate::prelude::*;
use crate::py_modules::SERIES;
use crate::{apply_method_all_arrow_series2, raise_err};

#[pymethods]
Expand Down Expand Up @@ -222,6 +224,31 @@ impl PySeries {

ca.into_series()
},
Some(DataType::List(inner)) => {
// Make sure the function returns a Series of the correct data type.
let function_owned = function.to_object(py);
let dtype_py = Wrap((*inner).clone()).to_object(py);
let function_wrapped =
PyCFunction::new_closure_bound(py, None, None, move |args, _kwargs| {
Python::with_gil(|py| {
let out = function_owned.call1(py, args)?;
SERIES.call1(py, ("", out, dtype_py.clone()))
})
})?
.to_object(py);

let ca = dispatch_apply!(
series,
apply_lambda_with_list_out_type,
py,
function_wrapped,
0,
None,
inner.as_ref()
)?;

ca.into_series()
},
#[cfg(feature = "object")]
Some(DataType::Object(_, _)) => {
let ca = dispatch_apply!(
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,15 @@ def test_map_elements_list_dtype_18472() -> None:
result = s.map_elements(lambda s: [i.strip() if i else None for i in s])
expected = pl.Series([[None], ["abc", None]])
assert_series_equal(result, expected)


def test_map_elements_list_return_dtype() -> None:
s = pl.Series([[1], [2, 3]])
return_dtype = pl.List(pl.UInt16)

result = s.map_elements(
lambda s: [i + 1 for i in s],
return_dtype=return_dtype,
)
expected = pl.Series([[2], [3, 4]], dtype=return_dtype)
assert_series_equal(result, expected)

0 comments on commit c85b338

Please sign in to comment.