Skip to content

Commit

Permalink
fix #[derive(FromPyObject)] expansion with trait bounds (#4645)
Browse files Browse the repository at this point in the history
* fix `#[derive(FromPyObject)]` expansion with trait bounds

* add newsfragment
  • Loading branch information
Icxolu authored Oct 25, 2024
1 parent 6aa5e6b commit b3bb667
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
1 change: 1 addition & 0 deletions newsfragments/4645.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fix `#[derive(FromPyObject)]` expansion on generic with trait bounds
21 changes: 12 additions & 9 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,24 +572,27 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::Life
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
/// adds `T: FromPyObject` on the derived implementation.
pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
let options = ContainerOptions::from_attrs(&tokens.attrs)?;
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = &ctx;

let (_, ty_generics, _) = tokens.generics.split_for_impl();
let mut trait_generics = tokens.generics.clone();
let generics = &tokens.generics;
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
lt.clone()
} else {
trait_generics.params.push(parse_quote!('py));
parse_quote!('py)
};
let mut where_clause: syn::WhereClause = parse_quote!(where);
for param in generics.type_params() {
let (impl_generics, _, where_clause) = trait_generics.split_for_impl();

let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
for param in trait_generics.type_params() {
let gen_ident = &param.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
.push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
}
let options = ContainerOptions::from_attrs(&tokens.attrs)?;
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = &ctx;

let derives = match &tokens.data {
syn::Data::Enum(en) => {
Expand All @@ -616,7 +619,7 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
let ident = &tokens.ident;
Ok(quote!(
#[automatically_derived]
impl #trait_generics #pyo3_path::FromPyObject<#lt_param> for #ident #generics #where_clause {
impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
#derives
}
Expand Down
17 changes: 16 additions & 1 deletion tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString, PyTuple};

#[macro_use]
#[path = "../src/tests/common.rs"]
Expand Down Expand Up @@ -109,6 +109,21 @@ fn test_generic_transparent_named_field_struct() {
});
}

#[derive(Debug, FromPyObject)]
pub struct GenericWithBound<K: std::hash::Hash + Eq, V>(std::collections::HashMap<K, V>);

#[test]
fn test_generic_with_bound() {
Python::with_gil(|py| {
let dict = [("1", 1), ("2", 2)].into_py_dict(py).unwrap();
let map = dict.extract::<GenericWithBound<String, i32>>().unwrap().0;
assert_eq!(map.len(), 2);
assert_eq!(map["1"], 1);
assert_eq!(map["2"], 2);
assert!(!map.contains_key("3"));
});
}

#[derive(Debug, FromPyObject)]
pub struct E<T, T2> {
test: T,
Expand Down

0 comments on commit b3bb667

Please sign in to comment.