Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variants with multiple fields in UDTs #850

Merged
merged 1 commit into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 218 additions & 79 deletions soroban-sdk-macros/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use soroban_env_common::Symbol;
use syn::{spanned::Spanned, Attribute, DataEnum, Error, Fields, Ident, Path};

use stellar_xdr::{
ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0,
ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, WriteXdr,
Error as XdrError, ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0,
ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, VecM, WriteXdr,
};

use crate::{doc::docs_from_attrs, map_type::map_type};
Expand All @@ -33,28 +33,19 @@ pub fn derive_type_enum(
.iter()
.map(|v| {
// TODO: Choose discriminant type based on repr type of enum.
// TODO: Should we use variants explicit discriminant? Probably not.
// Should have a separate derive for those types of enums that maps
// to an integer type only.
// TODO: Use attributes tagged on variant to control whether field is included.
// TODO: Support multi-field enum variants.
// TODO: Or, error on multi-field enum variants.
// TODO: Handle field names longer than a symbol. Hash the name? Truncate the name?
let ident = &v.ident;
let name = &ident.to_string();
if let Err(e) = Symbol::try_from_str(name) {
errors.push(Error::new(ident.span(), format!("enum variant name {}", e)));
}
if v.fields.len() > 1 {
errors.push(Error::new(v.fields.span(), format!("enum variant name {} has too many tuple values, max 1 supported", ident)));
}
match v.fields {
Fields::Named(_) => {
errors.push(Error::new(v.fields.span(), format!("enum variant {} has unsupported named fields", ident)));
}
_ => {}
};
let field = v.fields.iter().next();
_ => { }
}
let discriminant_const_sym_ident = format_ident!("DISCRIMINANT_SYM_{}", name.to_uppercase());
let discriminant_const_u64_ident = format_ident!("DISCRIMINANT_U64_{}", name.to_uppercase());
let discriminant_const_sym = quote! {
Expand All @@ -67,75 +58,34 @@ pub fn derive_type_enum(
#discriminant_const_sym
#discriminant_const_u64
};
if let Some(f) = field {
let spec_case = ScSpecUdtUnionCaseV0::TupleV0(
ScSpecUdtUnionCaseTupleV0 {
doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
type_: vec![
match map_type(&f.ty) {
Ok(t) => t,
Err(e) => {
errors.push(e);
ScSpecTypeDef::I32
}
},
].try_into().unwrap()
}
let has_fields = v.fields.iter().next().is_some();
if has_fields {
leighmcculloch marked this conversation as resolved.
Show resolved Hide resolved
let VariantTokens {
spec_case, try_from, try_into, try_from_xdr, into_xdr
} = map_tuple_variant(
path,
enum_ident,
&name,
ident,
&v.attrs,
&discriminant_const_sym_ident,
&discriminant_const_u64_ident,
&v.fields,
&mut errors,
);
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 1 {
return Err(#path::ConversionError);
}
Self::#ident(iter.next().ok_or(#path::ConversionError)??.try_into_val(env)?)
}
};
let try_into = quote! {
#enum_ident::#ident(ref value) => {
let tup: (#path::RawVal, #path::RawVal) = (#discriminant_const_sym_ident.into(), value.try_into_val(env)?);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 1 {
return Err(#path::xdr::Error::Invalid);
}
let rv: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?;
Self::#ident(rv.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?)
}
};
let into_xdr = quote! { #enum_ident::#ident(value) => (#name, value).try_into().map_err(|_| #path::xdr::Error::Invalid)? };
(spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr)
} else {
let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 {
doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
});
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 0 {
return Err(#path::ConversionError);
}
Self::#ident
}
};
let try_into = quote! {
#enum_ident::#ident => {
let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 0 {
return Err(#path::xdr::Error::Invalid);
}
Self::#ident
}
};
let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? };
let VariantTokens {
spec_case, try_from, try_into, try_from_xdr, into_xdr
} = map_empty_variant(
path,
enum_ident,
&name,
ident,
&v.attrs,
&discriminant_const_sym_ident,
&discriminant_const_u64_ident,
);
(spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr)
}
})
Expand Down Expand Up @@ -309,3 +259,192 @@ pub fn derive_type_enum(
}
}
}

struct VariantTokens {
spec_case: ScSpecUdtUnionCaseV0,
try_from: TokenStream2,
try_into: TokenStream2,
try_from_xdr: TokenStream2,
into_xdr: TokenStream2,
}

fn map_empty_variant(
path: &Path,
enum_ident: &Ident,
name: &str,
ident: &Ident,
attrs: &[Attribute],
discriminant_const_sym_ident: &Ident,
discriminant_const_u64_ident: &Ident,
) -> VariantTokens {
let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 {
doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
});
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 0 {
return Err(#path::ConversionError);
}
Self::#ident
}
};
let try_into = quote! {
#enum_ident::#ident => {
let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 0 {
return Err(#path::xdr::Error::Invalid);
}
Self::#ident
}
};
let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? };

VariantTokens {
spec_case,
try_from,
try_into,
try_from_xdr,
into_xdr,
}
}

