Skip to content

Commit

Permalink
Only put Display-like bounds on type variables
Browse files Browse the repository at this point in the history
Resolves #363

Related to #371

Requires #377

This makes sure we don't add unnecessary bounds for Display-like
derives. It does so in the same way as #371 did for the Debug derive: By
only adding bounds when the type contains a type variable.
  • Loading branch information
JelteF committed Jul 20, 2024
1 parent 67fa12b commit 4f33694
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 8 deletions.
144 changes: 136 additions & 8 deletions impl/src/fmt/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenSt
let trait_ident = format_ident!("{trait_name}");
let ident = &input.ident;

let ctx = (&attrs, ident, &trait_ident, &attr_name);
let type_params: Vec<_> = input
.generics
.params
.iter()
.filter_map(|p| match p {
syn::GenericParam::Type(t) => Some(&t.ident),
syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None,
})
.collect();

let ctx: ExpansionCtx = (&attrs, &type_params, ident, &trait_ident, &attr_name);
let (bounds, body) = match &input.data {
syn::Data::Struct(s) => expand_struct(s, ctx),
syn::Data::Enum(e) => expand_enum(e, ctx),
Expand Down Expand Up @@ -62,13 +72,15 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenSt

/// Type alias for an expansion context:
/// - [`ContainerAttributes`].
/// - Type parameters. Slice of [`syn::Ident`].
/// - Struct/enum/union [`syn::Ident`].
/// - Derived trait [`syn::Ident`].
/// - Attribute name [`syn::Ident`].
///
/// [`syn::Ident`]: struct@syn::Ident
type ExpansionCtx<'a> = (
&'a ContainerAttributes,
&'a [&'a syn::Ident],
&'a syn::Ident,
&'a syn::Ident,
&'a syn::Ident,
Expand All @@ -77,12 +89,13 @@ type ExpansionCtx<'a> = (
/// Expands a [`fmt::Display`]-like derive macro for the provided struct.
fn expand_struct(
s: &syn::DataStruct,
(attrs, ident, trait_ident, _): ExpansionCtx<'_>,
(attrs, type_params, ident, trait_ident, _): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let s = Expansion {
shared_format: None,
attrs,
fields: &s.fields,
type_params,
trait_ident,
ident,
};
Expand Down Expand Up @@ -111,7 +124,8 @@ fn expand_struct(
/// Expands a [`fmt`]-like derive macro for the provided enum.
fn expand_enum(
e: &syn::DataEnum,
(shared_attrs, _, trait_ident, attr_name): ExpansionCtx<'_>,

(shared_attrs, type_params, _, trait_ident, attr_name): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let (bounds, match_arms) = e.variants.iter().try_fold(
(Vec::new(), TokenStream::new()),
Expand All @@ -138,6 +152,7 @@ fn expand_enum(
shared_format: shared_attrs.fmt.as_ref(),
attrs: &attrs,
fields: &variant.fields,
type_params,
trait_ident,
ident,
};
Expand Down Expand Up @@ -175,7 +190,7 @@ fn expand_enum(
/// Expands a [`fmt::Display`]-like derive macro for the provided union.
fn expand_union(
u: &syn::DataUnion,
(attrs, _, _, attr_name): ExpansionCtx<'_>,
(attrs, _, _, _, attr_name): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let fmt = &attrs.fmt.as_ref().ok_or_else(|| {
syn::Error::new(
Expand Down Expand Up @@ -210,6 +225,9 @@ struct Expansion<'a> {
/// Struct or enum [`syn::Fields`].
fields: &'a syn::Fields,

/// Type parameters in this struct or enum.
type_params: &'a [&'a syn::Ident],

/// [`fmt`] trait [`syn::Ident`].
///
/// [`syn::Ident`]: struct@syn::Ident
Expand Down Expand Up @@ -352,10 +370,13 @@ impl<'a> Expansion<'a> {
if let Some(shared_format) = self.shared_format {
let shared_bounds = shared_format
.bounded_types(self.fields)
.map(|(ty, trait_name)| {
.filter_map(|(ty, trait_name)| {
if !self.contains_generic_param(ty) {
return None;
}
let trait_ident = format_ident!("{trait_name}");

parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
})
.chain(self.attrs.bounds.0.clone())
.collect();
Expand All @@ -379,6 +400,9 @@ impl<'a> Expansion<'a> {
.next()
.map(|f| {
let ty = &f.ty;
if !self.contains_generic_param(ty) {
return vec![];
}
let trait_ident = &self.trait_ident;
vec![parse_quote! { #ty: derive_more::core::fmt::#trait_ident }]
})
Expand All @@ -389,15 +413,119 @@ impl<'a> Expansion<'a> {

bounds.extend(
fmt.bounded_types(self.fields)
.map(|(ty, trait_name)| {
.filter_map(|(ty, trait_name)| {
if !self.contains_generic_param(ty) {
return None;
}
let trait_ident = format_ident!("{trait_name}");

parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
})
.chain(self.attrs.bounds.0.clone()),
);
bounds
}

/// Checks whether the provided [`syn::Path`] contains any of these [`Expansion::type_params`].
fn path_contains_generic_param(&self, path: &syn::Path) -> bool {
path.segments
.iter()
.any(|segment| match &segment.arguments {
syn::PathArguments::None => false,
syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments { args, .. },
) => args.iter().any(|generic| match generic {
syn::GenericArgument::Type(ty)
| syn::GenericArgument::AssocType(syn::AssocType { ty, .. }) => {
self.contains_generic_param(ty)
}

syn::GenericArgument::Lifetime(_)
| syn::GenericArgument::Const(_)
| syn::GenericArgument::AssocConst(_)
| syn::GenericArgument::Constraint(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}),
syn::PathArguments::Parenthesized(
syn::ParenthesizedGenericArguments { inputs, output, .. },
) => {
inputs.iter().any(|ty| self.contains_generic_param(ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => {
self.contains_generic_param(ty)
}
}
}
})
}

/// Checks whether the provided [`syn::Type`] contains any of these [`Expansion::type_params`].
fn contains_generic_param(&self, ty: &syn::Type) -> bool {
if self.type_params.is_empty() {
return false;
}
match ty {
syn::Type::Path(syn::TypePath { qself, path }) => {
if let Some(qself) = qself {
if self.contains_generic_param(&qself.ty) {
return true;
}
}

if let Some(ident) = path.get_ident() {
self.type_params.iter().any(|param| *param == ident)
} else {
self.path_contains_generic_param(path)
}
}

syn::Type::Array(syn::TypeArray { elem, .. })
| syn::Type::Group(syn::TypeGroup { elem, .. })
| syn::Type::Paren(syn::TypeParen { elem, .. })
| syn::Type::Ptr(syn::TypePtr { elem, .. })
| syn::Type::Reference(syn::TypeReference { elem, .. })
| syn::Type::Slice(syn::TypeSlice { elem, .. }) => {
self.contains_generic_param(elem)
}

syn::Type::BareFn(syn::TypeBareFn { inputs, output, .. }) => {
inputs
.iter()
.any(|arg| self.contains_generic_param(&arg.ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => self.contains_generic_param(ty),
}
}
syn::Type::Tuple(syn::TypeTuple { elems, .. }) => {
elems.iter().any(|ty| self.contains_generic_param(ty))
}

syn::Type::ImplTrait(_) => false,
syn::Type::Infer(_) => false,
syn::Type::Macro(_) => false,
syn::Type::Never(_) => false,
syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) => {
bounds.iter().any(|bound| match bound {
syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => {
self.path_contains_generic_param(path)
}
syn::TypeParamBound::Lifetime(_) => false,
syn::TypeParamBound::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
})
}
syn::Type::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}
}
}

/// Matches the provided derive macro `name` to appropriate actual trait name.
Expand Down
Loading

0 comments on commit 4f33694

Please sign in to comment.