Skip to content

Commit

Permalink
Avoid returning PyObject and add wrapper types for arro3 export (#269)
Browse files Browse the repository at this point in the history
* wip: avoid pyobject

* Avoid PyObject

* couple more

* flesh out

* Use local pyo3-arrow

* fix compile
  • Loading branch information
kylebarron authored Dec 6, 2024
1 parent b5b8341 commit 3709db9
Show file tree
Hide file tree
Showing 35 changed files with 823 additions and 415 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ numpy = "0.23"
object_store = "0.11"
parquet = "53"
pyo3 = { version = "0.23", features = ["macros", "indexmap"] }
pyo3-arrow = "0.6"
# pyo3-arrow = { path = "./pyo3-arrow" }
# pyo3-arrow = "0.6"
pyo3-arrow = { path = "./pyo3-arrow" }
pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime"] }
pyo3-file = "0.10"
pyo3-object_store = { git = "https://github.com/developmentseed/object-store-rs", rev = "bad34862a92849dd7b69c28cd4c225446d3d15ab" }
Expand Down
19 changes: 10 additions & 9 deletions arro3-compute/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ use arrow_schema::{ArrowError, DataType};
use arrow_select::concat;
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::export::Arro3Scalar;
use pyo3_arrow::input::AnyArray;
use pyo3_arrow::PyScalar;

#[pyfunction]
pub fn max(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
pub fn max(input: AnyArray) -> PyArrowResult<Arro3Scalar> {
match input {
AnyArray::Array(array) => {
let (array, field) = array.into_inner();
let result = max_array(array)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -43,7 +44,7 @@ pub fn max(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {

// Call max_array on intermediate outputs
let result = max_array(concatted)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
}
}
Expand Down Expand Up @@ -112,12 +113,12 @@ fn max_boolean(array: &BooleanArray) -> ArrayRef {
}

#[pyfunction]
pub fn min(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
pub fn min(input: AnyArray) -> PyArrowResult<Arro3Scalar> {
match input {
AnyArray::Array(array) => {
let (array, field) = array.into_inner();
let result = min_array(array)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -138,7 +139,7 @@ pub fn min(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {

// Call min_array on intermediate outputs
let result = min_array(concatted)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
}
}
Expand Down Expand Up @@ -207,12 +208,12 @@ fn min_boolean(array: &BooleanArray) -> ArrayRef {
}

#[pyfunction]
pub fn sum(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
pub fn sum(input: AnyArray) -> PyArrowResult<Arro3Scalar> {
match input {
AnyArray::Array(array) => {
let (array, field) = array.into_inner();
let result = sum_array(array)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -233,7 +234,7 @@ pub fn sum(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {

// Call sum_array on intermediate outputs
let result = sum_array(concatted)?;
Ok(PyScalar::try_new(result, field)?.to_arro3(py)?)
Ok(PyScalar::try_new(result, field)?.into())
}
}
}
Expand Down
42 changes: 32 additions & 10 deletions arro3-compute/src/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,72 @@ use pyo3_arrow::PyArray;

#[pyfunction]
pub fn add(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::add(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::add(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn add_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::add_wrapping(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::add_wrapping(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn div(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::div(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::div(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn mul(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::mul(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::mul(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn mul_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::mul_wrapping(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::mul_wrapping(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn neg(py: Python, array: PyArray) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::neg(array.as_ref())?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::neg(array.as_ref())?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn neg_wrapping(py: Python, array: PyArray) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::neg_wrapping(array.as_ref())?).to_arro3(py)?)
Ok(
PyArray::from_array_ref(numeric::neg_wrapping(array.as_ref())?)
.to_arro3(py)?
.unbind(),
)
}

#[pyfunction]
pub fn rem(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::rem(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::rem(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn sub(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::sub(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::sub(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}

#[pyfunction]
pub fn sub_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::sub_wrapping(&lhs, &rhs)?).to_arro3(py)?)
Ok(PyArray::from_array_ref(numeric::sub_wrapping(&lhs, &rhs)?)
.to_arro3(py)?
.unbind())
}
14 changes: 10 additions & 4 deletions arro3-compute/src/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ pub fn is_null(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
match input {
AnyArray::Array(input) => {
let out = arrow::compute::is_null(input.as_ref())?;
Ok(PyArray::from_array_ref(Arc::new(out)).to_arro3(py)?)
Ok(PyArray::from_array_ref(Arc::new(out))
.to_arro3(py)?
.unbind())
}
AnyArray::Stream(input) => {
let input = input.into_reader()?;
Expand All @@ -25,7 +27,8 @@ pub fn is_null(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
});
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, out_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand All @@ -36,7 +39,9 @@ pub fn is_not_null(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
match input {
AnyArray::Array(input) => {
let out = arrow::compute::is_not_null(input.as_ref())?;
Ok(PyArray::from_array_ref(Arc::new(out)).to_arro3(py)?)
Ok(PyArray::from_array_ref(Arc::new(out))
.to_arro3(py)?
.unbind())
}
AnyArray::Stream(input) => {
let input = input.into_reader()?;
Expand All @@ -48,7 +53,8 @@ pub fn is_not_null(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
});
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, out_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
8 changes: 6 additions & 2 deletions arro3-compute/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn cast(py: Python, input: AnyArray, to_type: PyField) -> PyArrowResult<PyOb
AnyArray::Array(arr) => {
let new_field = to_type.into_inner();
let out = arrow_cast::cast(arr.as_ref(), new_field.data_type())?;
Ok(PyArray::new(out, new_field).to_arro3(py)?)
Ok(PyArray::new(out, new_field).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -36,7 +36,11 @@ pub fn cast(py: Python, input: AnyArray, to_type: PyField) -> PyArrowResult<PyOb
let iter = reader
.into_iter()
.map(move |array| arrow_cast::cast(&array?, &to_type));
Ok(PyArrayReader::new(Box::new(ArrayIterator::new(iter, new_field))).to_arro3(py)?)
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, new_field)))
.to_arro3(py)?
.unbind(),
)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion arro3-compute/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ pub fn concat(py: Python, input: PyChunkedArray) -> PyArrowResult<PyObject> {
let (chunks, field) = input.into_inner();
let array_refs = chunks.iter().map(|arr| arr.as_ref()).collect::<Vec<_>>();
let concatted = arrow_select::concat::concat(array_refs.as_slice())?;
Ok(PyArray::new(concatted, field).to_arro3(py)?)
Ok(PyArray::new(concatted, field).to_arro3(py)?.unbind())
}
5 changes: 3 additions & 2 deletions arro3-compute/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub(crate) fn dictionary_encode(py: Python, array: AnyArray) -> PyArrowResult<Py
AnyArray::Array(array) => {
let (array, _field) = array.into_inner();
let output_array = dictionary_encode_array(array)?;
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?)
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -37,7 +37,8 @@ pub(crate) fn dictionary_encode(py: Python, array: AnyArray) -> PyArrowResult<Py
.map(move |array| dictionary_encode_array(array?));
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, output_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions arro3-compute/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn filter(py: Python, values: AnyArray, predicate: AnyArray) -> PyArrowResul
))?;

let filtered = arrow::compute::filter(values.as_ref(), predicate)?;
Ok(PyArray::new(filtered, values_field).to_arro3(py)?)
Ok(PyArray::new(filtered, values_field).to_arro3(py)?.unbind())
}
(AnyArray::Stream(values), AnyArray::Stream(predicate)) => {
let values = values.into_reader()?;
Expand All @@ -47,7 +47,8 @@ pub fn filter(py: Python, values: AnyArray, predicate: AnyArray) -> PyArrowResul
});
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, values_field)))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
_ => Err(PyValueError::new_err("Unsupported combination of array and stream").into()),
Expand Down
4 changes: 3 additions & 1 deletion arro3-compute/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ use pyo3_arrow::PyArray;
pub fn take(py: Python, values: PyArray, indices: PyArray) -> PyArrowResult<PyObject> {
let output_array =
py.allow_threads(|| arrow_select::take::take(values.as_ref(), indices.as_ref(), None))?;
Ok(PyArray::new(output_array, values.field().clone()).to_arro3(py)?)
Ok(PyArray::new(output_array, values.field().clone())
.to_arro3(py)?
.unbind())
}
5 changes: 3 additions & 2 deletions arro3-compute/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub fn date_part(py: Python, input: AnyArray, part: DatePart) -> PyArrowResult<P
match input {
AnyArray::Array(input) => {
let out = arrow::compute::date_part(input.as_ref(), part.into())?;
Ok(PyArray::from_array_ref(out).to_arro3(py)?)
Ok(PyArray::from_array_ref(out).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -98,7 +98,8 @@ pub fn date_part(py: Python, input: AnyArray, part: DatePart) -> PyArrowResult<P
.map(move |array| arrow::compute::date_part(array?.as_ref(), part));
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, output_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
10 changes: 6 additions & 4 deletions arro3-core/src/accessors/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub(crate) fn dictionary_indices(py: Python, array: AnyArray) -> PyArrowResult<P
AnyArray::Array(array) => {
let (array, _field) = array.into_inner();
let output_array = _dictionary_indices(array)?;
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?)
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -34,7 +34,8 @@ pub(crate) fn dictionary_indices(py: Python, array: AnyArray) -> PyArrowResult<P
.map(move |array| _dictionary_indices(array?));
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, out_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand All @@ -49,7 +50,7 @@ pub(crate) fn dictionary_dictionary(py: Python, array: AnyArray) -> PyArrowResul
AnyArray::Array(array) => {
let (array, _field) = array.into_inner();
let output_array = _dictionary_dictionary(array)?;
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?)
Ok(PyArray::from_array_ref(output_array).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -70,7 +71,8 @@ pub(crate) fn dictionary_dictionary(py: Python, array: AnyArray) -> PyArrowResul
.map(move |array| _dictionary_dictionary(array?));
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, out_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions arro3-core/src/accessors/list_flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn list_flatten(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
let (array, field) = array.into_inner();
let flat_array = flatten_array(array)?;
let flat_field = flatten_field(field)?;
Ok(PyArray::new(flat_array, flat_field).to_arro3(py)?)
Ok(PyArray::new(flat_array, flat_field).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -26,7 +26,8 @@ pub fn list_flatten(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
});
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, flatten_field)))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions arro3-core/src/accessors/list_offsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn list_offsets(py: Python, input: AnyArray, logical: bool) -> PyArrowResult
AnyArray::Array(array) => {
let (array, _field) = array.into_inner();
let offsets = _list_offsets(array, logical)?;
Ok(PyArray::from_array_ref(offsets).to_arro3(py)?)
Ok(PyArray::from_array_ref(offsets).to_arro3(py)?.unbind())
}
AnyArray::Stream(stream) => {
let reader = stream.into_reader()?;
Expand All @@ -36,7 +36,8 @@ pub fn list_offsets(py: Python, input: AnyArray, logical: bool) -> PyArrowResult
.map(move |array| _list_offsets(array?, logical));
Ok(
PyArrayReader::new(Box::new(ArrayIterator::new(iter, out_field.into())))
.to_arro3(py)?,
.to_arro3(py)?
.unbind(),
)
}
}
Expand Down
Loading

0 comments on commit 3709db9

Please sign in to comment.