diff --git a/strum/src/additional_attributes.rs b/strum/src/additional_attributes.rs index 0b32a25a..e3d86ebd 100644 --- a/strum/src/additional_attributes.rs +++ b/strum/src/additional_attributes.rs @@ -34,6 +34,9 @@ //! ); //! ``` //! +//! You can also apply the `#[strum(ascii_case_insensitive)]` attribute to the enum, +//! and this has the same effect of applying it to every variant. +//! //! Custom attributes are applied to a variant by adding `#[strum(parameter="value")]` to the variant. //! //! - `serialize="..."`: Changes the text that `FromStr()` looks for when parsing a string. This attribute can @@ -58,6 +61,10 @@ //! //! - `disabled`: removes variant from generated code. //! +//! - `ascii_case_insensitive`: makes the comparison to this variant case insensitive (ASCII only). +//! If the whole enum is marked `ascii_case_insensitive`, you can specify `ascii_case_insensitive = false` +//! to disable case insensitivity on this variant. +//! //! - `message=".."`: Adds a message to enum variant. This is used in conjunction with the `EnumMessage` //! trait to associate a message with a variant. If `detailed_message` is not provided, //! then `message` will also be returned when get_detailed_message() is called. diff --git a/strum_macros/src/helpers/metadata.rs b/strum_macros/src/helpers/metadata.rs index e8b35588..3c676cb2 100644 --- a/strum_macros/src/helpers/metadata.rs +++ b/strum_macros/src/helpers/metadata.rs @@ -4,7 +4,7 @@ use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, spanned::Spanned, - Attribute, DeriveInput, Ident, LitStr, Path, Token, Variant, Visibility, + Attribute, DeriveInput, Ident, LitBool, LitStr, Path, Token, Variant, Visibility, }; use super::case_style::CaseStyle; @@ -28,6 +28,7 @@ pub mod kw { custom_keyword!(disabled); custom_keyword!(default); custom_keyword!(props); + custom_keyword!(ascii_case_insensitive); } pub enum EnumMeta { @@ -35,14 +36,23 @@ pub enum EnumMeta { kw: kw::serialize_all, case_style: CaseStyle, }, + AsciiCaseInsensitive(kw::ascii_case_insensitive), } impl Parse for EnumMeta { fn parse(input: ParseStream) -> syn::Result { - let kw = input.parse::()?; - input.parse::()?; - let case_style = input.parse()?; - Ok(EnumMeta::SerializeAll { kw, case_style }) + let lookahead = input.lookahead1(); + if lookahead.peek(kw::serialize_all) { + let kw = input.parse::()?; + input.parse::()?; + let case_style = input.parse()?; + Ok(EnumMeta::SerializeAll { kw, case_style }) + } else if lookahead.peek(kw::ascii_case_insensitive) { + let kw = input.parse()?; + Ok(EnumMeta::AsciiCaseInsensitive(kw)) + } else { + Err(lookahead.error()) + } } } @@ -50,6 +60,7 @@ impl Spanned for EnumMeta { fn span(&self) -> Span { match self { EnumMeta::SerializeAll { kw, .. } => kw.span(), + EnumMeta::AsciiCaseInsensitive(kw) => kw.span(), } } } @@ -142,6 +153,10 @@ pub enum VariantMeta { }, Disabled(kw::disabled), Default(kw::default), + AsciiCaseInsensitive { + kw: kw::ascii_case_insensitive, + value: bool, + }, Props { kw: kw::props, props: Vec<(LitStr, LitStr)>, @@ -175,6 +190,15 @@ impl Parse for VariantMeta { Ok(VariantMeta::Disabled(input.parse()?)) } else if lookahead.peek(kw::default) { Ok(VariantMeta::Default(input.parse()?)) + } else if lookahead.peek(kw::ascii_case_insensitive) { + let kw = input.parse()?; + let value = if input.peek(Token![=]) { + let _: Token![=] = input.parse()?; + input.parse::()?.value() + } else { + true + }; + Ok(VariantMeta::AsciiCaseInsensitive { kw, value }) } else if lookahead.peek(kw::props) { let kw = input.parse()?; let content; @@ -216,6 +240,7 @@ impl Spanned for VariantMeta { VariantMeta::ToString { kw, .. } => kw.span, VariantMeta::Disabled(kw) => kw.span, VariantMeta::Default(kw) => kw.span, + VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span, VariantMeta::Props { kw, .. } => kw.span, } } diff --git a/strum_macros/src/helpers/type_props.rs b/strum_macros/src/helpers/type_props.rs index 1e9c1e47..1e0a42c8 100644 --- a/strum_macros/src/helpers/type_props.rs +++ b/strum_macros/src/helpers/type_props.rs @@ -14,6 +14,7 @@ pub trait HasTypeProperties { #[derive(Debug, Clone, Default)] pub struct StrumTypeProperties { pub case_style: Option, + pub ascii_case_insensitive: bool, pub discriminant_derives: Vec, pub discriminant_name: Option, pub discriminant_others: Vec, @@ -28,6 +29,7 @@ impl HasTypeProperties for DeriveInput { let discriminants_meta = self.get_discriminants_metadata()?; let mut serialize_all_kw = None; + let mut ascii_case_insensitive_kw = None; for meta in strum_meta { match meta { EnumMeta::SerializeAll { case_style, kw } => { @@ -38,6 +40,14 @@ impl HasTypeProperties for DeriveInput { serialize_all_kw = Some(kw); output.case_style = Some(case_style); } + EnumMeta::AsciiCaseInsensitive(kw) => { + if let Some(fst_kw) = ascii_case_insensitive_kw { + return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive")); + } + + ascii_case_insensitive_kw = Some(kw); + output.ascii_case_insensitive = true; + } } } diff --git a/strum_macros/src/helpers/variant_props.rs b/strum_macros/src/helpers/variant_props.rs index 1077bcb2..a7e94862 100644 --- a/strum_macros/src/helpers/variant_props.rs +++ b/strum_macros/src/helpers/variant_props.rs @@ -13,6 +13,7 @@ pub trait HasStrumVariantProperties { pub struct StrumVariantProperties { pub disabled: Option, pub default: Option, + pub ascii_case_insensitive: Option, pub message: Option, pub detailed_message: Option, pub string_props: Vec<(LitStr, LitStr)>, @@ -65,6 +66,7 @@ impl HasStrumVariantProperties for Variant { let mut to_string_kw = None; let mut disabled_kw = None; let mut default_kw = None; + let mut ascii_case_insensitive_kw = None; for meta in self.get_metadata()? { match meta { VariantMeta::Message { value, kw } => { @@ -110,6 +112,14 @@ impl HasStrumVariantProperties for Variant { default_kw = Some(kw); output.default = Some(kw); } + VariantMeta::AsciiCaseInsensitive { kw, value } => { + if let Some(fst_kw) = ascii_case_insensitive_kw { + return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive")); + } + + ascii_case_insensitive_kw = Some(kw); + output.ascii_case_insensitive = Some(value); + } VariantMeta::Props { props, .. } => { output.string_props.extend(props); } diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index 4bd6fb12..df06a5d8 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -66,6 +66,10 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) { /// // Notice that we can disable certain variants from being found /// #[strum(disabled)] /// Yellow, +/// +/// // We can make the comparison case insensitive (however Unicode is not supported at the moment) +/// #[strum(ascii_case_insensitive)] +/// Black, /// } /// /// /* @@ -77,7 +81,9 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) { /// match s { /// "Red" => ::std::result::Result::Ok(Color::Red), /// "Green" => ::std::result::Result::Ok(Color::Green { range:Default::default() }), -/// "blue" | "b" => ::std::result::Result::Ok(Color::Blue(Default::default())), +/// "blue" => ::std::result::Result::Ok(Color::Blue(Default::default())), +/// "b" => ::std::result::Result::Ok(Color::Blue(Default::default())), +/// s if s.eq_ignore_ascii_case("Black") => ::std::result::Result::Ok(Color::Black), /// _ => ::std::result::Result::Err(::strum::ParseError::VariantNotFound), /// } /// } @@ -95,6 +101,8 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) { /// assert!(color_variant.is_err()); /// // however the variant is still normally usable /// println!("{:?}", Color::Yellow); +/// let color_variant = Color::from_str("bLACk").unwrap(); +/// assert_eq!(Color::Black, color_variant); /// ``` #[proc_macro_derive(EnumString, attributes(strum))] pub fn from_string(input: proc_macro::TokenStream) -> proc_macro::TokenStream { diff --git a/strum_macros/src/macros/strings/from_string.rs b/strum_macros/src/macros/strings/from_string.rs index 060bfdcb..bf17d8ec 100644 --- a/strum_macros/src/macros/strings/from_string.rs +++ b/strum_macros/src/macros/strings/from_string.rs @@ -50,8 +50,20 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { continue; } + let is_ascii_case_insensitive = variant_properties + .ascii_case_insensitive + .unwrap_or(type_properties.ascii_case_insensitive); // If we don't have any custom variants, add the default serialized name. - let attrs = variant_properties.get_serializations(type_properties.case_style); + let attrs = variant_properties + .get_serializations(type_properties.case_style) + .into_iter() + .map(|serialization| { + if is_ascii_case_insensitive { + quote! { s if s.eq_ignore_ascii_case(#serialization) } + } else { + quote! { #serialization } + } + }); let params = match &variant.fields { Fields::Unit => quote! {}, @@ -69,7 +81,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { } }; - arms.push(quote! { #(#attrs)|* => ::std::result::Result::Ok(#name::#ident #params) }); + arms.push(quote! { #(#attrs => ::std::result::Result::Ok(#name::#ident #params)),* }); } arms.push(default); diff --git a/strum_tests/tests/from_str.rs b/strum_tests/tests/from_str.rs index 0c5782b8..77922e9d 100644 --- a/strum_tests/tests/from_str.rs +++ b/strum_tests/tests/from_str.rs @@ -13,6 +13,8 @@ enum Color { Green(String), #[strum(to_string = "purp")] Purple, + #[strum(serialize = "blk", serialize = "Black", ascii_case_insensitive)] + Black, } #[test] @@ -44,6 +46,12 @@ fn color_default() { ); } +#[test] +fn color_ascii_case_insensitive() { + assert_eq!(Color::Black, Color::from_str("BLK").unwrap()); + assert_eq!(Color::Black, Color::from_str("bLaCk").unwrap()); +} + #[derive(Debug, Eq, PartialEq, EnumString)] #[strum(serialize_all = "snake_case")] enum Brightness { @@ -122,3 +130,42 @@ enum Generic { fn generic_test() { assert_eq!(Generic::Gen(""), Generic::from_str("Gen").unwrap()); } + +#[derive(Debug, Eq, PartialEq, EnumString)] +#[strum(ascii_case_insensitive)] +enum CaseInsensitiveEnum { + NoAttr, + #[strum(ascii_case_insensitive = false)] + NoCaseInsensitive, + #[strum(ascii_case_insensitive = true)] + CaseInsensitive, +} + +#[test] +fn case_insensitive_enum_no_attr() { + assert_eq!( + CaseInsensitiveEnum::NoAttr, + CaseInsensitiveEnum::from_str("noattr").unwrap() + ); +} + +#[test] +fn case_insensitive_enum_no_case_insensitive() { + assert_eq!( + CaseInsensitiveEnum::NoCaseInsensitive, + CaseInsensitiveEnum::from_str("NoCaseInsensitive").unwrap(), + ); + assert!(CaseInsensitiveEnum::from_str("nocaseinsensitive").is_err()); +} + +#[test] +fn case_insensitive_enum_case_insensitive() { + assert_eq!( + CaseInsensitiveEnum::CaseInsensitive, + CaseInsensitiveEnum::from_str("CaseInsensitive").unwrap(), + ); + assert_eq!( + CaseInsensitiveEnum::CaseInsensitive, + CaseInsensitiveEnum::from_str("caseinsensitive").unwrap(), + ); +}