fn map_tuple_variant(
path: &Path,
enum_ident: &Ident,
name: &str,
ident: &Ident,
attrs: &[Attribute],
discriminant_const_sym_ident: &Ident,
discriminant_const_u64_ident: &Ident,
fields: &Fields,
errors: &mut Vec<Error>,
) -> VariantTokens {
let spec_case = {
let field_types = fields
.iter()
.map(|f| match map_type(&f.ty) {
Ok(t) => t,
Err(e) => {
errors.push(e);
ScSpecTypeDef::I32
}
})
.collect::<Vec<_>>();
let field_types = match VecM::try_from(field_types) {
Ok(t) => t,
Err(e) => {
let v = VecM::default();
let max_len = v.max_len();
match e {
XdrError::LengthExceedsMax => {
errors.push(Error::new(
fields.span(),
format!(
"enum variant name {} has too many tuple values, max {} supported",
ident, max_len
),
));
}
e => {
errors.push(Error::new(fields.span(), format!("{e}")));
}
}
v
}
};
ScSpecUdtUnionCaseV0::TupleV0(ScSpecUdtUnionCaseTupleV0 {
doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
type_: field_types.try_into().unwrap(),
})
};
let num_fields = fields.iter().len();
let try_from = {
let field_convs = fields
.iter()
.enumerate()
.map(|(_i, _f)| {
quote! {
iter.next().ok_or(#path::ConversionError)??.try_into_val(env)?
}
})
.collect::<Vec<_>>();
quote! {
#discriminant_const_u64_ident => {
if iter.len() > #num_fields {
return Err(#path::ConversionError);
}
Self::#ident( #(#field_convs,)* )
}
}
};
let try_into = {
let fragments = fields
.iter()
.enumerate()
.map(|(i, _f)| {
let binding_name = format_ident!("value{i}");
let field_conv = quote! {
#binding_name.try_into_val(env)?
};
let tup_elem_type = quote! {
#path::RawVal
};
(binding_name, field_conv, tup_elem_type)
})
.multiunzip();
let (binding_names, field_convs, tup_elem_types): (Vec<_>, Vec<_>, Vec<_>) = fragments;
quote! {
#enum_ident::#ident(#(ref #binding_names,)* ) => {
let tup: (#path::RawVal, #(#tup_elem_types,)* ) = (#discriminant_const_sym_ident.into(), #(#field_convs,)* );
tup.try_into_val(env)
}
}
};
let try_from_xdr = {
let fragments = fields.iter().enumerate().map(|(i, _f)| {
let rawval_name = format_ident!("rv{i}");
let rawval_binding = quote! {
let #rawval_name: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?;
};
let into_field = quote! {
#rawval_name.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?
};
(rawval_binding, into_field)
}).multiunzip();
let (rawval_bindings, into_fields): (Vec<_>, Vec<_>) = fragments;
quote! {
#name => {
if iter.len() > #num_fields {
return Err(#path::xdr::Error::Invalid);
}
#(#rawval_bindings)*
Self::#ident( #(#into_fields,)* )
}
}
};
let into_xdr = {
let binding_names = fields
.iter()
.enumerate()
.map(|(i, _f)| format_ident!("value{i}"))
.collect::<Vec<_>>();
quote! {
#enum_ident::#ident( #(#binding_names,)* ) => (#name, #(#binding_names,)* ).try_into().map_err(|_| #path::xdr::Error::Invalid)?
}
};

VariantTokens {
spec_case,
try_from,
try_into,
try_from_xdr,
into_xdr,
}
}
33 changes: 32 additions & 1 deletion soroban-sdk/src/tests/contract_udt_enum.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
use crate as soroban_sdk;
use soroban_sdk::xdr::ScVec;
use soroban_sdk::{
contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal, Vec,
contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal,
TryIntoVal, Vec,
};

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[contracttype]
pub enum Udt {
Aaa,
Bbb(i32),
MaxFields(u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32),
Nested(Udt2, Udt2),
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[contracttype]
pub struct Udt2 {
a: u32,
}

pub struct Contract;
Expand Down Expand Up @@ -56,3 +66,24 @@ fn test_error_on_partial_decode() {
let udt = Udt::try_from_val(&env, &vec.to_raw());
assert_eq!(udt, Err(ConversionError));
}

#[test]
fn round_trips() {
let env = Env::default();

let before = Udt::Nested(Udt2 { a: 1 }, Udt2 { a: 2 });
let rawval: RawVal = before.try_into_val(&env).unwrap();
let after: Udt = rawval.try_into_val(&env).unwrap();
assert_eq!(before, after);
let scvec: ScVec = before.try_into().unwrap();
let after: Udt = scvec.try_into_val(&env).unwrap();
assert_eq!(before, after);

let before = Udt::MaxFields(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
let rawval: RawVal = before.try_into_val(&env).unwrap();
let after: Udt = rawval.try_into_val(&env).unwrap();
assert_eq!(before, after);
let scvec: ScVec = before.try_into().unwrap();
let after: Udt = scvec.try_into_val(&env).unwrap();
assert_eq!(before, after);
}