Skip to content

Commit

Permalink
feat: add coroutine __name__/__qualname__ and not-awaited warning
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 23, 2023
1 parent 81df150 commit 9502795
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 18 deletions.
1 change: 1 addition & 0 deletions newsfragments/3588.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `__name__`/`__qualname__` attributes to `Coroutine`, as well as a Python warning when the coroutine is dropped without having been awaited
18 changes: 17 additions & 1 deletion pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,23 @@ impl<'a> FnSpec<'a> {
let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
let python_name = &self.python_name;
let qualname = match cls {
Some(cls) => quote! {
_pyo3::impl_::coroutine::method_coroutine_qualname::<#cls>(py, stringify!(#python_name))
},
None => quote! {
_pyo3::impl_::coroutine::coroutine_qualname(py, py.from_borrowed_ptr_or_opt::<_pyo3::types::PyModule>(_slf), stringify!(#python_name))
},
};
call = quote! {{
let future = #call;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::types::PyString::new(py, stringify!(#python_name)).into(),
#qualname,
async move { _pyo3::impl_::wrap::OkWrap::wrap_no_gil(future.await) }
)
}};
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand Down
47 changes: 44 additions & 3 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::AsyncioWaker,
exceptions::{PyRuntimeError, PyStopIteration},
exceptions::{PyAttributeError, PyRuntimeError, PyRuntimeWarning, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
types::PyIterator,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};

Expand All @@ -30,6 +30,8 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname: Option<Py<PyString>>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -41,7 +43,11 @@ impl Coroutine {
/// (should always be `None` anyway).
///
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed
pub(crate) fn from_future<F, T, E>(future: F) -> Self
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
mut qualname: Option<Py<PyString>>,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
Expand All @@ -52,7 +58,10 @@ impl Coroutine {
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
qualname = qualname.or_else(|| name.clone());
Self {
name,
qualname,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
Expand Down Expand Up @@ -113,6 +122,20 @@ pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResul

#[pymethods(crate = "crate")]
impl Coroutine {
#[getter]
fn __name__(&self) -> PyResult<Py<PyString>> {
self.name
.clone()
.ok_or_else(|| PyAttributeError::new_err("__name__"))
}

#[getter]
fn __qualname__(&self) -> PyResult<Py<PyString>> {
self.qualname
.clone()
.ok_or_else(|| PyAttributeError::new_err("__qualname__"))
}

fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?)
}
Expand All @@ -135,3 +158,21 @@ impl Coroutine {
self.poll(py, None)
}
}

impl Drop for Coroutine {
fn drop(&mut self) {
if self.future.is_some() {
Python::with_gil(|gil| {
let qualname = self
.qualname
.as_ref()
.map_or(Ok("<coroutine>"), |n| n.as_ref(gil).to_str())
.unwrap();
let message = format!("coroutine {} was never awaited", qualname);
PyErr::warn(gil, gil.get_type::<PyRuntimeWarning>(), &message, 2)
.expect("warning error");
self.poll(gil, None).expect("coroutine close error");
})
}
}
}
38 changes: 26 additions & 12 deletions src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
use crate::coroutine::Coroutine;
use crate::impl_::wrap::OkWrap;
use crate::{IntoPy, PyErr, PyObject, Python};
use std::future::Future;

/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`.
pub fn wrap_future<F, R, T>(future: F) -> Coroutine
use crate::{
coroutine::Coroutine,
types::{PyModule, PyString},
IntoPy, Py, PyClass, PyErr, PyObject, Python,
};

pub fn new_coroutine<F, T, E>(name: Py<PyString>, qualname: Py<PyString>, future: F) -> Coroutine
where
F: Future<Output = R> + Send + 'static,
R: OkWrap<T>,
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
PyErr: From<R::Error>,
PyErr: From<E>,
{
let future = async move {
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
future.await.wrap(unsafe { Python::assume_gil_acquired() })
Coroutine::new(Some(name), Some(qualname), future)
}

pub fn coroutine_qualname(py: Python<'_>, module: Option<&PyModule>, name: &str) -> Py<PyString> {
match module.and_then(|m| m.name().ok()) {
Some(module) => PyString::new(py, &format!("{}.{}", module, name)),
None => PyString::new(py, name),
}
.into()
}

pub fn method_coroutine_qualname<T: PyClass>(py: Python<'_>, name: &str) -> Py<PyString> {
let class = T::NAME;
let qualname = match T::MODULE {
Some(module) => format!("{}.{}.{}", module, class, name),
None => format!("{}.{}", class, name),
};
Coroutine::from_future(future)
PyString::new(py, &qualname).into()
}
7 changes: 7 additions & 0 deletions src/impl_/wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<T> SomeWrap<Option<T>> for Option<T> {
/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`.
pub trait OkWrap<T> {
type Error;
fn wrap_no_gil(self) -> Result<T, Self::Error>;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
}

Expand All @@ -30,6 +31,9 @@ where
T: IntoPy<PyObject>,
{
type Error = PyErr;
fn wrap_no_gil(self) -> Result<T, Self::Error> {
Ok(self)
}
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
Ok(self.into_py(py))
}
Expand All @@ -40,6 +44,9 @@ where
T: IntoPy<PyObject>,
{
type Error = E;
fn wrap_no_gil(self) -> Result<T, Self::Error> {
self
}
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
self.map(|o| o.into_py(py))
}
Expand Down
69 changes: 67 additions & 2 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![cfg(feature = "macros")]
#![cfg(not(target_arch = "wasm32"))]
use std::ops::Deref;
use std::{task::Poll, thread, time::Duration};

use futures::{channel::oneshot, future::poll_fn};
use pyo3::types::{IntoPyDict, PyType};
use pyo3::{prelude::*, py_run};

#[path = "../src/tests/common.rs"]
Expand All @@ -20,8 +22,8 @@ fn handle_windows(test: &str) -> String {
#[test]
fn noop_coroutine() {
#[pyfunction]
async fn noop() -> usize {
42
async fn noop() -> PyResult<usize> {
Ok(42)
}
Python::with_gil(|gil| {
let noop = wrap_pyfunction!(noop, gil).unwrap();
Expand All @@ -30,6 +32,69 @@ fn noop_coroutine() {
})
}

#[test]
fn test_coroutine_qualname() {
#[pyfunction]
async fn my_fn() {}
#[pyclass]
struct MyClass;
#[pymethods]
impl MyClass {
#[new]
fn new() -> Self {
Self
}
// TODO use &self when possible
async fn my_method(_self: Py<Self>) {}
#[classmethod]
async fn my_classmethod(_cls: Py<PyType>) {}
#[staticmethod]
async fn my_staticmethod() {}
}
#[pyclass(module = "my_module")]
struct MyClassWithModule;
#[pymethods]
impl MyClassWithModule {
#[new]
fn new() -> Self {
Self
}
// TODO use &self when possible
async fn my_method(_self: Py<Self>) {}
#[classmethod]
async fn my_classmethod(_cls: Py<PyType>) {}
#[staticmethod]
async fn my_staticmethod() {}
}
Python::with_gil(|gil| {
let test = r#"
for coro, name, qualname in [
(my_fn(), "my_fn", "my_fn"),
(my_fn_with_module(), "my_fn", "my_module.my_fn"),
(MyClass().my_method(), "my_method", "MyClass.my_method"),
#(MyClass().my_classmethod(), "my_classmethod", "MyClass.my_classmethod"),
(MyClass.my_staticmethod(), "my_staticmethod", "MyClass.my_staticmethod"),
(MyClassWithModule().my_method(), "my_method", "my_module.MyClassWithModule.my_method"),
#(MyClassWithModule().my_classmethod(), "my_classmethod", "my_module.MyClassWithModule.my_classmethod"),
(MyClassWithModule.my_staticmethod(), "my_staticmethod", "my_module.MyClassWithModule.my_staticmethod"),
]:
assert coro.__name__ == name and coro.__qualname__ == qualname
"#;
let my_module = PyModule::new(gil, "my_module").unwrap();
let locals = [
("my_fn", wrap_pyfunction!(my_fn, gil).unwrap().deref()),
(
"my_fn_with_module",
wrap_pyfunction!(my_fn, my_module).unwrap(),
),
("MyClass", gil.get_type::<MyClass>()),
("MyClassWithModule", gil.get_type::<MyClassWithModule>()),
]
.into_py_dict(gil);
py_run!(gil, *locals, &handle_windows(test));
})
}

#[test]
fn sleep_0_like_coroutine() {
#[pyfunction]
Expand Down

0 comments on commit 9502795

Please sign in to comment.