Skip to content

Commit

Permalink
Support variants with multiple fields in UDTs
Browse files Browse the repository at this point in the history
  • Loading branch information
brson committed Feb 10, 2023
1 parent bb98641 commit 8490378
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 80 deletions.
297 changes: 218 additions & 79 deletions soroban-sdk-macros/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use soroban_env_common::Symbol;
use syn::{spanned::Spanned, Attribute, DataEnum, Error, Fields, Ident, Path};

use stellar_xdr::{
ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0,
ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, WriteXdr,
Error as XdrError, ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0,
ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, VecM, WriteXdr,
};

use crate::{doc::docs_from_attrs, map_type::map_type};
Expand All @@ -33,28 +33,19 @@ pub fn derive_type_enum(
.iter()
.map(|v| {
// TODO: Choose discriminant type based on repr type of enum.
// TODO: Should we use variants explicit discriminant? Probably not.
// Should have a separate derive for those types of enums that maps
// to an integer type only.
// TODO: Use attributes tagged on variant to control whether field is included.
// TODO: Support multi-field enum variants.
// TODO: Or, error on multi-field enum variants.
// TODO: Handle field names longer than a symbol. Hash the name? Truncate the name?
let ident = &v.ident;
let name = &ident.to_string();
if let Err(e) = Symbol::try_from_str(name) {
errors.push(Error::new(ident.span(), format!("enum variant name {}", e)));
}
if v.fields.len() > 1 {
errors.push(Error::new(v.fields.span(), format!("enum variant name {} has too many tuple values, max 1 supported", ident)));
}
match v.fields {
Fields::Named(_) => {
errors.push(Error::new(v.fields.span(), format!("enum variant {} has unsupported named fields", ident)));
}
_ => {}
};
let field = v.fields.iter().next();
_ => { }
}
let discriminant_const_sym_ident = format_ident!("DISCRIMINANT_SYM_{}", name.to_uppercase());
let discriminant_const_u64_ident = format_ident!("DISCRIMINANT_U64_{}", name.to_uppercase());
let discriminant_const_sym = quote! {
Expand All @@ -67,75 +58,34 @@ pub fn derive_type_enum(
#discriminant_const_sym
#discriminant_const_u64
};
if let Some(f) = field {
let spec_case = ScSpecUdtUnionCaseV0::TupleV0(
ScSpecUdtUnionCaseTupleV0 {
doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
type_: vec![
match map_type(&f.ty) {
Ok(t) => t,
Err(e) => {
errors.push(e);
ScSpecTypeDef::I32
}
},
].try_into().unwrap()
}
let has_fields = v.fields.iter().next().is_some();
if has_fields {
let VariantTokens {
spec_case, try_from, try_into, try_from_xdr, into_xdr
} = map_tuple_variant(
path,
enum_ident,
&name,
ident,
&v.attrs,
&discriminant_const_sym_ident,
&discriminant_const_u64_ident,
&v.fields,
&mut errors,
);
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 1 {
return Err(#path::ConversionError);
}
Self::#ident(iter.next().ok_or(#path::ConversionError)??.try_into_val(env)?)
}
};
let try_into = quote! {
#enum_ident::#ident(ref value) => {
let tup: (#path::RawVal, #path::RawVal) = (#discriminant_const_sym_ident.into(), value.try_into_val(env)?);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 1 {
return Err(#path::xdr::Error::Invalid);
}
let rv: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?;
Self::#ident(rv.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?)
}
};
let into_xdr = quote! { #enum_ident::#ident(value) => (#name, value).try_into().map_err(|_| #path::xdr::Error::Invalid)? };
(spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr)
} else {
let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 {
doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
});
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 0 {
return Err(#path::ConversionError);
}
Self::#ident
}
};
let try_into = quote! {
#enum_ident::#ident => {
let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 0 {
return Err(#path::xdr::Error::Invalid);
}
Self::#ident
}
};
let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? };
let VariantTokens {
spec_case, try_from, try_into, try_from_xdr, into_xdr
} = map_empty_variant(
path,
enum_ident,
&name,
ident,
&v.attrs,
&discriminant_const_sym_ident,
&discriminant_const_u64_ident,
);
(spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr)
}
})
Expand Down Expand Up @@ -309,3 +259,192 @@ pub fn derive_type_enum(
}
}
}

struct VariantTokens {
spec_case: ScSpecUdtUnionCaseV0,
try_from: TokenStream2,
try_into: TokenStream2,
try_from_xdr: TokenStream2,
into_xdr: TokenStream2,
}

fn map_empty_variant(
path: &Path,
enum_ident: &Ident,
name: &str,
ident: &Ident,
attrs: &[Attribute],
discriminant_const_sym_ident: &Ident,
discriminant_const_u64_ident: &Ident,
) -> VariantTokens {
let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 {
doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
});
let try_from = quote! {
#discriminant_const_u64_ident => {
if iter.len() > 0 {
return Err(#path::ConversionError);
}
Self::#ident
}
};
let try_into = quote! {
#enum_ident::#ident => {
let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),);
tup.try_into_val(env)
}
};
let try_from_xdr = quote! {
#name => {
if iter.len() > 0 {
return Err(#path::xdr::Error::Invalid);
}
Self::#ident
}
};
let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? };

