From 5ca99b128e9991b5272c00292208d85415e70edf Mon Sep 17 00:00:00 2001 From: Alex Vitkov <44268717+alexvitkov@users.noreply.github.com> Date: Mon, 25 Sep 2023 23:36:56 +0300 Subject: [PATCH] feat(traits): Implement trait bounds typechecker + monomorphizer passes (#2717) Co-authored-by: jfecher --- .../src/hir/resolution/resolver.rs | 2 +- .../noirc_frontend/src/hir/type_check/expr.rs | 117 ++++++++++++------ .../noirc_frontend/src/hir/type_check/mod.rs | 11 +- compiler/noirc_frontend/src/hir_def/expr.rs | 32 ++++- compiler/noirc_frontend/src/hir_def/traits.rs | 2 +- .../src/monomorphization/mod.rs | 57 ++++++++- .../trait_where_clause/Nargo.toml | 7 ++ .../trait_where_clause/src/main.nr | 41 ++++++ 8 files changed, 221 insertions(+), 48 deletions(-) create mode 100644 tooling/nargo_cli/tests/execution_success/trait_where_clause/Nargo.toml create mode 100644 tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 78388eacb94..36f50525d00 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -672,7 +672,7 @@ impl<'a> Resolver<'a> { ) -> Vec { vecmap(where_clause, |constraint| TraitConstraint { typ: self.resolve_type(constraint.typ.clone()), - trait_id: constraint.trait_bound.trait_id, + trait_id: constraint.trait_bound.trait_id.unwrap_or_else(TraitId::dummy_id), }) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index fdec07ae62d..920f2071586 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -6,11 +6,11 @@ use crate::{ hir_def::{ expr::{ self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression, - HirPrefixExpression, + HirMethodReference, HirPrefixExpression, }, types::Type, }, - node_interner::{DefinitionKind, ExprId, FuncId}, + node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId}, Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp, }; @@ -144,7 +144,7 @@ impl<'interner> TypeChecker<'interner> { let object_type = self.check_expression(&method_call.object).follow_bindings(); let method_name = method_call.method.0.contents.as_str(); match self.lookup_method(&object_type, method_name, expr_id) { - Some(method_id) => { + Some(method_ref) => { let mut args = vec![( object_type, method_call.object, @@ -160,22 +160,27 @@ impl<'interner> TypeChecker<'interner> { // so that the backend doesn't need to worry about methods let location = method_call.location; - // Automatically add `&mut` if the method expects a mutable reference and - // the object is not already one. - if method_id != FuncId::dummy_id() { - let func_meta = self.interner.function_meta(&method_id); - self.try_add_mutable_reference_to_object( - &mut method_call, - &func_meta.typ, - &mut args, - ); + if let HirMethodReference::FuncId(func_id) = method_ref { + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if func_id != FuncId::dummy_id() { + let func_meta = self.interner.function_meta(&func_id); + self.try_add_mutable_reference_to_object( + &mut method_call, + &func_meta.typ, + &mut args, + ); + } } - let (function_id, function_call) = - method_call.into_function_call(method_id, location, self.interner); + let (function_id, function_call) = method_call.into_function_call( + method_ref, + location, + self.interner, + ); let span = self.interner.expr_span(expr_id); - let ret = self.check_method_call(&function_id, &method_id, args, span); + let ret = self.check_method_call(&function_id, method_ref, args, span); self.interner.replace_expr(expr_id, function_call); ret @@ -286,6 +291,7 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } + HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"), }; self.interner.push_expr_type(expr_id, typ.clone()); @@ -477,34 +483,46 @@ impl<'interner> TypeChecker<'interner> { fn check_method_call( &mut self, function_ident_id: &ExprId, - func_id: &FuncId, + method_ref: HirMethodReference, arguments: Vec<(Type, ExprId, Span)>, span: Span, ) -> Type { - if func_id == &FuncId::dummy_id() { - Type::Error - } else { - let func_meta = self.interner.function_meta(func_id); + let (fntyp, param_len) = match method_ref { + HirMethodReference::FuncId(func_id) => { + if func_id == FuncId::dummy_id() { + return Type::Error; + } - // Check function call arity is correct - let param_len = func_meta.parameters.len(); - let arg_len = arguments.len(); + let func_meta = self.interner.function_meta(&func_id); + let param_len = func_meta.parameters.len(); - if param_len != arg_len { - self.errors.push(TypeCheckError::ArityMisMatch { - expected: param_len as u16, - found: arg_len as u16, - span, - }); + (func_meta.typ, param_len) } + HirMethodReference::TraitMethodId(method) => { + let the_trait = self.interner.get_trait(method.trait_id); + let the_trait = the_trait.borrow(); + let method = &the_trait.methods[method.method_index]; - let (function_type, instantiation_bindings) = func_meta.typ.instantiate(self.interner); + (method.get_type(), method.arguments.len()) + } + }; - self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings); - self.interner.push_expr_type(function_ident_id, function_type.clone()); + let arg_len = arguments.len(); - self.bind_function_type(function_type, arguments, span) + if param_len != arg_len { + self.errors.push(TypeCheckError::ArityMisMatch { + expected: param_len as u16, + found: arg_len as u16, + span, + }); } + + let (function_type, instantiation_bindings) = fntyp.instantiate(self.interner); + + self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings); + self.interner.push_expr_type(function_ident_id, function_type.clone()); + + self.bind_function_type(function_type, arguments, span) } fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type { @@ -818,11 +836,11 @@ impl<'interner> TypeChecker<'interner> { object_type: &Type, method_name: &str, expr_id: &ExprId, - ) -> Option { + ) -> Option { match object_type { Type::Struct(typ, _args) => { match self.interner.lookup_method(typ.borrow().id, method_name) { - Some(method_id) => Some(method_id), + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), None => { self.errors.push(TypeCheckError::UnresolvedMethodCall { method_name: method_name.to_string(), @@ -833,6 +851,33 @@ impl<'interner> TypeChecker<'interner> { } } } + Type::NamedGeneric(_, _) => { + let func_meta = self.interner.function_meta( + &self.current_function.expect("unexpected method outside a function"), + ); + + for constraint in func_meta.trait_constraints { + if *object_type == constraint.typ { + let the_trait = self.interner.get_trait(constraint.trait_id); + let the_trait = the_trait.borrow(); + + for (method_index, method) in the_trait.methods.iter().enumerate() { + if method.name.0.contents == method_name { + let trait_method = + TraitMethodId { trait_id: constraint.trait_id, method_index }; + return Some(HirMethodReference::TraitMethodId(trait_method)); + } + } + } + } + + self.errors.push(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span: self.interner.expr_span(expr_id), + }); + None + } // Mutable references to another type should resolve to methods of their element type. // This may be a struct or a primitive type. Type::MutableReference(element) => self.lookup_method(element, method_name, expr_id), @@ -843,7 +888,7 @@ impl<'interner> TypeChecker<'interner> { // In the future we could support methods for non-struct types if we have a context // (in the interner?) essentially resembling HashMap other => match self.interner.lookup_primitive_method(other, method_name) { - Some(method_id) => Some(method_id), + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), None => { self.errors.push(TypeCheckError::UnresolvedMethodCall { method_name: method_name.to_string(), diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index f3d8c58a426..c2afa44c495 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -27,6 +27,7 @@ pub struct TypeChecker<'interner> { delayed_type_checks: Vec, interner: &'interner mut NodeInterner, errors: Vec, + current_function: Option, } /// Type checks a function and assigns the @@ -40,6 +41,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec (noirc_e impl<'interner> TypeChecker<'interner> { fn new(interner: &'interner mut NodeInterner) -> Self { - Self { delayed_type_checks: Vec::new(), interner, errors: vec![] } + Self { delayed_type_checks: Vec::new(), interner, errors: vec![], current_function: None } } pub fn push_delayed_type_check(&mut self, f: TypeCheckFn) { @@ -127,7 +129,12 @@ impl<'interner> TypeChecker<'interner> { } pub fn check_global(id: &StmtId, interner: &'interner mut NodeInterner) -> Vec { - let mut this = Self { delayed_type_checks: Vec::new(), interner, errors: vec![] }; + let mut this = Self { + delayed_type_checks: Vec::new(), + interner, + errors: vec![], + current_function: None, + }; this.check_statement(id); this.errors } diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 3c49d3e4afc..4989dd12bd6 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -2,7 +2,7 @@ use acvm::FieldElement; use fm::FileId; use noirc_errors::Location; -use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId}; +use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, TraitMethodId}; use crate::{BinaryOp, BinaryOpKind, Ident, Shared, UnaryOp}; use super::stmt::HirPattern; @@ -30,6 +30,7 @@ pub enum HirExpression { If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), + TraitMethodReference(TraitMethodId), Error, } @@ -150,20 +151,39 @@ pub struct HirMethodCallExpression { pub location: Location, } +#[derive(Debug, Copy, Clone)] +pub enum HirMethodReference { + /// A method can be defined in a regular `impl` block, in which case + /// it's syntax sugar for a normal function call, and can be + /// translated to one during type checking + FuncId(FuncId), + + /// Or a method can come from a Trait impl block, in which case + /// the actual function called will depend on the instantiated type, + /// which can be only known during monomorphizaiton. + TraitMethodId(TraitMethodId), +} + impl HirMethodCallExpression { pub fn into_function_call( mut self, - func: FuncId, + method: HirMethodReference, location: Location, interner: &mut NodeInterner, ) -> (ExprId, HirExpression) { let mut arguments = vec![self.object]; arguments.append(&mut self.arguments); - let id = interner.function_definition_id(func); - let ident = HirExpression::Ident(HirIdent { location, id }); - let func = interner.push_expr(ident); - + let expr = match method { + HirMethodReference::FuncId(func_id) => { + let id = interner.function_definition_id(func_id); + HirExpression::Ident(HirIdent { location, id }) + } + HirMethodReference::TraitMethodId(method_id) => { + HirExpression::TraitMethodReference(method_id) + } + }; + let func = interner.push_expr(expr); (func, HirExpression::Call(HirCallExpression { func, arguments, location })) } } diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 4176e4fc89b..5e9b8723dbe 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -62,7 +62,7 @@ pub struct TraitImpl { #[derive(Debug, Clone)] pub struct TraitConstraint { pub typ: Type, - pub trait_id: Option, + pub trait_id: TraitId, // pub trait_generics: Generics, TODO } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index e519469ab2c..07b60ae48e9 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -11,7 +11,10 @@ use acvm::FieldElement; use iter_extended::{btree_map, vecmap}; use noirc_printable_type::PrintableType; -use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::{ + collections::{BTreeMap, HashMap, VecDeque}, + unreachable, +}; use crate::{ hir_def::{ @@ -20,7 +23,7 @@ use crate::{ stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, types, }, - node_interner::{self, DefinitionKind, NodeInterner, StmtId}, + node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKey, TraitMethodId}, token::FunctionAttribute, ContractFunctionType, FunctionKind, Type, TypeBinding, TypeBindings, TypeVariableKind, Visibility, @@ -375,6 +378,17 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Lambda(lambda) => self.lambda(lambda, expr), + HirExpression::TraitMethodReference(method) => { + if let Type::Function(args, _, _) = self.interner.id_type(expr) { + let self_type = args[0].clone(); + self.resolve_trait_method_reference(self_type, expr, method) + } else { + unreachable!( + "Calling a non-function, this should've been caught in typechecking" + ); + } + } + HirExpression::MethodCall(_) => { unreachable!("Encountered HirExpression::MethodCall during monomorphization") } @@ -777,6 +791,45 @@ impl<'interner> Monomorphizer<'interner> { } } + fn resolve_trait_method_reference( + &mut self, + self_type: HirType, + expr_id: node_interner::ExprId, + method: TraitMethodId, + ) -> ast::Expression { + let function_type = self.interner.id_type(expr_id); + + // the substitute() here is to replace all internal occurences of the 'Self' typevar + // with whatever 'Self' is currently bound to, so we don't lose type information + // if we need to rebind the trait. + let trait_impl = self + .interner + .get_trait_implementation(&TraitImplKey { + typ: self_type.follow_bindings(), + trait_id: method.trait_id, + }) + .expect("ICE: missing trait impl - should be caught during type checking"); + + let hir_func_id = trait_impl.borrow().methods[method.method_index]; + + let func_def = self.lookup_function(hir_func_id, expr_id, &function_type); + let func_id = match func_def { + Definition::Function(func_id) => func_id, + _ => unreachable!(), + }; + + let the_trait = self.interner.get_trait(method.trait_id); + let the_trait = the_trait.borrow(); + + ast::Expression::Ident(ast::Ident { + definition: Definition::Function(func_id), + mutable: false, + location: None, + name: the_trait.methods[method.method_index].name.0.contents.clone(), + typ: self.convert_type(&function_type), + }) + } + fn function_call( &mut self, call: HirCallExpression, diff --git a/tooling/nargo_cli/tests/execution_success/trait_where_clause/Nargo.toml b/tooling/nargo_cli/tests/execution_success/trait_where_clause/Nargo.toml new file mode 100644 index 00000000000..9f17579976b --- /dev/null +++ b/tooling/nargo_cli/tests/execution_success/trait_where_clause/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "trait_where_clause" +type = "bin" +authors = [""] +compiler_version = "0.11.1" + +[dependencies] \ No newline at end of file diff --git a/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr new file mode 100644 index 00000000000..2d6d71d8df0 --- /dev/null +++ b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr @@ -0,0 +1,41 @@ +// TODO(#2568): Currently we only support trait constraints on free functions. +// There's a bunch of other places where they can pop up: +// - trait methods (trait Foo where T: ... { ) +// - free impl blocks (impl Foo where T...) +// - trait impl blocks (impl Foo for Bar where T...) +// - structs (struct Foo where T: ...) + +trait Asd { + fn asd(self) -> Field; +} + +struct Add10 { x: Field, } +struct Add20 { x: Field, } +struct Add30 { x: Field, } +struct AddXY { x: Field, y: Field, } + +impl Asd for Add10 { fn asd(self) -> Field { self.x + 10 } } +impl Asd for Add20 { fn asd(self) -> Field { self.x + 20 } } +impl Asd for Add30 { fn asd(self) -> Field { self.x + 30 } } + +impl Asd for AddXY { + fn asd(self) -> Field { + self.x + self.y + } +} + +fn assert_asd_eq_100(t: T) where T: Asd { + assert(t.asd() == 100); +} + +fn main() { + let x = Add10{ x: 90 }; + let z = Add20{ x: 80 }; + let a = Add30{ x: 70 }; + let xy = AddXY{ x: 30, y: 70 }; + + assert_asd_eq_100(x); + assert_asd_eq_100(z); + assert_asd_eq_100(a); + assert_asd_eq_100(xy); +}