Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derive: fix handling of generic bounds #178

Merged
merged 2 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 1 addition & 52 deletions book/src/format.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,4 @@ enum Request {
}
```

NOTE: for generic structs and enums the `derive` macro adds `Format` bounds to the *types of the generic fields* rather than to all the generic (input) parameters of the struct / enum.
Built-in `derive` attributes like `#[derive(Debug)]` use the latter approach.
To our knowledge `derive(Format)` approach is more accurate in that it doesn't over-constrain the generic type parameters.
The different between the two approaches is depicted below:

``` rust
# extern crate defmt;
# use defmt::Format;

#[derive(Format)]
struct S<'a, T> {
x: Option<&'a T>,
y: u8,
}
```

``` rust
# extern crate defmt;
# use defmt::Format;

// `Format` produces this implementation
impl<'a, T> Format for S<'a, T>
where
Option<&'a T>: Format // <- main difference
{
// ..
# fn format(&self, f: &mut defmt::Formatter) {}
}

#[derive(Debug)]
struct S<'a, T> {
x: Option<&'a T>,
y: u8,
}
```

``` rust
# use std::fmt::Debug;
# struct S<'a, T> {
# x: Option<&'a T>,
# y: u8,
# }

// `Debug` produces this implementation
impl<'a, T> Debug for S<'a, T>
where
T: Debug // <- main difference
{
// ..
# fn fmt(&self, f: &mut core::fmt::Formatter) -> std::fmt::Result { Ok(()) }
}
```
NOTE: Like built-in derives like `#[derive(Debug)]`, `#[derive(Format)]` will add `Format` bounds to the generic type parameters of the struct.
42 changes: 18 additions & 24 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use proc_macro::{Span, TokenStream};
use defmt_parser::Fragment;
use proc_macro2::{Ident as Ident2, Span as Span2, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::GenericParam;
use syn::WhereClause;
use syn::{
parse::{self, Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned as _,
Data, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, ItemFn, ItemStruct, LitInt,
LitStr, ReturnType, Token, Type,
LitStr, ReturnType, Token, Type, WherePredicate,
};

#[proc_macro_attribute]
Expand Down Expand Up @@ -175,7 +177,7 @@ impl MLevel {
// `#[derive(Format)]`
#[proc_macro_derive(Format)]
pub fn format(ts: TokenStream) -> TokenStream {
let input = parse_macro_input!(ts as DeriveInput);
let mut input = parse_macro_input!(ts as DeriveInput);
let span = input.span();

let ident = input.ident;
Expand Down Expand Up @@ -254,30 +256,22 @@ pub fn format(ts: TokenStream) -> TokenStream {
}
}

let params = input.generics.params;
let predicates = if params.is_empty() {
vec![]
} else {
// `Format` bounds for non-native field types
let mut preds = field_types
.into_iter()
.map(|ty| quote!(#ty: defmt::Format))
.collect::<Vec<_>>();
// extend with the where clause from the struct/enum declaration
if let Some(where_clause) = input.generics.where_clause {
preds.extend(
where_clause
.predicates
.into_iter()
.map(|pred| quote!(#pred)),
)
let where_clause = input.generics.make_where_clause();
let mut where_clause: WhereClause = where_clause.clone();
let (impl_generics, type_generics, _) = input.generics.split_for_impl();

// Extend where-clause with `Format` bounds for type parameters.
for param in &input.generics.params {
if let GenericParam::Type(ty) = param {
let ident = &ty.ident;
where_clause
.predicates
.push(syn::parse::<WherePredicate>(quote!(#ident: defmt::Format).into()).unwrap());
}
preds
};
}

quote!(
impl<#params> defmt::Format for #ident<#params>
where #(#predicates),*
{
impl #impl_generics defmt::Format for #ident #type_generics #where_clause {
fn format(&self, f: &mut defmt::Formatter) {
#(#exprs)*
}
Expand Down
39 changes: 39 additions & 0 deletions tests/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,42 @@ fn format_slice_enum_generic_struct() {
],
);
}

#[test]
fn derive_with_bounds() {
#[derive(Format)]
struct S<T: Copy> {
val: T,
}

#[derive(Format)]
struct S2<'a: 'b, 'b> {
a: &'a u8,
b: &'b u8,
}

let index = fetch_string_index();
check_format_implementation(
&S { val: 0 },
&[
index, // "S {{ val: {:?} }}"
inc(index, 1), // "{:i32}"
0,
0,
0,
0,
],
);

let index = fetch_string_index();
check_format_implementation(
&S2 { a: &1, b: &2 },
&[
index, // "S2 { a: {:?}, b: {:?} }}"
inc(index, 1), // "{:u8}"
1,
inc(index, 2), // "{:u8}"
2,
],
);
}