diff --git a/Cargo.toml b/Cargo.toml index 458267db..c921628a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = "1.0.3" +syn = "1.0.81" convert_case = { version = "0.4", optional = true} [build-dependencies] @@ -51,7 +51,7 @@ deref_mut = [] display = ["syn/extra-traits"] error = ["syn/extra-traits"] from = ["syn/extra-traits"] -from_str = [] +from_str = ["convert_case"] index = [] index_mut = [] into = ["syn/extra-traits"] diff --git a/doc/from_str.md b/doc/from_str.md index 3da68ca8..c5b71508 100644 --- a/doc/from_str.md +++ b/doc/from_str.md @@ -1,6 +1,7 @@ % What #[derive(FromStr)] generates -Deriving `FromStr` only works for newtypes, i.e structs with only a single +Deriving `FromStr` only works for enums with no fields +or newtypes, i.e structs with only a single field. The result is that you will be able to call the `parse()` method on a string to convert it to your newtype. This only works when the type that is contained in the type implements `FromStr`. @@ -77,4 +78,58 @@ impl ::core::str::FromStr for Point1D { # Enums -Deriving `FromStr` is not supported for enums. +When deriving `FromStr` for an enums with variants with no fields it will +generate a `from_str` method that converts strings that match the variant name +to the variant. If using a case insensitive match would give a unique variant +(i.e you dont have both a `MyEnum::Foo` and a `MyEnum::foo` variant) then case +insensitve matching will be used, otherwise it will fall back to exact string +matchng. + +Since the string may not match any vairants an error type is needed so one +will be generated of the format `Parse{}Error` + +e.g. Given the following enum: + +```rust +# #[macro_use] extern crate derive_more; +# fn main(){} +#[derive(FromStr)] +enum EnumNoFields { + Foo, + Bar, + Baz, +} +``` + +Code like this will be generated: + +```rust +# enum EnumNoFields { +# Foo, +# Bar, +# Baz, +# } + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ParseEnumNoFieldsError; + +impl std::fmt::Display for ParseEnumNoFieldsError { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.write_str("invalid enum no fields") + } +} + +impl std::error::Error for ParseEnumNoFieldsError {} + +impl ::core::str::FromStr for EnumNoFields { + type Err = ParseEnumNoFieldsError; + fn from_str(src: &str) -> Result { + Ok(match src.to_lowercase().as_str() { + "foo" => EnumNoFields::Foo, + "bar" => EnumNoFields::Bar, + "baz" => EnumNoFields::Baz, + _ => return Err(ParseEnumNoFieldsError{}), + }) + } +} +``` diff --git a/src/from_str.rs b/src/from_str.rs index 5f26fdd1..589b9a66 100644 --- a/src/from_str.rs +++ b/src/from_str.rs @@ -1,6 +1,8 @@ +use crate::utils::{DeriveType, HashMap}; use crate::utils::{SingleFieldData, State}; +use convert_case::{Case, Casing}; use proc_macro2::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; use syn::{parse::Result, DeriveInput}; /// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr` @@ -12,6 +14,14 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result TokenStream { // We cannot set defaults for fields, once we do we can remove this check if state.fields.len() != 1 || state.enabled_fields().len() != 1 { panic_one_field(trait_name); @@ -32,7 +42,7 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result::Err; @@ -41,7 +51,81 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result TokenStream { + let mut variants_caseinsensitive = HashMap::default(); + for variant_state in state.enabled_variant_data().variant_states { + let variant = variant_state.variant.unwrap(); + if !variant.fields.is_empty() { + panic!("Only enums with no fields can derive({})", trait_name) + } + + variants_caseinsensitive + .entry(variant.ident.to_string().to_lowercase()) + .or_insert_with(Vec::new) + .push(variant.ident.clone()); + } + + let input_type = &input.ident; + let visibility = &input.vis; + + let err_name = format_ident!("Parse{}Error", input_type); + let err_message = + format!("invalid {}", input_type.to_string().to_case(Case::Lower)); + + let mut cases = vec![]; + + // if a case insensitve match is unique match do that + // otherwise do a case sensitive match + for (ref canonical, ref variants) in variants_caseinsensitive { + if variants.len() == 1 { + let variant = &variants[0]; + cases.push(quote! { + #canonical => #input_type::#variant, + }) + } else { + for variant in variants { + let variant_str = variant.to_string(); + cases.push(quote! { + #canonical if(src == #variant_str) => #input_type::#variant, + }) + } + } + } + + let trait_path = state.trait_path; + + quote! { + + #[derive(Debug, Clone, PartialEq, Eq)] + #visibility struct #err_name; + + impl std::fmt::Display for #err_name { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.write_str(#err_message) + } + } + + impl std::error::Error for #err_name {} + + impl #trait_path for #input_type + { + type Err = #err_name; + #[inline] + fn from_str(src: &str) -> ::core::result::Result { + Ok(match src.to_lowercase().as_str() { + #(#cases)* + _ => return Err(#err_name{}), + }) + } + } + } } fn panic_one_field(trait_name: &str) -> ! { diff --git a/tests/from_str.rs b/tests/from_str.rs index 4bb0eca6..8ecc3835 100644 --- a/tests/from_str.rs +++ b/tests/from_str.rs @@ -9,3 +9,25 @@ struct MyInt(i32); struct Point1D { x: i32, } + +#[derive(Debug, FromStr, PartialEq, Eq)] +enum EnumNoFields { + Foo, + Bar, + Baz, +} + +#[test] +fn enum_test() { + assert_eq!("Foo".parse::().unwrap(), EnumNoFields::Foo); + assert_eq!("FOO".parse::().unwrap(), EnumNoFields::Foo); + assert_eq!("foo".parse::().unwrap(), EnumNoFields::Foo); + assert_eq!( + "other".parse::().unwrap_err(), + ParseEnumNoFieldsError {} + ); + assert_eq!( + ParseEnumNoFieldsError {}.to_string(), + "invalid enum no fields" + ); +}