VariantTokens {
spec_case,
try_from,
try_into,
try_from_xdr,
into_xdr,
}
}

fn map_tuple_variant(
path: &Path,
enum_ident: &Ident,
name: &str,
ident: &Ident,
attrs: &[Attribute],
discriminant_const_sym_ident: &Ident,
discriminant_const_u64_ident: &Ident,
fields: &Fields,
errors: &mut Vec<Error>,
) -> VariantTokens {
let spec_case = {
let field_types = fields
.iter()
.map(|f| match map_type(&f.ty) {
Ok(t) => t,
Err(e) => {
errors.push(e);
ScSpecTypeDef::I32
}
})
.collect::<Vec<_>>();
let field_types = match VecM::try_from(field_types) {
Ok(t) => t,
Err(e) => {
let v = VecM::default();
let max_len = v.max_len();
match e {
XdrError::LengthExceedsMax => {
errors.push(Error::new(
fields.span(),
format!(
"enum variant name {} has too many tuple values, max {} supported",
ident, max_len
),
));
}
e => {
errors.push(Error::new(fields.span(), format!("{e}")));
}
}
v
}
};
ScSpecUdtUnionCaseV0::TupleV0(ScSpecUdtUnionCaseTupleV0 {
doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error.
name: name.try_into().unwrap_or_else(|_| StringM::default()),
type_: field_types.try_into().unwrap(),
})
};
let num_fields = fields.iter().len();
let try_from = {
let field_convs = fields
.iter()
.enumerate()
.map(|(_i, _f)| {
quote! {
iter.next().ok_or(#path::ConversionError)??.try_into_val(env)?
}
})
.collect::<Vec<_>>();
quote! {
#discriminant_const_u64_ident => {
if iter.len() > #num_fields {
return Err(#path::ConversionError);
}
Self::#ident( #(#field_convs,)* )
}
}
};
let try_into = {
let fragments = fields
.iter()
.enumerate()
.map(|(i, _f)| {
let binding_name = format_ident!("value{i}");
let field_conv = quote! {
#binding_name.try_into_val(env)?
};
let tup_elem_type = quote! {
#path::RawVal
};
(binding_name, field_conv, tup_elem_type)
})
.multiunzip();
let (binding_names, field_convs, tup_elem_types): (Vec<_>, Vec<_>, Vec<_>) = fragments;
quote! {
#enum_ident::#ident(#(ref #binding_names,)* ) => {
let tup: (#path::RawVal, #(#tup_elem_types,)* ) = (#discriminant_const_sym_ident.into(), #(#field_convs,)* );
tup.try_into_val(env)
}
}
};
let try_from_xdr = {
let fragments = fields.iter().enumerate().map(|(i, _f)| {
let rawval_name = format_ident!("rv{i}");
let rawval_binding = quote! {
let #rawval_name: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?;
};
let into_field = quote! {
#rawval_name.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?
};
(rawval_binding, into_field)
}).multiunzip();
let (rawval_bindings, into_fields): (Vec<_>, Vec<_>) = fragments;
quote! {
#name => {
if iter.len() > #num_fields {
return Err(#path::xdr::Error::Invalid);
}
#(#rawval_bindings)*
Self::#ident( #(#into_fields,)* )
}
}
};
let into_xdr = {
let binding_names = fields
.iter()
.enumerate()
.map(|(i, _f)| format_ident!("value{i}"))
.collect::<Vec<_>>();
quote! {
#enum_ident::#ident( #(#binding_names,)* ) => (#name, #(#binding_names,)* ).try_into().map_err(|_| #path::xdr::Error::Invalid)?
}
};

VariantTokens {
spec_case,
try_from,
try_into,
try_from_xdr,
into_xdr,
}
}
33 changes: 32 additions & 1 deletion soroban-sdk/src/tests/contract_udt_enum.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
use crate as soroban_sdk;
use soroban_sdk::xdr::ScVec;
use soroban_sdk::{
contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal, Vec,
contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal,
TryIntoVal, Vec,
};

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[contracttype]
pub enum Udt {
Aaa,
Bbb(i32),
MaxFields(u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32),
Nested(Udt2, Udt2),
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[contracttype]
pub struct Udt2 {
a: u32,
}

pub struct Contract;
Expand Down Expand Up @@ -56,3 +66,24 @@ fn test_error_on_partial_decode() {
let udt = Udt::try_from_val(&env, &vec.to_raw());
assert_eq!(udt, Err(ConversionError));
}

#[test]
fn round_trips() {
let env = Env::default();

let before = Udt::Nested(Udt2 { a: 1 }, Udt2 { a: 2 });
let rawval: RawVal = before.try_into_val(&env).unwrap();
let after: Udt = rawval.try_into_val(&env).unwrap();
assert_eq!(before, after);
let scvec: ScVec = before.try_into().unwrap();
let after: Udt = scvec.try_into_val(&env).unwrap();
assert_eq!(before, after);

let before = Udt::MaxFields(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
let rawval: RawVal = before.try_into_val(&env).unwrap();
let after: Udt = rawval.try_into_val(&env).unwrap();
assert_eq!(before, after);
let scvec: ScVec = before.try_into().unwrap();
let after: Udt = scvec.try_into_val(&env).unwrap();
assert_eq!(before, after);
}

0 comments on commit 8490378

Please sign in to comment.