diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2964f6263b..1ef81d97ce 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2021-01-30 override: true - name: coverage with tarpaulin run: | @@ -66,7 +66,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2021-01-30 override: true components: rustfmt, clippy - name: Validate release notes entry @@ -116,7 +116,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2021-01-30 override: true - name: Build run: cargo build --all-features --verbose @@ -132,7 +132,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2021-01-30 override: true - name: Run WASM tests run: wasm-pack test --node -- --workspace @@ -167,7 +167,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly + toolchain: nightly-2021-01-30 override: true - name: Build run: cargo build --all-features --release && strip target/release/fe && mv target/release/fe target/release/${{ matrix.BIN_FILE }} diff --git a/analyzer/src/lib.rs b/analyzer/src/lib.rs index f6e523b735..e1f0dbd061 100644 --- a/analyzer/src/lib.rs +++ b/analyzer/src/lib.rs @@ -19,6 +19,7 @@ use crate::namespace::scopes::{ use crate::namespace::types::{ Contract, FixedSize, + Struct, Type, }; use fe_parser::ast as fe; @@ -55,6 +56,7 @@ impl Location { Type::Array(_) => Ok(Location::Memory), Type::Tuple(_) => Ok(Location::Memory), Type::String(_) => Ok(Location::Memory), + Type::Struct(_) => Ok(Location::Memory), Type::Map(_) => Err(SemanticError::cannot_move()), } } @@ -71,6 +73,8 @@ pub struct ContractAttributes { pub events: Vec, /// Static strings that the contract defines pub string_literals: HashSet, + /// Structs that have been defined by the user + pub structs: Vec, /// External contracts that may be called from within this contract. pub external_contracts: Vec, } @@ -108,6 +112,14 @@ impl From> for ContractAttributes { } }); + let structs = scope.borrow().get_module_type_defs(|typ| { + if let Type::Struct(val) = typ { + Some(val.to_owned()) + } else { + None + } + }); + ContractAttributes { public_functions, init_function, @@ -118,6 +130,7 @@ impl From> for ContractAttributes { .map(|event| event.to_owned()) .collect::>(), string_literals: scope.borrow().string_defs.clone(), + structs, external_contracts, } } @@ -248,12 +261,8 @@ impl Context { } /// Attribute contextual information to an expression node. - pub fn add_expression( - &mut self, - spanned: &Spanned, - attributes: ExpressionAttributes, - ) { - self.expressions.insert(spanned.span, attributes); + pub fn add_expression>(&mut self, span: T, attributes: ExpressionAttributes) { + self.expressions.insert(span.into(), attributes); } /// Get information that has been attributed to an expression node. diff --git a/analyzer/src/namespace/operations.rs b/analyzer/src/namespace/operations.rs index 306d5378a7..f1553a927d 100644 --- a/analyzer/src/namespace/operations.rs +++ b/analyzer/src/namespace/operations.rs @@ -17,6 +17,7 @@ pub fn index(value: Type, index: Type) -> Result { Type::Tuple(_) => Err(SemanticError::not_subscriptable()), Type::String(_) => Err(SemanticError::not_subscriptable()), Type::Contract(_) => Err(SemanticError::not_subscriptable()), + Type::Struct(_) => Err(SemanticError::not_subscriptable()), } } diff --git a/analyzer/src/namespace/types.rs b/analyzer/src/namespace/types.rs index 36b9e2e99c..e083837b43 100644 --- a/analyzer/src/namespace/types.rs +++ b/analyzer/src/namespace/types.rs @@ -1,6 +1,9 @@ use crate::errors::SemanticError; use fe_parser::ast as fe; -use std::collections::HashMap; +use std::collections::{ + BTreeMap, + HashMap, +}; use std::convert::TryFrom; use std::num::{ IntErrorKind, @@ -103,6 +106,7 @@ pub enum Type { Tuple(Tuple), String(FeString), Contract(Contract), + Struct(Struct), } #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] @@ -112,6 +116,7 @@ pub enum FixedSize { Tuple(Tuple), String(FeString), Contract(Contract), + Struct(Struct), } #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] @@ -157,6 +162,12 @@ pub struct Tuple { pub items: Vec, } +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Struct { + pub name: String, + fields: BTreeMap, +} + #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] pub struct FeString { pub max_size: usize, @@ -168,6 +179,40 @@ pub struct Contract { pub functions: Vec, } +impl Struct { + pub fn new(name: &str) -> Struct { + Struct { + name: name.to_string(), + fields: BTreeMap::new(), + } + } + + // Return `true` if the struct has any fields, otherwise return `false` + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// Add a field to the struct + pub fn add_field(&mut self, name: &str, value: &Base) -> Option { + self.fields.insert(name.to_string(), value.clone()) + } + + // Return the type of the given field name + pub fn get_field_type(&self, name: &str) -> Option<&Base> { + self.fields.get(name) + } + + // Return a vector of field types + pub fn get_field_types(&self) -> Vec { + self.fields.values().map(|val| val.clone().into()).collect() + } + + //Return a vector of field names + pub fn get_field_names(&self) -> Vec { + self.fields.keys().cloned().collect() + } +} + impl TryFrom<&str> for FeString { type Error = String; @@ -258,10 +303,17 @@ impl From for Type { FixedSize::Tuple(tuple) => Type::Tuple(tuple), FixedSize::String(string) => Type::String(string), FixedSize::Contract(contract) => Type::Contract(contract), + FixedSize::Struct(val) => Type::Struct(val), } } } +impl From for Type { + fn from(value: Base) -> Self { + Type::Base(value) + } +} + impl FeSized for FixedSize { fn size(&self) -> usize { match self { @@ -270,6 +322,7 @@ impl FeSized for FixedSize { FixedSize::Tuple(tuple) => tuple.size(), FixedSize::String(string) => string.size(), FixedSize::Contract(contract) => contract.size(), + FixedSize::Struct(val) => val.size(), } } } @@ -301,6 +354,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_name(), FixedSize::String(string) => string.abi_name(), FixedSize::Contract(contract) => contract.abi_name(), + FixedSize::Struct(val) => val.abi_name(), } } @@ -311,6 +365,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_safe_name(), FixedSize::String(string) => string.abi_safe_name(), FixedSize::Contract(contract) => contract.abi_safe_name(), + FixedSize::Struct(val) => val.abi_safe_name(), } } @@ -321,6 +376,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_type(), FixedSize::String(string) => string.abi_type(), FixedSize::Contract(contract) => contract.abi_type(), + FixedSize::Struct(val) => val.abi_type(), } } } @@ -348,6 +404,7 @@ impl TryFrom for FixedSize { Type::Base(base) => Ok(FixedSize::Base(base)), Type::Tuple(tuple) => Ok(FixedSize::Tuple(tuple)), Type::String(string) => Ok(FixedSize::String(string)), + Type::Struct(val) => Ok(FixedSize::Struct(val)), Type::Map(_) => Err(SemanticError::type_error()), Type::Contract(contract) => Ok(FixedSize::Contract(contract)), } @@ -543,6 +600,26 @@ impl FeSized for Tuple { } } +impl FeSized for Struct { + fn size(&self) -> usize { + self.fields.len() * 32 + } +} + +impl AbiEncoding for Struct { + fn abi_name(&self) -> String { + unimplemented!(); + } + + fn abi_safe_name(&self) -> String { + unimplemented!(); + } + + fn abi_type(&self) -> AbiType { + unimplemented!(); + } +} + impl AbiEncoding for Tuple { fn abi_name(&self) -> String { unimplemented!(); diff --git a/analyzer/src/traversal/expressions.rs b/analyzer/src/traversal/expressions.rs index 34d8585e48..29b743f04b 100644 --- a/analyzer/src/traversal/expressions.rs +++ b/analyzer/src/traversal/expressions.rs @@ -13,6 +13,7 @@ use crate::namespace::types::{ FeString, FixedSize, Integer, + Struct, Type, U256, }; @@ -45,7 +46,7 @@ pub fn expr( fe::Expr::Num(_) => expr_num(exp), fe::Expr::Bool(_) => expr_bool(exp), fe::Expr::Subscript { .. } => expr_subscript(scope, Rc::clone(&context), exp), - fe::Expr::Attribute { .. } => expr_attribute(scope, exp), + fe::Expr::Attribute { .. } => expr_attribute(scope, Rc::clone(&context), exp), fe::Expr::Ternary { .. } => expr_ternary(scope, Rc::clone(&context), exp), fe::Expr::BoolOperation { .. } => unimplemented!(), fe::Expr::BinOperation { .. } => expr_bin_operation(scope, Rc::clone(&context), exp), @@ -138,7 +139,6 @@ pub fn expr_name_str<'a>(exp: &Spanned>) -> Result<&'a str, Semanti if let fe::Expr::Name(name) = exp.node { return Ok(name); } - unreachable!() } @@ -198,6 +198,10 @@ fn expr_name( Location::Memory, )), Some(FixedSize::Tuple(_)) => unimplemented!(), + Some(FixedSize::Struct(val)) => Ok(ExpressionAttributes::new( + Type::Struct(val), + Location::Memory, + )), None => Err(SemanticError::undefined_value()), }; } @@ -275,6 +279,7 @@ fn expr_subscript( fn expr_attribute( scope: Shared, + context: Shared, exp: &Spanned, ) -> Result { if let fe::Expr::Attribute { value, attr } = &exp.node { @@ -286,10 +291,18 @@ fn expr_attribute( TxField, }; + let object_name = expr_name_str(value)?; + + // Before we try to match any pre-defined objects, try matching as a + // custom type + if let Some(FixedSize::Struct(_)) = scope.borrow().variable_def(object_name.to_string()) { + return expr_attribute_custom_type(Rc::clone(&scope), context, value, attr); + } + let val = |t| Ok(ExpressionAttributes::new(Type::Base(t), Location::Value)); let err = || Err(SemanticError::undefined_value()); - return match Object::from_str(expr_name_str(value)?) { + return match Object::from_str(object_name) { Ok(Object::Self_) => expr_attribute_self(scope, attr), Ok(Object::Block) => match BlockField::from_str(attr.node) { @@ -322,6 +335,35 @@ fn expr_attribute( unreachable!() } +fn expr_attribute_custom_type( + scope: Shared, + context: Shared, + value: &Spanned, + attr: &Spanned<&str>, +) -> Result { + let val_str = expr_name_str(value)?; + let custom_type = scope + .borrow() + .variable_def(val_str.to_string()) + .ok_or_else(SemanticError::undefined_value)?; + context.borrow_mut().add_expression( + value, + ExpressionAttributes::new(custom_type.clone().into(), Location::Memory), + ); + match custom_type { + FixedSize::Struct(val) => { + let field_type = val + .get_field_type(attr.node) + .ok_or_else(SemanticError::undefined_value)?; + Ok(ExpressionAttributes::new( + Type::Base(field_type.clone()), + Location::Memory, + )) + } + _ => Err(SemanticError::undefined_value()), + } +} + fn expr_attribute_self( scope: Shared, attr: &Spanned<&str>, @@ -431,12 +473,34 @@ fn expr_call( unreachable!() } +fn expr_call_struct_constructor( + scope: Shared, + context: Shared, + typ: Struct, + args: &Spanned>>, +) -> Result { + let argument_attributes = expr_call_args(Rc::clone(&scope), Rc::clone(&context), args)?; + + if typ.get_field_types() != expression_attributes_to_types(argument_attributes) { + return Err(SemanticError::type_error()); + } + + Ok(ExpressionAttributes::new( + Type::Struct(typ), + Location::Memory, + )) +} + fn expr_call_type_constructor( scope: Shared, context: Shared, typ: Type, args: &Spanned>>, ) -> Result { + if let Type::Struct(val) = typ { + return expr_call_struct_constructor(scope, context, val, args); + } + if args.node.len() != 1 { return Err(SemanticError::wrong_number_of_params()); } @@ -507,6 +571,17 @@ fn validate_str_literal_fits_type( Err(SemanticError::type_error()) } +fn expr_call_args( + scope: Shared, + context: Shared, + args: &Spanned>>, +) -> Result, SemanticError> { + args.node + .iter() + .map(|arg| call_arg(Rc::clone(&scope), Rc::clone(&context), arg)) + .collect::, _>>() +} + fn expr_call_self_attribute( scope: Shared, context: Shared, @@ -524,11 +599,7 @@ fn expr_call_self_attribute( .borrow() .function_def(func_name) { - let argument_attributes = args - .node - .iter() - .map(|arg| call_arg(Rc::clone(&scope), Rc::clone(&context), arg)) - .collect::, _>>()?; + let argument_attributes = expr_call_args(Rc::clone(&scope), Rc::clone(&context), args)?; if param_types.len() != argument_attributes.len() { return Err(SemanticError::wrong_number_of_params()); diff --git a/analyzer/src/traversal/mod.rs b/analyzer/src/traversal/mod.rs index 7ee3c758c7..249228cea2 100644 --- a/analyzer/src/traversal/mod.rs +++ b/analyzer/src/traversal/mod.rs @@ -5,4 +5,5 @@ mod declarations; mod expressions; mod functions; pub mod module; +mod structs; mod types; diff --git a/analyzer/src/traversal/module.rs b/analyzer/src/traversal/module.rs index ef4877eefe..07d5856ba4 100644 --- a/analyzer/src/traversal/module.rs +++ b/analyzer/src/traversal/module.rs @@ -4,7 +4,10 @@ use crate::namespace::scopes::{ Shared, }; use crate::namespace::types; -use crate::traversal::contracts; +use crate::traversal::{ + contracts, + structs, +}; use crate::Context; use fe_parser::ast as fe; use fe_parser::span::Spanned; @@ -17,6 +20,9 @@ pub fn module(context: Shared, module: &fe::Module) -> Result<(), Seman for stmt in module.body.iter() { match &stmt.node { fe::ModuleStmt::TypeDef { .. } => type_def(Rc::clone(&scope), stmt)?, + fe::ModuleStmt::StructDef { name, body } => { + structs::struct_def(Rc::clone(&scope), name.node, body)? + } fe::ModuleStmt::ContractDef { .. } => { contracts::contract_def(Rc::clone(&scope), Rc::clone(&context), stmt)? } diff --git a/analyzer/src/traversal/structs.rs b/analyzer/src/traversal/structs.rs new file mode 100644 index 0000000000..6a870521fd --- /dev/null +++ b/analyzer/src/traversal/structs.rs @@ -0,0 +1,36 @@ +use fe_parser::{ + ast::StructStmt, + span::Spanned, +}; + +use crate::errors::SemanticError; +use crate::namespace::scopes::{ + ModuleScope, + Shared, +}; +use crate::namespace::types::{ + type_desc, + Struct, + Type, +}; + +pub fn struct_def( + module_scope: Shared, + name: &str, + struct_stmts: &[Spanned], +) -> Result<(), SemanticError> { + let mut val = Struct::new(name); + for stmt in struct_stmts { + let StructStmt::StructField { name, typ, .. } = &stmt.node; + let field_type = type_desc(&module_scope.borrow().type_defs, &typ.node)?; + if let Type::Base(base_typ) = field_type { + val.add_field(name.node, &base_typ); + } else { + todo!("Non-Base type fields aren't yet supported") + } + } + module_scope + .borrow_mut() + .add_type_def(name.to_string(), Type::Struct(val)); + Ok(()) +} diff --git a/compiler/src/yul/mappers/expressions.rs b/compiler/src/yul/mappers/expressions.rs index b830023234..b46305a5ed 100644 --- a/compiler/src/yul/mappers/expressions.rs +++ b/compiler/src/yul/mappers/expressions.rs @@ -1,13 +1,19 @@ use crate::errors::CompileError; use crate::yul::names; -use crate::yul::operations::calls as call_operations; -use crate::yul::operations::data as data_operations; +use crate::yul::operations::{ + calls as call_operations, + data as data_operations, + structs as struct_operations, +}; use crate::yul::utils; -use fe_analyzer::builtins; use fe_analyzer::namespace::types::{ FixedSize, Type, }; +use fe_analyzer::{ + builtins, + ExpressionAttributes, +}; use fe_analyzer::{ CallType, Context, @@ -98,6 +104,9 @@ fn expr_call(context: &Context, exp: &Spanned) -> Result>()?; return match call_type { + CallType::TypeConstructor { + typ: Type::Struct(val), + } => Ok(struct_operations::new(val, yul_args)), CallType::TypeConstructor { .. } => Ok(yul_args[0].to_owned()), CallType::SelfAttribute { func_name } => { let func_name = names::func_name(func_name); @@ -341,7 +350,25 @@ fn expr_attribute( Object, TxField, }; - return match Object::from_str(expr_name_str(value)?) { + + let object_name = expr_name_str(value)?; + + // Before we try to match any known pre-defined objects, try matching as a + // custom type + if let Some(ExpressionAttributes { + typ: Type::Struct(val), + .. + }) = context.get_expression(&*value) + { + let custom_type = format!("${}", object_name); + return Ok(struct_operations::get_attribute( + val, + &custom_type, + attr.node, + )); + } + + return match Object::from_str(object_name) { Ok(Object::Self_) => expr_attribute_self(context, exp), Ok(Object::Block) => match BlockField::from_str(attr.node) { diff --git a/compiler/src/yul/mappers/module.rs b/compiler/src/yul/mappers/module.rs index f5cc1d4966..218e219368 100644 --- a/compiler/src/yul/mappers/module.rs +++ b/compiler/src/yul/mappers/module.rs @@ -22,6 +22,7 @@ pub fn module(context: &Context, module: &fe::Module) -> Result {} fe::ModuleStmt::FromImport { .. } => unimplemented!(), fe::ModuleStmt::SimpleImport { .. } => unimplemented!(), } diff --git a/compiler/src/yul/names.rs b/compiler/src/yul/names.rs index 34b58e155e..91287dbc3a 100644 --- a/compiler/src/yul/names.rs +++ b/compiler/src/yul/names.rs @@ -45,6 +45,23 @@ pub fn contract_call(contract_name: &str, func_name: &str) -> yul::Identifier { identifier! { (name) } } +/// Generates a function name for to interact with a certain struct type +pub fn struct_function_name(struct_name: &str, func_name: &str) -> yul::Identifier { + let name = format!("struct_{}_{}", struct_name, func_name); + identifier! { (name) } +} + +/// Generates a function name for creating a certain struct type +pub fn struct_new_call(struct_name: &str) -> yul::Identifier { + struct_function_name(struct_name, "new") +} + +/// Generates a function name for reading a named property of a certain struct +/// type +pub fn struct_getter_call(struct_name: &str, field_name: &str) -> yul::Identifier { + struct_function_name(struct_name, &format!("get_{}_ptr", field_name)) +} + #[cfg(test)] mod tests { use crate::yul::names::{ diff --git a/compiler/src/yul/operations/mod.rs b/compiler/src/yul/operations/mod.rs index 20f72dcd7e..a65780f2c8 100644 --- a/compiler/src/yul/operations/mod.rs +++ b/compiler/src/yul/operations/mod.rs @@ -1,3 +1,4 @@ pub mod abi; pub mod calls; pub mod data; +pub mod structs; diff --git a/compiler/src/yul/operations/structs.rs b/compiler/src/yul/operations/structs.rs new file mode 100644 index 0000000000..e8b1fda948 --- /dev/null +++ b/compiler/src/yul/operations/structs.rs @@ -0,0 +1,39 @@ +use crate::yul::names; +use fe_analyzer::namespace::types::Struct; +use yultsur::*; + +pub fn new(struct_type: &Struct, params: Vec) -> yul::Expression { + let function_name = names::struct_new_call(&struct_type.name); + expression! { [function_name]([params...]) } +} + +pub fn get_attribute(struct_type: &Struct, ptr_name: &str, field_name: &str) -> yul::Expression { + let function_name = names::struct_getter_call(&struct_type.name, field_name); + let ptr_name_exp = identifier_expression! {(ptr_name)}; + expression! { [function_name]([ptr_name_exp]) } +} + +#[cfg(test)] +mod tests { + use crate::yul::operations::structs; + use fe_analyzer::namespace::types::{ + Base, + Struct, + }; + use yultsur::*; + + #[test] + fn test_new() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + let params = vec![ + identifier_expression! { (1) }, + identifier_expression! { (2) }, + ]; + assert_eq!( + structs::new(&val, params).to_string(), + "struct_Foo_new(1, 2)" + ) + } +} diff --git a/compiler/src/yul/runtime/functions/mod.rs b/compiler/src/yul/runtime/functions/mod.rs index 43dd9ff692..bddfefacec 100644 --- a/compiler/src/yul/runtime/functions/mod.rs +++ b/compiler/src/yul/runtime/functions/mod.rs @@ -4,6 +4,7 @@ use yultsur::*; pub mod abi; pub mod calls; pub mod data; +pub mod structs; /// Returns all functions that should be available during runtime. pub fn std() -> Vec { diff --git a/compiler/src/yul/runtime/functions/structs.rs b/compiler/src/yul/runtime/functions/structs.rs new file mode 100644 index 0000000000..56fa06f66a --- /dev/null +++ b/compiler/src/yul/runtime/functions/structs.rs @@ -0,0 +1,136 @@ +use crate::yul::names; +use fe_analyzer::namespace::types::Struct; +use yultsur::*; + +/// Generate a YUL function that can be used to create an instance of +/// `struct_type` +pub fn generate_new_fn(struct_type: &Struct) -> yul::Statement { + let function_name = names::struct_new_call(&struct_type.name); + + if struct_type.is_empty() { + // We return 0 here because it is safe to assume that we never write to an empty + // struct. If we end up writing to an empty struct that's an actual Fe + // bug. + let body = statement! { return_val := 0 }; + return function_definition! { + function [function_name]() -> return_val { + [body] + } + }; + } + + let params = struct_type + .get_field_names() + .iter() + .map(|key| { + identifier! {(key)} + }) + .collect::>(); + + let body = struct_type + .get_field_names() + .iter() + .enumerate() + .map(|(index, key)| { + if index == 0 { + let param_identifier_exp = identifier_expression! {(key)}; + statements! { + (return_val := alloc(32)) + (mstore(return_val, [param_identifier_exp])) + } + } else { + let ptr_identifier = format!("{}_ptr", key); + let ptr_identifier = identifier! {(ptr_identifier)}; + let ptr_identifier_exp = identifier_expression! {(ptr_identifier)}; + let param_identifier_exp = identifier_expression! {(key)}; + statements! { + (let [ptr_identifier] := alloc(32)) + (mstore([ptr_identifier_exp], [param_identifier_exp])) + } + } + }) + .flatten() + .collect::>(); + + function_definition! { + function [function_name]([params...]) -> return_val { + [body...] + } + } +} + +/// Generate a YUL function that can be used to read a property of `struct_type` +pub fn generate_get_fn(struct_type: &Struct, field_name: &str) -> yul::Statement { + let function_name = names::struct_getter_call(&struct_type.name, field_name); + let field_index = struct_type + .get_field_names() + .iter() + .position(|field| field == field_name) + .unwrap_or_else(|| panic!("No field {} in {}", field_name, struct_type.name)); + let field_offset = field_index * 32; + + let offset = literal_expression! {(field_offset)}; + let return_expression = expression! { add(ptr, [offset]) }; + let body = statement! { (return_val := [return_expression]) }; + function_definition! { + function [function_name](ptr) -> return_val { + [body] + } + } +} + +/// Builds a set of functions used to interact with structs used in a contract +pub fn struct_apis(struct_type: Struct) -> Vec { + [ + vec![generate_new_fn(&struct_type)], + struct_type + .get_field_names() + .iter() + .map(|field| generate_get_fn(&struct_type, &field)) + .collect(), + ] + .concat() +} + +#[cfg(test)] +mod tests { + use crate::yul::runtime::functions::structs; + use fe_analyzer::namespace::types::{ + Base, + Struct, + }; + + #[test] + fn test_empty_struct() { + assert_eq!( + structs::generate_new_fn(&Struct::new("Foo")).to_string(), + "function struct_Foo_new() -> return_val { return_val := 0 }" + ) + } + + #[test] + fn test_struct_api_generation() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + assert_eq!( + structs::generate_new_fn(&val).to_string(), + "function struct_Foo_new(bar, bar2) -> return_val { return_val := alloc(32) mstore(return_val, bar) let bar2_ptr := alloc(32) mstore(bar2_ptr, bar2) }" + ) + } + + #[test] + fn test_struct_getter_generation() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + assert_eq!( + structs::generate_get_fn(&val, &val.get_field_names().get(0).unwrap()).to_string(), + "function struct_Foo_get_bar_ptr(ptr) -> return_val { return_val := add(ptr, 0) }" + ); + assert_eq!( + structs::generate_get_fn(&val, &val.get_field_names().get(1).unwrap()).to_string(), + "function struct_Foo_get_bar2_ptr(ptr) -> return_val { return_val := add(ptr, 32) }" + ); + } +} diff --git a/compiler/src/yul/runtime/mod.rs b/compiler/src/yul/runtime/mod.rs index f619013534..bbb2c857e6 100644 --- a/compiler/src/yul/runtime/mod.rs +++ b/compiler/src/yul/runtime/mod.rs @@ -86,7 +86,14 @@ pub fn build(context: &Context, contract: &Spanned) -> Vec>() + .concat(); + + return [std, encoding, decoding, contract_calls, struct_apis].concat(); } panic!("missing contract attributes") diff --git a/compiler/tests/evm_contracts.rs b/compiler/tests/evm_contracts.rs index 5f330b50ee..c3cda0a7c2 100644 --- a/compiler/tests/evm_contracts.rs +++ b/compiler/tests/evm_contracts.rs @@ -952,6 +952,14 @@ fn sized_vals_in_sto() { }); } +#[test] +fn structs() { + with_executor(&|mut executor| { + let harness = deploy_contract(&mut executor, "structs.fe", "Foo", &[]); + harness.test_function(&mut executor, "bar", &[], Some(&uint_token(2))); + }); +} + #[test] fn erc20_token() { with_executor(&|mut executor| { diff --git a/compiler/tests/fixtures/structs.fe b/compiler/tests/fixtures/structs.fe new file mode 100644 index 0000000000..4c86aa1829 --- /dev/null +++ b/compiler/tests/fixtures/structs.fe @@ -0,0 +1,21 @@ +struct House: + price: u256 + size: u256 + vacant: bool + +contract Foo: + + pub def bar() -> u256: + building: House = House(300, 500, true) + assert building.size == 500 + assert building.price == 300 + assert building.vacant + + building.vacant = false + building.price = 1 + building.size = 2 + + assert building.vacant == false + assert building.price == 1 + assert building.size == 2 + return building.size diff --git a/newsfragments/203.feature.md b/newsfragments/203.feature.md new file mode 100644 index 0000000000..d83b0bfa9f --- /dev/null +++ b/newsfragments/203.feature.md @@ -0,0 +1,22 @@ +Add basic support for structs. + +Example: + +``` +struct House: + price: u256 + size: u256 + vacant: bool + +contract City: + + pub def get_price() -> u256: + building: House = House(300, 500, true) + + assert building.size == 500 + assert building.price == 300 + assert building.vacant + + return building.price +``` + diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 7585fc056b..22fd134d70 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -33,6 +33,11 @@ pub enum ModuleStmt<'a> { #[serde(borrow)] body: Vec>>, }, + StructDef { + name: Spanned<&'a str>, + #[serde(borrow)] + body: Vec>>, + }, } #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] @@ -108,12 +113,28 @@ pub enum ContractStmt<'a> { }, } +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum StructStmt<'a> { + StructField { + qual: Option>, + #[serde(borrow)] + name: Spanned<&'a str>, + typ: Spanned>, + }, +} + #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub enum ContractFieldQual { Const, Pub, } +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum StructFieldQual { + Const, + Pub, +} + #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct EventField<'a> { pub qual: Option>, diff --git a/parser/src/ast_traits.rs b/parser/src/ast_traits.rs index 3e7ab13ffb..70a3402f07 100644 --- a/parser/src/ast_traits.rs +++ b/parser/src/ast_traits.rs @@ -24,6 +24,23 @@ impl TryFrom<&Token<'_>> for Spanned { } } +impl TryFrom<&Token<'_>> for Spanned { + type Error = &'static str; + + #[cfg_attr(tarpaulin, rustfmt::skip)] + fn try_from(tok: &Token) -> Result { + use StructFieldQual::*; + + let span = tok.span; + + Ok(match tok.string { + "const" => Spanned { node: Const, span }, + "pub" => Spanned { node: Pub, span }, + _ => return Err("unrecognized string"), + }) + } +} + impl TryFrom<&Token<'_>> for Spanned { type Error = &'static str; diff --git a/parser/src/parsers.rs b/parser/src/parsers.rs index bf921937b8..53c9b3b4ac 100644 --- a/parser/src/parsers.rs +++ b/parser/src/parsers.rs @@ -152,7 +152,7 @@ pub fn non_empty_file_input(input: Cursor) -> ParseResult> { /// Parse a module statement, such as a contract definition. pub fn module_stmt(input: Cursor) -> ParseResult> { - alt((import_stmt, type_def, contract_def))(input) + alt((import_stmt, type_def, contract_def, struct_def))(input) } /// Parse an import statement. @@ -470,6 +470,68 @@ pub fn contract_field(input: Cursor) -> ParseResult> { )) } +/// Parse a struct definition statement. +pub fn struct_def(input: Cursor) -> ParseResult> { + // "struct" name ":" NEWLINE + let (input, contract_kw) = name("struct")(input)?; + let (input, name_tok) = name_token(input)?; + let (input, _) = op(":")(input)?; + let (input, _) = newline_token(input)?; + + // INDENT struct_field+ DEDENT + let (input, _) = indent_token(input)?; + let (input, body) = many1(struct_field)(input)?; + let (input, _) = dedent_token(input)?; + + let last_stmt = body.last().unwrap(); + let span = Span::from_pair(contract_kw, last_stmt); + + Ok(( + input, + Spanned { + node: StructDef { + name: name_tok.into(), + body, + }, + span, + }, + )) +} + +/// Parse a struct field definition. +pub fn struct_field(input: Cursor) -> ParseResult> { + let (input, (qual, name_tok)) = alt(( + // Look for a qualifier and field name first... + map(pair(struct_field_qual, name_token), |res| { + let (qual, tok) = res; + (Some(qual), tok) + }), + // ...then fall back to just a field name + map(name_token, |tok| (None, tok)), + ))(input)?; + + let (input, _) = op(":")(input)?; + let (input, typ) = type_desc(input)?; + let (input, _) = newline_token(input)?; + + let span = match &qual { + Some(spanned) => Span::from_pair(spanned, &typ), + None => Span::from_pair(name_tok, &typ), + }; + + Ok(( + input, + Spanned { + node: StructStmt::StructField { + qual, + name: name_tok.into(), + typ, + }, + span, + }, + )) +} + /// Parse an event definition statement. pub fn event_def(input: Cursor) -> ParseResult> { // "event" name ":" NEWLINE @@ -759,6 +821,11 @@ pub fn contract_field_qual(input: Cursor) -> ParseResult ParseResult> { + try_from_tok(name("pub"))(input) +} + /// Parse an event field qualifier keyword i.e. "idx". pub fn event_field_qual(input: Cursor) -> ParseResult> { try_from_tok(name("idx"))(input) diff --git a/parser/tests/fixtures/parsers/struct_def.ron b/parser/tests/fixtures/parsers/struct_def.ron new file mode 100644 index 0000000000..57110882f4 --- /dev/null +++ b/parser/tests/fixtures/parsers/struct_def.ron @@ -0,0 +1,47 @@ +struct Foo: + x: address +--- +[ + Spanned( + node: StructDef( + name: Spanned( + node: "Foo", + span: Span( + start: 7, + end: 10, + ), + ), + body: [ + Spanned( + node: StructField( + qual: None, + name: Spanned( + node: "x", + span: Span( + start: 16, + end: 17, + ), + ), + typ: Spanned( + node: Base( + base: "address", + ), + span: Span( + start: 19, + end: 26, + ), + ), + ), + span: Span( + start: 16, + end: 26, + ), + ), + ], + ), + span: Span( + start: 0, + end: 26, + ), + ), +] diff --git a/parser/tests/test_parsers.rs b/parser/tests/test_parsers.rs index ec25144d36..7f651a3b3d 100644 --- a/parser/tests/test_parsers.rs +++ b/parser/tests/test_parsers.rs @@ -366,6 +366,12 @@ parser_fixture_tests! { write_contract_def, "fixtures/parsers/contract_def.ron", ), + ( + repeat(struct_def), + test_struct_def, + write_struct_def, + "fixtures/parsers/struct_def.ron", + ), ( repeat(contract_stmt), test_contract_stmt,