Skip to content

Commit

Permalink
Add support for #derive[FromStr] for enums
Browse files Browse the repository at this point in the history
This allows deriving FromStr for enums with no fields in the obvious way

fixes JelteF#59
  • Loading branch information
aj-bagwell committed Dec 3, 2021
1 parent 718fbb6 commit c4190a8
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = "1.0.3"
syn = "1.0.81"
convert_case = { version = "0.4", optional = true}

[build-dependencies]
Expand All @@ -51,7 +51,7 @@ deref_mut = []
display = ["syn/extra-traits"]
error = ["syn/extra-traits"]
from = ["syn/extra-traits"]
from_str = []
from_str = ["convert_case"]
index = []
index_mut = []
into = ["syn/extra-traits"]
Expand Down
59 changes: 57 additions & 2 deletions doc/from_str.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
% What #[derive(FromStr)] generates

Deriving `FromStr` only works for newtypes, i.e structs with only a single
Deriving `FromStr` only works for enums with no fields
or newtypes, i.e structs with only a single
field. The result is that you will be able to call the `parse()` method on a
string to convert it to your newtype. This only works when the type that is
contained in the type implements `FromStr`.
Expand Down Expand Up @@ -77,4 +78,58 @@ impl ::core::str::FromStr for Point1D {

# Enums

Deriving `FromStr` is not supported for enums.
When deriving `FromStr` for an enums with variants with no fields it will
generate a `from_str` method that converts strings that match the variant name
to the variant. If using a case insensitive match would give a unique variant
(i.e you dont have both a `MyEnum::Foo` and a `MyEnum::foo` variant) then case
insensitve matching will be used, otherwise it will fall back to exact string
matchng.

Since the string may not match any vairants an error type is needed so one
will be generated of the format `Parse{}Error`

e.g. Given the following enum:

```rust
# #[macro_use] extern crate derive_more;
# fn main(){}
#[derive(FromStr)]
enum EnumNoFields {
Foo,
Bar,
Baz,
}
```

Code like this will be generated:

```rust
# enum EnumNoFields {
# Foo,
# Bar,
# Baz,
# }

#[derive(Debug, Clone, PartialEq, Eq)]
struct ParseEnumNoFieldsError;

impl std::fmt::Display for ParseEnumNoFieldsError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_str("invalid enum no fields")
}
}

impl std::error::Error for ParseEnumNoFieldsError {}

impl ::core::str::FromStr for EnumNoFields {
type Err = ParseEnumNoFieldsError;
fn from_str(src: &str) -> Result<Self, Self::Err> {
Ok(match src.to_lowercase().as_str() {
"foo" => EnumNoFields::Foo,
"bar" => EnumNoFields::Bar,
"baz" => EnumNoFields::Baz,
_ => return Err(ParseEnumNoFieldsError{}),
})
}
}
```
90 changes: 87 additions & 3 deletions src/from_str.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::utils::{DeriveType, HashMap};
use crate::utils::{SingleFieldData, State};
use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::quote;
use quote::{format_ident, quote};
use syn::{parse::Result, DeriveInput};

/// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr`
Expand All @@ -12,6 +14,14 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
trait_name.to_lowercase(),
)?;

if state.derive_type == DeriveType::Enum {
Ok(enum_from(input, state, trait_name))
} else {
Ok(struct_from(&state, trait_name))
}
}

pub fn struct_from(state: &State, trait_name: &'static str) -> TokenStream {
// We cannot set defaults for fields, once we do we can remove this check
if state.fields.len() != 1 || state.enabled_fields().len() != 1 {
panic_one_field(trait_name);
Expand All @@ -32,7 +42,7 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
let initializers = [quote!(#casted_trait::from_str(src)?)];
let body = single_field_data.initializer(&initializers);

Ok(quote! {
quote! {
impl#impl_generics #trait_path for #input_type#ty_generics #where_clause
{
type Err = <#field_type as #trait_path>::Err;
Expand All @@ -41,7 +51,81 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
Ok(#body)
}
}
})
}
}

fn enum_from(
input: &DeriveInput,
state: State,
trait_name: &'static str,
) -> TokenStream {
let mut variants_caseinsensitive = HashMap::default();
for variant_state in state.enabled_variant_data().variant_states {
let variant = variant_state.variant.unwrap();
if !variant.fields.is_empty() {
panic!("Only enums with no fields can derive({})", trait_name)
}

variants_caseinsensitive
.entry(variant.ident.to_string().to_lowercase())
.or_insert_with(Vec::new)
.push(variant.ident.clone());
}

let input_type = &input.ident;
let visibility = &input.vis;

let err_name = format_ident!("Parse{}Error", input_type);
let err_message =
format!("invalid {}", input_type.to_string().to_case(Case::Lower));

let mut cases = vec![];

// if a case insensitve match is unique match do that
// otherwise do a case sensitive match
for (ref canonical, ref variants) in variants_caseinsensitive {
if variants.len() == 1 {
let variant = &variants[0];
cases.push(quote! {
#canonical => #input_type::#variant,
})
} else {
for variant in variants {
let variant_str = variant.to_string();
cases.push(quote! {
#canonical if(src == #variant_str) => #input_type::#variant,
})
}
}
}

let trait_path = state.trait_path;

quote! {

#[derive(Debug, Clone, PartialEq, Eq)]
#visibility struct #err_name;

impl std::fmt::Display for #err_name {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_str(#err_message)
}
}

impl std::error::Error for #err_name {}

impl #trait_path for #input_type
{
type Err = #err_name;
#[inline]
fn from_str(src: &str) -> ::core::result::Result<Self, Self::Err> {
Ok(match src.to_lowercase().as_str() {
#(#cases)*
_ => return Err(#err_name{}),
})
}
}
}
}

fn panic_one_field(trait_name: &str) -> ! {
Expand Down
22 changes: 22 additions & 0 deletions tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,25 @@ struct MyInt(i32);
struct Point1D {
x: i32,
}

#[derive(Debug, FromStr, PartialEq, Eq)]
enum EnumNoFields {
Foo,
Bar,
Baz,
}

#[test]
fn enum_test() {
assert_eq!("Foo".parse::<EnumNoFields>().unwrap(), EnumNoFields::Foo);
assert_eq!("FOO".parse::<EnumNoFields>().unwrap(), EnumNoFields::Foo);
assert_eq!("foo".parse::<EnumNoFields>().unwrap(), EnumNoFields::Foo);
assert_eq!(
"other".parse::<EnumNoFields>().unwrap_err(),
ParseEnumNoFieldsError {}
);
assert_eq!(
ParseEnumNoFieldsError {}.to_string(),
"invalid enum no fields"
);
}

0 comments on commit c4190a8

Please sign in to comment.