From 162535e806dbc98422f1399319752dab6fe8617c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustav=20S=C3=B6rn=C3=A4s?= Date: Fri, 19 Jul 2024 12:09:42 +0200 Subject: [PATCH] Only put Debug-like bounds on type variables (#371, #363) ## Synopsis The problem, as reported in the issue, is that code like the following ```rust #[derive(derive_more::Debug)] struct Item { next: Option>, } ``` expands into something like ```rust impl std::fmt::Debug for Item where Item: Debug { /* ... */ } ``` which does not compile. This PR changes the Debug derive so it does not emit those bounds. ## Solution My understanding of the current code is that we iterate over all fields of the struct/enum and add either a specific format bound (e.g. `: fmt::Binary`), a default `: fmt::Debug` bound or skip it if either it is marked as `#[debug(skip)]` or the entire container has a format attribute. The suggested solution in the issue (if I understood it correctly) was to only add bounds if the type is a type variable, since rustc already knows if a concrete type is, say, `: fmt::Debug`. So, instead of adding the bound for every type, we first check that the type contains one of the container's type variables. Since types can be nested, it is an unfortunately long recursive function handling the different types of types. This part of Rust syntax is probably not going to change, so perhaps it is feasible to shorten some of the branches into `_ => false`. One drawback of this implementation is that we iterate over the list of type variables every time we find a "leaf type". I chose `Vec` over `HashSet` because in my experience there are only a handful of type variables per container. Co-authored-by: Jelte Fennema-Nio Co-authored-by: Kai Ren --- impl/src/fmt/debug.rs | 143 ++++++++++++++++++++++++++++++++++++++++-- tests/debug.rs | 138 ++++++++++++++++++++++++++++++++++++++++ tests/sum.rs | 1 + 3 files changed, 276 insertions(+), 6 deletions(-) diff --git a/impl/src/fmt/debug.rs b/impl/src/fmt/debug.rs index 9b4e99ff..524f0218 100644 --- a/impl/src/fmt/debug.rs +++ b/impl/src/fmt/debug.rs @@ -24,9 +24,21 @@ pub fn expand(input: &syn::DeriveInput, _: &str) -> syn::Result { .unwrap_or_default(); let ident = &input.ident; + let type_params: Vec<_> = input + .generics + .params + .iter() + .filter_map(|p| match p { + syn::GenericParam::Type(t) => Some(&t.ident), + syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None, + }) + .collect(); + let (bounds, body) = match &input.data { - syn::Data::Struct(s) => expand_struct(attrs, ident, s, &attr_name), - syn::Data::Enum(e) => expand_enum(attrs, e, &attr_name), + syn::Data::Struct(s) => { + expand_struct(attrs, ident, s, &type_params, &attr_name) + } + syn::Data::Enum(e) => expand_enum(attrs, e, &type_params, &attr_name), syn::Data::Union(_) => { return Err(syn::Error::new( input.span(), @@ -64,11 +76,13 @@ fn expand_struct( attrs: ContainerAttributes, ident: &Ident, s: &syn::DataStruct, + type_params: &[&syn::Ident], attr_name: &syn::Ident, ) -> syn::Result<(Vec, TokenStream)> { let s = Expansion { attr: &attrs, fields: &s.fields, + type_params, ident, attr_name, }; @@ -99,6 +113,7 @@ fn expand_struct( fn expand_enum( mut attrs: ContainerAttributes, e: &syn::DataEnum, + type_params: &[&syn::Ident], attr_name: &syn::Ident, ) -> syn::Result<(Vec, TokenStream)> { if let Some(enum_fmt) = attrs.fmt.as_ref() { @@ -136,6 +151,7 @@ fn expand_enum( let v = Expansion { attr: &attrs, fields: &variant.fields, + type_params, ident, attr_name, }; @@ -195,6 +211,9 @@ struct Expansion<'a> { /// Struct or enum [`syn::Fields`]. fields: &'a syn::Fields, + /// Type parameters in this struct or enum. + type_params: &'a [&'a syn::Ident], + /// Name of the attributes, considered by this macro. attr_name: &'a syn::Ident, } @@ -334,15 +353,26 @@ impl<'a> Expansion<'a> { let mut out = self.attr.bounds.0.clone().into_iter().collect::>(); if let Some(fmt) = self.attr.fmt.as_ref() { - out.extend(fmt.bounded_types(self.fields).map(|(ty, trait_name)| { - let trait_ident = format_ident!("{trait_name}"); + out.extend(fmt.bounded_types(self.fields).filter_map( + |(ty, trait_name)| { + if !self.contains_generic_param(ty) { + return None; + } + + let trait_ident = format_ident!("{trait_name}"); - parse_quote! { #ty: derive_more::core::fmt::#trait_ident } - })); + Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident }) + }, + )); Ok(out) } else { self.fields.iter().try_fold(out, |mut out, field| { let ty = &field.ty; + + if !self.contains_generic_param(ty) { + return Ok(out); + } + match FieldAttribute::parse_attrs(&field.attrs, self.attr_name)? .map(Spanning::into_inner) { @@ -362,4 +392,105 @@ impl<'a> Expansion<'a> { }) } } + + /// Checks whether the provided [`syn::Path`] contains any of these [`Expansion::type_params`]. + fn path_contains_generic_param(&self, path: &syn::Path) -> bool { + path.segments + .iter() + .any(|segment| match &segment.arguments { + syn::PathArguments::None => false, + syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { args, .. }, + ) => args.iter().any(|generic| match generic { + syn::GenericArgument::Type(ty) + | syn::GenericArgument::AssocType(syn::AssocType { ty, .. }) => { + self.contains_generic_param(ty) + } + + syn::GenericArgument::Lifetime(_) + | syn::GenericArgument::Const(_) + | syn::GenericArgument::AssocConst(_) + | syn::GenericArgument::Constraint(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + }), + syn::PathArguments::Parenthesized( + syn::ParenthesizedGenericArguments { inputs, output, .. }, + ) => { + inputs.iter().any(|ty| self.contains_generic_param(ty)) + || match output { + syn::ReturnType::Default => false, + syn::ReturnType::Type(_, ty) => { + self.contains_generic_param(ty) + } + } + } + }) + } + + /// Checks whether the provided [`syn::Type`] contains any of these [`Expansion::type_params`]. + fn contains_generic_param(&self, ty: &syn::Type) -> bool { + if self.type_params.is_empty() { + return false; + } + match ty { + syn::Type::Path(syn::TypePath { qself, path }) => { + if let Some(qself) = qself { + if self.contains_generic_param(&qself.ty) { + return true; + } + } + + if let Some(ident) = path.get_ident() { + self.type_params.iter().any(|param| *param == ident) + } else { + self.path_contains_generic_param(path) + } + } + + syn::Type::Array(syn::TypeArray { elem, .. }) + | syn::Type::Group(syn::TypeGroup { elem, .. }) + | syn::Type::Paren(syn::TypeParen { elem, .. }) + | syn::Type::Ptr(syn::TypePtr { elem, .. }) + | syn::Type::Reference(syn::TypeReference { elem, .. }) + | syn::Type::Slice(syn::TypeSlice { elem, .. }) => { + self.contains_generic_param(elem) + } + + syn::Type::BareFn(syn::TypeBareFn { inputs, output, .. }) => { + inputs + .iter() + .any(|arg| self.contains_generic_param(&arg.ty)) + || match output { + syn::ReturnType::Default => false, + syn::ReturnType::Type(_, ty) => self.contains_generic_param(ty), + } + } + syn::Type::Tuple(syn::TypeTuple { elems, .. }) => { + elems.iter().any(|ty| self.contains_generic_param(ty)) + } + + syn::Type::ImplTrait(_) => false, + syn::Type::Infer(_) => false, + syn::Type::Macro(_) => false, + syn::Type::Never(_) => false, + syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) => { + bounds.iter().any(|bound| match bound { + syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => { + self.path_contains_generic_param(path) + } + syn::TypeParamBound::Lifetime(_) => false, + syn::TypeParamBound::Verbatim(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + }) + } + syn::Type::Verbatim(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + } + } } diff --git a/tests/debug.rs b/tests/debug.rs index bc6b6132..0c4b180a 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -1932,3 +1932,141 @@ mod complex_enum_syntax { assert_eq!(format!("{:?}", Enum::A), "A"); } } + +// See: https://github.com/JelteF/derive_more/issues/363 +mod type_variables { + mod our_alloc { + #[cfg(not(feature = "std"))] + pub use alloc::{boxed::Box, format, vec, vec::Vec}; + #[cfg(feature = "std")] + pub use std::{boxed::Box, format, vec, vec::Vec}; + } + + use our_alloc::{format, vec, Box, Vec}; + + use derive_more::Debug; + + #[derive(Debug)] + struct ItemStruct { + next: Option>, + } + + #[derive(Debug)] + struct ItemTuple(Option>); + + #[derive(Debug)] + #[debug("Item({_0:?})")] + struct ItemTupleContainerFmt(Option>); + + #[derive(Debug)] + enum ItemEnum { + Node { children: Vec, inner: i32 }, + Leaf { inner: i32 }, + } + + #[derive(Debug)] + struct VecMeansDifferent { + next: our_alloc::Vec, + real: Vec, + } + + #[derive(Debug)] + struct Array { + #[debug("{t}")] + t: [T; 10], + } + + mod parens { + #![allow(unused_parens)] // test that type is found even in parentheses + + use derive_more::Debug; + + #[derive(Debug)] + struct Paren { + t: (T), + } + } + + #[derive(Debug)] + struct ParenthesizedGenericArgumentsInput { + t: dyn Fn(T) -> i32, + } + + #[derive(Debug)] + struct ParenthesizedGenericArgumentsOutput { + t: dyn Fn(i32) -> T, + } + + #[derive(Debug)] + struct Ptr { + t: *const T, + } + + #[derive(Debug)] + struct Reference<'a, T> { + t: &'a T, + } + + #[derive(Debug)] + struct Slice<'a, T> { + t: &'a [T], + } + + #[derive(Debug)] + struct BareFn { + t: Box T>, + } + + #[derive(Debug)] + struct Tuple { + t: Box<(T, T)>, + } + + trait MyTrait {} + + #[derive(Debug)] + struct TraitObject { + t: Box>, + } + + #[test] + fn assert() { + assert_eq!( + format!( + "{:?}", + ItemStruct { + next: Some(Box::new(ItemStruct { next: None })) + }, + ), + "ItemStruct { next: Some(ItemStruct { next: None }) }", + ); + + assert_eq!( + format!("{:?}", ItemTuple(Some(Box::new(ItemTuple(None))))), + "ItemTuple(Some(ItemTuple(None)))", + ); + + assert_eq!( + format!( + "{:?}", + ItemTupleContainerFmt(Some(Box::new(ItemTupleContainerFmt(None)))), + ), + "Item(Some(Item(None)))", + ); + + let item = ItemEnum::Node { + children: vec![ + ItemEnum::Node { + children: vec![], + inner: 0, + }, + ItemEnum::Leaf { inner: 1 }, + ], + inner: 2, + }; + assert_eq!( + format!("{item:?}"), + "Node { children: [Node { children: [], inner: 0 }, Leaf { inner: 1 }], inner: 2 }", + ) + } +} diff --git a/tests/sum.rs b/tests/sum.rs index a273bf80..503d6b4f 100644 --- a/tests/sum.rs +++ b/tests/sum.rs @@ -1,4 +1,5 @@ #![cfg_attr(not(feature = "std"), no_std)] +#![allow(dead_code)] // some code is tested for type checking only use derive_more::Sum;