diff --git a/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/Nargo.toml b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/Nargo.toml new file mode 100644 index 00000000000..72ba378b5d9 --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "closures_passed_to_higher_order_fns" +type = "bin" +authors = [""] +compiler_version = "0.1" + +[dependencies] diff --git a/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/Prover.toml b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/src/main.nr b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/src/main.nr new file mode 100644 index 00000000000..07aafeb7773 --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/src/main.nr @@ -0,0 +1,31 @@ +use dep::std; + +fn main() { + // give some numbers names to turn them into variables for capturing + let _0: u32 = 0; + let _1: u32 = 1; + let _2: u32 = 2; + let _5: u32 = 5; + + let arr: [u32; 3] = [1, 2, 3]; + let evens = arr.map(|n| n * _2); + + assert(arr.any(|n| n > _1)); + assert(arr.all(|n| n > _0 & n < _5)); + + assert(evens.fold(0, |a, b| a + b + _1) == 15); + assert(evens.reduce(|a, b| a + b + _1) == 14); + + // call with 1 captured variable + assert(add_results(|| _5, || _2) == 7); + + // call with 2 captured variables + assert(add_results(|| _5 + _2, || _2 + _1) == 10); + + // call with different environment types for the 2 args + assert(add_results(|| _5 + _2, || _2) == 9); +} + +fn add_results(a: fn () -> u32, b: fn () -> u32) -> u32 { + a() + b() +} \ No newline at end of file diff --git a/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/higher_order_functions.json b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/higher_order_functions.json new file mode 100644 index 00000000000..7e9b8feec6e --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/higher_order_functions.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[],"param_witnesses":{},"return_type":{"kind":"field"},"return_witnesses":[83]},"bytecode":"H4sIAAAAAAAA/+1cDU8bRxAdA4UQAg1pCIEkcKZJHfJB9nz+OBNSh3wQaCg0qVqplaoqVszf6V9uV9m7zl7PjeJ9szpbuxJaaxXezs7beTM7EAZE9Bd9GrPma9R4bmY13oizWRBbCWDne8wwzJqZ+doiW8vW59m/z3w798/XTglWrYC7w76n7N/URuAssrXs+1eYLYTziZpn+6IwVxgm2uC4ZpxbY18ZYVFhz5nC3onqtFrDbnMYJ/EH1ewN0rZqtQedNE7jdtr+2EyTZJi20m5v0OuqXtxKhvFFu9e8YHs7YiUGS80QLiA9Bb6aMOx8Dx7kPIizEQIfg+kl8GfJDnxNSlTYEx34s4QL/DmSCRrQmXM7kWL3FfAeSJ4ZyfM88MzCAq8uzBDA9i7wC2a+xNaCwGMwvQj8AtkCr4mMCnuiBX6BcIF/ibABWcOc1RLjeQHcX3B3wRKLWbCdyMSGTBi/Toj/FnF2xsA7EwP91+T+QxcZSK25DPSfxkBrgr4rlwVwf6PJiJUFoVhxtet3T/5TbgMZ0zHwzsRA/zVDcf//e/DifsnMV9haKO4xmF6K+yWyi3tNZFTYE13cLxFOOK8QLiA9Bf5HAWzvgb9s5hW2FgIfg+kl8JfJDny9aVTYEx34y4QL/BWSCUj06wIpdl8D74HkmZE8XwWe2ZPADwSwvQv8qpmvsbUg8BhMLwK/SrbAayKjwp5ogV8lXOBfI2xAgkmMtRhfJXwr5RucjaI/j0RyfR3IiwTXmpPrzJeoc0snJDOaAtjeE9KamW+wtZCQMJheEtIa2QlJExkV9kQnpDXCBesNnE9G9hjRwgW0uVVi7sSJyLqZb7K1ICIYTC8isk62iGgio8Ke6B968CByFaR1wgnSTdwZW74ECWhzp8TciROkDTNvsrUgSBhML4K0QbYgaSKjwp5oQeJB5CpIG4QTpE2SCe4ZsP82gWe+BcAafmpTtLNfpS4OtAjfwmEpbu9t9nmuwJ0eWRwIBLvVzy7zo6gYSJF0WwD3DuEuv9S57+A5sjJ6lX3qqSc0FdXTlpm32VqonjCYXqqnLbKrJ01kVNgT3RPaIlywbuN80vH1BEPaXGLuxIlIZOY6WwsigsH0IiIR2SJSJ/knGA8iV0GKCCdIdZIJbvQTrA488w4Ay/cTbAeHZT3BvmWfwxPMEXPHOBSNe5eq/QTT576L50j0CYb26Sg7XbHv4c7clLTzO6CdvipboM0i1WcqiG2GVdk2zHyfrX1JZVsv4apY2dbp85VtGU6obEePvLJtkF3ZaiK3C3uiK1seROOKSNdgNQgnSPdxZ1S+BAko9lMhSLtmfsDWgiBhML0I0i7ZgqSJlBYkHkSugrRLOEF6QNjgRlVw5il7cc/wJcmHGm/kPT4kHw8rzodOrA2STdhqvJHzgUzYj2jyEjb6DknZifRtiVtFkv9jM++xtS9J/hH915/F5B/R55N/GU5I/qNHnvw1gdv0b/LfI/nkzy+6qyDtAbGeALEyfz5h/kS3cx4Z/tC4f5CMEKHv0UOhe6TcRoz2H7po0QWBRBHZAGIhi0hFfgoAVztjnJ0tSTubODu9tZGBNk9F1yYxM78ooWuDwfRSuCVkd200kdKFGw8i165NQjhBauHO6O1VChT7qRCktpn5b6gEQcJgehGkNtmCpImUFiQeRK6C1CacIHUIG9zoF0Bs+JLkQ4038rYlko9uxfnQiTUh2YStxhs5H8iEndLkJWz0HZKyE+nbEreKJP+emffZWmgjYzC9JH9NIG8j75N88ucX3VWQ9oFYT4FYmT+fklwbOTX8oXH/JBkhQt+jrtA9Um4jRvsPXbTogkCiiEyAWMgi8oD8FACudj7D2dmRtPN7nJ3e2shAm6eia9MvYOsRujYYTC+FW5/sro3Gly7ceBC5dm36hBOk57gzenuVAsV+KgTp0Mwv2FoQJAymF0E6JFuQNJHSgsSDyFWQDgknSC8IG9zoF8Azw5ckH2q8kbctkXy8rDgfOrH2STZhq/FGzgcyYb+iyUvY6DskZSfStyVuFUn+r818xNZCGxmD6SX5awJ5G/mI5JM/v+iugnQExHoDxMr8+Ybk2sivDH9o3A8kI0Toe/RS6B4ptxGj/YcuWnRBIFFE9oFYyCLyuOJ86P9RoAT4eOyOFfODo/g4AWL5+ktDJzgs6y8N/cA+h7805Ih5YhyKxn1LuAsrde63eI5ExE7/3PuA8GLXI5zYHQD5PgX6zpfYnZKM2P3IPgexc8Q8NQ5F455RtcVOn/sMz5GI2OnX2THhxe414cTuGMj3Oc53Q19id04yYvcT+xzEzhHz3DgUjfuOqi12+tzv8ByJ2Kor0DPCt6PeV/zcmp/3JRxV+e9n/ow7f5sKI7NV7/E3EXMPTcmZAAA=","proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/witness.tr b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/witness.tr new file mode 100644 index 00000000000..28701fee014 Binary files /dev/null and b/crates/nargo_cli/tests/execution_success/closures_passed_to_higher_order_fns/target/witness.tr differ diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 0210eb02e11..20b5e48cc5a 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -1090,10 +1090,6 @@ impl Type { /// given bindings if found. If a type variable is not found within /// the given TypeBindings, it is unchanged. pub fn substitute(&self, type_bindings: &TypeBindings) -> Type { - if type_bindings.is_empty() { - return self.clone(); - } - let substitute_binding = |binding: &TypeVariable| match &*binding.borrow() { TypeBinding::Bound(binding) => binding.substitute(type_bindings), TypeBinding::Unbound(id) => match type_bindings.get(id) { diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 998f3093d49..c73cd42af76 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -29,6 +29,7 @@ use self::ast::{Definition, FuncId, Function, LocalId, Program}; pub mod ast; pub mod printer; +pub mod propagate; struct LambdaContext { env_ident: ast::Ident, @@ -54,7 +55,7 @@ struct Monomorphizer<'interner> { locals: HashMap, /// Queue of functions to monomorphize next - queue: VecDeque<(node_interner::FuncId, FuncId, TypeBindings)>, + queue: VecDeque<(node_interner::FuncId, FuncId, TypeBindings, Type)>, /// When a function finishes being monomorphized, the monomorphized ast::Function is /// stored here along with its FuncId. @@ -67,9 +68,20 @@ struct Monomorphizer<'interner> { next_local_id: u32, next_function_id: u32, + + call_expr_to_arguments: HashMap>, + + /// `overridden_types` allows us to change the type an of expression in the current function instantiation only + /// e.g. for `fn foo(arg: fn() -> u32)` `arg` is of type Function([], u32, Type::Unit) + /// but if the instantiation we're processing has been called with a closure, we add the proper capture environment + overridden_types: HashMap, + + current_function_id: FuncId, } type HirType = crate::Type; +type OverriddenTypes = HashMap; +type ConvertedParameters = Vec<(ast::LocalId, bool, String, ast::Type)>; /// Starting from the given `main` function, monomorphize the entire program, /// replacing all references to type variables and NamedGenerics with concrete @@ -87,11 +99,11 @@ pub fn monomorphize(main: node_interner::FuncId, interner: &NodeInterner) -> Pro let function_sig = monomorphizer.compile_main(main); while !monomorphizer.queue.is_empty() { - let (next_fn_id, new_id, bindings) = monomorphizer.queue.pop_front().unwrap(); + let (next_fn_id, new_id, bindings, typ) = monomorphizer.queue.pop_front().unwrap(); monomorphizer.locals.clear(); perform_instantiation_bindings(&bindings); - monomorphizer.function(next_fn_id, new_id); + monomorphizer.function(next_fn_id, new_id, Some(typ)); undo_instantiation_bindings(bindings); } @@ -111,6 +123,9 @@ impl<'interner> Monomorphizer<'interner> { next_function_id: 0, interner, lambda_envs_stack: Vec::new(), + call_expr_to_arguments: HashMap::new(), + overridden_types: HashMap::new(), + current_function_id: FuncId(0), } } @@ -130,13 +145,41 @@ impl<'interner> Monomorphizer<'interner> { self.locals.get(&id).copied().map(Definition::Local) } + fn add_capture_environments(&self, typ: &Type, args: &Vec) -> Type { + match typ { + Type::Function(params, ret_type, env) => { + let new_params = params + .iter() + .zip(args) + .map(|(param, arg)| { + let arg_type = self.get_hir_type(arg); + + match (param, arg_type) { + (Type::Function(params, ret_type, _), Type::Function(_, _, env)) => { + Type::Function(params.clone(), ret_type.clone(), env) + } + _ => param.clone(), + } + }) + .collect(); + Type::Function(new_params, ret_type.clone(), env.clone()) + } + _ => typ.clone(), + } + } + fn lookup_function( &mut self, id: node_interner::FuncId, expr_id: node_interner::ExprId, typ: &HirType, ) -> Definition { - let typ = typ.follow_bindings(); + let mut typ = typ.follow_bindings(); + + if let Some(args) = self.call_expr_to_arguments.get(&expr_id) { + typ = self.add_capture_environments(&typ, args); + } + match self.globals.get(&id).and_then(|inner_map| inner_map.get(&typ)) { Some(id) => Definition::Function(*id), None => { @@ -186,19 +229,40 @@ impl<'interner> Monomorphizer<'interner> { fn compile_main(&mut self, main_id: node_interner::FuncId) -> FunctionSignature { let new_main_id = self.next_function_id(); assert_eq!(new_main_id, Program::main_id()); - self.function(main_id, new_main_id); + self.function(main_id, new_main_id, None); let main_meta = self.interner.function_meta(&main_id); main_meta.into_function_signature(self.interner) } - fn function(&mut self, f: node_interner::FuncId, id: FuncId) { + fn function(&mut self, f: node_interner::FuncId, id: FuncId, typ: Option) { + self.current_function_id = id; + let meta = self.interner.function_meta(&f); let name = self.interner.function_name(&f).to_owned(); - let return_type = Self::convert_type(meta.return_type()); - let parameters = self.parameters(meta.parameters); - let body = self.expr(*self.interner.function(&f).as_expr()); + let mut parameters: Parameters = meta.parameters.clone(); + + if let Some(Type::Function(params, _, _)) = typ { + parameters.0 = parameters + .0 + .iter() + .zip(params) + .map(|(param, arg)| Param(param.0.clone(), arg, param.2)) + .collect(); + } + + self.overridden_types.insert(id, OverriddenTypes::new()); + let parameters = self.parameters(parameters, &id); + + let hir_body = *self.interner.function(&f).as_expr(); + + self.propagate_overridden_types_expr(&id, &hir_body); + let body = self.expr(hir_body); + + let hir_return_type = self.get_hir_type(hir_body); + let return_type = Self::convert_type(&hir_return_type); + let unconstrained = meta.is_unconstrained || matches!(meta.contract_function_type, Some(ContractFunctionType::Open)); @@ -213,10 +277,10 @@ impl<'interner> Monomorphizer<'interner> { /// Monomorphize each parameter, expanding tuple/struct patterns into multiple parameters /// and binding any generic types found. - fn parameters(&mut self, params: Parameters) -> Vec<(ast::LocalId, bool, String, ast::Type)> { + fn parameters(&mut self, params: Parameters, func_id: &FuncId) -> ConvertedParameters { let mut new_params = Vec::with_capacity(params.len()); for parameter in params { - self.parameter(parameter.0, ¶meter.1, &mut new_params); + self.parameter(parameter.0, ¶meter.1, &mut new_params, func_id); } new_params } @@ -226,6 +290,7 @@ impl<'interner> Monomorphizer<'interner> { param: HirPattern, typ: &HirType, new_params: &mut Vec<(ast::LocalId, bool, String, ast::Type)>, + func_id: &FuncId, ) { match param { HirPattern::Identifier(ident) => { @@ -233,14 +298,19 @@ impl<'interner> Monomorphizer<'interner> { let definition = self.interner.definition(ident.id); let name = definition.name.clone(); new_params.push((new_id, definition.mutable, name, Self::convert_type(typ))); + + self.overridden_types + .get_mut(func_id) + .unwrap() + .insert(ident.id.into(), typ.clone()); self.define_local(ident.id, new_id); } - HirPattern::Mutable(pattern, _) => self.parameter(*pattern, typ, new_params), + HirPattern::Mutable(pattern, _) => self.parameter(*pattern, typ, new_params, func_id), HirPattern::Tuple(fields, _) => { let tuple_field_types = unwrap_tuple_type(typ); for (field, typ) in fields.into_iter().zip(tuple_field_types) { - self.parameter(field, &typ, new_params); + self.parameter(field, &typ, new_params, func_id); } } HirPattern::Struct(_, fields, _) => { @@ -256,12 +326,21 @@ impl<'interner> Monomorphizer<'interner> { unreachable!("Expected a field named '{field_name}' in the struct pattern") }); - self.parameter(field, &field_type, new_params); + self.parameter(field, &field_type, new_params, func_id); } } } } + fn get_hir_type(&self, expr: impl Into) -> HirType { + let idx: arena::Index = expr.into(); + + match self.overridden_types.get(&self.current_function_id).unwrap().get(&idx) { + Some(t) => t.clone(), + None => self.interner.id_type(idx), + } + } + fn expr(&mut self, expr: node_interner::ExprId) -> ast::Expression { use ast::Expression::Literal; use ast::Literal::*; @@ -279,7 +358,7 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Literal(HirLiteral::Bool(value)) => Literal(Bool(value)), HirExpression::Literal(HirLiteral::Integer(value)) => { - let typ = Self::convert_type(&self.interner.id_type(expr)); + let typ = Self::convert_type(&self.get_hir_type(expr)); Literal(Integer(value, typ)) } HirExpression::Literal(HirLiteral::Array(array)) => match array { @@ -294,7 +373,7 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Prefix(prefix) => ast::Expression::Unary(ast::Unary { operator: prefix.operator, rhs: Box::new(self.expr(prefix.rhs)), - result_type: Self::convert_type(&self.interner.id_type(expr)), + result_type: Self::convert_type(&self.get_hir_type(expr)), }), HirExpression::Infix(infix) => { @@ -332,7 +411,7 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::For(ast::For { index_variable, index_name: self.interner.definition_name(for_expr.identifier.id).to_owned(), - index_type: Self::convert_type(&self.interner.id_type(for_expr.start_range)), + index_type: Self::convert_type(&self.get_hir_type(for_expr.start_range)), start_range: Box::new(start), end_range: Box::new(end), start_range_location: self.interner.expr_location(&for_expr.start_range), @@ -349,7 +428,7 @@ impl<'interner> Monomorphizer<'interner> { condition: Box::new(cond), consequence: Box::new(then), alternative: else_, - typ: Self::convert_type(&self.interner.id_type(expr)), + typ: Self::convert_type(&self.get_hir_type(expr)), }) } @@ -373,7 +452,7 @@ impl<'interner> Monomorphizer<'interner> { array: node_interner::ExprId, array_elements: Vec, ) -> ast::Expression { - let typ = Self::convert_type(&self.interner.id_type(array)); + let typ = Self::convert_type(&self.get_hir_type(array)); let contents = vecmap(array_elements, |id| self.expr(id)); ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { contents, typ })) } @@ -384,7 +463,7 @@ impl<'interner> Monomorphizer<'interner> { repeated_element: node_interner::ExprId, length: HirType, ) -> ast::Expression { - let typ = Self::convert_type(&self.interner.id_type(array)); + let typ = Self::convert_type(&self.get_hir_type(array)); let contents = self.expr(repeated_element); let length = length @@ -396,7 +475,7 @@ impl<'interner> Monomorphizer<'interner> { } fn index(&mut self, id: node_interner::ExprId, index: HirIndexExpression) -> ast::Expression { - let element_type = Self::convert_type(&self.interner.id_type(id)); + let element_type = Self::convert_type(&self.get_hir_type(id)); let collection = Box::new(self.expr(index.collection)); let index = Box::new(self.expr(index.index)); @@ -421,7 +500,7 @@ impl<'interner> Monomorphizer<'interner> { fn let_statement(&mut self, let_statement: HirLetStatement) -> ast::Expression { let expr = self.expr(let_statement.expression); - let expected_type = self.interner.id_type(let_statement.expression); + let expected_type = self.get_hir_type(let_statement.expression); self.unpack_pattern(let_statement.pattern, expr, &expected_type) } @@ -430,7 +509,7 @@ impl<'interner> Monomorphizer<'interner> { constructor: HirConstructorExpression, id: node_interner::ExprId, ) -> ast::Expression { - let typ = self.interner.id_type(id); + let typ = self.get_hir_type(id); let field_types = unwrap_struct_type(&typ); let field_type_map = btree_map(&field_types, |x| x.clone()); @@ -581,7 +660,8 @@ impl<'interner> Monomorphizer<'interner> { let mutable = definition.mutable; let definition = self.lookup_local(ident.id)?; - let typ = Self::convert_type(&self.interner.id_type(ident.id)); + + let typ = Self::convert_type(&self.get_hir_type(ident.id)); Some(ast::Ident { location: Some(ident.location), mutable, definition, name, typ }) } @@ -593,10 +673,11 @@ impl<'interner> Monomorphizer<'interner> { let mutable = definition.mutable; let location = Some(ident.location); let name = definition.name.clone(); - let typ = self.interner.id_type(expr_id); + let hir_typ = self.get_hir_type(expr_id); + + let definition = self.lookup_function(*func_id, expr_id, &hir_typ); + let typ = Self::convert_type(&hir_typ); - let definition = self.lookup_function(*func_id, expr_id, &typ); - let typ = Self::convert_type(&typ); let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; let ident_expression = ast::Expression::Ident(ident); if self.is_function_closure_type(&typ) { @@ -734,7 +815,7 @@ impl<'interner> Monomorphizer<'interner> { } fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool { - let t = Self::convert_type(&self.interner.id_type(raw_func_id)); + let t = Self::convert_type(&self.get_hir_type(raw_func_id)); if self.is_function_closure_type(&t) { true } else if let ast::Type::Tuple(elements) = t { @@ -762,12 +843,17 @@ impl<'interner> Monomorphizer<'interner> { call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); + + self.call_expr_to_arguments.insert(call.func, call.arguments.clone()); + let original_func = Box::new(self.expr(call.func)); + self.call_expr_to_arguments.remove(&call.func); + let func: Box; - let return_type = self.interner.id_type(id); + let return_type = self.get_hir_type(id); let return_type = Self::convert_type(&return_type); + let location = call.location; if let ast::Expression::Ident(ident) = original_func.as_ref() { @@ -833,7 +919,7 @@ impl<'interner> Monomorphizer<'interner> { ) { match hir_argument { HirExpression::Ident(ident) => { - let typ = self.interner.id_type(ident.id); + let typ = self.get_hir_type(ident.id); let typ: Type = typ.follow_bindings(); let is_fmt_str = match typ { // A format string has many different possible types that need to be handled. @@ -940,12 +1026,12 @@ impl<'interner> Monomorphizer<'interner> { function_type: HirType, ) -> FuncId { let new_id = self.next_function_id(); - self.define_global(id, function_type, new_id); + self.define_global(id, function_type.clone(), new_id); let bindings = self.interner.get_instantiation_bindings(expr_id); let bindings = self.follow_bindings(bindings); - self.queue.push_back((id, new_id, bindings)); + self.queue.push_back((id, new_id, bindings, function_type)); new_id } @@ -1016,10 +1102,11 @@ impl<'interner> Monomorphizer<'interner> { Param(pattern, typ, noirc_abi::AbiVisibility::Private) })); - let parameters = self.parameters(parameters); + let id = self.next_function_id(); + self.overridden_types.insert(id, OverriddenTypes::new()); + let parameters = self.parameters(parameters, &id); let body = self.expr(lambda.body); - let id = self.next_function_id(); let return_type = ret_type.clone(); let name = lambda_name.to_owned(); let unconstrained = false; @@ -1068,9 +1155,10 @@ impl<'interner> Monomorphizer<'interner> { Param(pattern, typ, noirc_abi::AbiVisibility::Private) })); - let mut converted_parameters = self.parameters(parameters); - let id = self.next_function_id(); + self.overridden_types.insert(id, OverriddenTypes::new()); + let mut converted_parameters = self.parameters(parameters, &id); + let name = lambda_name.to_owned(); let return_type = ret_type.clone(); @@ -1093,11 +1181,11 @@ impl<'interner> Monomorphizer<'interner> { } } })); - let expr_type = self.interner.id_type(expr); + let expr_type = self.get_hir_type(expr); let env_typ = if let types::Type::Function(_, _, function_env_type) = expr_type { Self::convert_type(&function_env_type) } else { - unreachable!("expected a Function type for a Lambda node") + unreachable!("expected a Function type for a Lambda node, got {expr_type}") }; let env_let_stmt = ast::Expression::Let(ast::Let { diff --git a/crates/noirc_frontend/src/monomorphization/propagate.rs b/crates/noirc_frontend/src/monomorphization/propagate.rs new file mode 100644 index 00000000000..4797e94abc7 --- /dev/null +++ b/crates/noirc_frontend/src/monomorphization/propagate.rs @@ -0,0 +1,163 @@ +use crate::{ + hir_def::{expr::*, stmt::*}, + node_interner::{self, DefinitionKind}, + Type, +}; + +use super::{ast, Monomorphizer}; + +impl<'interner> Monomorphizer<'interner> { + /// Whenever we change the type of an expression during monomorphization + /// (which is currently only used to add lambda environments to higher-order function parameters) + /// we need to walk over the HIR and update the types of all other expressions that reference it + pub(crate) fn propagate_overridden_types_expr( + &mut self, + func: &ast::FuncId, + expr_id: &node_interner::ExprId, + ) { + let expr = self.interner.expression(expr_id); + + let typ: Option = match expr { + HirExpression::Ident(ident) => { + let definition = self.interner.definition(ident.id); + match &definition.kind { + DefinitionKind::Local(_) => { + let id: arena::Index = ident.id.into(); + self.overridden_types.get(func).unwrap().get(&id).cloned() + } + _ => None, + } + } + HirExpression::Block(block) => { + for stmt in &block.0 { + self.propagate_overridden_types_stmt(func, stmt); + } + + let mut typ = Type::Unit; + for stmt in block.0.iter().by_ref().rev() { + if let HirStatement::Expression(expr) = self.interner.statement(stmt) { + typ = self.get_hir_type(expr); + break; + } + } + Some(typ) + } + HirExpression::Call(call) => { + self.propagate_overridden_types_expr(func, &call.func); + for arg in call.arguments { + self.propagate_overridden_types_expr(func, &arg); + } + + let func = self.get_hir_type(call.func); + match func { + Type::Function(_, ret_type, _) => Some(*ret_type), + _ => None, + } + } + HirExpression::Prefix(prefix) => { + self.propagate_overridden_types_expr(func, &prefix.rhs); + None + } + HirExpression::Infix(infix) => { + self.propagate_overridden_types_expr(func, &infix.lhs); + self.propagate_overridden_types_expr(func, &infix.rhs); + None + } + HirExpression::Index(index) => { + self.propagate_overridden_types_expr(func, &index.collection); + self.propagate_overridden_types_expr(func, &index.index); + None + } + HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Standard(exprs))) => { + for expr in &exprs { + self.propagate_overridden_types_expr(func, expr); + } + None + } + HirExpression::Lambda(lambda) => { + for param in &lambda.parameters { + self.propagate_overridden_types_pattern(func, ¶m.0, param.1.clone()); + } + + self.propagate_overridden_types_expr(func, &lambda.body); + + None // TODO-AV patch the return type? + } + HirExpression::If(if_expr) => { + self.propagate_overridden_types_expr(func, &if_expr.condition); + self.propagate_overridden_types_expr(func, &if_expr.consequence); + + match &if_expr.alternative { + Some(alternative) => { + self.propagate_overridden_types_expr(func, alternative); + Some(self.get_hir_type(if_expr.consequence)) + } + None => None, + } + } + HirExpression::For(for_expr) => { + self.propagate_overridden_types_expr(func, &for_expr.start_range); + self.propagate_overridden_types_expr(func, &for_expr.end_range); + self.propagate_overridden_types_expr(func, &for_expr.block); + None + } + HirExpression::MethodCall(_) => { + unreachable!("unexpected method call during monomorphization") + } + HirExpression::Literal(_) => None, + HirExpression::Constructor(_) => None, + HirExpression::MemberAccess(_) => None, + HirExpression::Cast(_) => None, + HirExpression::Tuple(_) => None, + HirExpression::Error => None, + }; + + typ.map(|typ| self.overridden_types.get_mut(func).unwrap().insert(expr_id.into(), typ)); + } + + fn propagate_overridden_types_stmt( + &mut self, + func: &ast::FuncId, + stmt_id: &node_interner::StmtId, + ) { + let stmt = self.interner.statement(stmt_id); + match stmt { + HirStatement::Semi(expr) => { + self.propagate_overridden_types_expr(func, &expr); + } + HirStatement::Expression(expr) => { + self.propagate_overridden_types_expr(func, &expr); + } + HirStatement::Let(let_statement) => { + self.propagate_overridden_types_expr(func, &let_statement.expression); + let typ = self.get_hir_type(let_statement.expression); + + self.propagate_overridden_types_pattern(func, &let_statement.pattern, typ); + } + HirStatement::Constrain(constrain) => { + self.propagate_overridden_types_expr(func, &constrain.0); + } + HirStatement::Assign(assign) => { + self.propagate_overridden_types_expr(func, &assign.expression); + } + HirStatement::Error => (), + } + } + + fn propagate_overridden_types_pattern( + &mut self, + func: &ast::FuncId, + pattern: &HirPattern, + typ: Type, + ) { + match pattern { + HirPattern::Identifier(ident) => { + self.overridden_types.get_mut(func).unwrap().insert(ident.id.into(), typ); + } + HirPattern::Mutable(inner, _) => { + self.propagate_overridden_types_pattern(func, inner.as_ref(), typ); + } + _ => {} + }; + } +}