Skip to content

Commit

Permalink
Merge pull request #3661 from PyO3/iter-output-type
Browse files Browse the repository at this point in the history
Replace (A)IterNextOutput by autoref-based specialization to allow returning arbitrary value
  • Loading branch information
adamreichold authored Dec 20, 2023
2 parents a3c92fa + 5528895 commit 1b3dc6d
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 136 deletions.
119 changes: 118 additions & 1 deletion guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,124 @@ Python::with_gil(|py| {
});
```

### `PyType::name` is now `PyType::qualname`
### `Iter(A)NextOutput` are deprecated

The `__next__` and `__anext__` magic methods can now return any type convertible into Python objects directly just like all other `#[pymethods]`. The `IterNextOutput` used by `__next__` and `IterANextOutput` used by `__anext__` are subsequently deprecated. Most importantly, this change allows returning an awaitable from `__anext__` without non-sensically wrapping it into `Yield` or `Some`. Only the return types `Option<T>` and `Result<Option<T>, E>` are still handled in a special manner where `Some(val)` yields `val` and `None` stops iteration.

Starting with an implementation of a Python iterator using `IterNextOutput`, e.g.

```rust
#![allow(deprecated)]
use pyo3::prelude::*;
use pyo3::iter::IterNextOutput;

#[pyclass]
struct PyClassIter {
count: usize,
}

#[pymethods]
impl PyClassIter {
fn __next__(&mut self) -> IterNextOutput<usize, &'static str> {
if self.count < 5 {
self.count += 1;
IterNextOutput::Yield(self.count)
} else {
IterNextOutput::Return("done")
}
}
}
```

If returning `"done"` via `StopIteration` is not really required, this should be written as

```rust
use pyo3::prelude::*;

#[pyclass]
struct PyClassIter {
count: usize,
}

#[pymethods]
impl PyClassIter {
fn __next__(&mut self) -> Option<usize> {
if self.count < 5 {
self.count += 1;
Some(self.count)
} else {
None
}
}
}
```

This form also has additional benefits: It has already worked in previous PyO3 versions, it matches the signature of Rust's [`Iterator` trait](https://doc.rust-lang.org/stable/std/iter/trait.Iterator.html) and it allows using a fast path in CPython which completely avoids the cost of raising a `StopIteration` exception. Note that using [`Option::transpose`](https://doc.rust-lang.org/stable/std/option/enum.Option.html#method.transpose) and the `Result<Option<T>, E>` variant, this form can also be used to wrap fallible iterators.

Alternatively, the implementation can also be done as it would in Python itself, i.e. by "raising" a `StopIteration` exception

```rust
use pyo3::prelude::*;
use pyo3::exceptions::PyStopIteration;

#[pyclass]
struct PyClassIter {
count: usize,
}

#[pymethods]
impl PyClassIter {
fn __next__(&mut self) -> PyResult<usize> {
if self.count < 5 {
self.count += 1;
Ok(self.count)
} else {
Err(PyStopIteration::new_err("done"))
}
}
}
```

Finally, an asynchronous iterator can directly return an awaitable without confusing wrapping

```rust
use pyo3::prelude::*;

#[pyclass]
struct PyClassAwaitable {
number: usize,
}

#[pymethods]
impl PyClassAwaitable {
fn __next__(&self) -> usize {
self.number
}

fn __await__(slf: Py<Self>) -> Py<Self> {
slf
}
}

#[pyclass]
struct PyClassAsyncIter {
number: usize,
}

#[pymethods]
impl PyClassAsyncIter {
fn __anext__(&mut self) -> PyClassAwaitable {
self.number += 1;
PyClassAwaitable { number: self.number }
}

fn __aiter__(slf: Py<Self>) -> Py<Self> {
slf
}
}
```

### `PyType::name` has been renamed to `PyType::qualname`

