diff --git a/analyzer/src/traversal/assignments.rs b/analyzer/src/traversal/assignments.rs index 463a35e73d..ff1d145607 100644 --- a/analyzer/src/traversal/assignments.rs +++ b/analyzer/src/traversal/assignments.rs @@ -3,6 +3,10 @@ use crate::namespace::scopes::{ BlockScope, Shared, }; +use crate::namespace::types::{ + Base, + Type, +}; use crate::traversal::expressions; use crate::Context; use crate::Location; @@ -50,6 +54,36 @@ pub fn assign( unreachable!() } +/// Gather context information for assignments and check for type errors. +pub fn aug_assign( + scope: Shared, + context: Shared, + stmt: &Node, +) -> Result<(), SemanticError> { + if let fe::FuncStmt::AugAssign { + target, + op: _, + value, + } = &stmt.kind + { + let target_attributes = expressions::expr(Rc::clone(&scope), Rc::clone(&context), target)?; + let value_attributes = expressions::expr(scope, context, value)?; + + return match target_attributes.typ { + Type::Base(Base::Numeric(_)) => { + if target_attributes.typ == value_attributes.typ { + Ok(()) + } else { + Err(SemanticError::type_error()) + } + } + _ => Err(SemanticError::type_error()), + }; + } + + unreachable!() +} + #[cfg(test)] mod tests { use crate::errors::{ diff --git a/analyzer/src/traversal/functions.rs b/analyzer/src/traversal/functions.rs index 3c798413d6..75d273faf7 100644 --- a/analyzer/src/traversal/functions.rs +++ b/analyzer/src/traversal/functions.rs @@ -185,7 +185,7 @@ fn func_stmt( fe::FuncStmt::VarDecl { .. } => declarations::var_decl(scope, context, stmt), fe::FuncStmt::Assign { .. } => assignments::assign(scope, context, stmt), fe::FuncStmt::Emit { .. } => emit(scope, context, stmt), - fe::FuncStmt::AugAssign { .. } => unimplemented!(), + fe::FuncStmt::AugAssign { .. } => assignments::aug_assign(scope, context, stmt), fe::FuncStmt::For { .. } => for_loop(scope, context, stmt), fe::FuncStmt::While { .. } => while_loop(scope, context, stmt), fe::FuncStmt::If { .. } => if_statement(scope, context, stmt), diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index edb52165bf..da1c80e595 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -40,7 +40,7 @@ pub fn compile( let json_abis = abi::build(&context, &fe_module)?; // lower the AST - let lowered_fe_module = lowering::lower(&context, &fe_module); + let lowered_fe_module = lowering::lower(&context, fe_module.clone()); // analyze the lowered AST let context = fe_analyzer::analyze(&lowered_fe_module) @@ -81,8 +81,9 @@ pub fn compile( .collect::(); Ok(CompiledModule { - fe_tokens: format!("{:#?}", fe_tokens), - fe_ast: format!("{:#?}", fe_module), + src_tokens: format!("{:#?}", fe_tokens), + src_ast: format!("{:#?}", fe_module), + lowered_ast: format!("{:#?}", lowered_fe_module), contracts, }) } diff --git a/compiler/src/lowering/mappers/assignments.rs b/compiler/src/lowering/mappers/assignments.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/compiler/src/lowering/mappers/assignments.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/compiler/src/lowering/mappers/contracts.rs b/compiler/src/lowering/mappers/contracts.rs index 8b13789179..e68889ade5 100644 --- a/compiler/src/lowering/mappers/contracts.rs +++ b/compiler/src/lowering/mappers/contracts.rs @@ -1 +1,30 @@ +use fe_analyzer::Context; +use crate::lowering::mappers::functions; +use fe_parser::ast as fe; +use fe_parser::ast::ContractStmt; +use fe_parser::node::Node; + +/// Lowers a contract definition. +pub fn contract_def(context: &Context, stmt: Node) -> Node { + if let fe::ModuleStmt::ContractDef { name, body } = stmt.kind { + let lowered_body = body + .into_iter() + .map(|stmt| match stmt.kind { + ContractStmt::ContractField { .. } => stmt, + ContractStmt::EventDef { .. } => stmt, + ContractStmt::FuncDef { .. } => functions::func_def(context, stmt), + }) + .collect(); + + return Node::new( + fe::ModuleStmt::ContractDef { + name, + body: lowered_body, + }, + stmt.span, + ); + } + + unreachable!() +} diff --git a/compiler/src/lowering/mappers/declarations.rs b/compiler/src/lowering/mappers/declarations.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/compiler/src/lowering/mappers/declarations.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/compiler/src/lowering/mappers/expressions.rs b/compiler/src/lowering/mappers/expressions.rs index 8b13789179..2e4ca5d53b 100644 --- a/compiler/src/lowering/mappers/expressions.rs +++ b/compiler/src/lowering/mappers/expressions.rs @@ -1 +1,123 @@ +use fe_analyzer::Context; +use fe_parser::ast as fe; +use fe_parser::node::Node; +/// Lowers an expression and all sub expressions. +pub fn expr(context: &Context, exp: Node) -> Node { + let lowered_kind = match exp.kind { + fe::Expr::Name(_) => exp.kind, + fe::Expr::Num(_) => exp.kind, + fe::Expr::Bool(_) => exp.kind, + fe::Expr::Subscript { value, slices } => fe::Expr::Subscript { + value: boxed_expr(context, value), + slices: slices_index_expr(context, slices), + }, + fe::Expr::Attribute { value, attr } => fe::Expr::Attribute { + value: boxed_expr(context, value), + attr, + }, + fe::Expr::Ternary { + if_expr, + test, + else_expr, + } => fe::Expr::Ternary { + if_expr: boxed_expr(context, if_expr), + test: boxed_expr(context, test), + else_expr: boxed_expr(context, else_expr), + }, + fe::Expr::BoolOperation { left, op, right } => fe::Expr::BoolOperation { + left: boxed_expr(context, left), + op, + right: boxed_expr(context, right), + }, + fe::Expr::BinOperation { left, op, right } => fe::Expr::BinOperation { + left: boxed_expr(context, left), + op, + right: boxed_expr(context, right), + }, + fe::Expr::UnaryOperation { op, operand } => fe::Expr::UnaryOperation { + op, + operand: boxed_expr(context, operand), + }, + fe::Expr::CompOperation { left, op, right } => fe::Expr::CompOperation { + left: boxed_expr(context, left), + op, + right: boxed_expr(context, right), + }, + fe::Expr::Call { func, args } => fe::Expr::Call { + func: boxed_expr(context, func), + args: call_args(context, args), + }, + fe::Expr::List { .. } => unimplemented!(), + fe::Expr::ListComp { .. } => unimplemented!(), + // We only accept empty tuples for now. We may want to completely eliminate tuple + // expressions before the Yul codegen pass, tho. + fe::Expr::Tuple { .. } => exp.kind, + fe::Expr::Str(_) => exp.kind, + fe::Expr::Ellipsis => unimplemented!(), + }; + + Node::new(lowered_kind, exp.span) +} + +fn slices_index_expr( + context: &Context, + slices: Node>>, +) -> Node>> { + let first_slice = &slices.kind[0]; + + if let fe::Slice::Index(exp) = &first_slice.kind { + return Node::new( + vec![Node::new( + fe::Slice::Index(Box::new(expr(context, *exp.to_owned()))), + first_slice.span, + )], + slices.span, + ); + } + + unreachable!() +} + +/// Lowers and optional expression. +pub fn optional_expr(context: &Context, exp: Option>) -> Option> { + exp.map(|exp| expr(context, exp)) +} + +/// Lowers a boxed expression. +#[allow(clippy::boxed_local)] +pub fn boxed_expr(context: &Context, exp: Box>) -> Box> { + Box::new(expr(context, *exp)) +} + +/// Lowers a list of expression. +pub fn multiple_exprs(context: &Context, exp: Vec>) -> Vec> { + exp.into_iter().map(|exp| expr(context, exp)).collect() +} + +fn call_args( + context: &Context, + args: Node>>, +) -> Node>> { + let lowered_args = args + .kind + .into_iter() + .map(|arg| match arg.kind { + fe::CallArg::Arg(inner_arg) => { + Node::new(fe::CallArg::Arg(expr(context, inner_arg)), arg.span) + } + fe::CallArg::Kwarg(inner_arg) => { + Node::new(fe::CallArg::Kwarg(kwarg(context, inner_arg)), arg.span) + } + }) + .collect(); + + Node::new(lowered_args, args.span) +} + +fn kwarg(context: &Context, kwarg: fe::Kwarg) -> fe::Kwarg { + fe::Kwarg { + name: kwarg.name, + value: boxed_expr(context, kwarg.value), + } +} diff --git a/compiler/src/lowering/mappers/functions.rs b/compiler/src/lowering/mappers/functions.rs index 8b13789179..a00b01ca79 100644 --- a/compiler/src/lowering/mappers/functions.rs +++ b/compiler/src/lowering/mappers/functions.rs @@ -1 +1,130 @@ +use crate::lowering::mappers::expressions; +use fe_analyzer::Context; +use fe_parser::ast as fe; +use fe_parser::node::Node; +/// Lowers a function definition. +pub fn func_def(context: &Context, def: Node) -> Node { + if let fe::ContractStmt::FuncDef { + qual, + name, + args, + return_type, + body, + } = def.kind + { + let lowered_body = multiple_stmts(context, body); + + let lowered_kind = fe::ContractStmt::FuncDef { + qual, + name, + args, + return_type, + body: lowered_body, + }; + + return Node::new(lowered_kind, def.span); + } + + unreachable!() +} + +fn func_stmt(context: &Context, stmt: Node) -> Vec> { + let lowered_kinds = match stmt.kind { + fe::FuncStmt::Return { value } => vec![fe::FuncStmt::Return { + value: expressions::optional_expr(context, value), + }], + fe::FuncStmt::VarDecl { target, typ, value } => vec![fe::FuncStmt::VarDecl { + target: expressions::expr(context, target), + typ, + value: expressions::optional_expr(context, value), + }], + fe::FuncStmt::Assign { targets, value } => vec![fe::FuncStmt::Assign { + targets: expressions::multiple_exprs(context, targets), + value: expressions::expr(context, value), + }], + fe::FuncStmt::Emit { value } => vec![fe::FuncStmt::Emit { + value: expressions::expr(context, value), + }], + fe::FuncStmt::AugAssign { target, op, value } => aug_assign(context, target, op, value), + fe::FuncStmt::For { + target, + iter, + body, + or_else, + } => vec![fe::FuncStmt::For { + target: expressions::expr(context, target), + iter: expressions::expr(context, iter), + body: multiple_stmts(context, body), + or_else: multiple_stmts(context, or_else), + }], + fe::FuncStmt::While { + test, + body, + or_else, + } => vec![fe::FuncStmt::While { + test: expressions::expr(context, test), + body: multiple_stmts(context, body), + or_else: multiple_stmts(context, or_else), + }], + fe::FuncStmt::If { + test, + body, + or_else, + } => vec![fe::FuncStmt::If { + test: expressions::expr(context, test), + body: multiple_stmts(context, body), + or_else: multiple_stmts(context, or_else), + }], + fe::FuncStmt::Assert { test, msg } => vec![fe::FuncStmt::Assert { + test: expressions::expr(context, test), + msg: expressions::optional_expr(context, msg), + }], + fe::FuncStmt::Expr { value } => vec![fe::FuncStmt::Expr { + value: expressions::expr(context, value), + }], + fe::FuncStmt::Pass => vec![stmt.kind], + fe::FuncStmt::Break => vec![stmt.kind], + fe::FuncStmt::Continue => vec![stmt.kind], + fe::FuncStmt::Revert => vec![stmt.kind], + }; + let span = stmt.span; + + lowered_kinds + .into_iter() + .map(|kind| Node::new(kind, span)) + .collect() +} + +fn multiple_stmts(context: &Context, stmts: Vec>) -> Vec> { + stmts + .into_iter() + .map(|stmt| func_stmt(context, stmt)) + .collect::>>>() + .concat() +} + +fn aug_assign( + context: &Context, + target: Node, + op: Node, + value: Node, +) -> Vec { + let lowered_target = expressions::expr(context, target); + let original_value_span = value.span; + let lowered_value = expressions::expr(context, value); + + let new_value_kind = fe::Expr::BinOperation { + left: Box::new(lowered_target.clone().new_id()), + op, + right: Box::new(lowered_value), + }; + + let new_value = Node::new(new_value_kind, original_value_span); + + // the new statement is: `target = target value`. + vec![fe::FuncStmt::Assign { + targets: vec![lowered_target], + value: new_value, + }] +} diff --git a/compiler/src/lowering/mappers/mod.rs b/compiler/src/lowering/mappers/mod.rs index 4139016bb1..634e80eb2f 100644 --- a/compiler/src/lowering/mappers/mod.rs +++ b/compiler/src/lowering/mappers/mod.rs @@ -1,6 +1,4 @@ -mod assignments; mod contracts; -mod declarations; mod expressions; mod functions; pub mod module; diff --git a/compiler/src/lowering/mappers/module.rs b/compiler/src/lowering/mappers/module.rs index 8b13789179..a58366483f 100644 --- a/compiler/src/lowering/mappers/module.rs +++ b/compiler/src/lowering/mappers/module.rs @@ -1 +1,21 @@ +use fe_analyzer::Context; +use crate::lowering::mappers::contracts; +use fe_parser::ast as fe; + +/// Lowers a module. +pub fn module(context: &Context, module: fe::Module) -> fe::Module { + let lowered_body = module + .body + .into_iter() + .map(|stmt| match &stmt.kind { + fe::ModuleStmt::TypeDef { .. } => stmt, + fe::ModuleStmt::StructDef { .. } => stmt, + fe::ModuleStmt::FromImport { .. } => stmt, + fe::ModuleStmt::SimpleImport { .. } => stmt, + fe::ModuleStmt::ContractDef { .. } => contracts::contract_def(context, stmt), + }) + .collect(); + + fe::Module { body: lowered_body } +} diff --git a/compiler/src/lowering/mod.rs b/compiler/src/lowering/mod.rs index 279b0fbd0f..0edf6e334f 100644 --- a/compiler/src/lowering/mod.rs +++ b/compiler/src/lowering/mod.rs @@ -7,6 +7,6 @@ mod mappers; mod names; /// Lowers the Fe source AST to a Fe HIR AST. -pub fn lower(_context: &Context, module: &FeModuleAst) -> FeModuleAst { - module.clone() +pub fn lower(context: &Context, module: FeModuleAst) -> FeModuleAst { + mappers::module::module(context, module) } diff --git a/compiler/src/types.rs b/compiler/src/types.rs index 4fa39ba67d..06ab7bcfae 100644 --- a/compiler/src/types.rs +++ b/compiler/src/types.rs @@ -34,7 +34,8 @@ pub type NamedContracts = HashMap; /// The artifacts of a compiled module. pub struct CompiledModule { - pub fe_tokens: String, - pub fe_ast: String, + pub src_tokens: String, + pub src_ast: String, + pub lowered_ast: String, pub contracts: NamedContracts, } diff --git a/compiler/src/yul/mappers/expressions.rs b/compiler/src/yul/mappers/expressions.rs index b5decccd82..ae6e69d565 100644 --- a/compiler/src/yul/mappers/expressions.rs +++ b/compiler/src/yul/mappers/expressions.rs @@ -278,7 +278,7 @@ pub fn expr_bin_operation(context: &Context, exp: &Node) -> yul::Expre } _ => unreachable!(), }, - _ => unimplemented!(), + fe::BinOperator::FloorDiv => unimplemented!(), }; } diff --git a/compiler/tests/compile_errors.rs b/compiler/tests/compile_errors.rs index 102a92eee8..bf417ff8e2 100644 --- a/compiler/tests/compile_errors.rs +++ b/compiler/tests/compile_errors.rs @@ -86,7 +86,8 @@ use std::fs; case("struct_call_without_kw_args.fe", "KeyWordArgsRequired"), case("type_constructor_from_variable.fe", "NumericLiteralExpected"), case("unary_minus_on_bool.fe", "TypeError"), - case("unexpected_return.fe", "TypeError") + case("unexpected_return.fe", "TypeError"), + case("aug_assign_non_numeric.fe", "TypeError") )] fn test_compile_errors(fixture_file: &str, expected_error: &str) { let src = fs::read_to_string(format!("tests/fixtures/compile_errors/{}", fixture_file)) diff --git a/compiler/tests/features.rs b/compiler/tests/features.rs index d3814d2ac0..0614b14ccf 100644 --- a/compiler/tests/features.rs +++ b/compiler/tests/features.rs @@ -1186,3 +1186,34 @@ fn self_address() { ); }); } + +#[rstest( + target, + op, + value, + expected, + case(2, "add", 5, 7), + case(42, "sub", 26, 16), + case(10, "mul", 42, 420), + case(43, "div", 5, 8), + case(43, "mod", 5, 3), + case(3, "pow", 5, 243), + case(1, "lshift", 7, 128), + case(128, "rshift", 7, 1), + case(26, "bit_or", 42, 58), + case(26, "bit_xor", 42, 48), + case(26, "bit_and", 42, 10), + case(2, "add_from_sto", 5, 7), + case(2, "add_from_mem", 5, 7) +)] +fn aug_assign(target: usize, op: &str, value: usize, expected: usize) { + with_executor(&|mut executor| { + let harness = deploy_contract(&mut executor, "aug_assign.fe", "Foo", &[]); + harness.test_function( + &mut executor, + op, + &[uint_token(target), uint_token(value)], + Some(&uint_token(expected)), + ); + }); +} diff --git a/compiler/tests/fixtures/compile_errors/aug_assign_non_numeric.fe b/compiler/tests/fixtures/compile_errors/aug_assign_non_numeric.fe new file mode 100644 index 0000000000..79cb235c42 --- /dev/null +++ b/compiler/tests/fixtures/compile_errors/aug_assign_non_numeric.fe @@ -0,0 +1,3 @@ +contract Foo: + pub def bar(a: u256, b: bool): + a += b \ No newline at end of file diff --git a/compiler/tests/fixtures/features/aug_assign.fe b/compiler/tests/fixtures/features/aug_assign.fe new file mode 100644 index 0000000000..18f5303b99 --- /dev/null +++ b/compiler/tests/fixtures/features/aug_assign.fe @@ -0,0 +1,57 @@ +contract Foo: + my_num: u256 + + pub def add(a: u256, b: u256) -> u256: + a += b + return a + + pub def sub(a: u256, b: u256) -> u256: + a -= b + return a + + pub def mul(a: u256, b: u256) -> u256: + a *= b + return a + + pub def div(a: u256, b: u256) -> u256: + a /= b + return a + + pub def mod(a: u256, b: u256) -> u256: + a %= b + return a + + pub def pow(a: u256, b: u256) -> u256: + a **= b + return a + + pub def lshift(a: u8, b: u8) -> u8: + a <<= b + return a + + pub def rshift(a: u8, b: u8) -> u8: + a >>= b + return a + + pub def bit_or(a: u8, b: u8) -> u8: + a |= b + return a + + pub def bit_xor(a: u8, b: u8) -> u8: + a ^= b + return a + + pub def bit_and(a: u8, b: u8) -> u8: + a &= b + return a + + pub def add_from_sto(a: u256, b: u256) -> u256: + self.my_num = a + self.my_num += b + return self.my_num + + pub def add_from_mem(a: u256, b: u256) -> u256: + my_array: u256[10] + my_array[7] = a + my_array[7] += b + return my_array[7] diff --git a/parser/src/node.rs b/parser/src/node.rs index 0a6822c6d7..3e02ecdca5 100644 --- a/parser/src/node.rs +++ b/parser/src/node.rs @@ -61,6 +61,12 @@ impl Node { span, } } + + /// Sets a new node ID. + pub fn new_id(mut self) -> Self { + self.id = NodeId::create(); + self + } } impl From<&Node> for Span { diff --git a/src/main.rs b/src/main.rs index 903349d443..adb18a912b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ arg_enum! { pub enum CompilationTarget { Abi, Ast, + LoweredAst, Bytecode, Tokens, Yul, @@ -59,7 +60,7 @@ pub fn main() { .short("e") .long("emit") .help("Comma separated compile targets e.g. -e=bytecode,yul") - .possible_values(&["abi", "bytecode", "ast", "tokens", "yul"]) + .possible_values(&["abi", "bytecode", "ast", "tokens", "yul", "loweredAst"]) .default_value("abi,bytecode") .use_delimiter(true) .takes_value(true), @@ -134,11 +135,15 @@ fn write_compiled_module( fs::create_dir_all(output_dir).map_err(ioerr_to_string)?; if targets.contains(&CompilationTarget::Ast) { - write_output(&output_dir.join("module.ast"), &module.fe_ast)?; + write_output(&output_dir.join("module.ast"), &module.src_ast)?; + } + + if targets.contains(&CompilationTarget::LoweredAst) { + write_output(&output_dir.join("lowered_module.ast"), &module.lowered_ast)?; } if targets.contains(&CompilationTarget::Tokens) { - write_output(&output_dir.join("module.tokens"), &module.fe_tokens)?; + write_output(&output_dir.join("module.tokens"), &module.src_tokens)?; } for (name, contract) in module.contracts.drain() {