diff --git a/Cargo.lock b/Cargo.lock index fc11dc0e6e..22e8567e48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -932,6 +932,22 @@ dependencies = [ "yultsur", ] +[[package]] +name = "fe-codegen2" +version = "0.23.0" +dependencies = [ + "fe-abi", + "fe-hir", + "fe-hir-analysis", + "fe-mir2", + "fxhash", + "indexmap", + "num-bigint", + "salsa-2022", + "smol_str", + "yultsur", +] + [[package]] name = "fe-common" version = "0.23.0" @@ -1100,6 +1116,14 @@ dependencies = [ "include_dir", ] +[[package]] +name = "fe-library2" +version = "0.23.0" +dependencies = [ + "fe-common2", + "include_dir", +] + [[package]] name = "fe-macros" version = "0.23.0" @@ -1130,6 +1154,44 @@ dependencies = [ "smol_str", ] +[[package]] +name = "fe-mir2" +version = "0.23.0" +dependencies = [ + "dot2", + "fe-common2", + "fe-hir", + "fe-hir-analysis", + "fe-library2", + "fe-parser2", + "fe-test-files", + "fxhash", + "id-arena", + "indexmap", + "num-bigint", + "num-integer", + "num-traits", + "salsa-2022", + "smol_str", +] + +[[package]] +name = "fe-mir2-analysis" +version = "0.23.0" +dependencies = [ + "dot2", + "fe-common2", + "fe-hir", + "fe-hir-analysis", + "fe-library", + "fe-mir2", + "fe-test-files", + "fxhash", + "id-arena", + "salsa-2022", + "smol_str", +] + [[package]] name = "fe-parser" version = "0.23.0" diff --git a/crates/codegen2/Cargo.toml b/crates/codegen2/Cargo.toml new file mode 100644 index 0000000000..8d18bca5e7 --- /dev/null +++ b/crates/codegen2/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "fe-codegen2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" + +[dependencies] +hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } +hir = { path = "../hir", package = "fe-hir" } +mir = { path = "../mir2", package = "fe-mir2" } +fe-abi = { path = "../abi", version = "^0.23.0"} +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } +num-bigint = "0.4.3" +fxhash = "0.2.1" +indexmap = "1.6.2" +smol_str = "0.1.21" +yultsur = { git = "https://github.com/fe-lang/yultsur", rev = "ae85470" } diff --git a/crates/codegen2/src/abi.rs b/crates/codegen2/src/abi.rs new file mode 100644 index 0000000000..8331be96e7 --- /dev/null +++ b/crates/codegen2/src/abi.rs @@ -0,0 +1,279 @@ +use fe_abi::{ + contract::AbiContract, + event::{AbiEvent, AbiEventField}, + function::{AbiFunction, AbiFunctionType, CtxParam, SelfParam, StateMutability}, + types::{AbiTupleField, AbiType}, +}; +use fe_analyzer::{ + constants::INDEXED, + namespace::{ + items::ContractId, + types::{CtxDecl, SelfDecl}, + }, +}; +use fe_mir::ir::{self, FunctionId, TypeId}; + +#[salsa::tracked(return_ref)] +pub fn abi_contract(db: &dyn CodegenDb, contract: ContractId) -> AbiContract { + let mut funcs = vec![]; + + if let Some(init) = contract.init_function(db.upcast()) { + let init_func = db.mir_lowered_func_signature(init); + let init_abi = db.codegen_abi_function(init_func); + funcs.push(init_abi); + } + + for &func in contract.all_functions(db.upcast()).as_ref() { + let mir_func = db.mir_lowered_func_signature(func); + if mir_func.linkage(db.upcast()).is_exported() { + let func_abi = db.codegen_abi_function(mir_func); + funcs.push(func_abi); + } + } + + let mut events = vec![]; + for &s in db.module_structs(contract.module(db.upcast())).as_ref() { + let struct_ty = s.as_type(db.upcast()); + // TODO: This is a hack to avoid generating an ABI for non-`emittable` structs. + if struct_ty.is_emittable(db.upcast()) { + let mir_event = db.mir_lowered_type(struct_ty); + let event = db.codegen_abi_event(mir_event); + events.push(event); + } + } + + AbiContract::new(funcs, events) +} + +#[salsa::tracked(return_ref)] +pub fn abi_function(db: &dyn CodegenDb, function: FunctionId) -> AbiFunction { + // We use a legalized signature. + let sig = db.codegen_legalized_signature(function); + + let name = function.name(db.upcast()); + let args = sig + .params + .iter() + .map(|param| (param.name.to_string(), db.codegen_abi_type(param.ty))) + .collect(); + let ret_ty = sig.return_type.map(|ty| db.codegen_abi_type(ty)); + + let func_type = if function.is_contract_init(db.upcast()) { + AbiFunctionType::Constructor + } else { + AbiFunctionType::Function + }; + + // The "stateMutability" field is derived from the presence & mutability of + // `self` and `ctx` params in the analyzer fn sig. + let analyzer_sig = sig.analyzer_func_id.signature(db.upcast()); + let self_param = match analyzer_sig.self_decl { + None => SelfParam::None, + Some(SelfDecl { mut_: None, .. }) => SelfParam::Imm, + Some(SelfDecl { mut_: Some(_), .. }) => SelfParam::Mut, + }; + let ctx_param = match analyzer_sig.ctx_decl { + None => CtxParam::None, + Some(CtxDecl { mut_: None, .. }) => CtxParam::Imm, + Some(CtxDecl { mut_: Some(_), .. }) => CtxParam::Mut, + }; + + let state_mutability = if name == "__init__" { + StateMutability::Payable + } else { + StateMutability::from_self_and_ctx_params(self_param, ctx_param) + }; + + AbiFunction::new(func_type, name.to_string(), args, ret_ty, state_mutability) +} + +#[salsa::tracked(return_ref)] +pub fn abi_function_argument_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { + let sig = db.codegen_legalized_signature(function); + sig.params.iter().fold(0, |acc, param| { + acc + db.codegen_abi_type_maximum_size(param.ty) + }) +} + +#[salsa::tracked(return_ref)] +pub fn abi_function_return_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { + let sig = db.codegen_legalized_signature(function); + sig.return_type + .map(|ty| db.codegen_abi_type_maximum_size(ty)) + .unwrap_or_default() +} + +#[salsa::tracked(return_ref)] +pub fn abi_type_maximum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { + let abi_type = db.codegen_abi_type(ty); + if abi_type.is_static() { + abi_type.header_size() + } else { + match &ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) if def.elem_ty.data(db.upcast()).kind == ir::TypeKind::U8 => { + debug_assert_eq!(abi_type, AbiType::Bytes); + 64 + ceil_32(def.len) + } + + ir::TypeKind::Array(def) => { + db.codegen_abi_type_maximum_size(def.elem_ty) * def.len + 32 + } + + ir::TypeKind::String(len) => abi_type.header_size() + 32 + ceil_32(*len), + _ if ty.is_aggregate(db.upcast()) => { + let mut maximum = 0; + for i in 0..ty.aggregate_field_num(db.upcast()) { + let field_ty = ty.projection_ty_imm(db.upcast(), i); + maximum += db.codegen_abi_type_maximum_size(field_ty) + } + maximum + 32 + } + ir::TypeKind::MPtr(ty) => abi_type_maximum_size(db, ty.deref(db.upcast())), + + _ => unreachable!(), + } + } +} + +#[salsa::tracked(return_ref)] +pub fn abi_type_minimum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { + let abi_type = db.codegen_abi_type(ty); + if abi_type.is_static() { + abi_type.header_size() + } else { + match &ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) if def.elem_ty.data(db.upcast()).kind == ir::TypeKind::U8 => { + debug_assert_eq!(abi_type, AbiType::Bytes); + 64 + } + ir::TypeKind::Array(def) => { + db.codegen_abi_type_minimum_size(def.elem_ty) * def.len + 32 + } + + ir::TypeKind::String(_) => abi_type.header_size() + 32, + + _ if ty.is_aggregate(db.upcast()) => { + let mut minimum = 0; + for i in 0..ty.aggregate_field_num(db.upcast()) { + let field_ty = ty.projection_ty_imm(db.upcast(), i); + minimum += db.codegen_abi_type_minimum_size(field_ty) + } + minimum + 32 + } + ir::TypeKind::MPtr(ty) => abi_type_minimum_size(db, ty.deref(db.upcast())), + _ => unreachable!(), + } + } +} + +#[salsa::tracked(return_ref)] +pub fn abi_type(db: &dyn CodegenDb, ty: TypeId) -> AbiType { + let legalized_ty = db.codegen_legalized_type(ty); + + if legalized_ty.is_zero_sized(db.upcast()) { + unreachable!("zero-sized type must be removed in legalization"); + } + + let ty_data = legalized_ty.data(db.upcast()); + + match &ty_data.kind { + ir::TypeKind::I8 => AbiType::Int(8), + ir::TypeKind::I16 => AbiType::Int(16), + ir::TypeKind::I32 => AbiType::Int(32), + ir::TypeKind::I64 => AbiType::Int(64), + ir::TypeKind::I128 => AbiType::Int(128), + ir::TypeKind::I256 => AbiType::Int(256), + ir::TypeKind::U8 => AbiType::UInt(8), + ir::TypeKind::U16 => AbiType::UInt(16), + ir::TypeKind::U32 => AbiType::UInt(32), + ir::TypeKind::U64 => AbiType::UInt(64), + ir::TypeKind::U128 => AbiType::UInt(128), + ir::TypeKind::U256 => AbiType::UInt(256), + ir::TypeKind::Bool => AbiType::Bool, + ir::TypeKind::Address => AbiType::Address, + ir::TypeKind::String(_) => AbiType::String, + ir::TypeKind::Unit => unreachable!("zero-sized type must be removed in legalization"), + ir::TypeKind::Array(def) => { + let elem_ty_data = &def.elem_ty.data(db.upcast()); + match &elem_ty_data.kind { + ir::TypeKind::U8 => AbiType::Bytes, + _ => { + let elem_ty = db.codegen_abi_type(def.elem_ty); + let len = def.len; + AbiType::Array { + elem_ty: elem_ty.into(), + len, + } + } + } + } + ir::TypeKind::Tuple(def) => { + let fields = def + .items + .iter() + .enumerate() + .map(|(i, item)| { + let field_ty = db.codegen_abi_type(*item); + AbiTupleField::new(format!("{i}"), field_ty) + }) + .collect(); + + AbiType::Tuple(fields) + } + ir::TypeKind::Struct(def) => { + let fields = def + .fields + .iter() + .map(|(name, ty)| { + let ty = db.codegen_abi_type(*ty); + AbiTupleField::new(name.to_string(), ty) + }) + .collect(); + + AbiType::Tuple(fields) + } + ir::TypeKind::MPtr(inner) => db.codegen_abi_type(*inner), + + ir::TypeKind::Contract(_) + | ir::TypeKind::Map(_) + | ir::TypeKind::Enum(_) + | ir::TypeKind::SPtr(_) => unreachable!(), + } +} + +#[salsa::tracked(return_ref)] +pub fn abi_event(db: &dyn CodegenDb, ty: TypeId) -> AbiEvent { + debug_assert!(ty.is_struct(db.upcast())); + + let legalized_ty = db.codegen_legalized_type(ty); + let analyzer_struct = ty + .analyzer_ty(db.upcast()) + .and_then(|val| val.as_struct(db.upcast())) + .unwrap(); + let legalized_ty_data = legalized_ty.data(db.upcast()); + let event_def = match &legalized_ty_data.kind { + ir::TypeKind::Struct(def) => def, + _ => unreachable!(), + }; + + let fields = event_def + .fields + .iter() + .map(|(name, ty)| { + let attr = analyzer_struct + .field(db.upcast(), name) + .unwrap() + .attributes(db.upcast()); + + let ty = db.codegen_abi_type(*ty); + let indexed = attr.iter().any(|attr| attr == INDEXED); + AbiEventField::new(name.to_string(), ty, indexed) + }) + .collect(); + + AbiEvent::new(event_def.name.to_string(), fields, false) +} + +fn ceil_32(value: usize) -> usize { + ((value + 31) / 32) * 32 +} diff --git a/crates/codegen2/src/constant.rs b/crates/codegen2/src/constant.rs new file mode 100644 index 0000000000..b848084e8b --- /dev/null +++ b/crates/codegen2/src/constant.rs @@ -0,0 +1,12 @@ +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + rc::Rc, +}; + +#[salsa::tracked(return_ref)] +pub fn string_symbol_name(_db: &dyn CodegenDb, data: String) -> Rc { + let mut hasher = DefaultHasher::new(); + data.hash(&mut hasher); + format! {"{}", hasher.finish()}.into() +} diff --git a/crates/codegen2/src/contract.rs b/crates/codegen2/src/contract.rs new file mode 100644 index 0000000000..b39fc140b0 --- /dev/null +++ b/crates/codegen2/src/contract.rs @@ -0,0 +1,18 @@ +use std::rc::Rc; + +#[salsa::tracked(return_ref)] +pub fn symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { + let module = contract.module(db.upcast()); + + format!( + "{}${}", + module.name(db.upcast()), + contract.name(db.upcast()) + ) + .into() +} + +#[salsa::tracked(return_ref)] +pub fn deployer_symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { + format!("deploy_{}", symbol_name(db, contract).as_ref()).into() +} diff --git a/crates/codegen2/src/function.rs b/crates/codegen2/src/function.rs new file mode 100644 index 0000000000..89f91c799e --- /dev/null +++ b/crates/codegen2/src/function.rs @@ -0,0 +1,74 @@ +use std::rc::Rc; + +use fe_mir::ir::{FunctionBody, FunctionId, FunctionSignature}; +use hir::hir_def::TypeId; +use smol_str::SmolStr; + +use crate::CodegenDb; + +#[salsa::tracked(return_ref)] +pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut sig = function.signature(db.upcast()).as_ref().clone(); + db.legalize_func_signature(&mut sig); + sig.into() +} + +#[salsa::tracked(return_ref)] +pub fn legalized_body(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut body = function.body(db.upcast()).as_ref().clone(); + db.legalize_func_body(&mut body); + body.into() +} + +#[salsa::tracked(return_ref)] +pub fn symbol_name(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let module = function.signature(db.upcast()).module_id; + let module_name = module.name(db.upcast()); + + let analyzer_func = function.analyzer_func(db.upcast()); + let func_name = format!( + "{}{}", + analyzer_func.name(db.upcast()), + type_suffix(function, db) + ); + + // let func_name = match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { + // Some(Item::Impl(id)) => { + // let class_name = format!( + // "{}${}", + // id.trait_id(db.upcast()).name(db.upcast()), + // safe_name(db, id.receiver(db.upcast())) + // ); + // format!("{class_name}${func_name}") + // } + // Some(class) => { + // let class_name = class.name(db.upcast()); + // format!("{class_name}${func_name}") + // } + // _ => func_name, + // }; + + // format!("{module_name}${func_name}").into() + "".into() +} + +fn type_suffix(function: FunctionId, db: &dyn CodegenDb) -> SmolStr { + function + .signature(db.upcast()) + .resolved_generics + .values() + .fold(String::new(), |acc, param| { + format!("{}_{}", acc, safe_name(db, *param)) + }) + .into() +} + +fn safe_name(db: &dyn CodegenDb, ty: TypeId) -> SmolStr { + // match ty.typ(db.upcast()) { + // // TODO: Would be nice to get more human friendly names here + // Type::Array(_) => format!("array_{:?}", ty.as_intern_id()).into(), + // Type::Tuple(_) => format!("tuple_{:?}", ty.as_intern_id()).into(), + // _ => format!("{}", ty.display(db.upcast())).into(), + // } + "".into() +} diff --git a/crates/codegen2/src/lib.rs b/crates/codegen2/src/lib.rs new file mode 100644 index 0000000000..d1dbe89e9f --- /dev/null +++ b/crates/codegen2/src/lib.rs @@ -0,0 +1,96 @@ +use mir::MirDb; + +pub mod yul; + +// mod abi; +// mod constant; +// mod contract; +mod function; +// mod types; + +#[salsa::jar(db = CodegenDb)] +pub struct Jar( + function::legalized_signature, + function::legalized_body, + function::symbol_name, + // types::legalized_type, + // abi::abi_type, + // abi::abi_function, + // abi::abi_event, + // abi::abi_contract, + // abi::abi_type_maximum_size, + // abi::abi_type_minimum_size, + // abi::abi_function_argument_maximum_size, + // abi::abi_function_return_maximum_size, + // contract::symbol_name, + // contract::deployer_symbol_name, + // constant::string_symbol_name, +); + +pub trait CodegenDb: salsa::DbWithJar + MirDb { + fn as_hir_db(&self) -> &dyn CodegenDb { + >::as_jar_db::<'_>(self) + } +} +impl CodegenDb for DB where DB: salsa::DbWithJar + MirDb {} + +// #![allow(clippy::arc_with_non_send_sync)] +// use std::rc::Rc; + +// use fe_abi::{contract::AbiContract, event::AbiEvent, function::AbiFunction, types::AbiType}; +// use fe_analyzer::{db::AnalyzerDbStorage, namespace::items::ContractId, AnalyzerDb}; +// use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; +// use fe_mir::{ +// db::{MirDb, MirDbStorage}, +// ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, +// }; + +// mod queries; + +// #[salsa::query_group(CodegenDbStorage)] +// pub trait CodegenDb: MirDb + Upcast + UpcastMut { +// } + +// // TODO: Move this to driver. +// #[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage, CodegenDbStorage)] +// #[derive(Default)] +// pub struct Db { +// storage: salsa::Storage, +// } +// impl salsa::Database for Db {} + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn MirDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn MirDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn SourceDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { +// &mut *self +// } +// } diff --git a/crates/codegen2/src/types.rs b/crates/codegen2/src/types.rs new file mode 100644 index 0000000000..6789857719 --- /dev/null +++ b/crates/codegen2/src/types.rs @@ -0,0 +1,101 @@ +use fe_mir::ir::{ + types::{ArrayDef, MapDef, StructDef, TupleDef}, + Type, TypeId, TypeKind, +}; + +#[salsa::tracked(return_ref)] +pub fn legalized_type(db: &dyn CodegenDb, ty: TypeId) -> TypeId { + let ty_data = ty.data(db.upcast()); + let ty_kind = match &ty.data(db.upcast()).kind { + TypeKind::Tuple(def) => { + let items = def + .items + .iter() + .filter_map(|item| { + if item.is_zero_sized(db.upcast()) { + None + } else { + Some(legalized_type(db, *item)) + } + }) + .collect(); + let new_def = TupleDef { items }; + TypeKind::Tuple(new_def) + } + + TypeKind::Array(def) => { + let new_def = ArrayDef { + elem_ty: legalized_type(db, def.elem_ty), + len: def.len, + }; + TypeKind::Array(new_def) + } + + TypeKind::Struct(def) => { + let fields = def + .fields + .iter() + .cloned() + .filter_map(|(name, ty)| { + if ty.is_zero_sized(db.upcast()) { + None + } else { + Some((name, legalized_type(db, ty))) + } + }) + .collect(); + let new_def = StructDef { + name: def.name.clone(), + fields, + span: def.span, + module_id: def.module_id, + }; + TypeKind::Struct(new_def) + } + + TypeKind::Contract(def) => { + let fields = def + .fields + .iter() + .cloned() + .filter_map(|(name, ty)| { + if ty.is_zero_sized(db.upcast()) { + None + } else { + Some((name, legalized_type(db, ty))) + } + }) + .collect(); + let new_def = StructDef { + name: def.name.clone(), + fields, + span: def.span, + module_id: def.module_id, + }; + TypeKind::Contract(new_def) + } + + TypeKind::Map(def) => { + let new_def = MapDef { + key_ty: legalized_type(db, def.key_ty), + value_ty: legalized_type(db, def.value_ty), + }; + TypeKind::Map(new_def) + } + + TypeKind::MPtr(ty) => { + let new_ty = legalized_type(db, *ty); + TypeKind::MPtr(new_ty) + } + + TypeKind::SPtr(ty) => { + let new_ty = legalized_type(db, *ty); + TypeKind::SPtr(new_ty) + } + + _ => return ty, + }; + + let analyzer_ty = ty_data.analyzer_ty; + db.mir_intern_type(Type::new(ty_kind, analyzer_ty).into()) +} diff --git a/crates/codegen2/src/yul/isel/context.rs b/crates/codegen2/src/yul/isel/context.rs new file mode 100644 index 0000000000..be2748e92e --- /dev/null +++ b/crates/codegen2/src/yul/isel/context.rs @@ -0,0 +1,81 @@ +use hir::hir_def::Contract; +use indexmap::IndexSet; + +use fe_mir::ir::FunctionId; +use fxhash::FxHashSet; +use yultsur::yul; + +use crate::{ + yul::runtime::{DefaultRuntimeProvider, RuntimeProvider}, + CodegenDb, +}; + +use super::{lower_contract_deployable, lower_function}; + +pub struct Context { + pub runtime: Box, + pub(super) contract_dependency: IndexSet, + pub(super) function_dependency: IndexSet, + pub(super) string_constants: IndexSet, + pub(super) lowered_functions: FxHashSet, +} + +// Currently, `clippy::derivable_impls` causes false positive result, +// see https://github.com/rust-lang/rust-clippy/issues/10158 for more details. +#[allow(clippy::derivable_impls)] +impl Default for Context { + fn default() -> Self { + Self { + runtime: Box::::default(), + contract_dependency: IndexSet::default(), + function_dependency: IndexSet::default(), + string_constants: IndexSet::default(), + lowered_functions: FxHashSet::default(), + } + } +} + +impl Context { + pub(super) fn resolve_function_dependency( + &mut self, + db: &dyn CodegenDb, + ) -> Vec { + let mut funcs = vec![]; + loop { + let dependencies = std::mem::take(&mut self.function_dependency); + if dependencies.is_empty() { + break; + } + for dependency in dependencies { + if self.lowered_functions.contains(&dependency) { + // Ignore dependency if it's already lowered. + continue; + } else { + funcs.push(lower_function(db, self, dependency)) + } + } + } + + funcs + } + + pub(super) fn resolve_constant_dependency(&self, db: &dyn CodegenDb) -> Vec { + self.string_constants + .iter() + .map(|s| { + let symbol = db.codegen_constant_string_symbol_name(s.to_string()); + yul::Data { + name: symbol.as_ref().clone(), + value: s.to_string(), + } + }) + .collect() + } + + pub(super) fn resolve_contract_dependency(&self, db: &dyn CodegenDb) -> Vec { + self.contract_dependency + .iter() + .map(|cid| lower_contract_deployable(db, *cid)) + .collect() + } +} diff --git a/crates/codegen2/src/yul/isel/contract.rs b/crates/codegen2/src/yul/isel/contract.rs new file mode 100644 index 0000000000..e218404d55 --- /dev/null +++ b/crates/codegen2/src/yul/isel/contract.rs @@ -0,0 +1,289 @@ +use fe_mir::ir::{function::Linkage, FunctionId}; +use hir::hir_def::Contract; +use yultsur::{yul, *}; + +use crate::{ + yul::{runtime::AbiSrcLocation, YulVariable}, + CodegenDb, +}; + +use super::context::Context; + +pub fn lower_contract_deployable(db: &dyn CodegenDb, contract: Contract) -> yul::Object { + let mut context = Context::default(); + + let constructor = if let Some(init) = contract.init_function(db.upcast()) { + let init = db.mir_lowered_func_signature(init); + make_init(db, &mut context, contract, init) + } else { + statements! {} + }; + + let deploy_code = make_deploy(db, contract); + + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + + let deploy_block = block_statement! { + [constructor...] + [deploy_code...] + }; + + let code = code! { + [deploy_block] + [dep_functions...] + [runtime_funcs...] + }; + + let mut dep_contracts = context.resolve_contract_dependency(db); + dep_contracts.push(lower_contract(db, contract)); + let dep_constants = context.resolve_constant_dependency(db); + + let name = identifier! {( + db.codegen_contract_deployer_symbol_name(contract).as_ref() + )}; + let object = yul::Object { + name, + code, + objects: dep_contracts, + data: dep_constants, + }; + + normalize_object(object) +} + +pub fn lower_contract(db: &dyn CodegenDb, contract: Contract) -> yul::Object { + let exported_funcs: Vec<_> = db + .mir_lower_contract_all_functions(contract) + .iter() + .filter_map(|fid| { + if fid.signature(db.upcast()).linkage == Linkage::Export { + Some(*fid) + } else { + None + } + }) + .collect(); + + let mut context = Context::default(); + let dispatcher = if let Some(call_fn) = contract.call_function(db.upcast()) { + let call_fn = db.mir_lowered_func_signature(call_fn); + context.function_dependency.insert(call_fn); + let call_symbol = identifier! { (db.codegen_function_symbol_name(call_fn)) }; + statement! { + ([call_symbol]()) + } + } else { + make_dispatcher(db, &mut context, &exported_funcs) + }; + + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + + let code = code! { + ([dispatcher]) + [dep_functions...] + [runtime_funcs...] + }; + + // Lower dependant contracts. + let dep_contracts = context.resolve_contract_dependency(db); + + // Collect string constants. + let dep_constants = context.resolve_constant_dependency(db); + let contract_symbol = identifier! { (db.codegen_contract_symbol_name(contract)) }; + + yul::Object { + name: contract_symbol, + code, + objects: dep_contracts, + data: dep_constants, + } +} + +fn make_dispatcher( + db: &dyn CodegenDb, + context: &mut Context, + funcs: &[FunctionId], +) -> yul::Statement { + let arms = funcs + .iter() + .map(|func| dispatch_arm(db, context, *func)) + .collect::>(); + + if arms.is_empty() { + statement! { return(0, 0) } + } else { + let selector = expression! { + and((shr((sub(256, 32)), (calldataload(0)))), 0xffffffff) + }; + switch! { + switch ([selector]) + [arms...] + (default { (return(0, 0)) }) + } + } +} + +fn dispatch_arm(db: &dyn CodegenDb, context: &mut Context, func: FunctionId) -> yul::Case { + context.function_dependency.insert(func); + let func_sig = db.codegen_legalized_signature(func); + let mut param_vars = Vec::with_capacity(func_sig.params.len()); + let mut param_tys = Vec::with_capacity(func_sig.params.len()); + func_sig.params.iter().for_each(|param| { + param_vars.push(YulVariable::new(param.name.as_str())); + param_tys.push(param.ty); + }); + + let decode_params = if func_sig.params.is_empty() { + statements! {} + } else { + let ident_params: Vec<_> = param_vars.iter().map(YulVariable::ident).collect(); + let param_size = YulVariable::new("param_size"); + statements! { + (let [param_size.ident()] := sub((calldatasize()), 4)) + (let [ident_params...] := [context.runtime.abi_decode(db, expression! { 4 }, param_size.expr(), ¶m_tys, AbiSrcLocation::CallData)]) + } + }; + + let call_and_encode_return = { + let name = identifier! { (db.codegen_function_symbol_name(func)) }; + // we pass in a `0` for the expected `Context` argument + let call = expression! {[name]([(param_vars.iter().map(YulVariable::expr).collect::>())...])}; + if let Some(mut return_type) = func_sig.return_type { + if return_type.is_aggregate(db.upcast()) { + return_type = return_type.make_mptr(db.upcast()); + } + + let ret = YulVariable::new("ret"); + let enc_start = YulVariable::new("enc_start"); + let enc_size = YulVariable::new("enc_size"); + let abi_encode = context.runtime.abi_encode_seq( + db, + &[ret.expr()], + enc_start.expr(), + &[return_type], + false, + ); + statements! { + (let [ret.ident()] := [call]) + (let [enc_start.ident()] := [context.runtime.avail(db)]) + (let [enc_size.ident()] := [abi_encode]) + (return([enc_start.expr()], [enc_size.expr()])) + } + } else { + statements! { + ([yul::Statement::Expression(call)]) + (return(0, 0)) + } + } + }; + + let abi_sig = db.codegen_abi_function(func); + let selector = literal! { (format!("0x{}", abi_sig.selector().hex())) }; + case! { + case [selector] { + [decode_params...] + [call_and_encode_return...] + } + } +} + +fn make_init( + db: &dyn CodegenDb, + context: &mut Context, + contract: Contract, + init: FunctionId, +) -> Vec { + context.function_dependency.insert(init); + let init_func_name = identifier! { (db.codegen_function_symbol_name(init)) }; + let contract_name = identifier_expression! { (format!{r#""{}""#, db.codegen_contract_deployer_symbol_name(contract)}) }; + + let func_sig = db.codegen_legalized_signature(init); + let mut param_vars = Vec::with_capacity(func_sig.params.len()); + let mut param_tys = Vec::with_capacity(func_sig.params.len()); + let program_size = YulVariable::new("$program_size"); + let arg_size = YulVariable::new("$arg_size"); + let code_size = YulVariable::new("$code_size"); + let memory_data_offset = YulVariable::new("$memory_data_offset"); + func_sig.params.iter().for_each(|param| { + param_vars.push(YulVariable::new(param.name.as_str())); + param_tys.push(param.ty); + }); + + let decode_params = if func_sig.params.is_empty() { + statements! {} + } else { + let ident_params: Vec<_> = param_vars.iter().map(YulVariable::ident).collect(); + statements! { + (let [ident_params...] := [context.runtime.abi_decode(db, memory_data_offset.expr(), arg_size.expr(), ¶m_tys, AbiSrcLocation::Memory)]) + } + }; + + let call = expression! {[init_func_name]([(param_vars.iter().map(YulVariable::expr).collect::>())...])}; + statements! { + (let [program_size.ident()] := datasize([contract_name])) + (let [code_size.ident()] := codesize()) + (let [arg_size.ident()] := sub([code_size.expr()], [program_size.expr()])) + (let [memory_data_offset.ident()] := [context.runtime.alloc(db, arg_size.expr())]) + (codecopy([memory_data_offset.expr()], [program_size.expr()], [arg_size.expr()])) + [decode_params...] + ([yul::Statement::Expression(call)]) + } +} + +fn make_deploy(db: &dyn CodegenDb, contract: Contract) -> Vec { + let contract_symbol = + identifier_expression! { (format!{r#""{}""#, db.codegen_contract_symbol_name(contract)}) }; + let size = YulVariable::new("$$size"); + statements! { + (let [size.ident()] := (datasize([contract_symbol.clone()]))) + (datacopy(0, (dataoffset([contract_symbol])), [size.expr()])) + (return (0, [size.expr()])) + } +} + +fn normalize_object(obj: yul::Object) -> yul::Object { + let data = obj + .data + .into_iter() + .map(|data| yul::Data { + name: data.name, + value: data + .value + .replace('\\', "\\\\\\\\") + .replace('\n', "\\\\n") + .replace('"', "\\\\\"") + .replace('\r', "\\\\r") + .replace('\t', "\\\\t"), + }) + .collect::>(); + yul::Object { + name: obj.name, + code: obj.code, + objects: obj + .objects + .into_iter() + .map(normalize_object) + .collect::>(), + data, + } +} diff --git a/crates/codegen2/src/yul/isel/function.rs b/crates/codegen2/src/yul/isel/function.rs new file mode 100644 index 0000000000..b7adb5ae24 --- /dev/null +++ b/crates/codegen2/src/yul/isel/function.rs @@ -0,0 +1,978 @@ +#![allow(unused)] +use std::thread::Scope; + +use super::{context::Context, inst_order::InstSerializer}; +use fe_common::numeric::to_hex_str; + +use fe_abi::function::{AbiFunction, AbiFunctionType}; +use fe_common::db::Upcast; +use fe_mir::{ + ir::{ + self, + constant::ConstantValue, + inst::{BinOp, CallType, CastKind, InstKind, UnOp}, + value::AssignableValue, + Constant, FunctionBody, FunctionId, FunctionSignature, InstId, Type, TypeId, TypeKind, + Value, ValueId, + }, + pretty_print::PrettyPrint, +}; +use fxhash::FxHashMap; +use smol_str::SmolStr; +use yultsur::{ + yul::{self, Statement}, + *, +}; + +use crate::{ + yul::{ + isel::inst_order::StructuralInst, + runtime::{self, RuntimeProvider}, + slot_size::{function_hash_type, yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, + CodegenDb, +}; + +pub fn lower_function( + db: &dyn CodegenDb, + ctx: &mut Context, + function: FunctionId, +) -> yul::FunctionDefinition { + debug_assert!(!ctx.lowered_functions.contains(&function)); + ctx.lowered_functions.insert(function); + let sig = &db.codegen_legalized_signature(function); + let body = &db.codegen_legalized_body(function); + FuncLowerHelper::new(db, ctx, function, sig, body).lower_func() +} + +struct FuncLowerHelper<'db, 'a> { + db: &'db dyn CodegenDb, + ctx: &'a mut Context, + value_map: ScopedValueMap, + func: FunctionId, + sig: &'a FunctionSignature, + body: &'a FunctionBody, + ret_value: Option, + sink: Vec, +} + +impl<'db, 'a> FuncLowerHelper<'db, 'a> { + fn new( + db: &'db dyn CodegenDb, + ctx: &'a mut Context, + func: FunctionId, + sig: &'a FunctionSignature, + body: &'a FunctionBody, + ) -> Self { + let mut value_map = ScopedValueMap::default(); + // Register arguments to value_map. + for &value in body.store.locals() { + match body.store.value_data(value) { + Value::Local(local) if local.is_arg => { + let ident = YulVariable::new(local.name.as_str()).ident(); + value_map.insert(value, ident); + } + _ => {} + } + } + + let ret_value = if sig.return_type.is_some() { + Some(YulVariable::new("$ret").ident()) + } else { + None + }; + + Self { + db, + ctx, + value_map, + func, + sig, + body, + ret_value, + sink: Vec::new(), + } + } + + fn lower_func(mut self) -> yul::FunctionDefinition { + let name = identifier! { (self.db.codegen_function_symbol_name(self.func)) }; + + let parameters = self + .sig + .params + .iter() + .map(|param| YulVariable::new(param.name.as_str()).ident()) + .collect(); + + let ret = self + .ret_value + .clone() + .map(|value| vec![value]) + .unwrap_or_default(); + + let body = self.lower_body(); + + yul::FunctionDefinition { + name, + parameters, + returns: ret, + block: body, + } + } + + fn lower_body(mut self) -> yul::Block { + let inst_order = InstSerializer::new(self.body).serialize(); + + for inst in inst_order { + self.lower_structural_inst(inst) + } + + yul::Block { + statements: self.sink, + } + } + + fn lower_structural_inst(&mut self, inst: StructuralInst) { + match inst { + StructuralInst::Inst(inst) => self.lower_inst(inst), + StructuralInst::If { cond, then, else_ } => { + let if_block = self.lower_if(cond, then, else_); + self.sink.push(if_block) + } + StructuralInst::Switch { + scrutinee, + table, + default, + } => { + let switch_block = self.lower_switch(scrutinee, table, default); + self.sink.push(switch_block) + } + StructuralInst::For { body } => { + let for_block = self.lower_for(body); + self.sink.push(for_block) + } + StructuralInst::Break => self.sink.push(yul::Statement::Break), + StructuralInst::Continue => self.sink.push(yul::Statement::Continue), + }; + } + + fn lower_inst(&mut self, inst: InstId) { + if let Some(lhs) = self.body.store.inst_result(inst) { + self.declare_assignable_value(lhs) + } + + match &self.body.store.inst_data(inst).kind { + InstKind::Declare { local } => self.declare_value(*local), + + InstKind::Unary { op, value } => { + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result.ty(self.db.upcast(), &self.body.store); + let result = self.lower_unary(*op, *value); + self.assign_inst_result(inst, result, inst_result_ty.deref(self.db.upcast())) + } + + InstKind::Binary { op, lhs, rhs } => { + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result.ty(self.db.upcast(), &self.body.store); + let result = self.lower_binary(*op, *lhs, *rhs, inst); + self.assign_inst_result(inst, result, inst_result_ty.deref(self.db.upcast())) + } + + InstKind::Cast { kind, value, to } => { + let from_ty = self.body.store.value_ty(*value); + let result = match kind { + CastKind::Primitive => { + debug_assert!( + from_ty.is_primitive(self.db.upcast()) + && to.is_primitive(self.db.upcast()) + ); + let value = self.value_expr(*value); + self.ctx.runtime.primitive_cast(self.db, value, from_ty) + } + CastKind::Untag => { + let from_ty = from_ty.deref(self.db.upcast()); + debug_assert!(from_ty.is_enum(self.db.upcast())); + let value = self.value_expr(*value); + let offset = literal_expression! {(from_ty.enum_data_offset(self.db.upcast(), SLOT_SIZE))}; + expression! {add([value], [offset])} + } + }; + + self.assign_inst_result(inst, result, *to) + } + + InstKind::AggregateConstruct { ty, args } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let arg_values = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let arg_tys = args + .iter() + .map(|arg| self.body.store.value_ty(*arg)) + .collect(); + self.sink.push(yul::Statement::Expression( + self.ctx + .runtime + .aggregate_init(self.db, ptr, arg_values, ptr_ty, arg_tys), + )) + } + + InstKind::Bind { src } => { + match self.body.store.value_data(*src) { + Value::Constant { constant, .. } => { + // Need special handling when rhs is the string literal because it needs ptr + // copy. + if let ConstantValue::Str(s) = &constant.data(self.db.upcast()).value { + self.ctx.string_constants.insert(s.to_string()); + let size = self.value_ty_size_deref(*src); + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let inst_result_ty = lhs.ty(self.db.upcast(), &self.body.store); + self.sink.push(yul::Statement::Expression( + self.ctx.runtime.string_copy( + self.db, + ptr, + s, + inst_result_ty.is_sptr(self.db.upcast()), + ), + )) + } else { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + self.assign_inst_result(inst, src, src_ty) + } + } + _ => { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + self.assign_inst_result(inst, src, src_ty) + } + } + } + + InstKind::MemCopy { src } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let dst_ptr = self.lower_assignable_value(lhs); + let dst_ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_ptr = self.value_expr(*src); + let src_ptr_ty = self.body.store.value_ty(*src); + let ty_size = literal_expression! { (self.value_ty_size_deref(*src)) }; + self.sink + .push(yul::Statement::Expression(self.ctx.runtime.ptr_copy( + self.db, + src_ptr, + dst_ptr, + ty_size, + src_ptr_ty.is_sptr(self.db.upcast()), + dst_ptr_ty.is_sptr(self.db.upcast()), + ))) + } + + InstKind::Load { src } => { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + debug_assert!(src_ty.is_ptr(self.db.upcast())); + + let result = self.body.store.inst_result(inst).unwrap(); + debug_assert!(!result + .ty(self.db.upcast(), &self.body.store) + .is_ptr(self.db.upcast())); + self.assign_inst_result(inst, src, src_ty) + } + + InstKind::AggregateAccess { value, indices } => { + let base = self.value_expr(*value); + let mut ptr = base; + let mut inner_ty = self.body.store.value_ty(*value); + for &idx in indices { + ptr = self.aggregate_elem_ptr(ptr, idx, inner_ty.deref(self.db.upcast())); + inner_ty = + inner_ty.projection_ty(self.db.upcast(), self.body.store.value_data(idx)); + } + + let result = self.body.store.inst_result(inst).unwrap(); + self.assign_inst_result(inst, ptr, inner_ty) + } + + InstKind::MapAccess { value, key } => { + let map_ty = self.body.store.value_ty(*value).deref(self.db.upcast()); + let value_expr = self.value_expr(*value); + let key_expr = self.value_expr(*key); + let key_ty = self.body.store.value_ty(*key); + let ptr = self + .ctx + .runtime + .map_value_ptr(self.db, value_expr, key_expr, key_ty); + let value_ty = match &map_ty.data(self.db.upcast()).kind { + TypeKind::Map(def) => def.value_ty, + _ => unreachable!(), + }; + + self.assign_inst_result(inst, ptr, value_ty.make_sptr(self.db.upcast())); + } + + InstKind::Call { + func, + args, + call_type, + } => { + let args: Vec<_> = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let result = match call_type { + CallType::Internal => { + self.ctx.function_dependency.insert(*func); + let func_name = identifier! {(self.db.codegen_function_symbol_name(*func))}; + expression! {[func_name]([args...])} + } + CallType::External => self.ctx.runtime.external_call(self.db, *func, args), + }; + match self.db.codegen_legalized_signature(*func).return_type { + Some(mut result_ty) => { + if result_ty.is_aggregate(self.db.upcast()) + | result_ty.is_string(self.db.upcast()) + { + result_ty = result_ty.make_mptr(self.db.upcast()); + } + self.assign_inst_result(inst, result, result_ty) + } + _ => self.sink.push(Statement::Expression(result)), + } + } + + InstKind::Revert { arg } => match arg { + Some(arg) => { + let arg_ty = self.body.store.value_ty(*arg); + let deref_ty = arg_ty.deref(self.db.upcast()); + let ty_data = deref_ty.data(self.db.upcast()); + let arg_expr = if deref_ty.is_zero_sized(self.db.upcast()) { + None + } else { + Some(self.value_expr(*arg)) + }; + let name = match &ty_data.kind { + ir::TypeKind::Struct(def) => &def.name, + ir::TypeKind::String(def) => "Error", + _ => "Panic", + }; + self.sink.push(yul::Statement::Expression( + self.ctx.runtime.revert(self.db, arg_expr, name, arg_ty), + )); + } + None => self.sink.push(statement! {revert(0, 0)}), + }, + + InstKind::Emit { arg } => { + let event = self.value_expr(*arg); + let event_ty = self.body.store.value_ty(*arg); + let result = self.ctx.runtime.emit(self.db, event, event_ty); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty); + } + + InstKind::Return { arg } => { + if let Some(arg) = arg { + let arg = self.value_expr(*arg); + let ret_value = self.ret_value.clone().unwrap(); + self.sink.push(statement! {[ret_value] := [arg]}); + } + self.sink.push(yul::Statement::Leave) + } + + InstKind::Keccak256 { arg } => { + let result = self.keccak256(*arg); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty); + } + + InstKind::AbiEncode { arg } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_expr = self.value_expr(*arg); + let src_ty = self.body.store.value_ty(*arg); + + let abi_encode = self.ctx.runtime.abi_encode( + self.db, + src_expr, + ptr, + src_ty, + ptr_ty.is_sptr(self.db.upcast()), + ); + self.sink.push(statement! { + pop([abi_encode]) + }); + } + + InstKind::Create { value, contract } => { + self.ctx.contract_dependency.insert(*contract); + + let value_expr = self.value_expr(*value); + let result = self.ctx.runtime.create(self.db, *contract, value_expr); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::Create2 { + value, + salt, + contract, + } => { + self.ctx.contract_dependency.insert(*contract); + + let value_expr = self.value_expr(*value); + let salt_expr = self.value_expr(*salt); + let result = self + .ctx + .runtime + .create2(self.db, *contract, value_expr, salt_expr); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::YulIntrinsic { op, args } => { + let args: Vec<_> = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let op_name = identifier! { (format!("{op}").strip_prefix("__").unwrap()) }; + let result = expression! { [op_name]([args...]) }; + // Intrinsic operation never returns ptr type, so we can use u256_ty as a dummy + // type for the result. + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::Nop => {} + + // These flow control instructions are already legalized. + InstKind::Jump { .. } | InstKind::Branch { .. } | InstKind::Switch { .. } => { + unreachable!() + } + } + } + + fn lower_if( + &mut self, + cond: ValueId, + then: Vec, + else_: Vec, + ) -> yul::Statement { + let cond = self.value_expr(cond); + + self.enter_scope(); + let then_body = self.lower_branch_body(then); + self.leave_scope(); + + self.enter_scope(); + let else_body = self.lower_branch_body(else_); + self.leave_scope(); + + switch! { + switch ([cond]) + (case 1 {[then_body...]}) + (case 0 {[else_body...]}) + } + } + + fn lower_switch( + &mut self, + scrutinee: ValueId, + table: Vec<(ValueId, Vec)>, + default: Option>, + ) -> yul::Statement { + let scrutinee = self.value_expr(scrutinee); + + let mut cases = vec![]; + for (value, insts) in table { + let value = self.value_expr(value); + let value = match value { + yul::Expression::Literal(lit) => lit, + _ => panic!("switch table values must be literal"), + }; + + self.enter_scope(); + let body = self.lower_branch_body(insts); + self.leave_scope(); + cases.push(yul::Case { + literal: Some(value), + block: block! { [body...] }, + }) + } + + if let Some(insts) = default { + let block = self.lower_branch_body(insts); + cases.push(case! { + default {[block...]} + }); + } + + switch! { + switch ([scrutinee]) + [cases...] + } + } + + fn lower_branch_body(&mut self, insts: Vec) -> Vec { + let mut body = vec![]; + std::mem::swap(&mut self.sink, &mut body); + for inst in insts { + self.lower_structural_inst(inst); + } + std::mem::swap(&mut self.sink, &mut body); + body + } + + fn lower_for(&mut self, body: Vec) -> yul::Statement { + let mut body_stmts = vec![]; + std::mem::swap(&mut self.sink, &mut body_stmts); + for inst in body { + self.lower_structural_inst(inst); + } + std::mem::swap(&mut self.sink, &mut body_stmts); + + block_statement! {( + for {} (1) {} + { + [body_stmts...] + } + )} + } + + fn lower_assign(&mut self, lhs: &AssignableValue, rhs: ValueId) -> yul::Statement { + match lhs { + AssignableValue::Value(value) => { + let lhs = self.value_ident(*value); + let rhs = self.value_expr(rhs); + statement! { [lhs] := [rhs] } + } + AssignableValue::Aggregate { .. } | AssignableValue::Map { .. } => { + let dst_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_ty = self.body.store.value_ty(rhs); + debug_assert_eq!( + dst_ty.deref(self.db.upcast()), + src_ty.deref(self.db.upcast()) + ); + + let dst = self.lower_assignable_value(lhs); + let src = self.value_expr(rhs); + + if src_ty.is_ptr(self.db.upcast()) { + let ty_size = literal_expression! { (self.value_ty_size_deref(rhs)) }; + + let expr = self.ctx.runtime.ptr_copy( + self.db, + src, + dst, + ty_size, + src_ty.is_sptr(self.db.upcast()), + dst_ty.is_sptr(self.db.upcast()), + ); + yul::Statement::Expression(expr) + } else { + let expr = self.ctx.runtime.ptr_store(self.db, dst, src, dst_ty); + yul::Statement::Expression(expr) + } + } + } + } + + fn lower_unary(&mut self, op: UnOp, value: ValueId) -> yul::Expression { + let value_expr = self.value_expr(value); + match op { + UnOp::Not => expression! { iszero([value_expr])}, + UnOp::Neg => { + let zero = literal_expression! {0}; + if self.body.store.value_data(value).is_imm() { + // Literals are checked at compile time (e.g. -128) so there's no point + // in adding a runtime check. + expression! {sub([zero], [value_expr])} + } else { + let value_ty = self.body.store.value_ty(value); + self.ctx + .runtime + .safe_sub(self.db, zero, value_expr, value_ty) + } + } + UnOp::Inv => expression! { not([value_expr])}, + } + } + + fn lower_binary( + &mut self, + op: BinOp, + lhs: ValueId, + rhs: ValueId, + inst: InstId, + ) -> yul::Expression { + let lhs_expr = self.value_expr(lhs); + let rhs_expr = self.value_expr(rhs); + let is_result_signed = self + .body + .store + .inst_result(inst) + .map(|val| { + let ty = val.ty(self.db.upcast(), &self.body.store); + ty.is_signed(self.db.upcast()) + }) + .unwrap_or(false); + let is_lhs_signed = self.body.store.value_ty(lhs).is_signed(self.db.upcast()); + + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result + .ty(self.db.upcast(), &self.body.store) + .deref(self.db.upcast()); + match op { + BinOp::Add => self + .ctx + .runtime + .safe_add(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Sub => self + .ctx + .runtime + .safe_sub(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Mul => self + .ctx + .runtime + .safe_mul(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Div => self + .ctx + .runtime + .safe_div(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Mod => self + .ctx + .runtime + .safe_mod(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Pow => self + .ctx + .runtime + .safe_pow(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Shl => expression! {shl([rhs_expr], [lhs_expr])}, + BinOp::Shr if is_result_signed => expression! {sar([rhs_expr], [lhs_expr])}, + BinOp::Shr => expression! {shr([rhs_expr], [lhs_expr])}, + BinOp::BitOr | BinOp::LogicalOr => expression! {or([lhs_expr], [rhs_expr])}, + BinOp::BitXor => expression! {xor([lhs_expr], [rhs_expr])}, + BinOp::BitAnd | BinOp::LogicalAnd => expression! {and([lhs_expr], [rhs_expr])}, + BinOp::Eq => expression! {eq([lhs_expr], [rhs_expr])}, + BinOp::Ne => expression! {iszero((eq([lhs_expr], [rhs_expr])))}, + BinOp::Ge if is_lhs_signed => expression! {iszero((slt([lhs_expr], [rhs_expr])))}, + BinOp::Ge => expression! {iszero((lt([lhs_expr], [rhs_expr])))}, + BinOp::Gt if is_lhs_signed => expression! {sgt([lhs_expr], [rhs_expr])}, + BinOp::Gt => expression! {gt([lhs_expr], [rhs_expr])}, + BinOp::Le if is_lhs_signed => expression! {iszero((sgt([lhs_expr], [rhs_expr])))}, + BinOp::Le => expression! {iszero((gt([lhs_expr], [rhs_expr])))}, + BinOp::Lt if is_lhs_signed => expression! {slt([lhs_expr], [rhs_expr])}, + BinOp::Lt => expression! {lt([lhs_expr], [rhs_expr])}, + } + } + + fn lower_cast(&mut self, value: ValueId, to: TypeId) -> yul::Expression { + let from_ty = self.body.store.value_ty(value); + debug_assert!(from_ty.is_primitive(self.db.upcast())); + debug_assert!(to.is_primitive(self.db.upcast())); + + let value = self.value_expr(value); + self.ctx.runtime.primitive_cast(self.db, value, from_ty) + } + + fn assign_inst_result(&mut self, inst: InstId, rhs: yul::Expression, rhs_ty: TypeId) { + // NOTE: We don't have `deref` feature yet, so need a heuristics for an + // assignment. + let stmt = if let Some(result) = self.body.store.inst_result(inst) { + let lhs = self.lower_assignable_value(result); + let lhs_ty = result.ty(self.db.upcast(), &self.body.store); + match result { + AssignableValue::Value(value) => { + match ( + lhs_ty.is_ptr(self.db.upcast()), + rhs_ty.is_ptr(self.db.upcast()), + ) { + (true, true) => { + if lhs_ty.is_mptr(self.db.upcast()) == rhs_ty.is_mptr(self.db.upcast()) + { + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } else { + let ty_size = rhs_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + yul::Statement::Expression(self.ctx.runtime.ptr_copy( + self.db, + rhs, + lhs, + literal_expression! { (ty_size) }, + rhs_ty.is_sptr(self.db.upcast()), + lhs_ty.is_sptr(self.db.upcast()), + )) + } + } + (true, false) => yul::Statement::Expression( + self.ctx.runtime.ptr_store(self.db, lhs, rhs, lhs_ty), + ), + + (false, true) => { + let rhs = self.ctx.runtime.ptr_load(self.db, rhs, rhs_ty); + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } + (false, false) => { + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } + } + } + AssignableValue::Aggregate { .. } | AssignableValue::Map { .. } => { + let expr = if rhs_ty.is_ptr(self.db.upcast()) { + let ty_size = rhs_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + self.ctx.runtime.ptr_copy( + self.db, + rhs, + lhs, + literal_expression! { (ty_size) }, + rhs_ty.is_sptr(self.db.upcast()), + lhs_ty.is_sptr(self.db.upcast()), + ) + } else { + self.ctx.runtime.ptr_store(self.db, lhs, rhs, lhs_ty) + }; + yul::Statement::Expression(expr) + } + } + } else { + yul::Statement::Expression(rhs) + }; + + self.sink.push(stmt); + } + + /// Extend a value to 256 bits. + fn extend_value(&mut self, value: yul::Expression, ty: TypeId) -> yul::Expression { + if ty.is_primitive(self.db.upcast()) { + self.ctx.runtime.primitive_cast(self.db, value, ty) + } else { + value + } + } + + fn declare_assignable_value(&mut self, value: &AssignableValue) { + match value { + AssignableValue::Value(value) if !self.value_map.contains(*value) => { + self.declare_value(*value); + } + _ => {} + } + } + + fn declare_value(&mut self, value: ValueId) { + let var = YulVariable::new(format!("$tmp_{}", value.index())); + self.value_map.insert(value, var.ident()); + let value_ty = self.body.store.value_ty(value); + + // Allocate memory for a value if a value is a pointer type. + let init = if value_ty.is_mptr(self.db.upcast()) { + let deref_ty = value_ty.deref(self.db.upcast()); + let ty_size = deref_ty.size_of(self.db.upcast(), SLOT_SIZE); + let size = literal_expression! { (ty_size) }; + Some(self.ctx.runtime.alloc(self.db, size)) + } else { + None + }; + + self.sink.push(yul::Statement::VariableDeclaration( + yul::VariableDeclaration { + identifiers: vec![var.ident()], + expression: init, + }, + )) + } + + fn value_expr(&mut self, value: ValueId) -> yul::Expression { + match self.body.store.value_data(value) { + Value::Local(_) | Value::Temporary { .. } => { + let ident = self.value_map.lookup(value).unwrap(); + literal_expression! {(ident)} + } + Value::Immediate { imm, .. } => { + literal_expression! {(imm)} + } + Value::Constant { constant, .. } => match &constant.data(self.db.upcast()).value { + ConstantValue::Immediate(imm) => { + // YUL does not support representing negative integers with leading minus (e.g. + // `-1` in YUL would lead to an ICE). To mitigate that we + // convert all numeric values into hexadecimal representation. + literal_expression! {(to_hex_str(imm))} + } + ConstantValue::Str(s) => { + self.ctx.string_constants.insert(s.to_string()); + self.ctx.runtime.string_construct(self.db, s, s.len()) + } + ConstantValue::Bool(true) => { + literal_expression! {1} + } + ConstantValue::Bool(false) => { + literal_expression! {0} + } + }, + Value::Unit { .. } => unreachable!(), + } + } + + fn value_ident(&self, value: ValueId) -> yul::Identifier { + self.value_map.lookup(value).unwrap().clone() + } + + fn make_tmp(&mut self, tmp: ValueId) -> yul::Identifier { + let ident = YulVariable::new(format! {"$tmp_{}", tmp.index()}).ident(); + self.value_map.insert(tmp, ident.clone()); + ident + } + + fn keccak256(&mut self, value: ValueId) -> yul::Expression { + let value_ty = self.body.store.value_ty(value); + debug_assert!(value_ty.is_mptr(self.db.upcast())); + + let value_size = value_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + let value_size_expr = literal_expression! {(value_size)}; + let value_expr = self.value_expr(value); + expression! {keccak256([value_expr], [value_size_expr])} + } + + fn lower_assignable_value(&mut self, value: &AssignableValue) -> yul::Expression { + match value { + AssignableValue::Value(value) => self.value_expr(*value), + + AssignableValue::Aggregate { lhs, idx } => { + let base_ptr = self.lower_assignable_value(lhs); + let ty = lhs + .ty(self.db.upcast(), &self.body.store) + .deref(self.db.upcast()); + self.aggregate_elem_ptr(base_ptr, *idx, ty) + } + AssignableValue::Map { lhs, key } => { + let map_ptr = self.lower_assignable_value(lhs); + let key_ty = self.body.store.value_ty(*key); + let key = self.value_expr(*key); + self.ctx + .runtime + .map_value_ptr(self.db, map_ptr, key, key_ty) + } + } + } + + fn aggregate_elem_ptr( + &mut self, + base_ptr: yul::Expression, + idx: ValueId, + base_ty: TypeId, + ) -> yul::Expression { + debug_assert!(base_ty.is_aggregate(self.db.upcast())); + + match &base_ty.data(self.db.upcast()).kind { + TypeKind::Array(def) => { + let elem_size = + literal_expression! {(base_ty.array_elem_size(self.db.upcast(), SLOT_SIZE))}; + self.validate_array_indexing(def.len, idx); + let idx = self.value_expr(idx); + let offset = expression! {mul([elem_size], [idx])}; + expression! { add([base_ptr], [offset]) } + } + _ => { + let elem_idx = match self.body.store.value_data(idx) { + Value::Immediate { imm, .. } => imm, + _ => panic!("only array type can use dynamic value indexing"), + }; + let offset = literal_expression! {(base_ty.aggregate_elem_offset(self.db.upcast(), elem_idx.clone(), SLOT_SIZE))}; + expression! {add([base_ptr], [offset])} + } + } + } + + fn validate_array_indexing(&mut self, array_len: usize, idx: ValueId) { + const PANIC_OUT_OF_BOUNDS: usize = 0x32; + + if let Value::Immediate { .. } = self.body.store.value_data(idx) { + return; + } + + let idx = self.value_expr(idx); + let max_idx = literal_expression! {(array_len - 1)}; + self.sink.push(statement!(if (gt([idx], [max_idx])) { + ([runtime::panic_revert_numeric( + self.ctx.runtime.as_mut(), + self.db, + literal_expression! {(PANIC_OUT_OF_BOUNDS)}, + )]) + })); + } + + fn value_ty_size(&self, value: ValueId) -> usize { + self.body + .store + .value_ty(value) + .size_of(self.db.upcast(), SLOT_SIZE) + } + + fn value_ty_size_deref(&self, value: ValueId) -> usize { + self.body + .store + .value_ty(value) + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE) + } + + fn enter_scope(&mut self) { + let value_map = std::mem::take(&mut self.value_map); + self.value_map = ScopedValueMap::with_parent(value_map); + } + + fn leave_scope(&mut self) { + let value_map = std::mem::take(&mut self.value_map); + self.value_map = value_map.into_parent(); + } +} + +#[derive(Debug, Default)] +struct ScopedValueMap { + parent: Option>, + map: FxHashMap, +} + +impl ScopedValueMap { + fn lookup(&self, value: ValueId) -> Option<&yul::Identifier> { + match self.map.get(&value) { + Some(ident) => Some(ident), + None => self.parent.as_ref().and_then(|p| p.lookup(value)), + } + } + + fn with_parent(parent: ScopedValueMap) -> Self { + Self { + parent: Some(parent.into()), + ..Self::default() + } + } + + fn into_parent(self) -> Self { + *self.parent.unwrap() + } + + fn insert(&mut self, value: ValueId, ident: yul::Identifier) { + self.map.insert(value, ident); + } + + fn contains(&self, value: ValueId) -> bool { + self.lookup(value).is_some() + } +} + +fn bit_mask(byte_size: usize) -> usize { + (1 << (byte_size * 8)) - 1 +} + +fn bit_mask_expr(byte_size: usize) -> yul::Expression { + let mask = format!("{:#x}", bit_mask(byte_size)); + literal_expression! {(mask)} +} diff --git a/crates/codegen2/src/yul/isel/inst_order.rs b/crates/codegen2/src/yul/isel/inst_order.rs new file mode 100644 index 0000000000..afc82f0016 --- /dev/null +++ b/crates/codegen2/src/yul/isel/inst_order.rs @@ -0,0 +1,1368 @@ +use fe_mir::{ + analysis::{ + domtree::DFSet, loop_tree::LoopId, post_domtree::PostIDom, ControlFlowGraph, DomTree, + LoopTree, PostDomTree, + }, + ir::{ + inst::{BranchInfo, SwitchTable}, + BasicBlockId, FunctionBody, InstId, ValueId, + }, +}; +use indexmap::{IndexMap, IndexSet}; + +#[derive(Debug, Clone)] +pub(super) enum StructuralInst { + Inst(InstId), + If { + cond: ValueId, + then: Vec, + else_: Vec, + }, + + Switch { + scrutinee: ValueId, + table: Vec<(ValueId, Vec)>, + default: Option>, + }, + + For { + body: Vec, + }, + + Break, + + Continue, +} + +pub(super) struct InstSerializer<'a> { + body: &'a FunctionBody, + cfg: ControlFlowGraph, + loop_tree: LoopTree, + df: DFSet, + domtree: DomTree, + pd_tree: PostDomTree, + scope: Option, +} + +impl<'a> InstSerializer<'a> { + pub(super) fn new(body: &'a FunctionBody) -> Self { + let cfg = ControlFlowGraph::compute(body); + let domtree = DomTree::compute(&cfg); + let df = domtree.compute_df(&cfg); + let pd_tree = PostDomTree::compute(body); + let loop_tree = LoopTree::compute(&cfg, &domtree); + + Self { + body, + cfg, + loop_tree, + df, + domtree, + pd_tree, + scope: None, + } + } + + pub(super) fn serialize(&mut self) -> Vec { + self.scope = None; + let entry = self.cfg.entry(); + let mut order = vec![]; + self.serialize_block(entry, &mut order); + order + } + + fn serialize_block(&mut self, block: BasicBlockId, order: &mut Vec) { + match self.loop_tree.loop_of_block(block) { + Some(lp) + if block == self.loop_tree.loop_header(lp) + && Some(block) != self.scope.as_ref().and_then(Scope::loop_header) => + { + let loop_exit = self.find_loop_exit(lp); + self.enter_loop_scope(lp, block, loop_exit); + let mut body = vec![]; + self.serialize_block(block, &mut body); + self.exit_scope(); + order.push(StructuralInst::For { body }); + + match loop_exit { + Some(exit) + if self + .scope + .as_ref() + .map(|scope| scope.branch_merge_block() != Some(exit)) + .unwrap_or(true) => + { + self.serialize_block(exit, order); + } + _ => {} + } + + return; + } + _ => {} + }; + + for inst in self.body.order.iter_inst(block) { + if self.body.store.is_terminator(inst) { + break; + } + if !self.body.store.is_nop(inst) { + order.push(StructuralInst::Inst(inst)); + } + } + + let terminator = self.body.order.terminator(&self.body.store, block).unwrap(); + match self.analyze_terminator(terminator) { + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } => self.serialize_if_terminator(cond, *then, *else_, merge_block, order), + + TerminatorInfo::Switch { + scrutinee, + table, + default, + merge_block, + } => self.serialize_switch_terminator( + scrutinee, + table, + default.map(|value| *value), + merge_block, + order, + ), + + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::Continue => order.push(StructuralInst::Continue), + TerminatorInfo::Break => order.push(StructuralInst::Break), + TerminatorInfo::FallThrough(next) => self.serialize_block(next, order), + TerminatorInfo::NormalInst(inst) => order.push(StructuralInst::Inst(inst)), + } + } + + fn serialize_if_terminator( + &mut self, + cond: ValueId, + then: TerminatorInfo, + else_: TerminatorInfo, + merge_block: Option, + order: &mut Vec, + ) { + let mut then_body = vec![]; + let mut else_body = vec![]; + + self.enter_branch_scope(merge_block); + self.serialize_branch_dest(then, &mut then_body, merge_block); + self.serialize_branch_dest(else_, &mut else_body, merge_block); + self.exit_scope(); + + order.push(StructuralInst::If { + cond, + then: then_body, + else_: else_body, + }); + if let Some(merge_block) = merge_block { + self.serialize_block(merge_block, order); + } + } + + fn serialize_switch_terminator( + &mut self, + scrutinee: ValueId, + table: Vec<(ValueId, TerminatorInfo)>, + default: Option, + merge_block: Option, + order: &mut Vec, + ) { + self.enter_branch_scope(merge_block); + + let mut serialized_table = Vec::with_capacity(table.len()); + for (value, dest) in table { + let mut body = vec![]; + self.serialize_branch_dest(dest, &mut body, merge_block); + serialized_table.push((value, body)); + } + + let serialized_default = default.map(|dest| { + let mut body = vec![]; + self.serialize_branch_dest(dest, &mut body, merge_block); + body + }); + + order.push(StructuralInst::Switch { + scrutinee, + table: serialized_table, + default: serialized_default, + }); + + self.exit_scope(); + + if let Some(merge_block) = merge_block { + self.serialize_block(merge_block, order); + } + } + + fn serialize_branch_dest( + &mut self, + dest: TerminatorInfo, + body: &mut Vec, + merge_block: Option, + ) { + match dest { + TerminatorInfo::Break => body.push(StructuralInst::Break), + TerminatorInfo::Continue => body.push(StructuralInst::Continue), + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::FallThrough(dest) => { + if Some(dest) != merge_block { + self.serialize_block(dest, body); + } + } + _ => unreachable!(), + }; + } + + fn enter_loop_scope(&mut self, lp: LoopId, header: BasicBlockId, exit: Option) { + let kind = ScopeKind::Loop { lp, header, exit }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn enter_branch_scope(&mut self, merge_block: Option) { + let kind = ScopeKind::Branch { merge_block }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn exit_scope(&mut self) { + let current_scope = std::mem::take(&mut self.scope); + self.scope = current_scope.unwrap().parent.map(|parent| *parent); + } + + // NOTE: We assume loop has at most one canonical loop exit. + fn find_loop_exit(&self, lp: LoopId) -> Option { + let mut exit_candidates = vec![]; + for block_in_loop in self.loop_tree.iter_blocks_post_order(&self.cfg, lp) { + for &succ in self.cfg.succs(block_in_loop) { + if !self.loop_tree.is_block_in_loop(succ, lp) { + exit_candidates.push(succ); + } + } + } + + if exit_candidates.is_empty() { + return None; + } + + if exit_candidates.len() == 1 { + let candidate = exit_candidates[0]; + let exit = if let Some(mut df) = self.df.frontiers(candidate) { + debug_assert_eq!(self.df.frontier_num(candidate), 1); + df.next() + } else { + Some(candidate) + }; + return exit; + } + + // If a candidate is a dominance frontier of all other nodes, then the candidate + // is a loop exit. + for &cand in &exit_candidates { + if exit_candidates.iter().all(|&block| { + if block == cand { + true + } else if let Some(mut df) = self.df.frontiers(block) { + df.any(|frontier| frontier == cand) + } else { + true + } + }) { + return Some(cand); + } + } + + // If all candidates have the same dominance frontier, then the frontier block + // is the canonicalized loop exit. + let mut frontier: IndexSet<_> = self + .df + .frontiers(exit_candidates.pop().unwrap()) + .map(std::iter::Iterator::collect) + .unwrap_or_default(); + for cand in exit_candidates { + for cand_frontier in self.df.frontiers(cand).unwrap() { + if !frontier.contains(&cand_frontier) { + frontier.remove(&cand_frontier); + } + } + } + debug_assert!(frontier.len() < 2); + frontier.iter().next().copied() + } + + fn analyze_terminator(&self, inst: InstId) -> TerminatorInfo { + debug_assert!(self.body.store.is_terminator(inst)); + + let inst_block = self.body.order.inst_block(inst); + match self.body.store.branch_info(inst) { + BranchInfo::Jump(dest) => self.analyze_jump(dest), + + BranchInfo::Branch(cond, then, else_) => self.analyze_if(inst_block, cond, then, else_), + + BranchInfo::Switch(scrutinee, table, default) => { + self.analyze_switch(inst_block, scrutinee, table, default) + } + + BranchInfo::NotBranch => TerminatorInfo::NormalInst(inst), + } + } + + fn analyze_if( + &self, + block: BasicBlockId, + cond: ValueId, + then_bb: BasicBlockId, + else_bb: BasicBlockId, + ) -> TerminatorInfo { + let then = Box::new(self.analyze_dest(then_bb)); + let else_ = Box::new(self.analyze_dest(else_bb)); + + let then_cands = self.find_merge_block_candidates(block, then_bb); + let else_cands = self.find_merge_block_candidates(block, else_bb); + debug_assert!(then_cands.len() < 2); + debug_assert!(else_cands.len() < 2); + + let merge_block = match (then_cands.as_slice(), else_cands.as_slice()) { + (&[then_cand], &[else_cand]) => { + if then_cand == else_cand { + Some(then_cand) + } else { + None + } + } + + (&[cand], []) => { + if cand == else_bb { + Some(cand) + } else { + None + } + } + + ([], &[cand]) => { + if cand == then_bb { + Some(cand) + } else { + None + } + } + + ([], []) => match self.pd_tree.post_idom(block) { + PostIDom::Block(block) => { + if let Some(lp) = self.scope.as_ref().and_then(Scope::loop_recursive) { + if self.loop_tree.is_block_in_loop(block, lp) { + Some(block) + } else { + None + } + } else { + Some(block) + } + } + _ => None, + }, + + (_, _) => unreachable!(), + }; + + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } + } + + fn analyze_switch( + &self, + block: BasicBlockId, + scrutinee: ValueId, + table: &SwitchTable, + default: Option, + ) -> TerminatorInfo { + let mut analyzed_table = Vec::with_capacity(table.len()); + + let mut merge_block_cands = IndexSet::default(); + for (value, dest) in table.iter() { + analyzed_table.push((value, self.analyze_dest(dest))); + merge_block_cands.extend(self.find_merge_block_candidates(block, dest)); + } + + let analyzed_default = default.map(|dest| { + merge_block_cands.extend(self.find_merge_block_candidates(block, dest)); + Box::new(self.analyze_dest(dest)) + }); + + TerminatorInfo::Switch { + scrutinee, + table: analyzed_table, + default: analyzed_default, + merge_block: self.select_switch_merge_block( + &merge_block_cands, + table.iter().map(|(_, d)| d).chain(default), + ), + } + } + + fn find_merge_block_candidates( + &self, + branch_inst_bb: BasicBlockId, + branch_dest_bb: BasicBlockId, + ) -> Vec { + if self.domtree.dominates(branch_dest_bb, branch_inst_bb) { + return vec![]; + } + + // a block `cand` can be a candidate of a `merge` block iff + // 1. `cand` is a dominance frontier of `branch_dest_bb`. + // 2. `cand` is NOT a dominator of `branch_dest_bb`. + // 3. `cand` is NOT a "merge" block of parent `if` or `switch`. + // 4. `cand` is NOT a "loop_exit" block of parent `loop`. + match self.df.frontiers(branch_dest_bb) { + Some(cands) => cands + .filter(|cand| { + !self.domtree.dominates(*cand, branch_dest_bb) + && Some(*cand) + != self + .scope + .as_ref() + .and_then(Scope::branch_merge_block_recursive) + && Some(*cand) != self.scope.as_ref().and_then(Scope::loop_exit_recursive) + }) + .collect(), + None => vec![], + } + } + + /// Each destination block of `switch` instruction could have multiple + /// candidates for the merge block because arm bodies can have multiple + /// predecessors, e.g., `default` arm. + /// So we need a heuristic to select the merge block from candidates. + /// + /// First, if one of the dominance frontiers of switch dests is a parent + /// merge block, then we stop searching the merge block because the parent + /// merge block should be the subsequent codes after the switch in terms of + /// high-level flow structure like Fe or yul. + /// + /// If no parent merge block is found, we start scoring the candidates by + /// the following function. + /// + /// The scoring function `F` is defined as follows: + /// 1. The initial score of each candidate('cand_bb`) is number of + /// predecessors of the candidate. + /// + /// 2. Find the `top_cand` of each `cand_bb`. `top_cand` can be found by + /// [`Self::try_find_top_cand`] method, see the method for details. + /// + /// 3. If `top_cand` is found, then add the `cand_bb` score to the + /// `top_cand` score, then set 0 to the `cand_bb` score. + /// + /// After the scoring, the candidates with the highest score will be + /// selected. + fn select_switch_merge_block( + &self, + cands: &IndexSet, + dests: impl Iterator, + ) -> Option { + let parent_merge = self + .scope + .as_ref() + .and_then(Scope::branch_merge_block_recursive); + for dest in dests { + if self + .df + .frontiers(dest) + .map(|mut frontieres| frontieres.any(|frontier| Some(frontier) == parent_merge)) + .unwrap_or_default() + { + return None; + } + } + + let mut cands_with_score = cands + .iter() + .map(|cand| (*cand, self.cfg.preds(*cand).len())) + .collect::>(); + + for cand_bb in cands_with_score.keys().copied().collect::>() { + if let Some(top_cand) = self.try_find_top_cand(&cands_with_score, cand_bb) { + let score = std::mem::take(cands_with_score.get_mut(&cand_bb).unwrap()); + *cands_with_score.get_mut(&top_cand).unwrap() += score; + } + } + + cands_with_score + .iter() + .max_by_key(|(_, score)| *score) + .map(|(&cand, _)| cand) + } + + /// Try to find the `top_cand` of the `cand_bb`. + /// A `top_cand` can be found by the following rules: + /// + /// 1. Find the block which is contained in DF of `cand_bb` and in + /// `cands_with_score`. + /// + /// 2. If a block is found in 1., and the score of the block is positive, + /// then the block is `top_cand`. + /// + /// 2'. If a block is found in 1., and the score of the block is 0, then the + /// `top_cand` of the block is `top_cand` of `cand_bb`. + /// + /// 2''. If a block is NOT found in 1., then there is no `top_cand` for + /// `cand_bb`. + fn try_find_top_cand( + &self, + cands_with_score: &IndexMap, + cand_bb: BasicBlockId, + ) -> Option { + let mut frontiers = match self.df.frontiers(cand_bb) { + Some(frontiers) => frontiers, + _ => return None, + }; + + while let Some(frontier_bb) = frontiers.next() { + if cands_with_score.contains_key(&frontier_bb) { + debug_assert!(frontiers.all(|bb| !cands_with_score.contains_key(&bb))); + if cands_with_score[&frontier_bb] != 0 { + return Some(frontier_bb); + } else { + return self.try_find_top_cand(cands_with_score, frontier_bb); + } + } + } + + None + } + + fn analyze_jump(&self, dest: BasicBlockId) -> TerminatorInfo { + self.analyze_dest(dest) + } + + fn analyze_dest(&self, dest: BasicBlockId) -> TerminatorInfo { + match &self.scope { + Some(scope) => { + if Some(dest) == scope.loop_header_recursive() { + TerminatorInfo::Continue + } else if Some(dest) == scope.loop_exit_recursive() { + TerminatorInfo::Break + } else if Some(dest) == scope.branch_merge_block_recursive() { + TerminatorInfo::ToMergeBlock + } else { + TerminatorInfo::FallThrough(dest) + } + } + + None => TerminatorInfo::FallThrough(dest), + } + } +} + +struct Scope { + kind: ScopeKind, + parent: Option>, +} + +#[derive(Debug, Clone, Copy)] +enum ScopeKind { + Loop { + lp: LoopId, + header: BasicBlockId, + exit: Option, + }, + Branch { + merge_block: Option, + }, +} + +impl Scope { + fn loop_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { lp, .. } => Some(lp), + _ => self.parent.as_ref()?.loop_recursive(), + } + } + + fn loop_header(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => None, + } + } + + fn loop_header_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => self.parent.as_ref()?.loop_header_recursive(), + } + } + + fn loop_exit_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { exit, .. } => exit, + _ => self.parent.as_ref()?.loop_exit_recursive(), + } + } + + fn branch_merge_block(&self) -> Option { + match self.kind { + ScopeKind::Branch { merge_block } => merge_block, + _ => None, + } + } + + fn branch_merge_block_recursive(&self) -> Option { + match self.kind { + ScopeKind::Branch { + merge_block: Some(merge_block), + } => Some(merge_block), + _ => self.parent.as_ref()?.branch_merge_block_recursive(), + } + } +} + +#[derive(Debug, Clone)] +enum TerminatorInfo { + If { + cond: ValueId, + then: Box, + else_: Box, + merge_block: Option, + }, + + Switch { + scrutinee: ValueId, + table: Vec<(ValueId, TerminatorInfo)>, + default: Option>, + merge_block: Option, + }, + + ToMergeBlock, + Continue, + Break, + FallThrough(BasicBlockId), + NormalInst(InstId), +} + +#[cfg(test)] +mod tests { + use fe_mir::ir::{body_builder::BodyBuilder, inst::InstKind, FunctionId, SourceInfo, TypeId}; + + use super::*; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + fn serialize_func_body(func: &mut FunctionBody) -> impl Iterator { + InstSerializer::new(func).serialize().into_iter() + } + + fn expect_if( + insts: &mut impl Iterator, + ) -> ( + impl Iterator, + impl Iterator, + ) { + match insts.next().unwrap() { + StructuralInst::If { then, else_, .. } => (then.into_iter(), else_.into_iter()), + _ => panic!("expect if inst"), + } + } + + fn expect_switch( + insts: &mut impl Iterator, + ) -> Vec> { + match insts.next().unwrap() { + StructuralInst::Switch { table, default, .. } => { + let mut arms: Vec<_> = table + .into_iter() + .map(|(_, insts)| insts.into_iter()) + .collect(); + if let Some(default) = default { + arms.push(default.into_iter()); + } + + arms + } + + _ => panic!("expect if inst"), + } + } + + fn expect_for( + insts: &mut impl Iterator, + ) -> impl Iterator { + match insts.next().unwrap() { + StructuralInst::For { body } => body.into_iter(), + _ => panic!("expect if inst"), + } + } + + fn expect_break(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Break)) + } + + fn expect_continue(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Continue)) + } + + fn expect_return(func: &FunctionBody, insts: &mut impl Iterator) { + let inst = insts.next().unwrap(); + match inst { + StructuralInst::Inst(inst) => { + assert!(matches!( + func.store.inst_data(inst).kind, + InstKind::Return { .. } + )) + } + _ => panic!("expect return"), + } + } + + fn expect_end(insts: &mut impl Iterator) { + assert!(insts.next().is_none()) + } + + #[test] + fn if_non_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | + // | + // v + // +-------+ + // | else_ | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(&mut order); + expect_return(&func, &mut then); + expect_end(&mut then); + expect_return(&func, &mut else_); + expect_end(&mut else_); + + expect_end(&mut order); + } + + #[test] + fn if_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | | + // | | + // | v + // | +-------+ + // | | else_ | + // | +-------+ + // | | + // | | + // | v + // | +-------+ + // +--------> | merge | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + let merge = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(&mut order); + expect_end(&mut then); + expect_end(&mut else_); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn nested_if() { + // +-----+ + // | bb0 | -+ + // +-----+ | + // | | + // | | + // v | + // +-----+ +-----+ | + // | bb3 | <-- | bb1 | | + // +-----+ +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb4 | | + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | <+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then1, mut else2) = expect_if(&mut order); + expect_end(&mut else2); + + let (mut then3, mut else4) = expect_if(&mut then1); + expect_end(&mut then1); + expect_return(&func, &mut then3); + expect_end(&mut then3); + expect_end(&mut else4); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn simple_loop() { + // +--------+ + // | bb0 | -+ + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // +> | header | | + // | +--------+ | + // | | | + // | | | + // | v | + // | +--------+ | + // +- | latch | | + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // | exit | <+ + // +--------+ + let mut builder = body_builder(); + + let header = builder.make_block(); + let latch = builder.make_block(); + let exit = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(header); + builder.jump(latch, SourceInfo::dummy()); + + builder.move_to_block(latch); + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(exit); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + + let mut body = expect_for(&mut lp); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); + + expect_continue(&mut continue_); + expect_end(&mut continue_); + + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut empty); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_with_continue() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | + // | v + // | +---------------+ +-----+ + // | | bb1 | --> | bb3 | + // | +---------------+ +-----+ + // | | ^ ^ | + // | | | +---------+ + // | v | + // | +-----+ | + // | | bb4 | -+ + // | +-----+ + // | | + // | | + // | v + // | +-----+ + // +> | bb2 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); + + let mut body = expect_for(&mut lp); + + let (mut continue_, mut empty) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_end(&mut empty); + + let (mut continue_, mut break_) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut body); + expect_end(&mut lp); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_with_break() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | +---------+ + // | v v | + // | +---------------+ +-----+ + // | | bb1 | --> | bb4 | + // | +---------------+ +-----+ + // | | | + // | | | + // | v | + // | +-----+ | + // | | bb3 | | + // | +-----+ | + // | | | + // | | | + // | v | + // | +-----+ | + // +> | bb2 | <---------------+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); + + let mut body = expect_for(&mut lp); + + let (mut break_, mut latch) = expect_if(&mut body); + expect_break(&mut break_); + expect_end(&mut break_); + + let (mut continue_, mut break_) = expect_if(&mut latch); + expect_end(&mut latch); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut body); + expect_end(&mut lp); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_no_guard() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + // | + // | + // v + // +-----+ + // | bb3 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.branch(v0, bb1, bb3, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(&mut order); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); + + expect_continue(&mut continue_); + expect_end(&mut continue_); + + expect_break(&mut break_); + expect_end(&mut break_); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn infinite_loop() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.jump(bb1, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(&mut order); + expect_continue(&mut body); + expect_end(&mut body); + + expect_end(&mut order); + } + + #[test] + fn switch_basic() { + // +-----+ +-------+ +-----+ + // | bb2 | <-- | bb0 | --> | bb3 | + // +-----+ +-------+ +-----+ + // | | | + // | | | + // | v | + // | +-------+ | + // | | bb1 | | + // | +-------+ | + // | | | + // | | | + // | v | + // | +-------+ | + // +-------> | merge | <-----+ + // +-------+ + let mut builder = body_builder(); + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let merge = builder.make_block(); + + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb1); + table.add_arm(dummy_value, bb2); + table.add_arm(dummy_value, bb3); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.jump(merge, SourceInfo::dummy()); + builder.move_to_block(bb3); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(dummy_value, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let arms = expect_switch(&mut order); + assert_eq!(arms.len(), 3); + for mut arm in arms { + expect_end(&mut arm); + } + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn switch_default() { + // +-----------+ + // | | + // | | + // | +----+--------+ + // v | | | + // +-----+ | +-------+ | +---------+ + // | bb2 | -+ | bb0 | -+> | bb3 | + // +-----+ +-------+ | +---------+ + // | | | | + // | | | | + // v v | v + // +-----+ +-------+ | +---------+ + // | bb5 | +- | bb1 | +> | default | <+ + // +-----+ | +-------+ +---------+ | + // | | | | | + // | | | | | + // | | v | | + // | | +-------+ | | + // | | | bb4 | | | + // | | +-------+ | | + // | | | | | + // +----+------+ | | | + // | | v | | + // | | +-------+ | | + // | +-------> | merge | <-----+ | + // | +-------+ | + // | | + // +-----------------------------------------+ + let mut builder = body_builder(); + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + let bb5 = builder.make_block(); + let default = builder.make_block(); + let merge = builder.make_block(); + + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb1); + table.add_arm(dummy_value, bb2); + table.add_arm(dummy_value, bb3); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb1); + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb4); + table.add_arm(dummy_value, default); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb2); + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb5); + table.add_arm(dummy_value, default); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(default, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(bb5); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(default); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(dummy_value, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut arms = expect_switch(&mut order); + assert_eq!(arms.len(), 3); + + let mut bb3_jump = arms.pop().unwrap(); + expect_end(&mut bb3_jump); + + let mut bb2_switch = arms.pop().unwrap(); + let bb2_switch_arms = expect_switch(&mut bb2_switch); + assert_eq!(bb2_switch_arms.len(), 2); + for mut bb2_switch_arm in bb2_switch_arms { + expect_end(&mut bb2_switch_arm); + } + expect_end(&mut bb2_switch); + + let mut bb1_switch = arms.pop().unwrap(); + let bb1_switch_arms = expect_switch(&mut bb1_switch); + assert_eq!(bb1_switch_arms.len(), 2); + for mut bb1_switch_arm in bb1_switch_arms { + expect_end(&mut bb1_switch_arm); + } + expect_end(&mut bb1_switch); + + expect_return(&func, &mut order); + expect_end(&mut order); + } +} diff --git a/crates/codegen2/src/yul/isel/mod.rs b/crates/codegen2/src/yul/isel/mod.rs new file mode 100644 index 0000000000..2507774ff8 --- /dev/null +++ b/crates/codegen2/src/yul/isel/mod.rs @@ -0,0 +1,9 @@ +pub mod context; +mod contract; +mod function; +mod inst_order; +mod test; + +pub use contract::{lower_contract, lower_contract_deployable}; +pub use function::lower_function; +pub use test::lower_test; diff --git a/crates/codegen2/src/yul/isel/test.rs b/crates/codegen2/src/yul/isel/test.rs new file mode 100644 index 0000000000..d4246e4edb --- /dev/null +++ b/crates/codegen2/src/yul/isel/test.rs @@ -0,0 +1,71 @@ +use crate::CodegenDb; + +use super::context::Context; +use fe_mir::ir::FunctionId; +use yultsur::{yul, *}; + +pub fn lower_test(db: &dyn CodegenDb, test: FunctionId) -> yul::Object { + let mut context = Context::default(); + let test = db.mir_lowered_func_signature(test); + context.function_dependency.insert(test); + + let dep_constants = context.resolve_constant_dependency(db); + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let dep_contracts = context.resolve_contract_dependency(db); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let test_func_name = identifier! { (db.codegen_function_symbol_name(test)) }; + let call = function_call_statement! {[test_func_name]()}; + + let code = code! { + [dep_functions...] + [runtime_funcs...] + [call] + (stop()) + }; + + let name = identifier! { test }; + let object = yul::Object { + name, + code, + objects: dep_contracts, + data: dep_constants, + }; + + normalize_object(object) +} + +fn normalize_object(obj: yul::Object) -> yul::Object { + let data = obj + .data + .into_iter() + .map(|data| yul::Data { + name: data.name, + value: data + .value + .replace('\\', "\\\\\\\\") + .replace('\n', "\\\\n") + .replace('"', "\\\\\"") + .replace('\r', "\\\\r") + .replace('\t', "\\\\t"), + }) + .collect::>(); + yul::Object { + name: obj.name, + code: obj.code, + objects: obj + .objects + .into_iter() + .map(normalize_object) + .collect::>(), + data, + } +} diff --git a/crates/codegen2/src/yul/legalize/body.rs b/crates/codegen2/src/yul/legalize/body.rs new file mode 100644 index 0000000000..d3a556fa61 --- /dev/null +++ b/crates/codegen2/src/yul/legalize/body.rs @@ -0,0 +1,219 @@ +use fe_mir::ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::InstKind, + value::AssignableValue, + FunctionBody, Inst, InstId, TypeId, TypeKind, Value, ValueId, +}; + +use crate::CodegenDb; + +use super::critical_edge::CriticalEdgeSplitter; + +pub fn legalize_func_body(db: &dyn CodegenDb, body: &mut FunctionBody) { + CriticalEdgeSplitter::new().run(body); + legalize_func_arg(db, body); + + let mut cursor = BodyCursor::new_at_entry(body); + loop { + match cursor.loc() { + CursorLocation::BlockTop(_) | CursorLocation::BlockBottom(_) => cursor.proceed(), + CursorLocation::Inst(inst) => { + legalize_inst(db, &mut cursor, inst); + } + CursorLocation::NoWhere => break, + } + } +} + +fn legalize_func_arg(db: &dyn CodegenDb, body: &mut FunctionBody) { + for value in body.store.func_args_mut() { + let ty = value.ty(); + if ty.is_contract(db.upcast()) { + let slot_ptr = make_storage_ptr(db, ty); + *value = slot_ptr; + } else if (ty.is_aggregate(db.upcast()) || ty.is_string(db.upcast())) + && !ty.is_zero_sized(db.upcast()) + { + change_ty(value, ty.make_mptr(db.upcast())) + } + } +} + +fn legalize_inst(db: &dyn CodegenDb, cursor: &mut BodyCursor, inst: InstId) { + if legalize_unit_construct(db, cursor, inst) { + return; + } + legalize_declared_ty(db, cursor.body_mut(), inst); + legalize_inst_arg(db, cursor.body_mut(), inst); + legalize_inst_result(db, cursor.body_mut(), inst); + cursor.proceed(); +} + +fn legalize_unit_construct(db: &dyn CodegenDb, cursor: &mut BodyCursor, inst: InstId) -> bool { + let should_remove = match &cursor.body().store.inst_data(inst).kind { + InstKind::Declare { local } => is_value_zst(db, cursor.body(), *local), + InstKind::AggregateConstruct { ty, .. } => ty.deref(db.upcast()).is_zero_sized(db.upcast()), + InstKind::AggregateAccess { .. } | InstKind::MapAccess { .. } | InstKind::Cast { .. } => { + let result_value = cursor.body().store.inst_result(inst).unwrap(); + is_lvalue_zst(db, cursor.body(), result_value) + } + + _ => false, + }; + + if should_remove { + cursor.remove_inst() + } + + should_remove +} + +fn legalize_declared_ty(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + let value = match &body.store.inst_data(inst_id).kind { + InstKind::Declare { local } => *local, + _ => return, + }; + + let value_ty = body.store.value_ty(value); + if value_ty.is_aggregate(db.upcast()) { + let new_ty = value_ty.make_mptr(db.upcast()); + let value_data = body.store.value_data_mut(value); + change_ty(value_data, new_ty) + } +} + +fn legalize_inst_arg(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + // Replace inst with dummy inst to avoid borrow checker complaining. + let dummy_inst = Inst::nop(); + let mut inst = body.store.replace_inst(inst_id, dummy_inst); + + for arg in inst.args() { + let ty = body.store.value_ty(arg); + if ty.is_string(db.upcast()) { + let string_ptr = ty.make_mptr(db.upcast()); + change_ty(body.store.value_data_mut(arg), string_ptr) + } + } + + match &mut inst.kind { + InstKind::AggregateConstruct { args, .. } => { + args.retain(|arg| !is_value_zst(db, body, *arg)); + } + + InstKind::Call { args, .. } => { + args.retain(|arg| !is_value_zst(db, body, *arg) && !is_value_contract(db, body, *arg)) + } + + InstKind::Return { arg } => { + if arg.map(|arg| is_value_zst(db, body, arg)).unwrap_or(false) { + *arg = None; + } + } + + InstKind::MapAccess { key: arg, .. } | InstKind::Emit { arg } => { + let arg_ty = body.store.value_ty(*arg); + if arg_ty.is_zero_sized(db.upcast()) { + *arg = body.store.store_value(make_zst_ptr(db, arg_ty)); + } + } + + InstKind::Cast { value, to, .. } => { + if to.is_aggregate(db.upcast()) && !to.is_zero_sized(db.upcast()) { + let value_ty = body.store.value_ty(*value); + if value_ty.is_mptr(db.upcast()) { + *to = to.make_mptr(db.upcast()); + } else if value_ty.is_sptr(db.upcast()) { + *to = to.make_sptr(db.upcast()); + } else { + unreachable!() + } + } + } + + _ => {} + } + + body.store.replace_inst(inst_id, inst); +} + +fn legalize_inst_result(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + let result_value = if let Some(result) = body.store.inst_result(inst_id) { + result + } else { + return; + }; + + if is_lvalue_zst(db, body, result_value) { + body.store.remove_inst_result(inst_id); + return; + }; + + let value_id = if let Some(value_id) = result_value.value_id() { + value_id + } else { + return; + }; + let result_ty = body.store.value_ty(value_id); + let new_ty = if result_ty.is_aggregate(db.upcast()) || result_ty.is_string(db.upcast()) { + match &body.store.inst_data(inst_id).kind { + InstKind::AggregateAccess { value, .. } => { + let value_ty = body.store.value_ty(*value); + match &value_ty.data(db.upcast()).kind { + TypeKind::MPtr(..) => result_ty.make_mptr(db.upcast()), + // Note: All SPtr aggregate access results should be SPtr already + _ => unreachable!(), + } + } + _ => result_ty.make_mptr(db.upcast()), + } + } else { + return; + }; + + let value = body.store.value_data_mut(value_id); + change_ty(value, new_ty); +} + +fn change_ty(value: &mut Value, new_ty: TypeId) { + match value { + Value::Local(val) => val.ty = new_ty, + Value::Immediate { ty, .. } + | Value::Temporary { ty, .. } + | Value::Unit { ty } + | Value::Constant { ty, .. } => *ty = new_ty, + } +} + +fn make_storage_ptr(db: &dyn CodegenDb, ty: TypeId) -> Value { + debug_assert!(ty.is_contract(db.upcast())); + let ty = ty.make_sptr(db.upcast()); + + Value::Immediate { imm: 0.into(), ty } +} + +fn make_zst_ptr(db: &dyn CodegenDb, ty: TypeId) -> Value { + debug_assert!(ty.is_zero_sized(db.upcast())); + let ty = ty.make_mptr(db.upcast()); + + Value::Immediate { imm: 0.into(), ty } +} + +/// Returns `true` if a value has a zero sized type. +fn is_value_zst(db: &dyn CodegenDb, body: &FunctionBody, value: ValueId) -> bool { + body.store + .value_ty(value) + .deref(db.upcast()) + .is_zero_sized(db.upcast()) +} + +fn is_value_contract(db: &dyn CodegenDb, body: &FunctionBody, value: ValueId) -> bool { + let ty = body.store.value_ty(value); + ty.deref(db.upcast()).is_contract(db.upcast()) +} + +fn is_lvalue_zst(db: &dyn CodegenDb, body: &FunctionBody, lvalue: &AssignableValue) -> bool { + lvalue + .ty(db.upcast(), &body.store) + .deref(db.upcast()) + .is_zero_sized(db.upcast()) +} diff --git a/crates/codegen2/src/yul/legalize/critical_edge.rs b/crates/codegen2/src/yul/legalize/critical_edge.rs new file mode 100644 index 0000000000..3e3d689ab3 --- /dev/null +++ b/crates/codegen2/src/yul/legalize/critical_edge.rs @@ -0,0 +1,121 @@ +use fe_mir::{ + analysis::ControlFlowGraph, + ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::InstKind, + BasicBlock, BasicBlockId, FunctionBody, Inst, InstId, SourceInfo, + }, +}; + +#[derive(Debug)] +pub struct CriticalEdgeSplitter { + critical_edges: Vec, +} + +impl CriticalEdgeSplitter { + pub fn new() -> Self { + Self { + critical_edges: Vec::default(), + } + } + + pub fn run(&mut self, func: &mut FunctionBody) { + let cfg = ControlFlowGraph::compute(func); + + for block in cfg.post_order() { + let terminator = func.order.terminator(&func.store, block).unwrap(); + self.add_critical_edges(terminator, func, &cfg); + } + + self.split_edges(func); + } + + fn add_critical_edges( + &mut self, + terminator: InstId, + func: &FunctionBody, + cfg: &ControlFlowGraph, + ) { + for to in func.store.branch_info(terminator).block_iter() { + if cfg.preds(to).len() > 1 { + self.critical_edges.push(CriticalEdge { terminator, to }); + } + } + } + + fn split_edges(&mut self, func: &mut FunctionBody) { + for edge in std::mem::take(&mut self.critical_edges) { + let terminator = edge.terminator; + let source_block = func.order.inst_block(terminator); + let original_dest = edge.to; + + // Create new block that contains only jump inst. + let new_dest = func.store.store_block(BasicBlock {}); + let mut cursor = BodyCursor::new(func, CursorLocation::BlockTop(source_block)); + cursor.insert_block(new_dest); + cursor.set_loc(CursorLocation::BlockTop(new_dest)); + cursor.store_and_insert_inst(Inst::new( + InstKind::Jump { + dest: original_dest, + }, + SourceInfo::dummy(), + )); + + // Rewrite branch destination to the new dest. + func.store + .rewrite_branch_dest(terminator, original_dest, new_dest); + } + } +} + +#[derive(Debug)] +struct CriticalEdge { + terminator: InstId, + to: BasicBlockId, +} + +#[cfg(test)] +mod tests { + use fe_mir::ir::{body_builder::BodyBuilder, FunctionId, TypeId}; + + use super::*; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + #[test] + fn critical_edge_remove() { + let mut builder = body_builder(); + let lp_header = builder.make_block(); + let lp_body = builder.make_block(); + let exit = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(false, dummy_ty); + builder.branch(v0, lp_header, exit, SourceInfo::dummy()); + + builder.move_to_block(lp_header); + builder.jump(lp_body, SourceInfo::dummy()); + + builder.move_to_block(lp_body); + builder.branch(v0, lp_header, exit, SourceInfo::dummy()); + + builder.move_to_block(exit); + builder.ret(v0, SourceInfo::dummy()); + + let mut func = builder.build(); + CriticalEdgeSplitter::new().run(&mut func); + let cfg = ControlFlowGraph::compute(&func); + + for &header_pred in cfg.preds(lp_header) { + debug_assert_eq!(cfg.succs(header_pred).len(), 1); + debug_assert_eq!(cfg.succs(header_pred)[0], lp_header); + } + + for &exit_pred in cfg.preds(exit) { + debug_assert_eq!(cfg.succs(exit_pred).len(), 1); + debug_assert_eq!(cfg.succs(exit_pred)[0], exit); + } + } +} diff --git a/crates/codegen2/src/yul/legalize/mod.rs b/crates/codegen2/src/yul/legalize/mod.rs new file mode 100644 index 0000000000..62e82f78fe --- /dev/null +++ b/crates/codegen2/src/yul/legalize/mod.rs @@ -0,0 +1,6 @@ +mod body; +mod critical_edge; +mod signature; + +pub use body::legalize_func_body; +pub use signature::legalize_func_signature; diff --git a/crates/codegen2/src/yul/legalize/signature.rs b/crates/codegen2/src/yul/legalize/signature.rs new file mode 100644 index 0000000000..ff032d751b --- /dev/null +++ b/crates/codegen2/src/yul/legalize/signature.rs @@ -0,0 +1,27 @@ +use fe_mir::ir::{FunctionSignature, TypeKind}; + +use crate::CodegenDb; + +pub fn legalize_func_signature(db: &dyn CodegenDb, sig: &mut FunctionSignature) { + // Remove param if the type is contract or zero-sized. + let params = &mut sig.params; + params.retain(|param| match param.ty.data(db.upcast()).kind { + TypeKind::Contract(_) => false, + _ => !param.ty.deref(db.upcast()).is_zero_sized(db.upcast()), + }); + + // Legalize param types. + for param in params.iter_mut() { + param.ty = db.codegen_legalized_type(param.ty); + } + + if let Some(ret_ty) = sig.return_type { + // Remove return type if the type is contract or zero-sized. + if ret_ty.is_contract(db.upcast()) || ret_ty.deref(db.upcast()).is_zero_sized(db.upcast()) { + sig.return_type = None; + } else { + // Legalize param types. + sig.return_type = Some(db.codegen_legalized_type(ret_ty)); + } + } +} diff --git a/crates/codegen2/src/yul/mod.rs b/crates/codegen2/src/yul/mod.rs new file mode 100644 index 0000000000..6e7e95457e --- /dev/null +++ b/crates/codegen2/src/yul/mod.rs @@ -0,0 +1,26 @@ +use std::borrow::Cow; + +pub mod isel; +pub mod legalize; +pub mod runtime; + +mod slot_size; + +use yultsur::*; + +/// A helper struct to abstract ident and expr. +struct YulVariable<'a>(Cow<'a, str>); + +impl<'a> YulVariable<'a> { + fn expr(&self) -> yul::Expression { + identifier_expression! {(format!{"${}", self.0})} + } + + fn ident(&self) -> yul::Identifier { + identifier! {(format!{"${}", self.0})} + } + + fn new(name: impl Into>) -> Self { + Self(name.into()) + } +} diff --git a/crates/codegen2/src/yul/runtime/abi.rs b/crates/codegen2/src/yul/runtime/abi.rs new file mode 100644 index 0000000000..68def8d5e2 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/abi.rs @@ -0,0 +1,950 @@ +use crate::{ + yul::{ + runtime::{error_revert_numeric, make_ptr}, + slot_size::{yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, + CodegenDb, +}; + +use super::{AbiSrcLocation, DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_abi::types::AbiType; +use fe_mir::ir::{self, types::ArrayDef, TypeId, TypeKind}; +use yultsur::*; + +pub(super) fn make_abi_encode_primitive_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + ([src.ident()] := [provider.primitive_cast(db, src.expr(), legalized_ty)]) + ([yul::Statement::Expression(provider.ptr_store( + db, + dst.expr(), + src.expr(), + make_ptr(db, yul_primitive_type(db), is_dst_storage), + ))]) + ([enc_size.ident()] := 32) + } + }; + + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_static_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let is_dst_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_abi_ty = db.codegen_abi_type(elem_ty); + let elem_ptr_ty = make_ptr(db, elem_ty, false); + let elem_ty_size = deref_ty.array_elem_size(db.upcast(), SLOT_SIZE); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let header_size = elem_abi_ty.header_size(); + let iter_count = literal_expression! {(len)}; + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + + (pop([provider.abi_encode(db, src.expr(), dst.expr(), elem_ptr_ty, is_dst_storage)])) + ([src.ident()] := add([src.expr()], [literal_expression!{(elem_ty_size)}])) + ([dst.ident()] := add([dst.expr()], [literal_expression!{(header_size)}])) + }) + ([enc_size.ident()] := [literal_expression! {(header_size * len)}]) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_abi_encode_dynamic_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let is_dst_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_header_size = 32; + let total_header_size = elem_header_size * len; + let elem_ptr_ty = make_ptr(db, elem_ty, false); + let elem_ty_size = deref_ty.array_elem_size(db.upcast(), SLOT_SIZE); + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let data_ptr = YulVariable::new("data_ptr"); + let enc_size = YulVariable::new("enc_size"); + let iter_count = literal_expression! {(len)}; + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (let [header_ptr.ident()] := [dst.expr()]) + (let [data_ptr.ident()] := add([dst.expr()], [literal_expression!{(total_header_size)}])) + ([enc_size.ident()] := [literal_expression!{(total_header_size)}]) + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, src.expr(), data_ptr.expr(), elem_ptr_ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression!{(elem_header_size)}])) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + ([src.ident()] := add([src.expr()], [literal_expression!{(elem_ty_size)}])) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_abi_encode_static_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let deref_ty = legalized_ty.deref(db.upcast()); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let field_enc_size = YulVariable::new("field_enc_size"); + let mut body = vec![ + statement! {[enc_size.ident()] := 0 }, + statement! {let [field_enc_size.ident()] := 0 }, + ]; + let field_num = deref_ty.aggregate_field_num(db.upcast()); + + for idx in 0..field_num { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_ty_ptr = make_ptr(db, field_ty, false); + let field_offset = deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE); + let src_offset = expression! { add([src.expr()], [literal_expression!{(field_offset)}]) }; + body.push(statement!{ + [field_enc_size.ident()] := [provider.abi_encode(db, src_offset, dst.expr(), field_ty_ptr, is_dst_storage)] + }); + body.push(statement! { + [enc_size.ident()] := add([enc_size.expr()], [field_enc_size.expr()]) + }); + + if idx < field_num - 1 { + body.push(assignment! {[dst.ident()] := add([dst.expr()], [field_enc_size.expr()])}); + } + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident(), dst.ident()], + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_encode_dynamic_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let is_src_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let field_num = deref_ty.aggregate_field_num(db.upcast()); + + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let enc_size = YulVariable::new("enc_size"); + let data_ptr = YulVariable::new("data_ptr"); + + let total_header_size = literal_expression! { ((0..field_num).fold(0, |acc, idx| { + let ty = deref_ty.projection_ty_imm(db.upcast(), idx); + acc + db.codegen_abi_type(ty).header_size() + })) }; + let mut body = statements! { + (let [header_ptr.ident()] := [dst.expr()]) + ([enc_size.ident()] := [total_header_size]) + (let [data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + }; + + for idx in 0..field_num { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_abi_ty = db.codegen_abi_type(field_ty); + let field_offset = + literal_expression! { (deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + let field_ptr = expression! { add([src.expr()], [field_offset]) }; + let field_ptr_ty = make_ptr(db, field_ty, is_src_storage); + + let stmts = if field_abi_ty.is_static() { + statements! { + (pop([provider.abi_encode(db, field_ptr, header_ptr.expr(), field_ptr_ty, is_dst_storage)])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression! {(field_abi_ty.header_size())}])) + } + } else { + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + statements! { + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, field_ptr, data_ptr.expr(), field_ptr_ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + } + }; + body.extend_from_slice(&stmts); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident(), dst.ident()], + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_encode_string_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let string_len = YulVariable::new("string_len"); + let enc_size = YulVariable::new("enc_size"); + + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (let [string_len.ident()] := mload([src.expr()])) + (let data_size := add(32, [string_len.expr()])) + ([enc_size.ident()] := mul((div((add(data_size, 31)), 32)), 32)) + (let padding_word_ptr := add([dst.expr()], (sub([enc_size.expr()], 32)))) + (mstore(padding_word_ptr, 0)) + ([yul::Statement::Expression(provider.ptr_copy(db, src.expr(), dst.expr(), literal_expression!{data_size}, false, is_dst_storage))]) + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_bytes_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + len: usize, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let dst_len_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + ([enc_size.ident()] := [literal_expression!{ (ceil_32(32 + len)) }]) + (if (gt([enc_size.expr()], 0)) { + (let padding_word_ptr := add([dst.expr()], (sub([enc_size.expr()], 32)))) + (mstore(padding_word_ptr, 0)) + }) + ([yul::Statement::Expression(provider.ptr_store(db, dst.expr(), literal_expression!{ (len) }, dst_len_ty))]) + ([dst.ident()] := add(32, [dst.expr()])) + ([yul::Statement::Expression(provider.ptr_copy(db, src.expr(), dst.expr(), literal_expression!{(len)}, false, is_dst_storage))]) + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_seq( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + value_tys: &[TypeId], + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let value_num = value_tys.len(); + let abi_tys: Vec<_> = value_tys + .iter() + .map(|ty| db.codegen_abi_type(ty.deref(db.upcast()))) + .collect(); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let enc_size = YulVariable::new("enc_size"); + let data_ptr = YulVariable::new("data_ptr"); + let values: Vec<_> = (0..value_num) + .map(|idx| YulVariable::new(format!("value{idx}"))) + .collect(); + + let total_header_size = + literal_expression! { (abi_tys.iter().fold(0, |acc, ty| acc + ty.header_size())) }; + let mut body = statements! { + (let [header_ptr.ident()] := [dst.expr()]) + ([enc_size.ident()] := [total_header_size]) + (let [data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + }; + + for i in 0..value_num { + let ty = value_tys[i]; + let abi_ty = &abi_tys[i]; + let value = &values[i]; + let stmts = if abi_ty.is_static() { + statements! { + (pop([provider.abi_encode(db, value.expr(), header_ptr.expr(), ty, is_dst_storage)])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression!{ (abi_ty.header_size()) }])) + } + } else { + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + statements! { + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, value.expr(), data_ptr.expr(), ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + } + }; + body.extend_from_slice(&stmts); + } + + let mut parameters = vec![dst.ident()]; + for value in values { + parameters.push(value.ident()); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters, + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_decode( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + types: &[TypeId], + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let header_size = types + .iter() + .fold(0, |acc, ty| acc + db.codegen_abi_type(*ty).header_size()); + let src = YulVariable::new("$src"); + let enc_size = YulVariable::new("$enc_size"); + let header_ptr = YulVariable::new("header_ptr"); + let data_offset = YulVariable::new("data_offset"); + let tmp_offset = YulVariable::new("tmp_offset"); + let returns: Vec<_> = (0..types.len()) + .map(|i| YulVariable::new(format!("$ret{i}"))) + .collect(); + + let abi_enc_size = abi_enc_size(db, types); + let size_check = match abi_enc_size { + AbiEncodingSize::Static(size) => statements! { + (if (iszero((eq([enc_size.expr()], [literal_expression!{(size)}])))) + { [revert_with_invalid_abi_data(provider, db)] + }) + }, + AbiEncodingSize::Bounded { min, max } => statements! { + (if (or( + (lt([enc_size.expr()], [literal_expression!{(min)}])), + (gt([enc_size.expr()], [literal_expression!{(max)}])) + )) { + [revert_with_invalid_abi_data(provider, db)] + }) + }, + }; + + let mut body = statements! { + (let [header_ptr.ident()] := [src.expr()]) + (let [data_offset.ident()] := [literal_expression!{ (header_size) }]) + (let [tmp_offset.ident()] := 0) + }; + for i in 0..returns.len() { + let ret_value = &returns[i]; + let field_ty = types[i]; + let field_abi_ty = db.codegen_abi_type(field_ty.deref(db.upcast())); + if field_abi_ty.is_static() { + body.push(statement!{ [ret_value.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), field_ty, abi_loc)] }); + } else { + let identifiers = identifiers! { + [ret_value.ident()] + [tmp_offset.ident()] + }; + body.push(yul::Statement::Assignment(yul::Assignment { + identifiers, + expression: provider.abi_decode_dynamic( + db, + expression! {add([src.expr()], [data_offset.expr()])}, + field_ty, + abi_loc, + ), + })); + body.push(statement! { ([data_offset.ident()] := add([data_offset.expr()], [tmp_offset.expr()])) }); + }; + + let field_header_size = literal_expression! { (field_abi_ty.header_size()) }; + body.push( + statement! { [header_ptr.ident()] := add([header_ptr.expr()], [field_header_size]) }, + ); + } + + let offset_check = match abi_enc_size { + AbiEncodingSize::Static(_) => vec![], + AbiEncodingSize::Bounded { .. } => statements! { + (if (iszero((eq([enc_size.expr()], [data_offset.expr()])))) { [revert_with_invalid_abi_data(provider, db)] }) + }, + }; + + let returns: Vec<_> = returns.iter().map(YulVariable::ident).collect(); + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [enc_size.ident()]) -> [returns...] { + [size_check...] + [body...] + [offset_check...] + } + }; + RuntimeFunction::from_statement(func_def) +} + +impl DefaultRuntimeProvider { + fn abi_decode_static( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + ty: TypeId, + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let ty = db.codegen_legalized_type(ty).deref(db.upcast()); + let abi_ty = db.codegen_abi_type(ty.deref(db.upcast())); + debug_assert!(abi_ty.is_static()); + + let func_name_postfix = match abi_loc { + AbiSrcLocation::CallData => "calldata", + AbiSrcLocation::Memory => "memory", + }; + + let args = vec![src]; + if ty.is_primitive(db.upcast()) { + let name = format! { + "$abi_decode_primitive_type_{}_from_{}", + ty.0, func_name_postfix, + }; + return self.create_then_call(&name, args, |provider| { + make_abi_decode_primitive_type(provider, db, &name, ty, abi_loc) + }); + } + + let name = format! { + "$abi_decode_static_aggregate_type_{}_from_{}", + ty.0, func_name_postfix, + }; + self.create_then_call(&name, args, |provider| { + make_abi_decode_static_aggregate_type(provider, db, &name, ty, abi_loc) + }) + } + + fn abi_decode_dynamic( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + ty: TypeId, + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let ty = db.codegen_legalized_type(ty).deref(db.upcast()); + let abi_ty = db.codegen_abi_type(ty.deref(db.upcast())); + debug_assert!(!abi_ty.is_static()); + + let func_name_postfix = match abi_loc { + AbiSrcLocation::CallData => "calldata", + AbiSrcLocation::Memory => "memory", + }; + + let mut args = vec![src]; + match abi_ty { + AbiType::String => { + let len = match &ty.data(db.upcast()).kind { + TypeKind::String(len) => *len, + _ => unreachable!(), + }; + args.push(literal_expression! {(len)}); + let name = format! {"$abi_decode_string_from_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_string_type(provider, db, &name, abi_loc) + }) + } + + AbiType::Bytes => { + let len = match &ty.data(db.upcast()).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + _ => unreachable!(), + }; + args.push(literal_expression! {(len)}); + let name = format! {"$abi_decode_bytes_from_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_bytes_type(provider, db, &name, abi_loc) + }) + } + + AbiType::Array { .. } => { + let name = + format! {"$abi_decode_dynamic_array_{}_from_{}", ty.0, func_name_postfix}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_dynamic_elem_array_type(provider, db, &name, ty, abi_loc) + }) + } + + AbiType::Tuple(_) => { + let name = + format! {"$abi_decode_dynamic_aggregate_{}_from_{}", ty.0, func_name_postfix}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_dynamic_aggregate_type(provider, db, &name, ty, abi_loc) + }) + } + + _ => unreachable!(), + } + } +} + +fn make_abi_decode_primitive_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + debug_assert! {ty.is_primitive(db.upcast())} + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let ret = YulVariable::new("ret"); + + let decode = match abi_loc { + AbiSrcLocation::CallData => { + statement! { [ret.ident()] := calldataload([src.expr()]) } + } + AbiSrcLocation::Memory => { + statement! { [ret.ident()] := mload([src.expr()]) } + } + }; + + let ty_size_bits = ty.size_of(db.upcast(), SLOT_SIZE) * 8; + let validation = if ty_size_bits == 256 { + statements! {} + } else if ty.is_signed(db.upcast()) { + let shift_num = literal_expression! { ( ty_size_bits - 1) }; + let tmp1 = YulVariable::new("tmp1"); + let tmp2 = YulVariable::new("tmp2"); + statements! { + (let [tmp1.ident()] := iszero((shr([shift_num.clone()], [ret.expr()])))) + (let [tmp2.ident()] := iszero((shr([shift_num], (not([ret.expr()])))))) + (if (iszero((or([tmp1.expr()], [tmp2.expr()])))) { + [revert_with_invalid_abi_data(provider, db)] + }) + } + } else { + let shift_num = literal_expression! { ( ty_size_bits) }; + let tmp = YulVariable::new("tmp"); + statements! { + (let [tmp.ident()] := iszero((shr([shift_num], [ret.expr()])))) + (if (iszero([tmp.expr()])) { + [revert_with_invalid_abi_data(provider, db)] + }) + } + }; + + let func = function_definition! { + function [func_name.ident()]([src.ident()]) -> [ret.ident()] { + ([decode]) + [validation...] + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_static_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + debug_assert!(ty.is_aggregate(db.upcast())); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let ret = YulVariable::new("ret"); + let field_data = YulVariable::new("field_data"); + let type_size = literal_expression! { (ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let mut body = statements! { + (let [field_data.ident()] := 0) + ([ret.ident()] := [provider.alloc(db, type_size)]) + }; + + let field_num = ty.aggregate_field_num(db.upcast()); + for idx in 0..field_num { + let field_ty = ty.projection_ty_imm(db.upcast(), idx); + let field_ty_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + body.push(statement! { [field_data.ident()] := [provider.abi_decode_static(db, src.expr(), field_ty, abi_loc)] }); + + let dst_offset = + literal_expression! { (ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + let field_ty_ptr = make_ptr(db, field_ty, false); + if field_ty.is_primitive(db.upcast()) { + body.push(yul::Statement::Expression(provider.ptr_store( + db, + expression! {add([ret.expr()], [dst_offset])}, + field_data.expr(), + field_ty_ptr, + ))); + } else { + body.push(yul::Statement::Expression(provider.ptr_copy( + db, + field_data.expr(), + expression! {add([ret.expr()], [dst_offset])}, + literal_expression! { (field_ty_size) }, + false, + false, + ))); + } + + if idx < field_num - 1 { + let abi_field_ty = db.codegen_abi_type(field_ty); + let field_abi_ty_size = literal_expression! { (abi_field_ty.header_size()) }; + body.push(assignment! {[src.ident()] := add([src.expr()], [field_abi_ty_size])}); + } + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident()], + returns: vec![ret.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +fn make_abi_decode_string_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let max_len = YulVariable::new("max_len"); + let string_size = YulVariable::new("string_size"); + let dst_size = YulVariable::new("dst_size"); + let end_word = YulVariable::new("end_word"); + let end_word_ptr = YulVariable::new("end_word_ptr"); + let padding_size_bits = YulVariable::new("padding_size_bits"); + let primitive_ty_ptr = make_ptr(db, yul_primitive_type(db), false); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [max_len.ident()]) -> [(vec![decoded_data.ident(), decoded_size.ident()])...] { + (let string_len := [provider.abi_decode_static(db, src.expr(), primitive_ty_ptr, abi_loc)]) + (if (gt(string_len, [max_len.expr()])) { [revert_with_invalid_abi_data(provider, db)] } ) + (let [string_size.ident()] := add(string_len, 32)) + ([decoded_size.ident()] := mul((div((add([string_size.expr()], 31)), 32)), 32)) + (let [end_word_ptr.ident()] := sub((add([src.expr()], [decoded_size.expr()])), 32)) + (let [end_word.ident()] := [provider.abi_decode_static(db, end_word_ptr.expr(), primitive_ty_ptr, abi_loc)]) + (let [padding_size_bits.ident()] := mul((sub([decoded_size.expr()], [string_size.expr()])), 8)) + [(check_right_padding(provider, db, end_word.expr(), padding_size_bits.expr()))...] + (let [dst_size.ident()] := add([max_len.expr()], 32)) + ([decoded_data.ident()] := [provider.alloc(db, dst_size.expr())]) + ([ptr_copy_decode(provider, db, src.expr(), decoded_data.expr(), string_size.expr(), abi_loc)]) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_bytes_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let max_len = YulVariable::new("max_len"); + let bytes_size = YulVariable::new("bytes_size"); + let end_word = YulVariable::new("end_word"); + let end_word_ptr = YulVariable::new("end_word_ptr"); + let padding_size_bits = YulVariable::new("padding_size_bits"); + let primitive_ty_ptr = make_ptr(db, yul_primitive_type(db), false); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [max_len.ident()]) -> [(vec![decoded_data.ident(),decoded_size.ident()])...] { + (let [bytes_size.ident()] := [provider.abi_decode_static(db, src.expr(), primitive_ty_ptr, abi_loc)]) + (if (iszero((eq([bytes_size.expr()], [max_len.expr()])))) { [revert_with_invalid_abi_data(provider, db)] } ) + ([src.ident()] := add([src.expr()], 32)) + (let padded_data_size := mul((div((add([bytes_size.expr()], 31)), 32)), 32)) + ([decoded_size.ident()] := add(padded_data_size, 32)) + (let [end_word_ptr.ident()] := sub((add([src.expr()], padded_data_size)), 32)) + (let [end_word.ident()] := [provider.abi_decode_static(db, end_word_ptr.expr(), primitive_ty_ptr, abi_loc)]) + (let [padding_size_bits.ident()] := mul((sub(padded_data_size, [bytes_size.expr()])), 8)) + [(check_right_padding(provider, db, end_word.expr(), padding_size_bits.expr()))...] + ([decoded_data.ident()] := [provider.alloc(db, max_len.expr())]) + ([ptr_copy_decode(provider, db, src.expr(), decoded_data.expr(), bytes_size.expr(), abi_loc)]) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_dynamic_elem_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_ty_size = literal_expression! { (deref_ty.array_elem_size(db.upcast(), SLOT_SIZE)) }; + let total_header_size = literal_expression! { (32 * len) }; + let iter_count = literal_expression! { (len) }; + let ret_size = literal_expression! { (deref_ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let header_ptr = YulVariable::new("header_ptr"); + let data_ptr = YulVariable::new("data_ptr"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let decoded_size_tmp = YulVariable::new("decoded_size_tmp"); + let ret_elem_ptr = YulVariable::new("ret_elem_ptr"); + let elem_data = YulVariable::new("elem_data"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()]) -> [decoded_data.ident()], [decoded_size.ident()] { + ([decoded_data.ident()] := [provider.alloc(db, ret_size)]) + ([decoded_size.ident()] := [total_header_size]) + (let [decoded_size_tmp.ident()] := 0) + (let [header_ptr.ident()] := [src.expr()]) + (let [data_ptr.ident()] := 0) + (let [elem_data.ident()] := 0) + (let [ret_elem_ptr.ident()] := [decoded_data.expr()]) + + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + ([data_ptr.ident()] := add([src.expr()], [provider.abi_decode_static(db, header_ptr.expr(), yul_primitive_type(db), abi_loc)])) + ([assignment! {[elem_data.ident()], [decoded_size_tmp.ident()] := [provider.abi_decode_dynamic(db, data_ptr.expr(), elem_ty, abi_loc)] }]) + ([decoded_size.ident()] := add([decoded_size.expr()], [decoded_size_tmp.expr()])) + ([yul::Statement::Expression(provider.ptr_copy(db, elem_data.expr(), ret_elem_ptr.expr(), elem_ty_size.clone(), false, false))]) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([ret_elem_ptr.ident()] := add([ret_elem_ptr.expr()], [elem_ty_size])) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_dynamic_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let deref_ty = legalized_ty.deref(db.upcast()); + let type_size = literal_expression! { (deref_ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let header_ptr = YulVariable::new("header_ptr"); + let data_offset = YulVariable::new("data_offset"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let decoded_size_tmp = YulVariable::new("decoded_size_tmp"); + let ret_field_ptr = YulVariable::new("ret_field_ptr"); + let field_data = YulVariable::new("field_data"); + + let mut body = statements! { + ([decoded_data.ident()] := [provider.alloc(db, type_size)]) + ([decoded_size.ident()] := 0) + (let [decoded_size_tmp.ident()] := 0) + (let [header_ptr.ident()] := [src.expr()]) + (let [data_offset.ident()] := 0) + (let [field_data.ident()] := 0) + (let [ret_field_ptr.ident()] := 0) + }; + + for i in 0..deref_ty.aggregate_field_num(db.upcast()) { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), i); + let field_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + let field_abi_ty = db.codegen_abi_type(field_ty); + let field_offset = deref_ty.aggregate_elem_offset(db.upcast(), i, SLOT_SIZE); + + let decode_data = if field_abi_ty.is_static() { + statements! { + ([field_data.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), field_ty, abi_loc)]) + ([decoded_size_tmp.ident()] := [literal_expression!{ (field_abi_ty.header_size()) }]) + } + } else { + statements! { + ([data_offset.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), yul_primitive_type(db), abi_loc)]) + ([assignment! { + [field_data.ident()], [decoded_size_tmp.ident()] := + [provider.abi_decode_dynamic( + db, + expression!{ add([src.expr()], [data_offset.expr()]) }, + field_ty, + abi_loc + )] + }]) + ([decoded_size_tmp.ident()] := add([decoded_size_tmp.expr()], 32)) + } + }; + body.extend_from_slice(&decode_data); + body.push(assignment!{ [decoded_size.ident()] := add([decoded_size.expr()], [decoded_size_tmp.expr()]) }); + + body.push(assignment! { [ret_field_ptr.ident()] := add([decoded_data.expr()], [literal_expression!{ (field_offset) }])}); + let copy_to_ret = if field_ty.is_primitive(db.upcast()) { + let field_ptr_ty = make_ptr(db, field_ty, false); + yul::Statement::Expression(provider.ptr_store( + db, + ret_field_ptr.expr(), + field_data.expr(), + field_ptr_ty, + )) + } else { + yul::Statement::Expression(provider.ptr_copy( + db, + field_data.expr(), + ret_field_ptr.expr(), + literal_expression! { (field_size) }, + false, + false, + )) + }; + body.push(copy_to_ret); + + let header_size = literal_expression! { (field_abi_ty.header_size()) }; + body.push(statement! { + [header_ptr.ident()] := add([header_ptr.expr()], [header_size]) + }); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident()], + returns: vec![decoded_data.ident(), decoded_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +enum AbiEncodingSize { + Static(usize), + Bounded { min: usize, max: usize }, +} + +fn abi_enc_size(db: &dyn CodegenDb, types: &[TypeId]) -> AbiEncodingSize { + let mut min = 0; + let mut max = 0; + for &ty in types { + let legalized_ty = db.codegen_legalized_type(ty); + min += db.codegen_abi_type_minimum_size(legalized_ty); + max += db.codegen_abi_type_maximum_size(legalized_ty); + } + + if min == max { + AbiEncodingSize::Static(min) + } else { + AbiEncodingSize::Bounded { min, max } + } +} + +fn revert_with_invalid_abi_data( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, +) -> yul::Statement { + const ERROR_INVALID_ABI_DATA: usize = 0x103; + let error_code = literal_expression! { (ERROR_INVALID_ABI_DATA) }; + error_revert_numeric(provider, db, error_code) +} + +fn check_right_padding( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + val: yul::Expression, + size_bits: yul::Expression, +) -> Vec { + statements! { + (let bits_shifted := sub(256, [size_bits])) + (let is_ok := iszero((shl(bits_shifted, [val])))) + (if (iszero((is_ok))) { + [revert_with_invalid_abi_data(provider, db)] + }) + } +} + +fn ptr_copy_decode( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + len: yul::Expression, + loc: AbiSrcLocation, +) -> yul::Statement { + match loc { + AbiSrcLocation::CallData => { + statement! { calldatacopy([dst], [src], [len]) } + } + AbiSrcLocation::Memory => { + yul::Statement::Expression(provider.ptr_copy(db, src, dst, len, false, false)) + } + } +} + +fn ceil_32(len: usize) -> usize { + ((len + 31) / 32) * 32 +} diff --git a/crates/codegen2/src/yul/runtime/contract.rs b/crates/codegen2/src/yul/runtime/contract.rs new file mode 100644 index 0000000000..194679dcba --- /dev/null +++ b/crates/codegen2/src/yul/runtime/contract.rs @@ -0,0 +1,127 @@ +use crate::{ + yul::{runtime::AbiSrcLocation, YulVariable}, + CodegenDb, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::{FunctionId, Type, TypeKind}; + +use hir::hir_def::Contract; +use yultsur::*; + +pub(super) fn make_create( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + contract: Contract, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let contract_symbol = literal_expression! { + (format!(r#""{}""#, db.codegen_contract_deployer_symbol_name(contract))) + }; + + let size = YulVariable::new("size"); + let value = YulVariable::new("value"); + let func = function_definition! { + function [func_name.ident()]([value.ident()]) -> addr { + (let [size.ident()] := datasize([contract_symbol.clone()])) + (let mem_ptr := [provider.avail(db)]) + (let contract_ptr := dataoffset([contract_symbol])) + (datacopy(mem_ptr, contract_ptr, [size.expr()])) + (addr := create([value.expr()], mem_ptr, [size.expr()])) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_create2( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + contract: Contract, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let contract_symbol = literal_expression! { + (format!(r#""{}""#, db.codegen_contract_deployer_symbol_name(contract))) + }; + + let size = YulVariable::new("size"); + let value = YulVariable::new("value"); + let func = function_definition! { + function [func_name.ident()]([value.ident()], salt) -> addr { + (let [size.ident()] := datasize([contract_symbol.clone()])) + (let mem_ptr := [provider.avail(db)]) + (let contract_ptr := dataoffset([contract_symbol])) + (datacopy(mem_ptr, contract_ptr, [size.expr()])) + (addr := create2([value.expr()], mem_ptr, [size.expr()], salt)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_external_call( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + function: FunctionId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let sig = db.codegen_legalized_signature(function); + let param_num = sig.params.len(); + + let mut args = Vec::with_capacity(param_num); + let mut arg_tys = Vec::with_capacity(param_num); + for param in &sig.params { + args.push(YulVariable::new(param.name.as_str())); + arg_tys.push(param.ty); + } + let ret_ty = sig.return_type; + + let func_addr = YulVariable::new("func_addr"); + let params: Vec<_> = args.iter().map(YulVariable::ident).collect(); + let params_expr: Vec<_> = args.iter().map(YulVariable::expr).collect(); + let input = YulVariable::new("input"); + let input_size = YulVariable::new("input_size"); + let output_size = YulVariable::new("output_size"); + let output = YulVariable::new("output"); + + let func_selector = literal_expression! { (format!{"0x{}", db.codegen_abi_function(function).selector().hex()}) }; + let selector_ty = db.mir_intern_type(Type::new(TypeKind::U32, None).into()); + + let mut body = statements! { + (let [input.ident()] := [provider.avail(db)]) + [yul::Statement::Expression(provider.ptr_store(db, input.expr(), func_selector, selector_ty.make_mptr(db.upcast())))] + (let [input_size.ident()] := add(4, [provider.abi_encode_seq(db, ¶ms_expr, expression!{ add([input.expr()], 4) }, &arg_tys, false)])) + (let [output.ident()] := add([provider.avail(db)], [input_size.expr()])) + (let success := call((gas()), [func_addr.expr()], 0, [input.expr()], [input_size.expr()], 0, 0)) + (let [output_size.ident()] := returndatasize()) + (returndatacopy([output.expr()], 0, [output_size.expr()])) + (if (iszero(success)) { + (revert([output.expr()], [output_size.expr()])) + }) + }; + let func = if let Some(ret_ty) = ret_ty { + let ret = YulVariable::new("$ret"); + body.push( + statement!{ + [ret.ident()] := [provider.abi_decode(db, output.expr(), output_size.expr(), &[ret_ty], AbiSrcLocation::Memory)] + } + ); + function_definition! { + function [func_name.ident()]([func_addr.ident()], [params...]) -> [ret.ident()] { + [body...] + } + } + } else { + function_definition! { + function [func_name.ident()]([func_addr.ident()], [params...]) { + [body...] + } + } + }; + + RuntimeFunction::from_statement(func) +} diff --git a/crates/codegen2/src/yul/runtime/data.rs b/crates/codegen2/src/yul/runtime/data.rs new file mode 100644 index 0000000000..85eccd5704 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/data.rs @@ -0,0 +1,461 @@ +use crate::{ + yul::{ + runtime::{make_ptr, BitMask}, + slot_size::{yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, + CodegenDb, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::{types::TupleDef, Type, TypeId, TypeKind}; + +use yultsur::*; + +const HASH_SCRATCH_SPACE_START: usize = 0x00; +const HASH_SCRATCH_SPACE_SIZE: usize = 64; +const FREE_MEMORY_ADDRESS_STORE: usize = HASH_SCRATCH_SPACE_START + HASH_SCRATCH_SPACE_SIZE; +const FREE_MEMORY_START: usize = FREE_MEMORY_ADDRESS_STORE + 32; + +pub(super) fn make_alloc(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let free_address_ptr = literal_expression! {(FREE_MEMORY_ADDRESS_STORE)}; + let free_memory_start = literal_expression! {(FREE_MEMORY_START)}; + let func = function_definition! { + function [func_name.ident()](size) -> ptr { + (ptr := mload([free_address_ptr.clone()])) + (if (eq(ptr, 0x00)) { (ptr := [free_memory_start]) }) + (mstore([free_address_ptr], (add(ptr, size)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_avail(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let free_address_ptr = literal_expression! {(FREE_MEMORY_ADDRESS_STORE)}; + let free_memory_start = literal_expression! {(FREE_MEMORY_START)}; + let func = function_definition! { + function [func_name.ident()]() -> ptr { + (ptr := mload([free_address_ptr])) + (if (eq(ptr, 0x00)) { (ptr := [free_memory_start]) }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mcopym(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + (let iter_count := div([size.expr()], 32)) + (let original_src := [src.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (mstore([dst.expr()], (mload([src.expr()])))) + ([src.ident()] := add([src.expr()], 32)) + ([dst.ident()] := add([dst.expr()], 32)) + }) + + (let rem := sub([size.expr()], (sub([src.expr()], original_src)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((mload([src.expr()])), src_mask)) + (let dst_value := and((mload([dst.expr()])), dst_mask)) + (mstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mcopys(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([dst.ident()] := div([dst.expr()], 32)) + (let iter_count := div([size.expr()], 32)) + (let original_src := [src.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (sstore([dst.expr()], (mload([src.expr()])))) + ([src.ident()] := add([src.expr()], 32)) + ([dst.ident()] := add([dst.expr()], 1)) + }) + + (let rem := sub([size.expr()], (sub([src.expr()], original_src)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((mload([src.expr()])), src_mask)) + (let dst_value := and((sload([dst.expr()])), dst_mask)) + (sstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_scopym(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([src.ident()] := div([src.expr()], 32)) + (let iter_count := div([size.expr()], 32)) + (let original_dst := [dst.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (mstore([dst.expr()], (sload([src.expr()])))) + ([src.ident()] := add([src.expr()], 1)) + ([dst.ident()] := add([dst.expr()], 32)) + }) + + (let rem := sub([size.expr()], (sub([dst.expr()], original_dst)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((sload([src.expr()])), src_mask)) + (let dst_value := and((mload([dst.expr()])), dst_mask)) + (mstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_scopys(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([src.ident()] := div([src.expr()], 32)) + ([dst.ident()] := div([dst.expr()], 32)) + (let iter_count := div((add([size.expr()], 31)), 32)) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (sstore([dst.expr()], (sload([src.expr()])))) + ([src.ident()] := add([src.expr()], 1)) + ([dst.ident()] := add([dst.expr()], 1)) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_sptr_store(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, value, size_bits) { + (let rem_bits := mul((mod(ptr, 32)), 8)) + (let shift_bits := sub(256, (add(rem_bits, size_bits)))) + (let mask := (shl(shift_bits, (sub((shl(size_bits, 1)), 1))))) + (let inv_mask := not(mask)) + (let slot := div(ptr, 32)) + (let new_value := or((and((sload(slot)), inv_mask)), (and((shl(shift_bits, value)), mask)))) + (sstore(slot, new_value)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mptr_store(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, value, shift_num, mask) { + (value := shl(shift_num, value)) + (let ptr_value := and((mload(ptr)), mask)) + (value := or(value, ptr_value)) + (mstore(ptr, value)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_sptr_load(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, size_bits) -> ret { + (let rem_bits := mul((mod(ptr, 32)), 8)) + (let shift_num := sub(256, (add(rem_bits, size_bits)))) + (let slot := div(ptr, 32)) + (ret := shr(shift_num, (sload(slot)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mptr_load(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, shift_num) -> ret { + (ret := shr(shift_num, (mload(ptr)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +// TODO: We can optimize aggregate initialization by combining multiple +// `ptr_store` operations into single `ptr_store` operation. +pub(super) fn make_aggregate_init( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + arg_tys: Vec, +) -> RuntimeFunction { + debug_assert!(legalized_ty.is_ptr(db.upcast())); + let is_sptr = legalized_ty.is_sptr(db.upcast()); + let inner_ty = legalized_ty.deref(db.upcast()); + let ptr = YulVariable::new("ptr"); + let field_num = inner_ty.aggregate_field_num(db.upcast()); + + let iter_field_args = || (0..field_num).map(|i| YulVariable::new(format! {"arg{i}"})); + + let mut body = vec![]; + for (idx, field_arg) in iter_field_args().enumerate() { + let field_arg_ty = arg_tys[idx]; + let field_ty = inner_ty + .projection_ty_imm(db.upcast(), idx) + .deref(db.upcast()); + let field_ty_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + let field_ptr_ty = make_ptr(db, field_ty, is_sptr); + let field_offset = + literal_expression! {(inner_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE))}; + + let field_ptr = expression! { add([ptr.expr()], [field_offset] )}; + let copy_expr = if field_ty.is_aggregate(db.upcast()) || field_ty.is_string(db.upcast()) { + // Call ptr copy function if field type is aggregate. + debug_assert!(field_arg_ty.is_ptr(db.upcast())); + provider.ptr_copy( + db, + field_arg.expr(), + field_ptr, + literal_expression! {(field_ty_size)}, + field_arg_ty.is_sptr(db.upcast()), + is_sptr, + ) + } else { + // Call store function if field type is not aggregate. + provider.ptr_store(db, field_ptr, field_arg.expr(), field_ptr_ty) + }; + body.push(yul::Statement::Expression(copy_expr)); + } + + let func_name = identifier! {(func_name)}; + let parameters = std::iter::once(ptr) + .chain(iter_field_args()) + .map(|var| var.ident()) + .collect(); + let func_def = yul::FunctionDefinition { + name: func_name, + parameters, + returns: vec![], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_enum_init( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + arg_tys: Vec, +) -> RuntimeFunction { + debug_assert!(arg_tys.len() > 1); + + let func_name = YulVariable::new(func_name); + let is_sptr = legalized_ty.is_sptr(db.upcast()); + let ptr = YulVariable::new("ptr"); + let disc = YulVariable::new("disc"); + let disc_ty = arg_tys[0]; + let enum_data = || (0..arg_tys.len() - 1).map(|i| YulVariable::new(format! {"arg{i}"})); + + let tuple_def = TupleDef { + items: arg_tys + .iter() + .map(|ty| ty.deref(db.upcast())) + .skip(1) + .collect(), + }; + let tuple_ty = db.mir_intern_type( + Type { + kind: TypeKind::Tuple(tuple_def), + analyzer_ty: None, + } + .into(), + ); + let data_ptr_ty = make_ptr(db, tuple_ty, is_sptr); + let data_offset = legalized_ty + .deref(db.upcast()) + .enum_data_offset(db.upcast(), SLOT_SIZE); + let enum_data_init = statements! { + [statement! {[ptr.ident()] := add([ptr.expr()], [literal_expression!{(data_offset)}])}] + [yul::Statement::Expression(provider.aggregate_init( + db, + ptr.expr(), + enum_data().map(|arg| arg.expr()).collect(), + data_ptr_ty, arg_tys.iter().skip(1).copied().collect()))] + }; + + let enum_data_args: Vec<_> = enum_data().map(|var| var.ident()).collect(); + let func_def = function_definition! { + function [func_name.ident()]([ptr.ident()], [disc.ident()], [enum_data_args...]) { + [yul::Statement::Expression(provider.ptr_store(db, ptr.expr(), disc.expr(), make_ptr(db, disc_ty, is_sptr)))] + [enum_data_init...] + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_string_copy( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + data: &str, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let dst_ptr = YulVariable::new("dst_ptr"); + let symbol_name = literal_expression! { (format!(r#""{}""#, db.codegen_constant_string_symbol_name(data.to_string()))) }; + + let func = if is_dst_storage { + let tmp_ptr = YulVariable::new("tmp_ptr"); + let data_size = YulVariable::new("data_size"); + function_definition! { + function [func_name.ident()]([dst_ptr.ident()]) { + (let [tmp_ptr.ident()] := [provider.avail(db)]) + (let data_offset := dataoffset([symbol_name.clone()])) + (let [data_size.ident()] := datasize([symbol_name])) + (let len_slot := div([dst_ptr.expr()], 32)) + (sstore(len_slot, [data_size.expr()])) + (datacopy([tmp_ptr.expr()], data_offset, [data_size.expr()])) + ([dst_ptr.ident()] := add([dst_ptr.expr()], 32)) + ([yul::Statement::Expression( + provider.ptr_copy(db, tmp_ptr.expr(), dst_ptr.expr(), data_size.expr(), false, true)) + ]) + } + } + } else { + function_definition! { + function [func_name.ident()]([dst_ptr.ident()]) { + (let data_offset := dataoffset([symbol_name.clone()])) + (let data_size := datasize([symbol_name])) + (mstore([dst_ptr.expr()], data_size)) + ([dst_ptr.ident()] := add([dst_ptr.expr()], 32)) + (datacopy([dst_ptr.expr()], data_offset, data_size)) + } + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_string_construct( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + data: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let ptr_size = YulVariable::new("ptr_size"); + let string_ptr = YulVariable::new("string_ptr"); + + let func = function_definition! { + function [func_name.ident()]([ptr_size.ident()]) -> [string_ptr.ident()] { + ([string_ptr.ident()] := [provider.alloc(db, ptr_size.expr())]) + ([yul::Statement::Expression(provider.string_copy(db, string_ptr.expr(), data, false))]) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_map_value_ptr_with_primitive_key( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + key_ty: TypeId, +) -> RuntimeFunction { + debug_assert!(key_ty.is_primitive(db.upcast())); + let scratch_space = literal_expression! {(HASH_SCRATCH_SPACE_START)}; + let scratch_size = literal_expression! {(HASH_SCRATCH_SPACE_SIZE)}; + let func_name = YulVariable::new(func_name); + let map_ptr = YulVariable::new("map_ptr"); + let key = YulVariable::new("key"); + let yul_primitive_type = yul_primitive_type(db); + + let mask = BitMask::new(1).not(); + + let func = function_definition! { + function [func_name.ident()]([map_ptr.ident()], [key.ident()]) -> ret { + ([yul::Statement::Expression(provider.ptr_store( + db, + scratch_space.clone(), + key.expr(), + yul_primitive_type.make_mptr(db.upcast()), + ))]) + ([yul::Statement::Expression(provider.ptr_store( + db, + expression!(add([scratch_space.clone()], 32)), + map_ptr.expr(), + yul_primitive_type.make_mptr(db.upcast()), + ))]) + (ret := and([mask.as_expr()], (keccak256([scratch_space], [scratch_size])))) + }}; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_map_value_ptr_with_ptr_key( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + key_ty: TypeId, +) -> RuntimeFunction { + debug_assert!(key_ty.is_ptr(db.upcast())); + + let func_name = YulVariable::new(func_name); + let size = literal_expression! {(key_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE))}; + let map_ptr = YulVariable::new("map_ptr"); + let key = YulVariable::new("key"); + + let key_hash = expression! { keccak256([key.expr()], [size]) }; + let u256_ty = yul_primitive_type(db); + let def = function_definition! { + function [func_name.ident()]([map_ptr.ident()], [key.ident()]) -> ret { + (ret := [provider.map_value_ptr(db, map_ptr.expr(), key_hash, u256_ty)]) + } + }; + RuntimeFunction::from_statement(def) +} diff --git a/crates/codegen2/src/yul/runtime/emit.rs b/crates/codegen2/src/yul/runtime/emit.rs new file mode 100644 index 0000000000..cfe0920ee4 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/emit.rs @@ -0,0 +1,74 @@ +use crate::{ + yul::{runtime::make_ptr, slot_size::SLOT_SIZE, YulVariable}, + CodegenDb, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::TypeId; + +use yultsur::*; + +pub(super) fn make_emit( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let event_ptr = YulVariable::new("event_ptr"); + let deref_ty = legalized_ty.deref(db.upcast()); + + let abi = db.codegen_abi_event(deref_ty); + let mut topics = vec![literal_expression! {(format!("0x{}", abi.signature().hash_hex()))}]; + for (idx, field) in abi.inputs.iter().enumerate() { + if !field.indexed { + continue; + } + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let offset = + literal_expression! {(deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE))}; + let elem_ptr = expression! { add([event_ptr.expr()], [offset]) }; + let topic = if field_ty.is_aggregate(db.upcast()) { + todo!() + } else { + let topic = provider.ptr_load( + db, + elem_ptr, + make_ptr(db, field_ty, legalized_ty.is_sptr(db.upcast())), + ); + provider.primitive_cast(db, topic, field_ty) + }; + + topics.push(topic) + } + + let mut event_data_tys = vec![]; + let mut event_data_values = vec![]; + for (idx, field) in abi.inputs.iter().enumerate() { + if field.indexed { + continue; + } + + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_offset = + literal_expression! { (deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + event_data_tys.push(make_ptr(db, field_ty, legalized_ty.is_sptr(db.upcast()))); + event_data_values.push(expression! { add([event_ptr.expr()], [field_offset]) }); + } + + debug_assert!(topics.len() < 5); + let log_func = identifier! { (format!("log{}", topics.len()))}; + + let event_data_ptr = YulVariable::new("event_data_ptr"); + let event_enc_size = YulVariable::new("event_enc_size"); + let func = function_definition! { + function [func_name.ident()]([event_ptr.ident()]) { + (let [event_data_ptr.ident()] := [provider.avail(db)]) + (let [event_enc_size.ident()] := [provider.abi_encode_seq(db, &event_data_values, event_data_ptr.expr(), &event_data_tys, false )]) + ([log_func]([event_data_ptr.expr()], [event_enc_size.expr()], [topics...])) + } + }; + + RuntimeFunction::from_statement(func) +} diff --git a/crates/codegen2/src/yul/runtime/mod.rs b/crates/codegen2/src/yul/runtime/mod.rs new file mode 100644 index 0000000000..e5b7856d91 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/mod.rs @@ -0,0 +1,828 @@ +mod abi; +mod contract; +mod data; +mod emit; +mod revert; +mod safe_math; + +use std::fmt::Write; + +use fe_abi::types::AbiType; +use fe_mir::ir::{types::ArrayDef, FunctionId, TypeId, TypeKind}; +use hir::hir_def::Contract; +use indexmap::IndexMap; +use yultsur::*; + +use num_bigint::BigInt; + +use crate::{yul::slot_size::SLOT_SIZE, CodegenDb}; + +use super::slot_size::yul_primitive_type; + +pub trait RuntimeProvider { + fn collect_definitions(&self) -> Vec; + + fn alloc(&mut self, db: &dyn CodegenDb, size: yul::Expression) -> yul::Expression; + + fn avail(&mut self, db: &dyn CodegenDb) -> yul::Expression; + + fn create( + &mut self, + db: &dyn CodegenDb, + contract: Contract, + value: yul::Expression, + ) -> yul::Expression; + + fn create2( + &mut self, + db: &dyn CodegenDb, + contract: Contract, + value: yul::Expression, + salt: yul::Expression, + ) -> yul::Expression; + + fn emit( + &mut self, + db: &dyn CodegenDb, + event: yul::Expression, + event_ty: TypeId, + ) -> yul::Expression; + + fn revert( + &mut self, + db: &dyn CodegenDb, + arg: Option, + arg_name: &str, + arg_ty: TypeId, + ) -> yul::Expression; + + fn external_call( + &mut self, + db: &dyn CodegenDb, + function: FunctionId, + args: Vec, + ) -> yul::Expression; + + fn map_value_ptr( + &mut self, + db: &dyn CodegenDb, + map_ptr: yul::Expression, + key: yul::Expression, + key_ty: TypeId, + ) -> yul::Expression; + + fn aggregate_init( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + args: Vec, + ptr_ty: TypeId, + arg_tys: Vec, + ) -> yul::Expression; + + fn string_copy( + &mut self, + db: &dyn CodegenDb, + dst: yul::Expression, + data: &str, + is_dst_storage: bool, + ) -> yul::Expression; + + fn string_construct( + &mut self, + db: &dyn CodegenDb, + data: &str, + string_len: usize, + ) -> yul::Expression; + + /// Copy data from `src` to `dst`. + /// NOTE: src and dst must be aligned by 32 when a ptr is storage ptr. + fn ptr_copy( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + size: yul::Expression, + is_src_storage: bool, + is_dst_storage: bool, + ) -> yul::Expression; + + fn ptr_store( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + imm: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression; + + fn ptr_load( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression; + + fn abi_encode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + src_ty: TypeId, + is_dst_storage: bool, + ) -> yul::Expression; + + fn abi_encode_seq( + &mut self, + db: &dyn CodegenDb, + src: &[yul::Expression], + dst: yul::Expression, + src_tys: &[TypeId], + is_dst_storage: bool, + ) -> yul::Expression; + + fn abi_decode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + size: yul::Expression, + types: &[TypeId], + abi_loc: AbiSrcLocation, + ) -> yul::Expression; + + fn primitive_cast( + &mut self, + db: &dyn CodegenDb, + value: yul::Expression, + from_ty: TypeId, + ) -> yul::Expression { + debug_assert!(from_ty.is_primitive(db.upcast())); + let from_size = from_ty.size_of(db.upcast(), SLOT_SIZE); + + if from_ty.is_signed(db.upcast()) { + let significant = literal_expression! {(from_size-1)}; + expression! { signextend([significant], [value]) } + } else { + let mask = BitMask::new(from_size); + expression! { and([value], [mask.as_expr()]) } + } + } + + // TODO: The all functions below will be reimplemented in `std`. + fn safe_add( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_sub( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_mul( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_div( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_mod( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_pow( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; +} + +#[derive(Clone, Copy, Debug)] +pub enum AbiSrcLocation { + CallData, + Memory, +} + +#[derive(Debug, Default)] +pub struct DefaultRuntimeProvider { + functions: IndexMap, +} + +impl DefaultRuntimeProvider { + fn create_then_call( + &mut self, + name: &str, + args: Vec, + func_builder: F, + ) -> yul::Expression + where + F: FnOnce(&mut Self) -> RuntimeFunction, + { + if let Some(func) = self.functions.get(name) { + func.call(args) + } else { + let func = func_builder(self); + let result = func.call(args); + self.functions.insert(name.to_string(), func); + result + } + } +} + +impl RuntimeProvider for DefaultRuntimeProvider { + fn collect_definitions(&self) -> Vec { + self.functions + .values() + .map(RuntimeFunction::definition) + .collect() + } + + fn alloc(&mut self, _db: &dyn CodegenDb, bytes: yul::Expression) -> yul::Expression { + let name = "$alloc"; + let arg = vec![bytes]; + self.create_then_call(name, arg, |_| data::make_alloc(name)) + } + + fn avail(&mut self, _db: &dyn CodegenDb) -> yul::Expression { + let name = "$avail"; + let arg = vec![]; + self.create_then_call(name, arg, |_| data::make_avail(name)) + } + + fn create( + &mut self, + db: &dyn CodegenDb, + contract: Contract, + value: yul::Expression, + ) -> yul::Expression { + let name = format!("$create_{}", db.codegen_contract_symbol_name(contract)); + let arg = vec![value]; + self.create_then_call(&name, arg, |provider| { + contract::make_create(provider, db, &name, contract) + }) + } + + fn create2( + &mut self, + db: &dyn CodegenDb, + contract: Contract, + value: yul::Expression, + salt: yul::Expression, + ) -> yul::Expression { + let name = format!("$create2_{}", db.codegen_contract_symbol_name(contract)); + let arg = vec![value, salt]; + self.create_then_call(&name, arg, |provider| { + contract::make_create2(provider, db, &name, contract) + }) + } + + fn emit( + &mut self, + db: &dyn CodegenDb, + event: yul::Expression, + event_ty: TypeId, + ) -> yul::Expression { + let name = format!("$emit_{}", event_ty.0); + let legalized_ty = db.codegen_legalized_type(event_ty); + self.create_then_call(&name, vec![event], |provider| { + emit::make_emit(provider, db, &name, legalized_ty) + }) + } + + fn revert( + &mut self, + db: &dyn CodegenDb, + arg: Option, + arg_name: &str, + arg_ty: TypeId, + ) -> yul::Expression { + let func_name = format! {"$revert_{}_{}", arg_name, arg_ty.0}; + let args = match arg { + Some(arg) => vec![arg], + None => vec![], + }; + self.create_then_call(&func_name, args, |provider| { + revert::make_revert(provider, db, &func_name, arg_name, arg_ty) + }) + } + + fn external_call( + &mut self, + db: &dyn CodegenDb, + function: FunctionId, + args: Vec, + ) -> yul::Expression { + let name = format!( + "$call_external__{}", + db.codegen_function_symbol_name(function) + ); + self.create_then_call(&name, args, |provider| { + contract::make_external_call(provider, db, &name, function) + }) + } + + fn map_value_ptr( + &mut self, + db: &dyn CodegenDb, + map_ptr: yul::Expression, + key: yul::Expression, + key_ty: TypeId, + ) -> yul::Expression { + if key_ty.is_primitive(db.upcast()) { + let name = "$map_value_ptr_with_primitive_key"; + self.create_then_call(name, vec![map_ptr, key], |provider| { + data::make_map_value_ptr_with_primitive_key(provider, db, name, key_ty) + }) + } else if key_ty.is_mptr(db.upcast()) { + let name = "$map_value_ptr_with_ptr_key"; + self.create_then_call(name, vec![map_ptr, key], |provider| { + data::make_map_value_ptr_with_ptr_key(provider, db, name, key_ty) + }) + } else { + unreachable!() + } + } + + fn aggregate_init( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + mut args: Vec, + ptr_ty: TypeId, + arg_tys: Vec, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let deref_ty = ptr_ty.deref(db.upcast()); + + // Handle unit enum variant. + if args.len() == 1 && deref_ty.is_enum(db.upcast()) { + let tag = args.pop().unwrap(); + let tag_ty = arg_tys[0]; + let is_sptr = ptr_ty.is_sptr(db.upcast()); + return self.ptr_store(db, ptr, tag, make_ptr(db, tag_ty, is_sptr)); + } + + let deref_ty = ptr_ty.deref(db.upcast()); + let args = std::iter::once(ptr).chain(args).collect(); + let legalized_ty = db.codegen_legalized_type(ptr_ty); + if deref_ty.is_enum(db.upcast()) { + let mut name = format!("enum_init_{}", ptr_ty.0); + for ty in &arg_tys { + write!(&mut name, "_{}", ty.0).unwrap(); + } + self.create_then_call(&name, args, |provider| { + data::make_enum_init(provider, db, &name, legalized_ty, arg_tys) + }) + } else { + let name = format!("$aggregate_init_{}", ptr_ty.0); + self.create_then_call(&name, args, |provider| { + data::make_aggregate_init(provider, db, &name, legalized_ty, arg_tys) + }) + } + } + + fn string_copy( + &mut self, + db: &dyn CodegenDb, + dst: yul::Expression, + data: &str, + is_dst_storage: bool, + ) -> yul::Expression { + debug_assert!(data.is_ascii()); + let symbol_name = db.codegen_constant_string_symbol_name(data.to_string()); + + let name = if is_dst_storage { + format!("$string_copy_{symbol_name}_storage") + } else { + format!("$string_copy_{symbol_name}_memory") + }; + + self.create_then_call(&name, vec![dst], |provider| { + data::make_string_copy(provider, db, &name, data, is_dst_storage) + }) + } + + fn string_construct( + &mut self, + db: &dyn CodegenDb, + data: &str, + string_len: usize, + ) -> yul::Expression { + debug_assert!(data.is_ascii()); + debug_assert!(string_len >= data.len()); + let symbol_name = db.codegen_constant_string_symbol_name(data.to_string()); + + let name = format!("$string_construct_{symbol_name}"); + let arg = literal_expression!((32 + string_len)); + self.create_then_call(&name, vec![arg], |provider| { + data::make_string_construct(provider, db, &name, data) + }) + } + + fn ptr_copy( + &mut self, + _db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + size: yul::Expression, + is_src_storage: bool, + is_dst_storage: bool, + ) -> yul::Expression { + let args = vec![src, dst, size]; + match (is_src_storage, is_dst_storage) { + (true, true) => { + let name = "scopys"; + self.create_then_call(name, args, |_| data::make_scopys(name)) + } + (true, false) => { + let name = "scopym"; + self.create_then_call(name, args, |_| data::make_scopym(name)) + } + (false, true) => { + let name = "mcopys"; + self.create_then_call(name, args, |_| data::make_mcopys(name)) + } + (false, false) => { + let name = "mcopym"; + self.create_then_call(name, args, |_| data::make_mcopym(name)) + } + } + } + + fn ptr_store( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + imm: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let size = ptr_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE); + debug_assert!(size <= 32); + + let size_bits = size * 8; + if ptr_ty.is_sptr(db.upcast()) { + let name = "$sptr_store"; + let args = vec![ptr, imm, literal_expression! {(size_bits)}]; + self.create_then_call(name, args, |_| data::make_sptr_store(name)) + } else if ptr_ty.is_mptr(db.upcast()) { + let name = "$mptr_store"; + let shift_num = literal_expression! {(256 - size_bits)}; + let mask = BitMask::new(32 - size); + let args = vec![ptr, imm, shift_num, mask.as_expr()]; + self.create_then_call(name, args, |_| data::make_mptr_store(name)) + } else { + unreachable!() + } + } + + fn ptr_load( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let size = ptr_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE); + debug_assert!(size <= 32); + + let size_bits = size * 8; + if ptr_ty.is_sptr(db.upcast()) { + let name = "$sptr_load"; + let args = vec![ptr, literal_expression! {(size_bits)}]; + self.create_then_call(name, args, |_| data::make_sptr_load(name)) + } else if ptr_ty.is_mptr(db.upcast()) { + let name = "$mptr_load"; + let shift_num = literal_expression! {(256 - size_bits)}; + let args = vec![ptr, shift_num]; + self.create_then_call(name, args, |_| data::make_mptr_load(name)) + } else { + unreachable!() + } + } + + fn abi_encode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + src_ty: TypeId, + is_dst_storage: bool, + ) -> yul::Expression { + let legalized_ty = db.codegen_legalized_type(src_ty); + let args = vec![src.clone(), dst.clone()]; + + let func_name_postfix = if is_dst_storage { "storage" } else { "memory" }; + + if legalized_ty.is_primitive(db.upcast()) { + let name = format!( + "$abi_encode_primitive_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + return self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_primitive_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }); + } + + let deref_ty = legalized_ty.deref(db.upcast()); + let abi_ty = db.codegen_abi_type(deref_ty); + match abi_ty { + AbiType::UInt(_) | AbiType::Int(_) | AbiType::Bool | AbiType::Address => { + let value = self.ptr_load(db, src, src_ty); + let extended_value = self.primitive_cast(db, value, deref_ty); + self.abi_encode(db, extended_value, dst, deref_ty, is_dst_storage) + } + AbiType::Array { elem_ty, .. } => { + if elem_ty.is_static() { + let name = format!( + "$abi_encode_static_array_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_static_array_type(provider, db, &name, legalized_ty) + }) + } else { + let name = format! { + "$abi_encode_dynamic_array_type{}_to_{}", src_ty.0, func_name_postfix + }; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_dynamic_array_type(provider, db, &name, legalized_ty) + }) + } + } + AbiType::Tuple(_) => { + if abi_ty.is_static() { + let name = format!( + "$abi_encode_static_aggregate_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_static_aggregate_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }) + } else { + let name = format!( + "$abi_encode_dynamic_aggregate_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_dynamic_aggregate_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }) + } + } + AbiType::Bytes => { + let len = match &deref_ty.data(db.upcast()).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + _ => unreachable!(), + }; + let name = format! {"$abi_encode_bytes{len}_type_to_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_bytes_type(provider, db, &name, len, is_dst_storage) + }) + } + AbiType::String => { + let name = format! {"$abi_encode_string_type_to_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_string_type(provider, db, &name, is_dst_storage) + }) + } + AbiType::Function => unreachable!(), + } + } + + fn abi_encode_seq( + &mut self, + db: &dyn CodegenDb, + src: &[yul::Expression], + dst: yul::Expression, + src_tys: &[TypeId], + is_dst_storage: bool, + ) -> yul::Expression { + let mut name = "$abi_encode_value_seq".to_string(); + for ty in src_tys { + write!(&mut name, "_{}", ty.0).unwrap(); + } + + let mut args = vec![dst]; + args.extend(src.iter().cloned()); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_seq(provider, db, &name, src_tys, is_dst_storage) + }) + } + + fn abi_decode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + size: yul::Expression, + types: &[TypeId], + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let mut name = "$abi_decode".to_string(); + for ty in types { + write!(name, "_{}", ty.0).unwrap(); + } + + match abi_loc { + AbiSrcLocation::CallData => write!(name, "_from_calldata").unwrap(), + AbiSrcLocation::Memory => write!(name, "_from_memory").unwrap(), + }; + + self.create_then_call(&name, vec![src, size], |provider| { + abi::make_abi_decode(provider, db, &name, types, abi_loc) + }) + } + + fn safe_add( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_add(self, db, lhs, rhs, ty) + } + + fn safe_sub( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_sub(self, db, lhs, rhs, ty) + } + + fn safe_mul( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_mul(self, db, lhs, rhs, ty) + } + + fn safe_div( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_div(self, db, lhs, rhs, ty) + } + + fn safe_mod( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_mod(self, db, lhs, rhs, ty) + } + + fn safe_pow( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_pow(self, db, lhs, rhs, ty) + } +} + +#[derive(Debug)] +struct RuntimeFunction(yul::FunctionDefinition); + +impl RuntimeFunction { + fn arg_num(&self) -> usize { + self.0.parameters.len() + } + + fn definition(&self) -> yul::FunctionDefinition { + self.0.clone() + } + + /// # Panics + /// Panics if a number of arguments doesn't match the definition. + fn call(&self, args: Vec) -> yul::Expression { + debug_assert_eq!(self.arg_num(), args.len()); + + yul::Expression::FunctionCall(yul::FunctionCall { + identifier: self.0.name.clone(), + arguments: args, + }) + } + + /// Remove this when `yultsur::function_definition!` becomes to return + /// `FunctionDefinition`. + fn from_statement(func: yul::Statement) -> Self { + match func { + yul::Statement::FunctionDefinition(def) => Self(def), + _ => unreachable!(), + } + } +} + +fn make_ptr(db: &dyn CodegenDb, inner: TypeId, is_sptr: bool) -> TypeId { + if is_sptr { + inner.make_sptr(db.upcast()) + } else { + inner.make_mptr(db.upcast()) + } +} + +struct BitMask(BigInt); + +impl BitMask { + fn new(byte_size: usize) -> Self { + debug_assert!(byte_size <= 32); + let one: BigInt = 1usize.into(); + Self((one << (byte_size * 8)) - 1) + } + + fn not(&self) -> Self { + // Bigint is variable length integer, so we need special handling for `not` + // operation. + let one: BigInt = 1usize.into(); + let u256_max = (one << 256) - 1; + Self(u256_max ^ &self.0) + } + + fn as_expr(&self) -> yul::Expression { + let mask = format!("{:#x}", self.0); + literal_expression! {(mask)} + } +} + +pub(super) fn error_revert_numeric( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + error_code: yul::Expression, +) -> yul::Statement { + yul::Statement::Expression(provider.revert( + db, + Some(error_code), + "Error", + yul_primitive_type(db), + )) +} + +pub(super) fn panic_revert_numeric( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + error_code: yul::Expression, +) -> yul::Statement { + yul::Statement::Expression(provider.revert( + db, + Some(error_code), + "Panic", + yul_primitive_type(db), + )) +} diff --git a/crates/codegen2/src/yul/runtime/revert.rs b/crates/codegen2/src/yul/runtime/revert.rs new file mode 100644 index 0000000000..396e07e76d --- /dev/null +++ b/crates/codegen2/src/yul/runtime/revert.rs @@ -0,0 +1,91 @@ +use crate::{ + yul::{slot_size::function_hash_type, YulVariable}, + CodegenDb, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_abi::function::{AbiFunction, AbiFunctionType, StateMutability}; +use fe_mir::ir::{self, TypeId}; +use yultsur::*; + +pub(super) fn make_revert( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + arg_name: &str, + arg_ty: TypeId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let arg = YulVariable::new("arg"); + + let abi_size = YulVariable::new("abi_size"); + let abi_tmp_ptr = YulVariable::new("$abi_tmp_ptr"); + let signature = type_signature_for_revert(db, arg_name, arg_ty); + + let signature_store = yul::Statement::Expression(provider.ptr_store( + db, + abi_tmp_ptr.expr(), + signature, + function_hash_type(db).make_mptr(db.upcast()), + )); + + let func = if arg_ty.deref(db.upcast()).is_zero_sized(db.upcast()) { + function_definition! { + function [func_name.ident()]() { + (let [abi_tmp_ptr.ident()] := [provider.avail(db)]) + ([signature_store]) + (revert([abi_tmp_ptr.expr()], [literal_expression!{4}])) + } + } + } else { + let encode = provider.abi_encode_seq( + db, + &[arg.expr()], + expression! { add([abi_tmp_ptr.expr()], 4) }, + &[arg_ty], + false, + ); + + function_definition! { + function [func_name.ident()]([arg.ident()]) { + (let [abi_tmp_ptr.ident()] := [provider.avail(db)]) + ([signature_store]) + (let [abi_size.ident()] := add([encode], 4)) + (revert([abi_tmp_ptr.expr()], [abi_size.expr()])) + } + } + }; + + RuntimeFunction::from_statement(func) +} + +/// Returns signature hash of the type. +fn type_signature_for_revert(db: &dyn CodegenDb, name: &str, ty: TypeId) -> yul::Expression { + let deref_ty = ty.deref(db.upcast()); + let ty_data = deref_ty.data(db.upcast()); + let args = match &ty_data.kind { + ir::TypeKind::Struct(def) => def + .fields + .iter() + .map(|(_, ty)| ("".to_string(), db.codegen_abi_type(*ty))) + .collect(), + _ => { + let abi_ty = db.codegen_abi_type(deref_ty); + vec![("_".to_string(), abi_ty)] + } + }; + + // selector and state mutability is independent we can set has_self and has_ctx + // any value. + let selector = AbiFunction::new( + AbiFunctionType::Function, + name.to_string(), + args, + None, + StateMutability::Pure, + ) + .selector(); + let type_sig = selector.hex(); + literal_expression! {(format!{"0x{type_sig}" })} +} diff --git a/crates/codegen2/src/yul/runtime/safe_math.rs b/crates/codegen2/src/yul/runtime/safe_math.rs new file mode 100644 index 0000000000..12eb87b1f8 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/safe_math.rs @@ -0,0 +1,628 @@ +use crate::{yul::YulVariable, CodegenDb}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::{TypeId, TypeKind}; + +use yultsur::*; + +pub(super) fn dispatch_safe_add( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_add_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_add_signed(provider, db, name) + }) + } else { + let name = "$safe_add_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_add_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_sub( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_sub_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_sub_signed(provider, db, name) + }) + } else { + let name = "$safe_sub_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_sub_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_mul( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_mul_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_mul_signed(provider, db, name) + }) + } else { + let name = "$safe_mul_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_mul_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_div( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_div_signed"; + let args = vec![lhs, rhs, min_value]; + provider.create_then_call(name, args, |provider| { + make_safe_div_signed(provider, db, name) + }) + } else { + let name = "$safe_div_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_div_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_mod( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + if ty.is_signed(db.upcast()) { + let name = "$safe_mod_signed"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_mod_signed(provider, db, name) + }) + } else { + let name = "$safe_mod_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_mod_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_pow( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_pow_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_pow_signed(provider, db, name) + }) + } else { + let name = "$safe_pow_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_pow_unsigned(provider, db, name) + }) + } +} + +fn make_safe_add_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (and((iszero((slt([lhs.expr()], 0)))), (sgt([rhs.expr()], (sub([max_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + (if (and((slt([lhs.expr()], 0)), (slt([rhs.expr()], (sub([min_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := add([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_add_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (gt([lhs.expr()], (sub([max_value.expr()], [rhs.expr()])))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := add([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_sub_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (and((iszero((slt([rhs.expr()], 0)))), (slt([lhs.expr()], (add([min_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + (if (and((slt([rhs.expr()], 0)), (sgt([lhs.expr()], (add([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sub([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_sub_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (lt([lhs.expr()], [rhs.expr()])) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sub([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mul_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + // overflow, if lhs > 0, rhs > 0 and lhs > (max_value / rhs) + (if (and((and((sgt([lhs.expr()], 0)), (sgt([rhs.expr()], 0)))), (gt([lhs.expr()], (div([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // underflow, if lhs > 0, rhs < 0 and rhs < (min_value / lhs) + (if (and((and((sgt([lhs.expr()], 0)), (slt([rhs.expr()], 0)))), (slt([rhs.expr()], (sdiv([min_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // underflow, if lhs < 0, rhs > 0 and lhs < (min_value / rhs) + (if (and((and((slt([lhs.expr()], 0)), (sgt([rhs.expr()], 0)))), (slt([lhs.expr()], (sdiv([min_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // overflow, if lhs < 0, rhs < 0 and lhs < (max_value / rhs) + (if (and((and((slt([lhs.expr()], 0)), (slt([rhs.expr()], 0)))), (slt([lhs.expr()], (sdiv([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := mul([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mul_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [max_value.ident()]) -> [ret.ident()] { + // overflow, if lhs != 0 and rhs > (max_value / lhs) + (if (and((iszero((iszero([lhs.expr()])))), (gt([rhs.expr()], (div([max_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider ,db)] }) + ([ret.ident()] := mul([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_div_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + (if (and( (eq([lhs.expr()], [min_value.expr()])), (eq([rhs.expr()], (sub(0, 1))))) ) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sdiv([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_div_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := div([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mod_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := smod([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mod_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := mod([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +const SAFE_POW_HELPER_NAME: &str = "safe_pow_helper"; + +fn make_safe_pow_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let max_value = YulVariable::new("max_value"); + let power = YulVariable::new("power"); + + let safe_pow_helper_call = yul::Statement::Assignment(yul::Assignment { + identifiers: vec![base.ident(), power.ident()], + expression: { + let args = vec![ + base.expr(), + exponent.expr(), + literal_expression! {1}, + max_value.expr(), + ]; + provider.create_then_call(SAFE_POW_HELPER_NAME, args, |provider| { + make_safe_exp_helper(provider, db, SAFE_POW_HELPER_NAME) + }) + }, + }); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [max_value.ident()]) -> [power.ident()] { + // Currently, `leave` avoids this function being inlined. + // YUL team is working on optimizer improvements to fix that. + + // Note that 0**0 == 1 + (if (iszero([exponent.expr()])) { + ([power.ident()] := 1 ) + (leave) + }) + (if (iszero([base.expr()])) { + ([power.ident()] := 0 ) + (leave) + }) + // Specializations for small bases + ([switch! { + switch [base.expr()] + // 0 is handled above + (case 1 { + ([power.ident()] := 1 ) + (leave) + }) + (case 2 { + (if (gt([exponent.expr()], 255)) { + [revert_with_overflow(provider, db)] + }) + ([power.ident()] := (exp(2, [exponent.expr()]))) + (if (gt([power.expr()], [max_value.expr()])) { + [revert_with_overflow(provider, db)] + }) + (leave) + }) + }]) + (if (and((sgt([power.expr()], 0)), (gt([power.expr()], (div([max_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider, db)] }) + + (if (or((and((lt([base.expr()], 11)), (lt([exponent.expr()], 78)))), (and((lt([base.expr()], 307)), (lt([exponent.expr()], 32)))))) { + ([power.ident()] := (exp([base.expr()], [exponent.expr()]))) + (if (gt([power.expr()], [max_value.expr()])) { + [revert_with_overflow(provider, db)] + }) + (leave) + }) + + ([safe_pow_helper_call]) + (if (gt([power.expr()], (div([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + ([power.ident()] := (mul([power.expr()], [base.expr()]))) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_pow_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let min_value = YulVariable::new("min_value"); + let max_value = YulVariable::new("max_value"); + let power = YulVariable::new("power"); + + let safe_pow_helper_call = yul::Statement::Assignment(yul::Assignment { + identifiers: vec![base.ident(), power.ident()], + expression: { + let args = vec![base.expr(), exponent.expr(), power.expr(), max_value.expr()]; + provider.create_then_call(SAFE_POW_HELPER_NAME, args, |provider| { + make_safe_exp_helper(provider, db, SAFE_POW_HELPER_NAME) + }) + }, + }); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [min_value.ident()], [max_value.ident()]) -> [power.ident()] { + // Currently, `leave` avoids this function being inlined. + // YUL team is working on optimizer improvements to fix that. + + // Note that 0**0 == 1 + ([switch! { + switch [exponent.expr()] + (case 0 { + ([power.ident()] := 1 ) + (leave) + }) + (case 1 { + ([power.ident()] := [base.expr()] ) + (leave) + }) + }]) + (if (iszero([base.expr()])) { + ([power.ident()] := 0 ) + (leave) + }) + ([power.ident()] := 1 ) + // We pull out the first iteration because it is the only one in which + // base can be negative. + // Exponent is at least 2 here. + // overflow check for base * base + ([switch! { + switch (sgt([base.expr()], 0)) + (case 1 { + (if (gt([base.expr()], (div([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + }) + (case 0 { + (if (slt([base.expr()], (sdiv([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + }) + }]) + (if (and([exponent.expr()], 1)) { + ([power.ident()] := [base.expr()] ) + }) + ([base.ident()] := (mul([base.expr()], [base.expr()]))) + ([exponent.ident()] := shr(1, [exponent.expr()])) + // // Below this point, base is always positive. + ([safe_pow_helper_call]) // power = 1, base = 16 which is wrong + (if (and((sgt([power.expr()], 0)), (gt([power.expr()], (div([max_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider , db)] }) + (if (and((slt([power.expr()], 0)), (slt([power.expr()], (sdiv([min_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([power.ident()] := (mul([power.expr()], [base.expr()]))) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_exp_helper( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let power = YulVariable::new("power"); + let max_value = YulVariable::new("max_value"); + let ret_power = YulVariable::new("ret_power"); + let ret_base = YulVariable::new("ret_base"); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [power.ident()], [max_value.ident()]) + -> [(vec![ret_base.ident(), ret_power.ident()])...] { + ([ret_base.ident()] := [base.expr()]) + ([ret_power.ident()] := [power.expr()]) + (for {} (gt([exponent.expr()], 1)) {} + { + // overflow check for base * base + (if (gt([ret_base.expr()], (div([max_value.expr()], [ret_base.expr()])))) { [revert_with_overflow(provider, db)] }) + (if (and([exponent.expr()], 1)) { + // No checks for power := mul(power, base) needed, because the check + // for base * base above is sufficient, since: + // |power| <= base (proof by induction) and thus: + // |power * base| <= base * base <= max <= |min| (for signed) + // (this is equally true for signed and unsigned exp) + ([ret_power.ident()] := (mul([ret_power.expr()], [ret_base.expr()]))) + }) + ([ret_base.ident()] := mul([ret_base.expr()], [ret_base.expr()])) + ([exponent.ident()] := shr(1, [exponent.expr()])) + }) + } + }; + RuntimeFunction::from_statement(func) +} + +fn revert_with_overflow(provider: &mut dyn RuntimeProvider, db: &dyn CodegenDb) -> yul::Statement { + const PANIC_OVERFLOW: usize = 0x11; + + super::panic_revert_numeric(provider, db, literal_expression! {(PANIC_OVERFLOW)}) +} + +fn revert_with_zero_division( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, +) -> yul::Statement { + pub const PANIC_ZERO_DIVISION: usize = 0x12; + + super::panic_revert_numeric(provider, db, literal_expression! {(PANIC_ZERO_DIVISION)}) +} + +fn get_max_value(db: &dyn CodegenDb, ty: TypeId) -> yul::Expression { + match &ty.data(db.upcast()).kind { + TypeKind::I8 => literal_expression! {0x7f}, + TypeKind::I16 => literal_expression! {0x7fff}, + TypeKind::I32 => literal_expression! {0x7fffffff}, + TypeKind::I64 => literal_expression! {0x7fffffffffffffff}, + TypeKind::I128 => literal_expression! {0x7fffffffffffffffffffffffffffffff}, + TypeKind::I256 => { + literal_expression! {0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff} + } + TypeKind::U8 => literal_expression! {0xff}, + TypeKind::U16 => literal_expression! {0xffff}, + TypeKind::U32 => literal_expression! {0xffffffff}, + TypeKind::U64 => literal_expression! {0xffffffffffffffff}, + TypeKind::U128 => literal_expression! {0xffffffffffffffffffffffffffffffff}, + TypeKind::U256 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff} + } + _ => unreachable!(), + } +} + +fn get_min_value(db: &dyn CodegenDb, ty: TypeId) -> yul::Expression { + debug_assert! {ty.is_integral(db.upcast())}; + + match &ty.data(db.upcast()).kind { + TypeKind::I8 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff80} + } + TypeKind::I16 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8000} + } + TypeKind::I32 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffff80000000} + } + TypeKind::I64 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffff8000000000000000} + } + TypeKind::I128 => { + literal_expression! {0xffffffffffffffffffffffffffffffff80000000000000000000000000000000} + } + TypeKind::I256 => { + literal_expression! {0x8000000000000000000000000000000000000000000000000000000000000000} + } + + _ => literal_expression! {0x0}, + } +} diff --git a/crates/codegen2/src/yul/slot_size.rs b/crates/codegen2/src/yul/slot_size.rs new file mode 100644 index 0000000000..fdf4964b09 --- /dev/null +++ b/crates/codegen2/src/yul/slot_size.rs @@ -0,0 +1,16 @@ +use fe_mir::ir::{Type, TypeId, TypeKind}; + +use crate::CodegenDb; + +// We use the same slot size between memory and storage to simplify the +// implementation and minimize gas consumption in memory <-> storage copy +// instructions. +pub(crate) const SLOT_SIZE: usize = 32; + +pub(crate) fn yul_primitive_type(db: &dyn CodegenDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::U256, None).into()) +} + +pub(crate) fn function_hash_type(db: &dyn CodegenDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::U32, None).into()) +} diff --git a/crates/library2/Cargo.toml b/crates/library2/Cargo.toml new file mode 100644 index 0000000000..d2ddd0ed7d --- /dev/null +++ b/crates/library2/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "fe-library2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +include_dir = "0.7.2" +common = { path = "../common2", package = "fe-common2" } diff --git a/crates/library2/build.rs b/crates/library2/build.rs new file mode 100644 index 0000000000..0ce78ee5e7 --- /dev/null +++ b/crates/library2/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rerun-if-changed=./std"); +} diff --git a/crates/library2/src/lib.rs b/crates/library2/src/lib.rs new file mode 100644 index 0000000000..48310e1eda --- /dev/null +++ b/crates/library2/src/lib.rs @@ -0,0 +1,54 @@ +use std::collections::BTreeSet; + +pub use ::include_dir; +use common::{ + input::{IngotKind, Version}, + InputDb, InputFile, InputIngot, +}; +use include_dir::{include_dir, Dir}; + +pub const STD: Dir = include_dir!("$CARGO_MANIFEST_DIR/std"); + +fn std_src_input_files(db: &mut dyn InputDb, ingot: InputIngot) -> BTreeSet { + static_dir_files(&STD) + .into_iter() + .map(|(path, content)| InputFile::new(db, ingot, path.into(), content.into())) + .collect() +} + +pub fn std_lib_input_ingot(db: &mut dyn InputDb) -> InputIngot { + let ingot = InputIngot::new( + db, + "std", + IngotKind::Std, + Version::new(0, 0, 0), + BTreeSet::default(), + ); + + let input_files = std_src_input_files(db, ingot); + ingot.set_files(db, input_files); + ingot +} + +// pub fn std_src_files() -> Vec<(&'static str, &'static str)> { +// static_dir_files(STD.get_dir("src").unwrap()) +// } + +pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { + fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { + accum.extend(dir.files().map(|file| { + ( + file.path().to_str().unwrap(), + file.contents_utf8().expect("non-utf8 static file"), + ) + })); + + for sub_dir in dir.dirs() { + add_files(sub_dir, accum) + } + } + + let mut files = vec![]; + add_files(dir, &mut files); + files +} diff --git a/crates/library2/std/src/buf.fe b/crates/library2/std/src/buf.fe new file mode 100644 index 0000000000..a1d97af4e6 --- /dev/null +++ b/crates/library2/std/src/buf.fe @@ -0,0 +1,299 @@ +use ingot::evm +use ingot::math + +unsafe fn avail() -> u256 { + let ptr: u256 = evm::mload(offset: 64) + + if ptr == 0x00 { + return 96 + } else { + return ptr + } +} + +unsafe fn alloc(len: u256) -> u256 { + let ptr: u256 = avail() + evm::mstore(offset: 64, value: ptr + len) + return ptr +} + +struct Cursor { + cur: u256 + len: u256 + + pub fn new(len: u256) -> Self { + return Cursor(cur: 0, len) + } + + /// Increment the value of `cur` by `len` and return the value of `cur` before being incremented. + /// Reverts if the cursor is advanced beyond the given length. + pub fn advance(mut self, len: u256) -> u256 { + let cur: u256 = self.cur + assert cur + len < self.len + 1 + self.cur += len + return cur + } + + /// Length of the cursor remaining. + pub fn remainder(self) -> u256 { + return self.len - self.cur + } +} + +/// EVM memory buffer abstraction. +pub struct MemoryBuffer { + offset: u256 + len: u256 + + pub fn new(len: u256) -> Self { + unsafe { + return MemoryBuffer(offset: alloc(len: len + 30), len) + } + } + + pub fn from_u8(value: u8) -> Self { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 1) + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value) + return buf + } + + /// Length of the buffer in bytes. + pub fn len(self) -> u256 { + return self.len + } + + /// The start of the buffer in EVM memory. + pub fn offset(self) -> u256 { + return self.offset + } + + /// Returns a new buffer reader. + pub fn reader(self) -> MemoryBufferReader { + return MemoryBufferReader::new(buf: self) + } + + /// Returns a new buffer writer. + pub fn writer(mut self) -> MemoryBufferWriter { + return MemoryBufferWriter::new(buf: self) + } +} + +/// Memory buffer writer abstraction. +pub struct MemoryBufferWriter { + buf: MemoryBuffer + cur: Cursor + + /// Returns a new writer for the given buffer. + pub fn new(mut buf: MemoryBuffer) -> Self { + return MemoryBufferWriter( + buf, + cur: Cursor::new(len: buf.len()) + ) + } + + /// The number of bytes remaining to be written. + pub fn remainder(self) -> u256 { + return self.cur.remainder() + } + + pub fn write_offset(mut self, len: u256) -> u256 { + return self.buf.offset() + self.cur.advance(len) + } + + pub fn write_n(mut self, value: u256, len: u256) { + let offset: u256 = self.write_offset(len) + let shifted_value: u256 = evm::shl(bits: 256 - len * 8, value) + unsafe { evm::mstore(offset, value: shifted_value) } + } + + pub fn write_buf(mut self, buf: MemoryBuffer) { + let mut reader: MemoryBufferReader = buf.reader() + + while true { + let bytes_remaining: u256 = reader.remainder() + + if bytes_remaining >= 32 { + self.write(value: reader.read_u256()) + } else if bytes_remaining == 0 { + break + } else { + self.write(value: reader.read_u8()) + } + } + } + + pub fn write(mut self, value: T) { + value.write_buf(writer: self) + } +} + +pub trait MemoryBufferWrite { + fn write_buf(self, mut writer: MemoryBufferWriter); +} + +impl MemoryBufferWrite for u256 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + let offset: u256 = writer.write_offset(len: 32) + unsafe { evm::mstore(offset, value: self) } + } +} + +impl MemoryBufferWrite for u128 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 16) + } +} + +impl MemoryBufferWrite for u64 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 8) + } +} + +impl MemoryBufferWrite for u32 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 4) + } +} + +impl MemoryBufferWrite for u16 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 2) + } +} + +impl MemoryBufferWrite for u8 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + let offset: u256 = writer.write_offset(len: 1) + unsafe { evm::mstore8(offset, value: self) } + } +} + +// This is needed to prevent the `mir_lower_std_lib` to crash the compiler +impl MemoryBufferWrite for () { + fn write_buf(self, mut writer: MemoryBufferWriter) {} +} + +/// Memory buffer reader abstraction. +pub struct MemoryBufferReader { + buf: MemoryBuffer + cur: Cursor + + /// Returns a new reader for the given buffer. + pub fn new(buf: MemoryBuffer) -> Self { + return MemoryBufferReader(buf, cur: Cursor::new(len: buf.len())) + } + + /// The number of bytes remaining to be read. + pub fn remainder(self) -> u256 { + return self.cur.remainder() + } + + fn read_offset(mut self, len: u256) -> u256 { + return self.buf.offset() + self.cur.advance(len) + } + + fn read_n(mut self, len: u256) -> u256 { + let offset: u256 = self.read_offset(len) + unsafe { + let value: u256 = evm::mload(offset) + return evm::shr(bits: 256 - len * 8, value) + } + } + + pub fn read_u8(mut self) -> u8 { + return u8(self.read_n(len: 1)) + } + + pub fn read_u16(mut self) -> u16 { + return u16(self.read_n(len: 2)) + } + + pub fn read_u32(mut self) -> u32 { + return u32(self.read_n(len: 4)) + } + + pub fn read_u64(mut self) -> u64 { + return u64(self.read_n(len: 8)) + } + + pub fn read_u128(mut self) -> u128 { + return u128(self.read_n(len: 16)) + } + + pub fn read_u256(mut self) -> u256 { + let offset: u256 = self.read_offset(len: 32) + unsafe { + let value: u256 = evm::mload(offset) + return value + } + } + + pub fn read_buf(mut self, len: u256) -> MemoryBuffer { + let mut buf: MemoryBuffer = MemoryBuffer::new(len) + let mut writer: MemoryBufferWriter = buf.writer() + + while true { + let bytes_remaining: u256 = writer.remainder() + + if bytes_remaining >= 32 { + writer.write(value: self.read_u256()) + } else if bytes_remaining == 0 { + break + } else { + writer.write(value: self.read_u8()) + } + } + + return buf + } + + // `T` has not been defined + // pub fn read(mut self) -> T { + // T::read_buf(writer: self) + // } +} + +// pub trait MemoryBufferRead { +// fn read_buf(self, mut reader: MemoryBufferReader) -> Self; +// } +// +// impl MemoryBufferRead for u256 { .. } +// . +// . +// impl MemoryBufferRead for u8 { .. } + +/// `MemoryBuffer` wrapper for raw calls to other contracts. +pub struct RawCallBuffer { + input_len: u256 + output_len: u256 + buf: MemoryBuffer + + pub fn new(input_len: u256, output_len: u256) -> Self { + let len: u256 = math::max(input_len, output_len) + let buf: MemoryBuffer = MemoryBuffer::new(len) + + return RawCallBuffer(input_len, output_len, buf) + } + + pub fn input_len(self) -> u256 { + return self.input_len + } + + pub fn output_len(self) -> u256 { + return self.output_len + } + + pub fn offset(self) -> u256 { + return self.buf.offset() + } + + pub fn reader(self) -> MemoryBufferReader { + return self.buf.reader() + } + + pub fn writer(mut self) -> MemoryBufferWriter { + return self.buf.writer() + } +} diff --git a/crates/library2/std/src/context.fe b/crates/library2/std/src/context.fe new file mode 100644 index 0000000000..9b51f9ca9e --- /dev/null +++ b/crates/library2/std/src/context.fe @@ -0,0 +1,174 @@ +use ingot::evm +use ingot::error::{ + ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE, + ERROR_FAILED_SEND_VALUE, + Error +} +use ingot::buf::{ + RawCallBuffer, + MemoryBufferReader, + MemoryBufferWriter +} + +struct OutOfReachMarker {} + +// ctx.emit(my_event) should be the only way to emit an event. We achieve this by defining the +// private `OutOfReachMarker` here to which only the `Context` has access. +// Now there is no way to call `emit` directly on an Emittable. +pub trait Emittable { + fn emit(self, _ val: OutOfReachMarker); +} + +pub struct CalldataReader { + cur_offset: u256 + len: u256 + + pub unsafe fn new(len: u256) -> CalldataReader { + return CalldataReader(cur_offset: 0, len) + } + + pub fn remainder(self) -> u256 { + return self.len - self.cur_offset + } + + pub fn advance(mut self, len: u256) -> u256 { + self.cur_offset += len + assert self.cur_offset <= self.len + return self.cur_offset + } + + fn read_n(mut self, len: u256) -> u256 { + unsafe { + let value: u256 = evm::call_data_load(offset: self.cur_offset) + self.advance(len) + return evm::shr(bits: 256 - len * 8, value) + } + } + + pub fn read_u8(mut self) -> u8 { + return u8(self.read_n(len: 1)) + } + + pub fn read_u16(mut self) -> u16 { + return u16(self.read_n(len: 2)) + } + + pub fn read_u32(mut self) -> u32 { + return u32(self.read_n(len: 4)) + } + + pub fn read_u64(mut self) -> u64 { + return u64(self.read_n(len: 8)) + } + + pub fn read_u128(mut self) -> u128 { + return u128(self.read_n(len: 16)) + } + pub fn read_u256(mut self) -> u256 { + unsafe { + let value: u256 = evm::call_data_load(offset: self.cur_offset) + self.advance(len: 32) + return value + } + } +} + +pub struct Context { + pub fn base_fee(self) -> u256 { + unsafe { return evm::base_fee() } + } + + pub fn block_coinbase(self) -> address { + unsafe { return evm::coinbase() } + } + + pub fn prevrandao(self) -> u256 { + unsafe { return evm::prevrandao() } + } + + pub fn block_number(self) -> u256 { + unsafe { return evm::block_number() } + } + + pub fn block_timestamp(self) -> u256 { + unsafe { return evm::timestamp() } + } + + pub fn chain_id(self) -> u256 { + unsafe { return evm::chain_id() } + } + + pub fn msg_sender(self) -> address { + unsafe { return evm::caller() } + } + + pub fn msg_value(self) -> u256 { + unsafe { return evm::call_value() } + } + + pub fn tx_gas_price(self) -> u256 { + unsafe { return evm::gas_price() } + } + + pub fn tx_origin(self) -> address { + unsafe { return evm::origin() } + } + + pub fn msg_sig(self) -> u256 { + unsafe { return evm::shr(bits: 224, value: evm::call_data_load(offset: 0)) } + } + + pub fn balance_of(self, _ account: address) -> u256 { + unsafe { return evm::balance_of(account) } + } + + pub fn self_balance(self) -> u256 { + unsafe { return evm::balance() } + } + + pub fn self_address(self) -> address { + unsafe { return address(__address()) } + } + + pub fn calldata_reader(self) -> CalldataReader { + unsafe { + let len: u256 = evm::call_data_size() + return CalldataReader::new(len) + } + } + + pub fn send_value(mut self, to: address, wei: u256) { + unsafe { + if evm::balance() < wei { + revert Error(code: ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE) + } + let mut buf: RawCallBuffer = RawCallBuffer::new(input_len: 0, output_len: 0) + let success: bool = evm::call(gas: evm::gas_remaining(), addr: to, value: wei, + buf) + if not success { + revert Error(code: ERROR_FAILED_SEND_VALUE) + } + } + } + + /// Makes a call to the given address. + pub fn raw_call( + self, + addr: address, + value: u256, + mut buf: RawCallBuffer + ) -> bool { + unsafe { + return evm::call( + gas: evm::gas_remaining(), + addr, + value, + buf + ) + } + } + + pub fn emit(mut self, _ val: T) { + val.emit(OutOfReachMarker()) + } +} \ No newline at end of file diff --git a/crates/library2/std/src/error.fe b/crates/library2/std/src/error.fe new file mode 100644 index 0000000000..7ae066af4c --- /dev/null +++ b/crates/library2/std/src/error.fe @@ -0,0 +1,6 @@ +pub const ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE: u256 = 0x100 +pub const ERROR_FAILED_SEND_VALUE: u256 = 0x101 + +pub struct Error { + pub code: u256 +} \ No newline at end of file diff --git a/crates/library2/std/src/evm.fe b/crates/library2/std/src/evm.fe new file mode 100644 index 0000000000..09c9382556 --- /dev/null +++ b/crates/library2/std/src/evm.fe @@ -0,0 +1,325 @@ +use ingot::buf::{MemoryBuffer, RawCallBuffer} + +// Basic context accessor functions. +pub unsafe fn chain_id() -> u256 { + return __chainid() +} + +pub unsafe fn base_fee() -> u256 { + return __basefee() +} + +pub unsafe fn origin() -> address { + return address(__origin()) +} + +pub unsafe fn gas_price() -> u256 { + return __gasprice() +} + +pub unsafe fn gas_limit() -> u256 { + return __gaslimit() +} + +pub unsafe fn gas_remaining() -> u256 { + return __gas() +} + +pub unsafe fn block_hash(_ b: u256) -> u256 { + return __blockhash(b) +} + +pub unsafe fn coinbase() -> address { + return address(__coinbase()) +} + +pub unsafe fn timestamp() -> u256 { + return __timestamp() +} + +pub unsafe fn block_number() -> u256 { + return __number() +} + +pub unsafe fn prevrandao() -> u256 { + return __prevrandao() +} + +pub unsafe fn self_address() -> address { + return address(__address()) +} + +pub unsafe fn balance_of(_ addr: address) -> u256 { + return __balance(u256(addr)) +} + +pub unsafe fn balance() -> u256 { + return __selfbalance() +} + +pub unsafe fn caller() -> address { + return address(__caller()) +} + +pub unsafe fn call_value() -> u256 { + return __callvalue() +} + + +// Overflowing math ops. Should these be unsafe or named +// `overflowing_add`, etc? +pub fn add(_ x: u256, _ y: u256) -> u256 { + unsafe { return __add(x, y) } +} + +pub fn sub(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sub(x, y) } +} + +pub fn mul(_ x: u256, _ y: u256) -> u256 { + unsafe { return __mul(x, y) } +} + +pub fn div(_ x: u256, _ y: u256) -> u256 { + unsafe { return __div(x, y) } +} + +pub fn sdiv(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sdiv(x, y) } +} + +pub fn mod(_ x: u256, _ y: u256) -> u256 { + unsafe { return __mod(x, y) } +} + +pub fn smod(_ x: u256, _ y: u256) -> u256 { + unsafe { return __smod(x, y) } +} + +pub fn exp(_ x: u256, _ y: u256) -> u256 { + unsafe { return __exp(x, y) } +} + +pub fn addmod(_ x: u256, _ y: u256, _ m: u256) -> u256 { + unsafe { return __addmod(x, y, m) } +} + +pub fn mulmod(_ x: u256, _ y: u256, _ m: u256) -> u256 { + unsafe { return __mulmod(x, y, m) } +} + +pub fn sign_extend(_ i: u256, _ x: u256) -> u256 { + unsafe { return __signextend(i, x) } +} + + +// Comparison ops +// TODO: return bool (see issue //653) +pub fn lt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __lt(x, y) } +} + +pub fn gt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __gt(x, y) } +} + +pub fn slt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __slt(x, y) } +} + +pub fn sgt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sgt(x, y) } +} + +pub fn eq(_ x: u256, _ y: u256) -> u256 { + unsafe { return __eq(x, y) } +} + +pub fn is_zero(_ x: u256) -> u256 { + unsafe { return __iszero(x) } +} + + +// Bitwise ops +pub fn bitwise_and(_ x: u256, _ y: u256) -> u256 { + unsafe { return __and(x, y) } +} + +pub fn bitwise_or(_ x: u256, _ y: u256) -> u256 { + unsafe { return __or(x, y) } +} + +pub fn bitwise_not(_ x: u256) -> u256 { + unsafe { return __not(x) } +} + +pub fn xor(_ x: u256, _ y: u256) -> u256 { + unsafe { return __xor(x, y) } +} + +pub fn byte(offset: u256, value: u256) -> u256 { + unsafe { return __byte(offset, value) } +} + +pub fn shl(bits: u256, value: u256) -> u256 { + unsafe { return __shl(bits, value) } +} + +pub fn shr(bits: u256, value: u256) -> u256 { + unsafe { return __shr(bits, value) } +} + +pub fn sar(bits: u256, value: u256) -> u256 { + unsafe { return __sar(bits, value) } +} + + +// Evm state access and control +pub fn return_mem(buf: MemoryBuffer) { + unsafe{ __return(buf.offset(), buf.len()) } +} + +pub fn revert_mem(buf: MemoryBuffer) { + unsafe { __revert(buf.offset(), buf.len()) } +} + +pub unsafe fn selfdestruct(_ addr: address) { + __selfdestruct(u256(addr)) +} + +// Invalid opcode. Equivalent to revert(0, 0), +// except that all remaining gas in the current context +// is consumed. +pub unsafe fn invalid() { + __invalid() +} + +pub unsafe fn stop() { + __stop() +} + +pub unsafe fn pc() -> u256 { + return __pc() +} + +// TODO: dunno if we should enable this +// pub unsafe fn pop(_ x: u256) { +// return __pop(x) +// } + +pub unsafe fn mload(offset p: u256) -> u256 { + return __mload(p) +} + +pub unsafe fn mstore(offset p: u256, value v: u256) { + __mstore(p, v) +} +pub unsafe fn mstore8(offset p: u256, value v: u256) { + __mstore8(p, v) +} + +pub unsafe fn sload(offset p: u256) -> u256 { + return __sload(p) +} + +pub unsafe fn sstore(offset p: u256, value v: u256) { + __sstore(p, v) +} + +pub unsafe fn msize() -> u256 { + return __msize() +} + +pub unsafe fn call_data_load(offset p: u256) -> u256 { + return __calldataload(p) +} + +pub unsafe fn call_data_size() -> u256 { + return __calldatasize() +} + +pub fn call_data_copy(buf: MemoryBuffer, from_offset f: u256) { + unsafe { __calldatacopy(buf.offset(), f, buf.len()) } +} + +pub unsafe fn code_size() -> u256 { + return __codesize() +} + +pub unsafe fn code_copy(to_offset t: u256, from_offset f: u256, len: u256) { + __codecopy(t, f, len) +} + +pub unsafe fn return_data_size() -> u256 { + return __returndatasize() +} + +pub unsafe fn return_data_copy(to_offset t: u256, from_offset f: u256, len: u256) { + __returndatacopy(t, f, len) +} + +pub unsafe fn extcodesize(_ addr: address) -> u256 { + return __extcodesize(u256(addr)) +} + +pub unsafe fn ext_code_copy(_ addr: address, to_offset t: u256, from_offset f: u256, len: u256) { + __extcodecopy(u256(addr), t, f, len) +} + +pub unsafe fn ext_code_hash(_ addr: address) -> u256 { + return __extcodehash(u256(addr)) +} + +pub fn keccak256_mem(buf: MemoryBuffer) -> u256 { + unsafe { return __keccak256(buf.offset(), buf.len()) } +} + + +// Contract creation and calling + +pub fn create(value v: u256, buf: MemoryBuffer) -> address { + unsafe { return address(__create(v, buf.offset(), buf.len())) } +} + +pub fn create2(value v: u256, buf: MemoryBuffer, salt s: u256) -> address { + unsafe { return address(__create2(v, buf.offset(), buf.len(), s)) } +} + +// TODO: return bool (success) +pub fn call(gas: u256, addr: address, value: u256, mut buf: RawCallBuffer) -> bool { + unsafe{ return __call(gas, u256(addr), value, buf.offset(), buf.input_len(), buf.offset(), buf.output_len()) == 1 } +} + +pub unsafe fn call_code(gas: u256, addr: address, value: u256, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __callcode(gas, u256(addr), value, input_offset, input_len, output_offset, output_len) +} + +pub unsafe fn delegate_call(gas: u256, addr: address, value: u256, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __delegatecall(gas, u256(addr), input_offset, input_len, output_offset, output_len) +} + +pub unsafe fn static_call(gas: u256, addr: address, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __staticcall(gas, u256(addr), input_offset, input_len, output_offset, output_len) +} + +// Logging functions + +pub fn log0(buf: MemoryBuffer) { + unsafe { return __log0(buf.offset(), buf.len()) } +} + +pub fn log1(buf: MemoryBuffer, topic1 t1: u256) { + unsafe { return __log1(buf.offset(), buf.len(), t1) } +} + +pub fn log2(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256) { + unsafe { return __log2(buf.offset(), buf.len(), t1, t2) } +} + +pub fn log3(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256, topic3 t3: u256) { + unsafe { return __log3(buf.offset(), buf.len(), t1, t2, t3) } +} + +pub fn log4(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256, topic3 t3: u256, topic4 t4: u256) { + unsafe { return __log4(buf.offset(), buf.len(), t1, t2, t3, t4) } +} diff --git a/crates/library2/std/src/lib.fe b/crates/library2/std/src/lib.fe new file mode 100644 index 0000000000..8a94dde71d --- /dev/null +++ b/crates/library2/std/src/lib.fe @@ -0,0 +1,3 @@ +pub fn get_42() -> u256 { + return 42 +} \ No newline at end of file diff --git a/crates/library2/std/src/math.fe b/crates/library2/std/src/math.fe new file mode 100644 index 0000000000..bc37ee6739 --- /dev/null +++ b/crates/library2/std/src/math.fe @@ -0,0 +1,15 @@ +pub fn min(_ x: u256, _ y: u256) -> u256 { + if x < y { + return x + } else { + return y + } +} + +pub fn max(_ x: u256, _ y: u256) -> u256 { + if x > y { + return x + } else { + return y + } +} \ No newline at end of file diff --git a/crates/library2/std/src/precompiles.fe b/crates/library2/std/src/precompiles.fe new file mode 100644 index 0000000000..ba9d59f138 --- /dev/null +++ b/crates/library2/std/src/precompiles.fe @@ -0,0 +1,191 @@ +use ingot::buf::{MemoryBuffer, MemoryBufferWriter, MemoryBufferReader} +use ingot::evm + +enum Precompile { + EcRecover + Sha2256 + Ripemd160 + Identity + ModExp + EcAdd + EcMul + EcPairing + Blake2f + + pub fn addr(self) -> address { + match self { + Precompile::EcRecover => { return 0x01 } + Precompile::Sha2256 => { return 0x02 } + Precompile::Ripemd160 => { return 0x03 } + Precompile::Identity => { return 0x04 } + Precompile::ModExp => { return 0x05 } + Precompile::EcAdd => { return 0x06 } + Precompile::EcMul => { return 0x07 } + Precompile::EcPairing => { return 0x08 } + Precompile::Blake2f => { return 0x09 } + } + } + + pub fn single_buf_call(self, mut buf: MemoryBuffer) { + unsafe { + assert evm::static_call( + gas: evm::gas_remaining(), + addr: self.addr(), + input_offset: buf.offset(), + input_len: buf.len(), + output_offset: buf.offset(), + output_len: buf.len() + ) == 1 + } + } + + pub fn call(self, input: MemoryBuffer, mut output: MemoryBuffer) { + unsafe { + assert evm::static_call( + gas: evm::gas_remaining(), + addr: self.addr(), + input_offset: input.offset(), + input_len: input.len(), + output_offset: output.offset(), + output_len: output.len() + ) == 1 + } + } +} + +/// EC Recover precompile call. +pub fn ec_recover(hash: u256, v: u256, r: u256, s: u256) -> address { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value: hash) + writer.write(value: v) + writer.write(value: r) + writer.write(value: s) + + Precompile::EcRecover.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return address(reader.read_u256()) +} + +/// SHA2 256 precompile call. +pub fn sha2_256(buf: MemoryBuffer) -> u256 { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::Sha2256.call(input: buf, output) + return reader.read_u256() +} + +/// Ripemd 160 precompile call. +pub fn ripemd_160(buf: MemoryBuffer) -> u256 { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::Ripemd160.call(input: buf, output) + return reader.read_u256() +} + +/// Identity precompile call. +pub fn identity(buf: MemoryBuffer) -> MemoryBuffer { + let mut output: MemoryBuffer = MemoryBuffer::new(len: buf.len()) + Precompile::Identity.call(input: buf, output) + return output +} + +/// Mod exp preocmpile call. +pub fn mod_exp( + b_size: u256, + e_size: u256, + m_size: u256, + b: MemoryBuffer, + e: MemoryBuffer, + m: MemoryBuffer, +) -> MemoryBuffer { + let mut buf: MemoryBuffer = MemoryBuffer::new( + len: 96 + b_size + e_size + m_size + ) + + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value: b_size) + writer.write(value: e_size) + writer.write(value: m_size) + writer.write_buf(buf: b) + writer.write_buf(buf: e) + writer.write_buf(buf: m) + + Precompile::ModExp.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return reader.read_buf(len: m_size) +} + +/// EC add precompile call. +pub fn ec_add(x1: u256, y1: u256, x2: u256, y2: u256) -> (u256, u256) { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: x1) + writer.write(value: y1) + writer.write(value: x2) + writer.write(value: y2) + + Precompile::EcAdd.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return (reader.read_u256(), reader.read_u256()) +} + +/// EC mul precompile call. +pub fn ec_mul(x: u256, y: u256, s: u256) -> (u256, u256) { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: x) + writer.write(value: y) + writer.write(value: s) + + Precompile::EcMul.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return (reader.read_u256(), reader.read_u256()) +} + +/// EC pairing precompile call. +pub fn ec_pairing(buf: MemoryBuffer) -> bool { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::EcPairing.call(input: buf, output) + return reader.read_u256() == 1 +} + +/// Blake 2f precompile call. +pub fn blake_2f( + rounds: u32, + h: Array, + m: Array, + t: Array, + f: bool +) -> Array { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 213) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: rounds) + for value in h { writer.write(value) } + for value in m { writer.write(value) } + for value in t { writer.write(value) } + writer.write(value: u8(1) if f else u8(0)) + + Precompile::Blake2f.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return [ + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64() + ] +} diff --git a/crates/library2/std/src/prelude.fe b/crates/library2/std/src/prelude.fe new file mode 100644 index 0000000000..715ec70cf9 --- /dev/null +++ b/crates/library2/std/src/prelude.fe @@ -0,0 +1 @@ +use ingot::context::Context \ No newline at end of file diff --git a/crates/library2/std/src/traits.fe b/crates/library2/std/src/traits.fe new file mode 100644 index 0000000000..43a0f44727 --- /dev/null +++ b/crates/library2/std/src/traits.fe @@ -0,0 +1,160 @@ +// Dummy trait used in testing. We can remove this once we have more useful traits + +pub trait Dummy {} + +pub trait Min { + fn min() -> Self; +} + +impl Min for u8 { + fn min() -> Self { + return 0 + } +} + +impl Min for u16 { + fn min() -> Self { + return 0 + } +} + +impl Min for u32 { + fn min() -> Self { + return 0 + } +} + +impl Min for u64 { + fn min() -> Self { + return 0 + } +} + +impl Min for u128 { + fn min() -> Self { + return 0 + } +} + +impl Min for u256 { + fn min() -> Self { + return 0 + } +} + +impl Min for i8 { + fn min() -> Self { + return -128 + } +} + +impl Min for i16 { + fn min() -> Self { + return -32768 + } +} + +impl Min for i32 { + fn min() -> Self { + return -2147483648 + } +} + +impl Min for i64 { + fn min() -> Self { + return -9223372036854775808 + } +} + +impl Min for i128 { + fn min() -> Self { + return -170141183460469231731687303715884105728 + } +} + +impl Min for i256 { + fn min() -> Self { + return -57896044618658097711785492504343953926634992332820282019728792003956564819968 + } +} + + + + + + +pub trait Max { + fn max() -> Self; +} + +impl Max for u8 { + fn max() -> Self { + return 255 + } +} + +impl Max for u16 { + fn max() -> Self { + return 65535 + } +} + +impl Max for u32 { + fn max() -> Self { + return 4294967295 + } +} + +impl Max for u64 { + fn max() -> Self { + return 18446744073709551615 + } +} + +impl Max for u128 { + fn max() -> Self { + return 340282366920938463463374607431768211455 + } +} + +impl Max for u256 { + fn max() -> Self { + return 115792089237316195423570985008687907853269984665640564039457584007913129639935 + } +} + +impl Max for i8 { + fn max() -> Self { + return 127 + } +} + +impl Max for i16 { + fn max() -> Self { + return 32767 + } +} + +impl Max for i32 { + fn max() -> Self { + return 2147483647 + } +} + +impl Max for i64 { + fn max() -> Self { + return 9223372036854775807 + } +} + +impl Max for i128 { + fn max() -> Self { + return 170141183460469231731687303715884105727 + } +} + +impl Max for i256 { + fn max() -> Self { + return 57896044618658097711785492504343953926634992332820282019728792003956564819967 + } +} diff --git a/crates/mir2-analysis/Cargo.toml b/crates/mir2-analysis/Cargo.toml new file mode 100644 index 0000000000..24a92769dd --- /dev/null +++ b/crates/mir2-analysis/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fe-mir2-analysis" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +fe-common2 = { path = "../common2", version = "^0.23.0" } +fe-hir-analysis = { path = "../hir-analysis", version = "^0.23.0" } +fe-hir = { path = "../hir", version = "^0.23.0" } +fe-mir2 = { path = "../mir2", version = "^0.23.0" } +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } +smol_str = "0.1.21" +id-arena = "2.2.1" +fxhash = "0.2.1" +dot2 = "1.0.0" + +[dev-dependencies] +test-files = { path = "../test-files", package = "fe-test-files" } +fe-library = { path = "../library" } diff --git a/crates/mir2-analysis/src/db/queries.rs b/crates/mir2-analysis/src/db/queries.rs new file mode 100644 index 0000000000..8cdae44831 --- /dev/null +++ b/crates/mir2-analysis/src/db/queries.rs @@ -0,0 +1,7 @@ +pub mod constant; +pub mod contract; +pub mod enums; +pub mod function; +pub mod module; +pub mod structs; +pub mod types; diff --git a/crates/mir2-analysis/src/db/queries/contract.rs b/crates/mir2-analysis/src/db/queries/contract.rs new file mode 100644 index 0000000000..d7bcf742a4 --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/contract.rs @@ -0,0 +1,15 @@ +use std::rc::Rc; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_contract_all_functions( + db: &dyn MirDb, + contract: analyzer_items::ContractId, +) -> Rc> { + contract + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2-analysis/src/db/queries/enums.rs b/crates/mir2-analysis/src/db/queries/enums.rs new file mode 100644 index 0000000000..5082d76e42 --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/enums.rs @@ -0,0 +1,15 @@ +use std::rc::Rc; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_enum_all_functions( + db: &dyn MirDb, + enum_: analyzer_items::EnumId, +) -> Rc> { + enum_ + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2-analysis/src/db/queries/function.rs b/crates/mir2-analysis/src/db/queries/function.rs new file mode 100644 index 0000000000..211a5cadfb --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/function.rs @@ -0,0 +1,130 @@ +use std::{collections::BTreeMap, rc::Rc}; + +use fe_analyzer2::{ + display::Displayable, + namespace::{items as analyzer_items, items::Item, types as analyzer_types}, +}; + +use smol_str::SmolStr; + +use crate::{ + db::MirDb, + ir::{self, function::Linkage, FunctionSignature, TypeId}, + lower::function::{lower_func_body, lower_func_signature, lower_monomorphized_func_signature}, +}; + +pub fn mir_lowered_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, +) -> ir::FunctionId { + lower_func_signature(db, analyzer_func) +} + +pub fn mir_lowered_monomorphized_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, + resolved_generics: BTreeMap, +) -> ir::FunctionId { + lower_monomorphized_func_signature(db, analyzer_func, resolved_generics) +} + +/// Generate MIR function and monomorphize generic parameters as if they were +/// called with unit type NOTE: THIS SHOULD ONLY BE USED IN TEST CODE +pub fn mir_lowered_pseudo_monomorphized_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, +) -> ir::FunctionId { + let resolved_generics = analyzer_func + .sig(db.upcast()) + .generic_params(db.upcast()) + .iter() + .map(|generic| (generic.name(), analyzer_types::TypeId::unit(db.upcast()))) + .collect::>(); + lower_monomorphized_func_signature(db, analyzer_func, resolved_generics) +} + +pub fn mir_lowered_func_body(db: &dyn MirDb, func: ir::FunctionId) -> Rc { + lower_func_body(db, func) +} + +impl ir::FunctionId { + pub fn signature(self, db: &dyn MirDb) -> Rc { + db.lookup_mir_intern_function(self) + } + + pub fn return_type(self, db: &dyn MirDb) -> Option { + self.signature(db).return_type + } + + pub fn linkage(self, db: &dyn MirDb) -> Linkage { + self.signature(db).linkage + } + + pub fn analyzer_func(self, db: &dyn MirDb) -> analyzer_items::FunctionId { + self.signature(db).analyzer_func_id + } + + pub fn body(self, db: &dyn MirDb) -> Rc { + db.mir_lowered_func_body(self) + } + + pub fn module(self, db: &dyn MirDb) -> analyzer_items::ModuleId { + let analyzer_func = self.analyzer_func(db); + analyzer_func.module(db.upcast()) + } + + pub fn is_contract_init(self, db: &dyn MirDb) -> bool { + self.analyzer_func(db) + .data(db.upcast()) + .sig + .is_constructor(db.upcast()) + } + + /// Returns a type suffix if a generic function was monomorphized + pub fn type_suffix(&self, db: &dyn MirDb) -> SmolStr { + self.signature(db) + .resolved_generics + .values() + .fold(String::new(), |acc, param| { + format!("{}_{}", acc, param.display(db.upcast())) + }) + .into() + } + + pub fn name(&self, db: &dyn MirDb) -> SmolStr { + let analyzer_func = self.analyzer_func(db); + analyzer_func.name(db.upcast()) + } + + /// Returns `class_name::fn_name` if a function is a method else `fn_name`. + pub fn debug_name(self, db: &dyn MirDb) -> SmolStr { + let analyzer_func = self.analyzer_func(db); + let func_name = format!( + "{}{}", + analyzer_func.name(db.upcast()), + self.type_suffix(db) + ); + + match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { + Some(Item::Impl(id)) => { + let class_name = format!( + "<{} as {}>", + id.receiver(db.upcast()).display(db.upcast()), + id.trait_id(db.upcast()).name(db.upcast()) + ); + format!("{class_name}::{func_name}").into() + } + Some(class) => { + let class_name = class.name(db.upcast()); + format!("{class_name}::{func_name}").into() + } + _ => func_name.into(), + } + } + + pub fn returns_aggregate(self, db: &dyn MirDb) -> bool { + self.return_type(db) + .map(|ty| ty.is_aggregate(db)) + .unwrap_or_default() + } +} diff --git a/crates/mir2-analysis/src/db/queries/module.rs b/crates/mir2-analysis/src/db/queries/module.rs new file mode 100644 index 0000000000..b7d00521ac --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/module.rs @@ -0,0 +1,35 @@ +use std::rc::Rc; + +use fe_analyzer2::namespace::items::{self as analyzer_items, TypeDef}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_module_all_functions( + db: &dyn MirDb, + module: analyzer_items::ModuleId, +) -> Rc> { + let mut functions = vec![]; + + let items = module.all_items(db.upcast()); + items.iter().for_each(|item| match item { + analyzer_items::Item::Function(func) => { + functions.push(db.mir_lowered_func_signature(*func)) + } + + analyzer_items::Item::Type(TypeDef::Contract(contract)) => { + functions.extend_from_slice(&db.mir_lower_contract_all_functions(*contract)) + } + + analyzer_items::Item::Type(TypeDef::Struct(struct_)) => { + functions.extend_from_slice(&db.mir_lower_struct_all_functions(*struct_)) + } + + analyzer_items::Item::Type(TypeDef::Enum(enum_)) => { + functions.extend_from_slice(&db.mir_lower_enum_all_functions(*enum_)) + } + + _ => {} + }); + + functions.into() +} diff --git a/crates/mir2-analysis/src/db/queries/structs.rs b/crates/mir2-analysis/src/db/queries/structs.rs new file mode 100644 index 0000000000..c1a859718d --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/structs.rs @@ -0,0 +1,17 @@ +use std::rc::Rc; + +use fe_analyzer2::namespace::items::{self as analyzer_items}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_struct_all_functions( + db: &dyn MirDb, + struct_: analyzer_items::StructId, +) -> Rc> { + struct_ + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_pseudo_monomorphized_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2-analysis/src/db/queries/types.rs b/crates/mir2-analysis/src/db/queries/types.rs new file mode 100644 index 0000000000..fe1261bfc5 --- /dev/null +++ b/crates/mir2-analysis/src/db/queries/types.rs @@ -0,0 +1,657 @@ +use std::{fmt, rc::Rc, str::FromStr}; + +use fe_analyzer2::namespace::{items::EnumVariantId, types as analyzer_types}; + +use num_bigint::BigInt; +use num_traits::ToPrimitive; + +use crate::{ + db::MirDb, + ir::{ + types::{ArrayDef, TupleDef, TypeKind}, + Type, TypeId, Value, + }, + lower::types::lower_type, +}; + +pub fn mir_lowered_type(db: &dyn MirDb, analyzer_type: analyzer_types::TypeId) -> TypeId { + lower_type(db, analyzer_type) +} + +impl TypeId { + pub fn data(self, db: &dyn MirDb) -> Rc { + db.lookup_mir_intern_type(self) + } + + pub fn analyzer_ty(self, db: &dyn MirDb) -> Option { + self.data(db).analyzer_ty + } + + pub fn projection_ty(self, db: &dyn MirDb, access: &Value) -> TypeId { + let ty = self.deref(db); + let pty = match &ty.data(db).kind { + TypeKind::Array(ArrayDef { elem_ty, .. }) => *elem_ty, + TypeKind::Tuple(def) => { + let index = expect_projection_index(access); + def.items[index] + } + TypeKind::Struct(def) | TypeKind::Contract(def) => { + let index = expect_projection_index(access); + def.fields[index].1 + } + TypeKind::Enum(_) => { + let index = expect_projection_index(access); + debug_assert_eq!(index, 0); + ty.projection_ty_imm(db, 0) + } + _ => panic!("{:?} can't project onto the `access`", self.as_string(db)), + }; + match &self.data(db).kind { + TypeKind::SPtr(_) | TypeKind::Contract(_) => pty.make_sptr(db), + TypeKind::MPtr(_) => pty.make_mptr(db), + _ => pty, + } + } + + pub fn deref(self, db: &dyn MirDb) -> TypeId { + match self.data(db).kind { + TypeKind::SPtr(inner) => inner, + TypeKind::MPtr(inner) => inner.deref(db), + _ => self, + } + } + + pub fn make_sptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::SPtr(self), None).into()) + } + + pub fn make_mptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::MPtr(self), None).into()) + } + + pub fn projection_ty_imm(self, db: &dyn MirDb, index: usize) -> TypeId { + match &self.data(db).kind { + TypeKind::Array(ArrayDef { elem_ty, .. }) => *elem_ty, + TypeKind::Tuple(def) => def.items[index], + TypeKind::Struct(def) | TypeKind::Contract(def) => def.fields[index].1, + TypeKind::Enum(_) => { + debug_assert_eq!(index, 0); + self.enum_disc_type(db) + } + _ => panic!("{:?} can't project onto the `index`", self.as_string(db)), + } + } + + pub fn aggregate_field_num(self, db: &dyn MirDb) -> usize { + match &self.data(db).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + TypeKind::Tuple(def) => def.items.len(), + TypeKind::Struct(def) | TypeKind::Contract(def) => def.fields.len(), + TypeKind::Enum(_) => 2, + _ => unreachable!(), + } + } + + pub fn enum_disc_type(self, db: &dyn MirDb) -> TypeId { + let kind = match &self.deref(db).data(db).kind { + TypeKind::Enum(def) => def.tag_type(), + _ => unreachable!(), + }; + let analyzer_type = match kind { + TypeKind::U8 => Some(analyzer_types::Integer::U8), + TypeKind::U16 => Some(analyzer_types::Integer::U16), + TypeKind::U32 => Some(analyzer_types::Integer::U32), + TypeKind::U64 => Some(analyzer_types::Integer::U64), + TypeKind::U128 => Some(analyzer_types::Integer::U128), + TypeKind::U256 => Some(analyzer_types::Integer::U256), + _ => None, + } + .map(|int| analyzer_types::TypeId::int(db.upcast(), int)); + + db.mir_intern_type(Type::new(kind, analyzer_type).into()) + } + + pub fn enum_data_offset(self, db: &dyn MirDb, slot_size: usize) -> usize { + match &self.data(db).kind { + TypeKind::Enum(def) => { + let disc_size = self.enum_disc_type(db).size_of(db, slot_size); + let mut align = 1; + for variant in def.variants.iter() { + let variant_align = variant.ty.align_of(db, slot_size); + align = num_integer::lcm(align, variant_align); + } + round_up(disc_size, align) + } + _ => unreachable!(), + } + } + + pub fn enum_variant_type(self, db: &dyn MirDb, variant_id: EnumVariantId) -> TypeId { + let name = variant_id.name(db.upcast()); + match &self.deref(db).data(db).kind { + TypeKind::Enum(def) => def + .variants + .iter() + .find(|variant| variant.name == name) + .map(|variant| variant.ty) + .unwrap(), + _ => unreachable!(), + } + } + + pub fn index_from_fname(self, db: &dyn MirDb, fname: &str) -> BigInt { + let ty = self.deref(db); + match &ty.data(db).kind { + TypeKind::Tuple(_) => { + // TODO: Fix this when the syntax for tuple access changes. + let index_str = &fname[4..]; + BigInt::from_str(index_str).unwrap() + } + + TypeKind::Struct(def) | TypeKind::Contract(def) => def + .fields + .iter() + .enumerate() + .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) + .unwrap(), + + other => unreachable!("{:?} does not have fields", other), + } + } + + pub fn is_primitive(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 + | TypeKind::U256 + | TypeKind::Bool + | TypeKind::Address + | TypeKind::Unit + ) + } + + pub fn is_integral(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 + | TypeKind::U256 + ) + } + + pub fn is_address(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).kind, TypeKind::Address) + } + + pub fn is_unit(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Unit) + } + + pub fn is_enum(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Enum(_)) + } + + pub fn is_signed(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + ) + } + + /// Returns size of the type in bytes. + pub fn size_of(self, db: &dyn MirDb, slot_size: usize) -> usize { + match &self.data(db).kind { + TypeKind::Bool | TypeKind::I8 | TypeKind::U8 => 1, + TypeKind::I16 | TypeKind::U16 => 2, + TypeKind::I32 | TypeKind::U32 => 4, + TypeKind::I64 | TypeKind::U64 => 8, + TypeKind::I128 | TypeKind::U128 => 16, + TypeKind::String(len) => 32 + len, + TypeKind::MPtr(..) + | TypeKind::SPtr(..) + | TypeKind::I256 + | TypeKind::U256 + | TypeKind::Map(_) => 32, + TypeKind::Address => 20, + TypeKind::Unit => 0, + + TypeKind::Array(def) => array_elem_size_imp(db, def, slot_size) * def.len, + + TypeKind::Tuple(def) => { + if def.items.is_empty() { + return 0; + } + let last_idx = def.items.len() - 1; + self.aggregate_elem_offset(db, last_idx, slot_size) + + def.items[last_idx].size_of(db, slot_size) + } + + TypeKind::Struct(def) | TypeKind::Contract(def) => { + if def.fields.is_empty() { + return 0; + } + let last_idx = def.fields.len() - 1; + self.aggregate_elem_offset(db, last_idx, slot_size) + + def.fields[last_idx].1.size_of(db, slot_size) + } + + TypeKind::Enum(def) => { + let data_offset = self.enum_data_offset(db, slot_size); + let maximum_data_size = def + .variants + .iter() + .map(|variant| variant.ty.size_of(db, slot_size)) + .max() + .unwrap_or(0); + data_offset + maximum_data_size + } + } + } + + pub fn is_zero_sized(self, db: &dyn MirDb) -> bool { + // It's ok to use 1 as a slot size because slot size doesn't affect whether a + // type is zero sized or not. + self.size_of(db, 1) == 0 + } + + pub fn align_of(self, db: &dyn MirDb, slot_size: usize) -> usize { + if self.is_primitive(db) { + 1 + } else { + // TODO: Too naive, we could implement more efficient layout for aggregate + // types. + slot_size + } + } + + /// Returns an offset of the element of aggregate type. + pub fn aggregate_elem_offset(self, db: &dyn MirDb, elem_idx: T, slot_size: usize) -> usize + where + T: num_traits::ToPrimitive, + { + debug_assert!(self.is_aggregate(db)); + debug_assert!(elem_idx.to_usize().unwrap() < self.aggregate_field_num(db)); + let elem_idx = elem_idx.to_usize().unwrap(); + + if elem_idx == 0 { + return 0; + } + + match &self.data(db).kind { + TypeKind::Array(def) => array_elem_size_imp(db, def, slot_size) * elem_idx, + TypeKind::Enum(_) => self.enum_data_offset(db, slot_size), + _ => { + let mut offset = self.aggregate_elem_offset(db, elem_idx - 1, slot_size) + + self + .projection_ty_imm(db, elem_idx - 1) + .size_of(db, slot_size); + + let elem_ty = self.projection_ty_imm(db, elem_idx); + if (offset % slot_size + elem_ty.size_of(db, slot_size)) > slot_size { + offset = round_up(offset, slot_size); + } + + round_up(offset, elem_ty.align_of(db, slot_size)) + } + } + } + + pub fn is_aggregate(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::Array(_) + | TypeKind::Tuple(_) + | TypeKind::Struct(_) + | TypeKind::Enum(_) + | TypeKind::Contract(_) + ) + } + + pub fn is_struct(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Struct(_)) + } + + pub fn is_array(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).kind, TypeKind::Array(_)) + } + + pub fn is_string(self, db: &dyn MirDb) -> bool { + matches! { + &self.data(db).kind, + TypeKind::String(_) + } + } + + pub fn is_ptr(self, db: &dyn MirDb) -> bool { + self.is_mptr(db) || self.is_sptr(db) + } + + pub fn is_mptr(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::MPtr(_)) + } + + pub fn is_sptr(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::SPtr(_)) + } + + pub fn is_map(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::Map(_)) + } + + pub fn is_contract(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::Contract(_)) + } + + pub fn array_elem_size(self, db: &dyn MirDb, slot_size: usize) -> usize { + let data = self.data(db); + if let TypeKind::Array(def) = &data.kind { + array_elem_size_imp(db, def, slot_size) + } else { + panic!("expected `Array` type; but got {:?}", data.as_ref()) + } + } + + pub fn print(&self, db: &dyn MirDb, w: &mut W) -> fmt::Result { + match &self.data(db).kind { + TypeKind::I8 => write!(w, "i8"), + TypeKind::I16 => write!(w, "i16"), + TypeKind::I32 => write!(w, "i32"), + TypeKind::I64 => write!(w, "i64"), + TypeKind::I128 => write!(w, "i128"), + TypeKind::I256 => write!(w, "i256"), + TypeKind::U8 => write!(w, "u8"), + TypeKind::U16 => write!(w, "u16"), + TypeKind::U32 => write!(w, "u32"), + TypeKind::U64 => write!(w, "u64"), + TypeKind::U128 => write!(w, "u128"), + TypeKind::U256 => write!(w, "u256"), + TypeKind::Bool => write!(w, "bool"), + TypeKind::Address => write!(w, "address"), + TypeKind::Unit => write!(w, "()"), + TypeKind::String(size) => write!(w, "Str<{size}>"), + TypeKind::Array(ArrayDef { elem_ty, len }) => { + write!(w, "[")?; + elem_ty.print(db, w)?; + write!(w, "; {len}]") + } + TypeKind::Tuple(TupleDef { items }) => { + write!(w, "(")?; + if items.is_empty() { + return write!(w, ")"); + } + + let len = items.len(); + for item in &items[0..len - 1] { + item.print(db, w)?; + write!(w, ", ")?; + } + items.last().unwrap().print(db, w)?; + write!(w, ")") + } + TypeKind::Struct(def) => { + write!(w, "{}", def.name) + } + TypeKind::Enum(def) => { + write!(w, "{}", def.name) + } + TypeKind::Contract(def) => { + write!(w, "{}", def.name) + } + TypeKind::Map(def) => { + write!(w, "Map<")?; + def.key_ty.print(db, w)?; + write!(w, ",")?; + def.value_ty.print(db, w)?; + write!(w, ">") + } + TypeKind::MPtr(inner) => { + write!(w, "*@m ")?; + inner.print(db, w) + } + TypeKind::SPtr(inner) => { + write!(w, "*@s ")?; + inner.print(db, w) + } + } + } + + pub fn as_string(&self, db: &dyn MirDb) -> String { + let mut s = String::new(); + self.print(db, &mut s).unwrap(); + s + } +} + +fn array_elem_size_imp(db: &dyn MirDb, arr: &ArrayDef, slot_size: usize) -> usize { + let elem_ty = arr.elem_ty; + let elem = elem_ty.size_of(db, slot_size); + let align = if elem_ty.is_address(db) { + slot_size + } else { + elem_ty.align_of(db, slot_size) + }; + round_up(elem, align) +} + +fn expect_projection_index(value: &Value) -> usize { + match value { + Value::Immediate { imm, .. } => imm.to_usize().unwrap(), + _ => panic!("given `value` is not an immediate"), + } +} + +fn round_up(value: usize, slot_size: usize) -> usize { + ((value + slot_size - 1) / slot_size) * slot_size +} + +#[cfg(test)] +mod tests { + use fe_analyzer2::namespace::items::ModuleId; + use fe_common2::Span; + + use super::*; + use crate::{ + db::{MirDb, NewDb}, + ir::types::StructDef, + }; + + #[test] + fn test_primitive_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let bool = db.mir_intern_type(Type::new(TypeKind::Bool, None).into()); + + debug_assert_eq!(i8.size_of(&db, 1), 1); + debug_assert_eq!(i8.size_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 1), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + debug_assert_eq!(bool.size_of(&db, 1), 1); + debug_assert_eq!(bool.size_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + + let u32 = db.mir_intern_type(Type::new(TypeKind::U32, None).into()); + debug_assert_eq!(u32.size_of(&db, 1), 4); + debug_assert_eq!(u32.size_of(&db, 32), 4); + debug_assert_eq!(u32.align_of(&db, 32), 1); + + let address = db.mir_intern_type(Type::new(TypeKind::Address, None).into()); + debug_assert_eq!(address.size_of(&db, 1), 20); + debug_assert_eq!(address.size_of(&db, 32), 20); + debug_assert_eq!(address.align_of(&db, 32), 1); + } + + #[test] + fn test_primitive_elem_array_type_info() { + let db = NewDb::default(); + let i32 = db.mir_intern_type(Type::new(TypeKind::I32, None).into()); + + let array_len = 10; + let array_def = ArrayDef { + elem_ty: i32, + len: array_len, + }; + let array = db.mir_intern_type(Type::new(TypeKind::Array(array_def), None).into()); + + let elem_size = array.array_elem_size(&db, 1); + debug_assert_eq!(elem_size, 4); + debug_assert_eq!(array.array_elem_size(&db, 32), elem_size); + + debug_assert_eq!(array.size_of(&db, 1), elem_size * array_len); + debug_assert_eq!(array.size_of(&db, 32), elem_size * array_len); + debug_assert_eq!(array.align_of(&db, 1), 1); + debug_assert_eq!(array.align_of(&db, 32), 32); + + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 32), elem_size * 3); + debug_assert_eq!(array.aggregate_elem_offset(&db, 9, 1), elem_size * 9); + } + + #[test] + fn test_aggregate_elem_array_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + let array_len = 10; + let array_def = ArrayDef { + elem_ty: aggregate, + len: array_len, + }; + let array = db.mir_intern_type(Type::new(TypeKind::Array(array_def), None).into()); + + debug_assert_eq!(array.array_elem_size(&db, 1), 34); + debug_assert_eq!(array.array_elem_size(&db, 32), 64); + + debug_assert_eq!(array.size_of(&db, 1), 34 * array_len); + debug_assert_eq!(array.size_of(&db, 32), 64 * array_len); + + debug_assert_eq!(array.align_of(&db, 1), 1); + debug_assert_eq!(array.align_of(&db, 32), 32); + + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 1), 102); + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 32), 192); + } + + #[test] + fn test_primitive_elem_aggregate_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + debug_assert_eq!(aggregate.size_of(&db, 1), 34); + debug_assert_eq!(aggregate.size_of(&db, 32), 49); + + debug_assert_eq!(aggregate.align_of(&db, 1), 1); + debug_assert_eq!(aggregate.align_of(&db, 32), 32); + + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 1), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 32), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 3, 1), 17); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 3, 32), 32); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 4, 1), 33); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 4, 32), 48); + } + + #[test] + fn test_aggregate_elem_aggregate_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields_inner = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def_inner = StructDef { + name: "".into(), + fields: fields_inner, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate_inner = + db.mir_intern_type(Type::new(TypeKind::Struct(struct_def_inner), None).into()); + + let fields = vec![("".into(), i8), ("".into(), aggregate_inner)]; + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + debug_assert_eq!(aggregate.size_of(&db, 1), 35); + debug_assert_eq!(aggregate.size_of(&db, 32), 81); + + debug_assert_eq!(aggregate.align_of(&db, 1), 1); + debug_assert_eq!(aggregate.align_of(&db, 32), 32); + + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 1), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 32), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 1, 1), 1); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 1, 32), 32); + } +} diff --git a/crates/mir2-analysis/src/lib.rs b/crates/mir2-analysis/src/lib.rs new file mode 100644 index 0000000000..b246e3cfac --- /dev/null +++ b/crates/mir2-analysis/src/lib.rs @@ -0,0 +1,121 @@ +use fe_mir2::{ir, MirDb}; + +#[salsa::jar(db = MirAnalysisDb)] +pub struct Jar(ir::ConstantId, ir::FunctionId); + +pub trait HirAnalysisDb: salsa::DbWithJar + HirDb { + fn as_hir_analysis_db(&self) -> &dyn HirAnalysisDb { + >::as_jar_db::<'_>(self) + } +} +impl HirAnalysisDb for DB where DB: ?Sized + salsa::DbWithJar + HirDb {} + +pub mod name_resolution; +pub mod ty; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Spanned { + pub data: T, + pub span: DynLazySpan, +} + +// old mir db.rs +// +// #![allow(clippy::arc_with_non_send_sync)] +// use std::{collections::BTreeMap, rc::Rc}; + +// use smol_str::SmolStr; + +// use crate::ir::{self, ConstantId, TypeId}; + +// mod queries; + +// #[salsa::query_group(MirDbStorage)] +// pub trait MirDb: AnalyzerDb + Upcast + UpcastMut { +// #[salsa::interned] +// fn mir_intern_const(&self, data: Rc) -> ir::ConstantId; +// #[salsa::interned] +// fn mir_intern_type(&self, data: Rc) -> ir::TypeId; +// #[salsa::interned] +// fn mir_intern_function(&self, data: Rc) -> ir::FunctionId; + +// #[salsa::invoke(queries::module::mir_lower_module_all_functions)] +// fn mir_lower_module_all_functions( +// &self, +// module: analyzer_items::ModuleId, +// ) -> Rc>; + +// #[salsa::invoke(queries::contract::mir_lower_contract_all_functions)] +// fn mir_lower_contract_all_functions( +// &self, +// contract: analyzer_items::ContractId, +// ) -> Rc>; + +// #[salsa::invoke(queries::structs::mir_lower_struct_all_functions)] +// fn mir_lower_struct_all_functions( +// &self, +// struct_: analyzer_items::StructId, +// ) -> Rc>; + +// #[salsa::invoke(queries::enums::mir_lower_enum_all_functions)] +// fn mir_lower_enum_all_functions( +// &self, +// enum_: analyzer_items::EnumId, +// ) -> Rc>; + +// #[salsa::invoke(queries::types::mir_lowered_type)] +// fn mir_lowered_type(&self, analyzer_type: analyzer_types::TypeId) -> TypeId; + +// #[salsa::invoke(queries::constant::mir_lowered_constant)] +// fn mir_lowered_constant(&self, analyzer_const: analyzer_items::ModuleConstantId) -> ConstantId; + +// #[salsa::invoke(queries::function::mir_lowered_func_signature)] +// fn mir_lowered_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_monomorphized_func_signature)] +// fn mir_lowered_monomorphized_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// resolved_generics: BTreeMap, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_pseudo_monomorphized_func_signature)] +// fn mir_lowered_pseudo_monomorphized_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_func_body)] +// fn mir_lowered_func_body(&self, func: ir::FunctionId) -> Rc; +// } + +// #[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage)] +// #[derive(Default)] +// pub struct NewDb { +// storage: salsa::Storage, +// } +// impl salsa::Database for NewDb {} + +// impl Upcast for NewDb { +// fn upcast(&self) -> &(dyn SourceDb + 'static) { +// self +// } +// } + +// impl UpcastMut for NewDb { +// fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for NewDb { +// fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { +// self +// } +// } + +// impl UpcastMut for NewDb { +// fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { +// &mut *self +// } +// } diff --git a/crates/mir2/Cargo.toml b/crates/mir2/Cargo.toml new file mode 100644 index 0000000000..941bf8c549 --- /dev/null +++ b/crates/mir2/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "fe-mir2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +common = { path = "../common2", package = "fe-common2" } +parser = { path = "../parser2", package = "fe-parser2" } +hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } +hir = { path = "../hir", package = "fe-hir" } +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } +smol_str = "0.1.21" +num-bigint = "0.4.3" +num-traits = "0.2.14" +num-integer = "0.1.45" +id-arena = "2.2.1" +fxhash = "0.2.1" +dot2 = "1.0.0" +indexmap = "1.6.2" + +[dev-dependencies] +test-files = { path = "../test-files", package = "fe-test-files" } +library = { path = "../library2" , package = "fe-library2"} diff --git a/crates/mir2/src/analysis/cfg.rs b/crates/mir2/src/analysis/cfg.rs new file mode 100644 index 0000000000..d4de8b2cfa --- /dev/null +++ b/crates/mir2/src/analysis/cfg.rs @@ -0,0 +1,164 @@ +use fxhash::FxHashMap; + +use crate::ir::{BasicBlockId, FunctionBody, InstId}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ControlFlowGraph { + entry: BasicBlockId, + blocks: FxHashMap, + pub(super) exits: Vec, +} + +impl ControlFlowGraph { + pub fn compute(func: &FunctionBody) -> Self { + let entry = func.order.entry(); + let mut cfg = Self { + entry, + blocks: FxHashMap::default(), + exits: vec![], + }; + + for block in func.order.iter_block() { + let terminator = func + .order + .terminator(&func.store, block) + .expect("a block must have terminator"); + cfg.analyze_terminator(func, terminator); + } + + cfg + } + + pub fn entry(&self) -> BasicBlockId { + self.entry + } + + pub fn preds(&self, block: BasicBlockId) -> &[BasicBlockId] { + self.blocks[&block].preds() + } + + pub fn succs(&self, block: BasicBlockId) -> &[BasicBlockId] { + self.blocks[&block].succs() + } + + pub fn post_order(&self) -> CfgPostOrder { + CfgPostOrder::new(self) + } + + pub(super) fn add_edge(&mut self, from: BasicBlockId, to: BasicBlockId) { + self.node_mut(to).push_pred(from); + self.node_mut(from).push_succ(to); + } + + pub(super) fn reverse_edge(&mut self, new_entry: BasicBlockId, new_exits: Vec) { + for (_, block) in self.blocks.iter_mut() { + block.reverse_edge() + } + + self.entry = new_entry; + self.exits = new_exits; + } + + fn analyze_terminator(&mut self, func: &FunctionBody, terminator: InstId) { + let block = func.order.inst_block(terminator); + let branch_info = func.store.branch_info(terminator); + if branch_info.is_not_a_branch() { + self.node_mut(block); + self.exits.push(block) + } else { + for dest in branch_info.block_iter() { + self.add_edge(block, dest) + } + } + } + + fn node_mut(&mut self, block: BasicBlockId) -> &mut BlockNode { + self.blocks.entry(block).or_default() + } +} + +#[derive(Default, Clone, Debug, PartialEq, Eq)] +struct BlockNode { + preds: Vec, + succs: Vec, +} + +impl BlockNode { + fn push_pred(&mut self, pred: BasicBlockId) { + self.preds.push(pred); + } + + fn push_succ(&mut self, succ: BasicBlockId) { + self.succs.push(succ); + } + + fn preds(&self) -> &[BasicBlockId] { + &self.preds + } + + fn succs(&self) -> &[BasicBlockId] { + &self.succs + } + + fn reverse_edge(&mut self) { + std::mem::swap(&mut self.preds, &mut self.succs) + } +} + +pub struct CfgPostOrder<'a> { + cfg: &'a ControlFlowGraph, + node_state: FxHashMap, + stack: Vec, +} + +impl<'a> CfgPostOrder<'a> { + fn new(cfg: &'a ControlFlowGraph) -> Self { + let stack = vec![cfg.entry()]; + + Self { + cfg, + node_state: FxHashMap::default(), + stack, + } + } +} + +impl<'a> Iterator for CfgPostOrder<'a> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + while let Some(&block) = self.stack.last() { + let node_state = self.node_state.entry(block).or_default(); + if *node_state == NodeState::Unvisited { + *node_state = NodeState::Visited; + for &succ in self.cfg.succs(block) { + let pred_state = self.node_state.entry(succ).or_default(); + if *pred_state == NodeState::Unvisited { + self.stack.push(succ); + } + } + } else { + self.stack.pop().unwrap(); + if *node_state != NodeState::Finished { + *node_state = NodeState::Finished; + return Some(block); + } + } + } + + None + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NodeState { + Unvisited, + Visited, + Finished, +} + +impl Default for NodeState { + fn default() -> Self { + Self::Unvisited + } +} diff --git a/crates/mir2/src/analysis/domtree.rs b/crates/mir2/src/analysis/domtree.rs new file mode 100644 index 0000000000..cf30381e46 --- /dev/null +++ b/crates/mir2/src/analysis/domtree.rs @@ -0,0 +1,343 @@ +//! This module contains dominantor tree related structs. +//! +//! The algorithm is based on Keith D. Cooper., Timothy J. Harvey., and Ken +//! Kennedy.: A Simple, Fast Dominance Algorithm: + +use std::collections::BTreeSet; + +use fxhash::FxHashMap; + +use crate::ir::BasicBlockId; + +use super::cfg::ControlFlowGraph; + +#[derive(Debug, Clone)] +pub struct DomTree { + doms: FxHashMap, + /// CFG sorted in reverse post order. + rpo: Vec, +} + +impl DomTree { + pub fn compute(cfg: &ControlFlowGraph) -> Self { + let mut doms = FxHashMap::default(); + doms.insert(cfg.entry(), cfg.entry()); + let mut rpo: Vec<_> = cfg.post_order().collect(); + rpo.reverse(); + + let mut domtree = Self { doms, rpo }; + + let block_num = domtree.rpo.len(); + + let mut rpo_nums = FxHashMap::default(); + for (i, &block) in domtree.rpo.iter().enumerate() { + rpo_nums.insert(block, (block_num - i) as u32); + } + + let mut changed = true; + while changed { + changed = false; + for &block in domtree.rpo.iter().skip(1) { + let processed_pred = match cfg + .preds(block) + .iter() + .find(|pred| domtree.doms.contains_key(pred)) + { + Some(pred) => *pred, + _ => continue, + }; + let mut new_dom = processed_pred; + + for &pred in cfg.preds(block) { + if pred != processed_pred && domtree.doms.contains_key(&pred) { + new_dom = domtree.intersect(new_dom, pred, &rpo_nums); + } + } + if Some(new_dom) != domtree.doms.get(&block).copied() { + changed = true; + domtree.doms.insert(block, new_dom); + } + } + } + + domtree + } + + /// Returns the immediate dominator of the `block`. + /// Returns None if the `block` is unreachable from the entry block, or the + /// `block` is the entry block itself. + pub fn idom(&self, block: BasicBlockId) -> Option { + if self.rpo[0] == block { + return None; + } + self.doms.get(&block).copied() + } + + /// Returns `true` if block1 strictly dominates block2. + pub fn strictly_dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool { + let mut current_block = block2; + while let Some(block) = self.idom(current_block) { + if block == block1 { + return true; + } + current_block = block; + } + + false + } + + /// Returns `true` if block1 dominates block2. + pub fn dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool { + if block1 == block2 { + return true; + } + + self.strictly_dominates(block1, block2) + } + + /// Returns `true` if block is reachable from the entry block. + pub fn is_reachable(&self, block: BasicBlockId) -> bool { + self.idom(block).is_some() + } + + /// Returns blocks in RPO. + pub fn rpo(&self) -> &[BasicBlockId] { + &self.rpo + } + + fn intersect( + &self, + mut b1: BasicBlockId, + mut b2: BasicBlockId, + rpo_nums: &FxHashMap, + ) -> BasicBlockId { + while b1 != b2 { + while rpo_nums[&b1] < rpo_nums[&b2] { + b1 = self.doms[&b1]; + } + while rpo_nums[&b2] < rpo_nums[&b1] { + b2 = self.doms[&b2] + } + } + + b1 + } + + /// Compute dominance frontiers of each blocks. + pub fn compute_df(&self, cfg: &ControlFlowGraph) -> DFSet { + let mut df = DFSet::default(); + + for &block in &self.rpo { + let preds = cfg.preds(block); + if preds.len() < 2 { + continue; + } + + for pred in preds { + let mut runner = *pred; + while self.doms.get(&block) != Some(&runner) && self.is_reachable(runner) { + df.0.entry(runner).or_default().insert(block); + runner = self.doms[&runner]; + } + } + } + + df + } +} + +/// Dominance frontiers of each blocks. +#[derive(Default, Debug)] +pub struct DFSet(FxHashMap>); + +impl DFSet { + /// Returns all dominance frontieres of a `block`. + pub fn frontiers( + &self, + block: BasicBlockId, + ) -> Option + '_> { + self.0.get(&block).map(|set| set.iter().copied()) + } + + /// Returns number of frontier blocks of a `block`. + pub fn frontier_num(&self, block: BasicBlockId) -> usize { + self.0.get(&block).map(BTreeSet::len).unwrap_or(0) + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; + +// use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; + +// fn calc_dom(func: &FunctionBody) -> (DomTree, DFSet) { +// let cfg = ControlFlowGraph::compute(func); +// let domtree = DomTree::compute(&cfg); +// let df = domtree.compute_df(&cfg); +// (domtree, df) +// } + +// fn body_builder() -> BodyBuilder { +// BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// } + +// #[test] +// fn dom_tree_if_else() { +// let mut builder = body_builder(); + +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, df) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(then_block), Some(entry_block)); +// assert_eq!(dom_tree.idom(else_block), Some(entry_block)); +// assert_eq!(dom_tree.idom(merge_block), Some(entry_block)); + +// assert_eq!(df.frontier_num(entry_block), 0); +// assert_eq!(df.frontier_num(then_block), 1); +// assert_eq!( +// df.frontiers(then_block).unwrap().next().unwrap(), +// merge_block +// ); +// assert_eq!( +// df.frontiers(else_block).unwrap().next().unwrap(), +// merge_block +// ); +// assert_eq!(df.frontier_num(merge_block), 0); +// } + +// #[test] +// fn unreachable_edge() { +// let mut builder = body_builder(); + +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, _) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(block1), Some(entry_block)); +// assert_eq!(dom_tree.idom(block2), Some(entry_block)); +// assert_eq!(dom_tree.idom(block3), None); +// assert!(!dom_tree.is_reachable(block3)); +// assert_eq!(dom_tree.idom(block4), Some(entry_block)); +// } + +// #[test] +// fn dom_tree_complex() { +// let mut builder = body_builder(); + +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); +// let block5 = builder.make_block(); +// let block6 = builder.make_block(); +// let block7 = builder.make_block(); +// let block8 = builder.make_block(); +// let block9 = builder.make_block(); +// let block10 = builder.make_block(); +// let block11 = builder.make_block(); +// let block12 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, block2, block1, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.branch(v0, block6, block3, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block7, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.branch(v0, block6, block5, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.branch(v0, block7, block2, SourceInfo::dummy()); + +// builder.move_to_block(block5); +// builder.branch(v0, block10, block8, SourceInfo::dummy()); + +// builder.move_to_block(block6); +// builder.jump(block9, SourceInfo::dummy()); + +// builder.move_to_block(block7); +// builder.jump(block12, SourceInfo::dummy()); + +// builder.move_to_block(block8); +// builder.jump(block11, SourceInfo::dummy()); + +// builder.move_to_block(block9); +// builder.jump(block8, SourceInfo::dummy()); + +// builder.move_to_block(block10); +// builder.jump(block11, SourceInfo::dummy()); + +// builder.move_to_block(block11); +// builder.branch(v0, block12, block2, SourceInfo::dummy()); + +// builder.move_to_block(block12); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, _) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(block1), Some(entry_block)); +// assert_eq!(dom_tree.idom(block2), Some(entry_block)); +// assert_eq!(dom_tree.idom(block3), Some(block1)); +// assert_eq!(dom_tree.idom(block4), Some(block2)); +// assert_eq!(dom_tree.idom(block5), Some(block3)); +// assert_eq!(dom_tree.idom(block6), Some(block1)); +// assert_eq!(dom_tree.idom(block7), Some(block2)); +// assert_eq!(dom_tree.idom(block8), Some(block1)); +// assert_eq!(dom_tree.idom(block9), Some(block6)); +// assert_eq!(dom_tree.idom(block10), Some(block5)); +// assert_eq!(dom_tree.idom(block11), Some(block1)); +// assert_eq!(dom_tree.idom(block12), Some(entry_block)); +// } +// } diff --git a/crates/mir2/src/analysis/loop_tree.rs b/crates/mir2/src/analysis/loop_tree.rs new file mode 100644 index 0000000000..c818019b02 --- /dev/null +++ b/crates/mir2/src/analysis/loop_tree.rs @@ -0,0 +1,352 @@ +use id_arena::{Arena, Id}; + +use fxhash::FxHashMap; + +use super::{cfg::ControlFlowGraph, domtree::DomTree}; + +use crate::ir::BasicBlockId; + +#[derive(Debug, Default, Clone)] +pub struct LoopTree { + /// Stores loops. + /// The index of an outer loops is guaranteed to be lower than its inner + /// loops because loops are found in RPO. + loops: Arena, + + /// Maps blocks to its contained loop. + /// If the block is contained by multiple nested loops, then the block is + /// mapped to the innermost loop. + block_to_loop: FxHashMap, +} + +pub type LoopId = Id; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Loop { + /// A header of the loop. + pub header: BasicBlockId, + + /// A parent loop that includes the loop. + pub parent: Option, + + /// Child loops that the loop includes. + pub children: Vec, +} + +impl LoopTree { + pub fn compute(cfg: &ControlFlowGraph, domtree: &DomTree) -> Self { + let mut tree = LoopTree::default(); + + // Find loop headers in RPO, this means outer loops are guaranteed to be + // inserted first, then its inner loops are inserted. + for &block in domtree.rpo() { + for &pred in cfg.preds(block) { + if domtree.dominates(block, pred) { + let loop_data = Loop { + header: block, + parent: None, + children: Vec::new(), + }; + + tree.loops.alloc(loop_data); + break; + } + } + } + + tree.analyze_loops(cfg, domtree); + + tree + } + + /// Returns all blocks in the loop. + pub fn iter_blocks_post_order<'a, 'b>( + &'a self, + cfg: &'b ControlFlowGraph, + lp: LoopId, + ) -> BlocksInLoopPostOrder<'a, 'b> { + BlocksInLoopPostOrder::new(self, cfg, lp) + } + + /// Returns all loops in a function body. + /// An outer loop is guaranteed to be iterated before its inner loops. + pub fn loops(&self) -> impl Iterator + '_ { + self.loops.iter().map(|(id, _)| id) + } + + /// Returns number of loops found. + pub fn loop_num(&self) -> usize { + self.loops.len() + } + + /// Returns `true` if the `block` is in the `lp`. + pub fn is_block_in_loop(&self, block: BasicBlockId, lp: LoopId) -> bool { + let mut loop_of_block = self.loop_of_block(block); + while let Some(cur_lp) = loop_of_block { + if lp == cur_lp { + return true; + } + loop_of_block = self.parent_loop(cur_lp); + } + false + } + + /// Returns header block of the `lp`. + pub fn loop_header(&self, lp: LoopId) -> BasicBlockId { + self.loops[lp].header + } + + /// Get parent loop of the `lp` if exists. + pub fn parent_loop(&self, lp: LoopId) -> Option { + self.loops[lp].parent + } + + /// Returns the loop that the `block` belongs to. + /// If the `block` belongs to multiple loops, then returns the innermost + /// loop. + pub fn loop_of_block(&self, block: BasicBlockId) -> Option { + self.block_to_loop.get(&block).copied() + } + + /// Analyze loops. This method does + /// 1. Mapping each blocks to its contained loop. + /// 2. Setting parent and child of the loops. + fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) { + let mut worklist = vec![]; + + // Iterate loops reversely to ensure analyze inner loops first. + let loops_rev: Vec<_> = self.loops.iter().rev().map(|(id, _)| id).collect(); + for cur_lp in loops_rev { + let cur_lp_header = self.loop_header(cur_lp); + + // Add predecessors of the loop header to worklist. + for &block in cfg.preds(cur_lp_header) { + if domtree.dominates(cur_lp_header, block) { + worklist.push(block); + } + } + + while let Some(block) = worklist.pop() { + match self.block_to_loop.get(&block).copied() { + Some(lp_of_block) => { + let outermost_parent = self.outermost_parent(lp_of_block); + + // If outermost parent is current loop, then the block is already visited. + if outermost_parent == cur_lp { + continue; + } else { + self.loops[cur_lp].children.push(outermost_parent); + self.loops[outermost_parent].parent = cur_lp.into(); + + let lp_header_of_block = self.loop_header(lp_of_block); + worklist.extend(cfg.preds(lp_header_of_block)); + } + } + + // If the block is not mapped to any loops, then map it to the loop. + None => { + self.map_block(block, cur_lp); + // If block is not loop header, then add its predecessors to the worklist. + if block != cur_lp_header { + worklist.extend(cfg.preds(block)); + } + } + } + } + } + } + + /// Returns the outermost parent loop of `lp`. If `lp` doesn't have any + /// parent, then returns `lp` itself. + fn outermost_parent(&self, mut lp: LoopId) -> LoopId { + while let Some(parent) = self.parent_loop(lp) { + lp = parent; + } + lp + } + + /// Map `block` to `lp`. + fn map_block(&mut self, block: BasicBlockId, lp: LoopId) { + self.block_to_loop.insert(block, lp); + } +} + +pub struct BlocksInLoopPostOrder<'a, 'b> { + lpt: &'a LoopTree, + cfg: &'b ControlFlowGraph, + lp: LoopId, + stack: Vec, + block_state: FxHashMap, +} + +impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> { + fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: LoopId) -> Self { + let loop_header = lpt.loop_header(lp); + + Self { + lpt, + cfg, + lp, + stack: vec![loop_header], + block_state: FxHashMap::default(), + } + } +} + +impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + while let Some(&block) = self.stack.last() { + match self.block_state.get(&block) { + // The block is already visited, but not returned from the iterator, + // so mark the block as `Finished` and return the block. + Some(BlockState::Visited) => { + let block = self.stack.pop().unwrap(); + self.block_state.insert(block, BlockState::Finished); + return Some(block); + } + + // The block is already returned, so just remove the block from the stack. + Some(BlockState::Finished) => { + self.stack.pop().unwrap(); + } + + // The block is not visited yet, so push its unvisited in-loop successors to the + // stack and mark the block as `Visited`. + None => { + self.block_state.insert(block, BlockState::Visited); + for &succ in self.cfg.succs(block) { + if self.block_state.get(&succ).is_none() + && self.lpt.is_block_in_loop(succ, self.lp) + { + self.stack.push(succ); + } + } + } + } + } + + None + } +} + +enum BlockState { + Visited, + Finished, +} + +// #[cfg(test)] +// mod tests { +// use super::*; + +// // use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; +// use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId}; + +// fn compute_loop(func: &FunctionBody) -> LoopTree { +// let cfg = ControlFlowGraph::compute(func); +// let domtree = DomTree::compute(&cfg); +// LoopTree::compute(&cfg, &domtree) +// } + +// fn body_builder() -> BodyBuilder { +// // BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// BodyBuilder::new(FunctionId(0)) +// } + +// #[test] +// fn simple_loop() { +// let mut builder = body_builder(); + +// let entry = builder.current_block(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(false, dummy_ty); +// // builder.branch(v0, block1, block2, SourceInfo::dummy()); +// builder.branch(v0, block1, block2); + +// builder.move_to_block(block1); +// // builder.jump(entry, SourceInfo::dummy()); +// builder.jump(entry); + +// builder.move_to_block(block2); +// let dummy_value = builder.make_unit(dummy_ty); +// // builder.ret(dummy_value, SourceInfo::dummy()); +// builder.ret(dummy_value); + +// let func = builder.build(); + +// let lpt = compute_loop(&func); + +// assert_eq!(lpt.loop_num(), 1); +// let lp = lpt.loops().next().unwrap(); + +// assert!(lpt.is_block_in_loop(entry, lp)); +// assert_eq!(lpt.loop_of_block(entry), Some(lp)); + +// assert!(lpt.is_block_in_loop(block1, lp)); +// assert_eq!(lpt.loop_of_block(block1), Some(lp)); + +// assert!(!lpt.is_block_in_loop(block2, lp)); +// assert!(lpt.loop_of_block(block2).is_none()); + +// assert_eq!(lpt.loop_header(lp), entry); +// } + +// #[test] +// fn nested_loop() { +// let mut builder = body_builder(); + +// let entry = builder.current_block(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(false, dummy_ty); +// builder.branch(v0, block1, block3, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.branch(v0, entry, block2, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.jump(block1, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let lpt = compute_loop(&func); + +// assert_eq!(lpt.loop_num(), 2); +// let mut loops = lpt.loops(); +// let outer_lp = loops.next().unwrap(); +// let inner_lp = loops.next().unwrap(); + +// assert!(lpt.is_block_in_loop(entry, outer_lp)); +// assert!(!lpt.is_block_in_loop(entry, inner_lp)); +// assert_eq!(lpt.loop_of_block(entry), Some(outer_lp)); + +// assert!(lpt.is_block_in_loop(block1, outer_lp)); +// assert!(lpt.is_block_in_loop(block1, inner_lp)); +// assert_eq!(lpt.loop_of_block(block1), Some(inner_lp)); + +// assert!(lpt.is_block_in_loop(block2, outer_lp)); +// assert!(lpt.is_block_in_loop(block2, inner_lp)); +// assert_eq!(lpt.loop_of_block(block2), Some(inner_lp)); + +// assert!(!lpt.is_block_in_loop(block3, outer_lp)); +// assert!(!lpt.is_block_in_loop(block3, inner_lp)); +// assert!(lpt.loop_of_block(block3).is_none()); + +// assert!(lpt.parent_loop(outer_lp).is_none()); +// assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp)); + +// assert_eq!(lpt.loop_header(outer_lp), entry); +// assert_eq!(lpt.loop_header(inner_lp), block1); +// } +// } diff --git a/crates/mir2/src/analysis/mod.rs b/crates/mir2/src/analysis/mod.rs new file mode 100644 index 0000000000..b895cc02a7 --- /dev/null +++ b/crates/mir2/src/analysis/mod.rs @@ -0,0 +1,9 @@ +pub mod cfg; +pub mod domtree; +pub mod loop_tree; +pub mod post_domtree; + +pub use cfg::ControlFlowGraph; +pub use domtree::DomTree; +pub use loop_tree::LoopTree; +pub use post_domtree::PostDomTree; diff --git a/crates/mir2/src/analysis/post_domtree.rs b/crates/mir2/src/analysis/post_domtree.rs new file mode 100644 index 0000000000..9d034d2bb6 --- /dev/null +++ b/crates/mir2/src/analysis/post_domtree.rs @@ -0,0 +1,284 @@ +//! This module contains implementation of `Post Dominator Tree`. + +use id_arena::{ArenaBehavior, DefaultArenaBehavior}; + +use super::{cfg::ControlFlowGraph, domtree::DomTree}; + +use crate::ir::{BasicBlock, BasicBlockId, FunctionBody}; + +#[derive(Debug)] +pub struct PostDomTree { + /// Dummy entry block to calculate post dom tree. + dummy_entry: BasicBlockId, + /// Canonical dummy exit block to calculate post dom tree. All blocks ends + /// with `return` has an edge to this block. + dummy_exit: BasicBlockId, + + /// Dominator tree of reverse control flow graph. + domtree: DomTree, +} + +impl PostDomTree { + pub fn compute(func: &FunctionBody) -> Self { + let mut rcfg = ControlFlowGraph::compute(func); + + let real_entry = rcfg.entry(); + + let dummy_entry = Self::make_dummy_block(); + let dummy_exit = Self::make_dummy_block(); + // Add edges from dummy entry block to real entry block and dummy exit block. + rcfg.add_edge(dummy_entry, real_entry); + rcfg.add_edge(dummy_entry, dummy_exit); + + // Add edges from real exit blocks to dummy exit block. + for exit in std::mem::take(&mut rcfg.exits) { + rcfg.add_edge(exit, dummy_exit); + } + + rcfg.reverse_edge(dummy_exit, vec![dummy_entry]); + let domtree = DomTree::compute(&rcfg); + + Self { + dummy_entry, + dummy_exit, + domtree, + } + } + + pub fn post_idom(&self, block: BasicBlockId) -> PostIDom { + match self.domtree.idom(block).unwrap() { + block if block == self.dummy_entry => PostIDom::DummyEntry, + block if block == self.dummy_exit => PostIDom::DummyExit, + other => PostIDom::Block(other), + } + } + + /// Returns `true` if block is reachable from the exit blocks. + pub fn is_reachable(&self, block: BasicBlockId) -> bool { + self.domtree.is_reachable(block) + } + + fn make_dummy_block() -> BasicBlockId { + let arena_id = DefaultArenaBehavior::::new_arena_id(); + DefaultArenaBehavior::new_id(arena_id, 0) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PostIDom { + DummyEntry, + DummyExit, + Block(BasicBlockId), +} + +// #[cfg(test)] +// mod tests { +// use super::*; + +// use crate::ir::{body_builder::BodyBuilder, FunctionId, SourceInfo, TypeId}; + +// fn body_builder() -> BodyBuilder { +// BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// } + +// #[test] +// fn test_if_else_merge() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!( +// post_dom_tree.post_idom(else_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_if_else_return() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let dummy_value = builder.make_unit(dummy_ty); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!(post_dom_tree.post_idom(entry_block), PostIDom::DummyExit,); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!(post_dom_tree.post_idom(else_block), PostIDom::DummyExit); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_if_non_else() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let dummy_value = builder.make_unit(dummy_ty); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, merge_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_loop() { +// let mut builder = body_builder(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); + +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block3, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block3, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.jump(block2, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(block3), +// ); +// assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block3)); +// assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block3)); +// assert_eq!(post_dom_tree.post_idom(block3), PostIDom::DummyExit); +// assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block2)); +// } + +// #[test] +// fn test_pd_complex() { +// let mut builder = body_builder(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); +// let block5 = builder.make_block(); +// let block6 = builder.make_block(); +// let block7 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); + +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block6, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block3, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.jump(block5, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.jump(block5, SourceInfo::dummy()); + +// builder.move_to_block(block5); +// builder.jump(block6, SourceInfo::dummy()); + +// builder.move_to_block(block6); +// builder.jump(block7, SourceInfo::dummy()); + +// builder.move_to_block(block7); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(block6), +// ); +// assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block6)); +// assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block3), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block5), PostIDom::Block(block6)); +// assert_eq!(post_dom_tree.post_idom(block6), PostIDom::Block(block7)); +// assert_eq!(post_dom_tree.post_idom(block7), PostIDom::DummyExit); +// } +// } diff --git a/crates/mir2/src/graphviz/block.rs b/crates/mir2/src/graphviz/block.rs new file mode 100644 index 0000000000..9121d8e281 --- /dev/null +++ b/crates/mir2/src/graphviz/block.rs @@ -0,0 +1,62 @@ +use std::fmt::Write; + +use dot2::{label, Id}; + +use crate::{ + analysis::ControlFlowGraph, + db::MirDb, + ir::{BasicBlockId, FunctionId}, + pretty_print::PrettyPrint, +}; + +#[derive(Debug, Clone, Copy)] +pub(super) struct BlockNode { + func: FunctionId, + pub block: BasicBlockId, +} + +impl BlockNode { + pub(super) fn new(func: FunctionId, block: BasicBlockId) -> Self { + Self { func, block } + } + pub(super) fn id(self) -> dot2::Result> { + Id::new(format!("fn{}_bb{}", self.func.0, self.block.index())) + } + + pub(super) fn label(self, db: &dyn MirDb) -> label::Text<'static> { + let mut label = r#""#.to_string(); + + // Write block header. + write!( + &mut label, + r#""#, + self.block.index() + ) + .unwrap(); + + // Write block body. + let func_body = self.func.body(db); + write!(label, r#""#).unwrap(); + + write!(label, "
BB{}
"#).unwrap(); + for inst in func_body.order.iter_inst(self.block) { + let mut inst_string = String::new(); + inst.pretty_print(db, &func_body.store, &mut inst_string) + .unwrap(); + write!(label, "{}", dot2::escape_html(&inst_string)).unwrap(); + write!(label, "
").unwrap(); + } + write!(label, r#"
").unwrap(); + + label::Text::HtmlStr(label.into()) + } + + pub(super) fn succs(self, db: &dyn MirDb) -> Vec { + let func_body = self.func.body(db); + let cfg = ControlFlowGraph::compute(&func_body); + cfg.succs(self.block) + .iter() + .map(|block| Self::new(self.func, *block)) + .collect() + } +} diff --git a/crates/mir2/src/graphviz/function.rs b/crates/mir2/src/graphviz/function.rs new file mode 100644 index 0000000000..fa78d21719 --- /dev/null +++ b/crates/mir2/src/graphviz/function.rs @@ -0,0 +1,78 @@ +use std::fmt::Write; + +use dot2::{label, Id}; + +use crate::{analysis::ControlFlowGraph, db::MirDb, ir::FunctionId, pretty_print::PrettyPrint}; + +use super::block::BlockNode; + +#[derive(Debug, Clone, Copy)] +pub(super) struct FunctionNode { + pub(super) func: FunctionId, +} + +impl FunctionNode { + pub(super) fn new(func: FunctionId) -> Self { + Self { func } + } + + pub(super) fn subgraph_id(self) -> Option> { + dot2::Id::new(format!("cluster_{}", self.func.0)).ok() + } + + pub(super) fn label(self, db: &dyn MirDb) -> label::Text<'static> { + let mut label = self.signature(db); + write!(label, r#"

"#).unwrap(); + + // Maps local value id to local name. + let body = self.func.body(db); + for local in body.store.locals() { + local.pretty_print(db, &body.store, &mut label).unwrap(); + write!( + label, + r#" => {}
"#, + body.store.local_name(*local).unwrap() + ) + .unwrap(); + } + + label::Text::HtmlStr(label.into()) + } + + pub(super) fn blocks(self, db: &dyn MirDb) -> Vec { + let body = self.func.body(db); + // We use control flow graph to collect reachable blocks. + let cfg = ControlFlowGraph::compute(&body); + cfg.post_order() + .map(|block| BlockNode::new(self.func, block)) + .collect() + } + + fn signature(self, db: &dyn MirDb) -> String { + let body = self.func.body(db); + + let sig_data = self.func.signature(db); + let mut sig = format!("fn {}(", self.func.debug_name(db)); + + let params = &sig_data.params; + let param_len = params.len(); + for (i, param) in params.iter().enumerate() { + let name = ¶m.name; + let ty = param.ty; + write!(&mut sig, "{name}: ").unwrap(); + ty.pretty_print(db, &body.store, &mut sig).unwrap(); + if param_len - 1 != i { + write!(sig, ", ").unwrap(); + } + } + write!(sig, ")").unwrap(); + + let ret_ty = self.func.return_type(db); + if let Some(ret_ty) = ret_ty { + write!(sig, " -> ").unwrap(); + ret_ty.pretty_print(db, &body.store, &mut sig).unwrap(); + } + + dot2::escape_html(&sig) + } +} diff --git a/crates/mir2/src/graphviz/mod.rs b/crates/mir2/src/graphviz/mod.rs new file mode 100644 index 0000000000..c79335a04e --- /dev/null +++ b/crates/mir2/src/graphviz/mod.rs @@ -0,0 +1,22 @@ +use std::io; + +use fe_analyzer2::namespace::items::ModuleId; + +use crate::db::MirDb; + +mod block; +mod function; +mod module; + +/// Writes mir graphs of functions in a `module`. +pub fn write_mir_graphs( + db: &dyn MirDb, + module: ModuleId, + w: &mut W, +) -> io::Result<()> { + let module_graph = module::ModuleGraph::new(db, module); + dot2::render(&module_graph, w).map_err(|err| match err { + dot2::Error::Io(err) => err, + _ => panic!("invalid graphviz id"), + }) +} diff --git a/crates/mir2/src/graphviz/module.rs b/crates/mir2/src/graphviz/module.rs new file mode 100644 index 0000000000..8280e76c7f --- /dev/null +++ b/crates/mir2/src/graphviz/module.rs @@ -0,0 +1,158 @@ +use dot2::{label::Text, GraphWalk, Id, Kind, Labeller}; +use fe_analyzer2::namespace::items::ModuleId; + +use crate::{ + db::MirDb, + ir::{inst::BranchInfo, FunctionId}, + pretty_print::PrettyPrint, +}; + +use super::{block::BlockNode, function::FunctionNode}; + +pub(super) struct ModuleGraph<'db> { + db: &'db dyn MirDb, + module: ModuleId, +} + +impl<'db> ModuleGraph<'db> { + pub(super) fn new(db: &'db dyn MirDb, module: ModuleId) -> Self { + Self { db, module } + } +} + +impl<'db> GraphWalk<'db> for ModuleGraph<'db> { + type Node = BlockNode; + type Edge = ModuleGraphEdge; + type Subgraph = FunctionNode; + + fn nodes(&self) -> dot2::Nodes<'db, Self::Node> { + let mut nodes = Vec::new(); + + // Collect function nodes. + for func in self + .db + .mir_lower_module_all_functions(self.module) + .iter() + .map(|id| FunctionNode::new(*id)) + { + nodes.extend(func.blocks(self.db).into_iter()) + } + + nodes.into() + } + + fn edges(&self) -> dot2::Edges<'db, Self::Edge> { + let mut edges = vec![]; + for func in self.db.mir_lower_module_all_functions(self.module).iter() { + for block in FunctionNode::new(*func).blocks(self.db) { + for succ in block.succs(self.db) { + let edge = ModuleGraphEdge { + from: block, + to: succ, + func: *func, + }; + edges.push(edge); + } + } + } + + edges.into() + } + + fn source(&self, edge: &Self::Edge) -> Self::Node { + edge.from + } + + fn target(&self, edge: &Self::Edge) -> Self::Node { + edge.to + } + + fn subgraphs(&self) -> dot2::Subgraphs<'db, Self::Subgraph> { + self.db + .mir_lower_module_all_functions(self.module) + .iter() + .map(|id| FunctionNode::new(*id)) + .collect::>() + .into() + } + + fn subgraph_nodes(&self, s: &Self::Subgraph) -> dot2::Nodes<'db, Self::Node> { + s.blocks(self.db).into_iter().collect::>().into() + } +} + +impl<'db> Labeller<'db> for ModuleGraph<'db> { + type Node = BlockNode; + type Edge = ModuleGraphEdge; + type Subgraph = FunctionNode; + + fn graph_id(&self) -> dot2::Result> { + let module_name = self.module.name(self.db.upcast()); + dot2::Id::new(module_name.to_string()) + } + + fn node_id(&self, n: &Self::Node) -> dot2::Result> { + n.id() + } + + fn node_shape(&self, _n: &Self::Node) -> Option> { + Some(Text::LabelStr("none".into())) + } + + fn node_label(&self, n: &Self::Node) -> dot2::Result> { + Ok(n.label(self.db)) + } + + fn edge_label<'a>(&self, e: &Self::Edge) -> Text<'db> { + Text::LabelStr(e.label(self.db).into()) + } + + fn subgraph_id(&self, s: &Self::Subgraph) -> Option> { + s.subgraph_id() + } + + fn subgraph_label(&self, s: &Self::Subgraph) -> Text<'db> { + s.label(self.db) + } + + fn kind(&self) -> Kind { + Kind::Digraph + } +} + +#[derive(Debug, Clone)] +pub(super) struct ModuleGraphEdge { + from: BlockNode, + to: BlockNode, + func: FunctionId, +} + +impl ModuleGraphEdge { + fn label(&self, db: &dyn MirDb) -> String { + let body = self.func.body(db); + let terminator = body.order.terminator(&body.store, self.from.block).unwrap(); + let to = self.to.block; + match body.store.branch_info(terminator) { + BranchInfo::NotBranch => unreachable!(), + BranchInfo::Jump(_) => String::new(), + BranchInfo::Branch(_, true_bb, _) => { + format! {"{}", true_bb == to} + } + BranchInfo::Switch(_, table, default) => { + if default == Some(to) { + return "*".to_string(); + } + + for (value, bb) in table.iter() { + if bb == to { + let mut s = String::new(); + value.pretty_print(db, &body.store, &mut s).unwrap(); + return s; + } + } + + unreachable!() + } + } + } +} diff --git a/crates/mir2/src/ir/basic_block.rs b/crates/mir2/src/ir/basic_block.rs new file mode 100644 index 0000000000..359c4c76f6 --- /dev/null +++ b/crates/mir2/src/ir/basic_block.rs @@ -0,0 +1,6 @@ +use id_arena::Id; + +pub type BasicBlockId = Id; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BasicBlock {} diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs new file mode 100644 index 0000000000..602ccdf496 --- /dev/null +++ b/crates/mir2/src/ir/body_builder.rs @@ -0,0 +1,381 @@ +use hir::hir_def::TypeId; +use num_bigint::BigInt; + +use crate::ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::{BinOp, Inst, InstKind, UnOp}, + value::{AssignableValue, Local}, + BasicBlock, BasicBlockId, FunctionBody, FunctionId, InstId, +}; + +use super::{ + inst::{CallType, CastKind, SwitchTable, YulIntrinsicOp}, + ConstantId, Value, ValueId, +}; + +#[derive(Debug)] +pub struct BodyBuilder { + pub body: FunctionBody, + loc: CursorLocation, +} + +macro_rules! impl_unary_inst { + ($name:ident, $code:path) => { + // pub fn $name(&mut self, value: ValueId, source: SourceInfo) -> InstId { + pub fn $name(&mut self, value: ValueId) -> InstId { + // let inst = Inst::unary($code, value, source); + let inst = Inst::unary($code, value); + self.insert_inst(inst) + } + }; +} + +macro_rules! impl_binary_inst { + ($name:ident, $code:path) => { + // pub fn $name(&mut self, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> InstId { + pub fn $name(&mut self, lhs: ValueId, rhs: ValueId) -> InstId { + // let inst = Inst::binary($code, lhs, rhs, source); + let inst = Inst::binary($code, lhs, rhs); + self.insert_inst(inst) + } + }; +} + +impl BodyBuilder { + // pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + pub fn new(fid: FunctionId) -> Self { + // let body = FunctionBody::new(fid, source); + let body = FunctionBody::new(fid); + let entry_block = body.order.entry(); + Self { + body, + loc: CursorLocation::BlockTop(entry_block), + } + } + + pub fn build(self) -> FunctionBody { + self.body + } + + pub fn func_id(&self) -> FunctionId { + self.body.fid + } + + pub fn make_block(&mut self) -> BasicBlockId { + let block = BasicBlock {}; + let block_id = self.body.store.store_block(block); + self.body.order.append_block(block_id); + block_id + } + + pub fn make_value(&mut self, value: impl Into) -> ValueId { + self.body.store.store_value(value.into()) + } + + pub fn map_result(&mut self, inst: InstId, result: AssignableValue) { + self.body.store.map_result(inst, result) + } + + pub fn inst_result(&mut self, inst: InstId) -> Option<&AssignableValue> { + self.body.store.inst_result(inst) + } + + pub fn move_to_block(&mut self, block: BasicBlockId) { + self.loc = CursorLocation::BlockBottom(block) + } + + pub fn move_to_block_top(&mut self, block: BasicBlockId) { + self.loc = CursorLocation::BlockTop(block) + } + + pub fn make_unit(&mut self, unit_ty: TypeId) -> ValueId { + self.body.store.store_value(Value::Unit { ty: unit_ty }) + } + + pub fn make_imm(&mut self, imm: BigInt, ty: TypeId) -> ValueId { + self.body.store.store_value(Value::Immediate { imm, ty }) + } + + pub fn make_imm_from_bool(&mut self, imm: bool, ty: TypeId) -> ValueId { + if imm { + self.make_imm(1u8.into(), ty) + } else { + self.make_imm(0u8.into(), ty) + } + } + + pub fn make_constant(&mut self, constant: ConstantId, ty: TypeId) -> ValueId { + self.body + .store + .store_value(Value::Constant { constant, ty }) + } + + pub fn declare(&mut self, local: Local) -> ValueId { + // let source = local.source.clone(); + let local_id = self.body.store.store_value(Value::Local(local)); + + let kind = InstKind::Declare { local: local_id }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst); + local_id + } + + pub fn store_func_arg(&mut self, local: Local) -> ValueId { + self.body.store.store_value(Value::Local(local)) + } + + impl_unary_inst!(not, UnOp::Not); + impl_unary_inst!(neg, UnOp::Neg); + impl_unary_inst!(inv, UnOp::Inv); + + impl_binary_inst!(add, BinOp::Add); + impl_binary_inst!(sub, BinOp::Sub); + impl_binary_inst!(mul, BinOp::Mul); + impl_binary_inst!(div, BinOp::Div); + impl_binary_inst!(modulo, BinOp::Mod); + impl_binary_inst!(pow, BinOp::Pow); + impl_binary_inst!(shl, BinOp::Shl); + impl_binary_inst!(shr, BinOp::Shr); + impl_binary_inst!(bit_or, BinOp::BitOr); + impl_binary_inst!(bit_xor, BinOp::BitXor); + impl_binary_inst!(bit_and, BinOp::BitAnd); + impl_binary_inst!(logical_and, BinOp::LogicalAnd); + impl_binary_inst!(logical_or, BinOp::LogicalOr); + impl_binary_inst!(eq, BinOp::Eq); + impl_binary_inst!(ne, BinOp::Ne); + impl_binary_inst!(ge, BinOp::Ge); + impl_binary_inst!(gt, BinOp::Gt); + impl_binary_inst!(le, BinOp::Le); + impl_binary_inst!(lt, BinOp::Lt); + + pub fn primitive_cast( + &mut self, + value: ValueId, + result_ty: TypeId, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::Cast { + kind: CastKind::Primitive, + value, + to: result_ty, + }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId, source: SourceInfo) -> InstId { + pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId) -> InstId { + let kind = InstKind::Cast { + kind: CastKind::Untag, + value, + to: result_ty, + }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn aggregate_construct( + &mut self, + ty: TypeId, + args: Vec, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::AggregateConstruct { ty, args }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn bind(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn bind(&mut self, src: ValueId) -> InstId { + let kind = InstKind::Bind { src }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn mem_copy(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn mem_copy(&mut self, src: ValueId) -> InstId { + let kind = InstKind::MemCopy { src }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn load(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn load(&mut self, src: ValueId) -> InstId { + let kind = InstKind::Load { src }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn aggregate_access( + &mut self, + value: ValueId, + indices: Vec, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::AggregateAccess { value, indices }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn map_access(&mut self, value: ValueId, key: ValueId, source: SourceInfo) -> InstId { + pub fn map_access(&mut self, value: ValueId, key: ValueId) -> InstId { + let kind = InstKind::MapAccess { value, key }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn call( + &mut self, + func: FunctionId, + args: Vec, + call_type: CallType, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::Call { + func, + args, + call_type, + }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn yul_intrinsic( + &mut self, + op: YulIntrinsicOp, + args: Vec, + // source: SourceInfo, + ) -> InstId { + // let inst = Inst::intrinsic(op, args, source); + let inst = Inst::intrinsic(op, args); + self.insert_inst(inst) + } + + // pub fn jump(&mut self, dest: BasicBlockId, source: SourceInfo) -> InstId { + pub fn jump(&mut self, dest: BasicBlockId) -> InstId { + let kind = InstKind::Jump { dest }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn branch( + &mut self, + cond: ValueId, + then: BasicBlockId, + else_: BasicBlockId, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::Branch { cond, then, else_ }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn switch( + &mut self, + disc: ValueId, + table: SwitchTable, + default: Option, + // source: SourceInfo, + ) -> InstId { + let kind = InstKind::Switch { + disc, + table, + default, + }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn revert(&mut self, arg: Option, source: SourceInfo) -> InstId { + pub fn revert(&mut self, arg: Option) -> InstId { + let kind = InstKind::Revert { arg }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn emit(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + pub fn emit(&mut self, arg: ValueId) -> InstId { + let kind = InstKind::Emit { arg }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn ret(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + pub fn ret(&mut self, arg: ValueId) -> InstId { + let kind = InstKind::Return { arg: arg.into() }; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + // pub fn nop(&mut self, source: SourceInfo) -> InstId { + pub fn nop(&mut self) -> InstId { + let kind = InstKind::Nop; + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); + self.insert_inst(inst) + } + + pub fn value_ty(&mut self, value: ValueId) -> TypeId { + self.body.store.value_ty(value) + } + + pub fn value_data(&mut self, value: ValueId) -> &Value { + self.body.store.value_data(value) + } + + /// Returns `true` if current block is terminated. + pub fn is_block_terminated(&mut self, block: BasicBlockId) -> bool { + self.body.order.is_terminated(&self.body.store, block) + } + + pub fn is_current_block_terminated(&mut self) -> bool { + let current_block = self.current_block(); + self.is_block_terminated(current_block) + } + + pub fn current_block(&mut self) -> BasicBlockId { + self.cursor().expect_block() + } + + pub fn remove_inst(&mut self, inst: InstId) { + let mut cursor = BodyCursor::new(&mut self.body, CursorLocation::Inst(inst)); + if self.loc == cursor.loc() { + self.loc = cursor.prev_loc(); + } + cursor.remove_inst(); + } + + pub fn inst_data(&self, inst: InstId) -> &Inst { + self.body.store.inst_data(inst) + } + + fn insert_inst(&mut self, inst: Inst) -> InstId { + let mut cursor = self.cursor(); + let inst_id = cursor.store_and_insert_inst(inst); + + // Set cursor to the new inst. + self.loc = CursorLocation::Inst(inst_id); + + inst_id + } + + fn cursor(&mut self) -> BodyCursor { + BodyCursor::new(&mut self.body, self.loc) + } +} diff --git a/crates/mir2/src/ir/body_cursor.rs b/crates/mir2/src/ir/body_cursor.rs new file mode 100644 index 0000000000..ed4199a345 --- /dev/null +++ b/crates/mir2/src/ir/body_cursor.rs @@ -0,0 +1,231 @@ +//! This module provides a collection of structs to modify function body +//! in-place. +// The design used here is greatly inspired by [`cranelift`](https://crates.io/crates/cranelift) + +use super::{ + value::AssignableValue, BasicBlock, BasicBlockId, FunctionBody, Inst, InstId, ValueId, +}; + +/// Specify a current location of [`BodyCursor`] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CursorLocation { + Inst(InstId), + BlockTop(BasicBlockId), + BlockBottom(BasicBlockId), + NoWhere, +} + +pub struct BodyCursor<'a> { + body: &'a mut FunctionBody, + loc: CursorLocation, +} + +impl<'a> BodyCursor<'a> { + pub fn new(body: &'a mut FunctionBody, loc: CursorLocation) -> Self { + Self { body, loc } + } + + pub fn new_at_entry(body: &'a mut FunctionBody) -> Self { + let entry = body.order.entry(); + Self { + body, + loc: CursorLocation::BlockTop(entry), + } + } + pub fn set_loc(&mut self, loc: CursorLocation) { + self.loc = loc; + } + + pub fn loc(&self) -> CursorLocation { + self.loc + } + + pub fn next_loc(&self) -> CursorLocation { + match self.loc() { + CursorLocation::Inst(inst) => self.body.order.next_inst(inst).map_or_else( + || CursorLocation::BlockBottom(self.body.order.inst_block(inst)), + CursorLocation::Inst, + ), + CursorLocation::BlockTop(block) => self + .body + .order + .first_inst(block) + .map_or_else(|| CursorLocation::BlockBottom(block), CursorLocation::Inst), + CursorLocation::BlockBottom(block) => self + .body() + .order + .next_block(block) + .map_or(CursorLocation::NoWhere, |next_block| { + CursorLocation::BlockTop(next_block) + }), + CursorLocation::NoWhere => CursorLocation::NoWhere, + } + } + + pub fn prev_loc(&self) -> CursorLocation { + match self.loc() { + CursorLocation::Inst(inst) => self.body.order.prev_inst(inst).map_or_else( + || CursorLocation::BlockTop(self.body.order.inst_block(inst)), + CursorLocation::Inst, + ), + CursorLocation::BlockTop(block) => self + .body + .order + .prev_block(block) + .map_or(CursorLocation::NoWhere, |prev_block| { + CursorLocation::BlockBottom(prev_block) + }), + CursorLocation::BlockBottom(block) => self + .body + .order + .last_inst(block) + .map_or_else(|| CursorLocation::BlockTop(block), CursorLocation::Inst), + CursorLocation::NoWhere => CursorLocation::NoWhere, + } + } + + pub fn next_block(&self) -> Option { + let block = self.expect_block(); + self.body.order.next_block(block) + } + + pub fn prev_block(&self) -> Option { + let block = self.expect_block(); + self.body.order.prev_block(block) + } + + pub fn proceed(&mut self) { + self.set_loc(self.next_loc()) + } + + pub fn back(&mut self) { + self.set_loc(self.prev_loc()); + } + + pub fn body(&self) -> &FunctionBody { + self.body + } + + pub fn body_mut(&mut self) -> &mut FunctionBody { + self.body + } + + /// Sets a cursor to an entry block. + pub fn set_to_entry(&mut self) { + let entry_bb = self.body().order.entry(); + let loc = CursorLocation::BlockTop(entry_bb); + self.set_loc(loc); + } + + /// Insert [`InstId`] to a location where a cursor points. + /// If you need to store and insert [`Inst`], use [`store_and_insert_inst`]. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn insert_inst(&mut self, inst: InstId) { + match self.loc() { + CursorLocation::Inst(at) => self.body.order.insert_inst_after(inst, at), + CursorLocation::BlockTop(block) => self.body.order.prepend_inst(inst, block), + CursorLocation::BlockBottom(block) => self.body.order.append_inst(inst, block), + CursorLocation::NoWhere => panic!("cursor loc points to `NoWhere`"), + } + } + + pub fn store_and_insert_inst(&mut self, data: Inst) -> InstId { + let inst = self.body.store.store_inst(data); + self.insert_inst(inst); + inst + } + + /// Remove a current pointed [`Inst`] from a function body. A cursor + /// proceeds to a next inst. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn remove_inst(&mut self) { + let inst = self.expect_inst(); + let next_loc = self.next_loc(); + self.body.order.remove_inst(inst); + self.set_loc(next_loc); + } + + /// Remove a current pointed `block` and contained insts from a function + /// body. A cursor proceeds to a next block. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn remove_block(&mut self) { + let block = match self.loc() { + CursorLocation::Inst(inst) => self.body.order.inst_block(inst), + CursorLocation::BlockTop(block) | CursorLocation::BlockBottom(block) => block, + CursorLocation::NoWhere => panic!("cursor loc points `NoWhere`"), + }; + + // Store next block of the current block for later use. + let next_block = self.body.order.next_block(block); + + // Remove all insts in the current block. + if let Some(first_inst) = self.body.order.first_inst(block) { + self.set_loc(CursorLocation::Inst(first_inst)); + while matches!(self.loc(), CursorLocation::Inst(..)) { + self.remove_inst(); + } + } + // Remove current block. + self.body.order.remove_block(block); + + // Set cursor location to next block if exists. + if let Some(next_block) = next_block { + self.set_loc(CursorLocation::BlockTop(next_block)) + } else { + self.set_loc(CursorLocation::NoWhere) + } + } + + /// Insert [`BasicBlockId`] to a location where a cursor points. + /// If you need to store and insert [`BasicBlock`], use + /// [`store_and_insert_block`]. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn insert_block(&mut self, block: BasicBlockId) { + let current = self.expect_block(); + self.body.order.insert_block_after_block(block, current) + } + + pub fn store_and_insert_block(&mut self, block: BasicBlock) -> BasicBlockId { + let block_id = self.body.store.store_block(block); + self.insert_block(block_id); + block_id + } + + pub fn map_result(&mut self, result: AssignableValue) -> Option { + let inst = self.expect_inst(); + let result_value = result.value_id(); + self.body.store.map_result(inst, result); + result_value + } + + /// Returns current inst that cursor points. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn expect_inst(&self) -> InstId { + match self.loc { + CursorLocation::Inst(inst) => inst, + _ => panic!("Cursor doesn't point any inst."), + } + } + + /// Returns current block that cursor points. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn expect_block(&self) -> BasicBlockId { + match self.loc { + CursorLocation::Inst(inst) => self.body.order.inst_block(inst), + CursorLocation::BlockTop(block) | CursorLocation::BlockBottom(block) => block, + CursorLocation::NoWhere => panic!("cursor loc points `NoWhere`"), + } + } +} diff --git a/crates/mir2/src/ir/body_order.rs b/crates/mir2/src/ir/body_order.rs new file mode 100644 index 0000000000..70df3cf76a --- /dev/null +++ b/crates/mir2/src/ir/body_order.rs @@ -0,0 +1,473 @@ +use fxhash::FxHashMap; + +use super::{basic_block::BasicBlockId, function::BodyDataStore, inst::InstId}; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// Represents basic block order and instruction order. +pub struct BodyOrder { + blocks: FxHashMap, + insts: FxHashMap, + entry_block: BasicBlockId, + last_block: BasicBlockId, +} +impl BodyOrder { + pub fn new(entry_block: BasicBlockId) -> Self { + let entry_block_node = BlockNode::default(); + let mut blocks = FxHashMap::default(); + blocks.insert(entry_block, entry_block_node); + + Self { + blocks, + insts: FxHashMap::default(), + entry_block, + last_block: entry_block, + } + } + + /// Returns an entry block of a function body. + pub fn entry(&self) -> BasicBlockId { + self.entry_block + } + + /// Returns a last block of a function body. + pub fn last_block(&self) -> BasicBlockId { + self.last_block + } + + /// Returns `true` if a block doesn't contain any block. + pub fn is_block_empty(&self, block: BasicBlockId) -> bool { + self.first_inst(block).is_none() + } + + /// Returns `true` if a function body contains a given `block`. + pub fn is_block_inserted(&self, block: BasicBlockId) -> bool { + self.blocks.contains_key(&block) + } + + /// Returns a number of block in a function. + pub fn block_num(&self) -> usize { + self.blocks.len() + } + + /// Returns a previous block of a given block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn prev_block(&self, block: BasicBlockId) -> Option { + self.blocks[&block].prev + } + + /// Returns a next block of a given block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn next_block(&self, block: BasicBlockId) -> Option { + self.blocks[&block].next + } + + /// Returns `true` is a given `inst` is inserted. + pub fn is_inst_inserted(&self, inst: InstId) -> bool { + self.insts.contains_key(&inst) + } + + /// Returns first instruction of a block if exists. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn first_inst(&self, block: BasicBlockId) -> Option { + self.blocks[&block].first_inst + } + + /// Returns a terminator instruction of a block. + /// + /// # Panics + /// Panics if + /// 1. `block` is not inserted yet. + pub fn terminator(&self, store: &BodyDataStore, block: BasicBlockId) -> Option { + let last_inst = self.last_inst(block)?; + if store.is_terminator(last_inst) { + Some(last_inst) + } else { + None + } + } + + /// Returns `true` if a `block` is terminated. + pub fn is_terminated(&self, store: &BodyDataStore, block: BasicBlockId) -> bool { + self.terminator(store, block).is_some() + } + + /// Returns a last instruction of a block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn last_inst(&self, block: BasicBlockId) -> Option { + self.blocks[&block].last_inst + } + + /// Returns a previous instruction of a given `inst`. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn prev_inst(&self, inst: InstId) -> Option { + self.insts[&inst].prev + } + + /// Returns a next instruction of a given `inst`. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn next_inst(&self, inst: InstId) -> Option { + self.insts[&inst].next + } + + /// Returns a block to which a given `inst` belongs. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn inst_block(&self, inst: InstId) -> BasicBlockId { + self.insts[&inst].block + } + + /// Returns an iterator which iterates all basic blocks in a function body + /// in pre-order. + pub fn iter_block(&self) -> impl Iterator + '_ { + BlockIter { + next: Some(self.entry_block), + blocks: &self.blocks, + } + } + + /// Returns an iterator which iterates all instruction in a given `block` in + /// pre-order. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn iter_inst(&self, block: BasicBlockId) -> impl Iterator + '_ { + InstIter { + next: self.blocks[&block].first_inst, + insts: &self.insts, + } + } + + /// Appends a given `block` to a function body. + /// + /// # Panics + /// Panics if a given `block` is already inserted to a function. + pub fn append_block(&mut self, block: BasicBlockId) { + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + let last_block = self.last_block; + let last_block_node = &mut self.block_mut(last_block); + last_block_node.next = Some(block); + block_node.prev = Some(last_block); + + self.blocks.insert(block, block_node); + self.last_block = block; + } + + /// Inserts a given `block` before a `before` block. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is already inserted. + /// 2. a given `before` block is NOTE inserted yet. + pub fn insert_block_before_block(&mut self, block: BasicBlockId, before: BasicBlockId) { + debug_assert!(self.is_block_inserted(before)); + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + + match self.blocks[&before].prev { + Some(prev) => { + block_node.prev = Some(prev); + self.block_mut(prev).next = Some(block); + } + None => self.entry_block = block, + } + + block_node.next = Some(before); + self.block_mut(before).prev = Some(block); + self.blocks.insert(block, block_node); + } + + /// Inserts a given `block` after a `after` block. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is already inserted. + /// 2. a given `after` block is NOTE inserted yet. + pub fn insert_block_after_block(&mut self, block: BasicBlockId, after: BasicBlockId) { + debug_assert!(self.is_block_inserted(after)); + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + + match self.blocks[&after].next { + Some(next) => { + block_node.next = Some(next); + self.block_mut(next).prev = Some(block); + } + None => self.last_block = block, + } + block_node.prev = Some(after); + self.block_mut(after).next = Some(block); + self.blocks.insert(block, block_node); + } + + /// Remove a given `block` from a function. All instructions in a block are + /// also removed. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a `block` is the last one block in a function. + pub fn remove_block(&mut self, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(self.block_num() > 1); + + // Remove all insts in a `block`. + let mut next_inst = self.first_inst(block); + while let Some(inst) = next_inst { + next_inst = self.next_inst(inst); + self.remove_inst(inst); + } + + // Remove `block`. + let block_node = &self.blocks[&block]; + let prev_block = block_node.prev; + let next_block = block_node.next; + match (prev_block, next_block) { + // `block` is in the middle of a function. + (Some(prev), Some(next)) => { + self.block_mut(prev).next = Some(next); + self.block_mut(next).prev = Some(prev); + } + // `block` is the last block of a function. + (Some(prev), None) => { + self.block_mut(prev).next = None; + self.last_block = prev; + } + // `block` is the first block of a function. + (None, Some(next)) => { + self.block_mut(next).prev = None; + self.entry_block = next + } + (None, None) => { + unreachable!() + } + } + + self.blocks.remove(&block); + } + + /// Appends `inst` to the end of a `block` + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn append_inst(&mut self, inst: InstId, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(!self.is_inst_inserted(inst)); + + let mut inst_node = InstNode::new(block); + + if let Some(last_inst) = self.blocks[&block].last_inst { + inst_node.prev = Some(last_inst); + self.inst_mut(last_inst).next = Some(inst); + } else { + self.block_mut(block).first_inst = Some(inst); + } + + self.block_mut(block).last_inst = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Prepends `inst` to the beginning of a `block` + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn prepend_inst(&mut self, inst: InstId, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(!self.is_inst_inserted(inst)); + + let mut inst_node = InstNode::new(block); + + if let Some(first_inst) = self.blocks[&block].first_inst { + inst_node.next = Some(first_inst); + self.inst_mut(first_inst).prev = Some(inst); + } else { + self.block_mut(block).last_inst = Some(inst); + } + + self.block_mut(block).first_inst = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Insert `inst` before `before` inst. + /// + /// # Panics + /// Panics if + /// 1. a given `before` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn insert_inst_before_inst(&mut self, inst: InstId, before: InstId) { + debug_assert!(self.is_inst_inserted(before)); + debug_assert!(!self.is_inst_inserted(inst)); + + let before_inst_node = &self.insts[&before]; + let block = before_inst_node.block; + let mut inst_node = InstNode::new(block); + + match before_inst_node.prev { + Some(prev) => { + inst_node.prev = Some(prev); + self.inst_mut(prev).next = Some(inst); + } + None => self.block_mut(block).first_inst = Some(inst), + } + inst_node.next = Some(before); + self.inst_mut(before).prev = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Insert `inst` after `after` inst. + /// + /// # Panics + /// Panics if + /// 1. a given `after` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn insert_inst_after(&mut self, inst: InstId, after: InstId) { + debug_assert!(self.is_inst_inserted(after)); + debug_assert!(!self.is_inst_inserted(inst)); + + let after_inst_node = &self.insts[&after]; + let block = after_inst_node.block; + let mut inst_node = InstNode::new(block); + + match after_inst_node.next { + Some(next) => { + inst_node.next = Some(next); + self.inst_mut(next).prev = Some(inst); + } + None => self.block_mut(block).last_inst = Some(inst), + } + inst_node.prev = Some(after); + self.inst_mut(after).next = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Remove instruction from the function body. + /// + /// # Panics + /// Panics if a given `inst` is not inserted. + pub fn remove_inst(&mut self, inst: InstId) { + debug_assert!(self.is_inst_inserted(inst)); + + let inst_node = &self.insts[&inst]; + let inst_block = inst_node.block; + let prev_inst = inst_node.prev; + let next_inst = inst_node.next; + match (prev_inst, next_inst) { + (Some(prev), Some(next)) => { + self.inst_mut(prev).next = Some(next); + self.inst_mut(next).prev = Some(prev); + } + (Some(prev), None) => { + self.inst_mut(prev).next = None; + self.block_mut(inst_block).last_inst = Some(prev); + } + (None, Some(next)) => { + self.inst_mut(next).prev = None; + self.block_mut(inst_block).first_inst = Some(next); + } + (None, None) => { + let block_node = self.block_mut(inst_block); + block_node.first_inst = None; + block_node.last_inst = None; + } + } + + self.insts.remove(&inst); + } + + fn block_mut(&mut self, block: BasicBlockId) -> &mut BlockNode { + self.blocks.get_mut(&block).unwrap() + } + + fn inst_mut(&mut self, inst: InstId) -> &mut InstNode { + self.insts.get_mut(&inst).unwrap() + } +} + +struct BlockIter<'a> { + next: Option, + blocks: &'a FxHashMap, +} + +impl<'a> Iterator for BlockIter<'a> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + let next = self.next?; + self.next = self.blocks[&next].next; + Some(next) + } +} + +struct InstIter<'a> { + next: Option, + insts: &'a FxHashMap, +} + +impl<'a> Iterator for InstIter<'a> { + type Item = InstId; + + fn next(&mut self) -> Option { + let next = self.next?; + self.next = self.insts[&next].next; + Some(next) + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +/// A helper struct to track a basic block order in a function body. +struct BlockNode { + /// A previous block. + prev: Option, + + /// A next block. + next: Option, + + /// A first instruction of a block. + first_inst: Option, + + /// A last instruction of a block. + last_inst: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +/// A helper struct to track a instruction order in a basic block. +struct InstNode { + /// An block to which a inst belongs. + block: BasicBlockId, + + /// A previous instruction. + prev: Option, + + /// A next instruction. + next: Option, +} + +impl InstNode { + fn new(block: BasicBlockId) -> Self { + Self { + block, + prev: None, + next: None, + } + } +} diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs new file mode 100644 index 0000000000..34ab0e5961 --- /dev/null +++ b/crates/mir2/src/ir/constant.rs @@ -0,0 +1,39 @@ +use hir::hir_def; +use num_bigint::BigInt; +use smol_str::SmolStr; + +#[salsa::interned] +pub struct ConstId { + #[return_ref] + pub data: Const, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Const { + pub value: ConstantValue, + + #[return_ref] + pub(crate) origin: hir_def::Const, +} + +// /// An interned Id for [`Constant`]. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct ConstantId(pub(crate) u32); +// impl_intern_key!(ConstantId); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ConstantValue { + Immediate(BigInt), + Str(SmolStr), + Bool(bool), +} + +// impl From for ConstantValue { +// fn from(value: context::Constant) -> Self { +// match value { +// context::Constant::Int(num) | context::Constant::Address(num) => Self::Immediate(num), +// context::Constant::Str(s) => Self::Str(s), +// context::Constant::Bool(b) => Self::Bool(b), +// } +// } +// } diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs new file mode 100644 index 0000000000..89eeabac1e --- /dev/null +++ b/crates/mir2/src/ir/function.rs @@ -0,0 +1,273 @@ +use fxhash::FxHashMap; +use hir::hir_def::{ModuleTreeNodeId, TypeId}; +use id_arena::Arena; +use num_bigint::BigInt; +use smol_str::SmolStr; +use std::collections::BTreeMap; + +use super::{ + basic_block::BasicBlock, + body_order::BodyOrder, + inst::{BranchInfo, Inst, InstId, InstKind}, + // types::TypeId, + value::{AssignableValue, Local, Value, ValueId}, + BasicBlockId, +}; + +/// Represents function signature. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FunctionSignature { + pub params: Vec, + pub resolved_generics: BTreeMap, + pub return_type: Option, + pub module_id: ModuleTreeNodeId, + pub analyzer_func_id: FunctionId, + pub linkage: Linkage, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FunctionParam { + pub name: SmolStr, + pub ty: TypeId, + // pub source: SourceInfo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct FunctionId(pub u32); +// impl_intern_key!(FunctionId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Linkage { + /// A function can only be called within the same module. + Private, + + /// A function can be called from other modules, but can NOT be called from + /// other accounts and transactions. + Public, + + /// A function can be called from other modules, and also can be called from + /// other accounts and transactions. + Export, +} + +impl Linkage { + pub fn is_exported(self) -> bool { + self == Linkage::Export + } +} + +/// A function body, which is not stored in salsa db to enable in-place +/// transformation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionBody { + pub fid: FunctionId, + + pub store: BodyDataStore, + + /// Tracks order of basic blocks and instructions in a function body. + pub order: BodyOrder, + // pub source: SourceInfo, +} + +impl FunctionBody { + pub fn new(fid: FunctionId) -> Self { + // pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + let mut store = BodyDataStore::default(); + let entry_bb = store.store_block(BasicBlock {}); + Self { + fid, + store, + order: BodyOrder::new(entry_bb), + // source, + } + } +} + +/// A collection of basic block, instructions and values appear in a function +/// body. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct BodyDataStore { + /// Instructions appear in a function body. + insts: Arena, + + /// All values in a function. + values: Arena, + + blocks: Arena, + + /// Maps an immediate to a value to ensure the same immediate results in the + /// same value. + immediates: FxHashMap<(BigInt, TypeId), ValueId>, + + unit_value: Option, + + /// Maps an instruction to a value. + inst_results: FxHashMap, + + /// All declared local variables in a function. + locals: Vec, +} + +impl BodyDataStore { + pub fn store_inst(&mut self, inst: Inst) -> InstId { + self.insts.alloc(inst) + } + + pub fn inst_data(&self, inst: InstId) -> &Inst { + &self.insts[inst] + } + + pub fn inst_data_mut(&mut self, inst: InstId) -> &mut Inst { + &mut self.insts[inst] + } + + pub fn replace_inst(&mut self, inst: InstId, new: Inst) -> Inst { + let old = &mut self.insts[inst]; + std::mem::replace(old, new) + } + + pub fn store_value(&mut self, value: Value) -> ValueId { + match value { + Value::Immediate { imm, ty } => self.store_immediate(imm, ty), + + Value::Unit { .. } => { + if let Some(unit_value) = self.unit_value { + unit_value + } else { + let unit_value = self.values.alloc(value); + self.unit_value = Some(unit_value); + unit_value + } + } + + Value::Local(ref local) => { + let is_user_defined = !local.is_tmp; + let value_id = self.values.alloc(value); + if is_user_defined { + self.locals.push(value_id); + } + value_id + } + + _ => self.values.alloc(value), + } + } + + pub fn is_nop(&self, inst: InstId) -> bool { + matches!(&self.inst_data(inst).kind, InstKind::Nop) + } + + pub fn is_terminator(&self, inst: InstId) -> bool { + self.inst_data(inst).is_terminator() + } + + pub fn branch_info(&self, inst: InstId) -> BranchInfo { + self.inst_data(inst).branch_info() + } + + pub fn value_data(&self, value: ValueId) -> &Value { + &self.values[value] + } + + pub fn value_data_mut(&mut self, value: ValueId) -> &mut Value { + &mut self.values[value] + } + + pub fn values(&self) -> impl Iterator { + self.values.iter().map(|(_, value_data)| value_data) + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.values.iter_mut().map(|(_, value_data)| value_data) + } + + pub fn store_block(&mut self, block: BasicBlock) -> BasicBlockId { + self.blocks.alloc(block) + } + + /// Returns an instruction result + pub fn inst_result(&self, inst: InstId) -> Option<&AssignableValue> { + self.inst_results.get(&inst) + } + + pub fn map_result(&mut self, inst: InstId, result: AssignableValue) { + self.inst_results.insert(inst, result); + } + + pub fn remove_inst_result(&mut self, inst: InstId) -> Option { + self.inst_results.remove(&inst) + } + + pub fn rewrite_branch_dest(&mut self, inst: InstId, from: BasicBlockId, to: BasicBlockId) { + match &mut self.inst_data_mut(inst).kind { + InstKind::Jump { dest } => { + if *dest == from { + *dest = to; + } + } + InstKind::Branch { then, else_, .. } => { + if *then == from { + *then = to; + } + if *else_ == from { + *else_ = to; + } + } + _ => unreachable!("inst is not a branch"), + } + } + + pub fn value_ty(&self, vid: ValueId) -> TypeId { + self.values[vid].ty() + } + + pub fn locals(&self) -> &[ValueId] { + &self.locals + } + + pub fn locals_mut(&mut self) -> &[ValueId] { + &mut self.locals + } + + pub fn func_args(&self) -> impl Iterator + '_ { + self.locals() + .iter() + .filter(|value| match self.value_data(**value) { + Value::Local(local) => local.is_arg, + _ => unreachable!(), + }) + .copied() + } + + pub fn func_args_mut(&mut self) -> impl Iterator { + self.values_mut().filter(|value| match value { + Value::Local(local) => local.is_arg, + _ => false, + }) + } + + /// Returns Some(`local_name`) if value is `Value::Local`. + pub fn local_name(&self, value: ValueId) -> Option<&str> { + match self.value_data(value) { + Value::Local(Local { name, .. }) => Some(name), + _ => None, + } + } + + pub fn replace_value(&mut self, value: ValueId, to: Value) -> Value { + std::mem::replace(&mut self.values[value], to) + } + + fn store_immediate(&mut self, imm: BigInt, ty: TypeId) -> ValueId { + if let Some(value) = self.immediates.get(&(imm.clone(), ty)) { + *value + } else { + let id = self.values.alloc(Value::Immediate { + imm: imm.clone(), + ty, + }); + self.immediates.insert((imm, ty), id); + id + } + } +} diff --git a/crates/mir2/src/ir/inst.rs b/crates/mir2/src/ir/inst.rs new file mode 100644 index 0000000000..9925d20ded --- /dev/null +++ b/crates/mir2/src/ir/inst.rs @@ -0,0 +1,772 @@ +use std::fmt; + +use hir::hir_def::{Contract, TypeId}; +use id_arena::Id; + +use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId}; + +pub type InstId = Id; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Inst { + pub kind: InstKind, + // pub source: SourceInfo, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum InstKind { + /// This is not a real instruction, just used to tag a position where a + /// local is declared. + Declare { + local: ValueId, + }, + + /// Unary instruction. + Unary { + op: UnOp, + value: ValueId, + }, + + /// Binary instruction. + Binary { + op: BinOp, + lhs: ValueId, + rhs: ValueId, + }, + + Cast { + kind: CastKind, + value: ValueId, + to: TypeId, + }, + + /// Constructs aggregate value, i.e. struct, tuple and array. + AggregateConstruct { + ty: TypeId, + args: Vec, + }, + + Bind { + src: ValueId, + }, + + MemCopy { + src: ValueId, + }, + + /// Load a primitive value from a ptr + Load { + src: ValueId, + }, + + /// Access to aggregate fields or elements. + /// # Example + /// + /// ```fe + /// struct Foo: + /// x: i32 + /// y: Array + /// ``` + /// `foo.y` is lowered into `AggregateAccess(foo, [1])' for example. + AggregateAccess { + value: ValueId, + indices: Vec, + }, + + MapAccess { + key: ValueId, + value: ValueId, + }, + + Call { + func: FunctionId, + args: Vec, + call_type: CallType, + }, + + /// Unconditional jump instruction. + Jump { + dest: BasicBlockId, + }, + + /// Conditional branching instruction. + Branch { + cond: ValueId, + then: BasicBlockId, + else_: BasicBlockId, + }, + + Switch { + disc: ValueId, + table: SwitchTable, + default: Option, + }, + + Revert { + arg: Option, + }, + + Emit { + arg: ValueId, + }, + + Return { + arg: Option, + }, + + Keccak256 { + arg: ValueId, + }, + + AbiEncode { + arg: ValueId, + }, + + Nop, + + Create { + value: ValueId, + contract: Contract, + }, + + Create2 { + value: ValueId, + salt: ValueId, + contract: Contract, + }, + + YulIntrinsic { + op: YulIntrinsicOp, + args: Vec, + }, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] +pub struct SwitchTable { + values: Vec, + blocks: Vec, +} + +impl SwitchTable { + pub fn iter(&self) -> impl Iterator + '_ { + self.values.iter().copied().zip(self.blocks.iter().copied()) + } + + pub fn len(&self) -> usize { + debug_assert!(self.values.len() == self.blocks.len()); + self.values.len() + } + + pub fn is_empty(&self) -> bool { + debug_assert!(self.values.len() == self.blocks.len()); + self.values.is_empty() + } + + pub fn add_arm(&mut self, value: ValueId, block: BasicBlockId) { + self.values.push(value); + self.blocks.push(block); + } +} + +impl Inst { + // pub fn new(kind: InstKind, source: SourceInfo) -> Self { + pub fn new(kind: InstKind) -> Self { + // Self { kind, source } + Self { kind } + } + + // pub fn unary(op: UnOp, value: ValueId, source: SourceInfo) -> Self { + pub fn unary(op: UnOp, value: ValueId) -> Self { + let kind = InstKind::Unary { op, value }; + // Self::new(kind, source) + Self::new(kind) + } + + // pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> Self { + pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId) -> Self { + let kind = InstKind::Binary { op, lhs, rhs }; + // Self::new(kind, source) + Self::new(kind) + } + + // pub fn intrinsic(op: YulIntrinsicOp, args: Vec, source: SourceInfo) -> Self { + pub fn intrinsic(op: YulIntrinsicOp, args: Vec) -> Self { + let kind = InstKind::YulIntrinsic { op, args }; + // Self::new(kind, source) + Self::new(kind) + } + + pub fn nop() -> Self { + Self { + kind: InstKind::Nop, + // source: SourceInfo::dummy(), + } + } + + pub fn is_terminator(&self) -> bool { + match self.kind { + InstKind::Jump { .. } + | InstKind::Branch { .. } + | InstKind::Switch { .. } + | InstKind::Revert { .. } + | InstKind::Return { .. } => true, + InstKind::YulIntrinsic { op, .. } => op.is_terminator(), + _ => false, + } + } + + pub fn branch_info(&self) -> BranchInfo { + match self.kind { + InstKind::Jump { dest } => BranchInfo::Jump(dest), + InstKind::Branch { cond, then, else_ } => BranchInfo::Branch(cond, then, else_), + InstKind::Switch { + disc, + ref table, + default, + } => BranchInfo::Switch(disc, table, default), + _ => BranchInfo::NotBranch, + } + } + + pub fn args(&self) -> ValueIter { + use InstKind::*; + match &self.kind { + Declare { local: arg } + | Bind { src: arg } + | MemCopy { src: arg } + | Load { src: arg } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ValueIter::one(*arg), + + Switch { disc, table, .. } => { + ValueIter::one(*disc).chain(ValueIter::Slice(table.values.iter())) + } + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ValueIter::one(*lhs).chain(ValueIter::one(*rhs)), + + Revert { arg } | Return { arg } => ValueIter::One(*arg), + + Nop | Jump { .. } => ValueIter::Zero, + + AggregateAccess { value, indices } => { + ValueIter::one(*value).chain(ValueIter::Slice(indices.iter())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ValueIter::Slice(args.iter()) + } + } + } + + pub fn args_mut(&mut self) -> ValueIterMut { + use InstKind::*; + match &mut self.kind { + Declare { local: arg } + | Bind { src: arg } + | MemCopy { src: arg } + | Load { src: arg } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ValueIterMut::one(arg), + + Switch { disc, table, .. } => { + ValueIterMut::one(disc).chain(ValueIterMut::Slice(table.values.iter_mut())) + } + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ValueIterMut::one(lhs).chain(ValueIterMut::one(rhs)), + + Revert { arg } | Return { arg } => ValueIterMut::One(arg.as_mut()), + + Nop | Jump { .. } => ValueIterMut::Zero, + + AggregateAccess { value, indices } => { + ValueIterMut::one(value).chain(ValueIterMut::Slice(indices.iter_mut())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ValueIterMut::Slice(args.iter_mut()) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum UnOp { + /// `not` operator for logical inversion. + Not, + /// `-` operator for negation. + Neg, + /// `~` operator for bitwise inversion. + Inv, +} + +impl fmt::Display for UnOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Not => write!(w, "not"), + Self::Neg => write!(w, "-"), + Self::Inv => write!(w, "~"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Mod, + Pow, + Shl, + Shr, + BitOr, + BitXor, + BitAnd, + LogicalAnd, + LogicalOr, + Eq, + Ne, + Ge, + Gt, + Le, + Lt, +} + +impl fmt::Display for BinOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Add => write!(w, "+"), + Self::Sub => write!(w, "-"), + Self::Mul => write!(w, "*"), + Self::Div => write!(w, "/"), + Self::Mod => write!(w, "%"), + Self::Pow => write!(w, "**"), + Self::Shl => write!(w, "<<"), + Self::Shr => write!(w, ">>"), + Self::BitOr => write!(w, "|"), + Self::BitXor => write!(w, "^"), + Self::BitAnd => write!(w, "&"), + Self::LogicalAnd => write!(w, "and"), + Self::LogicalOr => write!(w, "or"), + Self::Eq => write!(w, "=="), + Self::Ne => write!(w, "!="), + Self::Ge => write!(w, ">="), + Self::Gt => write!(w, ">"), + Self::Le => write!(w, "<="), + Self::Lt => write!(w, "<"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CallType { + Internal, + External, +} + +impl fmt::Display for CallType { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Internal => write!(w, "internal"), + Self::External => write!(w, "external"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CastKind { + /// A cast from a primitive type to a primitive type. + Primitive, + + /// A cast from an enum type to its underlying type. + Untag, +} + +// TODO: We don't need all yul intrinsics. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum YulIntrinsicOp { + Stop, + Add, + Sub, + Mul, + Div, + Sdiv, + Mod, + Smod, + Exp, + Not, + Lt, + Gt, + Slt, + Sgt, + Eq, + Iszero, + And, + Or, + Xor, + Byte, + Shl, + Shr, + Sar, + Addmod, + Mulmod, + Signextend, + Keccak256, + Pc, + Pop, + Mload, + Mstore, + Mstore8, + Sload, + Sstore, + Msize, + Gas, + Address, + Balance, + Selfbalance, + Caller, + Callvalue, + Calldataload, + Calldatasize, + Calldatacopy, + Codesize, + Codecopy, + Extcodesize, + Extcodecopy, + Returndatasize, + Returndatacopy, + Extcodehash, + Create, + Create2, + Call, + Callcode, + Delegatecall, + Staticcall, + Return, + Revert, + Selfdestruct, + Invalid, + Log0, + Log1, + Log2, + Log3, + Log4, + Chainid, + Basefee, + Origin, + Gasprice, + Blockhash, + Coinbase, + Timestamp, + Number, + Prevrandao, + Gaslimit, +} +impl YulIntrinsicOp { + pub fn is_terminator(self) -> bool { + matches!( + self, + Self::Return | Self::Revert | Self::Selfdestruct | Self::Invalid + ) + } +} + +impl fmt::Display for YulIntrinsicOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + let op = match self { + Self::Stop => "__stop", + Self::Add => "__add", + Self::Sub => "__sub", + Self::Mul => "__mul", + Self::Div => "__div", + Self::Sdiv => "__sdiv", + Self::Mod => "__mod", + Self::Smod => "__smod", + Self::Exp => "__exp", + Self::Not => "__not", + Self::Lt => "__lt", + Self::Gt => "__gt", + Self::Slt => "__slt", + Self::Sgt => "__sgt", + Self::Eq => "__eq", + Self::Iszero => "__iszero", + Self::And => "__and", + Self::Or => "__or", + Self::Xor => "__xor", + Self::Byte => "__byte", + Self::Shl => "__shl", + Self::Shr => "__shr", + Self::Sar => "__sar", + Self::Addmod => "__addmod", + Self::Mulmod => "__mulmod", + Self::Signextend => "__signextend", + Self::Keccak256 => "__keccak256", + Self::Pc => "__pc", + Self::Pop => "__pop", + Self::Mload => "__mload", + Self::Mstore => "__mstore", + Self::Mstore8 => "__mstore8", + Self::Sload => "__sload", + Self::Sstore => "__sstore", + Self::Msize => "__msize", + Self::Gas => "__gas", + Self::Address => "__address", + Self::Balance => "__balance", + Self::Selfbalance => "__selfbalance", + Self::Caller => "__caller", + Self::Callvalue => "__callvalue", + Self::Calldataload => "__calldataload", + Self::Calldatasize => "__calldatasize", + Self::Calldatacopy => "__calldatacopy", + Self::Codesize => "__codesize", + Self::Codecopy => "__codecopy", + Self::Extcodesize => "__extcodesize", + Self::Extcodecopy => "__extcodecopy", + Self::Returndatasize => "__returndatasize", + Self::Returndatacopy => "__returndatacopy", + Self::Extcodehash => "__extcodehash", + Self::Create => "__create", + Self::Create2 => "__create2", + Self::Call => "__call", + Self::Callcode => "__callcode", + Self::Delegatecall => "__delegatecall", + Self::Staticcall => "__staticcall", + Self::Return => "__return", + Self::Revert => "__revert", + Self::Selfdestruct => "__selfdestruct", + Self::Invalid => "__invalid", + Self::Log0 => "__log0", + Self::Log1 => "__log1", + Self::Log2 => "__log2", + Self::Log3 => "__log3", + Self::Log4 => "__log4", + Self::Chainid => "__chainid", + Self::Basefee => "__basefee", + Self::Origin => "__origin", + Self::Gasprice => "__gasprice", + Self::Blockhash => "__blockhash", + Self::Coinbase => "__coinbase", + Self::Timestamp => "__timestamp", + Self::Number => "__number", + Self::Prevrandao => "__prevrandao", + Self::Gaslimit => "__gaslimit", + }; + + write!(w, "{op}") + } +} + +// impl From for YulIntrinsicOp { +// fn from(val: fe_analyzer2::builtins::Intrinsic) -> Self { +// use fe_analyzer2::builtins::Intrinsic; +// match val { +// Intrinsic::__stop => Self::Stop, +// Intrinsic::__add => Self::Add, +// Intrinsic::__sub => Self::Sub, +// Intrinsic::__mul => Self::Mul, +// Intrinsic::__div => Self::Div, +// Intrinsic::__sdiv => Self::Sdiv, +// Intrinsic::__mod => Self::Mod, +// Intrinsic::__smod => Self::Smod, +// Intrinsic::__exp => Self::Exp, +// Intrinsic::__not => Self::Not, +// Intrinsic::__lt => Self::Lt, +// Intrinsic::__gt => Self::Gt, +// Intrinsic::__slt => Self::Slt, +// Intrinsic::__sgt => Self::Sgt, +// Intrinsic::__eq => Self::Eq, +// Intrinsic::__iszero => Self::Iszero, +// Intrinsic::__and => Self::And, +// Intrinsic::__or => Self::Or, +// Intrinsic::__xor => Self::Xor, +// Intrinsic::__byte => Self::Byte, +// Intrinsic::__shl => Self::Shl, +// Intrinsic::__shr => Self::Shr, +// Intrinsic::__sar => Self::Sar, +// Intrinsic::__addmod => Self::Addmod, +// Intrinsic::__mulmod => Self::Mulmod, +// Intrinsic::__signextend => Self::Signextend, +// Intrinsic::__keccak256 => Self::Keccak256, +// Intrinsic::__pc => Self::Pc, +// Intrinsic::__pop => Self::Pop, +// Intrinsic::__mload => Self::Mload, +// Intrinsic::__mstore => Self::Mstore, +// Intrinsic::__mstore8 => Self::Mstore8, +// Intrinsic::__sload => Self::Sload, +// Intrinsic::__sstore => Self::Sstore, +// Intrinsic::__msize => Self::Msize, +// Intrinsic::__gas => Self::Gas, +// Intrinsic::__address => Self::Address, +// Intrinsic::__balance => Self::Balance, +// Intrinsic::__selfbalance => Self::Selfbalance, +// Intrinsic::__caller => Self::Caller, +// Intrinsic::__callvalue => Self::Callvalue, +// Intrinsic::__calldataload => Self::Calldataload, +// Intrinsic::__calldatasize => Self::Calldatasize, +// Intrinsic::__calldatacopy => Self::Calldatacopy, +// Intrinsic::__codesize => Self::Codesize, +// Intrinsic::__codecopy => Self::Codecopy, +// Intrinsic::__extcodesize => Self::Extcodesize, +// Intrinsic::__extcodecopy => Self::Extcodecopy, +// Intrinsic::__returndatasize => Self::Returndatasize, +// Intrinsic::__returndatacopy => Self::Returndatacopy, +// Intrinsic::__extcodehash => Self::Extcodehash, +// Intrinsic::__create => Self::Create, +// Intrinsic::__create2 => Self::Create2, +// Intrinsic::__call => Self::Call, +// Intrinsic::__callcode => Self::Callcode, +// Intrinsic::__delegatecall => Self::Delegatecall, +// Intrinsic::__staticcall => Self::Staticcall, +// Intrinsic::__return => Self::Return, +// Intrinsic::__revert => Self::Revert, +// Intrinsic::__selfdestruct => Self::Selfdestruct, +// Intrinsic::__invalid => Self::Invalid, +// Intrinsic::__log0 => Self::Log0, +// Intrinsic::__log1 => Self::Log1, +// Intrinsic::__log2 => Self::Log2, +// Intrinsic::__log3 => Self::Log3, +// Intrinsic::__log4 => Self::Log4, +// Intrinsic::__chainid => Self::Chainid, +// Intrinsic::__basefee => Self::Basefee, +// Intrinsic::__origin => Self::Origin, +// Intrinsic::__gasprice => Self::Gasprice, +// Intrinsic::__blockhash => Self::Blockhash, +// Intrinsic::__coinbase => Self::Coinbase, +// Intrinsic::__timestamp => Self::Timestamp, +// Intrinsic::__number => Self::Number, +// Intrinsic::__prevrandao => Self::Prevrandao, +// Intrinsic::__gaslimit => Self::Gaslimit, +// } +// } +// } + +pub enum BranchInfo<'a> { + NotBranch, + Jump(BasicBlockId), + Branch(ValueId, BasicBlockId, BasicBlockId), + Switch(ValueId, &'a SwitchTable, Option), +} + +impl<'a> BranchInfo<'a> { + pub fn is_not_a_branch(&self) -> bool { + matches!(self, BranchInfo::NotBranch) + } + + pub fn block_iter(&self) -> BlockIter { + match self { + Self::NotBranch => BlockIter::Zero, + Self::Jump(block) => BlockIter::one(*block), + Self::Branch(_, then, else_) => BlockIter::one(*then).chain(BlockIter::one(*else_)), + Self::Switch(_, table, default) => { + BlockIter::Slice(table.blocks.iter()).chain(BlockIter::One(*default)) + } + } + } +} + +pub type BlockIter<'a> = IterBase<'a, BasicBlockId>; +pub type ValueIter<'a> = IterBase<'a, ValueId>; +pub type ValueIterMut<'a> = IterMutBase<'a, ValueId>; + +pub enum IterBase<'a, T> { + Zero, + One(Option), + Slice(std::slice::Iter<'a, T>), + Chain(Box>, Box>), +} + +impl<'a, T> IterBase<'a, T> { + fn one(value: T) -> Self { + Self::One(Some(value)) + } + + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a, T> Iterator for IterBase<'a, T> +where + T: Copy, +{ + type Item = T; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next().copied(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} + +pub enum IterMutBase<'a, T> { + Zero, + One(Option<&'a mut T>), + Slice(std::slice::IterMut<'a, T>), + Chain(Box>, Box>), +} + +impl<'a, T> IterMutBase<'a, T> { + fn one(value: &'a mut T) -> Self { + Self::One(Some(value)) + } + + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a, T> Iterator for IterMutBase<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs new file mode 100644 index 0000000000..bf27440d41 --- /dev/null +++ b/crates/mir2/src/ir/mod.rs @@ -0,0 +1,46 @@ +pub mod basic_block; +pub mod body_builder; +pub mod body_cursor; +pub mod body_order; +pub mod constant; +pub mod function; +pub mod inst; +// pub mod types; +pub mod value; + +pub use basic_block::{BasicBlock, BasicBlockId}; +pub use constant::{Const, ConstId}; +pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; +pub use inst::{Inst, InstId}; +// pub use types::{Type, TypeId, TypeKind}; +pub use value::{Value, ValueId}; + +// /// An original source information that indicates where `mir` entities derive +// /// from. `SourceInfo` is mainly used for diagnostics. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct SourceInfo { +// pub span: Span, +// pub id: NodeId, +// } + +// impl SourceInfo { +// pub fn dummy() -> Self { +// Self { +// span: Span::dummy(), +// id: NodeId::dummy(), +// } +// } + +// pub fn is_dummy(&self) -> bool { +// self == &Self::dummy() +// } +// } + +// impl From<&Node> for SourceInfo { +// fn from(node: &Node) -> Self { +// Self { +// span: node.span, +// id: node.id, +// } +// } +// } diff --git a/crates/mir2/src/ir/types.rs b/crates/mir2/src/ir/types.rs new file mode 100644 index 0000000000..a368cce2dd --- /dev/null +++ b/crates/mir2/src/ir/types.rs @@ -0,0 +1,117 @@ +// use smol_str::SmolStr; + +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct Type { +// pub kind: TypeKind, +// pub analyzer_ty: Option, +// } + +// impl Type { +// pub fn new(kind: TypeKind, analyzer_ty: Option) -> Self { +// Self { kind, analyzer_ty } +// } +// } + +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub enum TypeKind { +// I8, +// I16, +// I32, +// I64, +// I128, +// I256, +// U8, +// U16, +// U32, +// U64, +// U128, +// U256, +// Bool, +// Address, +// Unit, +// Array(ArrayDef), +// // TODO: we should consider whether we really need `String` type. +// String(usize), +// Tuple(TupleDef), +// Struct(StructDef), +// Enum(EnumDef), +// Contract(StructDef), +// Map(MapDef), +// MPtr(TypeId), +// SPtr(TypeId), +// } + +// /// An interned Id for [`ArrayDef`]. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct TypeId(pub u32); +// impl_intern_key!(TypeId); + +// /// A static array type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct ArrayDef { +// pub elem_ty: TypeId, +// pub len: usize, +// } + +// /// A tuple type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct TupleDef { +// pub items: Vec, +// } + +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct StructDef { +// pub name: SmolStr, +// pub fields: Vec<(SmolStr, TypeId)>, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } + +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EnumDef { +// pub name: SmolStr, +// pub variants: Vec, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } + +// impl EnumDef { +// pub fn tag_type(&self) -> TypeKind { +// let variant_num = self.variants.len() as u64; +// if variant_num <= u8::MAX as u64 { +// TypeKind::U8 +// } else if variant_num <= u16::MAX as u64 { +// TypeKind::U16 +// } else if variant_num <= u32::MAX as u64 { +// TypeKind::U32 +// } else { +// TypeKind::U64 +// } +// } +// } + +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EnumVariant { +// pub name: SmolStr, +// pub span: Span, +// pub ty: TypeId, +// } + +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EventDef { +// pub name: SmolStr, +// pub fields: Vec<(SmolStr, TypeId, bool)>, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } + +// /// A map type definition. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct MapDef { +// pub key_ty: TypeId, +// pub value_ty: TypeId, +// } diff --git a/crates/mir2/src/ir/value.rs b/crates/mir2/src/ir/value.rs new file mode 100644 index 0000000000..4c31129a6e --- /dev/null +++ b/crates/mir2/src/ir/value.rs @@ -0,0 +1,138 @@ +use hir::hir_def::TypeId; +use id_arena::Id; +use num_bigint::BigInt; +use smol_str::SmolStr; + +// use crate::db::MirDb; + +use super::{constant::ConstantId, inst::InstId}; + +pub type ValueId = Id; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Value { + /// A value resulted from an instruction. + Temporary { inst: InstId, ty: TypeId }, + + /// A local variable declared in a function body. + Local(Local), + + /// An immediate value. + Immediate { imm: BigInt, ty: TypeId }, + + /// A constant value. + Constant { constant: ConstantId, ty: TypeId }, + + /// A singleton value representing `Unit` type. + Unit { ty: TypeId }, +} + +impl Value { + pub fn ty(&self) -> TypeId { + match self { + Self::Local(val) => val.ty, + Self::Immediate { ty, .. } + | Self::Temporary { ty, .. } + | Self::Unit { ty } + | Self::Constant { ty, .. } => *ty, + } + } + + pub fn is_imm(&self) -> bool { + matches!(self, Self::Immediate { .. }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum AssignableValue { + Value(ValueId), + Aggregate { + lhs: Box, + idx: ValueId, + }, + Map { + lhs: Box, + key: ValueId, + }, +} + +impl From for AssignableValue { + fn from(value: ValueId) -> Self { + Self::Value(value) + } +} + +impl AssignableValue { + // pub fn ty(&self, db: &dyn MirDb, store: &BodyDataStore) -> TypeId { + // match self { + // Self::Value(value) => store.value_ty(*value), + // Self::Aggregate { lhs, idx } => { + // let lhs_ty = lhs.ty(db, store); + // lhs_ty.projection_ty(db, store.value_data(*idx)) + // } + // Self::Map { lhs, .. } => { + // let lhs_ty = lhs.ty(db, store).deref(db); + // match lhs_ty.data(db).kind { + // TypeKind::Map(def) => def.value_ty.make_sptr(db), + // _ => unreachable!(), + // } + // } + // } + // } + + pub fn value_id(&self) -> Option { + match self { + Self::Value(value) => Some(*value), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Local { + /// An original name of a local variable. + pub name: SmolStr, + + pub ty: TypeId, + + /// `true` if a local is a function argument. + pub is_arg: bool, + + /// `true` if a local is introduced in MIR. + pub is_tmp: bool, + // pub source: SourceInfo, +} + +impl Local { + // pub fn user_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + pub fn user_local(name: SmolStr, ty: TypeId) -> Local { + Self { + name, + ty, + is_arg: false, + is_tmp: false, + // source, + } + } + + // pub fn arg_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + pub fn arg_local(name: SmolStr, ty: TypeId) -> Local { + Self { + name, + ty, + is_arg: true, + is_tmp: false, + // source, + } + } + + pub fn tmp_local(name: SmolStr, ty: TypeId) -> Local { + Self { + name, + ty, + is_arg: false, + is_tmp: true, + // source: SourceInfo::dummy(), + } + } +} diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs new file mode 100644 index 0000000000..2ee69d287f --- /dev/null +++ b/crates/mir2/src/lib.rs @@ -0,0 +1,59 @@ +use hir::HirDb; + +pub mod analysis; +// pub mod graphviz; +pub mod ir; +// pub mod pretty_print; + +mod lower; + +#[salsa::jar(db = MirDb)] +pub struct Jar( + // ir::Constant, + ir::ConstId, + // ir::FunctionBody, + // ir::FunctionId, + // ir::FunctionParam, + // ir::FunctionSignature, + // ir::Inst, + // ir::InstId, + // ir::Value, + // ir::ValueId, + // mir_intern_const, + // mir_intern_type, + // mir_intern_function, + // mir_lower_module_all_functions, + // mir_lower_contract_all_functions, + // mir_lower_struct_all_functions, + // mir_lower_enum_all_functions, + // mir_lowered_type, + lower::constant::mir_lowered_constant, + // mir_lowered_func_signature, + // mir_lowered_monomorphized_func_signature, + // mir_lowered_pseudo_monomorphized_func_signature, + // mir_lowered_func_body, +); + +#[salsa::jar(db = LowerMirDb)] +pub struct LowerJar(); + +pub trait MirDb: salsa::DbWithJar + HirDb { + fn prefill(&self) + where + Self: Sized, + { + // IdentId::prefill(self) + } + + // fn as_hir_db(&self) -> &dyn MirDb { + // >::as_jar_db::<'_>(self) + // } +} +impl MirDb for DB where DB: salsa::DbWithJar + HirDb {} + +pub trait LowerMirDb: salsa::DbWithJar + HirDb { + fn as_lower_hir_db(&self) -> &dyn LowerMirDb { + >::as_jar_db::<'_>(self) + } +} +impl LowerMirDb for DB where DB: salsa::DbWithJar + MirDb {} diff --git a/crates/mir2/src/lower/constant.rs b/crates/mir2/src/lower/constant.rs new file mode 100644 index 0000000000..6f6f0decfa --- /dev/null +++ b/crates/mir2/src/lower/constant.rs @@ -0,0 +1,18 @@ +use std::rc::Rc; + +use hir::hir_def::{Const, TypeId}; + +use crate::{ + ir::{Const, ConstId}, + MirDb, +}; + +#[salsa::tracked] +pub fn mir_lowered_constant(db: &dyn MirDb, hir_const: Const) -> ConstId { + let value = hir_const.constant_value(db.as_hir_db()).unwrap(); + + let constant = Const { + value: value.into(), + origin: hir_const, + } +} diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs new file mode 100644 index 0000000000..567eb47984 --- /dev/null +++ b/crates/mir2/src/lower/function.rs @@ -0,0 +1,1213 @@ +use std::rc::Rc; + +use fxhash::FxHashMap; +use hir::hir_def::{self, TypeId}; +use id_arena::{Arena, Id}; +use num_bigint::BigInt; +use smol_str::SmolStr; + +use crate::{ + ir::{ + body_builder::BodyBuilder, inst::InstKind, value::Local, BasicBlockId, FunctionBody, + FunctionId, FunctionParam, InstId, Value, ValueId, + }, + MirDb, +}; + +type ScopeId = Id; + +// pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { +// let analyzer_func = func.analyzer_func(db); +// let ast = &analyzer_func.data(db.upcast()).ast; +// let analyzer_body = analyzer_func.body(db.upcast()); + +// BodyLowerHelper::new(db, func, ast, analyzer_body.as_ref()) +// .lower() +// .into() +// } + +pub(super) struct BodyLowerHelper<'db, 'a> { + pub(super) db: &'db dyn MirDb, + pub(super) builder: BodyBuilder, + ast: &'a hir_def::Func, + func: FunctionId, + // analyzer_body: &'a fe_analyzer2::context::FunctionBody, + scopes: Arena, + current_scope: ScopeId, +} + +impl<'db, 'a> BodyLowerHelper<'db, 'a> { + pub(super) fn lower_stmt(&mut self, stmt: &hir_def::Stmt) { + // match &stmt.kind { + // hir_def::Stmt::Return(value) => { + // let value = if let Some(expr) = value { + // self.lower_expr_to_value(expr) + // } else { + // self.make_unit() + // }; + // self.builder.ret(value, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::VarDecl { target, value, .. } => { + // self.lower_var_decl(target, value.as_ref(), stmt.into()); + // } + + // hir_def::Stmt::ConstantDecl { name, value, .. } => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&name.id]); + + // let value = self.analyzer_body.expressions[&value.id] + // .const_value + // .clone() + // .unwrap(); + + // let constant = + // self.make_local_constant(name.kind.clone(), ty, value.into(), stmt.into()); + // self.scope_mut().declare_var(&name.kind, constant); + // } + + // hir_def::Stmt::Assign(target, value) => { + // let result = self.lower_assignable_value(target); + // let (expr, _ty) = self.lower_expr(value); + // self.builder.map_result(expr, result) + // } + + // hir_def::Stmt::AugAssign { target, op, value } => { + // let result = self.lower_assignable_value(target); + // let lhs = self.lower_expr_to_value(target); + // let rhs = self.lower_expr_to_value(value); + + // let inst = self.lower_binop(op.kind, lhs, rhs, stmt.into()); + // self.builder.map_result(inst, result) + // } + + // hir_def::Stmt::For(target, iter, body) => self.lower_for_loop(target, iter, body), + + // hir_def::Stmt::While(test, body) => { + // let header_bb = self.builder.make_block(); + // let exit_bb = self.builder.make_block(); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, header_bb, exit_bb()); + + // // Lower while body. + // self.builder.move_to_block(header_bb); + // self.enter_loop_scope(header_bb, exit_bb); + // for stmt in body { + // self.lower_stmt(stmt); + // } + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, header_bb, exit_bb()); + + // self.leave_scope(); + + // // Move to while exit bb. + // self.builder.move_to_block(exit_bb); + // } + + // hir_def::Stmt::If { + // test, + // body, + // or_else, + // } => self.lower_if(test, body, or_else), + + // hir_def::Stmt::Match { expr, arms } => { + // let matrix = &self.analyzer_body.matches[&stmt.id]; + // super::pattern_match::lower_match(self, matrix, expr, arms); + // } + + // hir_def::Stmt::Assert { test, msg } => { + // let then_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, then_bb, false_bb()); + + // self.builder.move_to_block(false_bb); + + // let msg = match msg { + // Some(msg) => self.lower_expr_to_value(msg), + // None => self.make_u256_imm(1), + // }; + // self.builder.revert(Some(msg), stmt.into()); + // self.builder.move_to_block(then_bb); + // } + + // hir_def::Stmt::Expr(value) => { + // self.lower_expr_to_value(value); + // } + + // hir_def::Stmt::Break => { + // let exit = self.scope().loop_exit(&self.scopes); + // self.builder.jump(exit, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Continue => { + // let entry = self.scope().loop_entry(&self.scopes); + // if let Some(loop_idx) = self.scope().loop_idx(&self.scopes) { + // let imm_one = self.make_u256_imm(1u32); + // let inc = self.builder.add(loop_idx, imm_one()); + // self.builder.map_result(inc, loop_idx.into()); + // let maximum_iter_count = self.scope().maximum_iter_count(&self.scopes).unwrap(); + // let exit = self.scope().loop_exit(&self.scopes); + // self.branch_eq(loop_idx, maximum_iter_count, exit, entry, stmt.into()); + // } else { + // self.builder.jump(entry, stmt.into()); + // } + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Revert { error } => { + // let error = error.as_ref().map(|err| self.lower_expr_to_value(err)); + // self.builder.revert(error, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Unsafe(stmts) => { + // self.enter_scope(); + // for stmt in stmts { + // self.lower_stmt(stmt) + // } + // self.leave_scope() + // } + // } + panic!() + } + + pub(super) fn lower_var_decl(&mut self, var: &hir_def::PatId, init: Option<&hir_def::Expr>) { + // match &var.kind { + // hir_def::VarDeclTarget::Name(name) => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + // let value = self.declare_var(name, ty, var.into()); + // if let Some(init) = init { + // let (init, _init_ty) = self.lower_expr(init); + // // debug_assert_eq!(ty.deref(self.db), init_ty, "vardecl init type mismatch: {} + // // != {}", ty.as_string(self.db), + // // init_ty.as_string(self.db)); + // self.builder.map_result(init, value.into()); + // } + // } + + // hir_def::VarDeclTarget::Tuple(decls) => { + // if let Some(init) = init { + // if let hir_def::Expr::Tuple(elts) = &init.kind { + // debug_assert_eq!(decls.len(), elts.len()); + // for (decl, init_elem) in decls.iter().zip(elts.iter()) { + // self.lower_var_decl(decl, Some(init_elem), source.clone()); + // } + // } else { + // let init_ty = self.expr_ty(init); + // let init_value = self.lower_expr_to_value(init); + // self.lower_var_decl_unpack(var, init_value, init_ty, source); + // }; + // } else { + // for decl in decls { + // self.lower_var_decl(decl, None, source.clone()) + // } + // } + // } + // } + panic!() + } + + pub(super) fn declare_var(&mut self, name: &SmolStr, ty: TypeId) -> ValueId { + let local = Local::user_local(name.clone(), ty); + let value = self.builder.declare(local); + self.scope_mut().declare_var(name, value); + value + } + + pub(super) fn lower_var_decl_unpack( + &mut self, + var: &hir_def::PatId, + init: ValueId, + init_ty: TypeId, + ) { + // match &var.kind { + // hir_def::VarDeclTarget::Name(name) => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + // let local = Local::user_local(name.clone(), ty, var.into()); + + // let lhs = self.builder.declare(local); + // self.scope_mut().declare_var(name, lhs); + // let bind = self.builder.bind(init, source); + // self.builder.map_result(bind, lhs.into()); + // } + + // hir_def::VarDeclTarget::Tuple(decls) => { + // for (index, decl) in decls.iter().enumerate() { + // let elem_ty = init_ty.projection_ty_imm(self.db, index); + // let index_value = self.make_u256_imm(index); + // let elem_inst = + // self.builder + // .aggregate_access(init, vec![index_value], source.clone()); + // let elem_value = self.map_to_tmp(elem_inst, elem_ty); + // self.lower_var_decl_unpack(decl, elem_value, elem_ty, source.clone()) + // } + // } + // } + panic!() + } + + pub(super) fn lower_expr(&mut self, expr: &hir_def::Expr) -> (InstId, TypeId) { + // let mut ty = self.expr_ty(expr); + // let mut inst = match &expr.kind { + // hir_def::Expr::Ternary { + // if_expr, + // test, + // else_expr, + // } => { + // let true_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // let tmp = self + // .builder + // .declare(Local::tmp_local("$ternary_tmp".into(), ty)); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, true_bb, false_bb()); + + // self.builder.move_to_block(true_bb); + // let (value, _) = self.lower_expr(if_expr); + // self.builder.map_result(value, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(false_bb); + // let (value, _) = self.lower_expr(else_expr); + // self.builder.map_result(value, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(merge_bb); + // self.builder.bind(tmp()) + // } + + // hir_def::Expr::BoolOperation { left, op, right } => { + // self.lower_bool_op(op.kind, left, right, ty) + // } + + // hir_def::Expr::BinOperation { left, op, right } => { + // let lhs = self.lower_expr_to_value(left); + // let rhs = self.lower_expr_to_value(right); + // self.lower_binop(op.kind, lhs, rhs, expr.into()) + // } + + // hir_def::Expr::UnaryOperation { op, operand } => { + // let value = self.lower_expr_to_value(operand); + // match op.kind { + // hir_def::UnOp::Invert => self.builder.inv(value, expr.into()), + // hir_def::UnOp::Not => self.builder.not(value, expr.into()), + // hir_def::UnOp::USub => self.builder.neg(value, expr.into()), + // } + // } + + // hir_def::Expr::CompOperation { left, op, right } => { + // let lhs = self.lower_expr_to_value(left); + // let rhs = self.lower_expr_to_value(right); + // self.lower_comp_op(op.kind, lhs, rhs, expr.into()) + // } + + // hir_def::Expr::Attribute { .. } => { + // let mut indices = vec![]; + // let value = self.lower_aggregate_access(expr, &mut indices); + // self.builder.aggregate_access(value, indices, expr.into()) + // } + + // hir_def::Expr::Subscript { value, index } => { + // let value_ty = self.expr_ty(value).deref(self.db); + // if value_ty.is_aggregate(self.db) { + // let mut indices = vec![]; + // let value = self.lower_aggregate_access(expr, &mut indices); + // self.builder.aggregate_access(value, indices, expr.into()) + // } else if value_ty.is_map(self.db) { + // let value = self.lower_expr_to_value(value); + // let key = self.lower_expr_to_value(index); + // self.builder.map_access(value, key, expr.into()) + // } else { + // unreachable!() + // } + // } + + // hir_def::Expr::Call(func, generic_args, args) => { + // let ty = self.expr_ty(expr); + // self.lower_call(func, generic_args, &args.kind, ty, expr.into()) + // } + + // hir_def::Expr::List { elts } | hir_def::Expr::Tuple { elts } => { + // let args = elts + // .iter() + // .map(|elem| self.lower_expr_to_value(elem)) + // .collect(); + // let ty = self.expr_ty(expr); + // self.builder.aggregate_construct(ty, args, expr.into()) + // } + + // hir_def::Expr::Repeat { value, len: _ } => { + // let array_type = if let Type::Array(array_type) = self.analyzer_body.expressions + // [&expr.id] + // .typ + // .typ(self.db.upcast()) + // { + // array_type + // } else { + // panic!("not an array"); + // }; + + // let args = vec![self.lower_expr_to_value(value); array_type.size]; + // let ty = self.expr_ty(expr); + // self.builder.aggregate_construct(ty, args, expr.into()) + // } + + // hir_def::Expr::Bool(b) => { + // let imm = self.builder.make_imm_from_bool(*b, ty); + // self.builder.bind(imm, expr.into()) + // } + + // hir_def::Expr::Name(name) => { + // let value = self.resolve_name(name); + // self.builder.bind(value, expr.into()) + // } + + // hir_def::Expr::Path(path) => { + // let value = self.resolve_path(path, expr.into()); + // self.builder.bind(value, expr.into()) + // } + + // hir_def::Expr::Num(num) => { + // let imm = Literal::new(num).parse().unwrap(); + // let imm = self.builder.make_imm(imm, ty); + // self.builder.bind(imm, expr.into()) + // } + + // hir_def::Expr::Str(s) => { + // let ty = self.expr_ty(expr); + // let const_value = self.make_local_constant( + // "str_in_func".into(), + // ty, + // ConstantValue::Str(s.clone()), + // expr.into(), + // ); + // self.builder.bind(const_value, expr.into()) + // } + + // hir_def::Expr::Unit => { + // let value = self.make_unit(); + // self.builder.bind(value, expr.into()) + // } + // }; + + // for Adjustment { into, kind } in &self.analyzer_body.expressions[&expr.id].type_adjustments + // { + // let into_ty = self.lower_analyzer_type(*into); + + // match kind { + // AdjustmentKind::Copy => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.mem_copy(val, expr.into()); + // } + // AdjustmentKind::Load => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.load(val, expr.into()); + // } + // AdjustmentKind::IntSizeIncrease => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.primitive_cast(val, into_ty, expr.into()) + // } + // AdjustmentKind::StringSizeIncrease => {} // XXX + // } + // ty = into_ty; + // } + // (inst, ty) + panic!() + } + + // fn inst_result_or_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + // self.builder + // .inst_result(inst) + // .and_then(|r| r.value_id()) + // .unwrap_or_else(|| self.map_to_tmp(inst, ty)) + // } + + // pub(super) fn lower_expr_to_value(&mut self, expr: &hir_def::Expr) -> ValueId { + // let (inst, ty) = self.lower_expr(expr); + // self.map_to_tmp(inst, ty) + // } + + pub(super) fn enter_scope(&mut self) { + let new_scope = Scope::with_parent(self.current_scope); + self.current_scope = self.scopes.alloc(new_scope); + } + + pub(super) fn leave_scope(&mut self) { + self.current_scope = self.scopes[self.current_scope].parent.unwrap(); + } + + pub(super) fn make_imm(&mut self, imm: impl Into, ty: TypeId) -> ValueId { + self.builder.make_value(Value::Immediate { + imm: imm.into(), + ty, + }) + } + + // pub(super) fn make_u256_imm(&mut self, value: impl Into) -> ValueId { + // let u256_ty = self.u256_ty(); + // self.make_imm(value, u256_ty) + // } + + // pub(super) fn map_to_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + // match &self.builder.inst_data(inst).kind { + // &InstKind::Bind { src } => { + // let value = *src; + // self.builder.remove_inst(inst); + // value + // } + // _ => { + // let tmp = Value::Temporary { inst, ty }; + // let result = self.builder.make_value(tmp); + // self.builder.map_result(inst, result.into()); + // result + // } + // } + // } + + // fn new( + // db: &'db dyn MirDb, + // func: FunctionId, + // ast: &'a Node, + // analyzer_body: &'a fe_analyzer2::context::FunctionBody, + // ) -> Self { + // let mut builder = BodyBuilder::new(func, ast.into()); + // let mut scopes = Arena::new(); + + // // Make a root scope. A root scope collects function parameters and module + // // constants. + // let root = Scope::root(db, func, &mut builder); + // let current_scope = scopes.alloc(root); + // Self { + // db, + // builder, + // ast, + // func, + // analyzer_body, + // scopes, + // current_scope, + // } + // } + + // fn lower_analyzer_type(&self, analyzer_ty: analyzer_types::TypeId) -> TypeId { + // // If the analyzer type is generic we first need to resolve it to its concrete + // // type before lowering to a MIR type + // if let analyzer_types::Type::Generic(generic) = analyzer_ty.deref_typ(self.db.upcast()) { + // let resolved_type = self + // .func + // .signature(self.db) + // .resolved_generics + // .get(&generic.name) + // .cloned() + // .expect("expected generic to be resolved"); + + // return self.db.mir_lowered_type(resolved_type); + // } + + // self.db.mir_lowered_type(analyzer_ty) + // } + + // fn lower(mut self) -> FunctionBody { + // for stmt in &self.ast.kind.body { + // self.lower_stmt(stmt) + // } + + // let last_block = self.builder.current_block(); + // if !self.builder.is_block_terminated(last_block) { + // let unit = self.make_unit(); + // self.builder.ret(unit()); + // } + + // self.builder.build() + // } + + // fn branch_eq( + // &mut self, + // v1: ValueId, + // v2: ValueId, + // true_bb: BasicBlockId, + // false_bb: BasicBlockId, + // ) { + // let cond = self.builder.eq(v1, v2); + // let bool_ty = self.bool_ty(); + // let cond = self.map_to_tmp(cond, bool_ty); + // self.builder.branch(cond, true_bb, false_bb); + // } + + // fn lower_if(&mut self, cond: &hir_def::Expr, then: &[hir_def::Stmt], else_: &[hir_def::Stmt]) { + // let cond = self.lower_expr_to_value(cond); + + // if else_.is_empty() { + // let then_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // self.builder.branch(cond, then_bb, merge_bb()); + + // // Lower then block. + // self.builder.move_to_block(then_bb); + // self.enter_scope(); + // for stmt in then { + // self.lower_stmt(stmt); + // } + // self.builder.jump(merge_bb()); + // self.builder.move_to_block(merge_bb); + // self.leave_scope(); + // } else { + // let then_bb = self.builder.make_block(); + // let else_bb = self.builder.make_block(); + + // self.builder.branch(cond, then_bb, else_bb()); + + // // Lower then block. + // self.builder.move_to_block(then_bb); + // self.enter_scope(); + // for stmt in then { + // self.lower_stmt(stmt); + // } + // self.leave_scope(); + // let then_block_end_bb = self.builder.current_block(); + + // // Lower else_block. + // self.builder.move_to_block(else_bb); + // self.enter_scope(); + // for stmt in else_ { + // self.lower_stmt(stmt); + // } + // self.leave_scope(); + // let else_block_end_bb = self.builder.current_block(); + + // let merge_bb = self.builder.make_block(); + // if !self.builder.is_block_terminated(then_block_end_bb) { + // self.builder.move_to_block(then_block_end_bb); + // self.builder.jump(merge_bb()); + // } + // if !self.builder.is_block_terminated(else_block_end_bb) { + // self.builder.move_to_block(else_block_end_bb); + // self.builder.jump(merge_bb()); + // } + // self.builder.move_to_block(merge_bb); + // } + // } + + // NOTE: we assume a type of `iter` is array. + // TODO: Desugar to `loop` + `match` like rustc in HIR to generate better MIR. + fn lower_for_loop( + &mut self, + loop_variable: &hir_def::IdentId, + iter: &hir_def::Expr, + body: &[hir_def::Stmt], + ) { + // let preheader_bb = self.builder.make_block(); + // let entry_bb = self.builder.make_block(); + // let exit_bb = self.builder.make_block(); + + // let iter_elem_ty = self.analyzer_body.var_types[&loop_variable.id]; + // let iter_elem_ty = self.lower_analyzer_type(iter_elem_ty); + + // self.builder.jump(preheader_bb()); + + // // `For` has its scope from preheader block. + // self.enter_loop_scope(entry_bb, exit_bb); + + // /* Lower preheader. */ + // self.builder.move_to_block(preheader_bb); + + // // Declare loop_variable. + // let loop_value = self.builder.declare(Local::user_local( + // loop_variable.kind.clone(), + // iter_elem_ty, + // loop_variable.into(), + // )); + // self.scope_mut() + // .declare_var(&loop_variable.kind, loop_value); + + // // Declare and initialize `loop_idx` to 0. + // let loop_idx = Local::tmp_local("$loop_idx_tmp".into(), self.u256_ty()); + // let loop_idx = self.builder.declare(loop_idx); + // let imm_zero = self.make_u256_imm(0u32); + // let imm_zero = self.builder.bind(imm_zero()); + // self.builder.map_result(imm_zero, loop_idx.into()); + + // // Evaluates loop variable. + // let iter_ty = self.expr_ty(iter); + // let iter = self.lower_expr_to_value(iter); + + // // Create maximum loop count. + // let maximum_iter_count = match &iter_ty.deref(self.db).data(self.db).kind { + // ir::TypeKind::Array(ir::types::ArrayDef { len, .. }) => *len, + // _ => unreachable!(), + // }; + // let maximum_iter_count = self.make_u256_imm(maximum_iter_count); + // self.branch_eq(loop_idx, maximum_iter_count, exit_bb, entry_bb); + // self.scope_mut().loop_idx = Some(loop_idx); + // self.scope_mut().maximum_iter_count = Some(maximum_iter_count); + + // /* Lower body. */ + // self.builder.move_to_block(entry_bb); + + // // loop_variable = array[loop_idx] + // let iter_elem = self.builder.aggregate_access(iter, vec![loop_idx]()); + // self.builder + // .map_result(iter_elem, AssignableValue::Value(loop_value)); + + // for stmt in body { + // self.lower_stmt(stmt); + // } + + // // loop_idx += 1 + // let imm_one = self.make_u256_imm(1u32); + // let inc = self.builder.add(loop_idx, imm_one()); + // self.builder + // .map_result(inc, AssignableValue::Value(loop_idx)); + // self.branch_eq(loop_idx, maximum_iter_count, exit_bb, entry_bb); + + // /* Move to exit bb */ + // self.leave_scope(); + // self.builder.move_to_block(exit_bb); + } + + // fn lower_assignable_value(&mut self, expr: &hir_def::Expr) -> AssignableValue { + // match &expr.kind { + // hir_def::Expr::Attribute { value, attr } => { + // let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + // let idx = self.make_u256_imm(idx); + // let lhs = self.lower_assignable_value(value).into(); + // AssignableValue::Aggregate { lhs, idx } + // } + // hir_def::Expr::Subscript { value, index } => { + // let lhs = self.lower_assignable_value(value).into(); + // let attr = self.lower_expr_to_value(index); + // let value_ty = self.expr_ty(value).deref(self.db); + // if value_ty.is_aggregate(self.db) { + // AssignableValue::Aggregate { lhs, idx: attr } + // } else if value_ty.is_map(self.db) { + // AssignableValue::Map { lhs, key: attr } + // } else { + // unreachable!() + // } + // } + // hir_def::Expr::Name(name) => self.resolve_name(name).into(), + // hir_def::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), + // _ => self.lower_expr_to_value(expr).into(), + // } + // } + + // /// Returns the pre-adjustment type of the given `Expr` + // fn expr_ty(&self, expr: &hir_def::Expr) -> TypeId { + // let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; + // self.lower_analyzer_type(analyzer_ty) + // } + + // fn lower_bool_op( + // &mut self, + // op: hir_def::LogicalBinOp, + // lhs: &hir_def::Expr, + // rhs: &hir_def::Expr, + // ty: TypeId, + // ) -> InstId { + // let true_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // let lhs = self.lower_expr_to_value(lhs); + // let tmp = self + // .builder + // .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); + + // match op { + // hir_def::LogicalBinOp::And => { + // self.builder.branch(lhs, true_bb, false_bb()); + + // self.builder.move_to_block(true_bb); + // let (rhs, _rhs_ty) = self.lower_expr(rhs); + // self.builder.map_result(rhs, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(false_bb); + // let false_imm = self.builder.make_imm_from_bool(false, ty); + // let false_imm_copy = self.builder.bind(false_imm()); + // self.builder.map_result(false_imm_copy, tmp.into()); + // self.builder.jump(merge_bb()); + // } + + // hir_def::LogicalBinOp::Or => { + // self.builder.branch(lhs, true_bb, false_bb()); + + // self.builder.move_to_block(true_bb); + // let true_imm = self.builder.make_imm_from_bool(true, ty); + // let true_imm_copy = self.builder.bind(true_imm()); + // self.builder.map_result(true_imm_copy, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(false_bb); + // let (rhs, _rhs_ty) = self.lower_expr(rhs); + // self.builder.map_result(rhs, tmp.into()); + // self.builder.jump(merge_bb()); + // } + // } + + // self.builder.move_to_block(merge_bb); + // self.builder.bind(tmp()) + // } + + // fn lower_binop( + // &mut self, + // op: hir_def::BinOp, + // lhs: ValueId, + // rhs: ValueId, + // // source: SourceInfo, + // ) -> InstId { + // match op { + // hir_def::BinOp::Add => self.builder.add(lhs, rhs), + // hir_def::BinOp::Sub => self.builder.sub(lhs, rhs), + // hir_def::BinOp::Mult => self.builder.mul(lhs, rhs), + // hir_def::BinOp::Div => self.builder.div(lhs, rhs), + // hir_def::BinOp::Mod => self.builder.modulo(lhs, rhs), + // hir_def::BinOp::Pow => self.builder.pow(lhs, rhs), + // hir_def::BinOp::LShift => self.builder.shl(lhs, rhs), + // hir_def::BinOp::RShift => self.builder.shr(lhs, rhs), + // hir_def::BinOp::BitOr => self.builder.bit_or(lhs, rhs), + // hir_def::BinOp::BitXor => self.builder.bit_xor(lhs, rhs), + // hir_def::BinOp::BitAnd => self.builder.bit_and(lhs, rhs), + // } + // } + + // fn lower_comp_op( + // &mut self, + // op: hir_def::CompBinOp, + // lhs: ValueId, + // rhs: ValueId, + // // source: SourceInfo, + // ) -> InstId { + // match op { + // hir_def::CompBinOp::Eq => self.builder.eq(lhs, rhs), + // hir_def::CompBinOp::NotEq => self.builder.ne(lhs, rhs), + // hir_def::CompBinOp::Lt => self.builder.lt(lhs, rhs), + // hir_def::CompBinOp::LtE => self.builder.le(lhs, rhs), + // hir_def::CompBinOp::Gt => self.builder.gt(lhs, rhs), + // hir_def::CompBinOp::GtE => self.builder.ge(lhs, rhs), + // } + // } + + // fn resolve_generics_args( + // &mut self, + // method: &analyzer_items::FunctionId, + // args: &[Id], + // ) -> BTreeMap { + // method + // .signature(self.db.upcast()) + // .params + // .iter() + // .zip(args.iter().map(|val| { + // self.builder + // .value_ty(*val) + // .analyzer_ty(self.db) + // .expect("invalid parameter") + // })) + // .filter_map(|(param, typ)| { + // if let Type::Generic(generic) = + // param.typ.clone().unwrap().deref_typ(self.db.upcast()) + // { + // Some((generic.name, typ)) + // } else { + // None + // } + // }) + // .collect::>() + // } + + // fn lower_function_id(&mut self, function: &hir_def::Func, args: &[Id]) -> FunctionId { + // let resolved_generics = self.resolve_generics_args(function, args); + // if function.is_generic(self.db.upcast()) { + // self.db + // .mir_lowered_monomorphized_func_signature(*function, resolved_generics) + // } else { + // self.db.mir_lowered_func_signature(*function) + // } + // } + + fn lower_call( + &mut self, + func: &hir_def::Expr, + _generic_args: &Option>, + args: &[hir_def::CallArg], + ty: TypeId, + ) -> InstId { + // let call_type = &self.analyzer_body.calls[&func.id]; + + // let mut args: Vec<_> = args + // .iter() + // .map(|arg| self.lower_expr_to_value(&arg.kind.value)) + // .collect(); + + // match call_type { + // AnalyzerCallType::BuiltinFunction(GlobalFunction::Keccak256) => { + // self.builder.keccak256(args[0], source) + // } + + // AnalyzerCallType::Intrinsic(intrinsic) => { + // self.builder + // .yul_intrinsic((*intrinsic).into(), args, source) + // } + + // AnalyzerCallType::BuiltinValueMethod { method, .. } => { + // let arg = self.lower_method_receiver(func); + // match method { + // ValueMethod::ToMem => self.builder.mem_copy(arg, source), + // ValueMethod::AbiEncode => self.builder.abi_encode(arg, source), + // } + // } + + // // We ignores `args[0]', which represents `context` and not used for now. + // AnalyzerCallType::BuiltinAssociatedFunction { contract, function } => match function { + // ContractTypeMethod::Create => self.builder.create(args[1], *contract, source), + // ContractTypeMethod::Create2 => { + // self.builder.create2(args[1], args[2], *contract, source) + // } + // }, + + // AnalyzerCallType::AssociatedFunction { function, .. } + // | AnalyzerCallType::Pure(function) => { + // let func_id = self.lower_function_id(function, &args); + // self.builder.call(func_id, args, CallType::Internal, source) + // } + + // AnalyzerCallType::ValueMethod { method, .. } => { + // let mut method_args = vec![self.lower_method_receiver(func)]; + // let func_id = self.lower_function_id(method, &args); + + // method_args.append(&mut args); + + // self.builder + // .call(func_id, method_args, CallType::Internal, source) + // } + // AnalyzerCallType::TraitValueMethod { + // trait_id, method, .. + // } if trait_id.is_std_trait(self.db.upcast(), EMITTABLE_TRAIT_NAME) + // && method.name(self.db.upcast()) == EMIT_FN_NAME => + // { + // let event = self.lower_method_receiver(func); + // self.builder.emit(event, source) + // } + // AnalyzerCallType::TraitValueMethod { + // method, + // trait_id, + // generic_type, + // .. + // } => { + // let mut method_args = vec![self.lower_method_receiver(func)]; + // method_args.append(&mut args); + + // let concrete_type = self + // .func + // .signature(self.db) + // .resolved_generics + // .get(&generic_type.name) + // .cloned() + // .expect("unresolved generic type"); + + // let impl_ = concrete_type + // .get_impl_for(self.db.upcast(), *trait_id) + // .expect("missing impl"); + + // let function = impl_ + // .function(self.db.upcast(), &method.name(self.db.upcast())) + // .expect("missing function"); + + // let func_id = self.db.mir_lowered_func_signature(function); + // self.builder + // .call(func_id, method_args, CallType::Internal, source) + // } + // AnalyzerCallType::External { function, .. } => { + // let receiver = self.lower_method_receiver(func); + // debug_assert!(self.builder.value_ty(receiver).is_address(self.db)); + + // let mut method_args = vec![receiver]; + // method_args.append(&mut args); + // let func_id = self.db.mir_lowered_func_signature(*function); + // self.builder + // .call(func_id, method_args, CallType::External, source) + // } + + // AnalyzerCallType::TypeConstructor(to_ty) => { + // if to_ty.is_string(self.db.upcast()) { + // let arg = *args.last().unwrap(); + // self.builder.mem_copy(arg, source) + // } else if ty.is_primitive(self.db) { + // // TODO: Ignore `ctx` for now. + // let arg = *args.last().unwrap(); + // let arg_ty = self.builder.value_ty(arg); + // if arg_ty == ty { + // self.builder.bind(arg, source) + // } else { + // debug_assert!(!arg_ty.is_ptr(self.db)); // Should be explicitly `Load`ed + // self.builder.primitive_cast(arg, ty, source) + // } + // } else if ty.is_aggregate(self.db) { + // self.builder.aggregate_construct(ty, args, source) + // } else { + // unreachable!() + // } + // } + + // AnalyzerCallType::EnumConstructor(variant) => { + // let tag_type = ty.enum_disc_type(self.db); + // let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + // let data_ty = ty.enum_variant_type(self.db, *variant); + // let enum_args = if data_ty.is_unit(self.db) { + // vec![tag, self.make_unit()] + // } else { + // std::iter::once(tag).chain(args).collect() + // }; + // self.builder.aggregate_construct(ty, enum_args, source) + // } + // } + todo!(); + } + + // // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. + // fn lower_method_receiver(&mut self, receiver: &hir_def::Expr) -> ValueId { + // match &receiver.kind { + // hir_def::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), + // _ => unreachable!(), + // } + // } + + // fn lower_aggregate_access( + // &mut self, + // expr: &hir_def::Expr, + // indices: &mut Vec, + // ) -> ValueId { + // match &expr.kind { + // hir_def::Expr::Attribute { value, attr } => { + // let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + // let value = self.lower_aggregate_access(value, indices); + // indices.push(self.make_u256_imm(index)); + // value + // } + + // hir_def::Expr::Subscript { value, index } + // if self.expr_ty(value).deref(self.db).is_aggregate(self.db) => + // { + // let value = self.lower_aggregate_access(value, indices); + // indices.push(self.lower_expr_to_value(index)); + // value + // } + + // _ => self.lower_expr_to_value(expr), + // } + // } + + // fn make_unit(&mut self) -> ValueId { + // let unit_ty = analyzer_types::TypeId::unit(self.db.upcast()); + // let unit_ty = self.db.mir_lowered_type(unit_ty); + // self.builder.make_unit(unit_ty) + // } + + // fn make_local_constant( + // &mut self, + // name: SmolStr, + // ty: TypeId, + // value: ConstantValue, + // source: SourceInfo, + // ) -> ValueId { + // let function_id = self.builder.func_id(); + // let constant = Constant { + // name, + // value, + // ty, + // module_id: function_id.module(self.db), + // source, + // }; + + // let constant_id = self.db.mir_intern_const(constant.into()); + // self.builder.make_constant(constant_id, ty) + // } + + // fn u256_ty(&mut self) -> TypeId { + // self.db + // .mir_intern_type(ir::Type::new(ir::TypeKind::U256, None).into()) + // } + + // fn bool_ty(&mut self) -> TypeId { + // self.db + // .mir_intern_type(ir::Type::new(ir::TypeKind::Bool, None).into()) + // } + + // fn enter_loop_scope(&mut self, entry: BasicBlockId, exit: BasicBlockId) { + // let new_scope = Scope::loop_scope(self.current_scope, entry, exit); + // self.current_scope = self.scopes.alloc(new_scope); + // } + + // /// Resolve a name appeared in an expression. + // /// NOTE: Don't call this to resolve method receiver. + // fn resolve_name(&mut self, name: &str) -> ValueId { + // if let Some(value) = self.scopes[self.current_scope].resolve_name(&self.scopes, name) { + // // Name is defined in local. + // value + // } else { + // // Name is defined in global. + // let func_id = self.builder.func_id(); + // let module = func_id.module(self.db); + // let constant = match module + // .resolve_name(self.db.upcast(), name) + // .unwrap() + // .unwrap() + // { + // NamedThing::Item(analyzer_items::Item::Constant(id)) => { + // self.db.mir_lowered_constant(id) + // } + // _ => panic!("name defined in global must be constant"), + // }; + // let ty = constant.ty(self.db); + // self.builder.make_constant(constant, ty) + // } + // } + + // /// Resolve a path appeared in an expression. + // /// NOTE: Don't call this to resolve method receiver. + // fn resolve_path(&mut self, path: &hir_def::Path, source: SourceInfo) -> ValueId { + // let func_id = self.builder.func_id(); + // let module = func_id.module(self.db); + // match module.resolve_path(self.db.upcast(), path).value.unwrap() { + // NamedThing::Item(analyzer_items::Item::Constant(id)) => { + // let constant = self.db.mir_lowered_constant(id); + // let ty = constant.ty(self.db); + // self.builder.make_constant(constant, ty) + // } + // NamedThing::EnumVariant(variant) => { + // let enum_ty = self + // .db + // .mir_lowered_type(variant.parent(self.db.upcast()).as_type(self.db.upcast())); + // let tag_type = enum_ty.enum_disc_type(self.db); + // let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + // let data = self.make_unit(); + // let enum_args = vec![tag, data]; + // let inst = self.builder.aggregate_construct(enum_ty, enum_args, source); + // self.map_to_tmp(inst, enum_ty) + // } + // _ => panic!("path defined in global must be constant"), + // } + // } + + fn scope(&self) -> &Scope { + &self.scopes[self.current_scope] + } + + fn scope_mut(&mut self) -> &mut Scope { + &mut self.scopes[self.current_scope] + } +} + +#[derive(Debug)] +struct Scope { + parent: Option, + loop_entry: Option, + loop_exit: Option, + variables: FxHashMap, + // TODO: Remove the below two fields when `for` loop desugaring is implemented. + loop_idx: Option, + maximum_iter_count: Option, +} + +impl Scope { + fn root(db: &dyn MirDb, func: FunctionId, builder: &mut BodyBuilder) -> Self { + let mut root = Self { + parent: None, + loop_entry: None, + loop_exit: None, + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + }; + + // // Declare function parameters. + // for param in &func.signature(db).params { + // let local = Local::arg_local(param.name.clone(), param.ty); + // let value_id = builder.store_func_arg(local); + // root.declare_var(¶m.name, value_id) + // } + + root + } + + fn with_parent(parent: ScopeId) -> Self { + Self { + parent: parent.into(), + loop_entry: None, + loop_exit: None, + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + } + } + + fn loop_scope(parent: ScopeId, loop_entry: BasicBlockId, loop_exit: BasicBlockId) -> Self { + Self { + parent: parent.into(), + loop_entry: loop_entry.into(), + loop_exit: loop_exit.into(), + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + } + } + + fn loop_entry(&self, scopes: &Arena) -> BasicBlockId { + match self.loop_entry { + Some(entry) => entry, + None => scopes[self.parent.unwrap()].loop_entry(scopes), + } + } + + fn loop_exit(&self, scopes: &Arena) -> BasicBlockId { + match self.loop_exit { + Some(exit) => exit, + None => scopes[self.parent.unwrap()].loop_exit(scopes), + } + } + + fn loop_idx(&self, scopes: &Arena) -> Option { + match self.loop_idx { + Some(idx) => Some(idx), + None => scopes[self.parent?].loop_idx(scopes), + } + } + + fn maximum_iter_count(&self, scopes: &Arena) -> Option { + match self.maximum_iter_count { + Some(count) => Some(count), + None => scopes[self.parent?].maximum_iter_count(scopes), + } + } + + fn declare_var(&mut self, name: &SmolStr, value: ValueId) { + debug_assert!(!self.variables.contains_key(name)); + + self.variables.insert(name.clone(), value); + } + + fn resolve_name(&self, scopes: &Arena, name: &str) -> Option { + match self.variables.get(name) { + Some(id) => Some(*id), + None => scopes[self.parent?].resolve_name(scopes, name), + } + } +} + +// fn make_param(db: &dyn MirDb, name: impl Into, ty: TypeId) -> FunctionParam { +// FunctionParam { +// name: name.into(), +// ty: db.mir_lowered_type(ty), +// } +// } diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs new file mode 100644 index 0000000000..7161774caf --- /dev/null +++ b/crates/mir2/src/lower/mod.rs @@ -0,0 +1,5 @@ +pub mod constant; +pub mod function; +pub mod types; + +// mod pattern_match; diff --git a/crates/mir2/src/lower/pattern_match/decision_tree.rs b/crates/mir2/src/lower/pattern_match/decision_tree.rs new file mode 100644 index 0000000000..5ab0bca5ef --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/decision_tree.rs @@ -0,0 +1,569 @@ +//! This module contains the decision tree definition and its construction +//! function. +//! The algorithm for efficient decision tree construction is mainly based on [Compiling pattern matching to good decision trees](https://dl.acm.org/doi/10.1145/1411304.1411311). +use std::io; + +use indexmap::IndexMap; +use smol_str::SmolStr; + +use super::tree_vis::TreeRenderer; + +pub fn build_decision_tree( + db: &dyn AnalyzerDb, + pattern_matrix: &PatternMatrix, + policy: ColumnSelectionPolicy, +) -> DecisionTree { + let builder = DecisionTreeBuilder::new(policy); + let simplified_arms = SimplifiedArmMatrix::new(pattern_matrix); + + builder.build(db, simplified_arms) +} + +#[derive(Debug)] +pub enum DecisionTree { + Leaf(LeafNode), + Switch(SwitchNode), +} + +impl DecisionTree { + #[allow(unused)] + pub fn dump_dot(&self, db: &dyn AnalyzerDb, w: &mut W) -> io::Result<()> + where + W: io::Write, + { + let renderer = TreeRenderer::new(db, self); + dot2::render(&renderer, w).map_err(|err| match err { + dot2::Error::Io(err) => err, + _ => panic!("invalid graphviz id"), + }) + } +} + +#[derive(Debug)] +pub struct LeafNode { + pub arm_idx: usize, + pub binds: IndexMap<(SmolStr, usize), Occurrence>, +} + +impl LeafNode { + fn new(arm: SimplifiedArm, occurrences: &[Occurrence]) -> Self { + let arm_idx = arm.body; + let binds = arm.finalize_binds(occurrences); + Self { arm_idx, binds } + } +} + +#[derive(Debug)] +pub struct SwitchNode { + pub occurrence: Occurrence, + pub arms: Vec<(Case, DecisionTree)>, +} + +#[derive(Debug, Clone, Copy)] +pub enum Case { + Ctor(ConstructorKind), + Default, +} + +#[derive(Debug, Clone, Default)] +pub struct ColumnSelectionPolicy(Vec); + +impl ColumnSelectionPolicy { + /// The score of column i is the sum of the negation of the arities of + /// constructors in sigma(i). + pub fn arity(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::Arity) + } + + /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)). + /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1. + pub fn small_branching(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::SmallBranching) + } + + /// The score is the number of needed rows of column i in the necessity + /// matrix. + #[allow(unused)] + pub fn needed_column(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::NeededColumn) + } + + /// The score is the larger row index j such that column i is needed for all + /// rows j′; 1 ≤ j′ ≤ j. + pub fn needed_prefix(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::NeededPrefix) + } + + fn select_column(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> usize { + let mut candidates: Vec<_> = (0..mat.ncols()).collect(); + + for scoring_fn in &self.0 { + let mut max_score = i32::MIN; + for col in std::mem::take(&mut candidates) { + let score = scoring_fn.score(db, mat, col); + match score.cmp(&max_score) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + candidates.push(col); + } + std::cmp::Ordering::Greater => { + candidates = vec![col]; + max_score = score; + } + } + } + + if candidates.len() == 1 { + return candidates.pop().unwrap(); + } + } + + // If there are more than one candidates remained, filter the columns with the + // shortest occurrences among the candidates, then select the rightmost one. + // This heuristics corresponds to the R pseudo heuristic in the paper. + let mut shortest_occurrences = usize::MAX; + for col in std::mem::take(&mut candidates) { + let occurrences = mat.occurrences[col].len(); + match occurrences.cmp(&shortest_occurrences) { + std::cmp::Ordering::Less => { + candidates = vec![col]; + shortest_occurrences = occurrences; + } + std::cmp::Ordering::Equal => { + candidates.push(col); + } + std::cmp::Ordering::Greater => {} + } + } + + candidates.pop().unwrap() + } + + fn add_heuristic(&mut self, heuristic: ColumnScoringFunction) -> &mut Self { + debug_assert!(!self.0.contains(&heuristic)); + self.0.push(heuristic); + self + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Occurrence(Vec); + +impl Occurrence { + pub fn new() -> Self { + Self(vec![]) + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn parent(&self) -> Option { + let mut inner = self.0.clone(); + inner.pop().map(|_| Occurrence(inner)) + } + + pub fn last_index(&self) -> Option { + self.0.last().cloned() + } + + fn phi_specialize(&self, db: &dyn AnalyzerDb, ctor: ConstructorKind) -> Vec { + let arity = ctor.arity(db); + (0..arity) + .map(|i| { + let mut inner = self.0.clone(); + inner.push(i); + Self(inner) + }) + .collect() + } + + fn len(&self) -> usize { + self.0.len() + } +} + +struct DecisionTreeBuilder { + policy: ColumnSelectionPolicy, +} + +impl DecisionTreeBuilder { + fn new(policy: ColumnSelectionPolicy) -> Self { + DecisionTreeBuilder { policy } + } + + fn build(&self, db: &dyn AnalyzerDb, mut mat: SimplifiedArmMatrix) -> DecisionTree { + debug_assert!(mat.nrows() > 0, "unexhausted pattern matrix"); + + if mat.is_first_arm_satisfied() { + mat.arms.truncate(1); + return DecisionTree::Leaf(LeafNode::new(mat.arms.pop().unwrap(), &mat.occurrences)); + } + + let col = self.policy.select_column(db, &mat); + mat.swap(col); + + let mut switch_arms = vec![]; + let occurrence = &mat.occurrences[0]; + let sigma_set = mat.sigma_set(0); + for &ctor in sigma_set.iter() { + let destructured_mat = mat.phi_specialize(db, ctor, occurrence); + let subtree = self.build(db, destructured_mat); + switch_arms.push((Case::Ctor(ctor), subtree)); + } + + if !sigma_set.is_complete(db) { + let destructured_mat = mat.d_specialize(db, occurrence); + let subtree = self.build(db, destructured_mat); + switch_arms.push((Case::Default, subtree)); + } + + DecisionTree::Switch(SwitchNode { + occurrence: occurrence.clone(), + arms: switch_arms, + }) + } +} + +#[derive(Clone, Debug)] +struct SimplifiedArmMatrix { + arms: Vec, + occurrences: Vec, +} + +impl SimplifiedArmMatrix { + fn new(mat: &PatternMatrix) -> Self { + let cols = mat.ncols(); + let arms: Vec<_> = mat + .rows() + .iter() + .enumerate() + .map(|(body, pat)| SimplifiedArm::new(pat, body)) + .collect(); + let occurrences = vec![Occurrence::new(); cols]; + + SimplifiedArmMatrix { arms, occurrences } + } + + fn nrows(&self) -> usize { + self.arms.len() + } + + fn ncols(&self) -> usize { + self.arms[0].pat_vec.len() + } + + fn pat(&self, row: usize, col: usize) -> &SimplifiedPattern { + self.arms[row].pat(col) + } + + fn necessity_matrix(&self, db: &dyn AnalyzerDb) -> NecessityMatrix { + NecessityMatrix::from_mat(db, self) + } + + fn reduced_pat_mat(&self, col: usize) -> PatternMatrix { + let mut rows = Vec::with_capacity(self.nrows()); + for arm in self.arms.iter() { + let reduced_pat_vec = arm + .pat_vec + .pats() + .iter() + .enumerate() + .filter(|(i, _)| (*i != col)) + .map(|(_, pat)| pat.clone()) + .collect(); + rows.push(PatternRowVec::new(reduced_pat_vec)); + } + + PatternMatrix::new(rows) + } + + /// Returns the constructor set in the column i. + fn sigma_set(&self, col: usize) -> SigmaSet { + SigmaSet::from_rows(self.arms.iter().map(|arm| &arm.pat_vec), col) + } + + fn is_first_arm_satisfied(&self) -> bool { + self.arms[0] + .pat_vec + .pats() + .iter() + .all(SimplifiedPattern::is_wildcard) + } + + fn phi_specialize( + &self, + db: &dyn AnalyzerDb, + ctor: ConstructorKind, + occurrence: &Occurrence, + ) -> Self { + let mut new_arms = Vec::new(); + for arm in &self.arms { + new_arms.extend_from_slice(&arm.phi_specialize(db, ctor, occurrence)); + } + + let mut new_occurrences = self.occurrences[0].phi_specialize(db, ctor); + new_occurrences.extend_from_slice(&self.occurrences.as_slice()[1..]); + + Self { + arms: new_arms, + occurrences: new_occurrences, + } + } + + fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Self { + let mut new_arms = Vec::new(); + for arm in &self.arms { + new_arms.extend_from_slice(&arm.d_specialize(db, occurrence)); + } + + Self { + arms: new_arms, + occurrences: self.occurrences.as_slice()[1..].to_vec(), + } + } + + fn swap(&mut self, i: usize) { + for arm in &mut self.arms { + arm.swap(0, i) + } + self.occurrences.swap(0, i); + } +} + +#[derive(Clone, Debug)] +struct SimplifiedArm { + pat_vec: PatternRowVec, + body: usize, + binds: IndexMap<(SmolStr, usize), Occurrence>, +} + +impl SimplifiedArm { + fn new(pat: &PatternRowVec, body: usize) -> Self { + let pat = PatternRowVec::new(pat.inner.iter().map(generalize_pattern).collect()); + Self { + pat_vec: pat, + body, + binds: IndexMap::new(), + } + } + + fn len(&self) -> usize { + self.pat_vec.len() + } + + fn pat(&self, col: usize) -> &SimplifiedPattern { + &self.pat_vec.inner[col] + } + + fn phi_specialize( + &self, + db: &dyn AnalyzerDb, + ctor: ConstructorKind, + occurrence: &Occurrence, + ) -> Vec { + let body = self.body; + let binds = self.new_binds(occurrence); + + self.pat_vec + .phi_specialize(db, ctor) + .into_iter() + .map(|pat| SimplifiedArm { + pat_vec: pat, + body, + binds: binds.clone(), + }) + .collect() + } + + fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Vec { + let body = self.body; + let binds = self.new_binds(occurrence); + + self.pat_vec + .d_specialize(db) + .into_iter() + .map(|pat| SimplifiedArm { + pat_vec: pat, + body, + binds: binds.clone(), + }) + .collect() + } + + fn new_binds(&self, occurrence: &Occurrence) -> IndexMap<(SmolStr, usize), Occurrence> { + let mut binds = self.binds.clone(); + if let Some(SimplifiedPatternKind::WildCard(Some(bind))) = + self.pat_vec.head().map(|pat| &pat.kind) + { + binds.entry(bind.clone()).or_insert(occurrence.clone()); + } + binds + } + + fn finalize_binds(self, occurrences: &[Occurrence]) -> IndexMap<(SmolStr, usize), Occurrence> { + debug_assert!(self.len() == occurrences.len()); + + let mut binds = self.binds; + for (pat, occurrence) in self.pat_vec.pats().iter().zip(occurrences.iter()) { + debug_assert!(pat.is_wildcard()); + + if let SimplifiedPatternKind::WildCard(Some(bind)) = &pat.kind { + binds.entry(bind.clone()).or_insert(occurrence.clone()); + } + } + + binds + } + + fn swap(&mut self, i: usize, j: usize) { + self.pat_vec.swap(i, j); + } +} + +struct NecessityMatrix { + data: Vec, + ncol: usize, + nrow: usize, +} + +impl NecessityMatrix { + fn new(ncol: usize, nrow: usize) -> Self { + let data = vec![false; ncol * nrow]; + Self { data, ncol, nrow } + } + + fn from_mat(db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> Self { + let nrow = mat.nrows(); + let ncol = mat.ncols(); + let mut necessity_mat = Self::new(ncol, nrow); + + necessity_mat.compute(db, mat); + necessity_mat + } + + fn compute(&mut self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) { + for row in 0..self.nrow { + for col in 0..self.ncol { + let pat = mat.pat(row, col); + let pos = self.pos(row, col); + + if !pat.is_wildcard() { + self.data[pos] = true; + } else { + let reduced_pat_mat = mat.reduced_pat_mat(col); + self.data[pos] = !reduced_pat_mat.is_row_useful(db, row); + } + } + } + } + + fn compute_needed_column_score(&self, col: usize) -> i32 { + let mut num = 0; + for i in 0..self.nrow { + if self.data[self.pos(i, col)] { + num += 1; + } + } + + num + } + + fn compute_needed_prefix_score(&self, col: usize) -> i32 { + let mut current_row = 0; + for i in 0..self.nrow { + if self.data[self.pos(i, col)] { + current_row += 1; + } else { + return current_row; + } + } + + current_row + } + + fn pos(&self, row: usize, col: usize) -> usize { + self.ncol * row + col + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ColumnScoringFunction { + /// The score of column i is the sum of the negation of the arities of + /// constructors in sigma(i). + Arity, + + /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)). + /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1. + SmallBranching, + + /// The score is the number of needed rows of column i in the necessity + /// matrix. + NeededColumn, + + NeededPrefix, +} + +impl ColumnScoringFunction { + fn score(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix, col: usize) -> i32 { + match self { + ColumnScoringFunction::Arity => mat + .sigma_set(col) + .iter() + .map(|c| -(c.arity(db) as i32)) + .sum(), + + ColumnScoringFunction::SmallBranching => { + let sigma_set = mat.sigma_set(col); + let score = -(mat.sigma_set(col).len() as i32); + if sigma_set.is_complete(db) { + score + } else { + score - 1 + } + } + + ColumnScoringFunction::NeededColumn => { + mat.necessity_matrix(db).compute_needed_column_score(col) + } + + ColumnScoringFunction::NeededPrefix => { + mat.necessity_matrix(db).compute_needed_prefix_score(col) + } + } + } +} + +fn generalize_pattern(pat: &SimplifiedPattern) -> SimplifiedPattern { + match &pat.kind { + SimplifiedPatternKind::WildCard(_) => pat.clone(), + + SimplifiedPatternKind::Constructor { kind, fields } => { + let fields = fields.iter().map(generalize_pattern).collect(); + let kind = SimplifiedPatternKind::Constructor { + kind: *kind, + fields, + }; + SimplifiedPattern::new(kind, pat.ty) + } + + SimplifiedPatternKind::Or(pats) => { + let mut gen_pats = vec![]; + for pat in pats { + let gen_pad = generalize_pattern(pat); + if gen_pad.is_wildcard() { + gen_pats.push(gen_pad); + break; + } else { + gen_pats.push(gen_pad); + } + } + + if gen_pats.len() == 1 { + gen_pats.pop().unwrap() + } else { + SimplifiedPattern::new(SimplifiedPatternKind::Or(gen_pats), pat.ty) + } + } + } +} diff --git a/crates/mir2/src/lower/pattern_match/mod.rs b/crates/mir2/src/lower/pattern_match/mod.rs new file mode 100644 index 0000000000..6f38beec0a --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/mod.rs @@ -0,0 +1,326 @@ +use fe_analyzer2::pattern_analysis::{ConstructorKind, PatternMatrix}; +use fe_parser2::{ + ast::{Expr, LiteralPattern, MatchArm}, + node::Node, +}; +use fxhash::FxHashMap; +use id_arena::{Arena, Id}; +use smol_str::SmolStr; + +use crate::ir::{ + body_builder::BodyBuilder, inst::SwitchTable, BasicBlockId, SourceInfo, TypeId, ValueId, +}; + +use self::decision_tree::{ + Case, ColumnSelectionPolicy, DecisionTree, LeafNode, Occurrence, SwitchNode, +}; + +use super::function::BodyLowerHelper; + +pub mod decision_tree; +mod tree_vis; + +pub(super) fn lower_match<'b>( + helper: &'b mut BodyLowerHelper<'_, '_>, + mat: &PatternMatrix, + scrutinee: &Node, + arms: &'b [Node], +) { + let mut policy = ColumnSelectionPolicy::default(); + // PBA heuristics described in the paper. + policy.needed_prefix().small_branching().arity(); + + let scrutinee = helper.lower_expr_to_value(scrutinee); + let decision_tree = decision_tree::build_decision_tree(helper.db.upcast(), mat, policy); + + DecisionTreeLowerHelper::new(helper, scrutinee, arms).lower(decision_tree); +} + +struct DecisionTreeLowerHelper<'db, 'a, 'b> { + helper: &'b mut BodyLowerHelper<'db, 'a>, + scopes: Arena, + current_scope: ScopeId, + root_block: BasicBlockId, + declared_vars: FxHashMap<(SmolStr, usize), ValueId>, + arms: &'b [Node], + lowered_arms: FxHashMap, + match_exit: BasicBlockId, +} + +impl<'db, 'a, 'b> DecisionTreeLowerHelper<'db, 'a, 'b> { + fn new( + helper: &'b mut BodyLowerHelper<'db, 'a>, + scrutinee: ValueId, + arms: &'b [Node], + ) -> Self { + let match_exit = helper.builder.make_block(); + + let mut scope = Scope::default(); + scope.register_occurrence(Occurrence::new(), scrutinee); + let mut scopes = Arena::new(); + let current_scope = scopes.alloc(scope); + + let root_block = helper.builder.current_block(); + + DecisionTreeLowerHelper { + helper, + scopes, + current_scope, + root_block, + declared_vars: FxHashMap::default(), + arms, + lowered_arms: FxHashMap::default(), + match_exit, + } + } + + fn lower(&mut self, tree: DecisionTree) { + self.lower_tree(tree); + + let match_exit = self.match_exit; + self.builder().move_to_block(match_exit); + } + + fn lower_tree(&mut self, tree: DecisionTree) { + match tree { + DecisionTree::Leaf(leaf) => self.lower_leaf(leaf), + DecisionTree::Switch(switch) => self.lower_switch(switch), + } + } + + fn lower_leaf(&mut self, leaf: LeafNode) { + for (var, occurrence) in leaf.binds { + let occurrence_value = self.resolve_occurrence(&occurrence); + let ty = self.builder().value_ty(occurrence_value); + let var_value = self.declare_or_use_var(&var, ty); + + let inst = self.builder().bind(occurrence_value, SourceInfo::dummy()); + self.builder().map_result(inst, var_value.into()); + } + + let arm_body = self.lower_arm_body(leaf.arm_idx); + self.builder().jump(arm_body, SourceInfo::dummy()); + } + + fn lower_switch(&mut self, mut switch: SwitchNode) { + let current_bb = self.builder().current_block(); + let occurrence_value = self.resolve_occurrence(&switch.occurrence); + + if switch.arms.len() == 1 { + let arm = switch.arms.pop().unwrap(); + let arm_bb = self.enter_arm(&switch.occurrence, &arm.0); + self.lower_tree(arm.1); + self.builder().move_to_block(current_bb); + self.builder().jump(arm_bb, SourceInfo::dummy()); + return; + } + + let mut table = SwitchTable::default(); + let mut default_arm = None; + let occurrence_ty = self.builder().value_ty(occurrence_value); + + for (case, tree) in switch.arms { + let arm_bb = self.enter_arm(&switch.occurrence, &case); + self.lower_tree(tree); + self.leave_arm(); + + if let Some(disc) = self.case_to_disc(&case, occurrence_ty) { + table.add_arm(disc, arm_bb); + } else { + debug_assert!(default_arm.is_none()); + default_arm = Some(arm_bb); + } + } + + self.builder().move_to_block(current_bb); + let disc = self.extract_disc(occurrence_value); + self.builder() + .switch(disc, table, default_arm, SourceInfo::dummy()); + } + + fn lower_arm_body(&mut self, index: usize) -> BasicBlockId { + if let Some(block) = self.lowered_arms.get(&index) { + *block + } else { + let current_bb = self.builder().current_block(); + let body_bb = self.builder().make_block(); + + self.builder().move_to_block(body_bb); + for stmt in &self.arms[index].kind.body { + self.helper.lower_stmt(stmt); + } + + if !self.builder().is_current_block_terminated() { + let match_exit = self.match_exit; + self.builder().jump(match_exit, SourceInfo::dummy()); + } + + self.lowered_arms.insert(index, body_bb); + self.builder().move_to_block(current_bb); + body_bb + } + } + + fn enter_arm(&mut self, occurrence: &Occurrence, case: &Case) -> BasicBlockId { + self.helper.enter_scope(); + + let bb = self.builder().make_block(); + self.builder().move_to_block(bb); + + let scope = Scope::with_parent(self.current_scope); + self.current_scope = self.scopes.alloc(scope); + + self.update_occurrence(occurrence, case); + bb + } + + fn leave_arm(&mut self) { + self.current_scope = self.scopes[self.current_scope].parent.unwrap(); + self.helper.leave_scope(); + } + + fn case_to_disc(&mut self, case: &Case, occurrence_ty: TypeId) -> Option { + match case { + Case::Ctor(ConstructorKind::Enum(variant)) => { + let disc_ty = occurrence_ty.enum_disc_type(self.helper.db); + let disc = variant.disc(self.helper.db.upcast()); + Some(self.helper.make_imm(disc, disc_ty)) + } + + Case::Ctor(ConstructorKind::Literal((LiteralPattern::Bool(b), ty))) => { + let ty = self.helper.db.mir_lowered_type(*ty); + Some(self.builder().make_imm_from_bool(*b, ty)) + } + + Case::Ctor(ConstructorKind::Tuple(_)) + | Case::Ctor(ConstructorKind::Struct(_)) + | Case::Default => None, + } + } + + fn update_occurrence(&mut self, occurrence: &Occurrence, case: &Case) { + let old_value = self.resolve_occurrence(occurrence); + let old_ty = self.builder().value_ty(old_value); + + match case { + Case::Ctor(ConstructorKind::Enum(variant)) => { + let new_ty = old_ty.enum_variant_type(self.helper.db, *variant); + let cast = self + .builder() + .untag_cast(old_value, new_ty, SourceInfo::dummy()); + let value = self.helper.map_to_tmp(cast, new_ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value) + } + + Case::Ctor(ConstructorKind::Literal((LiteralPattern::Bool(b), _))) => { + let value = self.builder().make_imm_from_bool(*b, old_ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value) + } + + Case::Ctor(ConstructorKind::Tuple(_)) + | Case::Ctor(ConstructorKind::Struct(_)) + | Case::Default => {} + } + } + + fn extract_disc(&mut self, value: ValueId) -> ValueId { + let value_ty = self.builder().value_ty(value); + match value_ty { + _ if value_ty.deref(self.helper.db).is_enum(self.helper.db) => { + let disc_ty = value_ty.enum_disc_type(self.helper.db); + let disc_index = self.helper.make_u256_imm(0); + let inst = + self.builder() + .aggregate_access(value, vec![disc_index], SourceInfo::dummy()); + self.helper.map_to_tmp(inst, disc_ty) + } + + _ => value, + } + } + + fn declare_or_use_var(&mut self, var: &(SmolStr, usize), ty: TypeId) -> ValueId { + if let Some(value) = self.declared_vars.get(var) { + *value + } else { + let current_block = self.builder().current_block(); + let root_block = self.root_block; + self.builder().move_to_block_top(root_block); + let value = self.helper.declare_var(&var.0, ty, SourceInfo::dummy()); + self.builder().move_to_block(current_block); + self.declared_vars.insert(var.clone(), value); + value + } + } + + fn builder(&mut self) -> &mut BodyBuilder { + &mut self.helper.builder + } + + fn resolve_occurrence(&mut self, occurrence: &Occurrence) -> ValueId { + if let Some(value) = self + .current_scope() + .resolve_occurrence(&self.scopes, occurrence) + { + return value; + } + + let parent = occurrence.parent().unwrap(); + let parent_value = self.resolve_occurrence(&parent); + let parent_value_ty = self.builder().value_ty(parent_value); + + let index = occurrence.last_index().unwrap(); + let index_value = self.helper.make_u256_imm(occurrence.last_index().unwrap()); + let inst = + self.builder() + .aggregate_access(parent_value, vec![index_value], SourceInfo::dummy()); + + let ty = parent_value_ty.projection_ty_imm(self.helper.db, index); + let value = self.helper.map_to_tmp(inst, ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value); + value + } + + fn current_scope(&self) -> &Scope { + self.scopes.get(self.current_scope).unwrap() + } + + fn current_scope_mut(&mut self) -> &mut Scope { + self.scopes.get_mut(self.current_scope).unwrap() + } +} + +type ScopeId = Id; + +#[derive(Debug, Default)] +struct Scope { + parent: Option, + occurrences: FxHashMap, +} + +impl Scope { + pub fn with_parent(parent: ScopeId) -> Self { + Self { + parent: Some(parent), + ..Default::default() + } + } + + pub fn register_occurrence(&mut self, occurrence: Occurrence, value: ValueId) { + self.occurrences.insert(occurrence, value); + } + + pub fn resolve_occurrence( + &self, + arena: &Arena, + occurrence: &Occurrence, + ) -> Option { + match self.occurrences.get(occurrence) { + Some(value) => Some(*value), + None => arena[self.parent?].resolve_occurrence(arena, occurrence), + } + } +} diff --git a/crates/mir2/src/lower/pattern_match/tree_vis.rs b/crates/mir2/src/lower/pattern_match/tree_vis.rs new file mode 100644 index 0000000000..30c94c3494 --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/tree_vis.rs @@ -0,0 +1,150 @@ +use std::fmt::Write; + +use dot2::{label::Text, Id}; +use fxhash::FxHashMap; +use hir::HirDb; +use indexmap::IndexMap; +use smol_str::SmolStr; + +use super::decision_tree::{Case, DecisionTree, LeafNode, Occurrence, SwitchNode}; + +pub(super) struct TreeRenderer<'db> { + nodes: Vec, + edges: FxHashMap<(usize, usize), Case>, + db: &'db dyn HirDb, +} + +impl<'db> TreeRenderer<'db> { + #[allow(unused)] + pub(super) fn new(db: &'db dyn HirDb, tree: &DecisionTree) -> Self { + let mut renderer = Self { + nodes: Vec::new(), + edges: FxHashMap::default(), + db, + }; + + match tree { + DecisionTree::Leaf(leaf) => { + renderer.nodes.push(Node::from(leaf)); + } + DecisionTree::Switch(switch) => { + renderer.nodes.push(Node::from(switch)); + let node_id = renderer.nodes.len() - 1; + for arm in &switch.arms { + renderer.switch_from(&arm.1, node_id, arm.0); + } + } + } + renderer + } + + fn switch_from(&mut self, tree: &DecisionTree, node_id: usize, case: Case) { + match tree { + DecisionTree::Leaf(leaf) => { + self.nodes.push(Node::from(leaf)); + self.edges.insert((node_id, self.nodes.len() - 1), case); + } + + DecisionTree::Switch(switch) => { + self.nodes.push(Node::from(switch)); + let switch_id = self.nodes.len() - 1; + self.edges.insert((node_id, switch_id), case); + for arm in &switch.arms { + self.switch_from(&arm.1, switch_id, arm.0); + } + } + } + } +} + +impl<'db> dot2::Labeller<'db> for TreeRenderer<'db> { + type Node = usize; + type Edge = (Self::Node, Self::Node); + type Subgraph = (); + + fn graph_id(&self) -> dot2::Result> { + dot2::Id::new("DecisionTree") + } + + fn node_id(&self, n: &Self::Node) -> dot2::Result> { + dot2::Id::new(format!("N{}", *n)) + } + + fn node_label(&self, n: &Self::Node) -> dot2::Result> { + let node = &self.nodes[*n]; + let label = match node { + Node::Leaf { arm_idx, .. } => { + format!("arm_idx: {arm_idx}") + } + Node::Switch(occurrence) => { + let mut s = "expr".to_string(); + for num in occurrence.iter() { + write!(&mut s, ".{num}").unwrap(); + } + s + } + }; + + Ok(Text::LabelStr(label.into())) + } + + // fn edge_label(&self, e: &Self::Edge) -> Text<'db> { + // let label = match &self.edges[e] { + // Case::Ctor(ConstructorKind ::Enum(variant)) => { + // variant.name_with_parent(self.db).to_string() + // } + // Case::Ctor(ConstructorKind::Tuple(_)) => "()".to_string(), + // Case::Ctor(ConstructorKind::Struct(sid)) => sid.name(self.db).into(), + // Case::Ctor(ConstructorKind::Literal((lit, _))) => lit.to_string(), + // Case::Default => "_".into(), + // }; + + // Text::LabelStr(label.into()) + // } +} + +impl<'db> dot2::GraphWalk<'db> for TreeRenderer<'db> { + type Node = usize; + type Edge = (Self::Node, Self::Node); + type Subgraph = (); + + fn nodes(&self) -> dot2::Nodes<'db, Self::Node> { + (0..self.nodes.len()).collect() + } + + fn edges(&self) -> dot2::Edges<'db, Self::Edge> { + self.edges.keys().cloned().collect::>().into() + } + + fn source(&self, e: &Self::Edge) -> Self::Node { + e.0 + } + + fn target(&self, e: &Self::Edge) -> Self::Node { + e.1 + } +} + +enum Node { + Leaf { + arm_idx: usize, + #[allow(unused)] + binds: IndexMap<(SmolStr, usize), Occurrence>, + }, + Switch(Occurrence), +} + +impl From<&LeafNode> for Node { + fn from(node: &LeafNode) -> Self { + Node::Leaf { + arm_idx: node.arm_idx, + binds: node.binds.clone(), + } + } +} + +impl From<&SwitchNode> for Node { + fn from(node: &SwitchNode) -> Self { + Node::Switch(node.occurrence.clone()) + } +} diff --git a/crates/mir2/src/lower/types.rs b/crates/mir2/src/lower/types.rs new file mode 100644 index 0000000000..924841aa1c --- /dev/null +++ b/crates/mir2/src/lower/types.rs @@ -0,0 +1,194 @@ +// use crate::{ +// db::MirDb, +// ir::{ +// types::{ArrayDef, EnumDef, EnumVariant, MapDef, StructDef, TupleDef}, +// Type, TypeId, TypeKind, +// }, +// }; + +// use fe_analyzer::namespace::{ +// items as analyzer_items, +// types::{self as analyzer_types, TraitOrType}, +// }; + +// pub fn lower_type(db: &dyn MirDb, analyzer_ty: analyzer_types::TypeId) -> TypeId { +// let ty_kind = match analyzer_ty.typ(db.upcast()) { +// analyzer_types::Type::SPtr(inner) => TypeKind::SPtr(lower_type(db, inner)), + +// // NOTE: this results in unexpected MIR TypeId inequalities +// // (when different analyzer types map to the same MIR type). +// // We could (should?) remove .analyzer_ty from Type. +// analyzer_types::Type::Mut(inner) => match inner.typ(db.upcast()) { +// analyzer_types::Type::SPtr(t) => TypeKind::SPtr(lower_type(db, t)), +// analyzer_types::Type::Base(t) => lower_base(t), +// analyzer_types::Type::Contract(_) => TypeKind::Address, +// _ => TypeKind::MPtr(lower_type(db, inner)), +// }, +// analyzer_types::Type::SelfType(inner) => match inner { +// TraitOrType::TypeId(id) => return lower_type(db, id), +// TraitOrType::TraitId(_) => panic!("traits aren't lowered"), +// }, +// analyzer_types::Type::Base(base) => lower_base(base), +// analyzer_types::Type::Array(arr) => lower_array(db, &arr), +// analyzer_types::Type::Map(map) => lower_map(db, &map), +// analyzer_types::Type::Tuple(tup) => lower_tuple(db, &tup), +// analyzer_types::Type::String(string) => TypeKind::String(string.max_size), +// analyzer_types::Type::Contract(_) => TypeKind::Address, +// analyzer_types::Type::SelfContract(contract) => lower_contract(db, contract), +// analyzer_types::Type::Struct(struct_) => lower_struct(db, struct_), +// analyzer_types::Type::Enum(enum_) => lower_enum(db, enum_), +// analyzer_types::Type::Generic(_) => { +// panic!("should be lowered in `lower_analyzer_type`") +// } +// }; + +// intern_type(db, ty_kind, Some(analyzer_ty.deref(db.upcast()))) +// } + +// fn lower_base(base: analyzer_types::Base) -> TypeKind { +// use analyzer_types::{Base, Integer}; + +// match base { +// Base::Numeric(int_ty) => match int_ty { +// Integer::I8 => TypeKind::I8, +// Integer::I16 => TypeKind::I16, +// Integer::I32 => TypeKind::I32, +// Integer::I64 => TypeKind::I64, +// Integer::I128 => TypeKind::I128, +// Integer::I256 => TypeKind::I256, +// Integer::U8 => TypeKind::U8, +// Integer::U16 => TypeKind::U16, +// Integer::U32 => TypeKind::U32, +// Integer::U64 => TypeKind::U64, +// Integer::U128 => TypeKind::U128, +// Integer::U256 => TypeKind::U256, +// }, + +// Base::Bool => TypeKind::Bool, +// Base::Address => TypeKind::Address, +// Base::Unit => TypeKind::Unit, +// } +// } + +// fn lower_array(db: &dyn MirDb, arr: &analyzer_types::Array) -> TypeKind { +// let len = arr.size; +// let elem_ty = db.mir_lowered_type(arr.inner); + +// let def = ArrayDef { elem_ty, len }; +// TypeKind::Array(def) +// } + +// fn lower_map(db: &dyn MirDb, map: &analyzer_types::Map) -> TypeKind { +// let key_ty = db.mir_lowered_type(map.key); +// let value_ty = db.mir_lowered_type(map.value); + +// let def = MapDef { key_ty, value_ty }; +// TypeKind::Map(def) +// } + +// fn lower_tuple(db: &dyn MirDb, tup: &analyzer_types::Tuple) -> TypeKind { +// let items = tup +// .items +// .iter() +// .map(|item| db.mir_lowered_type(*item)) +// .collect(); + +// let def = TupleDef { items }; +// TypeKind::Tuple(def) +// } + +// fn lower_contract(db: &dyn MirDb, contract: analyzer_items::ContractId) -> TypeKind { +// let name = contract.name(db.upcast()); + +// // Note: contract field types are wrapped in SPtr in TypeId::projection_ty +// let fields = contract +// .fields(db.upcast()) +// .iter() +// .map(|(fname, fid)| { +// let analyzer_type = fid.typ(db.upcast()).unwrap(); +// let ty = db.mir_lowered_type(analyzer_type); +// (fname.clone(), ty) +// }) +// .collect(); + +// // Obtain span. +// let span = contract.span(db.upcast()); + +// let module_id = contract.module(db.upcast()); + +// let def = StructDef { +// name, +// fields, +// span, +// module_id, +// }; +// TypeKind::Contract(def) +// } + +// fn lower_struct(db: &dyn MirDb, id: analyzer_items::StructId) -> TypeKind { +// let name = id.name(db.upcast()); + +// // Lower struct fields. +// let fields = id +// .fields(db.upcast()) +// .iter() +// .map(|(fname, fid)| { +// let analyzer_types = fid.typ(db.upcast()).unwrap(); +// let ty = db.mir_lowered_type(analyzer_types); +// (fname.clone(), ty) +// }) +// .collect(); + +// // obtain span. +// let span = id.span(db.upcast()); + +// let module_id = id.module(db.upcast()); + +// let def = StructDef { +// name, +// fields, +// span, +// module_id, +// }; +// TypeKind::Struct(def) +// } + +// fn lower_enum(db: &dyn MirDb, id: analyzer_items::EnumId) -> TypeKind { +// let analyzer_variants = id.variants(db.upcast()); +// let mut variants = Vec::with_capacity(analyzer_variants.len()); +// for variant in analyzer_variants.values() { +// let variant_ty = match variant.kind(db.upcast()).unwrap() { +// analyzer_items::EnumVariantKind::Tuple(elts) => { +// let tuple_ty = analyzer_types::TypeId::tuple(db.upcast(), &elts); +// db.mir_lowered_type(tuple_ty) +// } +// analyzer_items::EnumVariantKind::Unit => { +// let unit_ty = analyzer_types::TypeId::unit(db.upcast()); +// db.mir_lowered_type(unit_ty) +// } +// }; + +// variants.push(EnumVariant { +// name: variant.name(db.upcast()), +// span: variant.span(db.upcast()), +// ty: variant_ty, +// }); +// } + +// let def = EnumDef { +// name: id.name(db.upcast()), +// span: id.span(db.upcast()), +// variants, +// module_id: id.module(db.upcast()), +// }; + +// TypeKind::Enum(def) +// } + +// fn intern_type( +// db: &dyn MirDb, +// ty_kind: TypeKind, +// analyzer_type: Option, +// ) -> TypeId { +// db.mir_intern_type(Type::new(ty_kind, analyzer_type).into()) +// } diff --git a/crates/mir2/src/pretty_print/inst.rs b/crates/mir2/src/pretty_print/inst.rs new file mode 100644 index 0000000000..f3612e7c50 --- /dev/null +++ b/crates/mir2/src/pretty_print/inst.rs @@ -0,0 +1,206 @@ +use std::fmt::{self, Write}; + +use crate::{ + ir::{function::BodyDataStore, inst::InstKind, InstId}, + MirDb, +}; + +use super::PrettyPrint; + +impl PrettyPrint for InstId { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + if let Some(result) = store.inst_result(*self) { + result.pretty_print(db, store, w)?; + write!(w, ": ")?; + + let result_ty = result.ty(db, store); + result_ty.pretty_print(db, store, w)?; + write!(w, " = ")?; + } + + match &store.inst_data(*self).kind { + InstKind::Declare { local } => { + write!(w, "let ")?; + local.pretty_print(db, store, w)?; + write!(w, ": ")?; + store.value_ty(*local).pretty_print(db, store, w) + } + + InstKind::Unary { op, value } => { + write!(w, "{op}")?; + value.pretty_print(db, store, w) + } + + InstKind::Binary { op, lhs, rhs } => { + lhs.pretty_print(db, store, w)?; + write!(w, " {op} ")?; + rhs.pretty_print(db, store, w) + } + + InstKind::Cast { value, to, .. } => { + value.pretty_print(db, store, w)?; + write!(w, " as ")?; + to.pretty_print(db, store, w) + } + + InstKind::AggregateConstruct { ty, args } => { + ty.pretty_print(db, store, w)?; + write!(w, "{{")?; + if args.is_empty() { + return write!(w, "}}"); + } + + let arg_len = args.len(); + for (arg_idx, arg) in args.iter().enumerate().take(arg_len - 1) { + write!(w, "<{arg_idx}>: ")?; + arg.pretty_print(db, store, w)?; + write!(w, ", ")?; + } + let arg = args[arg_len - 1]; + write!(w, "<{}>: ", arg_len - 1)?; + arg.pretty_print(db, store, w)?; + write!(w, "}}") + } + + InstKind::Bind { src } => { + write!(w, "bind ")?; + src.pretty_print(db, store, w) + } + + InstKind::MemCopy { src } => { + write!(w, "memcopy ")?; + src.pretty_print(db, store, w) + } + + InstKind::Load { src } => { + write!(w, "load ")?; + src.pretty_print(db, store, w) + } + + InstKind::AggregateAccess { value, indices } => { + value.pretty_print(db, store, w)?; + for index in indices { + write!(w, ".<")?; + index.pretty_print(db, store, w)?; + write!(w, ">")? + } + Ok(()) + } + + InstKind::MapAccess { value, key } => { + value.pretty_print(db, store, w)?; + write!(w, "{{")?; + key.pretty_print(db, store, w)?; + write!(w, "}}") + } + + InstKind::Call { + func, + args, + call_type, + } => { + let name = func.debug_name(db); + write!(w, "{name}@{call_type}(")?; + args.as_slice().pretty_print(db, store, w)?; + write!(w, ")") + } + + InstKind::Jump { dest } => { + write!(w, "jump BB{}", dest.index()) + } + + InstKind::Branch { cond, then, else_ } => { + write!(w, "branch ")?; + cond.pretty_print(db, store, w)?; + write!(w, " then: BB{} else: BB{}", then.index(), else_.index()) + } + + InstKind::Switch { + disc, + table, + default, + } => { + write!(w, "switch ")?; + disc.pretty_print(db, store, w)?; + for (value, block) in table.iter() { + write!(w, " ")?; + value.pretty_print(db, store, w)?; + write!(w, ": BB{}", block.index())?; + } + + if let Some(default) = default { + write!(w, " default: BB{}", default.index()) + } else { + Ok(()) + } + } + + InstKind::Revert { arg } => { + write!(w, "revert ")?; + if let Some(arg) = arg { + arg.pretty_print(db, store, w)?; + } + Ok(()) + } + + InstKind::Emit { arg } => { + write!(w, "emit ")?; + arg.pretty_print(db, store, w) + } + + InstKind::Return { arg } => { + if let Some(arg) = arg { + write!(w, "return ")?; + arg.pretty_print(db, store, w) + } else { + write!(w, "return") + } + } + + InstKind::Keccak256 { arg } => { + write!(w, "keccak256 ")?; + arg.pretty_print(db, store, w) + } + + InstKind::AbiEncode { arg } => { + write!(w, "abi_encode ")?; + arg.pretty_print(db, store, w) + } + + InstKind::Nop => { + write!(w, "nop") + } + + InstKind::Create { value, contract } => { + write!(w, "create ")?; + let contract_name = contract.name(db.upcast()); + write!(w, "{contract_name} ")?; + value.pretty_print(db, store, w) + } + + InstKind::Create2 { + value, + salt, + contract, + } => { + write!(w, "create2 ")?; + let contract_name = contract.name(db.upcast()); + write!(w, "{contract_name} ")?; + value.pretty_print(db, store, w)?; + write!(w, " ")?; + salt.pretty_print(db, store, w) + } + + InstKind::YulIntrinsic { op, args } => { + write!(w, "{op}(")?; + args.as_slice().pretty_print(db, store, w)?; + write!(w, ")") + } + } + } +} diff --git a/crates/mir2/src/pretty_print/mod.rs b/crates/mir2/src/pretty_print/mod.rs new file mode 100644 index 0000000000..853b8cfa24 --- /dev/null +++ b/crates/mir2/src/pretty_print/mod.rs @@ -0,0 +1,22 @@ +use std::fmt; + +use crate::{ir::function::BodyDataStore, MirDb}; + +mod inst; +mod types; +mod value; + +pub trait PrettyPrint { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result; + + fn pretty_string(&self, db: &dyn MirDb, store: &BodyDataStore) -> String { + let mut s = String::new(); + self.pretty_print(db, store, &mut s).unwrap(); + s + } +} diff --git a/crates/mir2/src/pretty_print/types.rs b/crates/mir2/src/pretty_print/types.rs new file mode 100644 index 0000000000..1d146c8292 --- /dev/null +++ b/crates/mir2/src/pretty_print/types.rs @@ -0,0 +1,18 @@ +use std::fmt::{self, Write}; + +use hir::hir_def::TypeId; + +use crate::{ir::function::BodyDataStore, MirDb}; + +use super::PrettyPrint; + +impl PrettyPrint for TypeId { + fn pretty_print( + &self, + db: &dyn MirDb, + _store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + self.print(db, w) + } +} diff --git a/crates/mir2/src/pretty_print/value.rs b/crates/mir2/src/pretty_print/value.rs new file mode 100644 index 0000000000..3a1650b72b --- /dev/null +++ b/crates/mir2/src/pretty_print/value.rs @@ -0,0 +1,81 @@ +use std::fmt::{self, Write}; + +use crate::{ + ir::{ + constant::ConstantValue, function::BodyDataStore, value::AssignableValue, Value, ValueId, + }, + MirDb, +}; + +use super::PrettyPrint; + +impl PrettyPrint for ValueId { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + match store.value_data(*self) { + Value::Temporary { .. } | Value::Local(_) => write!(w, "_{}", self.index()), + Value::Immediate { imm, .. } => write!(w, "{imm}"), + Value::Constant { constant, .. } => { + let const_value = constant.data(db); + write!(w, "const ")?; + match &const_value.value { + ConstantValue::Immediate(num) => write!(w, "{num}"), + ConstantValue::Str(s) => write!(w, r#""{s}""#), + ConstantValue::Bool(b) => write!(w, "{b}"), + } + } + Value::Unit { .. } => write!(w, "()"), + } + } +} + +impl PrettyPrint for &[ValueId] { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + if self.is_empty() { + return Ok(()); + } + + let arg_len = self.len(); + for arg in self.iter().take(arg_len - 1) { + arg.pretty_print(db, store, w)?; + write!(w, ", ")?; + } + let arg = self[arg_len - 1]; + arg.pretty_print(db, store, w) + } +} + +impl PrettyPrint for AssignableValue { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + match self { + Self::Value(value) => value.pretty_print(db, store, w), + Self::Aggregate { lhs, idx } => { + lhs.pretty_print(db, store, w)?; + write!(w, ".<")?; + idx.pretty_print(db, store, w)?; + write!(w, ">") + } + + Self::Map { lhs, key } => { + lhs.pretty_print(db, store, w)?; + write!(w, "{{")?; + key.pretty_print(db, store, w)?; + write!(w, "}}") + } + } + } +} diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs new file mode 100644 index 0000000000..be4354be3f --- /dev/null +++ b/crates/mir2/tests/lowering.rs @@ -0,0 +1,108 @@ +use common::InputDb; +use hir::hir_def::IngotId; +use test_db::{initialize_analysis_pass, LowerMirTestDb}; + +mod test_db; + +// macro_rules! test_lowering { +// ($name:ident, $path:expr) => { +// #[test] +// fn $name() { +// let mut db = NewDb::default(); + +// let file_name = Utf8Path::new($path).file_name().unwrap(); +// let module = ModuleId::new_standalone(&mut db, file_name, test_files::fixture($path)); + +// let diags = module.diagnostics(&db); +// if !diags.is_empty() { +// panic!("lowering failed") +// } + +// for func in db.mir_lower_module_all_functions(module).iter() { +// let body = func.body(&db); +// ControlFlowGraph::compute(&body); +// } +// } +// }; +// } + +#[test] +fn mir_lower_std_lib() { + let mut db = LowerMirTestDb::default(); + let top_mod = db.new_std_lib(); + + let mut pm = initialize_analysis_pass(&db); + let diags = pm.run_on_module(top_mod); + + if !diags.is_empty() { + panic!("std lib analysis failed") + } + + // for &module in std_ingot.all_modules(&db).iter() { + // for func in db.mir_lower_module_all_functions(module).iter() { + // let body = func.body(&db); + // let cfg = ControlFlowGraph::compute(&body); + // let domtree = DomTree::compute(&cfg); + // LoopTree::compute(&cfg, &domtree); + // PostDomTree::compute(&body); + // } + // } +} + +// test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} +// test_lowering! { mir_guest_book, "demos/guest_book.fe"} +// test_lowering! { mir_uniswap, "demos/uniswap.fe"} +// test_lowering! { mir_assert, "features/assert.fe"} +// test_lowering! { mir_aug_assign, "features/aug_assign.fe"} +// test_lowering! { mir_call_statement_with_args, "features/call_statement_with_args.fe"} +// test_lowering! { mir_call_statement_with_args_2, "features/call_statement_with_args_2.fe"} +// test_lowering! { mir_call_statement_without_args, "features/call_statement_without_args.fe"} +// test_lowering! { mir_checked_arithmetic, "features/checked_arithmetic.fe"} +// test_lowering! { mir_constructor, "features/constructor.fe"} +// test_lowering! { mir_create2_contract, "features/create2_contract.fe"} +// test_lowering! { mir_create_contract, "features/create_contract.fe"} +// test_lowering! { mir_create_contract_from_init, "features/create_contract_from_init.fe"} +// test_lowering! { mir_empty, "features/empty.fe"} +// test_lowering! { mir_events, "features/events.fe"} +// test_lowering! { mir_module_level_events, "features/module_level_events.fe"} +// test_lowering! { mir_external_contract, "features/external_contract.fe"} +// test_lowering! { mir_for_loop_with_break, "features/for_loop_with_break.fe"} +// test_lowering! { mir_for_loop_with_continue, "features/for_loop_with_continue.fe"} +// test_lowering! { mir_for_loop_with_static_array, "features/for_loop_with_static_array.fe"} +// test_lowering! { mir_if_statement, "features/if_statement.fe"} +// test_lowering! { mir_if_statement_2, "features/if_statement_2.fe"} +// test_lowering! { mir_if_statement_with_block_declaration, "features/if_statement_with_block_declaration.fe"} +// test_lowering! { mir_keccak, "features/keccak.fe"} +// test_lowering! { mir_math, "features/math.fe"} +// test_lowering! { mir_module_const, "features/module_const.fe"} +// test_lowering! { mir_multi_param, "features/multi_param.fe"} +// test_lowering! { mir_nested_map, "features/nested_map.fe"} +// test_lowering! { mir_numeric_sizes, "features/numeric_sizes.fe"} +// test_lowering! { mir_ownable, "features/ownable.fe"} +// test_lowering! { mir_pure_fn_standalone, "features/pure_fn_standalone.fe"} +// test_lowering! { mir_revert, "features/revert.fe"} +// test_lowering! { mir_self_address, "features/self_address.fe"} +// test_lowering! { mir_send_value, "features/send_value.fe"} +// test_lowering! { mir_balances, "features/balances.fe"} +// test_lowering! { mir_sized_vals_in_sto, "features/sized_vals_in_sto.fe"} +// test_lowering! { mir_strings, "features/strings.fe"} +// test_lowering! { mir_structs, "features/structs.fe"} +// test_lowering! { mir_struct_fns, "features/struct_fns.fe"} +// test_lowering! { mir_ternary_expression, "features/ternary_expression.fe"} +// test_lowering! { mir_two_contracts, "features/two_contracts.fe"} +// test_lowering! { mir_u8_u8_map, "features/u8_u8_map.fe"} +// test_lowering! { mir_u16_u16_map, "features/u16_u16_map.fe"} +// test_lowering! { mir_u32_u32_map, "features/u32_u32_map.fe"} +// test_lowering! { mir_u64_u64_map, "features/u64_u64_map.fe"} +// test_lowering! { mir_u128_u128_map, "features/u128_u128_map.fe"} +// test_lowering! { mir_u256_u256_map, "features/u256_u256_map.fe"} +// test_lowering! { mir_while_loop, "features/while_loop.fe"} +// test_lowering! { mir_while_loop_with_break, "features/while_loop_with_break.fe"} +// test_lowering! { mir_while_loop_with_break_2, "features/while_loop_with_break_2.fe"} +// test_lowering! { mir_while_loop_with_continue, "features/while_loop_with_continue.fe"} +// test_lowering! { mir_abi_encoding_stress, "stress/abi_encoding_stress.fe"} +// test_lowering! { mir_data_copying_stress, "stress/data_copying_stress.fe"} +// test_lowering! { mir_tuple_stress, "stress/tuple_stress.fe"} +// test_lowering! { mir_type_aliases, "features/type_aliases.fe"} +// test_lowering! { mir_const_generics, "features/const_generics.fe" } +// test_lowering! { mir_const_local, "features/const_local.fe" } diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs new file mode 100644 index 0000000000..724c0a1d2b --- /dev/null +++ b/crates/mir2/tests/test_db.rs @@ -0,0 +1,174 @@ +use std::collections::{BTreeMap, BTreeSet}; + +// use codespan_reporting::{ +// diagnostic::{Diagnostic, Label}, +// files::SimpleFiles, +// term::{ +// self, +// termcolor::{BufferWriter, ColorChoice}, +// }, +// }; +use common::{ + diagnostics::Span, + input::{IngotKind, Version}, + InputFile, InputIngot, +}; +use hir::{analysis_pass::AnalysisPassManager, hir_def::TopLevelMod, ParsingPass}; +use hir_analysis::{ + name_resolution::{DefConflictAnalysisPass, ImportAnalysisPass, PathAnalysisPass}, + ty::{ + FuncAnalysisPass, ImplAnalysisPass, ImplTraitAnalysisPass, TraitAnalysisPass, + TypeAliasAnalysisPass, TypeDefAnalysisPass, + }, +}; +// use hir::{ +// hir_def::TopLevelMod, +// lower, +// span::{DynLazySpan, LazySpan}, +// HirDb, SpannedHirDb, +// }; +// use rustc_hash::FxHashMap; + +type CodeSpanFileId = usize; + +#[salsa::db( + common::Jar, + hir::Jar, + hir::SpannedJar, + hir::LowerJar, + hir_analysis::Jar +)] +pub struct LowerMirTestDb { + storage: salsa::Storage, +} + +impl LowerMirTestDb { + pub fn new_stand_alone(&mut self, file_name: &str, text: &str) { + let kind = IngotKind::StandAlone; + let version = Version::new(0, 0, 1); + let ingot = InputIngot::new(self, file_name, kind, version, BTreeSet::default()); + let root = InputFile::new(self, ingot, "test_file.fe".into(), text.to_string()); + ingot.set_root_file(self, root); + ingot.set_files(self, [root].into()); + + // let top_mod = lower::map_file_to_mod(self, input_file); + + // let mut prop_formatter = HirPropertyFormatter::default(); + // let top_mod = self.register_file(&mut prop_formatter, root); + // let top_mod = self.register_file(root); + // top_mod + } + + pub fn new_std_lib(&mut self) -> TopLevelMod { + let input = library::std_lib_input_ingot(self); + panic!(""); + // lower::map_file_to_mod(self, input_file) + } + + fn register_file(&self, input_file: InputFile) { + // let top_mod = lower::map_file_to_mod(self, input_file); + // let path = input_file.path(self); + // let text = input_file.text(self); + // prop_formatter.register_top_mod(path.as_str(), text, top_mod); + // top_mod + } +} + +impl Default for LowerMirTestDb { + fn default() -> Self { + let db = Self { + storage: Default::default(), + }; + // db.prefill(); + db + } +} + +// pub struct HirPropertyFormatter { +// properties: BTreeMap>, +// top_mod_to_file: FxHashMap, +// code_span_files: SimpleFiles, +// } + +// impl HirPropertyFormatter { +// pub fn push_prop(&mut self, top_mod: TopLevelMod, span: DynLazySpan, prop: String) { +// self.properties +// .entry(top_mod) +// .or_default() +// .push((prop, span)); +// } + +// pub fn finish(&mut self, db: &dyn SpannedHirDb) -> String { +// let writer = BufferWriter::stderr(ColorChoice::Never); +// let mut buffer = writer.buffer(); +// let config = term::Config::default(); + +// for top_mod in self.top_mod_to_file.keys() { +// if !self.properties.contains_key(top_mod) { +// continue; +// } + +// let diags = self.properties[top_mod] +// .iter() +// .map(|(prop, span)| { +// let (span, diag) = self.property_to_diag(db, *top_mod, prop, span.clone()); +// ((span.file, span.range.start()), diag) +// }) +// .collect::>(); + +// for diag in diags.values() { +// term::emit(&mut buffer, &config, &self.code_span_files, diag).unwrap(); +// } +// } + +// std::str::from_utf8(buffer.as_slice()).unwrap().to_string() +// } + +// fn property_to_diag( +// &self, +// db: &dyn SpannedHirDb, +// top_mod: TopLevelMod, +// prop: &str, +// span: DynLazySpan, +// ) -> (Span, Diagnostic) { +// let file_id = self.top_mod_to_file[&top_mod]; +// let span = span.resolve(db).unwrap(); +// let diag = Diagnostic::note() +// .with_labels(vec![Label::primary(file_id, span.range).with_message(prop)]); +// (span, diag) +// } + +// fn register_top_mod(&mut self, path: &str, text: &str, top_mod: TopLevelMod) { +// let file_id = self.code_span_files.add(path.to_string(), text.to_string()); +// self.top_mod_to_file.insert(top_mod, file_id); +// } +// } + +// impl Default for HirPropertyFormatter { +// fn default() -> Self { +// Self { +// properties: Default::default(), +// top_mod_to_file: Default::default(), +// code_span_files: SimpleFiles::new(), +// } +// } +// } + +impl salsa::Database for LowerMirTestDb { + fn salsa_event(&self, _: salsa::Event) {} +} + +pub fn initialize_analysis_pass(db: &LowerMirTestDb) -> AnalysisPassManager<'_> { + let mut pass_manager = AnalysisPassManager::new(); + pass_manager.add_module_pass(Box::new(ParsingPass::new(db))); + pass_manager.add_module_pass(Box::new(DefConflictAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImportAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(PathAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TypeDefAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TypeAliasAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TraitAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImplAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImplTraitAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(FuncAnalysisPass::new(db))); + pass_manager +}