From 228cfd009c0fd35b7b3c199e67d56d3ff404ba8e Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sat, 21 Jan 2023 11:07:53 +0100 Subject: [PATCH] Extend FromPyObject to support any object supporting the buffer protocol when possible. --- src/types/bytes.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/types/bytes.rs b/src/types/bytes.rs index 3ed906f22d6..bc5f9190f50 100644 --- a/src/types/bytes.rs +++ b/src/types/bytes.rs @@ -5,7 +5,10 @@ use std::os::raw::c_char; use std::slice::SliceIndex; use std::str; -use super::bytearray::PyByteArray; +#[cfg(any(not(Py_LIMITED_API), Py_3_11))] +use crate::buffer::PyBuffer; +#[cfg(all(Py_LIMITED_API, not(Py_3_11)))] +use crate::types::PyByteArray; /// Represents a Python `bytes` object. /// @@ -129,7 +132,21 @@ impl> Index for PyBytes { /// If the source object is a `bytes` object, the `Cow` will be borrowed and /// pointing into the source object, and no copying or heap allocations will happen. /// If it is a `bytearray`, its contents will be copied to an owned `Cow`. +/// +/// When not using the limited API or Python 3.11 or later, this will accept any object +/// containing bytes and supporting the buffer protocol in addition to`bytearray`. impl<'source> FromPyObject<'source> for Cow<'source, [u8]> { + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(bytes) = ob.downcast::() { + return Ok(Cow::Borrowed(bytes.as_bytes())); + } + + let buffer = PyBuffer::::get(ob)?; + buffer.to_vec(ob.py()).map(Cow::Owned) + } + + #[cfg(all(Py_LIMITED_API, not(Py_3_11)))] fn extract(ob: &'source PyAny) -> PyResult { if let Ok(bytes) = ob.downcast::() { return Ok(Cow::Borrowed(bytes.as_bytes())); @@ -215,6 +232,15 @@ mod tests { let cow = byte_array.extract::>().unwrap(); assert_eq!(cow, Cow::<[u8]>::Owned(b"foobar".to_vec())); + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + { + let array = py + .eval(r#"__import__("array").array('B', b"foobar")"#, None, None) + .unwrap(); + let cow = array.extract::>().unwrap(); + assert_eq!(cow, Cow::<[u8]>::Owned(b"foobar".to_vec())); + } + let something_else_entirely = py.eval("42", None, None).unwrap(); something_else_entirely .extract::>()