From 8490378fad54c10ce9a1279995a42ed7a2202fa0 Mon Sep 17 00:00:00 2001 From: Brian Anderson Date: Thu, 2 Feb 2023 16:36:09 -0700 Subject: [PATCH] Support variants with multiple fields in UDTs --- soroban-sdk-macros/src/derive_enum.rs | 297 +++++++++++++++------ soroban-sdk/src/tests/contract_udt_enum.rs | 33 ++- 2 files changed, 250 insertions(+), 80 deletions(-) diff --git a/soroban-sdk-macros/src/derive_enum.rs b/soroban-sdk-macros/src/derive_enum.rs index 750c4fcf1..64299d0c0 100644 --- a/soroban-sdk-macros/src/derive_enum.rs +++ b/soroban-sdk-macros/src/derive_enum.rs @@ -5,8 +5,8 @@ use soroban_env_common::Symbol; use syn::{spanned::Spanned, Attribute, DataEnum, Error, Fields, Ident, Path}; use stellar_xdr::{ - ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0, - ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, WriteXdr, + Error as XdrError, ScSpecEntry, ScSpecTypeDef, ScSpecUdtUnionCaseTupleV0, ScSpecUdtUnionCaseV0, + ScSpecUdtUnionCaseVoidV0, ScSpecUdtUnionV0, StringM, VecM, WriteXdr, }; use crate::{doc::docs_from_attrs, map_type::map_type}; @@ -33,28 +33,19 @@ pub fn derive_type_enum( .iter() .map(|v| { // TODO: Choose discriminant type based on repr type of enum. - // TODO: Should we use variants explicit discriminant? Probably not. - // Should have a separate derive for those types of enums that maps - // to an integer type only. // TODO: Use attributes tagged on variant to control whether field is included. - // TODO: Support multi-field enum variants. - // TODO: Or, error on multi-field enum variants. // TODO: Handle field names longer than a symbol. Hash the name? Truncate the name? let ident = &v.ident; let name = &ident.to_string(); if let Err(e) = Symbol::try_from_str(name) { errors.push(Error::new(ident.span(), format!("enum variant name {}", e))); } - if v.fields.len() > 1 { - errors.push(Error::new(v.fields.span(), format!("enum variant name {} has too many tuple values, max 1 supported", ident))); - } match v.fields { Fields::Named(_) => { errors.push(Error::new(v.fields.span(), format!("enum variant {} has unsupported named fields", ident))); } - _ => {} - }; - let field = v.fields.iter().next(); + _ => { } + } let discriminant_const_sym_ident = format_ident!("DISCRIMINANT_SYM_{}", name.to_uppercase()); let discriminant_const_u64_ident = format_ident!("DISCRIMINANT_U64_{}", name.to_uppercase()); let discriminant_const_sym = quote! { @@ -67,75 +58,34 @@ pub fn derive_type_enum( #discriminant_const_sym #discriminant_const_u64 }; - if let Some(f) = field { - let spec_case = ScSpecUdtUnionCaseV0::TupleV0( - ScSpecUdtUnionCaseTupleV0 { - doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error. - name: name.try_into().unwrap_or_else(|_| StringM::default()), - type_: vec![ - match map_type(&f.ty) { - Ok(t) => t, - Err(e) => { - errors.push(e); - ScSpecTypeDef::I32 - } - }, - ].try_into().unwrap() - } + let has_fields = v.fields.iter().next().is_some(); + if has_fields { + let VariantTokens { + spec_case, try_from, try_into, try_from_xdr, into_xdr + } = map_tuple_variant( + path, + enum_ident, + &name, + ident, + &v.attrs, + &discriminant_const_sym_ident, + &discriminant_const_u64_ident, + &v.fields, + &mut errors, ); - let try_from = quote! { - #discriminant_const_u64_ident => { - if iter.len() > 1 { - return Err(#path::ConversionError); - } - Self::#ident(iter.next().ok_or(#path::ConversionError)??.try_into_val(env)?) - } - }; - let try_into = quote! { - #enum_ident::#ident(ref value) => { - let tup: (#path::RawVal, #path::RawVal) = (#discriminant_const_sym_ident.into(), value.try_into_val(env)?); - tup.try_into_val(env) - } - }; - let try_from_xdr = quote! { - #name => { - if iter.len() > 1 { - return Err(#path::xdr::Error::Invalid); - } - let rv: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?; - Self::#ident(rv.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?) - } - }; - let into_xdr = quote! { #enum_ident::#ident(value) => (#name, value).try_into().map_err(|_| #path::xdr::Error::Invalid)? }; (spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr) } else { - let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 { - doc: docs_from_attrs(&v.attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error. - name: name.try_into().unwrap_or_else(|_| StringM::default()), - }); - let try_from = quote! { - #discriminant_const_u64_ident => { - if iter.len() > 0 { - return Err(#path::ConversionError); - } - Self::#ident - } - }; - let try_into = quote! { - #enum_ident::#ident => { - let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),); - tup.try_into_val(env) - } - }; - let try_from_xdr = quote! { - #name => { - if iter.len() > 0 { - return Err(#path::xdr::Error::Invalid); - } - Self::#ident - } - }; - let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? }; + let VariantTokens { + spec_case, try_from, try_into, try_from_xdr, into_xdr + } = map_empty_variant( + path, + enum_ident, + &name, + ident, + &v.attrs, + &discriminant_const_sym_ident, + &discriminant_const_u64_ident, + ); (spec_case, discriminant_const, try_from, try_into, try_from_xdr, into_xdr) } }) @@ -309,3 +259,192 @@ pub fn derive_type_enum( } } } + +struct VariantTokens { + spec_case: ScSpecUdtUnionCaseV0, + try_from: TokenStream2, + try_into: TokenStream2, + try_from_xdr: TokenStream2, + into_xdr: TokenStream2, +} + +fn map_empty_variant( + path: &Path, + enum_ident: &Ident, + name: &str, + ident: &Ident, + attrs: &[Attribute], + discriminant_const_sym_ident: &Ident, + discriminant_const_u64_ident: &Ident, +) -> VariantTokens { + let spec_case = ScSpecUdtUnionCaseV0::VoidV0(ScSpecUdtUnionCaseVoidV0 { + doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error. + name: name.try_into().unwrap_or_else(|_| StringM::default()), + }); + let try_from = quote! { + #discriminant_const_u64_ident => { + if iter.len() > 0 { + return Err(#path::ConversionError); + } + Self::#ident + } + }; + let try_into = quote! { + #enum_ident::#ident => { + let tup: (#path::RawVal,) = (#discriminant_const_sym_ident.into(),); + tup.try_into_val(env) + } + }; + let try_from_xdr = quote! { + #name => { + if iter.len() > 0 { + return Err(#path::xdr::Error::Invalid); + } + Self::#ident + } + }; + let into_xdr = quote! { #enum_ident::#ident => (#name,).try_into().map_err(|_| #path::xdr::Error::Invalid)? }; + + VariantTokens { + spec_case, + try_from, + try_into, + try_from_xdr, + into_xdr, + } +} + +fn map_tuple_variant( + path: &Path, + enum_ident: &Ident, + name: &str, + ident: &Ident, + attrs: &[Attribute], + discriminant_const_sym_ident: &Ident, + discriminant_const_u64_ident: &Ident, + fields: &Fields, + errors: &mut Vec, +) -> VariantTokens { + let spec_case = { + let field_types = fields + .iter() + .map(|f| match map_type(&f.ty) { + Ok(t) => t, + Err(e) => { + errors.push(e); + ScSpecTypeDef::I32 + } + }) + .collect::>(); + let field_types = match VecM::try_from(field_types) { + Ok(t) => t, + Err(e) => { + let v = VecM::default(); + let max_len = v.max_len(); + match e { + XdrError::LengthExceedsMax => { + errors.push(Error::new( + fields.span(), + format!( + "enum variant name {} has too many tuple values, max {} supported", + ident, max_len + ), + )); + } + e => { + errors.push(Error::new(fields.span(), format!("{e}"))); + } + } + v + } + }; + ScSpecUdtUnionCaseV0::TupleV0(ScSpecUdtUnionCaseTupleV0 { + doc: docs_from_attrs(attrs).try_into().unwrap(), // TODO: Truncate docs, or display friendly compile error. + name: name.try_into().unwrap_or_else(|_| StringM::default()), + type_: field_types.try_into().unwrap(), + }) + }; + let num_fields = fields.iter().len(); + let try_from = { + let field_convs = fields + .iter() + .enumerate() + .map(|(_i, _f)| { + quote! { + iter.next().ok_or(#path::ConversionError)??.try_into_val(env)? + } + }) + .collect::>(); + quote! { + #discriminant_const_u64_ident => { + if iter.len() > #num_fields { + return Err(#path::ConversionError); + } + Self::#ident( #(#field_convs,)* ) + } + } + }; + let try_into = { + let fragments = fields + .iter() + .enumerate() + .map(|(i, _f)| { + let binding_name = format_ident!("value{i}"); + let field_conv = quote! { + #binding_name.try_into_val(env)? + }; + let tup_elem_type = quote! { + #path::RawVal + }; + (binding_name, field_conv, tup_elem_type) + }) + .multiunzip(); + let (binding_names, field_convs, tup_elem_types): (Vec<_>, Vec<_>, Vec<_>) = fragments; + quote! { + #enum_ident::#ident(#(ref #binding_names,)* ) => { + let tup: (#path::RawVal, #(#tup_elem_types,)* ) = (#discriminant_const_sym_ident.into(), #(#field_convs,)* ); + tup.try_into_val(env) + } + } + }; + let try_from_xdr = { + let fragments = fields.iter().enumerate().map(|(i, _f)| { + let rawval_name = format_ident!("rv{i}"); + let rawval_binding = quote! { + let #rawval_name: #path::RawVal = iter.next().ok_or(#path::xdr::Error::Invalid)?.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)?; + }; + let into_field = quote! { + #rawval_name.try_into_val(env).map_err(|_| #path::xdr::Error::Invalid)? + }; + (rawval_binding, into_field) + }).multiunzip(); + let (rawval_bindings, into_fields): (Vec<_>, Vec<_>) = fragments; + quote! { + #name => { + if iter.len() > #num_fields { + return Err(#path::xdr::Error::Invalid); + } + #(#rawval_bindings)* + Self::#ident( #(#into_fields,)* ) + } + } + }; + let into_xdr = { + let binding_names = fields + .iter() + .enumerate() + .map(|(i, _f)| format_ident!("value{i}")) + .collect::>(); + quote! { + #enum_ident::#ident( #(#binding_names,)* ) => (#name, #(#binding_names,)* ).try_into().map_err(|_| #path::xdr::Error::Invalid)? + } + }; + + VariantTokens { + spec_case, + try_from, + try_into, + try_from_xdr, + into_xdr, + } +} diff --git a/soroban-sdk/src/tests/contract_udt_enum.rs b/soroban-sdk/src/tests/contract_udt_enum.rs index 66c7da08d..4f5e750f2 100644 --- a/soroban-sdk/src/tests/contract_udt_enum.rs +++ b/soroban-sdk/src/tests/contract_udt_enum.rs @@ -1,6 +1,8 @@ use crate as soroban_sdk; +use soroban_sdk::xdr::ScVec; use soroban_sdk::{ - contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal, Vec, + contractimpl, contracttype, symbol, vec, ConversionError, Env, IntoVal, RawVal, TryFromVal, + TryIntoVal, Vec, }; #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -8,6 +10,14 @@ use soroban_sdk::{ pub enum Udt { Aaa, Bbb(i32), + MaxFields(u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32, u32), + Nested(Udt2, Udt2), +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[contracttype] +pub struct Udt2 { + a: u32, } pub struct Contract; @@ -56,3 +66,24 @@ fn test_error_on_partial_decode() { let udt = Udt::try_from_val(&env, &vec.to_raw()); assert_eq!(udt, Err(ConversionError)); } + +#[test] +fn round_trips() { + let env = Env::default(); + + let before = Udt::Nested(Udt2 { a: 1 }, Udt2 { a: 2 }); + let rawval: RawVal = before.try_into_val(&env).unwrap(); + let after: Udt = rawval.try_into_val(&env).unwrap(); + assert_eq!(before, after); + let scvec: ScVec = before.try_into().unwrap(); + let after: Udt = scvec.try_into_val(&env).unwrap(); + assert_eq!(before, after); + + let before = Udt::MaxFields(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12); + let rawval: RawVal = before.try_into_val(&env).unwrap(); + let after: Udt = rawval.try_into_val(&env).unwrap(); + assert_eq!(before, after); + let scvec: ScVec = before.try_into().unwrap(); + let after: Udt = scvec.try_into_val(&env).unwrap(); + assert_eq!(before, after); +}