diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 261a97ec3fb..903751331b3 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -223,8 +223,8 @@ impl<'context> Elaborator<'context> { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - for ((typ, module), impls) in &mut items.impls { - this.collect_impls(typ, *module, impls); + for ((_self_type, module), impls) in &mut items.impls { + this.collect_impls(*module, impls); } // We must wait to resolve non-literal globals until after we resolve structs since struct @@ -284,8 +284,6 @@ impl<'context> Elaborator<'context> { self.scopes.start_function(); self.current_item = Some(DependencyId::Function(id)); - // Check whether the function has globals in the local module and add them to the scope - self.resolve_local_globals(); self.trait_bounds = function.def.where_clause.clone(); if function.def.is_unconstrained { @@ -474,18 +472,6 @@ impl<'context> Elaborator<'context> { None } - fn resolve_local_globals(&mut self) { - let globals = vecmap(self.interner.get_all_globals(), |global| { - (global.id, global.local_id, global.ident.clone()) - }); - for (id, local_module_id, name) in globals { - if local_module_id == self.local_module { - let definition = DefinitionKind::Global(id); - self.add_global_variable_decl(name, definition); - } - } - } - /// TODO: This is currently only respected for generic free functions /// there's a bunch of other places where trait constraints can pop up fn resolve_trait_constraints( @@ -494,21 +480,20 @@ impl<'context> Elaborator<'context> { ) -> Vec { where_clause .iter() - .cloned() .filter_map(|constraint| self.resolve_trait_constraint(constraint)) .collect() } pub fn resolve_trait_constraint( &mut self, - constraint: UnresolvedTraitConstraint, + constraint: &UnresolvedTraitConstraint, ) -> Option { - let typ = self.resolve_type(constraint.typ); + let typ = self.resolve_type(constraint.typ.clone()); let trait_generics = - vecmap(constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ)); + vecmap(&constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ.clone())); let span = constraint.trait_bound.trait_path.span(); - let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path)?; + let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path.clone())?; let trait_id = the_trait.id; let expected_generics = the_trait.generics.len(); @@ -548,9 +533,6 @@ impl<'context> Elaborator<'context> { self.scopes.start_function(); self.current_item = Some(DependencyId::Function(func_id)); - // Check whether the function has globals in the local module and add them to the scope - self.resolve_local_globals(); - let location = Location::new(func.name_ident().span(), self.file); let id = self.interner.function_definition_id(func_id); let name_ident = HirIdent::non_trait_method(id, location); @@ -864,40 +846,69 @@ impl<'context> Elaborator<'context> { self.file = trait_impl.file_id; self.local_module = trait_impl.module_id; - let unresolved_type = trait_impl.object_type; - let self_type_span = unresolved_type.span; - let old_generics_length = self.generics.len(); self.generics = trait_impl.resolved_generics; + self.current_trait_impl = trait_impl.impl_id; - let trait_generics = - vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())); + self.elaborate_functions(trait_impl.methods); - let self_type = trait_impl.resolved_object_type.unwrap_or(Type::Error); - let impl_id = - trait_impl.impl_id.expect("An impls' id should be set during define_function_metas"); + self.self_type = None; + self.current_trait_impl = None; + self.generics.clear(); + } - self.current_trait_impl = trait_impl.impl_id; + fn collect_impls( + &mut self, + module: LocalModuleId, + impls: &mut [(Vec, Span, UnresolvedFunctions)], + ) { + self.local_module = module; - let methods = trait_impl.methods.function_ids(); + for (generics, span, unresolved) in impls { + self.file = unresolved.file_id; + let old_generic_count = self.generics.len(); + self.add_generics(generics); + self.declare_methods_on_struct(false, unresolved, *span); + self.generics.truncate(old_generic_count); + } + } - self.elaborate_functions(trait_impl.methods); + fn collect_trait_impl(&mut self, trait_impl: &mut UnresolvedTraitImpl) { + self.local_module = trait_impl.module_id; + self.file = trait_impl.file_id; + trait_impl.trait_id = self.resolve_trait_by_path(trait_impl.trait_path.clone()); + + let self_type = trait_impl.methods.self_type.clone(); + let self_type = + self_type.expect("Expected struct type to be set before collect_trait_impl"); + + let self_type_span = trait_impl.object_type.span; if matches!(self_type, Type::MutableReference(_)) { let span = self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()); self.push_err(DefCollectorErrorKind::MutableReferenceInTraitImpl { span }); } + assert!(trait_impl.trait_id.is_some()); if let Some(trait_id) = trait_impl.trait_id { + self.collect_trait_impl_methods(trait_id, trait_impl); + + let span = trait_impl.object_type.span.expect("All trait self types should have spans"); + self.generics = trait_impl.resolved_generics.clone(); + self.declare_methods_on_struct(true, &mut trait_impl.methods, span); + + let methods = trait_impl.methods.function_ids(); for func_id in &methods { self.interner.set_function_trait(*func_id, self_type.clone(), trait_id); } let where_clause = trait_impl .where_clause - .into_iter() + .iter() .flat_map(|item| self.resolve_trait_constraint(item)) .collect(); + let trait_generics = trait_impl.resolved_trait_generics.clone(); + let resolved_trait_impl = Shared::new(TraitImpl { ident: trait_impl.trait_path.last_segment().clone(), typ: self_type.clone(), @@ -914,7 +925,7 @@ impl<'context> Elaborator<'context> { self_type.clone(), trait_id, trait_generics, - impl_id, + trait_impl.impl_id.expect("impl_id should be set in define_function_metas"), generics, resolved_trait_impl, ) { @@ -931,45 +942,7 @@ impl<'context> Elaborator<'context> { } } - self.self_type = None; - self.current_trait_impl = None; - self.generics.truncate(old_generics_length); - } - - fn collect_impls( - &mut self, - self_type: &UnresolvedType, - module: LocalModuleId, - impls: &mut [(Vec, Span, UnresolvedFunctions)], - ) { - self.local_module = module; - - for (generics, span, unresolved) in impls { - self.file = unresolved.file_id; - self.recover_generics(|this| { - this.add_generics(generics); - this.declare_methods_on_struct(self_type, false, unresolved, *span); - }); - } - } - - fn collect_trait_impl(&mut self, trait_impl: &mut UnresolvedTraitImpl) { - self.local_module = trait_impl.module_id; - self.file = trait_impl.file_id; - trait_impl.trait_id = self.resolve_trait_by_path(trait_impl.trait_path.clone()); - - if let Some(trait_id) = trait_impl.trait_id { - self.collect_trait_impl_methods(trait_id, trait_impl); - - let span = trait_impl.object_type.span.expect("All trait self types should have spans"); - let object_type = &trait_impl.object_type; - - self.recover_generics(|this| { - this.add_generics(&trait_impl.generics); - trait_impl.resolved_generics = this.generics.clone(); - this.declare_methods_on_struct(object_type, true, &mut trait_impl.methods, span); - }); - } + self.generics.clear(); } fn get_module_mut(&mut self, module: ModuleId) -> &mut ModuleData { @@ -979,14 +952,13 @@ impl<'context> Elaborator<'context> { fn declare_methods_on_struct( &mut self, - self_type: &UnresolvedType, is_trait_impl: bool, functions: &mut UnresolvedFunctions, span: Span, ) { - let self_type = self.resolve_type(self_type.clone()); - - functions.self_type = Some(self_type.clone()); + let self_type = functions.self_type.as_ref(); + let self_type = + self_type.expect("Expected struct type to be set before declare_methods_on_struct"); let function_ids = functions.function_ids(); @@ -1016,11 +988,11 @@ impl<'context> Elaborator<'context> { } } - self.declare_struct_methods(&self_type, &function_ids); + self.declare_struct_methods(self_type, &function_ids); // We can define methods on primitive types only if we're in the stdlib - } else if !is_trait_impl && self_type != Type::Error { + } else if !is_trait_impl && *self_type != Type::Error { if self.crate_id.is_stdlib() { - self.declare_struct_methods(&self_type, &function_ids); + self.declare_struct_methods(self_type, &function_ids); } else { self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span }); } @@ -1145,8 +1117,8 @@ impl<'context> Elaborator<'context> { self.local_module = trait_impl.module_id; self.file = trait_impl.file_id; - let object_crate = match self.resolve_type(trait_impl.object_type.clone()) { - Type::Struct(struct_type, _) => struct_type.borrow().id.krate(), + let object_crate = match &trait_impl.resolved_object_type { + Some(Type::Struct(struct_type, _)) => struct_type.borrow().id.krate(), _ => CrateId::Dummy, }; @@ -1163,7 +1135,6 @@ impl<'context> Elaborator<'context> { self.local_module = alias.module_id; let generics = self.add_generics(&alias.type_alias_def.generics); - self.resolve_local_globals(); self.current_item = Some(DependencyId::Alias(alias_id)); let typ = self.resolve_type(alias.type_alias_def.typ); self.interner.set_type_alias(alias_id, typ, generics); @@ -1215,9 +1186,6 @@ impl<'context> Elaborator<'context> { self.recover_generics(|this| { let generics = this.add_generics(&unresolved.generics); - // Check whether the struct definition has globals in the local module and add them to the scope - this.resolve_local_globals(); - this.current_item = Some(DependencyId::Struct(struct_id)); this.resolving_ids.insert(struct_id); @@ -1253,7 +1221,7 @@ impl<'context> Elaborator<'context> { let (let_statement, _typ) = self.elaborate_let(let_stmt); let statement_id = self.interner.get_global(global_id).let_statement; - self.interner.get_global_definition_mut(global_id).kind = definition_kind; + self.interner.get_global_definition_mut(global_id).kind = definition_kind.clone(); self.interner.replace_statement(statement_id, let_statement); } @@ -1286,6 +1254,11 @@ impl<'context> Elaborator<'context> { let unresolved_type = &trait_impl.object_type; self.add_generics(&trait_impl.generics); + trait_impl.resolved_generics = self.generics.clone(); + + let trait_generics = + vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())); + trait_impl.resolved_trait_generics = trait_generics; let self_type = self.resolve_type(unresolved_type.clone()); diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index 76cdc592276..2a35467f55a 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -13,7 +13,7 @@ use crate::{ }, node_interner::{FuncId, TraitId}, token::Attributes, - Generics, Type, TypeVariable, TypeVariableKind, + Type, TypeVariableKind, }; use super::Elaborator; @@ -21,21 +21,21 @@ use super::Elaborator; impl<'context> Elaborator<'context> { pub fn collect_traits(&mut self, traits: BTreeMap) { for (trait_id, unresolved_trait) in traits { - let generics = vecmap(&unresolved_trait.trait_def.generics, |_| { - TypeVariable::unbound(self.interner.next_type_variable_id()) - }); - - // Resolve order - // 1. Trait Types ( Trait constants can have a trait type, therefore types before constants) - let _ = self.resolve_trait_types(&unresolved_trait); - // 2. Trait Constants ( Trait's methods can use trait types & constants, therefore they should be after) - let _ = self.resolve_trait_constants(&unresolved_trait); - // 3. Trait Methods - let methods = self.resolve_trait_methods(trait_id, &unresolved_trait, &generics); - - self.interner.update_trait(trait_id, |trait_def| { - trait_def.set_methods(methods); - trait_def.generics = generics; + self.recover_generics(|this| { + this.add_generics(&unresolved_trait.trait_def.generics); + + // Resolve order + // 1. Trait Types ( Trait constants can have a trait type, therefore types before constants) + let _ = this.resolve_trait_types(&unresolved_trait); + // 2. Trait Constants ( Trait's methods can use trait types & constants, therefore they should be after) + let _ = this.resolve_trait_constants(&unresolved_trait); + // 3. Trait Methods + let methods = this.resolve_trait_methods(trait_id, &unresolved_trait); + + this.interner.update_trait(trait_id, |trait_def| { + trait_def.set_methods(methods); + trait_def.generics = vecmap(&this.generics, |(_, generic, _)| generic.clone()); + }); }); // This check needs to be after the trait's methods are set since @@ -64,7 +64,6 @@ impl<'context> Elaborator<'context> { &mut self, trait_id: TraitId, unresolved_trait: &UnresolvedTrait, - trait_generics: &Generics, ) -> Vec { self.local_module = unresolved_trait.module_id; self.file = self.def_maps[&self.crate_id].file_id(unresolved_trait.module_id); @@ -81,59 +80,57 @@ impl<'context> Elaborator<'context> { body: _, } = item { - let old_generic_count = self.generics.len(); - - let the_trait = self.interner.get_trait(trait_id); - let self_typevar = the_trait.self_type_typevar.clone(); - let self_type = Type::TypeVariable(self_typevar.clone(), TypeVariableKind::Normal); - let name_span = the_trait.name.span(); - - self.add_generics(generics); - self.add_existing_generics(&unresolved_trait.trait_def.generics, trait_generics); - self.add_existing_generic("Self", name_span, self_typevar); - self.self_type = Some(self_type.clone()); - - let func_id = unresolved_trait.method_ids[&name.0.contents]; - self.resolve_trait_function( - name, - generics, - parameters, - return_type, - where_clause, - func_id, - ); - - let arguments = vecmap(parameters, |param| self.resolve_type(param.1.clone())); - let return_type = self.resolve_type(return_type.get_type().into_owned()); - - let generics = vecmap(&self.generics, |(_, type_var, _)| type_var.clone()); - - let default_impl_list: Vec<_> = unresolved_trait - .fns_with_default_impl - .functions - .iter() - .filter(|(_, _, q)| q.name() == name.0.contents) - .collect(); - - let default_impl = if default_impl_list.len() == 1 { - Some(Box::new(default_impl_list[0].2.clone())) - } else { - None - }; - - let no_environment = Box::new(Type::Unit); - let function_type = - Type::Function(arguments, Box::new(return_type), no_environment); - - functions.push(TraitFunction { - name: name.clone(), - typ: Type::Forall(generics, Box::new(function_type)), - location: Location::new(name.span(), unresolved_trait.file_id), - default_impl, - default_impl_module_id: unresolved_trait.module_id, + self.recover_generics(|this| { + let the_trait = this.interner.get_trait(trait_id); + let self_typevar = the_trait.self_type_typevar.clone(); + let self_type = + Type::TypeVariable(self_typevar.clone(), TypeVariableKind::Normal); + let name_span = the_trait.name.span(); + + this.add_existing_generic("Self", name_span, self_typevar); + this.add_generics(generics); + this.self_type = Some(self_type.clone()); + + let func_id = unresolved_trait.method_ids[&name.0.contents]; + this.resolve_trait_function( + name, + generics, + parameters, + return_type, + where_clause, + func_id, + ); + + let arguments = vecmap(parameters, |param| this.resolve_type(param.1.clone())); + let return_type = this.resolve_type(return_type.get_type().into_owned()); + + let generics = vecmap(&this.generics, |(_, type_var, _)| type_var.clone()); + + let default_impl_list: Vec<_> = unresolved_trait + .fns_with_default_impl + .functions + .iter() + .filter(|(_, _, q)| q.name() == name.0.contents) + .collect(); + + let default_impl = if default_impl_list.len() == 1 { + Some(Box::new(default_impl_list[0].2.clone())) + } else { + None + }; + + let no_environment = Box::new(Type::Unit); + let function_type = + Type::Function(arguments, Box::new(return_type), no_environment); + + functions.push(TraitFunction { + name: name.clone(), + typ: Type::Forall(generics, Box::new(function_type)), + location: Location::new(name.span(), unresolved_trait.file_id), + default_impl, + default_impl_module_id: unresolved_trait.module_id, + }); }); - - self.generics.truncate(old_generic_count); } } functions @@ -151,9 +148,6 @@ impl<'context> Elaborator<'context> { let old_generic_count = self.generics.len(); self.scopes.start_function(); - // Check whether the function has globals in the local module and add them to the scope - self.resolve_local_globals(); - self.trait_bounds = where_clause.to_vec(); let kind = FunctionKind::Normal; diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 838aac3f067..05147af5459 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -121,11 +121,15 @@ pub struct UnresolvedTraitImpl { pub generics: UnresolvedGenerics, pub where_clause: Vec, - // These fields are filled in later + // Every field after this line is filled in later in the elaborator pub trait_id: Option, pub impl_id: Option, pub resolved_object_type: Option, pub resolved_generics: Vec<(Rc, TypeVariable, Span)>, + + // The resolved generic on the trait itself. E.g. it is the `` in + // `impl Foo for Bar { ... }` + pub resolved_trait_generics: Vec, } #[derive(Clone)] diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 536332d8f8a..5f1ef6477af 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -197,6 +197,7 @@ impl<'a> ModCollector<'a> { impl_id: None, resolved_object_type: None, resolved_generics: Vec::new(), + resolved_trait_generics: Vec::new(), }; self.def_collector.items.trait_impls.push(unresolved_trait_impl);