From aff86e704dabecbf99edd1e0ad62c216819dbc15 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Wed, 15 Nov 2023 13:18:45 -0500 Subject: [PATCH] Implement Arrow PyCapsule Interface (#5070) * arrow ffi array copy * remove copy_ffi_array * docstring * wip: pycapsule support * return * Update arrow/src/pyarrow.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * remove sync impl * Update arrow/src/pyarrow.rs Co-authored-by: Will Jones * Remove copy() * Need &mut FFI_ArrowArray for std::mem::replace * Use std::ptr::replace * update comments * Minimize unsafe block * revert pub release functions * Add RecordBatch and Stream conversion * fix returns * Fix return type * Fix name * fix ci * Add tests * Add table test * skip if pre pyarrow 14 * bump python version in CI to use pyarrow 14 * Add record batch test * Update arrow/src/pyarrow.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * run on pyarrow 13 and 14 * Update .github/workflows/integration.yml Co-authored-by: Will Jones --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Co-authored-by: Will Jones --- .github/workflows/integration.yml | 6 +- arrow-pyarrow-integration-testing/README.md | 2 + .../tests/test_sql.py | 138 +++++++++++++++++- arrow-schema/src/ffi.rs | 2 + arrow/src/pyarrow.rs | 134 ++++++++++++++++- 5 files changed, 274 insertions(+), 8 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 6e2b4420408a..f939a6a13b58 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -106,6 +106,8 @@ jobs: strategy: matrix: rust: [ stable ] + # PyArrow 13 was the last version prior to introduction to Arrow PyCapsules + pyarrow: [ "13", "14" ] steps: - uses: actions/checkout@v4 with: @@ -128,14 +130,14 @@ jobs: key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - uses: actions/setup-python@v4 with: - python-version: '3.7' + python-version: '3.8' - name: Upgrade pip and setuptools run: pip install --upgrade pip setuptools wheel virtualenv - name: Create virtualenv and install dependencies run: | virtualenv venv source venv/bin/activate - pip install maturin toml pytest pytz pyarrow>=5.0 + pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }} - name: Run Rust tests run: | source venv/bin/activate diff --git a/arrow-pyarrow-integration-testing/README.md b/arrow-pyarrow-integration-testing/README.md index e63953ad7900..5ca2ea76b88c 100644 --- a/arrow-pyarrow-integration-testing/README.md +++ b/arrow-pyarrow-integration-testing/README.md @@ -25,6 +25,7 @@ Note that this crate uses two languages and an external ABI: * `Rust` * `Python` * C ABI privately exposed by `Pyarrow`. +* PyCapsule ABI publicly exposed by `pyarrow` ## Basic idea @@ -36,6 +37,7 @@ we can use pyarrow's interface to move pointers from and to Rust. ## Relevant literature * [Arrow's CDataInterface](https://arrow.apache.org/docs/format/CDataInterface.html) +* [Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html) * [Rust's FFI](https://doc.rust-lang.org/nomicon/ffi.html) * [Pyarrow private binds](https://github.com/apache/arrow/blob/ae1d24efcc3f1ac2a876d8d9f544a34eb04ae874/python/pyarrow/array.pxi#L1226) * [PyO3](https://docs.rs/pyo3/0.12.1/pyo3/index.html) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 1748fd3ffb6b..16d4e0f12f88 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -27,6 +27,8 @@ import arrow_pyarrow_integration_testing as rust +PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14 + @contextlib.contextmanager def no_pyarrow_leak(): @@ -113,6 +115,34 @@ def assert_pyarrow_leak(): _unsupported_pyarrow_types = [ ] +# As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface +# (https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). +# This defines that Arrow consumers should allow any object that has specific "dunder" +# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle +# _any_ class, without pyarrow-specific handling. +class SchemaWrapper: + def __init__(self, schema): + self.schema = schema + + def __arrow_c_schema__(self): + return self.schema.__arrow_c_schema__() + + +class ArrayWrapper: + def __init__(self, array): + self.array = array + + def __arrow_c_array__(self): + return self.array.__arrow_c_array__() + + +class StreamWrapper: + def __init__(self, stream): + self.stream = stream + + def __arrow_c_stream__(self): + return self.stream.__arrow_c_stream__() + @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) def test_type_roundtrip(pyarrow_type): @@ -120,6 +150,14 @@ def test_type_roundtrip(pyarrow_type): assert restored == pyarrow_type assert restored is not pyarrow_type +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) +def test_type_roundtrip_pycapsule(pyarrow_type): + wrapped = SchemaWrapper(pyarrow_type) + restored = rust.round_trip_type(wrapped) + assert restored == pyarrow_type + assert restored is not pyarrow_type + @pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) def test_type_roundtrip_raises(pyarrow_type): @@ -138,6 +176,20 @@ def test_field_roundtrip(pyarrow_type): field = rust.round_trip_field(pyarrow_field) assert field == pyarrow_field +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) +def test_field_roundtrip_pycapsule(pyarrow_type): + pyarrow_field = pa.field("test", pyarrow_type, nullable=True) + wrapped = SchemaWrapper(pyarrow_field) + field = rust.round_trip_field(wrapped) + assert field == wrapped.schema + + if pyarrow_type != pa.null(): + # A null type field may not be non-nullable + pyarrow_field = pa.field("test", pyarrow_type, nullable=False) + field = rust.round_trip_field(wrapped) + assert field == wrapped.schema + def test_field_metadata_roundtrip(): metadata = {"hello": "World! 😊", "x": "2"} pyarrow_field = pa.field("test", pa.int32(), metadata=metadata) @@ -163,6 +215,17 @@ def test_primitive_python(): del b +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +def test_primitive_python_pycapsule(): + """ + Python -> Rust -> Python + """ + a = pa.array([1, 2, 3]) + wrapped = ArrayWrapper(a) + b = rust.double(wrapped) + assert b == pa.array([2, 4, 6]) + + def test_primitive_rust(): """ Rust -> Python -> Rust @@ -433,6 +496,33 @@ def test_record_batch_reader(): got_batches = list(b) assert got_batches == batches +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +def test_record_batch_reader_pycapsule(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + a = pa.RecordBatchReader.from_batches(schema, batches) + wrapped = StreamWrapper(a) + b = rust.round_trip_record_batch_reader(wrapped) + + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches + + # Also try the boxed reader variant + a = pa.RecordBatchReader.from_batches(schema, batches) + wrapped = StreamWrapper(a) + b = rust.boxed_reader_roundtrip(wrapped) + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches + + def test_record_batch_reader_error(): schema = pa.schema([('ints', pa.list_(pa.int32()))]) @@ -453,24 +543,64 @@ def iter_batches(): with pytest.raises(ValueError, match="invalid utf-8"): rust.round_trip_record_batch_reader(reader) + +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +def test_record_batch_pycapsule(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batch = pa.record_batch([[[1], [2, 42]]], schema) + wrapped = StreamWrapper(batch) + b = rust.round_trip_record_batch_reader(wrapped) + new_table = b.read_all() + new_batches = new_table.to_batches() + + assert len(new_batches) == 1 + new_batch = new_batches[0] + + assert batch == new_batch + assert batch.schema == new_batch.schema + + +@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") +def test_table_pycapsule(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + table = pa.Table.from_batches(batches) + wrapped = StreamWrapper(table) + b = rust.round_trip_record_batch_reader(wrapped) + new_table = b.read_all() + + assert table.schema == new_table.schema + assert table == new_table + assert len(table.to_batches()) == len(new_table.to_batches()) + + def test_reject_other_classes(): # Arbitrary type that is not a PyArrow type not_pyarrow = ["hello"] with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"): rust.round_trip_array(not_pyarrow) - + with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"): rust.round_trip_schema(not_pyarrow) - + with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"): rust.round_trip_field(not_pyarrow) - + with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"): rust.round_trip_type(not_pyarrow) with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"): rust.round_trip_record_batch(not_pyarrow) - + with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"): rust.round_trip_record_batch_reader(not_pyarrow) diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index 7e33a78fec27..640a7de79878 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -351,6 +351,8 @@ impl Drop for FFI_ArrowSchema { } } +unsafe impl Send for FFI_ArrowSchema {} + impl TryFrom<&FFI_ArrowSchema> for DataType { type Error = ArrowError; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 517c333addde..4d262b0d106f 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -59,12 +59,12 @@ use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; use std::sync::Arc; -use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; -use pyo3::types::{PyList, PyTuple}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; use crate::array::{make_array, ArrayData}; use crate::datatypes::{DataType, Field, Schema}; @@ -118,8 +118,40 @@ fn validate_class(expected: &str, value: &PyAny) -> PyResult<()> { Ok(()) } +fn validate_pycapsule(capsule: &PyCapsule, name: &str) -> PyResult<()> { + let capsule_name = capsule.name()?; + if capsule_name.is_none() { + return Err(PyValueError::new_err( + "Expected schema PyCapsule to have name set.", + )); + } + + let capsule_name = capsule_name.unwrap().to_str()?; + if capsule_name != name { + return Err(PyValueError::new_err(format!( + "Expected name '{}' in PyCapsule, instead got '{}'", + name, capsule_name + ))); + } + + Ok(()) +} + impl FromPyArrow for DataType { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_schema__")? { + let capsule: &PyCapsule = + PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?; + validate_pycapsule(capsule, "arrow_schema")?; + + let schema_ptr = unsafe { capsule.reference::() }; + let dtype = DataType::try_from(schema_ptr).map_err(to_py_err)?; + return Ok(dtype); + } + validate_class("DataType", value)?; let c_schema = FFI_ArrowSchema::empty(); @@ -143,6 +175,19 @@ impl ToPyArrow for DataType { impl FromPyArrow for Field { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_schema__")? { + let capsule: &PyCapsule = + PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?; + validate_pycapsule(capsule, "arrow_schema")?; + + let schema_ptr = unsafe { capsule.reference::() }; + let field = Field::try_from(schema_ptr).map_err(to_py_err)?; + return Ok(field); + } + validate_class("Field", value)?; let c_schema = FFI_ArrowSchema::empty(); @@ -166,6 +211,19 @@ impl ToPyArrow for Field { impl FromPyArrow for Schema { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_schema__")? { + let capsule: &PyCapsule = + PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?; + validate_pycapsule(capsule, "arrow_schema")?; + + let schema_ptr = unsafe { capsule.reference::() }; + let schema = Schema::try_from(schema_ptr).map_err(to_py_err)?; + return Ok(schema); + } + validate_class("Schema", value)?; let c_schema = FFI_ArrowSchema::empty(); @@ -189,6 +247,30 @@ impl ToPyArrow for Schema { impl FromPyArrow for ArrayData { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_array__")? { + let tuple = value.getattr("__arrow_c_array__")?.call0()?; + + if !tuple.is_instance_of::() { + return Err(PyTypeError::new_err( + "Expected __arrow_c_array__ to return a tuple.", + )); + } + + let schema_capsule: &PyCapsule = PyTryInto::try_into(tuple.get_item(0)?)?; + let array_capsule: &PyCapsule = PyTryInto::try_into(tuple.get_item(1)?)?; + + validate_pycapsule(schema_capsule, "arrow_schema")?; + validate_pycapsule(array_capsule, "arrow_array")?; + + let schema_ptr = unsafe { schema_capsule.reference::() }; + let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray; + let array = unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::empty()) }; + return ffi::from_ffi(array, schema_ptr).map_err(to_py_err); + } + validate_class("Array", value)?; // prepare a pointer to receive the Array struct @@ -247,6 +329,37 @@ impl ToPyArrow for Vec { impl FromPyArrow for RecordBatch { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_array__")? { + let tuple = value.getattr("__arrow_c_array__")?.call0()?; + + if !tuple.is_instance_of::() { + return Err(PyTypeError::new_err( + "Expected __arrow_c_array__ to return a tuple.", + )); + } + + let schema_capsule: &PyCapsule = PyTryInto::try_into(tuple.get_item(0)?)?; + let array_capsule: &PyCapsule = PyTryInto::try_into(tuple.get_item(1)?)?; + + validate_pycapsule(schema_capsule, "arrow_schema")?; + validate_pycapsule(array_capsule, "arrow_array")?; + + let schema_ptr = unsafe { schema_capsule.reference::() }; + let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray; + let ffi_array = unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::empty()) }; + let array_data = ffi::from_ffi(ffi_array, schema_ptr).map_err(to_py_err)?; + if !matches!(array_data.data_type(), DataType::Struct(_)) { + return Err(PyTypeError::new_err( + "Expected Struct type from __arrow_c_array.", + )); + } + let array = StructArray::from(array_data); + return Ok(array.into()); + } + validate_class("RecordBatch", value)?; // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches let schema = value.getattr("schema")?; @@ -276,6 +389,23 @@ impl ToPyArrow for RecordBatch { /// Supports conversion from `pyarrow.RecordBatchReader` to [ArrowArrayStreamReader]. impl FromPyArrow for ArrowArrayStreamReader { fn from_pyarrow(value: &PyAny) -> PyResult { + // Newer versions of PyArrow as well as other libraries with Arrow data implement this + // method, so prefer it over _export_to_c. + // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + if value.hasattr("__arrow_c_stream__")? { + let capsule: &PyCapsule = + PyTryInto::try_into(value.getattr("__arrow_c_stream__")?.call0()?)?; + validate_pycapsule(capsule, "arrow_array_stream")?; + + let stream_ptr = capsule.pointer() as *mut FFI_ArrowArrayStream; + let stream = unsafe { std::ptr::replace(stream_ptr, FFI_ArrowArrayStream::empty()) }; + + let stream_reader = ArrowArrayStreamReader::try_new(stream) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + return Ok(stream_reader); + } + validate_class("RecordBatchReader", value)?; // prepare a pointer to receive the stream struct