From 68a40245315aeab5b0c0c56b2ea6854c8c7563b4 Mon Sep 17 00:00:00 2001 From: Christoph Burgdorf Date: Tue, 26 Jan 2021 16:32:46 +0100 Subject: [PATCH] Add basic support for structs --- analyzer/src/lib.rs | 21 ++- analyzer/src/namespace/operations.rs | 1 + analyzer/src/namespace/types.rs | 79 +++++++++- analyzer/src/traversal/expressions.rs | 87 +++++++++-- analyzer/src/traversal/mod.rs | 1 + analyzer/src/traversal/module.rs | 9 +- analyzer/src/traversal/structs.rs | 36 +++++ compiler/src/yul/mappers/expressions.rs | 35 ++++- compiler/src/yul/mappers/module.rs | 2 +- compiler/src/yul/names.rs | 17 +++ compiler/src/yul/operations/mod.rs | 1 + compiler/src/yul/operations/structs.rs | 39 +++++ compiler/src/yul/runtime/functions/mod.rs | 1 + compiler/src/yul/runtime/functions/structs.rs | 136 ++++++++++++++++++ compiler/src/yul/runtime/mod.rs | 9 +- compiler/tests/evm_contracts.rs | 3 +- compiler/tests/fixtures/structs.fe | 18 ++- newsfragments/203.feature.md | 22 +++ parser/src/ast.rs | 2 +- parser/src/parsers.rs | 2 +- parser/tests/fixtures/parsers/struct_def.ron | 1 + 21 files changed, 494 insertions(+), 28 deletions(-) create mode 100644 analyzer/src/traversal/structs.rs create mode 100644 compiler/src/yul/operations/structs.rs create mode 100644 compiler/src/yul/runtime/functions/structs.rs create mode 100644 newsfragments/203.feature.md diff --git a/analyzer/src/lib.rs b/analyzer/src/lib.rs index f6e523b735..e1f0dbd061 100644 --- a/analyzer/src/lib.rs +++ b/analyzer/src/lib.rs @@ -19,6 +19,7 @@ use crate::namespace::scopes::{ use crate::namespace::types::{ Contract, FixedSize, + Struct, Type, }; use fe_parser::ast as fe; @@ -55,6 +56,7 @@ impl Location { Type::Array(_) => Ok(Location::Memory), Type::Tuple(_) => Ok(Location::Memory), Type::String(_) => Ok(Location::Memory), + Type::Struct(_) => Ok(Location::Memory), Type::Map(_) => Err(SemanticError::cannot_move()), } } @@ -71,6 +73,8 @@ pub struct ContractAttributes { pub events: Vec, /// Static strings that the contract defines pub string_literals: HashSet, + /// Structs that have been defined by the user + pub structs: Vec, /// External contracts that may be called from within this contract. pub external_contracts: Vec, } @@ -108,6 +112,14 @@ impl From> for ContractAttributes { } }); + let structs = scope.borrow().get_module_type_defs(|typ| { + if let Type::Struct(val) = typ { + Some(val.to_owned()) + } else { + None + } + }); + ContractAttributes { public_functions, init_function, @@ -118,6 +130,7 @@ impl From> for ContractAttributes { .map(|event| event.to_owned()) .collect::>(), string_literals: scope.borrow().string_defs.clone(), + structs, external_contracts, } } @@ -248,12 +261,8 @@ impl Context { } /// Attribute contextual information to an expression node. - pub fn add_expression( - &mut self, - spanned: &Spanned, - attributes: ExpressionAttributes, - ) { - self.expressions.insert(spanned.span, attributes); + pub fn add_expression>(&mut self, span: T, attributes: ExpressionAttributes) { + self.expressions.insert(span.into(), attributes); } /// Get information that has been attributed to an expression node. diff --git a/analyzer/src/namespace/operations.rs b/analyzer/src/namespace/operations.rs index 306d5378a7..f1553a927d 100644 --- a/analyzer/src/namespace/operations.rs +++ b/analyzer/src/namespace/operations.rs @@ -17,6 +17,7 @@ pub fn index(value: Type, index: Type) -> Result { Type::Tuple(_) => Err(SemanticError::not_subscriptable()), Type::String(_) => Err(SemanticError::not_subscriptable()), Type::Contract(_) => Err(SemanticError::not_subscriptable()), + Type::Struct(_) => Err(SemanticError::not_subscriptable()), } } diff --git a/analyzer/src/namespace/types.rs b/analyzer/src/namespace/types.rs index 36b9e2e99c..e083837b43 100644 --- a/analyzer/src/namespace/types.rs +++ b/analyzer/src/namespace/types.rs @@ -1,6 +1,9 @@ use crate::errors::SemanticError; use fe_parser::ast as fe; -use std::collections::HashMap; +use std::collections::{ + BTreeMap, + HashMap, +}; use std::convert::TryFrom; use std::num::{ IntErrorKind, @@ -103,6 +106,7 @@ pub enum Type { Tuple(Tuple), String(FeString), Contract(Contract), + Struct(Struct), } #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] @@ -112,6 +116,7 @@ pub enum FixedSize { Tuple(Tuple), String(FeString), Contract(Contract), + Struct(Struct), } #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] @@ -157,6 +162,12 @@ pub struct Tuple { pub items: Vec, } +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Struct { + pub name: String, + fields: BTreeMap, +} + #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] pub struct FeString { pub max_size: usize, @@ -168,6 +179,40 @@ pub struct Contract { pub functions: Vec, } +impl Struct { + pub fn new(name: &str) -> Struct { + Struct { + name: name.to_string(), + fields: BTreeMap::new(), + } + } + + // Return `true` if the struct has any fields, otherwise return `false` + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// Add a field to the struct + pub fn add_field(&mut self, name: &str, value: &Base) -> Option { + self.fields.insert(name.to_string(), value.clone()) + } + + // Return the type of the given field name + pub fn get_field_type(&self, name: &str) -> Option<&Base> { + self.fields.get(name) + } + + // Return a vector of field types + pub fn get_field_types(&self) -> Vec { + self.fields.values().map(|val| val.clone().into()).collect() + } + + //Return a vector of field names + pub fn get_field_names(&self) -> Vec { + self.fields.keys().cloned().collect() + } +} + impl TryFrom<&str> for FeString { type Error = String; @@ -258,10 +303,17 @@ impl From for Type { FixedSize::Tuple(tuple) => Type::Tuple(tuple), FixedSize::String(string) => Type::String(string), FixedSize::Contract(contract) => Type::Contract(contract), + FixedSize::Struct(val) => Type::Struct(val), } } } +impl From for Type { + fn from(value: Base) -> Self { + Type::Base(value) + } +} + impl FeSized for FixedSize { fn size(&self) -> usize { match self { @@ -270,6 +322,7 @@ impl FeSized for FixedSize { FixedSize::Tuple(tuple) => tuple.size(), FixedSize::String(string) => string.size(), FixedSize::Contract(contract) => contract.size(), + FixedSize::Struct(val) => val.size(), } } } @@ -301,6 +354,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_name(), FixedSize::String(string) => string.abi_name(), FixedSize::Contract(contract) => contract.abi_name(), + FixedSize::Struct(val) => val.abi_name(), } } @@ -311,6 +365,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_safe_name(), FixedSize::String(string) => string.abi_safe_name(), FixedSize::Contract(contract) => contract.abi_safe_name(), + FixedSize::Struct(val) => val.abi_safe_name(), } } @@ -321,6 +376,7 @@ impl AbiEncoding for FixedSize { FixedSize::Tuple(tuple) => tuple.abi_type(), FixedSize::String(string) => string.abi_type(), FixedSize::Contract(contract) => contract.abi_type(), + FixedSize::Struct(val) => val.abi_type(), } } } @@ -348,6 +404,7 @@ impl TryFrom for FixedSize { Type::Base(base) => Ok(FixedSize::Base(base)), Type::Tuple(tuple) => Ok(FixedSize::Tuple(tuple)), Type::String(string) => Ok(FixedSize::String(string)), + Type::Struct(val) => Ok(FixedSize::Struct(val)), Type::Map(_) => Err(SemanticError::type_error()), Type::Contract(contract) => Ok(FixedSize::Contract(contract)), } @@ -543,6 +600,26 @@ impl FeSized for Tuple { } } +impl FeSized for Struct { + fn size(&self) -> usize { + self.fields.len() * 32 + } +} + +impl AbiEncoding for Struct { + fn abi_name(&self) -> String { + unimplemented!(); + } + + fn abi_safe_name(&self) -> String { + unimplemented!(); + } + + fn abi_type(&self) -> AbiType { + unimplemented!(); + } +} + impl AbiEncoding for Tuple { fn abi_name(&self) -> String { unimplemented!(); diff --git a/analyzer/src/traversal/expressions.rs b/analyzer/src/traversal/expressions.rs index 34d8585e48..29b743f04b 100644 --- a/analyzer/src/traversal/expressions.rs +++ b/analyzer/src/traversal/expressions.rs @@ -13,6 +13,7 @@ use crate::namespace::types::{ FeString, FixedSize, Integer, + Struct, Type, U256, }; @@ -45,7 +46,7 @@ pub fn expr( fe::Expr::Num(_) => expr_num(exp), fe::Expr::Bool(_) => expr_bool(exp), fe::Expr::Subscript { .. } => expr_subscript(scope, Rc::clone(&context), exp), - fe::Expr::Attribute { .. } => expr_attribute(scope, exp), + fe::Expr::Attribute { .. } => expr_attribute(scope, Rc::clone(&context), exp), fe::Expr::Ternary { .. } => expr_ternary(scope, Rc::clone(&context), exp), fe::Expr::BoolOperation { .. } => unimplemented!(), fe::Expr::BinOperation { .. } => expr_bin_operation(scope, Rc::clone(&context), exp), @@ -138,7 +139,6 @@ pub fn expr_name_str<'a>(exp: &Spanned>) -> Result<&'a str, Semanti if let fe::Expr::Name(name) = exp.node { return Ok(name); } - unreachable!() } @@ -198,6 +198,10 @@ fn expr_name( Location::Memory, )), Some(FixedSize::Tuple(_)) => unimplemented!(), + Some(FixedSize::Struct(val)) => Ok(ExpressionAttributes::new( + Type::Struct(val), + Location::Memory, + )), None => Err(SemanticError::undefined_value()), }; } @@ -275,6 +279,7 @@ fn expr_subscript( fn expr_attribute( scope: Shared, + context: Shared, exp: &Spanned, ) -> Result { if let fe::Expr::Attribute { value, attr } = &exp.node { @@ -286,10 +291,18 @@ fn expr_attribute( TxField, }; + let object_name = expr_name_str(value)?; + + // Before we try to match any pre-defined objects, try matching as a + // custom type + if let Some(FixedSize::Struct(_)) = scope.borrow().variable_def(object_name.to_string()) { + return expr_attribute_custom_type(Rc::clone(&scope), context, value, attr); + } + let val = |t| Ok(ExpressionAttributes::new(Type::Base(t), Location::Value)); let err = || Err(SemanticError::undefined_value()); - return match Object::from_str(expr_name_str(value)?) { + return match Object::from_str(object_name) { Ok(Object::Self_) => expr_attribute_self(scope, attr), Ok(Object::Block) => match BlockField::from_str(attr.node) { @@ -322,6 +335,35 @@ fn expr_attribute( unreachable!() } +fn expr_attribute_custom_type( + scope: Shared, + context: Shared, + value: &Spanned, + attr: &Spanned<&str>, +) -> Result { + let val_str = expr_name_str(value)?; + let custom_type = scope + .borrow() + .variable_def(val_str.to_string()) + .ok_or_else(SemanticError::undefined_value)?; + context.borrow_mut().add_expression( + value, + ExpressionAttributes::new(custom_type.clone().into(), Location::Memory), + ); + match custom_type { + FixedSize::Struct(val) => { + let field_type = val + .get_field_type(attr.node) + .ok_or_else(SemanticError::undefined_value)?; + Ok(ExpressionAttributes::new( + Type::Base(field_type.clone()), + Location::Memory, + )) + } + _ => Err(SemanticError::undefined_value()), + } +} + fn expr_attribute_self( scope: Shared, attr: &Spanned<&str>, @@ -431,12 +473,34 @@ fn expr_call( unreachable!() } +fn expr_call_struct_constructor( + scope: Shared, + context: Shared, + typ: Struct, + args: &Spanned>>, +) -> Result { + let argument_attributes = expr_call_args(Rc::clone(&scope), Rc::clone(&context), args)?; + + if typ.get_field_types() != expression_attributes_to_types(argument_attributes) { + return Err(SemanticError::type_error()); + } + + Ok(ExpressionAttributes::new( + Type::Struct(typ), + Location::Memory, + )) +} + fn expr_call_type_constructor( scope: Shared, context: Shared, typ: Type, args: &Spanned>>, ) -> Result { + if let Type::Struct(val) = typ { + return expr_call_struct_constructor(scope, context, val, args); + } + if args.node.len() != 1 { return Err(SemanticError::wrong_number_of_params()); } @@ -507,6 +571,17 @@ fn validate_str_literal_fits_type( Err(SemanticError::type_error()) } +fn expr_call_args( + scope: Shared, + context: Shared, + args: &Spanned>>, +) -> Result, SemanticError> { + args.node + .iter() + .map(|arg| call_arg(Rc::clone(&scope), Rc::clone(&context), arg)) + .collect::, _>>() +} + fn expr_call_self_attribute( scope: Shared, context: Shared, @@ -524,11 +599,7 @@ fn expr_call_self_attribute( .borrow() .function_def(func_name) { - let argument_attributes = args - .node - .iter() - .map(|arg| call_arg(Rc::clone(&scope), Rc::clone(&context), arg)) - .collect::, _>>()?; + let argument_attributes = expr_call_args(Rc::clone(&scope), Rc::clone(&context), args)?; if param_types.len() != argument_attributes.len() { return Err(SemanticError::wrong_number_of_params()); diff --git a/analyzer/src/traversal/mod.rs b/analyzer/src/traversal/mod.rs index 7ee3c758c7..249228cea2 100644 --- a/analyzer/src/traversal/mod.rs +++ b/analyzer/src/traversal/mod.rs @@ -5,4 +5,5 @@ mod declarations; mod expressions; mod functions; pub mod module; +mod structs; mod types; diff --git a/analyzer/src/traversal/module.rs b/analyzer/src/traversal/module.rs index 90d4b084af..07d5856ba4 100644 --- a/analyzer/src/traversal/module.rs +++ b/analyzer/src/traversal/module.rs @@ -4,7 +4,10 @@ use crate::namespace::scopes::{ Shared, }; use crate::namespace::types; -use crate::traversal::contracts; +use crate::traversal::{ + contracts, + structs, +}; use crate::Context; use fe_parser::ast as fe; use fe_parser::span::Spanned; @@ -17,10 +20,12 @@ pub fn module(context: Shared, module: &fe::Module) -> Result<(), Seman for stmt in module.body.iter() { match &stmt.node { fe::ModuleStmt::TypeDef { .. } => type_def(Rc::clone(&scope), stmt)?, + fe::ModuleStmt::StructDef { name, body } => { + structs::struct_def(Rc::clone(&scope), name.node, body)? + } fe::ModuleStmt::ContractDef { .. } => { contracts::contract_def(Rc::clone(&scope), Rc::clone(&context), stmt)? } - fe::ModuleStmt::StructDef { .. } => unimplemented!(), fe::ModuleStmt::FromImport { .. } => unimplemented!(), fe::ModuleStmt::SimpleImport { .. } => unimplemented!(), } diff --git a/analyzer/src/traversal/structs.rs b/analyzer/src/traversal/structs.rs new file mode 100644 index 0000000000..6a870521fd --- /dev/null +++ b/analyzer/src/traversal/structs.rs @@ -0,0 +1,36 @@ +use fe_parser::{ + ast::StructStmt, + span::Spanned, +}; + +use crate::errors::SemanticError; +use crate::namespace::scopes::{ + ModuleScope, + Shared, +}; +use crate::namespace::types::{ + type_desc, + Struct, + Type, +}; + +pub fn struct_def( + module_scope: Shared, + name: &str, + struct_stmts: &[Spanned], +) -> Result<(), SemanticError> { + let mut val = Struct::new(name); + for stmt in struct_stmts { + let StructStmt::StructField { name, typ, .. } = &stmt.node; + let field_type = type_desc(&module_scope.borrow().type_defs, &typ.node)?; + if let Type::Base(base_typ) = field_type { + val.add_field(name.node, &base_typ); + } else { + todo!("Non-Base type fields aren't yet supported") + } + } + module_scope + .borrow_mut() + .add_type_def(name.to_string(), Type::Struct(val)); + Ok(()) +} diff --git a/compiler/src/yul/mappers/expressions.rs b/compiler/src/yul/mappers/expressions.rs index b830023234..b46305a5ed 100644 --- a/compiler/src/yul/mappers/expressions.rs +++ b/compiler/src/yul/mappers/expressions.rs @@ -1,13 +1,19 @@ use crate::errors::CompileError; use crate::yul::names; -use crate::yul::operations::calls as call_operations; -use crate::yul::operations::data as data_operations; +use crate::yul::operations::{ + calls as call_operations, + data as data_operations, + structs as struct_operations, +}; use crate::yul::utils; -use fe_analyzer::builtins; use fe_analyzer::namespace::types::{ FixedSize, Type, }; +use fe_analyzer::{ + builtins, + ExpressionAttributes, +}; use fe_analyzer::{ CallType, Context, @@ -98,6 +104,9 @@ fn expr_call(context: &Context, exp: &Spanned) -> Result>()?; return match call_type { + CallType::TypeConstructor { + typ: Type::Struct(val), + } => Ok(struct_operations::new(val, yul_args)), CallType::TypeConstructor { .. } => Ok(yul_args[0].to_owned()), CallType::SelfAttribute { func_name } => { let func_name = names::func_name(func_name); @@ -341,7 +350,25 @@ fn expr_attribute( Object, TxField, }; - return match Object::from_str(expr_name_str(value)?) { + + let object_name = expr_name_str(value)?; + + // Before we try to match any known pre-defined objects, try matching as a + // custom type + if let Some(ExpressionAttributes { + typ: Type::Struct(val), + .. + }) = context.get_expression(&*value) + { + let custom_type = format!("${}", object_name); + return Ok(struct_operations::get_attribute( + val, + &custom_type, + attr.node, + )); + } + + return match Object::from_str(object_name) { Ok(Object::Self_) => expr_attribute_self(context, exp), Ok(Object::Block) => match BlockField::from_str(attr.node) { diff --git a/compiler/src/yul/mappers/module.rs b/compiler/src/yul/mappers/module.rs index d33ac4d1ea..218e219368 100644 --- a/compiler/src/yul/mappers/module.rs +++ b/compiler/src/yul/mappers/module.rs @@ -22,7 +22,7 @@ pub fn module(context: &Context, module: &fe::Module) -> Result unimplemented!(), + fe::ModuleStmt::StructDef { .. } => {} fe::ModuleStmt::FromImport { .. } => unimplemented!(), fe::ModuleStmt::SimpleImport { .. } => unimplemented!(), } diff --git a/compiler/src/yul/names.rs b/compiler/src/yul/names.rs index 34b58e155e..91287dbc3a 100644 --- a/compiler/src/yul/names.rs +++ b/compiler/src/yul/names.rs @@ -45,6 +45,23 @@ pub fn contract_call(contract_name: &str, func_name: &str) -> yul::Identifier { identifier! { (name) } } +/// Generates a function name for to interact with a certain struct type +pub fn struct_function_name(struct_name: &str, func_name: &str) -> yul::Identifier { + let name = format!("struct_{}_{}", struct_name, func_name); + identifier! { (name) } +} + +/// Generates a function name for creating a certain struct type +pub fn struct_new_call(struct_name: &str) -> yul::Identifier { + struct_function_name(struct_name, "new") +} + +/// Generates a function name for reading a named property of a certain struct +/// type +pub fn struct_getter_call(struct_name: &str, field_name: &str) -> yul::Identifier { + struct_function_name(struct_name, &format!("get_{}_ptr", field_name)) +} + #[cfg(test)] mod tests { use crate::yul::names::{ diff --git a/compiler/src/yul/operations/mod.rs b/compiler/src/yul/operations/mod.rs index 20f72dcd7e..a65780f2c8 100644 --- a/compiler/src/yul/operations/mod.rs +++ b/compiler/src/yul/operations/mod.rs @@ -1,3 +1,4 @@ pub mod abi; pub mod calls; pub mod data; +pub mod structs; diff --git a/compiler/src/yul/operations/structs.rs b/compiler/src/yul/operations/structs.rs new file mode 100644 index 0000000000..e8b1fda948 --- /dev/null +++ b/compiler/src/yul/operations/structs.rs @@ -0,0 +1,39 @@ +use crate::yul::names; +use fe_analyzer::namespace::types::Struct; +use yultsur::*; + +pub fn new(struct_type: &Struct, params: Vec) -> yul::Expression { + let function_name = names::struct_new_call(&struct_type.name); + expression! { [function_name]([params...]) } +} + +pub fn get_attribute(struct_type: &Struct, ptr_name: &str, field_name: &str) -> yul::Expression { + let function_name = names::struct_getter_call(&struct_type.name, field_name); + let ptr_name_exp = identifier_expression! {(ptr_name)}; + expression! { [function_name]([ptr_name_exp]) } +} + +#[cfg(test)] +mod tests { + use crate::yul::operations::structs; + use fe_analyzer::namespace::types::{ + Base, + Struct, + }; + use yultsur::*; + + #[test] + fn test_new() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + let params = vec![ + identifier_expression! { (1) }, + identifier_expression! { (2) }, + ]; + assert_eq!( + structs::new(&val, params).to_string(), + "struct_Foo_new(1, 2)" + ) + } +} diff --git a/compiler/src/yul/runtime/functions/mod.rs b/compiler/src/yul/runtime/functions/mod.rs index 43dd9ff692..bddfefacec 100644 --- a/compiler/src/yul/runtime/functions/mod.rs +++ b/compiler/src/yul/runtime/functions/mod.rs @@ -4,6 +4,7 @@ use yultsur::*; pub mod abi; pub mod calls; pub mod data; +pub mod structs; /// Returns all functions that should be available during runtime. pub fn std() -> Vec { diff --git a/compiler/src/yul/runtime/functions/structs.rs b/compiler/src/yul/runtime/functions/structs.rs new file mode 100644 index 0000000000..56fa06f66a --- /dev/null +++ b/compiler/src/yul/runtime/functions/structs.rs @@ -0,0 +1,136 @@ +use crate::yul::names; +use fe_analyzer::namespace::types::Struct; +use yultsur::*; + +/// Generate a YUL function that can be used to create an instance of +/// `struct_type` +pub fn generate_new_fn(struct_type: &Struct) -> yul::Statement { + let function_name = names::struct_new_call(&struct_type.name); + + if struct_type.is_empty() { + // We return 0 here because it is safe to assume that we never write to an empty + // struct. If we end up writing to an empty struct that's an actual Fe + // bug. + let body = statement! { return_val := 0 }; + return function_definition! { + function [function_name]() -> return_val { + [body] + } + }; + } + + let params = struct_type + .get_field_names() + .iter() + .map(|key| { + identifier! {(key)} + }) + .collect::>(); + + let body = struct_type + .get_field_names() + .iter() + .enumerate() + .map(|(index, key)| { + if index == 0 { + let param_identifier_exp = identifier_expression! {(key)}; + statements! { + (return_val := alloc(32)) + (mstore(return_val, [param_identifier_exp])) + } + } else { + let ptr_identifier = format!("{}_ptr", key); + let ptr_identifier = identifier! {(ptr_identifier)}; + let ptr_identifier_exp = identifier_expression! {(ptr_identifier)}; + let param_identifier_exp = identifier_expression! {(key)}; + statements! { + (let [ptr_identifier] := alloc(32)) + (mstore([ptr_identifier_exp], [param_identifier_exp])) + } + } + }) + .flatten() + .collect::>(); + + function_definition! { + function [function_name]([params...]) -> return_val { + [body...] + } + } +} + +/// Generate a YUL function that can be used to read a property of `struct_type` +pub fn generate_get_fn(struct_type: &Struct, field_name: &str) -> yul::Statement { + let function_name = names::struct_getter_call(&struct_type.name, field_name); + let field_index = struct_type + .get_field_names() + .iter() + .position(|field| field == field_name) + .unwrap_or_else(|| panic!("No field {} in {}", field_name, struct_type.name)); + let field_offset = field_index * 32; + + let offset = literal_expression! {(field_offset)}; + let return_expression = expression! { add(ptr, [offset]) }; + let body = statement! { (return_val := [return_expression]) }; + function_definition! { + function [function_name](ptr) -> return_val { + [body] + } + } +} + +/// Builds a set of functions used to interact with structs used in a contract +pub fn struct_apis(struct_type: Struct) -> Vec { + [ + vec![generate_new_fn(&struct_type)], + struct_type + .get_field_names() + .iter() + .map(|field| generate_get_fn(&struct_type, &field)) + .collect(), + ] + .concat() +} + +#[cfg(test)] +mod tests { + use crate::yul::runtime::functions::structs; + use fe_analyzer::namespace::types::{ + Base, + Struct, + }; + + #[test] + fn test_empty_struct() { + assert_eq!( + structs::generate_new_fn(&Struct::new("Foo")).to_string(), + "function struct_Foo_new() -> return_val { return_val := 0 }" + ) + } + + #[test] + fn test_struct_api_generation() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + assert_eq!( + structs::generate_new_fn(&val).to_string(), + "function struct_Foo_new(bar, bar2) -> return_val { return_val := alloc(32) mstore(return_val, bar) let bar2_ptr := alloc(32) mstore(bar2_ptr, bar2) }" + ) + } + + #[test] + fn test_struct_getter_generation() { + let mut val = Struct::new("Foo"); + val.add_field("bar", &Base::Bool); + val.add_field("bar2", &Base::Bool); + assert_eq!( + structs::generate_get_fn(&val, &val.get_field_names().get(0).unwrap()).to_string(), + "function struct_Foo_get_bar_ptr(ptr) -> return_val { return_val := add(ptr, 0) }" + ); + assert_eq!( + structs::generate_get_fn(&val, &val.get_field_names().get(1).unwrap()).to_string(), + "function struct_Foo_get_bar2_ptr(ptr) -> return_val { return_val := add(ptr, 32) }" + ); + } +} diff --git a/compiler/src/yul/runtime/mod.rs b/compiler/src/yul/runtime/mod.rs index f619013534..bbb2c857e6 100644 --- a/compiler/src/yul/runtime/mod.rs +++ b/compiler/src/yul/runtime/mod.rs @@ -86,7 +86,14 @@ pub fn build(context: &Context, contract: &Spanned) -> Vec>() + .concat(); + + return [std, encoding, decoding, contract_calls, struct_apis].concat(); } panic!("missing contract attributes") diff --git a/compiler/tests/evm_contracts.rs b/compiler/tests/evm_contracts.rs index f571883ad6..c3cda0a7c2 100644 --- a/compiler/tests/evm_contracts.rs +++ b/compiler/tests/evm_contracts.rs @@ -955,7 +955,8 @@ fn sized_vals_in_sto() { #[test] fn structs() { with_executor(&|mut executor| { - let harness = deploy_contract(&mut executor, "structs.fe", "Foo", vec![]); + let harness = deploy_contract(&mut executor, "structs.fe", "Foo", &[]); + harness.test_function(&mut executor, "bar", &[], Some(&uint_token(2))); }); } diff --git a/compiler/tests/fixtures/structs.fe b/compiler/tests/fixtures/structs.fe index 912f9a4ca1..4c86aa1829 100644 --- a/compiler/tests/fixtures/structs.fe +++ b/compiler/tests/fixtures/structs.fe @@ -1,7 +1,21 @@ struct House: price: u256 size: u256 + vacant: bool +contract Foo: -contract City: - pub house: House + pub def bar() -> u256: + building: House = House(300, 500, true) + assert building.size == 500 + assert building.price == 300 + assert building.vacant + + building.vacant = false + building.price = 1 + building.size = 2 + + assert building.vacant == false + assert building.price == 1 + assert building.size == 2 + return building.size diff --git a/newsfragments/203.feature.md b/newsfragments/203.feature.md new file mode 100644 index 0000000000..d83b0bfa9f --- /dev/null +++ b/newsfragments/203.feature.md @@ -0,0 +1,22 @@ +Add basic support for structs. + +Example: + +``` +struct House: + price: u256 + size: u256 + vacant: bool + +contract City: + + pub def get_price() -> u256: + building: House = House(300, 500, true) + + assert building.size == 500 + assert building.price == 300 + assert building.vacant + + return building.price +``` + diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 706e355666..22fd134d70 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -116,7 +116,7 @@ pub enum ContractStmt<'a> { #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub enum StructStmt<'a> { StructField { - //qual: Option>, + qual: Option>, #[serde(borrow)] name: Spanned<&'a str>, typ: Spanned>, diff --git a/parser/src/parsers.rs b/parser/src/parsers.rs index 9eae0e721f..53c9b3b4ac 100644 --- a/parser/src/parsers.rs +++ b/parser/src/parsers.rs @@ -523,7 +523,7 @@ pub fn struct_field(input: Cursor) -> ParseResult> { input, Spanned { node: StructStmt::StructField { - //qual, + qual, name: name_tok.into(), typ, }, diff --git a/parser/tests/fixtures/parsers/struct_def.ron b/parser/tests/fixtures/parsers/struct_def.ron index f5ebf1df8d..57110882f4 100644 --- a/parser/tests/fixtures/parsers/struct_def.ron +++ b/parser/tests/fixtures/parsers/struct_def.ron @@ -14,6 +14,7 @@ struct Foo: body: [ Spanned( node: StructField( + qual: None, name: Spanned( node: "x", span: Span(