Skip to content

Commit

Permalink
Clear constant values generated code, defined behavior with the alter…
Browse files Browse the repository at this point in the history
…native variants, add error messages. (#89)
  • Loading branch information
GilShoshan94 authored Jan 9, 2023
1 parent c26e9a2 commit 4b1985c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 25 deletions.
4 changes: 2 additions & 2 deletions num_enum/tests/renamed_num_enum.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[test]
fn no_std() {
assert!(::std::process::Command::new("cargo")
.args(&[
.args([
"run",
"--manifest-path",
concat!(
Expand All @@ -17,7 +17,7 @@ fn no_std() {
#[test]
fn std() {
assert!(::std::process::Command::new("cargo")
.args(&[
.args([
"run",
"--manifest-path",
concat!(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: The discriminant '2' collides with a value attributed to a previous variant
--> tests/try_build/compile_fail/alternative_clashes_with_variant.rs:7:5
|
7 | Two = 2,
| ^^^
12 changes: 0 additions & 12 deletions num_enum/tests/try_build/compile_fail/unreachable_patterns.stderr

This file was deleted.

6 changes: 5 additions & 1 deletion num_enum/tests/try_from_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ fn alternative_values_and_default_value() {
One = 1,
#[num_enum(alternatives = [3])]
TwoOrThree = 2,
Four = 4,
}

let zero: Result<Enum, _> = 0u8.try_into();
Expand All @@ -336,7 +337,10 @@ fn alternative_values_and_default_value() {
assert_eq!(three, Ok(Enum::TwoOrThree));

let four: Result<Enum, _> = 4u8.try_into();
assert_eq!(four, Ok(Enum::Zero));
assert_eq!(four, Ok(Enum::Four));

let five: Result<Enum, _> = 5u8.try_into();
assert_eq!(five, Ok(Enum::Zero));
}

#[test]
Expand Down
92 changes: 82 additions & 10 deletions num_enum_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
extern crate proc_macro;

use ::proc_macro::TokenStream;
use ::proc_macro2::Span;
use ::quote::{format_ident, quote};
use ::syn::{
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
Data, DeriveInput, Error, Expr, Fields, Ident, LitInt, LitStr, Meta, Result,
Attribute, Data, DeriveInput, Error, Expr, Fields, Ident, Lit, LitInt, LitStr, Meta, Result,
};

macro_rules! die {
Expand All @@ -24,13 +25,23 @@ macro_rules! die {
};
}

fn literal(i: u64) -> Expr {
fn literal(i: i128) -> Expr {
let literal = LitInt::new(&i.to_string(), Span::call_site());
parse_quote! {
#literal
}
}

fn expr_to_int(val_exp: &Expr) -> Result<i128> {
Ok(match val_exp {
Expr::Lit(ref val_exp_lit) => match val_exp_lit.lit {
Lit::Int(ref lit_int) => lit_int.base10_parse()?,
_ => die!(val_exp => "Expected integer"),
},
_ => die!(val_exp => "Expected literal"),
})
}

mod kw {
syn::custom_keyword!(default);
syn::custom_keyword!(catch_all);
Expand Down Expand Up @@ -278,6 +289,9 @@ impl Parse for EnumInfo {
let mut has_default_variant: bool = false;
let mut has_catch_all_variant: bool = false;

// Vec to keep track of the used discriminants and alt values.
let mut val_set: BTreeSet<i128> = BTreeSet::new();

let mut next_discriminant = literal(0);
for variant in data.variants.into_iter() {
let ident = variant.ident.clone();
Expand All @@ -289,6 +303,8 @@ impl Parse for EnumInfo {

let mut attr_spans: AttributeSpans = Default::default();
let mut alternative_values: Vec<Expr> = vec![];
// Keep the attribute around for better error reporting.
let mut alt_attr_ref: Vec<&Attribute> = vec![];

// `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
// and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
Expand Down Expand Up @@ -366,6 +382,7 @@ impl Parse for EnumInfo {
NumEnumVariantAttributeItem::Alternatives(alternatives) => {
attr_spans.alternatives.push(alternatives.span());
alternative_values.extend(alternatives.expressions);
alt_attr_ref.push(attribute);
}
}
}
Expand All @@ -388,7 +405,63 @@ impl Parse for EnumInfo {
}
}

let canonical_value = discriminant.clone();
let canonical_value = discriminant;
let canonical_value_int = expr_to_int(&canonical_value)?;

// Check for collision.
if val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
}

// Deal with the alternative values.
let alt_val = alternative_values
.iter()
.map(expr_to_int)
.collect::<Result<Vec<_>>>()?;

debug_assert_eq!(alt_val.len(), alternative_values.len());

if !alt_val.is_empty() {
let mut alt_val_sorted = alt_val.clone();
alt_val_sorted.sort_unstable();
let alt_val_sorted = alt_val_sorted;

// check if the current discriminant is not in the alternative values.
if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
}

// Search for duplicates, the vec is sorted. Warn about them.
if (1..alt_val_sorted.len()).any(|i| alt_val_sorted[i] == alt_val_sorted[i - 1])
{
let attr = *alt_attr_ref.last().unwrap();
die!(attr => "There is duplication in the alternative values");
}
// Search if those alt_val where already attributed.
// (The val_set is BTreeSet, and last() is the is the maximum in the set.)
if let Some(last_upper_val) = val_set.last() {
if alt_val_sorted.first().unwrap() <= last_upper_val {
for (i, val) in alt_val_sorted.iter().enumerate() {
if val_set.contains(val) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
}
}
}
}

// Reconstruct the alternative_values vec of Expr but sorted.
alternative_values = alt_val_sorted
.iter()
.map(|val| literal(val.to_owned()))
.collect();

// Add the alternative values to the the set to keep track.
val_set.extend(alt_val_sorted);
}

// Add the current discriminant to the the set to keep track.
let newly_inserted = val_set.insert(canonical_value_int);
debug_assert!(newly_inserted);

variants.push(VariantInfo {
ident,
Expand All @@ -399,9 +472,8 @@ impl Parse for EnumInfo {
alternative_values,
});

next_discriminant = parse_quote! {
#repr::wrapping_add(#discriminant, 1)
};
// Get the next value for the discriminant.
next_discriminant = literal(canonical_value_int + 1);
}

EnumInfo {
Expand Down

0 comments on commit 4b1985c

Please sign in to comment.