Skip to content

Commit

Permalink
use RecordBatchOptions when converting a pyarrow RecordBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-J-Ward committed Aug 28, 2024
1 parent a937869 commit f8d417f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
27 changes: 27 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,33 @@ def test_tensor_array():

del b


def test_empty_recordbatch_with_row_count():
"""
The result of a `count` on a dataset is a RecordBatch with no columns but with `num_rows` set
"""

# If you know how to create an empty RecordBatch with a specific number of rows, please share
# Create an empty schema with no fields
schema = pa.schema([])

# Create an empty RecordBatch with 0 columns
record_batch = pa.RecordBatch.from_arrays([], schema=schema)

# Set the desired number of rows by creating a table and slicing
num_rows = 5 # Replace with your desired number of rows
empty_table = pa.Table.from_batches([record_batch]).slice(0, num_rows)

# Get the first batch from the table which will have the desired number of rows
batch = empty_table.to_batches()[0]

b = rust.round_trip_record_batch(batch)
assert b == batch
assert b.schema == batch.schema
assert b.schema.metadata == batch.schema.metadata

del b

def test_record_batch_reader():
"""
Python -> Rust -> Python
Expand Down
14 changes: 11 additions & 3 deletions arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray};
use arrow_array::{RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
Expand Down Expand Up @@ -333,6 +333,13 @@ impl<T: ToPyArrow> ToPyArrow for Vec<T> {

impl FromPyArrow for RecordBatch {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
let row_count = value
.getattr("num_rows")
.ok()
.and_then(|x| x.extract().ok());
println!("USING row_count: {:?}", row_count);
let options = RecordBatchOptions::default().with_row_count(row_count);

// Newer versions of PyArrow as well as other libraries with Arrow data implement this
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
Expand Down Expand Up @@ -371,7 +378,7 @@ impl FromPyArrow for RecordBatch {
0,
"Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
);
return RecordBatch::try_new(schema, columns).map_err(to_py_err);
return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
}

validate_class("RecordBatch", value)?;
Expand All @@ -386,7 +393,8 @@ impl FromPyArrow for RecordBatch {
.map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
.collect::<PyResult<_>>()?;

let batch = RecordBatch::try_new(schema, arrays).map_err(to_py_err)?;
let batch =
RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
Ok(batch)
}
}
Expand Down

0 comments on commit f8d417f

Please sign in to comment.