Skip to content

Commit

Permalink
async method should allow args not only receiver (#4015)
Browse files Browse the repository at this point in the history
* async method should allow args not only receiver

* add changelog md
  • Loading branch information
reswqa authored Mar 30, 2024
1 parent 4d033c4 commit 74d9d23
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions newsfragments/4015.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the bug that an async `#[pymethod]` with receiver can't have any other args.
22 changes: 19 additions & 3 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Display;

use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
Expand Down Expand Up @@ -518,17 +518,33 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
None => quote!(None),
};
let evaluate_args = || -> (Vec<Ident>, TokenStream) {
let mut arg_names = Vec::with_capacity(args.len());
let mut evaluate_arg = quote! {};
for arg in &args {
let arg_name = format_ident!("arg_{}", arg_names.len());
arg_names.push(arg_name.clone());
evaluate_arg.extend(quote! {
let #arg_name = #arg
});
}
(arg_names, evaluate_arg)
};
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
let (arg_name, evaluate_arg) = evaluate_args();
quote! {{
#evaluate_arg;
let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
async move { function(&__guard, #(#args),*).await }
async move { function(&__guard, #(#arg_name),*).await }
}}
}
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => {
let (arg_name, evaluate_arg) = evaluate_args();
quote! {{
#evaluate_arg;
let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
async move { function(&mut __guard, #(#args),*).await }
async move { function(&mut __guard, #(#arg_name),*).await }
}}
}
_ => {
Expand Down
33 changes: 33 additions & 0 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,39 @@ fn coroutine_panic() {
})
}

#[test]
fn test_async_method_receiver_with_other_args() {
#[pyclass]
struct Value(i32);
#[pymethods]
impl Value {
#[new]
fn new() -> Self {
Self(0)
}
async fn get_value_plus_with(&self, v: i32) -> i32 {
self.0 + v
}
async fn set_value(&mut self, new_value: i32) -> i32 {
self.0 = new_value;
self.0
}
}

Python::with_gil(|gil| {
let test = r#"
import asyncio
v = Value()
assert asyncio.run(v.get_value_plus_with(3)) == 3
assert asyncio.run(v.set_value(10)) == 10
assert asyncio.run(v.get_value_plus_with(1)) == 11
"#;
let locals = [("Value", gil.get_type_bound::<Value>())].into_py_dict_bound(gil);
py_run!(gil, *locals, test);
});
}

#[test]
fn test_async_method_receiver() {
#[pyclass]
Expand Down

0 comments on commit 74d9d23

Please sign in to comment.