Skip to content

Commit

Permalink
Also replace IterANextOutput by autoref-based specialization to allow…
Browse files Browse the repository at this point in the history
… returning arbitrary values
  • Loading branch information
adamreichold committed Dec 19, 2023
1 parent 2d28af6 commit 874ab10
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 30 deletions.
5 changes: 3 additions & 2 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,9 @@ const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc")
);
const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc");
const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc");
const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion(
TokenGenerator(|| quote! { _pyo3::class::pyasync::IterANextOutput::<_, _> }),
const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_specialized_conversion(
TokenGenerator(|| quote! { AsyncIterBaseKind, AsyncIterOptionKind, AsyncIterResultOptionKind }),
TokenGenerator(|| quote! { async_iter_tag }),
);
const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT);
const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc")
Expand Down
85 changes: 84 additions & 1 deletion src/callback.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Utilities for a Python callable object that invokes a Rust function.

use crate::err::{PyErr, PyResult};
use crate::exceptions::PyOverflowError;
use crate::exceptions::{PyOverflowError, PyStopAsyncIteration};
use crate::ffi::{self, Py_hash_t};
use crate::{IntoPy, PyObject, Python};
use std::isize;
Expand Down Expand Up @@ -260,3 +260,86 @@ pub trait IterResultOptionKind {
}

impl<Value> IterResultOptionKind for PyResult<Option<Value>> {}

// Autoref-based specialization for handling `__anext__` returning `Option`

#[doc(hidden)]
pub struct AsyncIterBaseTag;

impl AsyncIterBaseTag {
#[inline]
pub fn convert<Value, Target>(self, py: Python<'_>, value: Value) -> PyResult<Target>
where
Value: IntoPyCallbackOutput<Target>,
{
value.convert(py)
}
}

#[doc(hidden)]
pub trait AsyncIterBaseKind {
fn async_iter_tag(&self) -> AsyncIterBaseTag {
AsyncIterBaseTag
}
}

impl<Value> AsyncIterBaseKind for &Value {}

#[doc(hidden)]
pub struct AsyncIterOptionTag;

impl AsyncIterOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: Option<Value>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Some(value) => value.convert(py),
None => Err(PyStopAsyncIteration::new_err(())),
}
}
}

#[doc(hidden)]
pub trait AsyncIterOptionKind {
fn async_iter_tag(&self) -> AsyncIterOptionTag {
AsyncIterOptionTag
}
}

impl<Value> AsyncIterOptionKind for Option<Value> {}

#[doc(hidden)]
pub struct AsyncIterResultOptionTag;

impl AsyncIterResultOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: PyResult<Option<Value>>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Ok(Some(value)) => value.convert(py),
Ok(None) => Err(PyStopAsyncIteration::new_err(())),
Err(err) => Err(err),
}
}
}

#[doc(hidden)]
pub trait AsyncIterResultOptionKind {
fn async_iter_tag(&self) -> AsyncIterResultOptionTag {
AsyncIterResultOptionTag
}
}

impl<Value> AsyncIterResultOptionKind for PyResult<Option<Value>> {}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ pub mod class {
/// For compatibility reasons this has not yet been removed, however will be done so
/// once <https://github.com/rust-lang/rust/issues/30827> is resolved.
pub mod pyasync {
#[allow(deprecated)]
pub use crate::pyclass::{IterANextOutput, PyIterANextOutput};
}

Expand Down
43 changes: 16 additions & 27 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ where
/// Output of `__anext__`.
///
/// <https://docs.python.org/3/reference/expressions.html#agen.__anext__>
#[deprecated(
since = "0.21.0",
note = "Use `Option` or `PyStopAsyncIteration` instead."
)]
pub enum IterANextOutput<T, U> {
/// An expression which the generator yielded.
Yield(T),
Expand All @@ -163,40 +167,25 @@ pub enum IterANextOutput<T, U> {
}

/// An [IterANextOutput] of Python objects.
#[deprecated(
since = "0.21.0",
note = "Use `Option` or `PyStopAsyncIteration` instead."
)]
#[allow(deprecated)]
pub type PyIterANextOutput = IterANextOutput<PyObject, PyObject>;

impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterANextOutput {
fn convert(self, _py: Python<'_>) -> PyResult<*mut ffi::PyObject> {
match self {
IterANextOutput::Yield(o) => Ok(o.into_ptr()),
IterANextOutput::Return(opt) => {
Err(crate::exceptions::PyStopAsyncIteration::new_err((opt,)))
}
}
}
}

impl<T, U> IntoPyCallbackOutput<PyIterANextOutput> for IterANextOutput<T, U>
#[allow(deprecated)]
impl<T, U> IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python<'_>) -> PyResult<PyIterANextOutput> {
match self {
IterANextOutput::Yield(o) => Ok(IterANextOutput::Yield(o.into_py(py))),
IterANextOutput::Return(o) => Ok(IterANextOutput::Return(o.into_py(py))),
}
}
}

impl<T> IntoPyCallbackOutput<PyIterANextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python<'_>) -> PyResult<PyIterANextOutput> {
fn convert(self, py: Python<'_>) -> PyResult<*mut ffi::PyObject> {
match self {
Some(o) => Ok(PyIterANextOutput::Yield(o.into_py(py))),
None => Ok(PyIterANextOutput::Return(py.None().into())),
IterANextOutput::Yield(o) => Ok(o.into_py(py).into_ptr()),
IterANextOutput::Return(o) => Err(crate::exceptions::PyStopAsyncIteration::new_err(
o.into_py(py),
)),
}
}
}
Expand Down

0 comments on commit 874ab10

Please sign in to comment.