Skip to content

Commit

Permalink
Support zero column RecordBatches in pyarrow integration (use Recor…
Browse files Browse the repository at this point in the history
…dBatchOptions when converting a pyarrow RecordBatch) (#6320)

* use RecordBatchOptions when converting a pyarrow RecordBatch

Ref: #6318

* add assertion that num_rows persists through the round trip

* add implementation comment

* nicer creation of empty recordbatch in test_empty_recordbatch_with_row_count

* use len provided by pycapsule interface when available

* update test comment
  • Loading branch information
Michael-J-Ward authored Aug 31, 2024
1 parent 6e50503 commit 0c15191
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
23 changes: 23 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,29 @@ def test_tensor_array():

del b


def test_empty_recordbatch_with_row_count():
"""
A pyarrow.RecordBatch with no columns but with `num_rows` set.
`datafusion-python` gets this as the result of a `count(*)` query.
"""

# Create an empty schema with no fields
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([])
num_rows = 4
assert batch.num_rows == num_rows
assert batch.num_columns == 0

b = rust.round_trip_record_batch(batch)
assert b == batch
assert b.schema == batch.schema
assert b.schema.metadata == batch.schema.metadata

assert b.num_rows == batch.num_rows

del b

def test_record_batch_reader():
"""
Python -> Rust -> Python
Expand Down
14 changes: 11 additions & 3 deletions arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray};
use arrow_array::{RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
Expand Down Expand Up @@ -361,6 +361,7 @@ impl FromPyArrow for RecordBatch {
"Expected Struct type from __arrow_c_array.",
));
}
let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
let array = StructArray::from(array_data);
// StructArray does not embed metadata from schema. We need to override
// the output schema with the schema from the capsule.
Expand All @@ -371,7 +372,7 @@ impl FromPyArrow for RecordBatch {
0,
"Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
);
return RecordBatch::try_new(schema, columns).map_err(to_py_err);
return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
}

validate_class("RecordBatch", value)?;
Expand All @@ -386,7 +387,14 @@ impl FromPyArrow for RecordBatch {
.map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
.collect::<PyResult<_>>()?;

let batch = RecordBatch::try_new(schema, arrays).map_err(to_py_err)?;
let row_count = value
.getattr("num_rows")
.ok()
.and_then(|x| x.extract().ok());
let options = RecordBatchOptions::default().with_row_count(row_count);

let batch =
RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
Ok(batch)
}
}
Expand Down

0 comments on commit 0c15191

Please sign in to comment.