Skip to content

Commit

Permalink
Extend FromPyObject to support any object supporting the buffer proto…
Browse files Browse the repository at this point in the history
…col when possible.
  • Loading branch information
adamreichold committed Feb 8, 2023
1 parent 2dd9d1b commit 228cfd0
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::os::raw::c_char;
use std::slice::SliceIndex;
use std::str;

use super::bytearray::PyByteArray;
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
use crate::buffer::PyBuffer;
#[cfg(all(Py_LIMITED_API, not(Py_3_11)))]
use crate::types::PyByteArray;

/// Represents a Python `bytes` object.
///
Expand Down Expand Up @@ -129,7 +132,21 @@ impl<I: SliceIndex<[u8]>> Index<I> for PyBytes {
/// If the source object is a `bytes` object, the `Cow` will be borrowed and
/// pointing into the source object, and no copying or heap allocations will happen.
/// If it is a `bytearray`, its contents will be copied to an owned `Cow`.
///
/// When not using the limited API or Python 3.11 or later, this will accept any object
/// containing bytes and supporting the buffer protocol in addition to`bytearray`.
impl<'source> FromPyObject<'source> for Cow<'source, [u8]> {
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok(bytes) = ob.downcast::<PyBytes>() {
return Ok(Cow::Borrowed(bytes.as_bytes()));
}

let buffer = PyBuffer::<u8>::get(ob)?;
buffer.to_vec(ob.py()).map(Cow::Owned)
}

#[cfg(all(Py_LIMITED_API, not(Py_3_11)))]
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok(bytes) = ob.downcast::<PyBytes>() {
return Ok(Cow::Borrowed(bytes.as_bytes()));
Expand Down Expand Up @@ -215,6 +232,15 @@ mod tests {
let cow = byte_array.extract::<Cow<'_, [u8]>>().unwrap();
assert_eq!(cow, Cow::<[u8]>::Owned(b"foobar".to_vec()));

#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
{
let array = py
.eval(r#"__import__("array").array('B', b"foobar")"#, None, None)
.unwrap();
let cow = array.extract::<Cow<'_, [u8]>>().unwrap();
assert_eq!(cow, Cow::<[u8]>::Owned(b"foobar".to_vec()));
}

let something_else_entirely = py.eval("42", None, None).unwrap();
something_else_entirely
.extract::<Cow<'_, [u8]>>()
Expand Down

0 comments on commit 228cfd0

Please sign in to comment.