`PyType::name` has been renamed to `PyType::qualname` to indicate that it does indeed return the [qualified name](https://docs.python.org/3/glossary.html#term-qualified-name), matching the `__qualname__` attribute. The newly added `PyType::name` yields the full name including the module name now which corresponds to `__module__.__name__` on the level of attributes.

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3661.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The `Iter(A)NextOutput` types are now deprecated and `__(a)next__` can directly return anything which can be converted into Python objects, i.e. awaitables do not need to be wrapped into `IterANextOutput` or `Option` any more. `Option` can still be used as well and returning `None` will trigger the fast path for `__next__`, stopping iteration without having to raise a `StopIteration` exception.
76 changes: 54 additions & 22 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ impl FnType {
}
}

pub fn self_arg(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream {
pub fn self_arg(
&self,
cls: Option<&syn::Type>,
error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> TokenStream {
match self {
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => {
let mut receiver = st.receiver(
cls.expect("no class given for Fn with a \"self\" receiver"),
error_mode,
holders,
);
syn::Token![,](Span::call_site()).to_tokens(&mut receiver);
receiver
Expand Down Expand Up @@ -161,7 +167,12 @@ impl ExtractErrorMode {
}

impl SelfType {
pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream {
pub fn receiver(
&self,
cls: &syn::Type,
error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> TokenStream {
// Due to use of quote_spanned in this function, need to bind these idents to the
// main macro callsite.
let py = syn::Ident::new("py", Span::call_site());
Expand All @@ -173,10 +184,15 @@ impl SelfType {
} else {
syn::Ident::new("extract_pyclass_ref", *span)
};
let holder = syn::Ident::new(&format!("holder_{}", holders.len()), *span);
holders.push(quote_spanned! { *span =>
#[allow(clippy::let_unit_value)]
let mut #holder = _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT;
});
error_mode.handle_error(quote_spanned! { *span =>
_pyo3::impl_::extract_argument::#method::<#cls>(
#py.from_borrowed_ptr::<_pyo3::PyAny>(#slf),
&mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT },
&mut #holder,
)
})
}
Expand Down Expand Up @@ -457,9 +473,6 @@ impl<'a> FnSpec<'a> {
ident: &proc_macro2::Ident,
cls: Option<&syn::Type>,
) -> Result<TokenStream> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;

let mut cancel_handle_iter = self
.signature
.arguments
Expand All @@ -473,7 +486,9 @@ impl<'a> FnSpec<'a> {
}
}

let rust_call = |args: Vec<TokenStream>| {
let rust_call = |args: Vec<TokenStream>, holders: &mut Vec<TokenStream>| {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders);

let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand All @@ -486,14 +501,22 @@ impl<'a> FnSpec<'a> {
None => quote!(None),
};
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&__guard, #(#args),*).await }
}},
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}},
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
holders.pop().unwrap(); // does not actually use holder created by `self_arg`

quote! {{
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&__guard, #(#args),*).await }
}}
}
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => {
holders.pop().unwrap(); // does not actually use holder created by `self_arg`

quote! {{
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}}
}
_ => quote! { function(#self_arg #(#args),*) },
};
let mut call = quote! {{
Expand All @@ -519,6 +542,7 @@ impl<'a> FnSpec<'a> {
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};

let func_name = &self.name;
let rust_name = if let Some(cls) = cls {
quote!(#cls::#func_name)
} else {
Expand All @@ -527,6 +551,7 @@ impl<'a> FnSpec<'a> {

Ok(match self.convention {
CallingConvention::Noargs => {
let mut holders = Vec::new();
let args = self
.signature
.arguments
Expand All @@ -541,21 +566,23 @@ impl<'a> FnSpec<'a> {
}
})
.collect();
let call = rust_call(args);
let call = rust_call(args, &mut holders);

quote! {
unsafe fn #ident<'py>(
py: _pyo3::Python<'py>,
_slf: *mut _pyo3::ffi::PyObject,
) -> _pyo3::PyResult<*mut _pyo3::ffi::PyObject> {
let function = #rust_name; // Shadow the function name to avoid #3017
#( #holders )*
#call
}
}
}
CallingConvention::Fastcall => {
let (arg_convert, args) = impl_arg_params(self, cls, true)?;
let call = rust_call(args);
let mut holders = Vec::new();
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders)?;
let call = rust_call(args, &mut holders);
quote! {
unsafe fn #ident<'py>(
py: _pyo3::Python<'py>,
Expand All @@ -566,13 +593,15 @@ impl<'a> FnSpec<'a> {
) -> _pyo3::PyResult<*mut _pyo3::ffi::PyObject> {
let function = #rust_name; // Shadow the function name to avoid #3017
#arg_convert
#( #holders )*
#call
}
}
}
CallingConvention::Varargs => {
let (arg_convert, args) = impl_arg_params(self, cls, false)?;
let call = rust_call(args);
let mut holders = Vec::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders)?;
let call = rust_call(args, &mut holders);
quote! {
unsafe fn #ident<'py>(
py: _pyo3::Python<'py>,
Expand All @@ -582,13 +611,15 @@ impl<'a> FnSpec<'a> {
) -> _pyo3::PyResult<*mut _pyo3::ffi::PyObject> {
let function = #rust_name; // Shadow the function name to avoid #3017
#arg_convert
#( #holders )*
#call
}
}
}
CallingConvention::TpNew => {
let (arg_convert, args) = impl_arg_params(self, cls, false)?;
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let mut holders = Vec::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders)?;
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, &mut holders);
let call = quote! { #rust_name(#self_arg #(#args),*) };
quote! {
unsafe fn #ident(
Expand All @@ -600,6 +631,7 @@ impl<'a> FnSpec<'a> {
use _pyo3::callback::IntoPyCallbackOutput;
let function = #rust_name; // Shadow the function name to avoid #3017
#arg_convert
#( #holders )*
let result = #call;
let initializer: _pyo3::PyClassInitializer::<#cls> = result.convert(py)?;
let cell = initializer.create_cell_from_subtype(py, _slf)?;
Expand Down
9 changes: 3 additions & 6 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,23 @@ pub fn impl_arg_params(
spec: &FnSpec<'_>,
self_: Option<&syn::Type>,
fastcall: bool,
holders: &mut Vec<TokenStream>,
) -> Result<(TokenStream, Vec<TokenStream>)> {
let args_array = syn::Ident::new("output", Span::call_site());

if !fastcall && is_forwarded_args(&spec.signature) {
// In the varargs convention, we can just pass though if the signature
// is (*args, **kwds).
let mut holders = Vec::new();
let arg_convert = spec
.signature
.arguments
.iter()
.map(|arg| impl_arg_param(arg, &mut 0, &args_array, &mut holders))
.map(|arg| impl_arg_param(arg, &mut 0, &args_array, holders))
.collect::<Result<_>>()?;
return Ok((
quote! {
let _args = py.from_borrowed_ptr::<_pyo3::types::PyTuple>(_args);
let _kwargs: ::std::option::Option<&_pyo3::types::PyDict> = py.from_borrowed_ptr_or_opt(_kwargs);
#( #holders )*
},
arg_convert,
));
Expand Down Expand Up @@ -75,12 +74,11 @@ pub fn impl_arg_params(
let num_params = positional_parameter_names.len() + keyword_only_parameters.len();

let mut option_pos = 0;
let mut holders = Vec::new();
let param_conversion = spec
.signature
.arguments
.iter()
.map(|arg| impl_arg_param(arg, &mut option_pos, &args_array, &mut holders))
.map(|arg| impl_arg_param(arg, &mut option_pos, &args_array, holders))
.collect::<Result<_>>()?;

let args_handler = if spec.signature.python_signature.varargs.is_some() {
Expand Down Expand Up @@ -134,7 +132,6 @@ pub fn impl_arg_params(
keyword_only_parameters: &[#(#keyword_only_parameters),*],
};
let mut #args_array = [::std::option::Option::None; #num_params];
#( #holders )*
let (_args, _kwargs) = #extract_expression;
},
param_conversion,
Expand Down
Loading

0 comments on commit 1b3dc6d

Please sign in to comment.