From 06951f4fa3908413bae628dda39eb239fa91b712 Mon Sep 17 00:00:00 2001 From: Yoshitomo Nakanishi Date: Mon, 28 Nov 2022 03:00:51 +0100 Subject: [PATCH] Implement function monomorphization in the CGUs generation phase --- crates/codegen/src/db.rs | 29 ++- crates/codegen/src/db/queries.rs | 1 + crates/codegen/src/db/queries/cgu.rs | 261 +++++++++++++++++++++ crates/codegen/src/db/queries/function.rs | 4 +- crates/codegen/src/lib.rs | 1 + crates/codegen/src/yul/isel/function.rs | 22 +- crates/codegen/src/yul/runtime/contract.rs | 4 +- crates/mir/src/ir/body_builder.rs | 3 + crates/mir/src/ir/inst.rs | 4 + crates/mir/src/lower/function.rs | 22 +- crates/mir/src/pretty_print/inst.rs | 1 + 11 files changed, 323 insertions(+), 29 deletions(-) create mode 100644 crates/codegen/src/db/queries/cgu.rs diff --git a/crates/codegen/src/db.rs b/crates/codegen/src/db.rs index de52a2a681..8eb9ac3c9c 100644 --- a/crates/codegen/src/db.rs +++ b/crates/codegen/src/db.rs @@ -1,23 +1,36 @@ use std::rc::Rc; use fe_abi::{contract::AbiContract, event::AbiEvent, function::AbiFunction, types::AbiType}; -use fe_analyzer::{db::AnalyzerDbStorage, namespace::items::ContractId, AnalyzerDb}; +use fe_analyzer::{ + db::AnalyzerDbStorage, + namespace::items::{ContractId, FunctionId, IngotId}, + AnalyzerDb, +}; use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; use fe_mir::{ db::{MirDb, MirDbStorage}, - ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, + ir::{FunctionBody, FunctionSigId, FunctionSignature, TypeId}, }; +use crate::cgu::{CguFunction, CguFunctionId, CodegenUnit, CodegenUnitId}; + mod queries; #[salsa::query_group(CodegenDbStorage)] pub trait CodegenDb: MirDb + Upcast + UpcastMut { + #[salsa::interned] + fn codegen_intern_cgu(&self, data: Rc) -> CodegenUnitId; + #[salsa::interned] + fn codegen_intern_cgu_func(&self, data: Rc) -> CguFunctionId; + #[salsa::invoke(queries::cgu::generate_cgu)] + fn codegen_generate_cgu(&self, ingot: IngotId) -> Rc<[CodegenUnitId]>; + #[salsa::invoke(queries::function::legalized_signature)] - fn codegen_legalized_signature(&self, function_id: FunctionId) -> Rc; + fn codegen_legalized_signature(&self, function_id: FunctionSigId) -> Rc; #[salsa::invoke(queries::function::legalized_body)] - fn codegen_legalized_body(&self, function_id: FunctionId) -> Rc; + fn codegen_legalized_body(&self, func: FunctionId) -> Rc; #[salsa::invoke(queries::function::symbol_name)] - fn codegen_function_symbol_name(&self, function_id: FunctionId) -> Rc; + fn codegen_function_symbol_name(&self, function_id: FunctionSigId) -> Rc; #[salsa::invoke(queries::types::legalized_type)] fn codegen_legalized_type(&self, ty: TypeId) -> TypeId; @@ -25,7 +38,7 @@ pub trait CodegenDb: MirDb + Upcast + UpcastMut { #[salsa::invoke(queries::abi::abi_type)] fn codegen_abi_type(&self, ty: TypeId) -> AbiType; #[salsa::invoke(queries::abi::abi_function)] - fn codegen_abi_function(&self, function_id: FunctionId) -> AbiFunction; + fn codegen_abi_function(&self, function_id: FunctionSigId) -> AbiFunction; #[salsa::invoke(queries::abi::abi_event)] fn codegen_abi_event(&self, ty: TypeId) -> AbiEvent; #[salsa::invoke(queries::abi::abi_contract)] @@ -35,9 +48,9 @@ pub trait CodegenDb: MirDb + Upcast + UpcastMut { #[salsa::invoke(queries::abi::abi_type_minimum_size)] fn codegen_abi_type_minimum_size(&self, ty: TypeId) -> usize; #[salsa::invoke(queries::abi::abi_function_argument_maximum_size)] - fn codegen_abi_function_argument_maximum_size(&self, contract: FunctionId) -> usize; + fn codegen_abi_function_argument_maximum_size(&self, contract: FunctionSigId) -> usize; #[salsa::invoke(queries::abi::abi_function_return_maximum_size)] - fn codegen_abi_function_return_maximum_size(&self, function: FunctionId) -> usize; + fn codegen_abi_function_return_maximum_size(&self, function: FunctionSigId) -> usize; #[salsa::invoke(queries::contract::symbol_name)] fn codegen_contract_symbol_name(&self, contract: ContractId) -> Rc; diff --git a/crates/codegen/src/db/queries.rs b/crates/codegen/src/db/queries.rs index 31cca43870..1849ee625e 100644 --- a/crates/codegen/src/db/queries.rs +++ b/crates/codegen/src/db/queries.rs @@ -1,4 +1,5 @@ pub mod abi; +pub mod cgu; pub mod constant; pub mod contract; pub mod function; diff --git a/crates/codegen/src/db/queries/cgu.rs b/crates/codegen/src/db/queries/cgu.rs new file mode 100644 index 0000000000..5a290ed9bb --- /dev/null +++ b/crates/codegen/src/db/queries/cgu.rs @@ -0,0 +1,261 @@ +//! This module contains algorithms to generate Codegen Units (CGUs) from the +//! MIR modules. +//! +//! The algorithm to collect CGUs is basically a normal worklist algorithm. +//! Therefore, the algorithm is summarized by the following steps: +//! +//! 1. Collect all functions which are not generic functions, and add them to +//! the worklist. +//! +//! 2. Pop a function from the worklist. If the function is not in a visited +//! set, add it to the visited set and proceed to step 3. +//! +//! 3. Iterate all the instructions in the current function to find a generic +//! function call. If generic function calls are found, monomorphizes them and +//! add them to the worklist. +//! +//! 4. Add the current function to the CGU of the module. +//! +//! 5. Repeat steps 2-4 until the worklist is empty. + +use std::rc::Rc; + +use fe_analyzer::namespace::{ + items::FunctionSigId as AnalyzerFuncSigId, + items::{IngotId, Item}, +}; +use fe_mir::ir::{ + body_cursor::{BodyCursor, CursorLocation}, + function::BodyDataStore, + inst::InstKind, + FunctionBody, FunctionParam, FunctionSigId, FunctionSignature, InstId, TypeId, TypeKind, + ValueId, +}; +use fxhash::{FxHashMap, FxHashSet}; +use indexmap::IndexSet; +use smol_str::SmolStr; + +use crate::{ + cgu::{CguFunction, CguFunctionId, CodegenUnit, CodegenUnitId}, + db::CodegenDb, +}; + +pub fn generate_cgu(db: &dyn CodegenDb, ingot: IngotId) -> Rc<[CodegenUnitId]> { + let mut worklist = Vec::new(); + + // Step 1. + for &module in ingot.all_modules(db.upcast()).iter() { + for &function in module.all_functions(db.upcast()).iter() { + if !function.is_generic(db.upcast()) { + let sig = function.sig(db.upcast()); + let sig = db.mir_lowered_func_signature(sig); + let body = (*db.mir_lowered_func_body(function)).clone(); + worklist.push((sig, body)); + } + } + } + + let mut cgu_map: FxHashMap<_, CodegenUnit> = FxHashMap::default(); + let mut visited = FxHashSet::default(); + + // Step 2-4. + while let Some((sig, body)) = worklist.pop() { + // Step 2. + if !visited.insert(sig) { + continue; + } + + // Step 3. + let (cgu_func_id, instantiated_funcs) = CguFunctionBuilder::new(db).build(sig, body); + worklist.extend_from_slice(instantiated_funcs.as_slice()); + + // Step 4. + let module = sig.module(db.upcast()); + cgu_map + .entry(module) + .or_default() + .functions + .push(cgu_func_id); + } + + cgu_map + .into_values() + .map(|cgu| db.codegen_intern_cgu(cgu.into())) + .collect() +} + +struct CguFunctionBuilder<'db> { + db: &'db dyn CodegenDb, +} + +impl<'db> CguFunctionBuilder<'db> { + fn new(db: &'db dyn CodegenDb) -> Self { + Self { db } + } + + fn build( + &self, + sig: FunctionSigId, + mut body: FunctionBody, + ) -> (CguFunctionId, Vec<(FunctionSigId, FunctionBody)>) { + let mut mono_funcs = Vec::new(); + let mut callees = IndexSet::new(); + + let mut cursor = BodyCursor::new_at_entry(&mut body); + loop { + match cursor.loc() { + CursorLocation::BlockTop(_) | CursorLocation::BlockBottom(_) => cursor.proceed(), + CursorLocation::NoWhere => { + break; + } + CursorLocation::Inst(inst) => { + if let InstKind::Call { func: callee, .. } = + &cursor.body().store.inst_data(inst).kind + { + if callee.is_generic(self.db.upcast()) { + let (mono_sig, mono_body) = + self.monomorphize_func(&cursor.body().store, inst); + // Rewrite the callee to the mono function. + cursor.body_mut().store.rewrite_callee(inst, mono_sig); + + if callees.insert(mono_sig) { + mono_funcs.push((mono_sig, mono_body)); + } + } else { + callees.insert(*callee); + } + } + cursor.proceed(); + } + } + } + let cgu_func = CguFunction { sig, body, callees }; + let cgu_func_id = self.db.codegen_intern_cgu_func(cgu_func.into()); + (cgu_func_id, mono_funcs) + } + + fn monomorphize_func( + &self, + store: &BodyDataStore, + inst: InstId, + ) -> (FunctionSigId, FunctionBody) { + let InstKind::Call {func: callee, args, ..} = &store.inst_data(inst).kind else { + panic!("expected a call instruction"); + }; + debug_assert!(callee.is_generic(self.db.upcast())); + + let callee = *callee; + let subst = self.get_subst(store, callee, args); + + let mono_sig = self.monomorphize_sig(callee, &subst); + let mono_body = self.monomorphize_body(callee, &subst); + + (self.db.mir_intern_function(mono_sig.into()), mono_body) + } + + fn monomorphize_sig( + &self, + sig: FunctionSigId, + subst: &FxHashMap, + ) -> FunctionSignature { + let params = sig + .data(self.db.upcast()) + .params + .iter() + .map(|param| { + let ty = match ¶m.ty.data(self.db.upcast()).kind { + TypeKind::TypeParam(def) => subst[&def.name], + _ => param.ty, + }; + + FunctionParam { + name: param.name.clone(), + ty, + source: param.source.clone(), + } + }) + .collect(); + + let return_type = sig.return_type(self.db.upcast()).clone(); + let module_id = sig.module(self.db.upcast()); + let linkage = sig.linkage(self.db.upcast()); + let analyzer_id = self.resolve_function(sig.analyzer_sig(self.db.upcast()), subst); + + FunctionSignature { + name: self.mono_func_name(analyzer_id, subst), + params, + return_type, + module_id, + analyzer_id, + linkage, + has_self: sig.has_self(self.db.upcast()), + } + } + + fn monomorphize_body( + &self, + func: FunctionSigId, + subst: &FxHashMap, + ) -> FunctionBody { + let func_id = func + .analyzer_sig(self.db.upcast()) + .function(self.db.upcast()) + .unwrap(); + + let mut body = (*self.db.mir_lowered_func_body(func_id)).clone(); + for value in body.store.values_mut() { + match &value.ty().data(self.db.upcast()).kind { + TypeKind::TypeParam(def) => { + let subst_ty = subst[&def.name]; + *value.ty_mut() = subst_ty; + } + _ => {} + } + } + + body + } + + /// Resolve the trait function signature into the corresponding concrete + /// implementation if the `callee` is a trait function. + fn resolve_function( + &self, + callee: AnalyzerFuncSigId, + subst: &FxHashMap, + ) -> AnalyzerFuncSigId { + let trait_id = match callee.parent(self.db.upcast()) { + Item::Trait(id) => id, + _ => return callee, + }; + + todo!() + } + + fn mono_func_name( + &self, + callee: AnalyzerFuncSigId, + subst: &FxHashMap, + ) -> SmolStr { + todo!() + } + + fn get_subst( + &self, + store: &BodyDataStore, + callee: FunctionSigId, + args: &[ValueId], + ) -> FxHashMap { + debug_assert_eq!(callee.data(self.db.upcast()).params.len(), args.len()); + + callee + .data(self.db.upcast()) + .params + .iter() + .zip(args) + .filter_map(|(param, arg)| match ¶m.ty.data(self.db.upcast()).kind { + TypeKind::TypeParam(def) => Some((def.name.clone(), store.value_ty(*arg))), + _ => None, + }) + .collect() + } +} diff --git a/crates/codegen/src/db/queries/function.rs b/crates/codegen/src/db/queries/function.rs index 03ec7bed91..532e1e8ff4 100644 --- a/crates/codegen/src/db/queries/function.rs +++ b/crates/codegen/src/db/queries/function.rs @@ -14,7 +14,7 @@ use smol_str::SmolStr; use crate::{db::CodegenDb, yul::legalize}; pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionSigId) -> Rc { - let mut sig = function.signature(db.upcast()).as_ref().clone(); + let mut sig = function.data(db.upcast()).as_ref().clone(); legalize::legalize_func_signature(db, &mut sig); sig.into() } @@ -26,7 +26,7 @@ pub fn legalized_body(db: &dyn CodegenDb, func: FunctionId) -> Rc } pub fn symbol_name(db: &dyn CodegenDb, function: FunctionSigId) -> Rc { - let module = function.signature(db.upcast()).module_id; + let module = function.data(db.upcast()).module_id; let module_name = module.name(db.upcast()); let analyzer_func = function.analyzer_sig(db.upcast()); diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 37ec962db2..3b3df510cd 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -1,2 +1,3 @@ +pub mod cgu; pub mod db; pub mod yul; diff --git a/crates/codegen/src/yul/isel/function.rs b/crates/codegen/src/yul/isel/function.rs index 51078bdc9c..dbbd427cb4 100644 --- a/crates/codegen/src/yul/isel/function.rs +++ b/crates/codegen/src/yul/isel/function.rs @@ -2,6 +2,7 @@ use std::thread::Scope; use super::{context::Context, inst_order::InstSerializer}; +use fe_analyzer::namespace::items::FunctionId; use fe_common::numeric::to_hex_str; use fe_abi::function::{AbiFunction, AbiFunctionType}; @@ -12,7 +13,7 @@ use fe_mir::{ constant::ConstantValue, inst::{BinOp, CallType, CastKind, InstKind, UnOp}, value::AssignableValue, - Constant, FunctionBody, FunctionId, FunctionSignature, InstId, Type, TypeId, TypeKind, + Constant, FunctionBody, FunctionSigId, FunctionSignature, InstId, Type, TypeId, TypeKind, Value, ValueId, }, pretty_print::PrettyPrint, @@ -37,20 +38,19 @@ use crate::{ pub fn lower_function( db: &dyn CodegenDb, ctx: &mut Context, - function: FunctionId, + sig: FunctionSigId, + func: FunctionId, ) -> yul::FunctionDefinition { - debug_assert!(!ctx.lowered_functions.contains(&function)); - ctx.lowered_functions.insert(function); - let sig = &db.codegen_legalized_signature(function); - let body = &db.codegen_legalized_body(function); - FuncLowerHelper::new(db, ctx, function, sig, body).lower_func() + let sig_data = &db.codegen_legalized_signature(sig); + let body = &db.codegen_legalized_body(func); + FuncLowerHelper::new(db, ctx, sig, sig_data, body).lower_func() } struct FuncLowerHelper<'db, 'a> { db: &'db dyn CodegenDb, ctx: &'a mut Context, value_map: ScopedValueMap, - func: FunctionId, + sig_id: FunctionSigId, sig: &'a FunctionSignature, body: &'a FunctionBody, ret_value: Option, @@ -61,7 +61,7 @@ impl<'db, 'a> FuncLowerHelper<'db, 'a> { fn new( db: &'db dyn CodegenDb, ctx: &'a mut Context, - func: FunctionId, + sig_id: FunctionSigId, sig: &'a FunctionSignature, body: &'a FunctionBody, ) -> Self { @@ -87,7 +87,7 @@ impl<'db, 'a> FuncLowerHelper<'db, 'a> { db, ctx, value_map, - func, + sig_id, sig, body, ret_value, @@ -96,7 +96,7 @@ impl<'db, 'a> FuncLowerHelper<'db, 'a> { } fn lower_func(mut self) -> yul::FunctionDefinition { - let name = identifier! { (self.db.codegen_function_symbol_name(self.func)) }; + let name = identifier! { (self.db.codegen_function_symbol_name(self.sig_id)) }; let parameters = self .sig diff --git a/crates/codegen/src/yul/runtime/contract.rs b/crates/codegen/src/yul/runtime/contract.rs index d85321b377..3222f5f962 100644 --- a/crates/codegen/src/yul/runtime/contract.rs +++ b/crates/codegen/src/yul/runtime/contract.rs @@ -6,7 +6,7 @@ use crate::{ use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; use fe_analyzer::namespace::items::ContractId; -use fe_mir::ir::{FunctionId, Type, TypeKind}; +use fe_mir::ir::{FunctionSigId, Type, TypeKind}; use yultsur::*; @@ -66,7 +66,7 @@ pub(super) fn make_external_call( provider: &mut DefaultRuntimeProvider, db: &dyn CodegenDb, func_name: &str, - function: FunctionId, + function: FunctionSigId, ) -> RuntimeFunction { let func_name = YulVariable::new(func_name); let sig = db.codegen_legalized_signature(function); diff --git a/crates/mir/src/ir/body_builder.rs b/crates/mir/src/ir/body_builder.rs index 8dc9242e51..b9160e62f7 100644 --- a/crates/mir/src/ir/body_builder.rs +++ b/crates/mir/src/ir/body_builder.rs @@ -10,6 +10,7 @@ use crate::ir::{ use super::{ inst::{CallType, CastKind, SwitchTable, YulIntrinsicOp}, + types::TypeParamDef, ConstantId, Value, ValueId, }; @@ -218,12 +219,14 @@ impl BodyBuilder { func: FunctionSigId, args: Vec, call_type: CallType, + generic_type: Option, source: SourceInfo, ) -> InstId { let kind = InstKind::Call { func, args, call_type, + generic_type, }; let inst = Inst::new(kind, source); self.insert_inst(inst) diff --git a/crates/mir/src/ir/inst.rs b/crates/mir/src/ir/inst.rs index 78da857d76..b3aa050d2f 100644 --- a/crates/mir/src/ir/inst.rs +++ b/crates/mir/src/ir/inst.rs @@ -85,6 +85,10 @@ pub enum InstKind { func: FunctionSigId, args: Vec, call_type: CallType, + // Traits can not directly be used as types but can act as bounds for generics. This is the + // generic type that the method is called on. + // TODO: This is the temporary solution until we implement a trait solver. + generic_type: Option, }, /// Unconditional jump instruction. diff --git a/crates/mir/src/lower/function.rs b/crates/mir/src/lower/function.rs index 15975fa0a3..c8316194ad 100644 --- a/crates/mir/src/lower/function.rs +++ b/crates/mir/src/lower/function.rs @@ -937,7 +937,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { AnalyzerCallType::AssociatedFunction { sig, .. } | AnalyzerCallType::Pure(sig) => { let func_id = self.db.mir_lowered_func_signature(*sig); - self.builder.call(func_id, args, CallType::Internal, source) + self.builder + .call(func_id, args, CallType::Internal, None, source) } AnalyzerCallType::ValueMethod { sig, .. } => { @@ -945,7 +946,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { method_args.append(&mut args); let func_id = self.db.mir_lowered_func_signature(*sig); self.builder - .call(func_id, method_args, CallType::Internal, source) + .call(func_id, method_args, CallType::Internal, None, source) } AnalyzerCallType::TraitValueMethod { trait_id, @@ -957,12 +958,21 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let event = self.lower_method_receiver(func); self.builder.emit(event, source) } - AnalyzerCallType::TraitValueMethod { sig, .. } => { + AnalyzerCallType::TraitValueMethod { + sig, generic_type, .. + } => { let mut method_args = vec![self.lower_method_receiver(func)]; method_args.append(&mut args); let func_id = self.db.mir_lowered_func_signature(*sig); - self.builder - .call(func_id, method_args, CallType::Internal, source) + self.builder.call( + func_id, + method_args, + CallType::Internal, + Some(ir::types::TypeParamDef { + name: generic_type.name.clone(), + }), + source, + ) } AnalyzerCallType::External { sig, .. } => { let receiver = self.lower_method_receiver(func); @@ -972,7 +982,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { method_args.append(&mut args); let func_id = self.db.mir_lowered_func_signature(*sig); self.builder - .call(func_id, method_args, CallType::External, source) + .call(func_id, method_args, CallType::External, None, source) } AnalyzerCallType::TypeConstructor(to_ty) => { diff --git a/crates/mir/src/pretty_print/inst.rs b/crates/mir/src/pretty_print/inst.rs index 45f2df007c..c9db87c8bf 100644 --- a/crates/mir/src/pretty_print/inst.rs +++ b/crates/mir/src/pretty_print/inst.rs @@ -103,6 +103,7 @@ impl PrettyPrint for InstId { func, args, call_type, + .. } => { let name = func.debug_name(db); write!(w, "{}@{}(", name, call_type)?;