From 9d1932bbe0cac289bb9413fa2f39126d3175d623 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Sat, 23 Dec 2023 18:53:53 -0700 Subject: [PATCH 01/22] hacking --- Cargo.lock | 19 + crates/mir2/Cargo.toml | 24 + crates/mir2/src/analysis/cfg.rs | 164 ++ crates/mir2/src/analysis/domtree.rs | 343 +++++ crates/mir2/src/analysis/loop_tree.rs | 347 +++++ crates/mir2/src/analysis/mod.rs | 9 + crates/mir2/src/analysis/post_domtree.rs | 284 ++++ crates/mir2/src/db.rs | 104 ++ crates/mir2/src/db/queries.rs | 7 + crates/mir2/src/db/queries/constant.rs | 43 + crates/mir2/src/db/queries/contract.rs | 17 + crates/mir2/src/db/queries/enums.rs | 17 + crates/mir2/src/db/queries/function.rs | 130 ++ crates/mir2/src/db/queries/module.rs | 35 + crates/mir2/src/db/queries/structs.rs | 17 + crates/mir2/src/db/queries/types.rs | 657 ++++++++ crates/mir2/src/graphviz/block.rs | 62 + crates/mir2/src/graphviz/function.rs | 78 + crates/mir2/src/graphviz/mod.rs | 22 + crates/mir2/src/graphviz/module.rs | 158 ++ crates/mir2/src/ir/basic_block.rs | 6 + crates/mir2/src/ir/body_builder.rs | 381 +++++ crates/mir2/src/ir/body_cursor.rs | 231 +++ crates/mir2/src/ir/body_order.rs | 473 ++++++ crates/mir2/src/ir/constant.rs | 47 + crates/mir2/src/ir/function.rs | 274 ++++ crates/mir2/src/ir/inst.rs | 764 +++++++++ crates/mir2/src/ir/mod.rs | 49 + crates/mir2/src/ir/types.rs | 119 ++ crates/mir2/src/ir/value.rs | 142 ++ crates/mir2/src/lib.rs | 7 + crates/mir2/src/lower/function.rs | 1367 +++++++++++++++++ crates/mir2/src/lower/mod.rs | 4 + .../src/lower/pattern_match/decision_tree.rs | 576 +++++++ crates/mir2/src/lower/pattern_match/mod.rs | 326 ++++ .../mir2/src/lower/pattern_match/tree_vis.rs | 150 ++ crates/mir2/src/lower/types.rs | 194 +++ crates/mir2/src/pretty_print/inst.rs | 206 +++ crates/mir2/src/pretty_print/mod.rs | 22 + crates/mir2/src/pretty_print/types.rs | 19 + crates/mir2/src/pretty_print/value.rs | 81 + crates/mir2/tests/lowering.rs | 109 ++ 42 files changed, 8084 insertions(+) create mode 100644 crates/mir2/Cargo.toml create mode 100644 crates/mir2/src/analysis/cfg.rs create mode 100644 crates/mir2/src/analysis/domtree.rs create mode 100644 crates/mir2/src/analysis/loop_tree.rs create mode 100644 crates/mir2/src/analysis/mod.rs create mode 100644 crates/mir2/src/analysis/post_domtree.rs create mode 100644 crates/mir2/src/db.rs create mode 100644 crates/mir2/src/db/queries.rs create mode 100644 crates/mir2/src/db/queries/constant.rs create mode 100644 crates/mir2/src/db/queries/contract.rs create mode 100644 crates/mir2/src/db/queries/enums.rs create mode 100644 crates/mir2/src/db/queries/function.rs create mode 100644 crates/mir2/src/db/queries/module.rs create mode 100644 crates/mir2/src/db/queries/structs.rs create mode 100644 crates/mir2/src/db/queries/types.rs create mode 100644 crates/mir2/src/graphviz/block.rs create mode 100644 crates/mir2/src/graphviz/function.rs create mode 100644 crates/mir2/src/graphviz/mod.rs create mode 100644 crates/mir2/src/graphviz/module.rs create mode 100644 crates/mir2/src/ir/basic_block.rs create mode 100644 crates/mir2/src/ir/body_builder.rs create mode 100644 crates/mir2/src/ir/body_cursor.rs create mode 100644 crates/mir2/src/ir/body_order.rs create mode 100644 crates/mir2/src/ir/constant.rs create mode 100644 crates/mir2/src/ir/function.rs create mode 100644 crates/mir2/src/ir/inst.rs create mode 100644 crates/mir2/src/ir/mod.rs create mode 100644 crates/mir2/src/ir/types.rs create mode 100644 crates/mir2/src/ir/value.rs create mode 100644 crates/mir2/src/lib.rs create mode 100644 crates/mir2/src/lower/function.rs create mode 100644 crates/mir2/src/lower/mod.rs create mode 100644 crates/mir2/src/lower/pattern_match/decision_tree.rs create mode 100644 crates/mir2/src/lower/pattern_match/mod.rs create mode 100644 crates/mir2/src/lower/pattern_match/tree_vis.rs create mode 100644 crates/mir2/src/lower/types.rs create mode 100644 crates/mir2/src/pretty_print/inst.rs create mode 100644 crates/mir2/src/pretty_print/mod.rs create mode 100644 crates/mir2/src/pretty_print/types.rs create mode 100644 crates/mir2/src/pretty_print/value.rs create mode 100644 crates/mir2/tests/lowering.rs diff --git a/Cargo.lock b/Cargo.lock index fc11dc0e6e..559d3b5735 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1130,6 +1130,25 @@ dependencies = [ "smol_str", ] +[[package]] +name = "fe-mir2" +version = "0.23.0" +dependencies = [ + "dot2", + "fe-common2", + "fe-library", + "fe-parser2", + "fe-test-files", + "fxhash", + "id-arena", + "indexmap", + "num-bigint", + "num-integer", + "num-traits", + "salsa", + "smol_str", +] + [[package]] name = "fe-parser" version = "0.23.0" diff --git a/crates/mir2/Cargo.toml b/crates/mir2/Cargo.toml new file mode 100644 index 0000000000..6ca3c5445d --- /dev/null +++ b/crates/mir2/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "fe-mir2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +fe-common2 = { path = "../common2", version = "^0.23.0" } +fe-parser2 = { path = "../parser2", version = "^0.23.0" } +salsa = "0.16.1" +smol_str = "0.1.21" +num-bigint = "0.4.3" +num-traits = "0.2.14" +num-integer = "0.1.45" +id-arena = "2.2.1" +fxhash = "0.2.1" +dot2 = "1.0.0" +indexmap = "1.6.2" + +[dev-dependencies] +test-files = { path = "../test-files", package = "fe-test-files" } +fe-library = { path = "../library" } diff --git a/crates/mir2/src/analysis/cfg.rs b/crates/mir2/src/analysis/cfg.rs new file mode 100644 index 0000000000..d4de8b2cfa --- /dev/null +++ b/crates/mir2/src/analysis/cfg.rs @@ -0,0 +1,164 @@ +use fxhash::FxHashMap; + +use crate::ir::{BasicBlockId, FunctionBody, InstId}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ControlFlowGraph { + entry: BasicBlockId, + blocks: FxHashMap, + pub(super) exits: Vec, +} + +impl ControlFlowGraph { + pub fn compute(func: &FunctionBody) -> Self { + let entry = func.order.entry(); + let mut cfg = Self { + entry, + blocks: FxHashMap::default(), + exits: vec![], + }; + + for block in func.order.iter_block() { + let terminator = func + .order + .terminator(&func.store, block) + .expect("a block must have terminator"); + cfg.analyze_terminator(func, terminator); + } + + cfg + } + + pub fn entry(&self) -> BasicBlockId { + self.entry + } + + pub fn preds(&self, block: BasicBlockId) -> &[BasicBlockId] { + self.blocks[&block].preds() + } + + pub fn succs(&self, block: BasicBlockId) -> &[BasicBlockId] { + self.blocks[&block].succs() + } + + pub fn post_order(&self) -> CfgPostOrder { + CfgPostOrder::new(self) + } + + pub(super) fn add_edge(&mut self, from: BasicBlockId, to: BasicBlockId) { + self.node_mut(to).push_pred(from); + self.node_mut(from).push_succ(to); + } + + pub(super) fn reverse_edge(&mut self, new_entry: BasicBlockId, new_exits: Vec) { + for (_, block) in self.blocks.iter_mut() { + block.reverse_edge() + } + + self.entry = new_entry; + self.exits = new_exits; + } + + fn analyze_terminator(&mut self, func: &FunctionBody, terminator: InstId) { + let block = func.order.inst_block(terminator); + let branch_info = func.store.branch_info(terminator); + if branch_info.is_not_a_branch() { + self.node_mut(block); + self.exits.push(block) + } else { + for dest in branch_info.block_iter() { + self.add_edge(block, dest) + } + } + } + + fn node_mut(&mut self, block: BasicBlockId) -> &mut BlockNode { + self.blocks.entry(block).or_default() + } +} + +#[derive(Default, Clone, Debug, PartialEq, Eq)] +struct BlockNode { + preds: Vec, + succs: Vec, +} + +impl BlockNode { + fn push_pred(&mut self, pred: BasicBlockId) { + self.preds.push(pred); + } + + fn push_succ(&mut self, succ: BasicBlockId) { + self.succs.push(succ); + } + + fn preds(&self) -> &[BasicBlockId] { + &self.preds + } + + fn succs(&self) -> &[BasicBlockId] { + &self.succs + } + + fn reverse_edge(&mut self) { + std::mem::swap(&mut self.preds, &mut self.succs) + } +} + +pub struct CfgPostOrder<'a> { + cfg: &'a ControlFlowGraph, + node_state: FxHashMap, + stack: Vec, +} + +impl<'a> CfgPostOrder<'a> { + fn new(cfg: &'a ControlFlowGraph) -> Self { + let stack = vec![cfg.entry()]; + + Self { + cfg, + node_state: FxHashMap::default(), + stack, + } + } +} + +impl<'a> Iterator for CfgPostOrder<'a> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + while let Some(&block) = self.stack.last() { + let node_state = self.node_state.entry(block).or_default(); + if *node_state == NodeState::Unvisited { + *node_state = NodeState::Visited; + for &succ in self.cfg.succs(block) { + let pred_state = self.node_state.entry(succ).or_default(); + if *pred_state == NodeState::Unvisited { + self.stack.push(succ); + } + } + } else { + self.stack.pop().unwrap(); + if *node_state != NodeState::Finished { + *node_state = NodeState::Finished; + return Some(block); + } + } + } + + None + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NodeState { + Unvisited, + Visited, + Finished, +} + +impl Default for NodeState { + fn default() -> Self { + Self::Unvisited + } +} diff --git a/crates/mir2/src/analysis/domtree.rs b/crates/mir2/src/analysis/domtree.rs new file mode 100644 index 0000000000..9775db6335 --- /dev/null +++ b/crates/mir2/src/analysis/domtree.rs @@ -0,0 +1,343 @@ +//! This module contains dominantor tree related structs. +//! +//! The algorithm is based on Keith D. Cooper., Timothy J. Harvey., and Ken +//! Kennedy.: A Simple, Fast Dominance Algorithm: + +use std::collections::BTreeSet; + +use fxhash::FxHashMap; + +use crate::ir::BasicBlockId; + +use super::cfg::ControlFlowGraph; + +#[derive(Debug, Clone)] +pub struct DomTree { + doms: FxHashMap, + /// CFG sorted in reverse post order. + rpo: Vec, +} + +impl DomTree { + pub fn compute(cfg: &ControlFlowGraph) -> Self { + let mut doms = FxHashMap::default(); + doms.insert(cfg.entry(), cfg.entry()); + let mut rpo: Vec<_> = cfg.post_order().collect(); + rpo.reverse(); + + let mut domtree = Self { doms, rpo }; + + let block_num = domtree.rpo.len(); + + let mut rpo_nums = FxHashMap::default(); + for (i, &block) in domtree.rpo.iter().enumerate() { + rpo_nums.insert(block, (block_num - i) as u32); + } + + let mut changed = true; + while changed { + changed = false; + for &block in domtree.rpo.iter().skip(1) { + let processed_pred = match cfg + .preds(block) + .iter() + .find(|pred| domtree.doms.contains_key(pred)) + { + Some(pred) => *pred, + _ => continue, + }; + let mut new_dom = processed_pred; + + for &pred in cfg.preds(block) { + if pred != processed_pred && domtree.doms.contains_key(&pred) { + new_dom = domtree.intersect(new_dom, pred, &rpo_nums); + } + } + if Some(new_dom) != domtree.doms.get(&block).copied() { + changed = true; + domtree.doms.insert(block, new_dom); + } + } + } + + domtree + } + + /// Returns the immediate dominator of the `block`. + /// Returns None if the `block` is unreachable from the entry block, or the + /// `block` is the entry block itself. + pub fn idom(&self, block: BasicBlockId) -> Option { + if self.rpo[0] == block { + return None; + } + self.doms.get(&block).copied() + } + + /// Returns `true` if block1 strictly dominates block2. + pub fn strictly_dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool { + let mut current_block = block2; + while let Some(block) = self.idom(current_block) { + if block == block1 { + return true; + } + current_block = block; + } + + false + } + + /// Returns `true` if block1 dominates block2. + pub fn dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool { + if block1 == block2 { + return true; + } + + self.strictly_dominates(block1, block2) + } + + /// Returns `true` if block is reachable from the entry block. + pub fn is_reachable(&self, block: BasicBlockId) -> bool { + self.idom(block).is_some() + } + + /// Returns blocks in RPO. + pub fn rpo(&self) -> &[BasicBlockId] { + &self.rpo + } + + fn intersect( + &self, + mut b1: BasicBlockId, + mut b2: BasicBlockId, + rpo_nums: &FxHashMap, + ) -> BasicBlockId { + while b1 != b2 { + while rpo_nums[&b1] < rpo_nums[&b2] { + b1 = self.doms[&b1]; + } + while rpo_nums[&b2] < rpo_nums[&b1] { + b2 = self.doms[&b2] + } + } + + b1 + } + + /// Compute dominance frontiers of each blocks. + pub fn compute_df(&self, cfg: &ControlFlowGraph) -> DFSet { + let mut df = DFSet::default(); + + for &block in &self.rpo { + let preds = cfg.preds(block); + if preds.len() < 2 { + continue; + } + + for pred in preds { + let mut runner = *pred; + while self.doms.get(&block) != Some(&runner) && self.is_reachable(runner) { + df.0.entry(runner).or_default().insert(block); + runner = self.doms[&runner]; + } + } + } + + df + } +} + +/// Dominance frontiers of each blocks. +#[derive(Default, Debug)] +pub struct DFSet(FxHashMap>); + +impl DFSet { + /// Returns all dominance frontieres of a `block`. + pub fn frontiers( + &self, + block: BasicBlockId, + ) -> Option + '_> { + self.0.get(&block).map(|set| set.iter().copied()) + } + + /// Returns number of frontier blocks of a `block`. + pub fn frontier_num(&self, block: BasicBlockId) -> usize { + self.0.get(&block).map(BTreeSet::len).unwrap_or(0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; + + fn calc_dom(func: &FunctionBody) -> (DomTree, DFSet) { + let cfg = ControlFlowGraph::compute(func); + let domtree = DomTree::compute(&cfg); + let df = domtree.compute_df(&cfg); + (domtree, df) + } + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + #[test] + fn dom_tree_if_else() { + let mut builder = body_builder(); + + let then_block = builder.make_block(); + let else_block = builder.make_block(); + let merge_block = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + + builder.move_to_block(then_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(else_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(merge_block); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let (dom_tree, df) = calc_dom(&func); + let entry_block = func.order.entry(); + assert_eq!(dom_tree.idom(entry_block), None); + assert_eq!(dom_tree.idom(then_block), Some(entry_block)); + assert_eq!(dom_tree.idom(else_block), Some(entry_block)); + assert_eq!(dom_tree.idom(merge_block), Some(entry_block)); + + assert_eq!(df.frontier_num(entry_block), 0); + assert_eq!(df.frontier_num(then_block), 1); + assert_eq!( + df.frontiers(then_block).unwrap().next().unwrap(), + merge_block + ); + assert_eq!( + df.frontiers(else_block).unwrap().next().unwrap(), + merge_block + ); + assert_eq!(df.frontier_num(merge_block), 0); + } + + #[test] + fn unreachable_edge() { + let mut builder = body_builder(); + + let block1 = builder.make_block(); + let block2 = builder.make_block(); + let block3 = builder.make_block(); + let block4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, block1, block2, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.jump(block4, SourceInfo::dummy()); + + builder.move_to_block(block2); + builder.jump(block4, SourceInfo::dummy()); + + builder.move_to_block(block3); + builder.jump(block4, SourceInfo::dummy()); + + builder.move_to_block(block4); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let (dom_tree, _) = calc_dom(&func); + let entry_block = func.order.entry(); + assert_eq!(dom_tree.idom(entry_block), None); + assert_eq!(dom_tree.idom(block1), Some(entry_block)); + assert_eq!(dom_tree.idom(block2), Some(entry_block)); + assert_eq!(dom_tree.idom(block3), None); + assert!(!dom_tree.is_reachable(block3)); + assert_eq!(dom_tree.idom(block4), Some(entry_block)); + } + + #[test] + fn dom_tree_complex() { + let mut builder = body_builder(); + + let block1 = builder.make_block(); + let block2 = builder.make_block(); + let block3 = builder.make_block(); + let block4 = builder.make_block(); + let block5 = builder.make_block(); + let block6 = builder.make_block(); + let block7 = builder.make_block(); + let block8 = builder.make_block(); + let block9 = builder.make_block(); + let block10 = builder.make_block(); + let block11 = builder.make_block(); + let block12 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, block2, block1, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.branch(v0, block6, block3, SourceInfo::dummy()); + + builder.move_to_block(block2); + builder.branch(v0, block7, block4, SourceInfo::dummy()); + + builder.move_to_block(block3); + builder.branch(v0, block6, block5, SourceInfo::dummy()); + + builder.move_to_block(block4); + builder.branch(v0, block7, block2, SourceInfo::dummy()); + + builder.move_to_block(block5); + builder.branch(v0, block10, block8, SourceInfo::dummy()); + + builder.move_to_block(block6); + builder.jump(block9, SourceInfo::dummy()); + + builder.move_to_block(block7); + builder.jump(block12, SourceInfo::dummy()); + + builder.move_to_block(block8); + builder.jump(block11, SourceInfo::dummy()); + + builder.move_to_block(block9); + builder.jump(block8, SourceInfo::dummy()); + + builder.move_to_block(block10); + builder.jump(block11, SourceInfo::dummy()); + + builder.move_to_block(block11); + builder.branch(v0, block12, block2, SourceInfo::dummy()); + + builder.move_to_block(block12); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let (dom_tree, _) = calc_dom(&func); + let entry_block = func.order.entry(); + assert_eq!(dom_tree.idom(entry_block), None); + assert_eq!(dom_tree.idom(block1), Some(entry_block)); + assert_eq!(dom_tree.idom(block2), Some(entry_block)); + assert_eq!(dom_tree.idom(block3), Some(block1)); + assert_eq!(dom_tree.idom(block4), Some(block2)); + assert_eq!(dom_tree.idom(block5), Some(block3)); + assert_eq!(dom_tree.idom(block6), Some(block1)); + assert_eq!(dom_tree.idom(block7), Some(block2)); + assert_eq!(dom_tree.idom(block8), Some(block1)); + assert_eq!(dom_tree.idom(block9), Some(block6)); + assert_eq!(dom_tree.idom(block10), Some(block5)); + assert_eq!(dom_tree.idom(block11), Some(block1)); + assert_eq!(dom_tree.idom(block12), Some(entry_block)); + } +} diff --git a/crates/mir2/src/analysis/loop_tree.rs b/crates/mir2/src/analysis/loop_tree.rs new file mode 100644 index 0000000000..ca13db5dcb --- /dev/null +++ b/crates/mir2/src/analysis/loop_tree.rs @@ -0,0 +1,347 @@ +use id_arena::{Arena, Id}; + +use fxhash::FxHashMap; + +use super::{cfg::ControlFlowGraph, domtree::DomTree}; + +use crate::ir::BasicBlockId; + +#[derive(Debug, Default, Clone)] +pub struct LoopTree { + /// Stores loops. + /// The index of an outer loops is guaranteed to be lower than its inner + /// loops because loops are found in RPO. + loops: Arena, + + /// Maps blocks to its contained loop. + /// If the block is contained by multiple nested loops, then the block is + /// mapped to the innermost loop. + block_to_loop: FxHashMap, +} + +pub type LoopId = Id; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Loop { + /// A header of the loop. + pub header: BasicBlockId, + + /// A parent loop that includes the loop. + pub parent: Option, + + /// Child loops that the loop includes. + pub children: Vec, +} + +impl LoopTree { + pub fn compute(cfg: &ControlFlowGraph, domtree: &DomTree) -> Self { + let mut tree = LoopTree::default(); + + // Find loop headers in RPO, this means outer loops are guaranteed to be + // inserted first, then its inner loops are inserted. + for &block in domtree.rpo() { + for &pred in cfg.preds(block) { + if domtree.dominates(block, pred) { + let loop_data = Loop { + header: block, + parent: None, + children: Vec::new(), + }; + + tree.loops.alloc(loop_data); + break; + } + } + } + + tree.analyze_loops(cfg, domtree); + + tree + } + + /// Returns all blocks in the loop. + pub fn iter_blocks_post_order<'a, 'b>( + &'a self, + cfg: &'b ControlFlowGraph, + lp: LoopId, + ) -> BlocksInLoopPostOrder<'a, 'b> { + BlocksInLoopPostOrder::new(self, cfg, lp) + } + + /// Returns all loops in a function body. + /// An outer loop is guaranteed to be iterated before its inner loops. + pub fn loops(&self) -> impl Iterator + '_ { + self.loops.iter().map(|(id, _)| id) + } + + /// Returns number of loops found. + pub fn loop_num(&self) -> usize { + self.loops.len() + } + + /// Returns `true` if the `block` is in the `lp`. + pub fn is_block_in_loop(&self, block: BasicBlockId, lp: LoopId) -> bool { + let mut loop_of_block = self.loop_of_block(block); + while let Some(cur_lp) = loop_of_block { + if lp == cur_lp { + return true; + } + loop_of_block = self.parent_loop(cur_lp); + } + false + } + + /// Returns header block of the `lp`. + pub fn loop_header(&self, lp: LoopId) -> BasicBlockId { + self.loops[lp].header + } + + /// Get parent loop of the `lp` if exists. + pub fn parent_loop(&self, lp: LoopId) -> Option { + self.loops[lp].parent + } + + /// Returns the loop that the `block` belongs to. + /// If the `block` belongs to multiple loops, then returns the innermost + /// loop. + pub fn loop_of_block(&self, block: BasicBlockId) -> Option { + self.block_to_loop.get(&block).copied() + } + + /// Analyze loops. This method does + /// 1. Mapping each blocks to its contained loop. + /// 2. Setting parent and child of the loops. + fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) { + let mut worklist = vec![]; + + // Iterate loops reversely to ensure analyze inner loops first. + let loops_rev: Vec<_> = self.loops.iter().rev().map(|(id, _)| id).collect(); + for cur_lp in loops_rev { + let cur_lp_header = self.loop_header(cur_lp); + + // Add predecessors of the loop header to worklist. + for &block in cfg.preds(cur_lp_header) { + if domtree.dominates(cur_lp_header, block) { + worklist.push(block); + } + } + + while let Some(block) = worklist.pop() { + match self.block_to_loop.get(&block).copied() { + Some(lp_of_block) => { + let outermost_parent = self.outermost_parent(lp_of_block); + + // If outermost parent is current loop, then the block is already visited. + if outermost_parent == cur_lp { + continue; + } else { + self.loops[cur_lp].children.push(outermost_parent); + self.loops[outermost_parent].parent = cur_lp.into(); + + let lp_header_of_block = self.loop_header(lp_of_block); + worklist.extend(cfg.preds(lp_header_of_block)); + } + } + + // If the block is not mapped to any loops, then map it to the loop. + None => { + self.map_block(block, cur_lp); + // If block is not loop header, then add its predecessors to the worklist. + if block != cur_lp_header { + worklist.extend(cfg.preds(block)); + } + } + } + } + } + } + + /// Returns the outermost parent loop of `lp`. If `lp` doesn't have any + /// parent, then returns `lp` itself. + fn outermost_parent(&self, mut lp: LoopId) -> LoopId { + while let Some(parent) = self.parent_loop(lp) { + lp = parent; + } + lp + } + + /// Map `block` to `lp`. + fn map_block(&mut self, block: BasicBlockId, lp: LoopId) { + self.block_to_loop.insert(block, lp); + } +} + +pub struct BlocksInLoopPostOrder<'a, 'b> { + lpt: &'a LoopTree, + cfg: &'b ControlFlowGraph, + lp: LoopId, + stack: Vec, + block_state: FxHashMap, +} + +impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> { + fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: LoopId) -> Self { + let loop_header = lpt.loop_header(lp); + + Self { + lpt, + cfg, + lp, + stack: vec![loop_header], + block_state: FxHashMap::default(), + } + } +} + +impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + while let Some(&block) = self.stack.last() { + match self.block_state.get(&block) { + // The block is already visited, but not returned from the iterator, + // so mark the block as `Finished` and return the block. + Some(BlockState::Visited) => { + let block = self.stack.pop().unwrap(); + self.block_state.insert(block, BlockState::Finished); + return Some(block); + } + + // The block is already returned, so just remove the block from the stack. + Some(BlockState::Finished) => { + self.stack.pop().unwrap(); + } + + // The block is not visited yet, so push its unvisited in-loop successors to the + // stack and mark the block as `Visited`. + None => { + self.block_state.insert(block, BlockState::Visited); + for &succ in self.cfg.succs(block) { + if self.block_state.get(&succ).is_none() + && self.lpt.is_block_in_loop(succ, self.lp) + { + self.stack.push(succ); + } + } + } + } + } + + None + } +} + +enum BlockState { + Visited, + Finished, +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; + + fn compute_loop(func: &FunctionBody) -> LoopTree { + let cfg = ControlFlowGraph::compute(func); + let domtree = DomTree::compute(&cfg); + LoopTree::compute(&cfg, &domtree) + } + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + #[test] + fn simple_loop() { + let mut builder = body_builder(); + + let entry = builder.current_block(); + let block1 = builder.make_block(); + let block2 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(false, dummy_ty); + builder.branch(v0, block1, block2, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.jump(entry, SourceInfo::dummy()); + + builder.move_to_block(block2); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let lpt = compute_loop(&func); + + assert_eq!(lpt.loop_num(), 1); + let lp = lpt.loops().next().unwrap(); + + assert!(lpt.is_block_in_loop(entry, lp)); + assert_eq!(lpt.loop_of_block(entry), Some(lp)); + + assert!(lpt.is_block_in_loop(block1, lp)); + assert_eq!(lpt.loop_of_block(block1), Some(lp)); + + assert!(!lpt.is_block_in_loop(block2, lp)); + assert!(lpt.loop_of_block(block2).is_none()); + + assert_eq!(lpt.loop_header(lp), entry); + } + + #[test] + fn nested_loop() { + let mut builder = body_builder(); + + let entry = builder.current_block(); + let block1 = builder.make_block(); + let block2 = builder.make_block(); + let block3 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(false, dummy_ty); + builder.branch(v0, block1, block3, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.branch(v0, entry, block2, SourceInfo::dummy()); + + builder.move_to_block(block2); + builder.jump(block1, SourceInfo::dummy()); + + builder.move_to_block(block3); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let lpt = compute_loop(&func); + + assert_eq!(lpt.loop_num(), 2); + let mut loops = lpt.loops(); + let outer_lp = loops.next().unwrap(); + let inner_lp = loops.next().unwrap(); + + assert!(lpt.is_block_in_loop(entry, outer_lp)); + assert!(!lpt.is_block_in_loop(entry, inner_lp)); + assert_eq!(lpt.loop_of_block(entry), Some(outer_lp)); + + assert!(lpt.is_block_in_loop(block1, outer_lp)); + assert!(lpt.is_block_in_loop(block1, inner_lp)); + assert_eq!(lpt.loop_of_block(block1), Some(inner_lp)); + + assert!(lpt.is_block_in_loop(block2, outer_lp)); + assert!(lpt.is_block_in_loop(block2, inner_lp)); + assert_eq!(lpt.loop_of_block(block2), Some(inner_lp)); + + assert!(!lpt.is_block_in_loop(block3, outer_lp)); + assert!(!lpt.is_block_in_loop(block3, inner_lp)); + assert!(lpt.loop_of_block(block3).is_none()); + + assert!(lpt.parent_loop(outer_lp).is_none()); + assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp)); + + assert_eq!(lpt.loop_header(outer_lp), entry); + assert_eq!(lpt.loop_header(inner_lp), block1); + } +} diff --git a/crates/mir2/src/analysis/mod.rs b/crates/mir2/src/analysis/mod.rs new file mode 100644 index 0000000000..b895cc02a7 --- /dev/null +++ b/crates/mir2/src/analysis/mod.rs @@ -0,0 +1,9 @@ +pub mod cfg; +pub mod domtree; +pub mod loop_tree; +pub mod post_domtree; + +pub use cfg::ControlFlowGraph; +pub use domtree::DomTree; +pub use loop_tree::LoopTree; +pub use post_domtree::PostDomTree; diff --git a/crates/mir2/src/analysis/post_domtree.rs b/crates/mir2/src/analysis/post_domtree.rs new file mode 100644 index 0000000000..ba33aab5f0 --- /dev/null +++ b/crates/mir2/src/analysis/post_domtree.rs @@ -0,0 +1,284 @@ +//! This module contains implementation of `Post Dominator Tree`. + +use id_arena::{ArenaBehavior, DefaultArenaBehavior}; + +use super::{cfg::ControlFlowGraph, domtree::DomTree}; + +use crate::ir::{BasicBlock, BasicBlockId, FunctionBody}; + +#[derive(Debug)] +pub struct PostDomTree { + /// Dummy entry block to calculate post dom tree. + dummy_entry: BasicBlockId, + /// Canonical dummy exit block to calculate post dom tree. All blocks ends + /// with `return` has an edge to this block. + dummy_exit: BasicBlockId, + + /// Dominator tree of reverse control flow graph. + domtree: DomTree, +} + +impl PostDomTree { + pub fn compute(func: &FunctionBody) -> Self { + let mut rcfg = ControlFlowGraph::compute(func); + + let real_entry = rcfg.entry(); + + let dummy_entry = Self::make_dummy_block(); + let dummy_exit = Self::make_dummy_block(); + // Add edges from dummy entry block to real entry block and dummy exit block. + rcfg.add_edge(dummy_entry, real_entry); + rcfg.add_edge(dummy_entry, dummy_exit); + + // Add edges from real exit blocks to dummy exit block. + for exit in std::mem::take(&mut rcfg.exits) { + rcfg.add_edge(exit, dummy_exit); + } + + rcfg.reverse_edge(dummy_exit, vec![dummy_entry]); + let domtree = DomTree::compute(&rcfg); + + Self { + dummy_entry, + dummy_exit, + domtree, + } + } + + pub fn post_idom(&self, block: BasicBlockId) -> PostIDom { + match self.domtree.idom(block).unwrap() { + block if block == self.dummy_entry => PostIDom::DummyEntry, + block if block == self.dummy_exit => PostIDom::DummyExit, + other => PostIDom::Block(other), + } + } + + /// Returns `true` if block is reachable from the exit blocks. + pub fn is_reachable(&self, block: BasicBlockId) -> bool { + self.domtree.is_reachable(block) + } + + fn make_dummy_block() -> BasicBlockId { + let arena_id = DefaultArenaBehavior::::new_arena_id(); + DefaultArenaBehavior::new_id(arena_id, 0) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PostIDom { + DummyEntry, + DummyExit, + Block(BasicBlockId), +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::ir::{body_builder::BodyBuilder, FunctionId, SourceInfo, TypeId}; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + #[test] + fn test_if_else_merge() { + let mut builder = body_builder(); + let then_block = builder.make_block(); + let else_block = builder.make_block(); + let merge_block = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + + builder.move_to_block(then_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(else_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(merge_block); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let post_dom_tree = PostDomTree::compute(&func); + let entry_block = func.order.entry(); + assert_eq!( + post_dom_tree.post_idom(entry_block), + PostIDom::Block(merge_block) + ); + assert_eq!( + post_dom_tree.post_idom(then_block), + PostIDom::Block(merge_block) + ); + assert_eq!( + post_dom_tree.post_idom(else_block), + PostIDom::Block(merge_block) + ); + assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); + } + + #[test] + fn test_if_else_return() { + let mut builder = body_builder(); + let then_block = builder.make_block(); + let else_block = builder.make_block(); + let merge_block = builder.make_block(); + + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + + builder.move_to_block(then_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(else_block); + builder.ret(dummy_value, SourceInfo::dummy()); + + builder.move_to_block(merge_block); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let post_dom_tree = PostDomTree::compute(&func); + let entry_block = func.order.entry(); + assert_eq!(post_dom_tree.post_idom(entry_block), PostIDom::DummyExit,); + assert_eq!( + post_dom_tree.post_idom(then_block), + PostIDom::Block(merge_block), + ); + assert_eq!(post_dom_tree.post_idom(else_block), PostIDom::DummyExit); + assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); + } + + #[test] + fn test_if_non_else() { + let mut builder = body_builder(); + let then_block = builder.make_block(); + let merge_block = builder.make_block(); + + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + builder.branch(v0, then_block, merge_block, SourceInfo::dummy()); + + builder.move_to_block(then_block); + builder.jump(merge_block, SourceInfo::dummy()); + + builder.move_to_block(merge_block); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let post_dom_tree = PostDomTree::compute(&func); + let entry_block = func.order.entry(); + assert_eq!( + post_dom_tree.post_idom(entry_block), + PostIDom::Block(merge_block), + ); + assert_eq!( + post_dom_tree.post_idom(then_block), + PostIDom::Block(merge_block), + ); + assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); + } + + #[test] + fn test_loop() { + let mut builder = body_builder(); + let block1 = builder.make_block(); + let block2 = builder.make_block(); + let block3 = builder.make_block(); + let block4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + + builder.branch(v0, block1, block2, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.jump(block3, SourceInfo::dummy()); + + builder.move_to_block(block2); + builder.branch(v0, block3, block4, SourceInfo::dummy()); + + builder.move_to_block(block3); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + builder.move_to_block(block4); + builder.jump(block2, SourceInfo::dummy()); + + let func = builder.build(); + + let post_dom_tree = PostDomTree::compute(&func); + let entry_block = func.order.entry(); + assert_eq!( + post_dom_tree.post_idom(entry_block), + PostIDom::Block(block3), + ); + assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block3)); + assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block3)); + assert_eq!(post_dom_tree.post_idom(block3), PostIDom::DummyExit); + assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block2)); + } + + #[test] + fn test_pd_complex() { + let mut builder = body_builder(); + let block1 = builder.make_block(); + let block2 = builder.make_block(); + let block3 = builder.make_block(); + let block4 = builder.make_block(); + let block5 = builder.make_block(); + let block6 = builder.make_block(); + let block7 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + + builder.branch(v0, block1, block2, SourceInfo::dummy()); + + builder.move_to_block(block1); + builder.jump(block6, SourceInfo::dummy()); + + builder.move_to_block(block2); + builder.branch(v0, block3, block4, SourceInfo::dummy()); + + builder.move_to_block(block3); + builder.jump(block5, SourceInfo::dummy()); + + builder.move_to_block(block4); + builder.jump(block5, SourceInfo::dummy()); + + builder.move_to_block(block5); + builder.jump(block6, SourceInfo::dummy()); + + builder.move_to_block(block6); + builder.jump(block7, SourceInfo::dummy()); + + builder.move_to_block(block7); + let dummy_value = builder.make_unit(dummy_ty); + builder.ret(dummy_value, SourceInfo::dummy()); + + let func = builder.build(); + + let post_dom_tree = PostDomTree::compute(&func); + let entry_block = func.order.entry(); + assert_eq!( + post_dom_tree.post_idom(entry_block), + PostIDom::Block(block6), + ); + assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block6)); + assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block5)); + assert_eq!(post_dom_tree.post_idom(block3), PostIDom::Block(block5)); + assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block5)); + assert_eq!(post_dom_tree.post_idom(block5), PostIDom::Block(block6)); + assert_eq!(post_dom_tree.post_idom(block6), PostIDom::Block(block7)); + assert_eq!(post_dom_tree.post_idom(block7), PostIDom::DummyExit); + } +} diff --git a/crates/mir2/src/db.rs b/crates/mir2/src/db.rs new file mode 100644 index 0000000000..fc930318de --- /dev/null +++ b/crates/mir2/src/db.rs @@ -0,0 +1,104 @@ +#![allow(clippy::arc_with_non_send_sync)] +use std::{collections::BTreeMap, rc::Rc}; + +use fe_analyzer::{ + db::AnalyzerDbStorage, + namespace::{items as analyzer_items, types as analyzer_types}, + AnalyzerDb, +}; +use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; +use smol_str::SmolStr; + +use crate::ir::{self, ConstantId, TypeId}; + +mod queries; + +#[salsa::query_group(MirDbStorage)] +pub trait MirDb: AnalyzerDb + Upcast + UpcastMut { + #[salsa::interned] + fn mir_intern_const(&self, data: Rc) -> ir::ConstantId; + #[salsa::interned] + fn mir_intern_type(&self, data: Rc) -> ir::TypeId; + #[salsa::interned] + fn mir_intern_function(&self, data: Rc) -> ir::FunctionId; + + #[salsa::invoke(queries::module::mir_lower_module_all_functions)] + fn mir_lower_module_all_functions( + &self, + module: analyzer_items::ModuleId, + ) -> Rc>; + + #[salsa::invoke(queries::contract::mir_lower_contract_all_functions)] + fn mir_lower_contract_all_functions( + &self, + contract: analyzer_items::ContractId, + ) -> Rc>; + + #[salsa::invoke(queries::structs::mir_lower_struct_all_functions)] + fn mir_lower_struct_all_functions( + &self, + struct_: analyzer_items::StructId, + ) -> Rc>; + + #[salsa::invoke(queries::enums::mir_lower_enum_all_functions)] + fn mir_lower_enum_all_functions( + &self, + enum_: analyzer_items::EnumId, + ) -> Rc>; + + #[salsa::invoke(queries::types::mir_lowered_type)] + fn mir_lowered_type(&self, analyzer_type: analyzer_types::TypeId) -> TypeId; + + #[salsa::invoke(queries::constant::mir_lowered_constant)] + fn mir_lowered_constant(&self, analyzer_const: analyzer_items::ModuleConstantId) -> ConstantId; + + #[salsa::invoke(queries::function::mir_lowered_func_signature)] + fn mir_lowered_func_signature( + &self, + analyzer_func: analyzer_items::FunctionId, + ) -> ir::FunctionId; + #[salsa::invoke(queries::function::mir_lowered_monomorphized_func_signature)] + fn mir_lowered_monomorphized_func_signature( + &self, + analyzer_func: analyzer_items::FunctionId, + resolved_generics: BTreeMap, + ) -> ir::FunctionId; + #[salsa::invoke(queries::function::mir_lowered_pseudo_monomorphized_func_signature)] + fn mir_lowered_pseudo_monomorphized_func_signature( + &self, + analyzer_func: analyzer_items::FunctionId, + ) -> ir::FunctionId; + #[salsa::invoke(queries::function::mir_lowered_func_body)] + fn mir_lowered_func_body(&self, func: ir::FunctionId) -> Rc; +} + +#[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage)] +#[derive(Default)] +pub struct NewDb { + storage: salsa::Storage, +} +impl salsa::Database for NewDb {} + +impl Upcast for NewDb { + fn upcast(&self) -> &(dyn SourceDb + 'static) { + self + } +} + +impl UpcastMut for NewDb { + fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { + &mut *self + } +} + +impl Upcast for NewDb { + fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { + self + } +} + +impl UpcastMut for NewDb { + fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { + &mut *self + } +} diff --git a/crates/mir2/src/db/queries.rs b/crates/mir2/src/db/queries.rs new file mode 100644 index 0000000000..8cdae44831 --- /dev/null +++ b/crates/mir2/src/db/queries.rs @@ -0,0 +1,7 @@ +pub mod constant; +pub mod contract; +pub mod enums; +pub mod function; +pub mod module; +pub mod structs; +pub mod types; diff --git a/crates/mir2/src/db/queries/constant.rs b/crates/mir2/src/db/queries/constant.rs new file mode 100644 index 0000000000..1e012420c1 --- /dev/null +++ b/crates/mir2/src/db/queries/constant.rs @@ -0,0 +1,43 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items as analyzer_items; + +use crate::{ + db::MirDb, + ir::{Constant, ConstantId, SourceInfo, TypeId}, +}; + +pub fn mir_lowered_constant( + db: &dyn MirDb, + analyzer_const: analyzer_items::ModuleConstantId, +) -> ConstantId { + let name = analyzer_const.name(db.upcast()); + let value = analyzer_const.constant_value(db.upcast()).unwrap(); + let ty = analyzer_const.typ(db.upcast()).unwrap(); + let module_id = analyzer_const.module(db.upcast()); + let span = analyzer_const.span(db.upcast()); + let id = analyzer_const.node_id(db.upcast()); + + let ty = db.mir_lowered_type(ty); + let source = SourceInfo { span, id }; + + let constant = Constant { + name, + value: value.into(), + ty, + module_id, + source, + }; + + db.mir_intern_const(constant.into()) +} + +impl ConstantId { + pub fn data(self, db: &dyn MirDb) -> Rc { + db.lookup_mir_intern_const(self) + } + + pub fn ty(self, db: &dyn MirDb) -> TypeId { + self.data(db).ty + } +} diff --git a/crates/mir2/src/db/queries/contract.rs b/crates/mir2/src/db/queries/contract.rs new file mode 100644 index 0000000000..b36b1893e1 --- /dev/null +++ b/crates/mir2/src/db/queries/contract.rs @@ -0,0 +1,17 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items::{self as analyzer_items}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_contract_all_functions( + db: &dyn MirDb, + contract: analyzer_items::ContractId, +) -> Rc> { + contract + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2/src/db/queries/enums.rs b/crates/mir2/src/db/queries/enums.rs new file mode 100644 index 0000000000..2fb26cb478 --- /dev/null +++ b/crates/mir2/src/db/queries/enums.rs @@ -0,0 +1,17 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items::{self as analyzer_items}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_enum_all_functions( + db: &dyn MirDb, + enum_: analyzer_items::EnumId, +) -> Rc> { + enum_ + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2/src/db/queries/function.rs b/crates/mir2/src/db/queries/function.rs new file mode 100644 index 0000000000..e9f0e9f282 --- /dev/null +++ b/crates/mir2/src/db/queries/function.rs @@ -0,0 +1,130 @@ +use std::{collections::BTreeMap, rc::Rc}; + +use fe_analyzer::{ + display::Displayable, + namespace::{items as analyzer_items, items::Item, types as analyzer_types}, +}; + +use smol_str::SmolStr; + +use crate::{ + db::MirDb, + ir::{self, function::Linkage, FunctionSignature, TypeId}, + lower::function::{lower_func_body, lower_func_signature, lower_monomorphized_func_signature}, +}; + +pub fn mir_lowered_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, +) -> ir::FunctionId { + lower_func_signature(db, analyzer_func) +} + +pub fn mir_lowered_monomorphized_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, + resolved_generics: BTreeMap, +) -> ir::FunctionId { + lower_monomorphized_func_signature(db, analyzer_func, resolved_generics) +} + +/// Generate MIR function and monomorphize generic parameters as if they were +/// called with unit type NOTE: THIS SHOULD ONLY BE USED IN TEST CODE +pub fn mir_lowered_pseudo_monomorphized_func_signature( + db: &dyn MirDb, + analyzer_func: analyzer_items::FunctionId, +) -> ir::FunctionId { + let resolved_generics = analyzer_func + .sig(db.upcast()) + .generic_params(db.upcast()) + .iter() + .map(|generic| (generic.name(), analyzer_types::TypeId::unit(db.upcast()))) + .collect::>(); + lower_monomorphized_func_signature(db, analyzer_func, resolved_generics) +} + +pub fn mir_lowered_func_body(db: &dyn MirDb, func: ir::FunctionId) -> Rc { + lower_func_body(db, func) +} + +impl ir::FunctionId { + pub fn signature(self, db: &dyn MirDb) -> Rc { + db.lookup_mir_intern_function(self) + } + + pub fn return_type(self, db: &dyn MirDb) -> Option { + self.signature(db).return_type + } + + pub fn linkage(self, db: &dyn MirDb) -> Linkage { + self.signature(db).linkage + } + + pub fn analyzer_func(self, db: &dyn MirDb) -> analyzer_items::FunctionId { + self.signature(db).analyzer_func_id + } + + pub fn body(self, db: &dyn MirDb) -> Rc { + db.mir_lowered_func_body(self) + } + + pub fn module(self, db: &dyn MirDb) -> analyzer_items::ModuleId { + let analyzer_func = self.analyzer_func(db); + analyzer_func.module(db.upcast()) + } + + pub fn is_contract_init(self, db: &dyn MirDb) -> bool { + self.analyzer_func(db) + .data(db.upcast()) + .sig + .is_constructor(db.upcast()) + } + + /// Returns a type suffix if a generic function was monomorphized + pub fn type_suffix(&self, db: &dyn MirDb) -> SmolStr { + self.signature(db) + .resolved_generics + .values() + .fold(String::new(), |acc, param| { + format!("{}_{}", acc, param.display(db.upcast())) + }) + .into() + } + + pub fn name(&self, db: &dyn MirDb) -> SmolStr { + let analyzer_func = self.analyzer_func(db); + analyzer_func.name(db.upcast()) + } + + /// Returns `class_name::fn_name` if a function is a method else `fn_name`. + pub fn debug_name(self, db: &dyn MirDb) -> SmolStr { + let analyzer_func = self.analyzer_func(db); + let func_name = format!( + "{}{}", + analyzer_func.name(db.upcast()), + self.type_suffix(db) + ); + + match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { + Some(Item::Impl(id)) => { + let class_name = format!( + "<{} as {}>", + id.receiver(db.upcast()).display(db.upcast()), + id.trait_id(db.upcast()).name(db.upcast()) + ); + format!("{class_name}::{func_name}").into() + } + Some(class) => { + let class_name = class.name(db.upcast()); + format!("{class_name}::{func_name}").into() + } + _ => func_name.into(), + } + } + + pub fn returns_aggregate(self, db: &dyn MirDb) -> bool { + self.return_type(db) + .map(|ty| ty.is_aggregate(db)) + .unwrap_or_default() + } +} diff --git a/crates/mir2/src/db/queries/module.rs b/crates/mir2/src/db/queries/module.rs new file mode 100644 index 0000000000..00f0ea6a3c --- /dev/null +++ b/crates/mir2/src/db/queries/module.rs @@ -0,0 +1,35 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items::{self as analyzer_items, TypeDef}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_module_all_functions( + db: &dyn MirDb, + module: analyzer_items::ModuleId, +) -> Rc> { + let mut functions = vec![]; + + let items = module.all_items(db.upcast()); + items.iter().for_each(|item| match item { + analyzer_items::Item::Function(func) => { + functions.push(db.mir_lowered_func_signature(*func)) + } + + analyzer_items::Item::Type(TypeDef::Contract(contract)) => { + functions.extend_from_slice(&db.mir_lower_contract_all_functions(*contract)) + } + + analyzer_items::Item::Type(TypeDef::Struct(struct_)) => { + functions.extend_from_slice(&db.mir_lower_struct_all_functions(*struct_)) + } + + analyzer_items::Item::Type(TypeDef::Enum(enum_)) => { + functions.extend_from_slice(&db.mir_lower_enum_all_functions(*enum_)) + } + + _ => {} + }); + + functions.into() +} diff --git a/crates/mir2/src/db/queries/structs.rs b/crates/mir2/src/db/queries/structs.rs new file mode 100644 index 0000000000..8ca121f94a --- /dev/null +++ b/crates/mir2/src/db/queries/structs.rs @@ -0,0 +1,17 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items::{self as analyzer_items}; + +use crate::{db::MirDb, ir::FunctionId}; + +pub fn mir_lower_struct_all_functions( + db: &dyn MirDb, + struct_: analyzer_items::StructId, +) -> Rc> { + struct_ + .all_functions(db.upcast()) + .iter() + .map(|func| db.mir_lowered_pseudo_monomorphized_func_signature(*func)) + .collect::>() + .into() +} diff --git a/crates/mir2/src/db/queries/types.rs b/crates/mir2/src/db/queries/types.rs new file mode 100644 index 0000000000..a0d13511d5 --- /dev/null +++ b/crates/mir2/src/db/queries/types.rs @@ -0,0 +1,657 @@ +use std::{fmt, rc::Rc, str::FromStr}; + +use fe_analyzer::namespace::{items::EnumVariantId, types as analyzer_types}; + +use num_bigint::BigInt; +use num_traits::ToPrimitive; + +use crate::{ + db::MirDb, + ir::{ + types::{ArrayDef, TupleDef, TypeKind}, + Type, TypeId, Value, + }, + lower::types::lower_type, +}; + +pub fn mir_lowered_type(db: &dyn MirDb, analyzer_type: analyzer_types::TypeId) -> TypeId { + lower_type(db, analyzer_type) +} + +impl TypeId { + pub fn data(self, db: &dyn MirDb) -> Rc { + db.lookup_mir_intern_type(self) + } + + pub fn analyzer_ty(self, db: &dyn MirDb) -> Option { + self.data(db).analyzer_ty + } + + pub fn projection_ty(self, db: &dyn MirDb, access: &Value) -> TypeId { + let ty = self.deref(db); + let pty = match &ty.data(db).kind { + TypeKind::Array(ArrayDef { elem_ty, .. }) => *elem_ty, + TypeKind::Tuple(def) => { + let index = expect_projection_index(access); + def.items[index] + } + TypeKind::Struct(def) | TypeKind::Contract(def) => { + let index = expect_projection_index(access); + def.fields[index].1 + } + TypeKind::Enum(_) => { + let index = expect_projection_index(access); + debug_assert_eq!(index, 0); + ty.projection_ty_imm(db, 0) + } + _ => panic!("{:?} can't project onto the `access`", self.as_string(db)), + }; + match &self.data(db).kind { + TypeKind::SPtr(_) | TypeKind::Contract(_) => pty.make_sptr(db), + TypeKind::MPtr(_) => pty.make_mptr(db), + _ => pty, + } + } + + pub fn deref(self, db: &dyn MirDb) -> TypeId { + match self.data(db).kind { + TypeKind::SPtr(inner) => inner, + TypeKind::MPtr(inner) => inner.deref(db), + _ => self, + } + } + + pub fn make_sptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::SPtr(self), None).into()) + } + + pub fn make_mptr(self, db: &dyn MirDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::MPtr(self), None).into()) + } + + pub fn projection_ty_imm(self, db: &dyn MirDb, index: usize) -> TypeId { + match &self.data(db).kind { + TypeKind::Array(ArrayDef { elem_ty, .. }) => *elem_ty, + TypeKind::Tuple(def) => def.items[index], + TypeKind::Struct(def) | TypeKind::Contract(def) => def.fields[index].1, + TypeKind::Enum(_) => { + debug_assert_eq!(index, 0); + self.enum_disc_type(db) + } + _ => panic!("{:?} can't project onto the `index`", self.as_string(db)), + } + } + + pub fn aggregate_field_num(self, db: &dyn MirDb) -> usize { + match &self.data(db).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + TypeKind::Tuple(def) => def.items.len(), + TypeKind::Struct(def) | TypeKind::Contract(def) => def.fields.len(), + TypeKind::Enum(_) => 2, + _ => unreachable!(), + } + } + + pub fn enum_disc_type(self, db: &dyn MirDb) -> TypeId { + let kind = match &self.deref(db).data(db).kind { + TypeKind::Enum(def) => def.tag_type(), + _ => unreachable!(), + }; + let analyzer_type = match kind { + TypeKind::U8 => Some(analyzer_types::Integer::U8), + TypeKind::U16 => Some(analyzer_types::Integer::U16), + TypeKind::U32 => Some(analyzer_types::Integer::U32), + TypeKind::U64 => Some(analyzer_types::Integer::U64), + TypeKind::U128 => Some(analyzer_types::Integer::U128), + TypeKind::U256 => Some(analyzer_types::Integer::U256), + _ => None, + } + .map(|int| analyzer_types::TypeId::int(db.upcast(), int)); + + db.mir_intern_type(Type::new(kind, analyzer_type).into()) + } + + pub fn enum_data_offset(self, db: &dyn MirDb, slot_size: usize) -> usize { + match &self.data(db).kind { + TypeKind::Enum(def) => { + let disc_size = self.enum_disc_type(db).size_of(db, slot_size); + let mut align = 1; + for variant in def.variants.iter() { + let variant_align = variant.ty.align_of(db, slot_size); + align = num_integer::lcm(align, variant_align); + } + round_up(disc_size, align) + } + _ => unreachable!(), + } + } + + pub fn enum_variant_type(self, db: &dyn MirDb, variant_id: EnumVariantId) -> TypeId { + let name = variant_id.name(db.upcast()); + match &self.deref(db).data(db).kind { + TypeKind::Enum(def) => def + .variants + .iter() + .find(|variant| variant.name == name) + .map(|variant| variant.ty) + .unwrap(), + _ => unreachable!(), + } + } + + pub fn index_from_fname(self, db: &dyn MirDb, fname: &str) -> BigInt { + let ty = self.deref(db); + match &ty.data(db).kind { + TypeKind::Tuple(_) => { + // TODO: Fix this when the syntax for tuple access changes. + let index_str = &fname[4..]; + BigInt::from_str(index_str).unwrap() + } + + TypeKind::Struct(def) | TypeKind::Contract(def) => def + .fields + .iter() + .enumerate() + .find_map(|(i, field)| (field.0 == fname).then(|| i.into())) + .unwrap(), + + other => unreachable!("{:?} does not have fields", other), + } + } + + pub fn is_primitive(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 + | TypeKind::U256 + | TypeKind::Bool + | TypeKind::Address + | TypeKind::Unit + ) + } + + pub fn is_integral(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 + | TypeKind::U256 + ) + } + + pub fn is_address(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).kind, TypeKind::Address) + } + + pub fn is_unit(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Unit) + } + + pub fn is_enum(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Enum(_)) + } + + pub fn is_signed(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::I256 + ) + } + + /// Returns size of the type in bytes. + pub fn size_of(self, db: &dyn MirDb, slot_size: usize) -> usize { + match &self.data(db).kind { + TypeKind::Bool | TypeKind::I8 | TypeKind::U8 => 1, + TypeKind::I16 | TypeKind::U16 => 2, + TypeKind::I32 | TypeKind::U32 => 4, + TypeKind::I64 | TypeKind::U64 => 8, + TypeKind::I128 | TypeKind::U128 => 16, + TypeKind::String(len) => 32 + len, + TypeKind::MPtr(..) + | TypeKind::SPtr(..) + | TypeKind::I256 + | TypeKind::U256 + | TypeKind::Map(_) => 32, + TypeKind::Address => 20, + TypeKind::Unit => 0, + + TypeKind::Array(def) => array_elem_size_imp(db, def, slot_size) * def.len, + + TypeKind::Tuple(def) => { + if def.items.is_empty() { + return 0; + } + let last_idx = def.items.len() - 1; + self.aggregate_elem_offset(db, last_idx, slot_size) + + def.items[last_idx].size_of(db, slot_size) + } + + TypeKind::Struct(def) | TypeKind::Contract(def) => { + if def.fields.is_empty() { + return 0; + } + let last_idx = def.fields.len() - 1; + self.aggregate_elem_offset(db, last_idx, slot_size) + + def.fields[last_idx].1.size_of(db, slot_size) + } + + TypeKind::Enum(def) => { + let data_offset = self.enum_data_offset(db, slot_size); + let maximum_data_size = def + .variants + .iter() + .map(|variant| variant.ty.size_of(db, slot_size)) + .max() + .unwrap_or(0); + data_offset + maximum_data_size + } + } + } + + pub fn is_zero_sized(self, db: &dyn MirDb) -> bool { + // It's ok to use 1 as a slot size because slot size doesn't affect whether a + // type is zero sized or not. + self.size_of(db, 1) == 0 + } + + pub fn align_of(self, db: &dyn MirDb, slot_size: usize) -> usize { + if self.is_primitive(db) { + 1 + } else { + // TODO: Too naive, we could implement more efficient layout for aggregate + // types. + slot_size + } + } + + /// Returns an offset of the element of aggregate type. + pub fn aggregate_elem_offset(self, db: &dyn MirDb, elem_idx: T, slot_size: usize) -> usize + where + T: num_traits::ToPrimitive, + { + debug_assert!(self.is_aggregate(db)); + debug_assert!(elem_idx.to_usize().unwrap() < self.aggregate_field_num(db)); + let elem_idx = elem_idx.to_usize().unwrap(); + + if elem_idx == 0 { + return 0; + } + + match &self.data(db).kind { + TypeKind::Array(def) => array_elem_size_imp(db, def, slot_size) * elem_idx, + TypeKind::Enum(_) => self.enum_data_offset(db, slot_size), + _ => { + let mut offset = self.aggregate_elem_offset(db, elem_idx - 1, slot_size) + + self + .projection_ty_imm(db, elem_idx - 1) + .size_of(db, slot_size); + + let elem_ty = self.projection_ty_imm(db, elem_idx); + if (offset % slot_size + elem_ty.size_of(db, slot_size)) > slot_size { + offset = round_up(offset, slot_size); + } + + round_up(offset, elem_ty.align_of(db, slot_size)) + } + } + } + + pub fn is_aggregate(self, db: &dyn MirDb) -> bool { + matches!( + &self.data(db).kind, + TypeKind::Array(_) + | TypeKind::Tuple(_) + | TypeKind::Struct(_) + | TypeKind::Enum(_) + | TypeKind::Contract(_) + ) + } + + pub fn is_struct(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).as_ref().kind, TypeKind::Struct(_)) + } + + pub fn is_array(self, db: &dyn MirDb) -> bool { + matches!(&self.data(db).kind, TypeKind::Array(_)) + } + + pub fn is_string(self, db: &dyn MirDb) -> bool { + matches! { + &self.data(db).kind, + TypeKind::String(_) + } + } + + pub fn is_ptr(self, db: &dyn MirDb) -> bool { + self.is_mptr(db) || self.is_sptr(db) + } + + pub fn is_mptr(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::MPtr(_)) + } + + pub fn is_sptr(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::SPtr(_)) + } + + pub fn is_map(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::Map(_)) + } + + pub fn is_contract(self, db: &dyn MirDb) -> bool { + matches!(self.data(db).kind, TypeKind::Contract(_)) + } + + pub fn array_elem_size(self, db: &dyn MirDb, slot_size: usize) -> usize { + let data = self.data(db); + if let TypeKind::Array(def) = &data.kind { + array_elem_size_imp(db, def, slot_size) + } else { + panic!("expected `Array` type; but got {:?}", data.as_ref()) + } + } + + pub fn print(&self, db: &dyn MirDb, w: &mut W) -> fmt::Result { + match &self.data(db).kind { + TypeKind::I8 => write!(w, "i8"), + TypeKind::I16 => write!(w, "i16"), + TypeKind::I32 => write!(w, "i32"), + TypeKind::I64 => write!(w, "i64"), + TypeKind::I128 => write!(w, "i128"), + TypeKind::I256 => write!(w, "i256"), + TypeKind::U8 => write!(w, "u8"), + TypeKind::U16 => write!(w, "u16"), + TypeKind::U32 => write!(w, "u32"), + TypeKind::U64 => write!(w, "u64"), + TypeKind::U128 => write!(w, "u128"), + TypeKind::U256 => write!(w, "u256"), + TypeKind::Bool => write!(w, "bool"), + TypeKind::Address => write!(w, "address"), + TypeKind::Unit => write!(w, "()"), + TypeKind::String(size) => write!(w, "Str<{size}>"), + TypeKind::Array(ArrayDef { elem_ty, len }) => { + write!(w, "[")?; + elem_ty.print(db, w)?; + write!(w, "; {len}]") + } + TypeKind::Tuple(TupleDef { items }) => { + write!(w, "(")?; + if items.is_empty() { + return write!(w, ")"); + } + + let len = items.len(); + for item in &items[0..len - 1] { + item.print(db, w)?; + write!(w, ", ")?; + } + items.last().unwrap().print(db, w)?; + write!(w, ")") + } + TypeKind::Struct(def) => { + write!(w, "{}", def.name) + } + TypeKind::Enum(def) => { + write!(w, "{}", def.name) + } + TypeKind::Contract(def) => { + write!(w, "{}", def.name) + } + TypeKind::Map(def) => { + write!(w, "Map<")?; + def.key_ty.print(db, w)?; + write!(w, ",")?; + def.value_ty.print(db, w)?; + write!(w, ">") + } + TypeKind::MPtr(inner) => { + write!(w, "*@m ")?; + inner.print(db, w) + } + TypeKind::SPtr(inner) => { + write!(w, "*@s ")?; + inner.print(db, w) + } + } + } + + pub fn as_string(&self, db: &dyn MirDb) -> String { + let mut s = String::new(); + self.print(db, &mut s).unwrap(); + s + } +} + +fn array_elem_size_imp(db: &dyn MirDb, arr: &ArrayDef, slot_size: usize) -> usize { + let elem_ty = arr.elem_ty; + let elem = elem_ty.size_of(db, slot_size); + let align = if elem_ty.is_address(db) { + slot_size + } else { + elem_ty.align_of(db, slot_size) + }; + round_up(elem, align) +} + +fn expect_projection_index(value: &Value) -> usize { + match value { + Value::Immediate { imm, .. } => imm.to_usize().unwrap(), + _ => panic!("given `value` is not an immediate"), + } +} + +fn round_up(value: usize, slot_size: usize) -> usize { + ((value + slot_size - 1) / slot_size) * slot_size +} + +#[cfg(test)] +mod tests { + use fe_analyzer::namespace::items::ModuleId; + use fe_common::Span; + + use super::*; + use crate::{ + db::{MirDb, NewDb}, + ir::types::StructDef, + }; + + #[test] + fn test_primitive_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let bool = db.mir_intern_type(Type::new(TypeKind::Bool, None).into()); + + debug_assert_eq!(i8.size_of(&db, 1), 1); + debug_assert_eq!(i8.size_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 1), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + debug_assert_eq!(bool.size_of(&db, 1), 1); + debug_assert_eq!(bool.size_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + debug_assert_eq!(i8.align_of(&db, 32), 1); + + let u32 = db.mir_intern_type(Type::new(TypeKind::U32, None).into()); + debug_assert_eq!(u32.size_of(&db, 1), 4); + debug_assert_eq!(u32.size_of(&db, 32), 4); + debug_assert_eq!(u32.align_of(&db, 32), 1); + + let address = db.mir_intern_type(Type::new(TypeKind::Address, None).into()); + debug_assert_eq!(address.size_of(&db, 1), 20); + debug_assert_eq!(address.size_of(&db, 32), 20); + debug_assert_eq!(address.align_of(&db, 32), 1); + } + + #[test] + fn test_primitive_elem_array_type_info() { + let db = NewDb::default(); + let i32 = db.mir_intern_type(Type::new(TypeKind::I32, None).into()); + + let array_len = 10; + let array_def = ArrayDef { + elem_ty: i32, + len: array_len, + }; + let array = db.mir_intern_type(Type::new(TypeKind::Array(array_def), None).into()); + + let elem_size = array.array_elem_size(&db, 1); + debug_assert_eq!(elem_size, 4); + debug_assert_eq!(array.array_elem_size(&db, 32), elem_size); + + debug_assert_eq!(array.size_of(&db, 1), elem_size * array_len); + debug_assert_eq!(array.size_of(&db, 32), elem_size * array_len); + debug_assert_eq!(array.align_of(&db, 1), 1); + debug_assert_eq!(array.align_of(&db, 32), 32); + + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 32), elem_size * 3); + debug_assert_eq!(array.aggregate_elem_offset(&db, 9, 1), elem_size * 9); + } + + #[test] + fn test_aggregate_elem_array_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + let array_len = 10; + let array_def = ArrayDef { + elem_ty: aggregate, + len: array_len, + }; + let array = db.mir_intern_type(Type::new(TypeKind::Array(array_def), None).into()); + + debug_assert_eq!(array.array_elem_size(&db, 1), 34); + debug_assert_eq!(array.array_elem_size(&db, 32), 64); + + debug_assert_eq!(array.size_of(&db, 1), 34 * array_len); + debug_assert_eq!(array.size_of(&db, 32), 64 * array_len); + + debug_assert_eq!(array.align_of(&db, 1), 1); + debug_assert_eq!(array.align_of(&db, 32), 32); + + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 1), 102); + debug_assert_eq!(array.aggregate_elem_offset(&db, 3, 32), 192); + } + + #[test] + fn test_primitive_elem_aggregate_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + debug_assert_eq!(aggregate.size_of(&db, 1), 34); + debug_assert_eq!(aggregate.size_of(&db, 32), 49); + + debug_assert_eq!(aggregate.align_of(&db, 1), 1); + debug_assert_eq!(aggregate.align_of(&db, 32), 32); + + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 1), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 32), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 3, 1), 17); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 3, 32), 32); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 4, 1), 33); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 4, 32), 48); + } + + #[test] + fn test_aggregate_elem_aggregate_type_info() { + let db = NewDb::default(); + let i8 = db.mir_intern_type(Type::new(TypeKind::I8, None).into()); + let i64 = db.mir_intern_type(Type::new(TypeKind::I64, None).into()); + let i128 = db.mir_intern_type(Type::new(TypeKind::I128, None).into()); + + let fields_inner = vec![ + ("".into(), i64), + ("".into(), i64), + ("".into(), i8), + ("".into(), i128), + ("".into(), i8), + ]; + + let struct_def_inner = StructDef { + name: "".into(), + fields: fields_inner, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate_inner = + db.mir_intern_type(Type::new(TypeKind::Struct(struct_def_inner), None).into()); + + let fields = vec![("".into(), i8), ("".into(), aggregate_inner)]; + let struct_def = StructDef { + name: "".into(), + fields, + span: Span::dummy(), + module_id: ModuleId::from_raw_internal(0), + }; + let aggregate = db.mir_intern_type(Type::new(TypeKind::Struct(struct_def), None).into()); + + debug_assert_eq!(aggregate.size_of(&db, 1), 35); + debug_assert_eq!(aggregate.size_of(&db, 32), 81); + + debug_assert_eq!(aggregate.align_of(&db, 1), 1); + debug_assert_eq!(aggregate.align_of(&db, 32), 32); + + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 1), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 0, 32), 0); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 1, 1), 1); + debug_assert_eq!(aggregate.aggregate_elem_offset(&db, 1, 32), 32); + } +} diff --git a/crates/mir2/src/graphviz/block.rs b/crates/mir2/src/graphviz/block.rs new file mode 100644 index 0000000000..9121d8e281 --- /dev/null +++ b/crates/mir2/src/graphviz/block.rs @@ -0,0 +1,62 @@ +use std::fmt::Write; + +use dot2::{label, Id}; + +use crate::{ + analysis::ControlFlowGraph, + db::MirDb, + ir::{BasicBlockId, FunctionId}, + pretty_print::PrettyPrint, +}; + +#[derive(Debug, Clone, Copy)] +pub(super) struct BlockNode { + func: FunctionId, + pub block: BasicBlockId, +} + +impl BlockNode { + pub(super) fn new(func: FunctionId, block: BasicBlockId) -> Self { + Self { func, block } + } + pub(super) fn id(self) -> dot2::Result> { + Id::new(format!("fn{}_bb{}", self.func.0, self.block.index())) + } + + pub(super) fn label(self, db: &dyn MirDb) -> label::Text<'static> { + let mut label = r#""#.to_string(); + + // Write block header. + write!( + &mut label, + r#""#, + self.block.index() + ) + .unwrap(); + + // Write block body. + let func_body = self.func.body(db); + write!(label, r#""#).unwrap(); + + write!(label, "
BB{}
"#).unwrap(); + for inst in func_body.order.iter_inst(self.block) { + let mut inst_string = String::new(); + inst.pretty_print(db, &func_body.store, &mut inst_string) + .unwrap(); + write!(label, "{}", dot2::escape_html(&inst_string)).unwrap(); + write!(label, "
").unwrap(); + } + write!(label, r#"
").unwrap(); + + label::Text::HtmlStr(label.into()) + } + + pub(super) fn succs(self, db: &dyn MirDb) -> Vec { + let func_body = self.func.body(db); + let cfg = ControlFlowGraph::compute(&func_body); + cfg.succs(self.block) + .iter() + .map(|block| Self::new(self.func, *block)) + .collect() + } +} diff --git a/crates/mir2/src/graphviz/function.rs b/crates/mir2/src/graphviz/function.rs new file mode 100644 index 0000000000..fa78d21719 --- /dev/null +++ b/crates/mir2/src/graphviz/function.rs @@ -0,0 +1,78 @@ +use std::fmt::Write; + +use dot2::{label, Id}; + +use crate::{analysis::ControlFlowGraph, db::MirDb, ir::FunctionId, pretty_print::PrettyPrint}; + +use super::block::BlockNode; + +#[derive(Debug, Clone, Copy)] +pub(super) struct FunctionNode { + pub(super) func: FunctionId, +} + +impl FunctionNode { + pub(super) fn new(func: FunctionId) -> Self { + Self { func } + } + + pub(super) fn subgraph_id(self) -> Option> { + dot2::Id::new(format!("cluster_{}", self.func.0)).ok() + } + + pub(super) fn label(self, db: &dyn MirDb) -> label::Text<'static> { + let mut label = self.signature(db); + write!(label, r#"

"#).unwrap(); + + // Maps local value id to local name. + let body = self.func.body(db); + for local in body.store.locals() { + local.pretty_print(db, &body.store, &mut label).unwrap(); + write!( + label, + r#" => {}
"#, + body.store.local_name(*local).unwrap() + ) + .unwrap(); + } + + label::Text::HtmlStr(label.into()) + } + + pub(super) fn blocks(self, db: &dyn MirDb) -> Vec { + let body = self.func.body(db); + // We use control flow graph to collect reachable blocks. + let cfg = ControlFlowGraph::compute(&body); + cfg.post_order() + .map(|block| BlockNode::new(self.func, block)) + .collect() + } + + fn signature(self, db: &dyn MirDb) -> String { + let body = self.func.body(db); + + let sig_data = self.func.signature(db); + let mut sig = format!("fn {}(", self.func.debug_name(db)); + + let params = &sig_data.params; + let param_len = params.len(); + for (i, param) in params.iter().enumerate() { + let name = ¶m.name; + let ty = param.ty; + write!(&mut sig, "{name}: ").unwrap(); + ty.pretty_print(db, &body.store, &mut sig).unwrap(); + if param_len - 1 != i { + write!(sig, ", ").unwrap(); + } + } + write!(sig, ")").unwrap(); + + let ret_ty = self.func.return_type(db); + if let Some(ret_ty) = ret_ty { + write!(sig, " -> ").unwrap(); + ret_ty.pretty_print(db, &body.store, &mut sig).unwrap(); + } + + dot2::escape_html(&sig) + } +} diff --git a/crates/mir2/src/graphviz/mod.rs b/crates/mir2/src/graphviz/mod.rs new file mode 100644 index 0000000000..8ab37cd37e --- /dev/null +++ b/crates/mir2/src/graphviz/mod.rs @@ -0,0 +1,22 @@ +use std::io; + +use fe_analyzer::namespace::items::ModuleId; + +use crate::db::MirDb; + +mod block; +mod function; +mod module; + +/// Writes mir graphs of functions in a `module`. +pub fn write_mir_graphs( + db: &dyn MirDb, + module: ModuleId, + w: &mut W, +) -> io::Result<()> { + let module_graph = module::ModuleGraph::new(db, module); + dot2::render(&module_graph, w).map_err(|err| match err { + dot2::Error::Io(err) => err, + _ => panic!("invalid graphviz id"), + }) +} diff --git a/crates/mir2/src/graphviz/module.rs b/crates/mir2/src/graphviz/module.rs new file mode 100644 index 0000000000..4b0c395b25 --- /dev/null +++ b/crates/mir2/src/graphviz/module.rs @@ -0,0 +1,158 @@ +use dot2::{label::Text, GraphWalk, Id, Kind, Labeller}; +use fe_analyzer::namespace::items::ModuleId; + +use crate::{ + db::MirDb, + ir::{inst::BranchInfo, FunctionId}, + pretty_print::PrettyPrint, +}; + +use super::{block::BlockNode, function::FunctionNode}; + +pub(super) struct ModuleGraph<'db> { + db: &'db dyn MirDb, + module: ModuleId, +} + +impl<'db> ModuleGraph<'db> { + pub(super) fn new(db: &'db dyn MirDb, module: ModuleId) -> Self { + Self { db, module } + } +} + +impl<'db> GraphWalk<'db> for ModuleGraph<'db> { + type Node = BlockNode; + type Edge = ModuleGraphEdge; + type Subgraph = FunctionNode; + + fn nodes(&self) -> dot2::Nodes<'db, Self::Node> { + let mut nodes = Vec::new(); + + // Collect function nodes. + for func in self + .db + .mir_lower_module_all_functions(self.module) + .iter() + .map(|id| FunctionNode::new(*id)) + { + nodes.extend(func.blocks(self.db).into_iter()) + } + + nodes.into() + } + + fn edges(&self) -> dot2::Edges<'db, Self::Edge> { + let mut edges = vec![]; + for func in self.db.mir_lower_module_all_functions(self.module).iter() { + for block in FunctionNode::new(*func).blocks(self.db) { + for succ in block.succs(self.db) { + let edge = ModuleGraphEdge { + from: block, + to: succ, + func: *func, + }; + edges.push(edge); + } + } + } + + edges.into() + } + + fn source(&self, edge: &Self::Edge) -> Self::Node { + edge.from + } + + fn target(&self, edge: &Self::Edge) -> Self::Node { + edge.to + } + + fn subgraphs(&self) -> dot2::Subgraphs<'db, Self::Subgraph> { + self.db + .mir_lower_module_all_functions(self.module) + .iter() + .map(|id| FunctionNode::new(*id)) + .collect::>() + .into() + } + + fn subgraph_nodes(&self, s: &Self::Subgraph) -> dot2::Nodes<'db, Self::Node> { + s.blocks(self.db).into_iter().collect::>().into() + } +} + +impl<'db> Labeller<'db> for ModuleGraph<'db> { + type Node = BlockNode; + type Edge = ModuleGraphEdge; + type Subgraph = FunctionNode; + + fn graph_id(&self) -> dot2::Result> { + let module_name = self.module.name(self.db.upcast()); + dot2::Id::new(module_name.to_string()) + } + + fn node_id(&self, n: &Self::Node) -> dot2::Result> { + n.id() + } + + fn node_shape(&self, _n: &Self::Node) -> Option> { + Some(Text::LabelStr("none".into())) + } + + fn node_label(&self, n: &Self::Node) -> dot2::Result> { + Ok(n.label(self.db)) + } + + fn edge_label<'a>(&self, e: &Self::Edge) -> Text<'db> { + Text::LabelStr(e.label(self.db).into()) + } + + fn subgraph_id(&self, s: &Self::Subgraph) -> Option> { + s.subgraph_id() + } + + fn subgraph_label(&self, s: &Self::Subgraph) -> Text<'db> { + s.label(self.db) + } + + fn kind(&self) -> Kind { + Kind::Digraph + } +} + +#[derive(Debug, Clone)] +pub(super) struct ModuleGraphEdge { + from: BlockNode, + to: BlockNode, + func: FunctionId, +} + +impl ModuleGraphEdge { + fn label(&self, db: &dyn MirDb) -> String { + let body = self.func.body(db); + let terminator = body.order.terminator(&body.store, self.from.block).unwrap(); + let to = self.to.block; + match body.store.branch_info(terminator) { + BranchInfo::NotBranch => unreachable!(), + BranchInfo::Jump(_) => String::new(), + BranchInfo::Branch(_, true_bb, _) => { + format! {"{}", true_bb == to} + } + BranchInfo::Switch(_, table, default) => { + if default == Some(to) { + return "*".to_string(); + } + + for (value, bb) in table.iter() { + if bb == to { + let mut s = String::new(); + value.pretty_print(db, &body.store, &mut s).unwrap(); + return s; + } + } + + unreachable!() + } + } + } +} diff --git a/crates/mir2/src/ir/basic_block.rs b/crates/mir2/src/ir/basic_block.rs new file mode 100644 index 0000000000..359c4c76f6 --- /dev/null +++ b/crates/mir2/src/ir/basic_block.rs @@ -0,0 +1,6 @@ +use id_arena::Id; + +pub type BasicBlockId = Id; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BasicBlock {} diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs new file mode 100644 index 0000000000..3dee893a73 --- /dev/null +++ b/crates/mir2/src/ir/body_builder.rs @@ -0,0 +1,381 @@ +use fe_analyzer::namespace::items::ContractId; +use num_bigint::BigInt; + +use crate::ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::{BinOp, Inst, InstKind, UnOp}, + value::{AssignableValue, Local}, + BasicBlock, BasicBlockId, FunctionBody, FunctionId, InstId, SourceInfo, TypeId, +}; + +use super::{ + inst::{CallType, CastKind, SwitchTable, YulIntrinsicOp}, + ConstantId, Value, ValueId, +}; + +#[derive(Debug)] +pub struct BodyBuilder { + pub body: FunctionBody, + loc: CursorLocation, +} + +macro_rules! impl_unary_inst { + ($name:ident, $code:path) => { + pub fn $name(&mut self, value: ValueId, source: SourceInfo) -> InstId { + let inst = Inst::unary($code, value, source); + self.insert_inst(inst) + } + }; +} + +macro_rules! impl_binary_inst { + ($name:ident, $code:path) => { + pub fn $name(&mut self, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> InstId { + let inst = Inst::binary($code, lhs, rhs, source); + self.insert_inst(inst) + } + }; +} + +impl BodyBuilder { + pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + let body = FunctionBody::new(fid, source); + let entry_block = body.order.entry(); + Self { + body, + loc: CursorLocation::BlockTop(entry_block), + } + } + + pub fn build(self) -> FunctionBody { + self.body + } + + pub fn func_id(&self) -> FunctionId { + self.body.fid + } + + pub fn make_block(&mut self) -> BasicBlockId { + let block = BasicBlock {}; + let block_id = self.body.store.store_block(block); + self.body.order.append_block(block_id); + block_id + } + + pub fn make_value(&mut self, value: impl Into) -> ValueId { + self.body.store.store_value(value.into()) + } + + pub fn map_result(&mut self, inst: InstId, result: AssignableValue) { + self.body.store.map_result(inst, result) + } + + pub fn inst_result(&mut self, inst: InstId) -> Option<&AssignableValue> { + self.body.store.inst_result(inst) + } + + pub fn move_to_block(&mut self, block: BasicBlockId) { + self.loc = CursorLocation::BlockBottom(block) + } + + pub fn move_to_block_top(&mut self, block: BasicBlockId) { + self.loc = CursorLocation::BlockTop(block) + } + + pub fn make_unit(&mut self, unit_ty: TypeId) -> ValueId { + self.body.store.store_value(Value::Unit { ty: unit_ty }) + } + + pub fn make_imm(&mut self, imm: BigInt, ty: TypeId) -> ValueId { + self.body.store.store_value(Value::Immediate { imm, ty }) + } + + pub fn make_imm_from_bool(&mut self, imm: bool, ty: TypeId) -> ValueId { + if imm { + self.make_imm(1u8.into(), ty) + } else { + self.make_imm(0u8.into(), ty) + } + } + + pub fn make_constant(&mut self, constant: ConstantId, ty: TypeId) -> ValueId { + self.body + .store + .store_value(Value::Constant { constant, ty }) + } + + pub fn declare(&mut self, local: Local) -> ValueId { + let source = local.source.clone(); + let local_id = self.body.store.store_value(Value::Local(local)); + + let kind = InstKind::Declare { local: local_id }; + let inst = Inst::new(kind, source); + self.insert_inst(inst); + local_id + } + + pub fn store_func_arg(&mut self, local: Local) -> ValueId { + self.body.store.store_value(Value::Local(local)) + } + + impl_unary_inst!(not, UnOp::Not); + impl_unary_inst!(neg, UnOp::Neg); + impl_unary_inst!(inv, UnOp::Inv); + + impl_binary_inst!(add, BinOp::Add); + impl_binary_inst!(sub, BinOp::Sub); + impl_binary_inst!(mul, BinOp::Mul); + impl_binary_inst!(div, BinOp::Div); + impl_binary_inst!(modulo, BinOp::Mod); + impl_binary_inst!(pow, BinOp::Pow); + impl_binary_inst!(shl, BinOp::Shl); + impl_binary_inst!(shr, BinOp::Shr); + impl_binary_inst!(bit_or, BinOp::BitOr); + impl_binary_inst!(bit_xor, BinOp::BitXor); + impl_binary_inst!(bit_and, BinOp::BitAnd); + impl_binary_inst!(logical_and, BinOp::LogicalAnd); + impl_binary_inst!(logical_or, BinOp::LogicalOr); + impl_binary_inst!(eq, BinOp::Eq); + impl_binary_inst!(ne, BinOp::Ne); + impl_binary_inst!(ge, BinOp::Ge); + impl_binary_inst!(gt, BinOp::Gt); + impl_binary_inst!(le, BinOp::Le); + impl_binary_inst!(lt, BinOp::Lt); + + pub fn primitive_cast( + &mut self, + value: ValueId, + result_ty: TypeId, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::Cast { + kind: CastKind::Primitive, + value, + to: result_ty, + }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId, source: SourceInfo) -> InstId { + let kind = InstKind::Cast { + kind: CastKind::Untag, + value, + to: result_ty, + }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn aggregate_construct( + &mut self, + ty: TypeId, + args: Vec, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::AggregateConstruct { ty, args }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn bind(&mut self, src: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::Bind { src }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn mem_copy(&mut self, src: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::MemCopy { src }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn load(&mut self, src: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::Load { src }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn aggregate_access( + &mut self, + value: ValueId, + indices: Vec, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::AggregateAccess { value, indices }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn map_access(&mut self, value: ValueId, key: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::MapAccess { value, key }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn call( + &mut self, + func: FunctionId, + args: Vec, + call_type: CallType, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::Call { + func, + args, + call_type, + }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn keccak256(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::Keccak256 { arg }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn abi_encode(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::AbiEncode { arg }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn create(&mut self, value: ValueId, contract: ContractId, source: SourceInfo) -> InstId { + let kind = InstKind::Create { value, contract }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn create2( + &mut self, + value: ValueId, + salt: ValueId, + contract: ContractId, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::Create2 { + value, + salt, + contract, + }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn yul_intrinsic( + &mut self, + op: YulIntrinsicOp, + args: Vec, + source: SourceInfo, + ) -> InstId { + let inst = Inst::intrinsic(op, args, source); + self.insert_inst(inst) + } + + pub fn jump(&mut self, dest: BasicBlockId, source: SourceInfo) -> InstId { + let kind = InstKind::Jump { dest }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn branch( + &mut self, + cond: ValueId, + then: BasicBlockId, + else_: BasicBlockId, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::Branch { cond, then, else_ }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn switch( + &mut self, + disc: ValueId, + table: SwitchTable, + default: Option, + source: SourceInfo, + ) -> InstId { + let kind = InstKind::Switch { + disc, + table, + default, + }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn revert(&mut self, arg: Option, source: SourceInfo) -> InstId { + let kind = InstKind::Revert { arg }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn emit(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::Emit { arg }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn ret(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + let kind = InstKind::Return { arg: arg.into() }; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn nop(&mut self, source: SourceInfo) -> InstId { + let kind = InstKind::Nop; + let inst = Inst::new(kind, source); + self.insert_inst(inst) + } + + pub fn value_ty(&mut self, value: ValueId) -> TypeId { + self.body.store.value_ty(value) + } + + pub fn value_data(&mut self, value: ValueId) -> &Value { + self.body.store.value_data(value) + } + + /// Returns `true` if current block is terminated. + pub fn is_block_terminated(&mut self, block: BasicBlockId) -> bool { + self.body.order.is_terminated(&self.body.store, block) + } + + pub fn is_current_block_terminated(&mut self) -> bool { + let current_block = self.current_block(); + self.is_block_terminated(current_block) + } + + pub fn current_block(&mut self) -> BasicBlockId { + self.cursor().expect_block() + } + + pub fn remove_inst(&mut self, inst: InstId) { + let mut cursor = BodyCursor::new(&mut self.body, CursorLocation::Inst(inst)); + if self.loc == cursor.loc() { + self.loc = cursor.prev_loc(); + } + cursor.remove_inst(); + } + + pub fn inst_data(&self, inst: InstId) -> &Inst { + self.body.store.inst_data(inst) + } + + fn insert_inst(&mut self, inst: Inst) -> InstId { + let mut cursor = self.cursor(); + let inst_id = cursor.store_and_insert_inst(inst); + + // Set cursor to the new inst. + self.loc = CursorLocation::Inst(inst_id); + + inst_id + } + + fn cursor(&mut self) -> BodyCursor { + BodyCursor::new(&mut self.body, self.loc) + } +} diff --git a/crates/mir2/src/ir/body_cursor.rs b/crates/mir2/src/ir/body_cursor.rs new file mode 100644 index 0000000000..ed4199a345 --- /dev/null +++ b/crates/mir2/src/ir/body_cursor.rs @@ -0,0 +1,231 @@ +//! This module provides a collection of structs to modify function body +//! in-place. +// The design used here is greatly inspired by [`cranelift`](https://crates.io/crates/cranelift) + +use super::{ + value::AssignableValue, BasicBlock, BasicBlockId, FunctionBody, Inst, InstId, ValueId, +}; + +/// Specify a current location of [`BodyCursor`] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CursorLocation { + Inst(InstId), + BlockTop(BasicBlockId), + BlockBottom(BasicBlockId), + NoWhere, +} + +pub struct BodyCursor<'a> { + body: &'a mut FunctionBody, + loc: CursorLocation, +} + +impl<'a> BodyCursor<'a> { + pub fn new(body: &'a mut FunctionBody, loc: CursorLocation) -> Self { + Self { body, loc } + } + + pub fn new_at_entry(body: &'a mut FunctionBody) -> Self { + let entry = body.order.entry(); + Self { + body, + loc: CursorLocation::BlockTop(entry), + } + } + pub fn set_loc(&mut self, loc: CursorLocation) { + self.loc = loc; + } + + pub fn loc(&self) -> CursorLocation { + self.loc + } + + pub fn next_loc(&self) -> CursorLocation { + match self.loc() { + CursorLocation::Inst(inst) => self.body.order.next_inst(inst).map_or_else( + || CursorLocation::BlockBottom(self.body.order.inst_block(inst)), + CursorLocation::Inst, + ), + CursorLocation::BlockTop(block) => self + .body + .order + .first_inst(block) + .map_or_else(|| CursorLocation::BlockBottom(block), CursorLocation::Inst), + CursorLocation::BlockBottom(block) => self + .body() + .order + .next_block(block) + .map_or(CursorLocation::NoWhere, |next_block| { + CursorLocation::BlockTop(next_block) + }), + CursorLocation::NoWhere => CursorLocation::NoWhere, + } + } + + pub fn prev_loc(&self) -> CursorLocation { + match self.loc() { + CursorLocation::Inst(inst) => self.body.order.prev_inst(inst).map_or_else( + || CursorLocation::BlockTop(self.body.order.inst_block(inst)), + CursorLocation::Inst, + ), + CursorLocation::BlockTop(block) => self + .body + .order + .prev_block(block) + .map_or(CursorLocation::NoWhere, |prev_block| { + CursorLocation::BlockBottom(prev_block) + }), + CursorLocation::BlockBottom(block) => self + .body + .order + .last_inst(block) + .map_or_else(|| CursorLocation::BlockTop(block), CursorLocation::Inst), + CursorLocation::NoWhere => CursorLocation::NoWhere, + } + } + + pub fn next_block(&self) -> Option { + let block = self.expect_block(); + self.body.order.next_block(block) + } + + pub fn prev_block(&self) -> Option { + let block = self.expect_block(); + self.body.order.prev_block(block) + } + + pub fn proceed(&mut self) { + self.set_loc(self.next_loc()) + } + + pub fn back(&mut self) { + self.set_loc(self.prev_loc()); + } + + pub fn body(&self) -> &FunctionBody { + self.body + } + + pub fn body_mut(&mut self) -> &mut FunctionBody { + self.body + } + + /// Sets a cursor to an entry block. + pub fn set_to_entry(&mut self) { + let entry_bb = self.body().order.entry(); + let loc = CursorLocation::BlockTop(entry_bb); + self.set_loc(loc); + } + + /// Insert [`InstId`] to a location where a cursor points. + /// If you need to store and insert [`Inst`], use [`store_and_insert_inst`]. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn insert_inst(&mut self, inst: InstId) { + match self.loc() { + CursorLocation::Inst(at) => self.body.order.insert_inst_after(inst, at), + CursorLocation::BlockTop(block) => self.body.order.prepend_inst(inst, block), + CursorLocation::BlockBottom(block) => self.body.order.append_inst(inst, block), + CursorLocation::NoWhere => panic!("cursor loc points to `NoWhere`"), + } + } + + pub fn store_and_insert_inst(&mut self, data: Inst) -> InstId { + let inst = self.body.store.store_inst(data); + self.insert_inst(inst); + inst + } + + /// Remove a current pointed [`Inst`] from a function body. A cursor + /// proceeds to a next inst. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn remove_inst(&mut self) { + let inst = self.expect_inst(); + let next_loc = self.next_loc(); + self.body.order.remove_inst(inst); + self.set_loc(next_loc); + } + + /// Remove a current pointed `block` and contained insts from a function + /// body. A cursor proceeds to a next block. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn remove_block(&mut self) { + let block = match self.loc() { + CursorLocation::Inst(inst) => self.body.order.inst_block(inst), + CursorLocation::BlockTop(block) | CursorLocation::BlockBottom(block) => block, + CursorLocation::NoWhere => panic!("cursor loc points `NoWhere`"), + }; + + // Store next block of the current block for later use. + let next_block = self.body.order.next_block(block); + + // Remove all insts in the current block. + if let Some(first_inst) = self.body.order.first_inst(block) { + self.set_loc(CursorLocation::Inst(first_inst)); + while matches!(self.loc(), CursorLocation::Inst(..)) { + self.remove_inst(); + } + } + // Remove current block. + self.body.order.remove_block(block); + + // Set cursor location to next block if exists. + if let Some(next_block) = next_block { + self.set_loc(CursorLocation::BlockTop(next_block)) + } else { + self.set_loc(CursorLocation::NoWhere) + } + } + + /// Insert [`BasicBlockId`] to a location where a cursor points. + /// If you need to store and insert [`BasicBlock`], use + /// [`store_and_insert_block`]. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn insert_block(&mut self, block: BasicBlockId) { + let current = self.expect_block(); + self.body.order.insert_block_after_block(block, current) + } + + pub fn store_and_insert_block(&mut self, block: BasicBlock) -> BasicBlockId { + let block_id = self.body.store.store_block(block); + self.insert_block(block_id); + block_id + } + + pub fn map_result(&mut self, result: AssignableValue) -> Option { + let inst = self.expect_inst(); + let result_value = result.value_id(); + self.body.store.map_result(inst, result); + result_value + } + + /// Returns current inst that cursor points. + /// + /// # Panics + /// Panics if a cursor doesn't point [`CursorLocation::Inst`]. + pub fn expect_inst(&self) -> InstId { + match self.loc { + CursorLocation::Inst(inst) => inst, + _ => panic!("Cursor doesn't point any inst."), + } + } + + /// Returns current block that cursor points. + /// + /// # Panics + /// Panics if a cursor points [`CursorLocation::NoWhere`]. + pub fn expect_block(&self) -> BasicBlockId { + match self.loc { + CursorLocation::Inst(inst) => self.body.order.inst_block(inst), + CursorLocation::BlockTop(block) | CursorLocation::BlockBottom(block) => block, + CursorLocation::NoWhere => panic!("cursor loc points `NoWhere`"), + } + } +} diff --git a/crates/mir2/src/ir/body_order.rs b/crates/mir2/src/ir/body_order.rs new file mode 100644 index 0000000000..70df3cf76a --- /dev/null +++ b/crates/mir2/src/ir/body_order.rs @@ -0,0 +1,473 @@ +use fxhash::FxHashMap; + +use super::{basic_block::BasicBlockId, function::BodyDataStore, inst::InstId}; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// Represents basic block order and instruction order. +pub struct BodyOrder { + blocks: FxHashMap, + insts: FxHashMap, + entry_block: BasicBlockId, + last_block: BasicBlockId, +} +impl BodyOrder { + pub fn new(entry_block: BasicBlockId) -> Self { + let entry_block_node = BlockNode::default(); + let mut blocks = FxHashMap::default(); + blocks.insert(entry_block, entry_block_node); + + Self { + blocks, + insts: FxHashMap::default(), + entry_block, + last_block: entry_block, + } + } + + /// Returns an entry block of a function body. + pub fn entry(&self) -> BasicBlockId { + self.entry_block + } + + /// Returns a last block of a function body. + pub fn last_block(&self) -> BasicBlockId { + self.last_block + } + + /// Returns `true` if a block doesn't contain any block. + pub fn is_block_empty(&self, block: BasicBlockId) -> bool { + self.first_inst(block).is_none() + } + + /// Returns `true` if a function body contains a given `block`. + pub fn is_block_inserted(&self, block: BasicBlockId) -> bool { + self.blocks.contains_key(&block) + } + + /// Returns a number of block in a function. + pub fn block_num(&self) -> usize { + self.blocks.len() + } + + /// Returns a previous block of a given block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn prev_block(&self, block: BasicBlockId) -> Option { + self.blocks[&block].prev + } + + /// Returns a next block of a given block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn next_block(&self, block: BasicBlockId) -> Option { + self.blocks[&block].next + } + + /// Returns `true` is a given `inst` is inserted. + pub fn is_inst_inserted(&self, inst: InstId) -> bool { + self.insts.contains_key(&inst) + } + + /// Returns first instruction of a block if exists. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn first_inst(&self, block: BasicBlockId) -> Option { + self.blocks[&block].first_inst + } + + /// Returns a terminator instruction of a block. + /// + /// # Panics + /// Panics if + /// 1. `block` is not inserted yet. + pub fn terminator(&self, store: &BodyDataStore, block: BasicBlockId) -> Option { + let last_inst = self.last_inst(block)?; + if store.is_terminator(last_inst) { + Some(last_inst) + } else { + None + } + } + + /// Returns `true` if a `block` is terminated. + pub fn is_terminated(&self, store: &BodyDataStore, block: BasicBlockId) -> bool { + self.terminator(store, block).is_some() + } + + /// Returns a last instruction of a block. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn last_inst(&self, block: BasicBlockId) -> Option { + self.blocks[&block].last_inst + } + + /// Returns a previous instruction of a given `inst`. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn prev_inst(&self, inst: InstId) -> Option { + self.insts[&inst].prev + } + + /// Returns a next instruction of a given `inst`. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn next_inst(&self, inst: InstId) -> Option { + self.insts[&inst].next + } + + /// Returns a block to which a given `inst` belongs. + /// + /// # Panics + /// Panics if `inst` is not inserted yet. + pub fn inst_block(&self, inst: InstId) -> BasicBlockId { + self.insts[&inst].block + } + + /// Returns an iterator which iterates all basic blocks in a function body + /// in pre-order. + pub fn iter_block(&self) -> impl Iterator + '_ { + BlockIter { + next: Some(self.entry_block), + blocks: &self.blocks, + } + } + + /// Returns an iterator which iterates all instruction in a given `block` in + /// pre-order. + /// + /// # Panics + /// Panics if `block` is not inserted yet. + pub fn iter_inst(&self, block: BasicBlockId) -> impl Iterator + '_ { + InstIter { + next: self.blocks[&block].first_inst, + insts: &self.insts, + } + } + + /// Appends a given `block` to a function body. + /// + /// # Panics + /// Panics if a given `block` is already inserted to a function. + pub fn append_block(&mut self, block: BasicBlockId) { + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + let last_block = self.last_block; + let last_block_node = &mut self.block_mut(last_block); + last_block_node.next = Some(block); + block_node.prev = Some(last_block); + + self.blocks.insert(block, block_node); + self.last_block = block; + } + + /// Inserts a given `block` before a `before` block. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is already inserted. + /// 2. a given `before` block is NOTE inserted yet. + pub fn insert_block_before_block(&mut self, block: BasicBlockId, before: BasicBlockId) { + debug_assert!(self.is_block_inserted(before)); + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + + match self.blocks[&before].prev { + Some(prev) => { + block_node.prev = Some(prev); + self.block_mut(prev).next = Some(block); + } + None => self.entry_block = block, + } + + block_node.next = Some(before); + self.block_mut(before).prev = Some(block); + self.blocks.insert(block, block_node); + } + + /// Inserts a given `block` after a `after` block. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is already inserted. + /// 2. a given `after` block is NOTE inserted yet. + pub fn insert_block_after_block(&mut self, block: BasicBlockId, after: BasicBlockId) { + debug_assert!(self.is_block_inserted(after)); + debug_assert!(!self.is_block_inserted(block)); + + let mut block_node = BlockNode::default(); + + match self.blocks[&after].next { + Some(next) => { + block_node.next = Some(next); + self.block_mut(next).prev = Some(block); + } + None => self.last_block = block, + } + block_node.prev = Some(after); + self.block_mut(after).next = Some(block); + self.blocks.insert(block, block_node); + } + + /// Remove a given `block` from a function. All instructions in a block are + /// also removed. + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a `block` is the last one block in a function. + pub fn remove_block(&mut self, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(self.block_num() > 1); + + // Remove all insts in a `block`. + let mut next_inst = self.first_inst(block); + while let Some(inst) = next_inst { + next_inst = self.next_inst(inst); + self.remove_inst(inst); + } + + // Remove `block`. + let block_node = &self.blocks[&block]; + let prev_block = block_node.prev; + let next_block = block_node.next; + match (prev_block, next_block) { + // `block` is in the middle of a function. + (Some(prev), Some(next)) => { + self.block_mut(prev).next = Some(next); + self.block_mut(next).prev = Some(prev); + } + // `block` is the last block of a function. + (Some(prev), None) => { + self.block_mut(prev).next = None; + self.last_block = prev; + } + // `block` is the first block of a function. + (None, Some(next)) => { + self.block_mut(next).prev = None; + self.entry_block = next + } + (None, None) => { + unreachable!() + } + } + + self.blocks.remove(&block); + } + + /// Appends `inst` to the end of a `block` + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn append_inst(&mut self, inst: InstId, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(!self.is_inst_inserted(inst)); + + let mut inst_node = InstNode::new(block); + + if let Some(last_inst) = self.blocks[&block].last_inst { + inst_node.prev = Some(last_inst); + self.inst_mut(last_inst).next = Some(inst); + } else { + self.block_mut(block).first_inst = Some(inst); + } + + self.block_mut(block).last_inst = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Prepends `inst` to the beginning of a `block` + /// + /// # Panics + /// Panics if + /// 1. a given `block` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn prepend_inst(&mut self, inst: InstId, block: BasicBlockId) { + debug_assert!(self.is_block_inserted(block)); + debug_assert!(!self.is_inst_inserted(inst)); + + let mut inst_node = InstNode::new(block); + + if let Some(first_inst) = self.blocks[&block].first_inst { + inst_node.next = Some(first_inst); + self.inst_mut(first_inst).prev = Some(inst); + } else { + self.block_mut(block).last_inst = Some(inst); + } + + self.block_mut(block).first_inst = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Insert `inst` before `before` inst. + /// + /// # Panics + /// Panics if + /// 1. a given `before` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn insert_inst_before_inst(&mut self, inst: InstId, before: InstId) { + debug_assert!(self.is_inst_inserted(before)); + debug_assert!(!self.is_inst_inserted(inst)); + + let before_inst_node = &self.insts[&before]; + let block = before_inst_node.block; + let mut inst_node = InstNode::new(block); + + match before_inst_node.prev { + Some(prev) => { + inst_node.prev = Some(prev); + self.inst_mut(prev).next = Some(inst); + } + None => self.block_mut(block).first_inst = Some(inst), + } + inst_node.next = Some(before); + self.inst_mut(before).prev = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Insert `inst` after `after` inst. + /// + /// # Panics + /// Panics if + /// 1. a given `after` is NOT inserted. + /// 2. a given `inst` is already inserted. + pub fn insert_inst_after(&mut self, inst: InstId, after: InstId) { + debug_assert!(self.is_inst_inserted(after)); + debug_assert!(!self.is_inst_inserted(inst)); + + let after_inst_node = &self.insts[&after]; + let block = after_inst_node.block; + let mut inst_node = InstNode::new(block); + + match after_inst_node.next { + Some(next) => { + inst_node.next = Some(next); + self.inst_mut(next).prev = Some(inst); + } + None => self.block_mut(block).last_inst = Some(inst), + } + inst_node.prev = Some(after); + self.inst_mut(after).next = Some(inst); + self.insts.insert(inst, inst_node); + } + + /// Remove instruction from the function body. + /// + /// # Panics + /// Panics if a given `inst` is not inserted. + pub fn remove_inst(&mut self, inst: InstId) { + debug_assert!(self.is_inst_inserted(inst)); + + let inst_node = &self.insts[&inst]; + let inst_block = inst_node.block; + let prev_inst = inst_node.prev; + let next_inst = inst_node.next; + match (prev_inst, next_inst) { + (Some(prev), Some(next)) => { + self.inst_mut(prev).next = Some(next); + self.inst_mut(next).prev = Some(prev); + } + (Some(prev), None) => { + self.inst_mut(prev).next = None; + self.block_mut(inst_block).last_inst = Some(prev); + } + (None, Some(next)) => { + self.inst_mut(next).prev = None; + self.block_mut(inst_block).first_inst = Some(next); + } + (None, None) => { + let block_node = self.block_mut(inst_block); + block_node.first_inst = None; + block_node.last_inst = None; + } + } + + self.insts.remove(&inst); + } + + fn block_mut(&mut self, block: BasicBlockId) -> &mut BlockNode { + self.blocks.get_mut(&block).unwrap() + } + + fn inst_mut(&mut self, inst: InstId) -> &mut InstNode { + self.insts.get_mut(&inst).unwrap() + } +} + +struct BlockIter<'a> { + next: Option, + blocks: &'a FxHashMap, +} + +impl<'a> Iterator for BlockIter<'a> { + type Item = BasicBlockId; + + fn next(&mut self) -> Option { + let next = self.next?; + self.next = self.blocks[&next].next; + Some(next) + } +} + +struct InstIter<'a> { + next: Option, + insts: &'a FxHashMap, +} + +impl<'a> Iterator for InstIter<'a> { + type Item = InstId; + + fn next(&mut self) -> Option { + let next = self.next?; + self.next = self.insts[&next].next; + Some(next) + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +/// A helper struct to track a basic block order in a function body. +struct BlockNode { + /// A previous block. + prev: Option, + + /// A next block. + next: Option, + + /// A first instruction of a block. + first_inst: Option, + + /// A last instruction of a block. + last_inst: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +/// A helper struct to track a instruction order in a basic block. +struct InstNode { + /// An block to which a inst belongs. + block: BasicBlockId, + + /// A previous instruction. + prev: Option, + + /// A next instruction. + next: Option, +} + +impl InstNode { + fn new(block: BasicBlockId) -> Self { + Self { + block, + prev: None, + next: None, + } + } +} diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs new file mode 100644 index 0000000000..68466ff3e7 --- /dev/null +++ b/crates/mir2/src/ir/constant.rs @@ -0,0 +1,47 @@ +use fe_common::impl_intern_key; +use num_bigint::BigInt; +use smol_str::SmolStr; + +use fe_analyzer::{context, namespace::items as analyzer_items}; + +use super::{SourceInfo, TypeId}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Constant { + /// A name of a constant. + pub name: SmolStr, + + /// A value of a constant. + pub value: ConstantValue, + + /// A type of a constant. + pub ty: TypeId, + + /// A module where a constant is declared. + pub module_id: analyzer_items::ModuleId, + + /// A span where a constant is declared. + pub source: SourceInfo, +} + +/// An interned Id for [`Constant`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ConstantId(pub(crate) u32); +impl_intern_key!(ConstantId); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ConstantValue { + Immediate(BigInt), + Str(SmolStr), + Bool(bool), +} + +impl From for ConstantValue { + fn from(value: context::Constant) -> Self { + match value { + context::Constant::Int(num) | context::Constant::Address(num) => Self::Immediate(num), + context::Constant::Str(s) => Self::Str(s), + context::Constant::Bool(b) => Self::Bool(b), + } + } +} diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs new file mode 100644 index 0000000000..c359f20f71 --- /dev/null +++ b/crates/mir2/src/ir/function.rs @@ -0,0 +1,274 @@ +use fe_analyzer::namespace::{items as analyzer_items, types as analyzer_types}; +use fe_common::impl_intern_key; +use fxhash::FxHashMap; +use id_arena::Arena; +use num_bigint::BigInt; +use smol_str::SmolStr; +use std::collections::BTreeMap; + +use super::{ + basic_block::BasicBlock, + body_order::BodyOrder, + inst::{BranchInfo, Inst, InstId, InstKind}, + types::TypeId, + value::{AssignableValue, Local, Value, ValueId}, + BasicBlockId, SourceInfo, +}; + +/// Represents function signature. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FunctionSignature { + pub params: Vec, + pub resolved_generics: BTreeMap, + pub return_type: Option, + pub module_id: analyzer_items::ModuleId, + pub analyzer_func_id: analyzer_items::FunctionId, + pub linkage: Linkage, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FunctionParam { + pub name: SmolStr, + pub ty: TypeId, + pub source: SourceInfo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct FunctionId(pub u32); +impl_intern_key!(FunctionId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Linkage { + /// A function can only be called within the same module. + Private, + + /// A function can be called from other modules, but can NOT be called from + /// other accounts and transactions. + Public, + + /// A function can be called from other modules, and also can be called from + /// other accounts and transactions. + Export, +} + +impl Linkage { + pub fn is_exported(self) -> bool { + self == Linkage::Export + } +} + +/// A function body, which is not stored in salsa db to enable in-place +/// transformation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionBody { + pub fid: FunctionId, + + pub store: BodyDataStore, + + /// Tracks order of basic blocks and instructions in a function body. + pub order: BodyOrder, + + pub source: SourceInfo, +} + +impl FunctionBody { + pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + let mut store = BodyDataStore::default(); + let entry_bb = store.store_block(BasicBlock {}); + Self { + fid, + store, + order: BodyOrder::new(entry_bb), + source, + } + } +} + +/// A collection of basic block, instructions and values appear in a function +/// body. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct BodyDataStore { + /// Instructions appear in a function body. + insts: Arena, + + /// All values in a function. + values: Arena, + + blocks: Arena, + + /// Maps an immediate to a value to ensure the same immediate results in the + /// same value. + immediates: FxHashMap<(BigInt, TypeId), ValueId>, + + unit_value: Option, + + /// Maps an instruction to a value. + inst_results: FxHashMap, + + /// All declared local variables in a function. + locals: Vec, +} + +impl BodyDataStore { + pub fn store_inst(&mut self, inst: Inst) -> InstId { + self.insts.alloc(inst) + } + + pub fn inst_data(&self, inst: InstId) -> &Inst { + &self.insts[inst] + } + + pub fn inst_data_mut(&mut self, inst: InstId) -> &mut Inst { + &mut self.insts[inst] + } + + pub fn replace_inst(&mut self, inst: InstId, new: Inst) -> Inst { + let old = &mut self.insts[inst]; + std::mem::replace(old, new) + } + + pub fn store_value(&mut self, value: Value) -> ValueId { + match value { + Value::Immediate { imm, ty } => self.store_immediate(imm, ty), + + Value::Unit { .. } => { + if let Some(unit_value) = self.unit_value { + unit_value + } else { + let unit_value = self.values.alloc(value); + self.unit_value = Some(unit_value); + unit_value + } + } + + Value::Local(ref local) => { + let is_user_defined = !local.is_tmp; + let value_id = self.values.alloc(value); + if is_user_defined { + self.locals.push(value_id); + } + value_id + } + + _ => self.values.alloc(value), + } + } + + pub fn is_nop(&self, inst: InstId) -> bool { + matches!(&self.inst_data(inst).kind, InstKind::Nop) + } + + pub fn is_terminator(&self, inst: InstId) -> bool { + self.inst_data(inst).is_terminator() + } + + pub fn branch_info(&self, inst: InstId) -> BranchInfo { + self.inst_data(inst).branch_info() + } + + pub fn value_data(&self, value: ValueId) -> &Value { + &self.values[value] + } + + pub fn value_data_mut(&mut self, value: ValueId) -> &mut Value { + &mut self.values[value] + } + + pub fn values(&self) -> impl Iterator { + self.values.iter().map(|(_, value_data)| value_data) + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.values.iter_mut().map(|(_, value_data)| value_data) + } + + pub fn store_block(&mut self, block: BasicBlock) -> BasicBlockId { + self.blocks.alloc(block) + } + + /// Returns an instruction result + pub fn inst_result(&self, inst: InstId) -> Option<&AssignableValue> { + self.inst_results.get(&inst) + } + + pub fn map_result(&mut self, inst: InstId, result: AssignableValue) { + self.inst_results.insert(inst, result); + } + + pub fn remove_inst_result(&mut self, inst: InstId) -> Option { + self.inst_results.remove(&inst) + } + + pub fn rewrite_branch_dest(&mut self, inst: InstId, from: BasicBlockId, to: BasicBlockId) { + match &mut self.inst_data_mut(inst).kind { + InstKind::Jump { dest } => { + if *dest == from { + *dest = to; + } + } + InstKind::Branch { then, else_, .. } => { + if *then == from { + *then = to; + } + if *else_ == from { + *else_ = to; + } + } + _ => unreachable!("inst is not a branch"), + } + } + + pub fn value_ty(&self, vid: ValueId) -> TypeId { + self.values[vid].ty() + } + + pub fn locals(&self) -> &[ValueId] { + &self.locals + } + + pub fn locals_mut(&mut self) -> &[ValueId] { + &mut self.locals + } + + pub fn func_args(&self) -> impl Iterator + '_ { + self.locals() + .iter() + .filter(|value| match self.value_data(**value) { + Value::Local(local) => local.is_arg, + _ => unreachable!(), + }) + .copied() + } + + pub fn func_args_mut(&mut self) -> impl Iterator { + self.values_mut().filter(|value| match value { + Value::Local(local) => local.is_arg, + _ => false, + }) + } + + /// Returns Some(`local_name`) if value is `Value::Local`. + pub fn local_name(&self, value: ValueId) -> Option<&str> { + match self.value_data(value) { + Value::Local(Local { name, .. }) => Some(name), + _ => None, + } + } + + pub fn replace_value(&mut self, value: ValueId, to: Value) -> Value { + std::mem::replace(&mut self.values[value], to) + } + + fn store_immediate(&mut self, imm: BigInt, ty: TypeId) -> ValueId { + if let Some(value) = self.immediates.get(&(imm.clone(), ty)) { + *value + } else { + let id = self.values.alloc(Value::Immediate { + imm: imm.clone(), + ty, + }); + self.immediates.insert((imm, ty), id); + id + } + } +} diff --git a/crates/mir2/src/ir/inst.rs b/crates/mir2/src/ir/inst.rs new file mode 100644 index 0000000000..4ef76fa906 --- /dev/null +++ b/crates/mir2/src/ir/inst.rs @@ -0,0 +1,764 @@ +use std::fmt; + +use fe_analyzer::namespace::items::ContractId; +use id_arena::Id; + +use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId, SourceInfo, TypeId}; + +pub type InstId = Id; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Inst { + pub kind: InstKind, + pub source: SourceInfo, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum InstKind { + /// This is not a real instruction, just used to tag a position where a + /// local is declared. + Declare { + local: ValueId, + }, + + /// Unary instruction. + Unary { + op: UnOp, + value: ValueId, + }, + + /// Binary instruction. + Binary { + op: BinOp, + lhs: ValueId, + rhs: ValueId, + }, + + Cast { + kind: CastKind, + value: ValueId, + to: TypeId, + }, + + /// Constructs aggregate value, i.e. struct, tuple and array. + AggregateConstruct { + ty: TypeId, + args: Vec, + }, + + Bind { + src: ValueId, + }, + + MemCopy { + src: ValueId, + }, + + /// Load a primitive value from a ptr + Load { + src: ValueId, + }, + + /// Access to aggregate fields or elements. + /// # Example + /// + /// ```fe + /// struct Foo: + /// x: i32 + /// y: Array + /// ``` + /// `foo.y` is lowered into `AggregateAccess(foo, [1])' for example. + AggregateAccess { + value: ValueId, + indices: Vec, + }, + + MapAccess { + key: ValueId, + value: ValueId, + }, + + Call { + func: FunctionId, + args: Vec, + call_type: CallType, + }, + + /// Unconditional jump instruction. + Jump { + dest: BasicBlockId, + }, + + /// Conditional branching instruction. + Branch { + cond: ValueId, + then: BasicBlockId, + else_: BasicBlockId, + }, + + Switch { + disc: ValueId, + table: SwitchTable, + default: Option, + }, + + Revert { + arg: Option, + }, + + Emit { + arg: ValueId, + }, + + Return { + arg: Option, + }, + + Keccak256 { + arg: ValueId, + }, + + AbiEncode { + arg: ValueId, + }, + + Nop, + + Create { + value: ValueId, + contract: ContractId, + }, + + Create2 { + value: ValueId, + salt: ValueId, + contract: ContractId, + }, + + YulIntrinsic { + op: YulIntrinsicOp, + args: Vec, + }, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] +pub struct SwitchTable { + values: Vec, + blocks: Vec, +} + +impl SwitchTable { + pub fn iter(&self) -> impl Iterator + '_ { + self.values.iter().copied().zip(self.blocks.iter().copied()) + } + + pub fn len(&self) -> usize { + debug_assert!(self.values.len() == self.blocks.len()); + self.values.len() + } + + pub fn is_empty(&self) -> bool { + debug_assert!(self.values.len() == self.blocks.len()); + self.values.is_empty() + } + + pub fn add_arm(&mut self, value: ValueId, block: BasicBlockId) { + self.values.push(value); + self.blocks.push(block); + } +} + +impl Inst { + pub fn new(kind: InstKind, source: SourceInfo) -> Self { + Self { kind, source } + } + + pub fn unary(op: UnOp, value: ValueId, source: SourceInfo) -> Self { + let kind = InstKind::Unary { op, value }; + Self::new(kind, source) + } + + pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> Self { + let kind = InstKind::Binary { op, lhs, rhs }; + Self::new(kind, source) + } + + pub fn intrinsic(op: YulIntrinsicOp, args: Vec, source: SourceInfo) -> Self { + let kind = InstKind::YulIntrinsic { op, args }; + Self::new(kind, source) + } + + pub fn nop() -> Self { + Self { + kind: InstKind::Nop, + source: SourceInfo::dummy(), + } + } + + pub fn is_terminator(&self) -> bool { + match self.kind { + InstKind::Jump { .. } + | InstKind::Branch { .. } + | InstKind::Switch { .. } + | InstKind::Revert { .. } + | InstKind::Return { .. } => true, + InstKind::YulIntrinsic { op, .. } => op.is_terminator(), + _ => false, + } + } + + pub fn branch_info(&self) -> BranchInfo { + match self.kind { + InstKind::Jump { dest } => BranchInfo::Jump(dest), + InstKind::Branch { cond, then, else_ } => BranchInfo::Branch(cond, then, else_), + InstKind::Switch { + disc, + ref table, + default, + } => BranchInfo::Switch(disc, table, default), + _ => BranchInfo::NotBranch, + } + } + + pub fn args(&self) -> ValueIter { + use InstKind::*; + match &self.kind { + Declare { local: arg } + | Bind { src: arg } + | MemCopy { src: arg } + | Load { src: arg } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ValueIter::one(*arg), + + Switch { disc, table, .. } => { + ValueIter::one(*disc).chain(ValueIter::Slice(table.values.iter())) + } + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ValueIter::one(*lhs).chain(ValueIter::one(*rhs)), + + Revert { arg } | Return { arg } => ValueIter::One(*arg), + + Nop | Jump { .. } => ValueIter::Zero, + + AggregateAccess { value, indices } => { + ValueIter::one(*value).chain(ValueIter::Slice(indices.iter())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ValueIter::Slice(args.iter()) + } + } + } + + pub fn args_mut(&mut self) -> ValueIterMut { + use InstKind::*; + match &mut self.kind { + Declare { local: arg } + | Bind { src: arg } + | MemCopy { src: arg } + | Load { src: arg } + | Unary { value: arg, .. } + | Cast { value: arg, .. } + | Emit { arg } + | Keccak256 { arg } + | AbiEncode { arg } + | Create { value: arg, .. } + | Branch { cond: arg, .. } => ValueIterMut::one(arg), + + Switch { disc, table, .. } => { + ValueIterMut::one(disc).chain(ValueIterMut::Slice(table.values.iter_mut())) + } + + Binary { lhs, rhs, .. } + | MapAccess { + value: lhs, + key: rhs, + } + | Create2 { + value: lhs, + salt: rhs, + .. + } => ValueIterMut::one(lhs).chain(ValueIterMut::one(rhs)), + + Revert { arg } | Return { arg } => ValueIterMut::One(arg.as_mut()), + + Nop | Jump { .. } => ValueIterMut::Zero, + + AggregateAccess { value, indices } => { + ValueIterMut::one(value).chain(ValueIterMut::Slice(indices.iter_mut())) + } + + AggregateConstruct { args, .. } | Call { args, .. } | YulIntrinsic { args, .. } => { + ValueIterMut::Slice(args.iter_mut()) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum UnOp { + /// `not` operator for logical inversion. + Not, + /// `-` operator for negation. + Neg, + /// `~` operator for bitwise inversion. + Inv, +} + +impl fmt::Display for UnOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Not => write!(w, "not"), + Self::Neg => write!(w, "-"), + Self::Inv => write!(w, "~"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Mod, + Pow, + Shl, + Shr, + BitOr, + BitXor, + BitAnd, + LogicalAnd, + LogicalOr, + Eq, + Ne, + Ge, + Gt, + Le, + Lt, +} + +impl fmt::Display for BinOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Add => write!(w, "+"), + Self::Sub => write!(w, "-"), + Self::Mul => write!(w, "*"), + Self::Div => write!(w, "/"), + Self::Mod => write!(w, "%"), + Self::Pow => write!(w, "**"), + Self::Shl => write!(w, "<<"), + Self::Shr => write!(w, ">>"), + Self::BitOr => write!(w, "|"), + Self::BitXor => write!(w, "^"), + Self::BitAnd => write!(w, "&"), + Self::LogicalAnd => write!(w, "and"), + Self::LogicalOr => write!(w, "or"), + Self::Eq => write!(w, "=="), + Self::Ne => write!(w, "!="), + Self::Ge => write!(w, ">="), + Self::Gt => write!(w, ">"), + Self::Le => write!(w, "<="), + Self::Lt => write!(w, "<"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CallType { + Internal, + External, +} + +impl fmt::Display for CallType { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Internal => write!(w, "internal"), + Self::External => write!(w, "external"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CastKind { + /// A cast from a primitive type to a primitive type. + Primitive, + + /// A cast from an enum type to its underlying type. + Untag, +} + +// TODO: We don't need all yul intrinsics. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum YulIntrinsicOp { + Stop, + Add, + Sub, + Mul, + Div, + Sdiv, + Mod, + Smod, + Exp, + Not, + Lt, + Gt, + Slt, + Sgt, + Eq, + Iszero, + And, + Or, + Xor, + Byte, + Shl, + Shr, + Sar, + Addmod, + Mulmod, + Signextend, + Keccak256, + Pc, + Pop, + Mload, + Mstore, + Mstore8, + Sload, + Sstore, + Msize, + Gas, + Address, + Balance, + Selfbalance, + Caller, + Callvalue, + Calldataload, + Calldatasize, + Calldatacopy, + Codesize, + Codecopy, + Extcodesize, + Extcodecopy, + Returndatasize, + Returndatacopy, + Extcodehash, + Create, + Create2, + Call, + Callcode, + Delegatecall, + Staticcall, + Return, + Revert, + Selfdestruct, + Invalid, + Log0, + Log1, + Log2, + Log3, + Log4, + Chainid, + Basefee, + Origin, + Gasprice, + Blockhash, + Coinbase, + Timestamp, + Number, + Prevrandao, + Gaslimit, +} +impl YulIntrinsicOp { + pub fn is_terminator(self) -> bool { + matches!( + self, + Self::Return | Self::Revert | Self::Selfdestruct | Self::Invalid + ) + } +} + +impl fmt::Display for YulIntrinsicOp { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + let op = match self { + Self::Stop => "__stop", + Self::Add => "__add", + Self::Sub => "__sub", + Self::Mul => "__mul", + Self::Div => "__div", + Self::Sdiv => "__sdiv", + Self::Mod => "__mod", + Self::Smod => "__smod", + Self::Exp => "__exp", + Self::Not => "__not", + Self::Lt => "__lt", + Self::Gt => "__gt", + Self::Slt => "__slt", + Self::Sgt => "__sgt", + Self::Eq => "__eq", + Self::Iszero => "__iszero", + Self::And => "__and", + Self::Or => "__or", + Self::Xor => "__xor", + Self::Byte => "__byte", + Self::Shl => "__shl", + Self::Shr => "__shr", + Self::Sar => "__sar", + Self::Addmod => "__addmod", + Self::Mulmod => "__mulmod", + Self::Signextend => "__signextend", + Self::Keccak256 => "__keccak256", + Self::Pc => "__pc", + Self::Pop => "__pop", + Self::Mload => "__mload", + Self::Mstore => "__mstore", + Self::Mstore8 => "__mstore8", + Self::Sload => "__sload", + Self::Sstore => "__sstore", + Self::Msize => "__msize", + Self::Gas => "__gas", + Self::Address => "__address", + Self::Balance => "__balance", + Self::Selfbalance => "__selfbalance", + Self::Caller => "__caller", + Self::Callvalue => "__callvalue", + Self::Calldataload => "__calldataload", + Self::Calldatasize => "__calldatasize", + Self::Calldatacopy => "__calldatacopy", + Self::Codesize => "__codesize", + Self::Codecopy => "__codecopy", + Self::Extcodesize => "__extcodesize", + Self::Extcodecopy => "__extcodecopy", + Self::Returndatasize => "__returndatasize", + Self::Returndatacopy => "__returndatacopy", + Self::Extcodehash => "__extcodehash", + Self::Create => "__create", + Self::Create2 => "__create2", + Self::Call => "__call", + Self::Callcode => "__callcode", + Self::Delegatecall => "__delegatecall", + Self::Staticcall => "__staticcall", + Self::Return => "__return", + Self::Revert => "__revert", + Self::Selfdestruct => "__selfdestruct", + Self::Invalid => "__invalid", + Self::Log0 => "__log0", + Self::Log1 => "__log1", + Self::Log2 => "__log2", + Self::Log3 => "__log3", + Self::Log4 => "__log4", + Self::Chainid => "__chainid", + Self::Basefee => "__basefee", + Self::Origin => "__origin", + Self::Gasprice => "__gasprice", + Self::Blockhash => "__blockhash", + Self::Coinbase => "__coinbase", + Self::Timestamp => "__timestamp", + Self::Number => "__number", + Self::Prevrandao => "__prevrandao", + Self::Gaslimit => "__gaslimit", + }; + + write!(w, "{op}") + } +} + +impl From for YulIntrinsicOp { + fn from(val: fe_analyzer::builtins::Intrinsic) -> Self { + use fe_analyzer::builtins::Intrinsic; + match val { + Intrinsic::__stop => Self::Stop, + Intrinsic::__add => Self::Add, + Intrinsic::__sub => Self::Sub, + Intrinsic::__mul => Self::Mul, + Intrinsic::__div => Self::Div, + Intrinsic::__sdiv => Self::Sdiv, + Intrinsic::__mod => Self::Mod, + Intrinsic::__smod => Self::Smod, + Intrinsic::__exp => Self::Exp, + Intrinsic::__not => Self::Not, + Intrinsic::__lt => Self::Lt, + Intrinsic::__gt => Self::Gt, + Intrinsic::__slt => Self::Slt, + Intrinsic::__sgt => Self::Sgt, + Intrinsic::__eq => Self::Eq, + Intrinsic::__iszero => Self::Iszero, + Intrinsic::__and => Self::And, + Intrinsic::__or => Self::Or, + Intrinsic::__xor => Self::Xor, + Intrinsic::__byte => Self::Byte, + Intrinsic::__shl => Self::Shl, + Intrinsic::__shr => Self::Shr, + Intrinsic::__sar => Self::Sar, + Intrinsic::__addmod => Self::Addmod, + Intrinsic::__mulmod => Self::Mulmod, + Intrinsic::__signextend => Self::Signextend, + Intrinsic::__keccak256 => Self::Keccak256, + Intrinsic::__pc => Self::Pc, + Intrinsic::__pop => Self::Pop, + Intrinsic::__mload => Self::Mload, + Intrinsic::__mstore => Self::Mstore, + Intrinsic::__mstore8 => Self::Mstore8, + Intrinsic::__sload => Self::Sload, + Intrinsic::__sstore => Self::Sstore, + Intrinsic::__msize => Self::Msize, + Intrinsic::__gas => Self::Gas, + Intrinsic::__address => Self::Address, + Intrinsic::__balance => Self::Balance, + Intrinsic::__selfbalance => Self::Selfbalance, + Intrinsic::__caller => Self::Caller, + Intrinsic::__callvalue => Self::Callvalue, + Intrinsic::__calldataload => Self::Calldataload, + Intrinsic::__calldatasize => Self::Calldatasize, + Intrinsic::__calldatacopy => Self::Calldatacopy, + Intrinsic::__codesize => Self::Codesize, + Intrinsic::__codecopy => Self::Codecopy, + Intrinsic::__extcodesize => Self::Extcodesize, + Intrinsic::__extcodecopy => Self::Extcodecopy, + Intrinsic::__returndatasize => Self::Returndatasize, + Intrinsic::__returndatacopy => Self::Returndatacopy, + Intrinsic::__extcodehash => Self::Extcodehash, + Intrinsic::__create => Self::Create, + Intrinsic::__create2 => Self::Create2, + Intrinsic::__call => Self::Call, + Intrinsic::__callcode => Self::Callcode, + Intrinsic::__delegatecall => Self::Delegatecall, + Intrinsic::__staticcall => Self::Staticcall, + Intrinsic::__return => Self::Return, + Intrinsic::__revert => Self::Revert, + Intrinsic::__selfdestruct => Self::Selfdestruct, + Intrinsic::__invalid => Self::Invalid, + Intrinsic::__log0 => Self::Log0, + Intrinsic::__log1 => Self::Log1, + Intrinsic::__log2 => Self::Log2, + Intrinsic::__log3 => Self::Log3, + Intrinsic::__log4 => Self::Log4, + Intrinsic::__chainid => Self::Chainid, + Intrinsic::__basefee => Self::Basefee, + Intrinsic::__origin => Self::Origin, + Intrinsic::__gasprice => Self::Gasprice, + Intrinsic::__blockhash => Self::Blockhash, + Intrinsic::__coinbase => Self::Coinbase, + Intrinsic::__timestamp => Self::Timestamp, + Intrinsic::__number => Self::Number, + Intrinsic::__prevrandao => Self::Prevrandao, + Intrinsic::__gaslimit => Self::Gaslimit, + } + } +} + +pub enum BranchInfo<'a> { + NotBranch, + Jump(BasicBlockId), + Branch(ValueId, BasicBlockId, BasicBlockId), + Switch(ValueId, &'a SwitchTable, Option), +} + +impl<'a> BranchInfo<'a> { + pub fn is_not_a_branch(&self) -> bool { + matches!(self, BranchInfo::NotBranch) + } + + pub fn block_iter(&self) -> BlockIter { + match self { + Self::NotBranch => BlockIter::Zero, + Self::Jump(block) => BlockIter::one(*block), + Self::Branch(_, then, else_) => BlockIter::one(*then).chain(BlockIter::one(*else_)), + Self::Switch(_, table, default) => { + BlockIter::Slice(table.blocks.iter()).chain(BlockIter::One(*default)) + } + } + } +} + +pub type BlockIter<'a> = IterBase<'a, BasicBlockId>; +pub type ValueIter<'a> = IterBase<'a, ValueId>; +pub type ValueIterMut<'a> = IterMutBase<'a, ValueId>; + +pub enum IterBase<'a, T> { + Zero, + One(Option), + Slice(std::slice::Iter<'a, T>), + Chain(Box>, Box>), +} + +impl<'a, T> IterBase<'a, T> { + fn one(value: T) -> Self { + Self::One(Some(value)) + } + + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a, T> Iterator for IterBase<'a, T> +where + T: Copy, +{ + type Item = T; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next().copied(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} + +pub enum IterMutBase<'a, T> { + Zero, + One(Option<&'a mut T>), + Slice(std::slice::IterMut<'a, T>), + Chain(Box>, Box>), +} + +impl<'a, T> IterMutBase<'a, T> { + fn one(value: &'a mut T) -> Self { + Self::One(Some(value)) + } + + fn chain(self, rhs: Self) -> Self { + Self::Chain(self.into(), rhs.into()) + } +} + +impl<'a, T> Iterator for IterMutBase<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + match self { + Self::Zero => None, + Self::One(value) => value.take(), + Self::Slice(s) => s.next(), + Self::Chain(first, second) => { + if let Some(value) = first.next() { + Some(value) + } else { + second.next() + } + } + } + } +} diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs new file mode 100644 index 0000000000..0327b0348b --- /dev/null +++ b/crates/mir2/src/ir/mod.rs @@ -0,0 +1,49 @@ +use fe_common::Span; +use fe_parser2::node::{Node, NodeId}; + +pub mod basic_block; +pub mod body_builder; +pub mod body_cursor; +pub mod body_order; +pub mod constant; +pub mod function; +pub mod inst; +pub mod types; +pub mod value; + +pub use basic_block::{BasicBlock, BasicBlockId}; +pub use constant::{Constant, ConstantId}; +pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; +pub use inst::{Inst, InstId}; +pub use types::{Type, TypeId, TypeKind}; +pub use value::{Value, ValueId}; + +/// An original source information that indicates where `mir` entities derive +/// from. `SourceInfo` is mainly used for diagnostics. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SourceInfo { + pub span: Span, + pub id: NodeId, +} + +impl SourceInfo { + pub fn dummy() -> Self { + Self { + span: Span::dummy(), + id: NodeId::dummy(), + } + } + + pub fn is_dummy(&self) -> bool { + self == &Self::dummy() + } +} + +impl From<&Node> for SourceInfo { + fn from(node: &Node) -> Self { + Self { + span: node.span, + id: node.id, + } + } +} diff --git a/crates/mir2/src/ir/types.rs b/crates/mir2/src/ir/types.rs new file mode 100644 index 0000000000..8bdd9995c2 --- /dev/null +++ b/crates/mir2/src/ir/types.rs @@ -0,0 +1,119 @@ +use fe_analyzer::namespace::{items as analyzer_items, types as analyzer_types}; +use fe_common::{impl_intern_key, Span}; +use smol_str::SmolStr; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Type { + pub kind: TypeKind, + pub analyzer_ty: Option, +} + +impl Type { + pub fn new(kind: TypeKind, analyzer_ty: Option) -> Self { + Self { kind, analyzer_ty } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum TypeKind { + I8, + I16, + I32, + I64, + I128, + I256, + U8, + U16, + U32, + U64, + U128, + U256, + Bool, + Address, + Unit, + Array(ArrayDef), + // TODO: we should consider whether we really need `String` type. + String(usize), + Tuple(TupleDef), + Struct(StructDef), + Enum(EnumDef), + Contract(StructDef), + Map(MapDef), + MPtr(TypeId), + SPtr(TypeId), +} + +/// An interned Id for [`ArrayDef`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TypeId(pub u32); +impl_intern_key!(TypeId); + +/// A static array type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArrayDef { + pub elem_ty: TypeId, + pub len: usize, +} + +/// A tuple type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TupleDef { + pub items: Vec, +} + +/// A user defined struct type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StructDef { + pub name: SmolStr, + pub fields: Vec<(SmolStr, TypeId)>, + pub span: Span, + pub module_id: analyzer_items::ModuleId, +} + +/// A user defined struct type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EnumDef { + pub name: SmolStr, + pub variants: Vec, + pub span: Span, + pub module_id: analyzer_items::ModuleId, +} + +impl EnumDef { + pub fn tag_type(&self) -> TypeKind { + let variant_num = self.variants.len() as u64; + if variant_num <= u8::MAX as u64 { + TypeKind::U8 + } else if variant_num <= u16::MAX as u64 { + TypeKind::U16 + } else if variant_num <= u32::MAX as u64 { + TypeKind::U32 + } else { + TypeKind::U64 + } + } +} + +/// A user defined struct type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EnumVariant { + pub name: SmolStr, + pub span: Span, + pub ty: TypeId, +} + +/// A user defined struct type definition. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EventDef { + pub name: SmolStr, + pub fields: Vec<(SmolStr, TypeId, bool)>, + pub span: Span, + pub module_id: analyzer_items::ModuleId, +} + +/// A map type definition. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MapDef { + pub key_ty: TypeId, + pub value_ty: TypeId, +} diff --git a/crates/mir2/src/ir/value.rs b/crates/mir2/src/ir/value.rs new file mode 100644 index 0000000000..f4aad28b63 --- /dev/null +++ b/crates/mir2/src/ir/value.rs @@ -0,0 +1,142 @@ +use id_arena::Id; +use num_bigint::BigInt; +use smol_str::SmolStr; + +use crate::db::MirDb; + +use super::{ + constant::ConstantId, + function::BodyDataStore, + inst::InstId, + types::{TypeId, TypeKind}, + SourceInfo, +}; + +pub type ValueId = Id; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Value { + /// A value resulted from an instruction. + Temporary { inst: InstId, ty: TypeId }, + + /// A local variable declared in a function body. + Local(Local), + + /// An immediate value. + Immediate { imm: BigInt, ty: TypeId }, + + /// A constant value. + Constant { constant: ConstantId, ty: TypeId }, + + /// A singleton value representing `Unit` type. + Unit { ty: TypeId }, +} + +impl Value { + pub fn ty(&self) -> TypeId { + match self { + Self::Local(val) => val.ty, + Self::Immediate { ty, .. } + | Self::Temporary { ty, .. } + | Self::Unit { ty } + | Self::Constant { ty, .. } => *ty, + } + } + + pub fn is_imm(&self) -> bool { + matches!(self, Self::Immediate { .. }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum AssignableValue { + Value(ValueId), + Aggregate { + lhs: Box, + idx: ValueId, + }, + Map { + lhs: Box, + key: ValueId, + }, +} + +impl From for AssignableValue { + fn from(value: ValueId) -> Self { + Self::Value(value) + } +} + +impl AssignableValue { + pub fn ty(&self, db: &dyn MirDb, store: &BodyDataStore) -> TypeId { + match self { + Self::Value(value) => store.value_ty(*value), + Self::Aggregate { lhs, idx } => { + let lhs_ty = lhs.ty(db, store); + lhs_ty.projection_ty(db, store.value_data(*idx)) + } + Self::Map { lhs, .. } => { + let lhs_ty = lhs.ty(db, store).deref(db); + match lhs_ty.data(db).kind { + TypeKind::Map(def) => def.value_ty.make_sptr(db), + _ => unreachable!(), + } + } + } + } + + pub fn value_id(&self) -> Option { + match self { + Self::Value(value) => Some(*value), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Local { + /// An original name of a local variable. + pub name: SmolStr, + + pub ty: TypeId, + + /// `true` if a local is a function argument. + pub is_arg: bool, + + /// `true` if a local is introduced in MIR. + pub is_tmp: bool, + + pub source: SourceInfo, +} + +impl Local { + pub fn user_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + Self { + name, + ty, + is_arg: false, + is_tmp: false, + source, + } + } + + pub fn arg_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + Self { + name, + ty, + is_arg: true, + is_tmp: false, + source, + } + } + + pub fn tmp_local(name: SmolStr, ty: TypeId) -> Local { + Self { + name, + ty, + is_arg: false, + is_tmp: true, + source: SourceInfo::dummy(), + } + } +} diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs new file mode 100644 index 0000000000..f938352df8 --- /dev/null +++ b/crates/mir2/src/lib.rs @@ -0,0 +1,7 @@ +pub mod analysis; +pub mod db; +pub mod graphviz; +pub mod ir; +pub mod pretty_print; + +mod lower; diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs new file mode 100644 index 0000000000..a1d0212f6c --- /dev/null +++ b/crates/mir2/src/lower/function.rs @@ -0,0 +1,1367 @@ +use std::{collections::BTreeMap, rc::Rc, vec}; + +use fe_analyzer2::{ + builtins::{ContractTypeMethod, GlobalFunction, ValueMethod}, + constants::{EMITTABLE_TRAIT_NAME, EMIT_FN_NAME}, + context::{Adjustment, AdjustmentKind, CallType as AnalyzerCallType, NamedThing}, + namespace::{ + items as analyzer_items, + types::{self as analyzer_types, Type}, + }, +}; +use fe_common::numeric::Literal; +use fe_parser2::{ast, node::Node}; +use fxhash::FxHashMap; +use id_arena::{Arena, Id}; +use num_bigint::BigInt; +use smol_str::SmolStr; + +use crate::{ + db::MirDb, + ir::{ + self, + body_builder::BodyBuilder, + constant::ConstantValue, + function::Linkage, + inst::{CallType, InstKind}, + value::{AssignableValue, Local}, + BasicBlockId, Constant, FunctionBody, FunctionId, FunctionParam, FunctionSignature, InstId, + SourceInfo, TypeId, Value, ValueId, + }, +}; + +type ScopeId = Id; + +pub fn lower_func_signature(db: &dyn MirDb, func: analyzer_items::FunctionId) -> FunctionId { + lower_monomorphized_func_signature(db, func, BTreeMap::new()) +} +pub fn lower_monomorphized_func_signature( + db: &dyn MirDb, + func: analyzer_items::FunctionId, + resolved_generics: BTreeMap, +) -> FunctionId { + // TODO: Remove this when an analyzer's function signature contains `self` type. + let mut params = vec![]; + + if func.takes_self(db.upcast()) { + let self_ty = func.self_type(db.upcast()).unwrap(); + let source = self_arg_source(db, func); + params.push(make_param(db, "self", self_ty, source)); + } + let analyzer_signature = func.signature(db.upcast()); + + for param in analyzer_signature.params.iter() { + let source = arg_source(db, func, ¶m.name); + + let param_type = + if let Type::Generic(generic) = param.typ.clone().unwrap().deref_typ(db.upcast()) { + *resolved_generics.get(&generic.name).unwrap() + } else { + param.typ.clone().unwrap() + }; + + params.push(make_param(db, param.clone().name, param_type, source)) + } + + let return_type = db.mir_lowered_type(analyzer_signature.return_type.clone().unwrap()); + + let linkage = if func.is_public(db.upcast()) { + if func.is_contract_func(db.upcast()) && !func.is_constructor(db.upcast()) { + Linkage::Export + } else { + Linkage::Public + } + } else { + Linkage::Private + }; + + let sig = FunctionSignature { + params, + resolved_generics, + return_type: Some(return_type), + module_id: func.module(db.upcast()), + analyzer_func_id: func, + linkage, + }; + + db.mir_intern_function(sig.into()) +} + +pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { + let analyzer_func = func.analyzer_func(db); + let ast = &analyzer_func.data(db.upcast()).ast; + let analyzer_body = analyzer_func.body(db.upcast()); + + BodyLowerHelper::new(db, func, ast, analyzer_body.as_ref()) + .lower() + .into() +} + +pub(super) struct BodyLowerHelper<'db, 'a> { + pub(super) db: &'db dyn MirDb, + pub(super) builder: BodyBuilder, + ast: &'a Node, + func: FunctionId, + analyzer_body: &'a fe_analyzer::context::FunctionBody, + scopes: Arena, + current_scope: ScopeId, +} + +impl<'db, 'a> BodyLowerHelper<'db, 'a> { + pub(super) fn lower_stmt(&mut self, stmt: &Node) { + match &stmt.kind { + ast::FuncStmt::Return { value } => { + let value = if let Some(expr) = value { + self.lower_expr_to_value(expr) + } else { + self.make_unit() + }; + self.builder.ret(value, stmt.into()); + let next_block = self.builder.make_block(); + self.builder.move_to_block(next_block); + } + + ast::FuncStmt::VarDecl { target, value, .. } => { + self.lower_var_decl(target, value.as_ref(), stmt.into()); + } + + ast::FuncStmt::ConstantDecl { name, value, .. } => { + let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&name.id]); + + let value = self.analyzer_body.expressions[&value.id] + .const_value + .clone() + .unwrap(); + + let constant = + self.make_local_constant(name.kind.clone(), ty, value.into(), stmt.into()); + self.scope_mut().declare_var(&name.kind, constant); + } + + ast::FuncStmt::Assign { target, value } => { + let result = self.lower_assignable_value(target); + let (expr, _ty) = self.lower_expr(value); + self.builder.map_result(expr, result) + } + + ast::FuncStmt::AugAssign { target, op, value } => { + let result = self.lower_assignable_value(target); + let lhs = self.lower_expr_to_value(target); + let rhs = self.lower_expr_to_value(value); + + let inst = self.lower_binop(op.kind, lhs, rhs, stmt.into()); + self.builder.map_result(inst, result) + } + + ast::FuncStmt::For { target, iter, body } => self.lower_for_loop(target, iter, body), + + ast::FuncStmt::While { test, body } => { + let header_bb = self.builder.make_block(); + let exit_bb = self.builder.make_block(); + + let cond = self.lower_expr_to_value(test); + self.builder + .branch(cond, header_bb, exit_bb, SourceInfo::dummy()); + + // Lower while body. + self.builder.move_to_block(header_bb); + self.enter_loop_scope(header_bb, exit_bb); + for stmt in body { + self.lower_stmt(stmt); + } + let cond = self.lower_expr_to_value(test); + self.builder + .branch(cond, header_bb, exit_bb, SourceInfo::dummy()); + + self.leave_scope(); + + // Move to while exit bb. + self.builder.move_to_block(exit_bb); + } + + ast::FuncStmt::If { + test, + body, + or_else, + } => self.lower_if(test, body, or_else), + + ast::FuncStmt::Match { expr, arms } => { + let matrix = &self.analyzer_body.matches[&stmt.id]; + super::pattern_match::lower_match(self, matrix, expr, arms); + } + + ast::FuncStmt::Assert { test, msg } => { + let then_bb = self.builder.make_block(); + let false_bb = self.builder.make_block(); + + let cond = self.lower_expr_to_value(test); + self.builder + .branch(cond, then_bb, false_bb, SourceInfo::dummy()); + + self.builder.move_to_block(false_bb); + + let msg = match msg { + Some(msg) => self.lower_expr_to_value(msg), + None => self.make_u256_imm(1), + }; + self.builder.revert(Some(msg), stmt.into()); + self.builder.move_to_block(then_bb); + } + + ast::FuncStmt::Expr { value } => { + self.lower_expr_to_value(value); + } + + ast::FuncStmt::Break => { + let exit = self.scope().loop_exit(&self.scopes); + self.builder.jump(exit, stmt.into()); + let next_block = self.builder.make_block(); + self.builder.move_to_block(next_block); + } + + ast::FuncStmt::Continue => { + let entry = self.scope().loop_entry(&self.scopes); + if let Some(loop_idx) = self.scope().loop_idx(&self.scopes) { + let imm_one = self.make_u256_imm(1u32); + let inc = self.builder.add(loop_idx, imm_one, SourceInfo::dummy()); + self.builder.map_result(inc, loop_idx.into()); + let maximum_iter_count = self.scope().maximum_iter_count(&self.scopes).unwrap(); + let exit = self.scope().loop_exit(&self.scopes); + self.branch_eq(loop_idx, maximum_iter_count, exit, entry, stmt.into()); + } else { + self.builder.jump(entry, stmt.into()); + } + let next_block = self.builder.make_block(); + self.builder.move_to_block(next_block); + } + + ast::FuncStmt::Revert { error } => { + let error = error.as_ref().map(|err| self.lower_expr_to_value(err)); + self.builder.revert(error, stmt.into()); + let next_block = self.builder.make_block(); + self.builder.move_to_block(next_block); + } + + ast::FuncStmt::Unsafe(stmts) => { + self.enter_scope(); + for stmt in stmts { + self.lower_stmt(stmt) + } + self.leave_scope() + } + } + } + + pub(super) fn lower_var_decl( + &mut self, + var: &Node, + init: Option<&Node>, + source: SourceInfo, + ) { + match &var.kind { + ast::VarDeclTarget::Name(name) => { + let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + let value = self.declare_var(name, ty, var.into()); + if let Some(init) = init { + let (init, _init_ty) = self.lower_expr(init); + // debug_assert_eq!(ty.deref(self.db), init_ty, "vardecl init type mismatch: {} + // != {}", ty.as_string(self.db), + // init_ty.as_string(self.db)); + self.builder.map_result(init, value.into()); + } + } + + ast::VarDeclTarget::Tuple(decls) => { + if let Some(init) = init { + if let ast::Expr::Tuple { elts } = &init.kind { + debug_assert_eq!(decls.len(), elts.len()); + for (decl, init_elem) in decls.iter().zip(elts.iter()) { + self.lower_var_decl(decl, Some(init_elem), source.clone()); + } + } else { + let init_ty = self.expr_ty(init); + let init_value = self.lower_expr_to_value(init); + self.lower_var_decl_unpack(var, init_value, init_ty, source); + }; + } else { + for decl in decls { + self.lower_var_decl(decl, None, source.clone()) + } + } + } + } + } + + pub(super) fn declare_var( + &mut self, + name: &SmolStr, + ty: TypeId, + source: SourceInfo, + ) -> ValueId { + let local = Local::user_local(name.clone(), ty, source); + let value = self.builder.declare(local); + self.scope_mut().declare_var(name, value); + value + } + + pub(super) fn lower_var_decl_unpack( + &mut self, + var: &Node, + init: ValueId, + init_ty: TypeId, + source: SourceInfo, + ) { + match &var.kind { + ast::VarDeclTarget::Name(name) => { + let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + let local = Local::user_local(name.clone(), ty, var.into()); + + let lhs = self.builder.declare(local); + self.scope_mut().declare_var(name, lhs); + let bind = self.builder.bind(init, source); + self.builder.map_result(bind, lhs.into()); + } + + ast::VarDeclTarget::Tuple(decls) => { + for (index, decl) in decls.iter().enumerate() { + let elem_ty = init_ty.projection_ty_imm(self.db, index); + let index_value = self.make_u256_imm(index); + let elem_inst = + self.builder + .aggregate_access(init, vec![index_value], source.clone()); + let elem_value = self.map_to_tmp(elem_inst, elem_ty); + self.lower_var_decl_unpack(decl, elem_value, elem_ty, source.clone()) + } + } + } + } + + pub(super) fn lower_expr(&mut self, expr: &Node) -> (InstId, TypeId) { + let mut ty = self.expr_ty(expr); + let mut inst = match &expr.kind { + ast::Expr::Ternary { + if_expr, + test, + else_expr, + } => { + let true_bb = self.builder.make_block(); + let false_bb = self.builder.make_block(); + let merge_bb = self.builder.make_block(); + + let tmp = self + .builder + .declare(Local::tmp_local("$ternary_tmp".into(), ty)); + + let cond = self.lower_expr_to_value(test); + self.builder + .branch(cond, true_bb, false_bb, SourceInfo::dummy()); + + self.builder.move_to_block(true_bb); + let (value, _) = self.lower_expr(if_expr); + self.builder.map_result(value, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + + self.builder.move_to_block(false_bb); + let (value, _) = self.lower_expr(else_expr); + self.builder.map_result(value, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + + self.builder.move_to_block(merge_bb); + self.builder.bind(tmp, SourceInfo::dummy()) + } + + ast::Expr::BoolOperation { left, op, right } => { + self.lower_bool_op(op.kind, left, right, ty) + } + + ast::Expr::BinOperation { left, op, right } => { + let lhs = self.lower_expr_to_value(left); + let rhs = self.lower_expr_to_value(right); + self.lower_binop(op.kind, lhs, rhs, expr.into()) + } + + ast::Expr::UnaryOperation { op, operand } => { + let value = self.lower_expr_to_value(operand); + match op.kind { + ast::UnaryOperator::Invert => self.builder.inv(value, expr.into()), + ast::UnaryOperator::Not => self.builder.not(value, expr.into()), + ast::UnaryOperator::USub => self.builder.neg(value, expr.into()), + } + } + + ast::Expr::CompOperation { left, op, right } => { + let lhs = self.lower_expr_to_value(left); + let rhs = self.lower_expr_to_value(right); + self.lower_comp_op(op.kind, lhs, rhs, expr.into()) + } + + ast::Expr::Attribute { .. } => { + let mut indices = vec![]; + let value = self.lower_aggregate_access(expr, &mut indices); + self.builder.aggregate_access(value, indices, expr.into()) + } + + ast::Expr::Subscript { value, index } => { + let value_ty = self.expr_ty(value).deref(self.db); + if value_ty.is_aggregate(self.db) { + let mut indices = vec![]; + let value = self.lower_aggregate_access(expr, &mut indices); + self.builder.aggregate_access(value, indices, expr.into()) + } else if value_ty.is_map(self.db) { + let value = self.lower_expr_to_value(value); + let key = self.lower_expr_to_value(index); + self.builder.map_access(value, key, expr.into()) + } else { + unreachable!() + } + } + + ast::Expr::Call { + func, + generic_args, + args, + } => { + let ty = self.expr_ty(expr); + self.lower_call(func, generic_args, &args.kind, ty, expr.into()) + } + + ast::Expr::List { elts } | ast::Expr::Tuple { elts } => { + let args = elts + .iter() + .map(|elem| self.lower_expr_to_value(elem)) + .collect(); + let ty = self.expr_ty(expr); + self.builder.aggregate_construct(ty, args, expr.into()) + } + + ast::Expr::Repeat { value, len: _ } => { + let array_type = if let Type::Array(array_type) = self.analyzer_body.expressions + [&expr.id] + .typ + .typ(self.db.upcast()) + { + array_type + } else { + panic!("not an array"); + }; + + let args = vec![self.lower_expr_to_value(value); array_type.size]; + let ty = self.expr_ty(expr); + self.builder.aggregate_construct(ty, args, expr.into()) + } + + ast::Expr::Bool(b) => { + let imm = self.builder.make_imm_from_bool(*b, ty); + self.builder.bind(imm, expr.into()) + } + + ast::Expr::Name(name) => { + let value = self.resolve_name(name); + self.builder.bind(value, expr.into()) + } + + ast::Expr::Path(path) => { + let value = self.resolve_path(path, expr.into()); + self.builder.bind(value, expr.into()) + } + + ast::Expr::Num(num) => { + let imm = Literal::new(num).parse().unwrap(); + let imm = self.builder.make_imm(imm, ty); + self.builder.bind(imm, expr.into()) + } + + ast::Expr::Str(s) => { + let ty = self.expr_ty(expr); + let const_value = self.make_local_constant( + "str_in_func".into(), + ty, + ConstantValue::Str(s.clone()), + expr.into(), + ); + self.builder.bind(const_value, expr.into()) + } + + ast::Expr::Unit => { + let value = self.make_unit(); + self.builder.bind(value, expr.into()) + } + }; + + for Adjustment { into, kind } in &self.analyzer_body.expressions[&expr.id].type_adjustments + { + let into_ty = self.lower_analyzer_type(*into); + + match kind { + AdjustmentKind::Copy => { + let val = self.inst_result_or_tmp(inst, ty); + inst = self.builder.mem_copy(val, expr.into()); + } + AdjustmentKind::Load => { + let val = self.inst_result_or_tmp(inst, ty); + inst = self.builder.load(val, expr.into()); + } + AdjustmentKind::IntSizeIncrease => { + let val = self.inst_result_or_tmp(inst, ty); + inst = self.builder.primitive_cast(val, into_ty, expr.into()) + } + AdjustmentKind::StringSizeIncrease => {} // XXX + } + ty = into_ty; + } + (inst, ty) + } + + fn inst_result_or_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + self.builder + .inst_result(inst) + .and_then(|r| r.value_id()) + .unwrap_or_else(|| self.map_to_tmp(inst, ty)) + } + + pub(super) fn lower_expr_to_value(&mut self, expr: &Node) -> ValueId { + let (inst, ty) = self.lower_expr(expr); + self.map_to_tmp(inst, ty) + } + + pub(super) fn enter_scope(&mut self) { + let new_scope = Scope::with_parent(self.current_scope); + self.current_scope = self.scopes.alloc(new_scope); + } + + pub(super) fn leave_scope(&mut self) { + self.current_scope = self.scopes[self.current_scope].parent.unwrap(); + } + + pub(super) fn make_imm(&mut self, imm: impl Into, ty: TypeId) -> ValueId { + self.builder.make_value(Value::Immediate { + imm: imm.into(), + ty, + }) + } + + pub(super) fn make_u256_imm(&mut self, value: impl Into) -> ValueId { + let u256_ty = self.u256_ty(); + self.make_imm(value, u256_ty) + } + + pub(super) fn map_to_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + match &self.builder.inst_data(inst).kind { + InstKind::Bind { src } => { + let value = *src; + self.builder.remove_inst(inst); + value + } + _ => { + let tmp = Value::Temporary { inst, ty }; + let result = self.builder.make_value(tmp); + self.builder.map_result(inst, result.into()); + result + } + } + } + + fn new( + db: &'db dyn MirDb, + func: FunctionId, + ast: &'a Node, + analyzer_body: &'a fe_analyzer::context::FunctionBody, + ) -> Self { + let mut builder = BodyBuilder::new(func, ast.into()); + let mut scopes = Arena::new(); + + // Make a root scope. A root scope collects function parameters and module + // constants. + let root = Scope::root(db, func, &mut builder); + let current_scope = scopes.alloc(root); + Self { + db, + builder, + ast, + func, + analyzer_body, + scopes, + current_scope, + } + } + + fn lower_analyzer_type(&self, analyzer_ty: analyzer_types::TypeId) -> TypeId { + // If the analyzer type is generic we first need to resolve it to its concrete + // type before lowering to a MIR type + if let analyzer_types::Type::Generic(generic) = analyzer_ty.deref_typ(self.db.upcast()) { + let resolved_type = self + .func + .signature(self.db) + .resolved_generics + .get(&generic.name) + .cloned() + .expect("expected generic to be resolved"); + + return self.db.mir_lowered_type(resolved_type); + } + + self.db.mir_lowered_type(analyzer_ty) + } + + fn lower(mut self) -> FunctionBody { + for stmt in &self.ast.kind.body { + self.lower_stmt(stmt) + } + + let last_block = self.builder.current_block(); + if !self.builder.is_block_terminated(last_block) { + let unit = self.make_unit(); + self.builder.ret(unit, SourceInfo::dummy()); + } + + self.builder.build() + } + + fn branch_eq( + &mut self, + v1: ValueId, + v2: ValueId, + true_bb: BasicBlockId, + false_bb: BasicBlockId, + source: SourceInfo, + ) { + let cond = self.builder.eq(v1, v2, source.clone()); + let bool_ty = self.bool_ty(); + let cond = self.map_to_tmp(cond, bool_ty); + self.builder.branch(cond, true_bb, false_bb, source); + } + + fn lower_if( + &mut self, + cond: &Node, + then: &[Node], + else_: &[Node], + ) { + let cond = self.lower_expr_to_value(cond); + + if else_.is_empty() { + let then_bb = self.builder.make_block(); + let merge_bb = self.builder.make_block(); + + self.builder + .branch(cond, then_bb, merge_bb, SourceInfo::dummy()); + + // Lower then block. + self.builder.move_to_block(then_bb); + self.enter_scope(); + for stmt in then { + self.lower_stmt(stmt); + } + self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.move_to_block(merge_bb); + self.leave_scope(); + } else { + let then_bb = self.builder.make_block(); + let else_bb = self.builder.make_block(); + + self.builder + .branch(cond, then_bb, else_bb, SourceInfo::dummy()); + + // Lower then block. + self.builder.move_to_block(then_bb); + self.enter_scope(); + for stmt in then { + self.lower_stmt(stmt); + } + self.leave_scope(); + let then_block_end_bb = self.builder.current_block(); + + // Lower else_block. + self.builder.move_to_block(else_bb); + self.enter_scope(); + for stmt in else_ { + self.lower_stmt(stmt); + } + self.leave_scope(); + let else_block_end_bb = self.builder.current_block(); + + let merge_bb = self.builder.make_block(); + if !self.builder.is_block_terminated(then_block_end_bb) { + self.builder.move_to_block(then_block_end_bb); + self.builder.jump(merge_bb, SourceInfo::dummy()); + } + if !self.builder.is_block_terminated(else_block_end_bb) { + self.builder.move_to_block(else_block_end_bb); + self.builder.jump(merge_bb, SourceInfo::dummy()); + } + self.builder.move_to_block(merge_bb); + } + } + + // NOTE: we assume a type of `iter` is array. + // TODO: Desugar to `loop` + `match` like rustc in HIR to generate better MIR. + fn lower_for_loop( + &mut self, + loop_variable: &Node, + iter: &Node, + body: &[Node], + ) { + let preheader_bb = self.builder.make_block(); + let entry_bb = self.builder.make_block(); + let exit_bb = self.builder.make_block(); + + let iter_elem_ty = self.analyzer_body.var_types[&loop_variable.id]; + let iter_elem_ty = self.lower_analyzer_type(iter_elem_ty); + + self.builder.jump(preheader_bb, SourceInfo::dummy()); + + // `For` has its scope from preheader block. + self.enter_loop_scope(entry_bb, exit_bb); + + /* Lower preheader. */ + self.builder.move_to_block(preheader_bb); + + // Declare loop_variable. + let loop_value = self.builder.declare(Local::user_local( + loop_variable.kind.clone(), + iter_elem_ty, + loop_variable.into(), + )); + self.scope_mut() + .declare_var(&loop_variable.kind, loop_value); + + // Declare and initialize `loop_idx` to 0. + let loop_idx = Local::tmp_local("$loop_idx_tmp".into(), self.u256_ty()); + let loop_idx = self.builder.declare(loop_idx); + let imm_zero = self.make_u256_imm(0u32); + let imm_zero = self.builder.bind(imm_zero, SourceInfo::dummy()); + self.builder.map_result(imm_zero, loop_idx.into()); + + // Evaluates loop variable. + let iter_ty = self.expr_ty(iter); + let iter = self.lower_expr_to_value(iter); + + // Create maximum loop count. + let maximum_iter_count = match &iter_ty.deref(self.db).data(self.db).kind { + ir::TypeKind::Array(ir::types::ArrayDef { len, .. }) => *len, + _ => unreachable!(), + }; + let maximum_iter_count = self.make_u256_imm(maximum_iter_count); + self.branch_eq( + loop_idx, + maximum_iter_count, + exit_bb, + entry_bb, + SourceInfo::dummy(), + ); + self.scope_mut().loop_idx = Some(loop_idx); + self.scope_mut().maximum_iter_count = Some(maximum_iter_count); + + /* Lower body. */ + self.builder.move_to_block(entry_bb); + + // loop_variable = array[loop_idx] + let iter_elem = self + .builder + .aggregate_access(iter, vec![loop_idx], SourceInfo::dummy()); + self.builder + .map_result(iter_elem, AssignableValue::Value(loop_value)); + + for stmt in body { + self.lower_stmt(stmt); + } + + // loop_idx += 1 + let imm_one = self.make_u256_imm(1u32); + let inc = self.builder.add(loop_idx, imm_one, SourceInfo::dummy()); + self.builder + .map_result(inc, AssignableValue::Value(loop_idx)); + self.branch_eq( + loop_idx, + maximum_iter_count, + exit_bb, + entry_bb, + SourceInfo::dummy(), + ); + + /* Move to exit bb */ + self.leave_scope(); + self.builder.move_to_block(exit_bb); + } + + fn lower_assignable_value(&mut self, expr: &Node) -> AssignableValue { + match &expr.kind { + ast::Expr::Attribute { value, attr } => { + let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + let idx = self.make_u256_imm(idx); + let lhs = self.lower_assignable_value(value).into(); + AssignableValue::Aggregate { lhs, idx } + } + ast::Expr::Subscript { value, index } => { + let lhs = self.lower_assignable_value(value).into(); + let attr = self.lower_expr_to_value(index); + let value_ty = self.expr_ty(value).deref(self.db); + if value_ty.is_aggregate(self.db) { + AssignableValue::Aggregate { lhs, idx: attr } + } else if value_ty.is_map(self.db) { + AssignableValue::Map { lhs, key: attr } + } else { + unreachable!() + } + } + ast::Expr::Name(name) => self.resolve_name(name).into(), + ast::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), + _ => self.lower_expr_to_value(expr).into(), + } + } + + /// Returns the pre-adjustment type of the given `Expr` + fn expr_ty(&self, expr: &Node) -> TypeId { + let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; + self.lower_analyzer_type(analyzer_ty) + } + + fn lower_bool_op( + &mut self, + op: ast::BoolOperator, + lhs: &Node, + rhs: &Node, + ty: TypeId, + ) -> InstId { + let true_bb = self.builder.make_block(); + let false_bb = self.builder.make_block(); + let merge_bb = self.builder.make_block(); + + let lhs = self.lower_expr_to_value(lhs); + let tmp = self + .builder + .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); + + match op { + ast::BoolOperator::And => { + self.builder + .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); + + self.builder.move_to_block(true_bb); + let (rhs, _rhs_ty) = self.lower_expr(rhs); + self.builder.map_result(rhs, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + + self.builder.move_to_block(false_bb); + let false_imm = self.builder.make_imm_from_bool(false, ty); + let false_imm_copy = self.builder.bind(false_imm, SourceInfo::dummy()); + self.builder.map_result(false_imm_copy, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + } + + ast::BoolOperator::Or => { + self.builder + .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); + + self.builder.move_to_block(true_bb); + let true_imm = self.builder.make_imm_from_bool(true, ty); + let true_imm_copy = self.builder.bind(true_imm, SourceInfo::dummy()); + self.builder.map_result(true_imm_copy, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + + self.builder.move_to_block(false_bb); + let (rhs, _rhs_ty) = self.lower_expr(rhs); + self.builder.map_result(rhs, tmp.into()); + self.builder.jump(merge_bb, SourceInfo::dummy()); + } + } + + self.builder.move_to_block(merge_bb); + self.builder.bind(tmp, SourceInfo::dummy()) + } + + fn lower_binop( + &mut self, + op: ast::BinOperator, + lhs: ValueId, + rhs: ValueId, + source: SourceInfo, + ) -> InstId { + match op { + ast::BinOperator::Add => self.builder.add(lhs, rhs, source), + ast::BinOperator::Sub => self.builder.sub(lhs, rhs, source), + ast::BinOperator::Mult => self.builder.mul(lhs, rhs, source), + ast::BinOperator::Div => self.builder.div(lhs, rhs, source), + ast::BinOperator::Mod => self.builder.modulo(lhs, rhs, source), + ast::BinOperator::Pow => self.builder.pow(lhs, rhs, source), + ast::BinOperator::LShift => self.builder.shl(lhs, rhs, source), + ast::BinOperator::RShift => self.builder.shr(lhs, rhs, source), + ast::BinOperator::BitOr => self.builder.bit_or(lhs, rhs, source), + ast::BinOperator::BitXor => self.builder.bit_xor(lhs, rhs, source), + ast::BinOperator::BitAnd => self.builder.bit_and(lhs, rhs, source), + } + } + + fn lower_comp_op( + &mut self, + op: ast::CompOperator, + lhs: ValueId, + rhs: ValueId, + source: SourceInfo, + ) -> InstId { + match op { + ast::CompOperator::Eq => self.builder.eq(lhs, rhs, source), + ast::CompOperator::NotEq => self.builder.ne(lhs, rhs, source), + ast::CompOperator::Lt => self.builder.lt(lhs, rhs, source), + ast::CompOperator::LtE => self.builder.le(lhs, rhs, source), + ast::CompOperator::Gt => self.builder.gt(lhs, rhs, source), + ast::CompOperator::GtE => self.builder.ge(lhs, rhs, source), + } + } + + fn resolve_generics_args( + &mut self, + method: &analyzer_items::FunctionId, + args: &[Id], + ) -> BTreeMap { + method + .signature(self.db.upcast()) + .params + .iter() + .zip(args.iter().map(|val| { + self.builder + .value_ty(*val) + .analyzer_ty(self.db) + .expect("invalid parameter") + })) + .filter_map(|(param, typ)| { + if let Type::Generic(generic) = + param.typ.clone().unwrap().deref_typ(self.db.upcast()) + { + Some((generic.name, typ)) + } else { + None + } + }) + .collect::>() + } + + fn lower_function_id( + &mut self, + function: &analyzer_items::FunctionId, + args: &[Id], + ) -> FunctionId { + let resolved_generics = self.resolve_generics_args(function, args); + if function.is_generic(self.db.upcast()) { + self.db + .mir_lowered_monomorphized_func_signature(*function, resolved_generics) + } else { + self.db.mir_lowered_func_signature(*function) + } + } + + fn lower_call( + &mut self, + func: &Node, + _generic_args: &Option>>, + args: &[Node], + ty: TypeId, + source: SourceInfo, + ) -> InstId { + let call_type = &self.analyzer_body.calls[&func.id]; + + let mut args: Vec<_> = args + .iter() + .map(|arg| self.lower_expr_to_value(&arg.kind.value)) + .collect(); + + match call_type { + AnalyzerCallType::BuiltinFunction(GlobalFunction::Keccak256) => { + self.builder.keccak256(args[0], source) + } + + AnalyzerCallType::Intrinsic(intrinsic) => { + self.builder + .yul_intrinsic((*intrinsic).into(), args, source) + } + + AnalyzerCallType::BuiltinValueMethod { method, .. } => { + let arg = self.lower_method_receiver(func); + match method { + ValueMethod::ToMem => self.builder.mem_copy(arg, source), + ValueMethod::AbiEncode => self.builder.abi_encode(arg, source), + } + } + + // We ignores `args[0]', which represents `context` and not used for now. + AnalyzerCallType::BuiltinAssociatedFunction { contract, function } => match function { + ContractTypeMethod::Create => self.builder.create(args[1], *contract, source), + ContractTypeMethod::Create2 => { + self.builder.create2(args[1], args[2], *contract, source) + } + }, + + AnalyzerCallType::AssociatedFunction { function, .. } + | AnalyzerCallType::Pure(function) => { + let func_id = self.lower_function_id(function, &args); + self.builder.call(func_id, args, CallType::Internal, source) + } + + AnalyzerCallType::ValueMethod { method, .. } => { + let mut method_args = vec![self.lower_method_receiver(func)]; + let func_id = self.lower_function_id(method, &args); + + method_args.append(&mut args); + + self.builder + .call(func_id, method_args, CallType::Internal, source) + } + AnalyzerCallType::TraitValueMethod { + trait_id, method, .. + } if trait_id.is_std_trait(self.db.upcast(), EMITTABLE_TRAIT_NAME) + && method.name(self.db.upcast()) == EMIT_FN_NAME => + { + let event = self.lower_method_receiver(func); + self.builder.emit(event, source) + } + AnalyzerCallType::TraitValueMethod { + method, + trait_id, + generic_type, + .. + } => { + let mut method_args = vec![self.lower_method_receiver(func)]; + method_args.append(&mut args); + + let concrete_type = self + .func + .signature(self.db) + .resolved_generics + .get(&generic_type.name) + .cloned() + .expect("unresolved generic type"); + + let impl_ = concrete_type + .get_impl_for(self.db.upcast(), *trait_id) + .expect("missing impl"); + + let function = impl_ + .function(self.db.upcast(), &method.name(self.db.upcast())) + .expect("missing function"); + + let func_id = self.db.mir_lowered_func_signature(function); + self.builder + .call(func_id, method_args, CallType::Internal, source) + } + AnalyzerCallType::External { function, .. } => { + let receiver = self.lower_method_receiver(func); + debug_assert!(self.builder.value_ty(receiver).is_address(self.db)); + + let mut method_args = vec![receiver]; + method_args.append(&mut args); + let func_id = self.db.mir_lowered_func_signature(*function); + self.builder + .call(func_id, method_args, CallType::External, source) + } + + AnalyzerCallType::TypeConstructor(to_ty) => { + if to_ty.is_string(self.db.upcast()) { + let arg = *args.last().unwrap(); + self.builder.mem_copy(arg, source) + } else if ty.is_primitive(self.db) { + // TODO: Ignore `ctx` for now. + let arg = *args.last().unwrap(); + let arg_ty = self.builder.value_ty(arg); + if arg_ty == ty { + self.builder.bind(arg, source) + } else { + debug_assert!(!arg_ty.is_ptr(self.db)); // Should be explicitly `Load`ed + self.builder.primitive_cast(arg, ty, source) + } + } else if ty.is_aggregate(self.db) { + self.builder.aggregate_construct(ty, args, source) + } else { + unreachable!() + } + } + + AnalyzerCallType::EnumConstructor(variant) => { + let tag_type = ty.enum_disc_type(self.db); + let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + let data_ty = ty.enum_variant_type(self.db, *variant); + let enum_args = if data_ty.is_unit(self.db) { + vec![tag, self.make_unit()] + } else { + std::iter::once(tag).chain(args).collect() + }; + self.builder.aggregate_construct(ty, enum_args, source) + } + } + } + + // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. + fn lower_method_receiver(&mut self, receiver: &Node) -> ValueId { + match &receiver.kind { + ast::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), + _ => unreachable!(), + } + } + + fn lower_aggregate_access( + &mut self, + expr: &Node, + indices: &mut Vec, + ) -> ValueId { + match &expr.kind { + ast::Expr::Attribute { value, attr } => { + let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + let value = self.lower_aggregate_access(value, indices); + indices.push(self.make_u256_imm(index)); + value + } + + ast::Expr::Subscript { value, index } + if self.expr_ty(value).deref(self.db).is_aggregate(self.db) => + { + let value = self.lower_aggregate_access(value, indices); + indices.push(self.lower_expr_to_value(index)); + value + } + + _ => self.lower_expr_to_value(expr), + } + } + + fn make_unit(&mut self) -> ValueId { + let unit_ty = analyzer_types::TypeId::unit(self.db.upcast()); + let unit_ty = self.db.mir_lowered_type(unit_ty); + self.builder.make_unit(unit_ty) + } + + fn make_local_constant( + &mut self, + name: SmolStr, + ty: TypeId, + value: ConstantValue, + source: SourceInfo, + ) -> ValueId { + let function_id = self.builder.func_id(); + let constant = Constant { + name, + value, + ty, + module_id: function_id.module(self.db), + source, + }; + + let constant_id = self.db.mir_intern_const(constant.into()); + self.builder.make_constant(constant_id, ty) + } + + fn u256_ty(&mut self) -> TypeId { + self.db + .mir_intern_type(ir::Type::new(ir::TypeKind::U256, None).into()) + } + + fn bool_ty(&mut self) -> TypeId { + self.db + .mir_intern_type(ir::Type::new(ir::TypeKind::Bool, None).into()) + } + + fn enter_loop_scope(&mut self, entry: BasicBlockId, exit: BasicBlockId) { + let new_scope = Scope::loop_scope(self.current_scope, entry, exit); + self.current_scope = self.scopes.alloc(new_scope); + } + + /// Resolve a name appeared in an expression. + /// NOTE: Don't call this to resolve method receiver. + fn resolve_name(&mut self, name: &str) -> ValueId { + if let Some(value) = self.scopes[self.current_scope].resolve_name(&self.scopes, name) { + // Name is defined in local. + value + } else { + // Name is defined in global. + let func_id = self.builder.func_id(); + let module = func_id.module(self.db); + let constant = match module + .resolve_name(self.db.upcast(), name) + .unwrap() + .unwrap() + { + NamedThing::Item(analyzer_items::Item::Constant(id)) => { + self.db.mir_lowered_constant(id) + } + _ => panic!("name defined in global must be constant"), + }; + let ty = constant.ty(self.db); + self.builder.make_constant(constant, ty) + } + } + + /// Resolve a path appeared in an expression. + /// NOTE: Don't call this to resolve method receiver. + fn resolve_path(&mut self, path: &ast::Path, source: SourceInfo) -> ValueId { + let func_id = self.builder.func_id(); + let module = func_id.module(self.db); + match module.resolve_path(self.db.upcast(), path).value.unwrap() { + NamedThing::Item(analyzer_items::Item::Constant(id)) => { + let constant = self.db.mir_lowered_constant(id); + let ty = constant.ty(self.db); + self.builder.make_constant(constant, ty) + } + NamedThing::EnumVariant(variant) => { + let enum_ty = self + .db + .mir_lowered_type(variant.parent(self.db.upcast()).as_type(self.db.upcast())); + let tag_type = enum_ty.enum_disc_type(self.db); + let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + let data = self.make_unit(); + let enum_args = vec![tag, data]; + let inst = self.builder.aggregate_construct(enum_ty, enum_args, source); + self.map_to_tmp(inst, enum_ty) + } + _ => panic!("path defined in global must be constant"), + } + } + + fn scope(&self) -> &Scope { + &self.scopes[self.current_scope] + } + + fn scope_mut(&mut self) -> &mut Scope { + &mut self.scopes[self.current_scope] + } +} + +#[derive(Debug)] +struct Scope { + parent: Option, + loop_entry: Option, + loop_exit: Option, + variables: FxHashMap, + // TODO: Remove the below two fields when `for` loop desugaring is implemented. + loop_idx: Option, + maximum_iter_count: Option, +} + +impl Scope { + fn root(db: &dyn MirDb, func: FunctionId, builder: &mut BodyBuilder) -> Self { + let mut root = Self { + parent: None, + loop_entry: None, + loop_exit: None, + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + }; + + // Declare function parameters. + for param in &func.signature(db).params { + let local = Local::arg_local(param.name.clone(), param.ty, param.source.clone()); + let value_id = builder.store_func_arg(local); + root.declare_var(¶m.name, value_id) + } + + root + } + + fn with_parent(parent: ScopeId) -> Self { + Self { + parent: parent.into(), + loop_entry: None, + loop_exit: None, + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + } + } + + fn loop_scope(parent: ScopeId, loop_entry: BasicBlockId, loop_exit: BasicBlockId) -> Self { + Self { + parent: parent.into(), + loop_entry: loop_entry.into(), + loop_exit: loop_exit.into(), + variables: FxHashMap::default(), + loop_idx: None, + maximum_iter_count: None, + } + } + + fn loop_entry(&self, scopes: &Arena) -> BasicBlockId { + match self.loop_entry { + Some(entry) => entry, + None => scopes[self.parent.unwrap()].loop_entry(scopes), + } + } + + fn loop_exit(&self, scopes: &Arena) -> BasicBlockId { + match self.loop_exit { + Some(exit) => exit, + None => scopes[self.parent.unwrap()].loop_exit(scopes), + } + } + + fn loop_idx(&self, scopes: &Arena) -> Option { + match self.loop_idx { + Some(idx) => Some(idx), + None => scopes[self.parent?].loop_idx(scopes), + } + } + + fn maximum_iter_count(&self, scopes: &Arena) -> Option { + match self.maximum_iter_count { + Some(count) => Some(count), + None => scopes[self.parent?].maximum_iter_count(scopes), + } + } + + fn declare_var(&mut self, name: &SmolStr, value: ValueId) { + debug_assert!(!self.variables.contains_key(name)); + + self.variables.insert(name.clone(), value); + } + + fn resolve_name(&self, scopes: &Arena, name: &str) -> Option { + match self.variables.get(name) { + Some(id) => Some(*id), + None => scopes[self.parent?].resolve_name(scopes, name), + } + } +} + +fn self_arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId) -> SourceInfo { + func.data(db.upcast()) + .ast + .kind + .sig + .kind + .args + .iter() + .find(|arg| matches!(arg.kind, ast::FunctionArg::Self_ { .. })) + .unwrap() + .into() +} + +fn arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId, arg_name: &str) -> SourceInfo { + func.data(db.upcast()) + .ast + .kind + .sig + .kind + .args + .iter() + .find_map(|arg| match &arg.kind { + ast::FunctionArg::Regular { name, .. } => { + if name.kind == arg_name { + Some(name.into()) + } else { + None + } + } + ast::FunctionArg::Self_ { .. } => None, + }) + .unwrap() +} + +fn make_param( + db: &dyn MirDb, + name: impl Into, + ty: analyzer_types::TypeId, + source: SourceInfo, +) -> FunctionParam { + FunctionParam { + name: name.into(), + ty: db.mir_lowered_type(ty), + source, + } +} diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs new file mode 100644 index 0000000000..36e43653a6 --- /dev/null +++ b/crates/mir2/src/lower/mod.rs @@ -0,0 +1,4 @@ +pub mod function; +pub mod types; + +mod pattern_match; diff --git a/crates/mir2/src/lower/pattern_match/decision_tree.rs b/crates/mir2/src/lower/pattern_match/decision_tree.rs new file mode 100644 index 0000000000..852dfb921a --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/decision_tree.rs @@ -0,0 +1,576 @@ +//! This module contains the decision tree definition and its construction +//! function. +//! The algorithm for efficient decision tree construction is mainly based on [Compiling pattern matching to good decision trees](https://dl.acm.org/doi/10.1145/1411304.1411311). +use std::io; + +use fe_analyzer::{ + pattern_analysis::{ + ConstructorKind, PatternMatrix, PatternRowVec, SigmaSet, SimplifiedPattern, + SimplifiedPatternKind, + }, + AnalyzerDb, +}; +use indexmap::IndexMap; +use smol_str::SmolStr; + +use super::tree_vis::TreeRenderer; + +pub fn build_decision_tree( + db: &dyn AnalyzerDb, + pattern_matrix: &PatternMatrix, + policy: ColumnSelectionPolicy, +) -> DecisionTree { + let builder = DecisionTreeBuilder::new(policy); + let simplified_arms = SimplifiedArmMatrix::new(pattern_matrix); + + builder.build(db, simplified_arms) +} + +#[derive(Debug)] +pub enum DecisionTree { + Leaf(LeafNode), + Switch(SwitchNode), +} + +impl DecisionTree { + #[allow(unused)] + pub fn dump_dot(&self, db: &dyn AnalyzerDb, w: &mut W) -> io::Result<()> + where + W: io::Write, + { + let renderer = TreeRenderer::new(db, self); + dot2::render(&renderer, w).map_err(|err| match err { + dot2::Error::Io(err) => err, + _ => panic!("invalid graphviz id"), + }) + } +} + +#[derive(Debug)] +pub struct LeafNode { + pub arm_idx: usize, + pub binds: IndexMap<(SmolStr, usize), Occurrence>, +} + +impl LeafNode { + fn new(arm: SimplifiedArm, occurrences: &[Occurrence]) -> Self { + let arm_idx = arm.body; + let binds = arm.finalize_binds(occurrences); + Self { arm_idx, binds } + } +} + +#[derive(Debug)] +pub struct SwitchNode { + pub occurrence: Occurrence, + pub arms: Vec<(Case, DecisionTree)>, +} + +#[derive(Debug, Clone, Copy)] +pub enum Case { + Ctor(ConstructorKind), + Default, +} + +#[derive(Debug, Clone, Default)] +pub struct ColumnSelectionPolicy(Vec); + +impl ColumnSelectionPolicy { + /// The score of column i is the sum of the negation of the arities of + /// constructors in sigma(i). + pub fn arity(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::Arity) + } + + /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)). + /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1. + pub fn small_branching(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::SmallBranching) + } + + /// The score is the number of needed rows of column i in the necessity + /// matrix. + #[allow(unused)] + pub fn needed_column(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::NeededColumn) + } + + /// The score is the larger row index j such that column i is needed for all + /// rows j′; 1 ≤ j′ ≤ j. + pub fn needed_prefix(&mut self) -> &mut Self { + self.add_heuristic(ColumnScoringFunction::NeededPrefix) + } + + fn select_column(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> usize { + let mut candidates: Vec<_> = (0..mat.ncols()).collect(); + + for scoring_fn in &self.0 { + let mut max_score = i32::MIN; + for col in std::mem::take(&mut candidates) { + let score = scoring_fn.score(db, mat, col); + match score.cmp(&max_score) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + candidates.push(col); + } + std::cmp::Ordering::Greater => { + candidates = vec![col]; + max_score = score; + } + } + } + + if candidates.len() == 1 { + return candidates.pop().unwrap(); + } + } + + // If there are more than one candidates remained, filter the columns with the + // shortest occurrences among the candidates, then select the rightmost one. + // This heuristics corresponds to the R pseudo heuristic in the paper. + let mut shortest_occurrences = usize::MAX; + for col in std::mem::take(&mut candidates) { + let occurrences = mat.occurrences[col].len(); + match occurrences.cmp(&shortest_occurrences) { + std::cmp::Ordering::Less => { + candidates = vec![col]; + shortest_occurrences = occurrences; + } + std::cmp::Ordering::Equal => { + candidates.push(col); + } + std::cmp::Ordering::Greater => {} + } + } + + candidates.pop().unwrap() + } + + fn add_heuristic(&mut self, heuristic: ColumnScoringFunction) -> &mut Self { + debug_assert!(!self.0.contains(&heuristic)); + self.0.push(heuristic); + self + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Occurrence(Vec); + +impl Occurrence { + pub fn new() -> Self { + Self(vec![]) + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn parent(&self) -> Option { + let mut inner = self.0.clone(); + inner.pop().map(|_| Occurrence(inner)) + } + + pub fn last_index(&self) -> Option { + self.0.last().cloned() + } + + fn phi_specialize(&self, db: &dyn AnalyzerDb, ctor: ConstructorKind) -> Vec { + let arity = ctor.arity(db); + (0..arity) + .map(|i| { + let mut inner = self.0.clone(); + inner.push(i); + Self(inner) + }) + .collect() + } + + fn len(&self) -> usize { + self.0.len() + } +} + +struct DecisionTreeBuilder { + policy: ColumnSelectionPolicy, +} + +impl DecisionTreeBuilder { + fn new(policy: ColumnSelectionPolicy) -> Self { + DecisionTreeBuilder { policy } + } + + fn build(&self, db: &dyn AnalyzerDb, mut mat: SimplifiedArmMatrix) -> DecisionTree { + debug_assert!(mat.nrows() > 0, "unexhausted pattern matrix"); + + if mat.is_first_arm_satisfied() { + mat.arms.truncate(1); + return DecisionTree::Leaf(LeafNode::new(mat.arms.pop().unwrap(), &mat.occurrences)); + } + + let col = self.policy.select_column(db, &mat); + mat.swap(col); + + let mut switch_arms = vec![]; + let occurrence = &mat.occurrences[0]; + let sigma_set = mat.sigma_set(0); + for &ctor in sigma_set.iter() { + let destructured_mat = mat.phi_specialize(db, ctor, occurrence); + let subtree = self.build(db, destructured_mat); + switch_arms.push((Case::Ctor(ctor), subtree)); + } + + if !sigma_set.is_complete(db) { + let destructured_mat = mat.d_specialize(db, occurrence); + let subtree = self.build(db, destructured_mat); + switch_arms.push((Case::Default, subtree)); + } + + DecisionTree::Switch(SwitchNode { + occurrence: occurrence.clone(), + arms: switch_arms, + }) + } +} + +#[derive(Clone, Debug)] +struct SimplifiedArmMatrix { + arms: Vec, + occurrences: Vec, +} + +impl SimplifiedArmMatrix { + fn new(mat: &PatternMatrix) -> Self { + let cols = mat.ncols(); + let arms: Vec<_> = mat + .rows() + .iter() + .enumerate() + .map(|(body, pat)| SimplifiedArm::new(pat, body)) + .collect(); + let occurrences = vec![Occurrence::new(); cols]; + + SimplifiedArmMatrix { arms, occurrences } + } + + fn nrows(&self) -> usize { + self.arms.len() + } + + fn ncols(&self) -> usize { + self.arms[0].pat_vec.len() + } + + fn pat(&self, row: usize, col: usize) -> &SimplifiedPattern { + self.arms[row].pat(col) + } + + fn necessity_matrix(&self, db: &dyn AnalyzerDb) -> NecessityMatrix { + NecessityMatrix::from_mat(db, self) + } + + fn reduced_pat_mat(&self, col: usize) -> PatternMatrix { + let mut rows = Vec::with_capacity(self.nrows()); + for arm in self.arms.iter() { + let reduced_pat_vec = arm + .pat_vec + .pats() + .iter() + .enumerate() + .filter(|(i, _)| (*i != col)) + .map(|(_, pat)| pat.clone()) + .collect(); + rows.push(PatternRowVec::new(reduced_pat_vec)); + } + + PatternMatrix::new(rows) + } + + /// Returns the constructor set in the column i. + fn sigma_set(&self, col: usize) -> SigmaSet { + SigmaSet::from_rows(self.arms.iter().map(|arm| &arm.pat_vec), col) + } + + fn is_first_arm_satisfied(&self) -> bool { + self.arms[0] + .pat_vec + .pats() + .iter() + .all(SimplifiedPattern::is_wildcard) + } + + fn phi_specialize( + &self, + db: &dyn AnalyzerDb, + ctor: ConstructorKind, + occurrence: &Occurrence, + ) -> Self { + let mut new_arms = Vec::new(); + for arm in &self.arms { + new_arms.extend_from_slice(&arm.phi_specialize(db, ctor, occurrence)); + } + + let mut new_occurrences = self.occurrences[0].phi_specialize(db, ctor); + new_occurrences.extend_from_slice(&self.occurrences.as_slice()[1..]); + + Self { + arms: new_arms, + occurrences: new_occurrences, + } + } + + fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Self { + let mut new_arms = Vec::new(); + for arm in &self.arms { + new_arms.extend_from_slice(&arm.d_specialize(db, occurrence)); + } + + Self { + arms: new_arms, + occurrences: self.occurrences.as_slice()[1..].to_vec(), + } + } + + fn swap(&mut self, i: usize) { + for arm in &mut self.arms { + arm.swap(0, i) + } + self.occurrences.swap(0, i); + } +} + +#[derive(Clone, Debug)] +struct SimplifiedArm { + pat_vec: PatternRowVec, + body: usize, + binds: IndexMap<(SmolStr, usize), Occurrence>, +} + +impl SimplifiedArm { + fn new(pat: &PatternRowVec, body: usize) -> Self { + let pat = PatternRowVec::new(pat.inner.iter().map(generalize_pattern).collect()); + Self { + pat_vec: pat, + body, + binds: IndexMap::new(), + } + } + + fn len(&self) -> usize { + self.pat_vec.len() + } + + fn pat(&self, col: usize) -> &SimplifiedPattern { + &self.pat_vec.inner[col] + } + + fn phi_specialize( + &self, + db: &dyn AnalyzerDb, + ctor: ConstructorKind, + occurrence: &Occurrence, + ) -> Vec { + let body = self.body; + let binds = self.new_binds(occurrence); + + self.pat_vec + .phi_specialize(db, ctor) + .into_iter() + .map(|pat| SimplifiedArm { + pat_vec: pat, + body, + binds: binds.clone(), + }) + .collect() + } + + fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Vec { + let body = self.body; + let binds = self.new_binds(occurrence); + + self.pat_vec + .d_specialize(db) + .into_iter() + .map(|pat| SimplifiedArm { + pat_vec: pat, + body, + binds: binds.clone(), + }) + .collect() + } + + fn new_binds(&self, occurrence: &Occurrence) -> IndexMap<(SmolStr, usize), Occurrence> { + let mut binds = self.binds.clone(); + if let Some(SimplifiedPatternKind::WildCard(Some(bind))) = + self.pat_vec.head().map(|pat| &pat.kind) + { + binds.entry(bind.clone()).or_insert(occurrence.clone()); + } + binds + } + + fn finalize_binds(self, occurrences: &[Occurrence]) -> IndexMap<(SmolStr, usize), Occurrence> { + debug_assert!(self.len() == occurrences.len()); + + let mut binds = self.binds; + for (pat, occurrence) in self.pat_vec.pats().iter().zip(occurrences.iter()) { + debug_assert!(pat.is_wildcard()); + + if let SimplifiedPatternKind::WildCard(Some(bind)) = &pat.kind { + binds.entry(bind.clone()).or_insert(occurrence.clone()); + } + } + + binds + } + + fn swap(&mut self, i: usize, j: usize) { + self.pat_vec.swap(i, j); + } +} + +struct NecessityMatrix { + data: Vec, + ncol: usize, + nrow: usize, +} + +impl NecessityMatrix { + fn new(ncol: usize, nrow: usize) -> Self { + let data = vec![false; ncol * nrow]; + Self { data, ncol, nrow } + } + + fn from_mat(db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> Self { + let nrow = mat.nrows(); + let ncol = mat.ncols(); + let mut necessity_mat = Self::new(ncol, nrow); + + necessity_mat.compute(db, mat); + necessity_mat + } + + fn compute(&mut self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) { + for row in 0..self.nrow { + for col in 0..self.ncol { + let pat = mat.pat(row, col); + let pos = self.pos(row, col); + + if !pat.is_wildcard() { + self.data[pos] = true; + } else { + let reduced_pat_mat = mat.reduced_pat_mat(col); + self.data[pos] = !reduced_pat_mat.is_row_useful(db, row); + } + } + } + } + + fn compute_needed_column_score(&self, col: usize) -> i32 { + let mut num = 0; + for i in 0..self.nrow { + if self.data[self.pos(i, col)] { + num += 1; + } + } + + num + } + + fn compute_needed_prefix_score(&self, col: usize) -> i32 { + let mut current_row = 0; + for i in 0..self.nrow { + if self.data[self.pos(i, col)] { + current_row += 1; + } else { + return current_row; + } + } + + current_row + } + + fn pos(&self, row: usize, col: usize) -> usize { + self.ncol * row + col + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ColumnScoringFunction { + /// The score of column i is the sum of the negation of the arities of + /// constructors in sigma(i). + Arity, + + /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)). + /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1. + SmallBranching, + + /// The score is the number of needed rows of column i in the necessity + /// matrix. + NeededColumn, + + NeededPrefix, +} + +impl ColumnScoringFunction { + fn score(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix, col: usize) -> i32 { + match self { + ColumnScoringFunction::Arity => mat + .sigma_set(col) + .iter() + .map(|c| -(c.arity(db) as i32)) + .sum(), + + ColumnScoringFunction::SmallBranching => { + let sigma_set = mat.sigma_set(col); + let score = -(mat.sigma_set(col).len() as i32); + if sigma_set.is_complete(db) { + score + } else { + score - 1 + } + } + + ColumnScoringFunction::NeededColumn => { + mat.necessity_matrix(db).compute_needed_column_score(col) + } + + ColumnScoringFunction::NeededPrefix => { + mat.necessity_matrix(db).compute_needed_prefix_score(col) + } + } + } +} + +fn generalize_pattern(pat: &SimplifiedPattern) -> SimplifiedPattern { + match &pat.kind { + SimplifiedPatternKind::WildCard(_) => pat.clone(), + + SimplifiedPatternKind::Constructor { kind, fields } => { + let fields = fields.iter().map(generalize_pattern).collect(); + let kind = SimplifiedPatternKind::Constructor { + kind: *kind, + fields, + }; + SimplifiedPattern::new(kind, pat.ty) + } + + SimplifiedPatternKind::Or(pats) => { + let mut gen_pats = vec![]; + for pat in pats { + let gen_pad = generalize_pattern(pat); + if gen_pad.is_wildcard() { + gen_pats.push(gen_pad); + break; + } else { + gen_pats.push(gen_pad); + } + } + + if gen_pats.len() == 1 { + gen_pats.pop().unwrap() + } else { + SimplifiedPattern::new(SimplifiedPatternKind::Or(gen_pats), pat.ty) + } + } + } +} diff --git a/crates/mir2/src/lower/pattern_match/mod.rs b/crates/mir2/src/lower/pattern_match/mod.rs new file mode 100644 index 0000000000..2172dd7e94 --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/mod.rs @@ -0,0 +1,326 @@ +use fe_analyzer::pattern_analysis::{ConstructorKind, PatternMatrix}; +use fe_parser2::{ + ast::{Expr, LiteralPattern, MatchArm}, + node::Node, +}; +use fxhash::FxHashMap; +use id_arena::{Arena, Id}; +use smol_str::SmolStr; + +use crate::ir::{ + body_builder::BodyBuilder, inst::SwitchTable, BasicBlockId, SourceInfo, TypeId, ValueId, +}; + +use self::decision_tree::{ + Case, ColumnSelectionPolicy, DecisionTree, LeafNode, Occurrence, SwitchNode, +}; + +use super::function::BodyLowerHelper; + +pub mod decision_tree; +mod tree_vis; + +pub(super) fn lower_match<'b>( + helper: &'b mut BodyLowerHelper<'_, '_>, + mat: &PatternMatrix, + scrutinee: &Node, + arms: &'b [Node], +) { + let mut policy = ColumnSelectionPolicy::default(); + // PBA heuristics described in the paper. + policy.needed_prefix().small_branching().arity(); + + let scrutinee = helper.lower_expr_to_value(scrutinee); + let decision_tree = decision_tree::build_decision_tree(helper.db.upcast(), mat, policy); + + DecisionTreeLowerHelper::new(helper, scrutinee, arms).lower(decision_tree); +} + +struct DecisionTreeLowerHelper<'db, 'a, 'b> { + helper: &'b mut BodyLowerHelper<'db, 'a>, + scopes: Arena, + current_scope: ScopeId, + root_block: BasicBlockId, + declared_vars: FxHashMap<(SmolStr, usize), ValueId>, + arms: &'b [Node], + lowered_arms: FxHashMap, + match_exit: BasicBlockId, +} + +impl<'db, 'a, 'b> DecisionTreeLowerHelper<'db, 'a, 'b> { + fn new( + helper: &'b mut BodyLowerHelper<'db, 'a>, + scrutinee: ValueId, + arms: &'b [Node], + ) -> Self { + let match_exit = helper.builder.make_block(); + + let mut scope = Scope::default(); + scope.register_occurrence(Occurrence::new(), scrutinee); + let mut scopes = Arena::new(); + let current_scope = scopes.alloc(scope); + + let root_block = helper.builder.current_block(); + + DecisionTreeLowerHelper { + helper, + scopes, + current_scope, + root_block, + declared_vars: FxHashMap::default(), + arms, + lowered_arms: FxHashMap::default(), + match_exit, + } + } + + fn lower(&mut self, tree: DecisionTree) { + self.lower_tree(tree); + + let match_exit = self.match_exit; + self.builder().move_to_block(match_exit); + } + + fn lower_tree(&mut self, tree: DecisionTree) { + match tree { + DecisionTree::Leaf(leaf) => self.lower_leaf(leaf), + DecisionTree::Switch(switch) => self.lower_switch(switch), + } + } + + fn lower_leaf(&mut self, leaf: LeafNode) { + for (var, occurrence) in leaf.binds { + let occurrence_value = self.resolve_occurrence(&occurrence); + let ty = self.builder().value_ty(occurrence_value); + let var_value = self.declare_or_use_var(&var, ty); + + let inst = self.builder().bind(occurrence_value, SourceInfo::dummy()); + self.builder().map_result(inst, var_value.into()); + } + + let arm_body = self.lower_arm_body(leaf.arm_idx); + self.builder().jump(arm_body, SourceInfo::dummy()); + } + + fn lower_switch(&mut self, mut switch: SwitchNode) { + let current_bb = self.builder().current_block(); + let occurrence_value = self.resolve_occurrence(&switch.occurrence); + + if switch.arms.len() == 1 { + let arm = switch.arms.pop().unwrap(); + let arm_bb = self.enter_arm(&switch.occurrence, &arm.0); + self.lower_tree(arm.1); + self.builder().move_to_block(current_bb); + self.builder().jump(arm_bb, SourceInfo::dummy()); + return; + } + + let mut table = SwitchTable::default(); + let mut default_arm = None; + let occurrence_ty = self.builder().value_ty(occurrence_value); + + for (case, tree) in switch.arms { + let arm_bb = self.enter_arm(&switch.occurrence, &case); + self.lower_tree(tree); + self.leave_arm(); + + if let Some(disc) = self.case_to_disc(&case, occurrence_ty) { + table.add_arm(disc, arm_bb); + } else { + debug_assert!(default_arm.is_none()); + default_arm = Some(arm_bb); + } + } + + self.builder().move_to_block(current_bb); + let disc = self.extract_disc(occurrence_value); + self.builder() + .switch(disc, table, default_arm, SourceInfo::dummy()); + } + + fn lower_arm_body(&mut self, index: usize) -> BasicBlockId { + if let Some(block) = self.lowered_arms.get(&index) { + *block + } else { + let current_bb = self.builder().current_block(); + let body_bb = self.builder().make_block(); + + self.builder().move_to_block(body_bb); + for stmt in &self.arms[index].kind.body { + self.helper.lower_stmt(stmt); + } + + if !self.builder().is_current_block_terminated() { + let match_exit = self.match_exit; + self.builder().jump(match_exit, SourceInfo::dummy()); + } + + self.lowered_arms.insert(index, body_bb); + self.builder().move_to_block(current_bb); + body_bb + } + } + + fn enter_arm(&mut self, occurrence: &Occurrence, case: &Case) -> BasicBlockId { + self.helper.enter_scope(); + + let bb = self.builder().make_block(); + self.builder().move_to_block(bb); + + let scope = Scope::with_parent(self.current_scope); + self.current_scope = self.scopes.alloc(scope); + + self.update_occurrence(occurrence, case); + bb + } + + fn leave_arm(&mut self) { + self.current_scope = self.scopes[self.current_scope].parent.unwrap(); + self.helper.leave_scope(); + } + + fn case_to_disc(&mut self, case: &Case, occurrence_ty: TypeId) -> Option { + match case { + Case::Ctor(ConstructorKind::Enum(variant)) => { + let disc_ty = occurrence_ty.enum_disc_type(self.helper.db); + let disc = variant.disc(self.helper.db.upcast()); + Some(self.helper.make_imm(disc, disc_ty)) + } + + Case::Ctor(ConstructorKind::Literal((LiteralPattern::Bool(b), ty))) => { + let ty = self.helper.db.mir_lowered_type(*ty); + Some(self.builder().make_imm_from_bool(*b, ty)) + } + + Case::Ctor(ConstructorKind::Tuple(_)) + | Case::Ctor(ConstructorKind::Struct(_)) + | Case::Default => None, + } + } + + fn update_occurrence(&mut self, occurrence: &Occurrence, case: &Case) { + let old_value = self.resolve_occurrence(occurrence); + let old_ty = self.builder().value_ty(old_value); + + match case { + Case::Ctor(ConstructorKind::Enum(variant)) => { + let new_ty = old_ty.enum_variant_type(self.helper.db, *variant); + let cast = self + .builder() + .untag_cast(old_value, new_ty, SourceInfo::dummy()); + let value = self.helper.map_to_tmp(cast, new_ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value) + } + + Case::Ctor(ConstructorKind::Literal((LiteralPattern::Bool(b), _))) => { + let value = self.builder().make_imm_from_bool(*b, old_ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value) + } + + Case::Ctor(ConstructorKind::Tuple(_)) + | Case::Ctor(ConstructorKind::Struct(_)) + | Case::Default => {} + } + } + + fn extract_disc(&mut self, value: ValueId) -> ValueId { + let value_ty = self.builder().value_ty(value); + match value_ty { + _ if value_ty.deref(self.helper.db).is_enum(self.helper.db) => { + let disc_ty = value_ty.enum_disc_type(self.helper.db); + let disc_index = self.helper.make_u256_imm(0); + let inst = + self.builder() + .aggregate_access(value, vec![disc_index], SourceInfo::dummy()); + self.helper.map_to_tmp(inst, disc_ty) + } + + _ => value, + } + } + + fn declare_or_use_var(&mut self, var: &(SmolStr, usize), ty: TypeId) -> ValueId { + if let Some(value) = self.declared_vars.get(var) { + *value + } else { + let current_block = self.builder().current_block(); + let root_block = self.root_block; + self.builder().move_to_block_top(root_block); + let value = self.helper.declare_var(&var.0, ty, SourceInfo::dummy()); + self.builder().move_to_block(current_block); + self.declared_vars.insert(var.clone(), value); + value + } + } + + fn builder(&mut self) -> &mut BodyBuilder { + &mut self.helper.builder + } + + fn resolve_occurrence(&mut self, occurrence: &Occurrence) -> ValueId { + if let Some(value) = self + .current_scope() + .resolve_occurrence(&self.scopes, occurrence) + { + return value; + } + + let parent = occurrence.parent().unwrap(); + let parent_value = self.resolve_occurrence(&parent); + let parent_value_ty = self.builder().value_ty(parent_value); + + let index = occurrence.last_index().unwrap(); + let index_value = self.helper.make_u256_imm(occurrence.last_index().unwrap()); + let inst = + self.builder() + .aggregate_access(parent_value, vec![index_value], SourceInfo::dummy()); + + let ty = parent_value_ty.projection_ty_imm(self.helper.db, index); + let value = self.helper.map_to_tmp(inst, ty); + self.current_scope_mut() + .register_occurrence(occurrence.clone(), value); + value + } + + fn current_scope(&self) -> &Scope { + self.scopes.get(self.current_scope).unwrap() + } + + fn current_scope_mut(&mut self) -> &mut Scope { + self.scopes.get_mut(self.current_scope).unwrap() + } +} + +type ScopeId = Id; + +#[derive(Debug, Default)] +struct Scope { + parent: Option, + occurrences: FxHashMap, +} + +impl Scope { + pub fn with_parent(parent: ScopeId) -> Self { + Self { + parent: Some(parent), + ..Default::default() + } + } + + pub fn register_occurrence(&mut self, occurrence: Occurrence, value: ValueId) { + self.occurrences.insert(occurrence, value); + } + + pub fn resolve_occurrence( + &self, + arena: &Arena, + occurrence: &Occurrence, + ) -> Option { + match self.occurrences.get(occurrence) { + Some(value) => Some(*value), + None => arena[self.parent?].resolve_occurrence(arena, occurrence), + } + } +} diff --git a/crates/mir2/src/lower/pattern_match/tree_vis.rs b/crates/mir2/src/lower/pattern_match/tree_vis.rs new file mode 100644 index 0000000000..9681ecb790 --- /dev/null +++ b/crates/mir2/src/lower/pattern_match/tree_vis.rs @@ -0,0 +1,150 @@ +use std::fmt::Write; + +use dot2::{label::Text, Id}; +use fe_analyzer::{pattern_analysis::ConstructorKind, AnalyzerDb}; +use fxhash::FxHashMap; +use indexmap::IndexMap; +use smol_str::SmolStr; + +use super::decision_tree::{Case, DecisionTree, LeafNode, Occurrence, SwitchNode}; + +pub(super) struct TreeRenderer<'db> { + nodes: Vec, + edges: FxHashMap<(usize, usize), Case>, + db: &'db dyn AnalyzerDb, +} + +impl<'db> TreeRenderer<'db> { + #[allow(unused)] + pub(super) fn new(db: &'db dyn AnalyzerDb, tree: &DecisionTree) -> Self { + let mut renderer = Self { + nodes: Vec::new(), + edges: FxHashMap::default(), + db, + }; + + match tree { + DecisionTree::Leaf(leaf) => { + renderer.nodes.push(Node::from(leaf)); + } + DecisionTree::Switch(switch) => { + renderer.nodes.push(Node::from(switch)); + let node_id = renderer.nodes.len() - 1; + for arm in &switch.arms { + renderer.switch_from(&arm.1, node_id, arm.0); + } + } + } + renderer + } + + fn switch_from(&mut self, tree: &DecisionTree, node_id: usize, case: Case) { + match tree { + DecisionTree::Leaf(leaf) => { + self.nodes.push(Node::from(leaf)); + self.edges.insert((node_id, self.nodes.len() - 1), case); + } + + DecisionTree::Switch(switch) => { + self.nodes.push(Node::from(switch)); + let switch_id = self.nodes.len() - 1; + self.edges.insert((node_id, switch_id), case); + for arm in &switch.arms { + self.switch_from(&arm.1, switch_id, arm.0); + } + } + } + } +} + +impl<'db> dot2::Labeller<'db> for TreeRenderer<'db> { + type Node = usize; + type Edge = (Self::Node, Self::Node); + type Subgraph = (); + + fn graph_id(&self) -> dot2::Result> { + dot2::Id::new("DecisionTree") + } + + fn node_id(&self, n: &Self::Node) -> dot2::Result> { + dot2::Id::new(format!("N{}", *n)) + } + + fn node_label(&self, n: &Self::Node) -> dot2::Result> { + let node = &self.nodes[*n]; + let label = match node { + Node::Leaf { arm_idx, .. } => { + format!("arm_idx: {arm_idx}") + } + Node::Switch(occurrence) => { + let mut s = "expr".to_string(); + for num in occurrence.iter() { + write!(&mut s, ".{num}").unwrap(); + } + s + } + }; + + Ok(Text::LabelStr(label.into())) + } + + fn edge_label(&self, e: &Self::Edge) -> Text<'db> { + let label = match &self.edges[e] { + Case::Ctor(ConstructorKind::Enum(variant)) => { + variant.name_with_parent(self.db).to_string() + } + Case::Ctor(ConstructorKind::Tuple(_)) => "()".to_string(), + Case::Ctor(ConstructorKind::Struct(sid)) => sid.name(self.db).into(), + Case::Ctor(ConstructorKind::Literal((lit, _))) => lit.to_string(), + Case::Default => "_".into(), + }; + + Text::LabelStr(label.into()) + } +} + +impl<'db> dot2::GraphWalk<'db> for TreeRenderer<'db> { + type Node = usize; + type Edge = (Self::Node, Self::Node); + type Subgraph = (); + + fn nodes(&self) -> dot2::Nodes<'db, Self::Node> { + (0..self.nodes.len()).collect() + } + + fn edges(&self) -> dot2::Edges<'db, Self::Edge> { + self.edges.keys().cloned().collect::>().into() + } + + fn source(&self, e: &Self::Edge) -> Self::Node { + e.0 + } + + fn target(&self, e: &Self::Edge) -> Self::Node { + e.1 + } +} + +enum Node { + Leaf { + arm_idx: usize, + #[allow(unused)] + binds: IndexMap<(SmolStr, usize), Occurrence>, + }, + Switch(Occurrence), +} + +impl From<&LeafNode> for Node { + fn from(node: &LeafNode) -> Self { + Node::Leaf { + arm_idx: node.arm_idx, + binds: node.binds.clone(), + } + } +} + +impl From<&SwitchNode> for Node { + fn from(node: &SwitchNode) -> Self { + Node::Switch(node.occurrence.clone()) + } +} diff --git a/crates/mir2/src/lower/types.rs b/crates/mir2/src/lower/types.rs new file mode 100644 index 0000000000..7072eaa96b --- /dev/null +++ b/crates/mir2/src/lower/types.rs @@ -0,0 +1,194 @@ +use crate::{ + db::MirDb, + ir::{ + types::{ArrayDef, EnumDef, EnumVariant, MapDef, StructDef, TupleDef}, + Type, TypeId, TypeKind, + }, +}; + +use fe_analyzer::namespace::{ + items as analyzer_items, + types::{self as analyzer_types, TraitOrType}, +}; + +pub fn lower_type(db: &dyn MirDb, analyzer_ty: analyzer_types::TypeId) -> TypeId { + let ty_kind = match analyzer_ty.typ(db.upcast()) { + analyzer_types::Type::SPtr(inner) => TypeKind::SPtr(lower_type(db, inner)), + + // NOTE: this results in unexpected MIR TypeId inequalities + // (when different analyzer types map to the same MIR type). + // We could (should?) remove .analyzer_ty from Type. + analyzer_types::Type::Mut(inner) => match inner.typ(db.upcast()) { + analyzer_types::Type::SPtr(t) => TypeKind::SPtr(lower_type(db, t)), + analyzer_types::Type::Base(t) => lower_base(t), + analyzer_types::Type::Contract(_) => TypeKind::Address, + _ => TypeKind::MPtr(lower_type(db, inner)), + }, + analyzer_types::Type::SelfType(inner) => match inner { + TraitOrType::TypeId(id) => return lower_type(db, id), + TraitOrType::TraitId(_) => panic!("traits aren't lowered"), + }, + analyzer_types::Type::Base(base) => lower_base(base), + analyzer_types::Type::Array(arr) => lower_array(db, &arr), + analyzer_types::Type::Map(map) => lower_map(db, &map), + analyzer_types::Type::Tuple(tup) => lower_tuple(db, &tup), + analyzer_types::Type::String(string) => TypeKind::String(string.max_size), + analyzer_types::Type::Contract(_) => TypeKind::Address, + analyzer_types::Type::SelfContract(contract) => lower_contract(db, contract), + analyzer_types::Type::Struct(struct_) => lower_struct(db, struct_), + analyzer_types::Type::Enum(enum_) => lower_enum(db, enum_), + analyzer_types::Type::Generic(_) => { + panic!("should be lowered in `lower_analyzer_type`") + } + }; + + intern_type(db, ty_kind, Some(analyzer_ty.deref(db.upcast()))) +} + +fn lower_base(base: analyzer_types::Base) -> TypeKind { + use analyzer_types::{Base, Integer}; + + match base { + Base::Numeric(int_ty) => match int_ty { + Integer::I8 => TypeKind::I8, + Integer::I16 => TypeKind::I16, + Integer::I32 => TypeKind::I32, + Integer::I64 => TypeKind::I64, + Integer::I128 => TypeKind::I128, + Integer::I256 => TypeKind::I256, + Integer::U8 => TypeKind::U8, + Integer::U16 => TypeKind::U16, + Integer::U32 => TypeKind::U32, + Integer::U64 => TypeKind::U64, + Integer::U128 => TypeKind::U128, + Integer::U256 => TypeKind::U256, + }, + + Base::Bool => TypeKind::Bool, + Base::Address => TypeKind::Address, + Base::Unit => TypeKind::Unit, + } +} + +fn lower_array(db: &dyn MirDb, arr: &analyzer_types::Array) -> TypeKind { + let len = arr.size; + let elem_ty = db.mir_lowered_type(arr.inner); + + let def = ArrayDef { elem_ty, len }; + TypeKind::Array(def) +} + +fn lower_map(db: &dyn MirDb, map: &analyzer_types::Map) -> TypeKind { + let key_ty = db.mir_lowered_type(map.key); + let value_ty = db.mir_lowered_type(map.value); + + let def = MapDef { key_ty, value_ty }; + TypeKind::Map(def) +} + +fn lower_tuple(db: &dyn MirDb, tup: &analyzer_types::Tuple) -> TypeKind { + let items = tup + .items + .iter() + .map(|item| db.mir_lowered_type(*item)) + .collect(); + + let def = TupleDef { items }; + TypeKind::Tuple(def) +} + +fn lower_contract(db: &dyn MirDb, contract: analyzer_items::ContractId) -> TypeKind { + let name = contract.name(db.upcast()); + + // Note: contract field types are wrapped in SPtr in TypeId::projection_ty + let fields = contract + .fields(db.upcast()) + .iter() + .map(|(fname, fid)| { + let analyzer_type = fid.typ(db.upcast()).unwrap(); + let ty = db.mir_lowered_type(analyzer_type); + (fname.clone(), ty) + }) + .collect(); + + // Obtain span. + let span = contract.span(db.upcast()); + + let module_id = contract.module(db.upcast()); + + let def = StructDef { + name, + fields, + span, + module_id, + }; + TypeKind::Contract(def) +} + +fn lower_struct(db: &dyn MirDb, id: analyzer_items::StructId) -> TypeKind { + let name = id.name(db.upcast()); + + // Lower struct fields. + let fields = id + .fields(db.upcast()) + .iter() + .map(|(fname, fid)| { + let analyzer_types = fid.typ(db.upcast()).unwrap(); + let ty = db.mir_lowered_type(analyzer_types); + (fname.clone(), ty) + }) + .collect(); + + // obtain span. + let span = id.span(db.upcast()); + + let module_id = id.module(db.upcast()); + + let def = StructDef { + name, + fields, + span, + module_id, + }; + TypeKind::Struct(def) +} + +fn lower_enum(db: &dyn MirDb, id: analyzer_items::EnumId) -> TypeKind { + let analyzer_variants = id.variants(db.upcast()); + let mut variants = Vec::with_capacity(analyzer_variants.len()); + for variant in analyzer_variants.values() { + let variant_ty = match variant.kind(db.upcast()).unwrap() { + analyzer_items::EnumVariantKind::Tuple(elts) => { + let tuple_ty = analyzer_types::TypeId::tuple(db.upcast(), &elts); + db.mir_lowered_type(tuple_ty) + } + analyzer_items::EnumVariantKind::Unit => { + let unit_ty = analyzer_types::TypeId::unit(db.upcast()); + db.mir_lowered_type(unit_ty) + } + }; + + variants.push(EnumVariant { + name: variant.name(db.upcast()), + span: variant.span(db.upcast()), + ty: variant_ty, + }); + } + + let def = EnumDef { + name: id.name(db.upcast()), + span: id.span(db.upcast()), + variants, + module_id: id.module(db.upcast()), + }; + + TypeKind::Enum(def) +} + +fn intern_type( + db: &dyn MirDb, + ty_kind: TypeKind, + analyzer_type: Option, +) -> TypeId { + db.mir_intern_type(Type::new(ty_kind, analyzer_type).into()) +} diff --git a/crates/mir2/src/pretty_print/inst.rs b/crates/mir2/src/pretty_print/inst.rs new file mode 100644 index 0000000000..345d7946a3 --- /dev/null +++ b/crates/mir2/src/pretty_print/inst.rs @@ -0,0 +1,206 @@ +use std::fmt::{self, Write}; + +use crate::{ + db::MirDb, + ir::{function::BodyDataStore, inst::InstKind, InstId}, +}; + +use super::PrettyPrint; + +impl PrettyPrint for InstId { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + if let Some(result) = store.inst_result(*self) { + result.pretty_print(db, store, w)?; + write!(w, ": ")?; + + let result_ty = result.ty(db, store); + result_ty.pretty_print(db, store, w)?; + write!(w, " = ")?; + } + + match &store.inst_data(*self).kind { + InstKind::Declare { local } => { + write!(w, "let ")?; + local.pretty_print(db, store, w)?; + write!(w, ": ")?; + store.value_ty(*local).pretty_print(db, store, w) + } + + InstKind::Unary { op, value } => { + write!(w, "{op}")?; + value.pretty_print(db, store, w) + } + + InstKind::Binary { op, lhs, rhs } => { + lhs.pretty_print(db, store, w)?; + write!(w, " {op} ")?; + rhs.pretty_print(db, store, w) + } + + InstKind::Cast { value, to, .. } => { + value.pretty_print(db, store, w)?; + write!(w, " as ")?; + to.pretty_print(db, store, w) + } + + InstKind::AggregateConstruct { ty, args } => { + ty.pretty_print(db, store, w)?; + write!(w, "{{")?; + if args.is_empty() { + return write!(w, "}}"); + } + + let arg_len = args.len(); + for (arg_idx, arg) in args.iter().enumerate().take(arg_len - 1) { + write!(w, "<{arg_idx}>: ")?; + arg.pretty_print(db, store, w)?; + write!(w, ", ")?; + } + let arg = args[arg_len - 1]; + write!(w, "<{}>: ", arg_len - 1)?; + arg.pretty_print(db, store, w)?; + write!(w, "}}") + } + + InstKind::Bind { src } => { + write!(w, "bind ")?; + src.pretty_print(db, store, w) + } + + InstKind::MemCopy { src } => { + write!(w, "memcopy ")?; + src.pretty_print(db, store, w) + } + + InstKind::Load { src } => { + write!(w, "load ")?; + src.pretty_print(db, store, w) + } + + InstKind::AggregateAccess { value, indices } => { + value.pretty_print(db, store, w)?; + for index in indices { + write!(w, ".<")?; + index.pretty_print(db, store, w)?; + write!(w, ">")? + } + Ok(()) + } + + InstKind::MapAccess { value, key } => { + value.pretty_print(db, store, w)?; + write!(w, "{{")?; + key.pretty_print(db, store, w)?; + write!(w, "}}") + } + + InstKind::Call { + func, + args, + call_type, + } => { + let name = func.debug_name(db); + write!(w, "{name}@{call_type}(")?; + args.as_slice().pretty_print(db, store, w)?; + write!(w, ")") + } + + InstKind::Jump { dest } => { + write!(w, "jump BB{}", dest.index()) + } + + InstKind::Branch { cond, then, else_ } => { + write!(w, "branch ")?; + cond.pretty_print(db, store, w)?; + write!(w, " then: BB{} else: BB{}", then.index(), else_.index()) + } + + InstKind::Switch { + disc, + table, + default, + } => { + write!(w, "switch ")?; + disc.pretty_print(db, store, w)?; + for (value, block) in table.iter() { + write!(w, " ")?; + value.pretty_print(db, store, w)?; + write!(w, ": BB{}", block.index())?; + } + + if let Some(default) = default { + write!(w, " default: BB{}", default.index()) + } else { + Ok(()) + } + } + + InstKind::Revert { arg } => { + write!(w, "revert ")?; + if let Some(arg) = arg { + arg.pretty_print(db, store, w)?; + } + Ok(()) + } + + InstKind::Emit { arg } => { + write!(w, "emit ")?; + arg.pretty_print(db, store, w) + } + + InstKind::Return { arg } => { + if let Some(arg) = arg { + write!(w, "return ")?; + arg.pretty_print(db, store, w) + } else { + write!(w, "return") + } + } + + InstKind::Keccak256 { arg } => { + write!(w, "keccak256 ")?; + arg.pretty_print(db, store, w) + } + + InstKind::AbiEncode { arg } => { + write!(w, "abi_encode ")?; + arg.pretty_print(db, store, w) + } + + InstKind::Nop => { + write!(w, "nop") + } + + InstKind::Create { value, contract } => { + write!(w, "create ")?; + let contract_name = contract.name(db.upcast()); + write!(w, "{contract_name} ")?; + value.pretty_print(db, store, w) + } + + InstKind::Create2 { + value, + salt, + contract, + } => { + write!(w, "create2 ")?; + let contract_name = contract.name(db.upcast()); + write!(w, "{contract_name} ")?; + value.pretty_print(db, store, w)?; + write!(w, " ")?; + salt.pretty_print(db, store, w) + } + + InstKind::YulIntrinsic { op, args } => { + write!(w, "{op}(")?; + args.as_slice().pretty_print(db, store, w)?; + write!(w, ")") + } + } + } +} diff --git a/crates/mir2/src/pretty_print/mod.rs b/crates/mir2/src/pretty_print/mod.rs new file mode 100644 index 0000000000..190bb3ca7e --- /dev/null +++ b/crates/mir2/src/pretty_print/mod.rs @@ -0,0 +1,22 @@ +use std::fmt; + +use crate::{db::MirDb, ir::function::BodyDataStore}; + +mod inst; +mod types; +mod value; + +pub trait PrettyPrint { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result; + + fn pretty_string(&self, db: &dyn MirDb, store: &BodyDataStore) -> String { + let mut s = String::new(); + self.pretty_print(db, store, &mut s).unwrap(); + s + } +} diff --git a/crates/mir2/src/pretty_print/types.rs b/crates/mir2/src/pretty_print/types.rs new file mode 100644 index 0000000000..2574d8260a --- /dev/null +++ b/crates/mir2/src/pretty_print/types.rs @@ -0,0 +1,19 @@ +use std::fmt::{self, Write}; + +use crate::{ + db::MirDb, + ir::{function::BodyDataStore, TypeId}, +}; + +use super::PrettyPrint; + +impl PrettyPrint for TypeId { + fn pretty_print( + &self, + db: &dyn MirDb, + _store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + self.print(db, w) + } +} diff --git a/crates/mir2/src/pretty_print/value.rs b/crates/mir2/src/pretty_print/value.rs new file mode 100644 index 0000000000..05e1ff796d --- /dev/null +++ b/crates/mir2/src/pretty_print/value.rs @@ -0,0 +1,81 @@ +use std::fmt::{self, Write}; + +use crate::{ + db::MirDb, + ir::{ + constant::ConstantValue, function::BodyDataStore, value::AssignableValue, Value, ValueId, + }, +}; + +use super::PrettyPrint; + +impl PrettyPrint for ValueId { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + match store.value_data(*self) { + Value::Temporary { .. } | Value::Local(_) => write!(w, "_{}", self.index()), + Value::Immediate { imm, .. } => write!(w, "{imm}"), + Value::Constant { constant, .. } => { + let const_value = constant.data(db); + write!(w, "const ")?; + match &const_value.value { + ConstantValue::Immediate(num) => write!(w, "{num}"), + ConstantValue::Str(s) => write!(w, r#""{s}""#), + ConstantValue::Bool(b) => write!(w, "{b}"), + } + } + Value::Unit { .. } => write!(w, "()"), + } + } +} + +impl PrettyPrint for &[ValueId] { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + if self.is_empty() { + return Ok(()); + } + + let arg_len = self.len(); + for arg in self.iter().take(arg_len - 1) { + arg.pretty_print(db, store, w)?; + write!(w, ", ")?; + } + let arg = self[arg_len - 1]; + arg.pretty_print(db, store, w) + } +} + +impl PrettyPrint for AssignableValue { + fn pretty_print( + &self, + db: &dyn MirDb, + store: &BodyDataStore, + w: &mut W, + ) -> fmt::Result { + match self { + Self::Value(value) => value.pretty_print(db, store, w), + Self::Aggregate { lhs, idx } => { + lhs.pretty_print(db, store, w)?; + write!(w, ".<")?; + idx.pretty_print(db, store, w)?; + write!(w, ">") + } + + Self::Map { lhs, key } => { + lhs.pretty_print(db, store, w)?; + write!(w, "{{")?; + key.pretty_print(db, store, w)?; + write!(w, "}}") + } + } + } +} diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs new file mode 100644 index 0000000000..03308c7e6f --- /dev/null +++ b/crates/mir2/tests/lowering.rs @@ -0,0 +1,109 @@ +use fe_analyzer::namespace::items::{IngotId, ModuleId}; +use fe_common::{db::Upcast, files::Utf8Path}; +use fe_mir::{ + analysis::{ControlFlowGraph, DomTree, LoopTree, PostDomTree}, + db::{MirDb, NewDb}, +}; + +macro_rules! test_lowering { + ($name:ident, $path:expr) => { + #[test] + fn $name() { + let mut db = NewDb::default(); + + let file_name = Utf8Path::new($path).file_name().unwrap(); + let module = ModuleId::new_standalone(&mut db, file_name, test_files::fixture($path)); + + let diags = module.diagnostics(&db); + if !diags.is_empty() { + panic!("lowering failed") + } + + for func in db.mir_lower_module_all_functions(module).iter() { + let body = func.body(&db); + ControlFlowGraph::compute(&body); + } + } + }; +} + +#[test] +fn mir_lower_std_lib() { + let mut db = NewDb::default(); + + // Should return the same id + let std_ingot = IngotId::std_lib(&mut db); + + let diags = std_ingot.diagnostics(&db); + if !diags.is_empty() { + panic!("std lib analysis failed") + } + + for &module in std_ingot.all_modules(db.upcast()).iter() { + for func in db.mir_lower_module_all_functions(module).iter() { + let body = func.body(&db); + let cfg = ControlFlowGraph::compute(&body); + let domtree = DomTree::compute(&cfg); + LoopTree::compute(&cfg, &domtree); + PostDomTree::compute(&body); + } + } +} + +test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} +test_lowering! { mir_guest_book, "demos/guest_book.fe"} +test_lowering! { mir_uniswap, "demos/uniswap.fe"} +test_lowering! { mir_assert, "features/assert.fe"} +test_lowering! { mir_aug_assign, "features/aug_assign.fe"} +test_lowering! { mir_call_statement_with_args, "features/call_statement_with_args.fe"} +test_lowering! { mir_call_statement_with_args_2, "features/call_statement_with_args_2.fe"} +test_lowering! { mir_call_statement_without_args, "features/call_statement_without_args.fe"} +test_lowering! { mir_checked_arithmetic, "features/checked_arithmetic.fe"} +test_lowering! { mir_constructor, "features/constructor.fe"} +test_lowering! { mir_create2_contract, "features/create2_contract.fe"} +test_lowering! { mir_create_contract, "features/create_contract.fe"} +test_lowering! { mir_create_contract_from_init, "features/create_contract_from_init.fe"} +test_lowering! { mir_empty, "features/empty.fe"} +test_lowering! { mir_events, "features/events.fe"} +test_lowering! { mir_module_level_events, "features/module_level_events.fe"} +test_lowering! { mir_external_contract, "features/external_contract.fe"} +test_lowering! { mir_for_loop_with_break, "features/for_loop_with_break.fe"} +test_lowering! { mir_for_loop_with_continue, "features/for_loop_with_continue.fe"} +test_lowering! { mir_for_loop_with_static_array, "features/for_loop_with_static_array.fe"} +test_lowering! { mir_if_statement, "features/if_statement.fe"} +test_lowering! { mir_if_statement_2, "features/if_statement_2.fe"} +test_lowering! { mir_if_statement_with_block_declaration, "features/if_statement_with_block_declaration.fe"} +test_lowering! { mir_keccak, "features/keccak.fe"} +test_lowering! { mir_math, "features/math.fe"} +test_lowering! { mir_module_const, "features/module_const.fe"} +test_lowering! { mir_multi_param, "features/multi_param.fe"} +test_lowering! { mir_nested_map, "features/nested_map.fe"} +test_lowering! { mir_numeric_sizes, "features/numeric_sizes.fe"} +test_lowering! { mir_ownable, "features/ownable.fe"} +test_lowering! { mir_pure_fn_standalone, "features/pure_fn_standalone.fe"} +test_lowering! { mir_revert, "features/revert.fe"} +test_lowering! { mir_self_address, "features/self_address.fe"} +test_lowering! { mir_send_value, "features/send_value.fe"} +test_lowering! { mir_balances, "features/balances.fe"} +test_lowering! { mir_sized_vals_in_sto, "features/sized_vals_in_sto.fe"} +test_lowering! { mir_strings, "features/strings.fe"} +test_lowering! { mir_structs, "features/structs.fe"} +test_lowering! { mir_struct_fns, "features/struct_fns.fe"} +test_lowering! { mir_ternary_expression, "features/ternary_expression.fe"} +test_lowering! { mir_two_contracts, "features/two_contracts.fe"} +test_lowering! { mir_u8_u8_map, "features/u8_u8_map.fe"} +test_lowering! { mir_u16_u16_map, "features/u16_u16_map.fe"} +test_lowering! { mir_u32_u32_map, "features/u32_u32_map.fe"} +test_lowering! { mir_u64_u64_map, "features/u64_u64_map.fe"} +test_lowering! { mir_u128_u128_map, "features/u128_u128_map.fe"} +test_lowering! { mir_u256_u256_map, "features/u256_u256_map.fe"} +test_lowering! { mir_while_loop, "features/while_loop.fe"} +test_lowering! { mir_while_loop_with_break, "features/while_loop_with_break.fe"} +test_lowering! { mir_while_loop_with_break_2, "features/while_loop_with_break_2.fe"} +test_lowering! { mir_while_loop_with_continue, "features/while_loop_with_continue.fe"} +test_lowering! { mir_abi_encoding_stress, "stress/abi_encoding_stress.fe"} +test_lowering! { mir_data_copying_stress, "stress/data_copying_stress.fe"} +test_lowering! { mir_tuple_stress, "stress/tuple_stress.fe"} +test_lowering! { mir_type_aliases, "features/type_aliases.fe"} +test_lowering! { mir_const_generics, "features/const_generics.fe" } +test_lowering! { mir_const_local, "features/const_local.fe" } From 4ebb7eb688fb5b3b99f52b27c99483f5521066d7 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Sat, 23 Dec 2023 18:58:15 -0700 Subject: [PATCH 02/22] hacking --- crates/mir2/src/db.rs | 4 ++-- crates/mir2/src/db/queries/constant.rs | 2 +- crates/mir2/src/db/queries/contract.rs | 2 +- crates/mir2/src/db/queries/enums.rs | 2 +- crates/mir2/src/db/queries/function.rs | 2 +- crates/mir2/src/db/queries/module.rs | 2 +- crates/mir2/src/db/queries/structs.rs | 2 +- crates/mir2/src/db/queries/types.rs | 6 +++--- crates/mir2/src/graphviz/mod.rs | 2 +- crates/mir2/src/graphviz/module.rs | 2 +- crates/mir2/src/ir/body_builder.rs | 2 +- crates/mir2/src/ir/constant.rs | 4 ++-- crates/mir2/src/ir/function.rs | 4 ++-- crates/mir2/src/ir/inst.rs | 8 ++++---- crates/mir2/src/ir/mod.rs | 2 +- crates/mir2/src/ir/types.rs | 4 ++-- crates/mir2/src/lower/function.rs | 6 +++--- crates/mir2/src/lower/pattern_match/decision_tree.rs | 2 +- crates/mir2/src/lower/pattern_match/mod.rs | 2 +- crates/mir2/src/lower/pattern_match/tree_vis.rs | 2 +- crates/mir2/src/lower/types.rs | 2 +- crates/mir2/tests/lowering.rs | 4 ++-- 22 files changed, 34 insertions(+), 34 deletions(-) diff --git a/crates/mir2/src/db.rs b/crates/mir2/src/db.rs index fc930318de..fa40d8e142 100644 --- a/crates/mir2/src/db.rs +++ b/crates/mir2/src/db.rs @@ -1,12 +1,12 @@ #![allow(clippy::arc_with_non_send_sync)] use std::{collections::BTreeMap, rc::Rc}; -use fe_analyzer::{ +use fe_analyzer2::{ db::AnalyzerDbStorage, namespace::{items as analyzer_items, types as analyzer_types}, AnalyzerDb, }; -use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; +use fe_common2::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; use smol_str::SmolStr; use crate::ir::{self, ConstantId, TypeId}; diff --git a/crates/mir2/src/db/queries/constant.rs b/crates/mir2/src/db/queries/constant.rs index 1e012420c1..4985fd673d 100644 --- a/crates/mir2/src/db/queries/constant.rs +++ b/crates/mir2/src/db/queries/constant.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items as analyzer_items; +use fe_analyzer2::namespace::items as analyzer_items; use crate::{ db::MirDb, diff --git a/crates/mir2/src/db/queries/contract.rs b/crates/mir2/src/db/queries/contract.rs index b36b1893e1..6fe50abfee 100644 --- a/crates/mir2/src/db/queries/contract.rs +++ b/crates/mir2/src/db/queries/contract.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items::{self as analyzer_items}; +use fe_analyzer2::namespace::items::{self as analyzer_items}; use crate::{db::MirDb, ir::FunctionId}; diff --git a/crates/mir2/src/db/queries/enums.rs b/crates/mir2/src/db/queries/enums.rs index 2fb26cb478..7fc0384ab0 100644 --- a/crates/mir2/src/db/queries/enums.rs +++ b/crates/mir2/src/db/queries/enums.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items::{self as analyzer_items}; +use fe_analyzer2::namespace::items::{self as analyzer_items}; use crate::{db::MirDb, ir::FunctionId}; diff --git a/crates/mir2/src/db/queries/function.rs b/crates/mir2/src/db/queries/function.rs index e9f0e9f282..211a5cadfb 100644 --- a/crates/mir2/src/db/queries/function.rs +++ b/crates/mir2/src/db/queries/function.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, rc::Rc}; -use fe_analyzer::{ +use fe_analyzer2::{ display::Displayable, namespace::{items as analyzer_items, items::Item, types as analyzer_types}, }; diff --git a/crates/mir2/src/db/queries/module.rs b/crates/mir2/src/db/queries/module.rs index 00f0ea6a3c..b7d00521ac 100644 --- a/crates/mir2/src/db/queries/module.rs +++ b/crates/mir2/src/db/queries/module.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items::{self as analyzer_items, TypeDef}; +use fe_analyzer2::namespace::items::{self as analyzer_items, TypeDef}; use crate::{db::MirDb, ir::FunctionId}; diff --git a/crates/mir2/src/db/queries/structs.rs b/crates/mir2/src/db/queries/structs.rs index 8ca121f94a..c1a859718d 100644 --- a/crates/mir2/src/db/queries/structs.rs +++ b/crates/mir2/src/db/queries/structs.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items::{self as analyzer_items}; +use fe_analyzer2::namespace::items::{self as analyzer_items}; use crate::{db::MirDb, ir::FunctionId}; diff --git a/crates/mir2/src/db/queries/types.rs b/crates/mir2/src/db/queries/types.rs index a0d13511d5..fe1261bfc5 100644 --- a/crates/mir2/src/db/queries/types.rs +++ b/crates/mir2/src/db/queries/types.rs @@ -1,6 +1,6 @@ use std::{fmt, rc::Rc, str::FromStr}; -use fe_analyzer::namespace::{items::EnumVariantId, types as analyzer_types}; +use fe_analyzer2::namespace::{items::EnumVariantId, types as analyzer_types}; use num_bigint::BigInt; use num_traits::ToPrimitive; @@ -470,8 +470,8 @@ fn round_up(value: usize, slot_size: usize) -> usize { #[cfg(test)] mod tests { - use fe_analyzer::namespace::items::ModuleId; - use fe_common::Span; + use fe_analyzer2::namespace::items::ModuleId; + use fe_common2::Span; use super::*; use crate::{ diff --git a/crates/mir2/src/graphviz/mod.rs b/crates/mir2/src/graphviz/mod.rs index 8ab37cd37e..c79335a04e 100644 --- a/crates/mir2/src/graphviz/mod.rs +++ b/crates/mir2/src/graphviz/mod.rs @@ -1,6 +1,6 @@ use std::io; -use fe_analyzer::namespace::items::ModuleId; +use fe_analyzer2::namespace::items::ModuleId; use crate::db::MirDb; diff --git a/crates/mir2/src/graphviz/module.rs b/crates/mir2/src/graphviz/module.rs index 4b0c395b25..8280e76c7f 100644 --- a/crates/mir2/src/graphviz/module.rs +++ b/crates/mir2/src/graphviz/module.rs @@ -1,5 +1,5 @@ use dot2::{label::Text, GraphWalk, Id, Kind, Labeller}; -use fe_analyzer::namespace::items::ModuleId; +use fe_analyzer2::namespace::items::ModuleId; use crate::{ db::MirDb, diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs index 3dee893a73..0807d7d215 100644 --- a/crates/mir2/src/ir/body_builder.rs +++ b/crates/mir2/src/ir/body_builder.rs @@ -1,4 +1,4 @@ -use fe_analyzer::namespace::items::ContractId; +use fe_analyzer2::namespace::items::ContractId; use num_bigint::BigInt; use crate::ir::{ diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs index 68466ff3e7..f7b892f824 100644 --- a/crates/mir2/src/ir/constant.rs +++ b/crates/mir2/src/ir/constant.rs @@ -1,8 +1,8 @@ -use fe_common::impl_intern_key; +use fe_common2::impl_intern_key; use num_bigint::BigInt; use smol_str::SmolStr; -use fe_analyzer::{context, namespace::items as analyzer_items}; +use fe_analyzer2::{context, namespace::items as analyzer_items}; use super::{SourceInfo, TypeId}; diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs index c359f20f71..631ad78a94 100644 --- a/crates/mir2/src/ir/function.rs +++ b/crates/mir2/src/ir/function.rs @@ -1,5 +1,5 @@ -use fe_analyzer::namespace::{items as analyzer_items, types as analyzer_types}; -use fe_common::impl_intern_key; +use fe_analyzer2::namespace::{items as analyzer_items, types as analyzer_types}; +use fe_common2::impl_intern_key; use fxhash::FxHashMap; use id_arena::Arena; use num_bigint::BigInt; diff --git a/crates/mir2/src/ir/inst.rs b/crates/mir2/src/ir/inst.rs index 4ef76fa906..6948495150 100644 --- a/crates/mir2/src/ir/inst.rs +++ b/crates/mir2/src/ir/inst.rs @@ -1,6 +1,6 @@ use std::fmt; -use fe_analyzer::namespace::items::ContractId; +use fe_analyzer2::namespace::items::ContractId; use id_arena::Id; use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId, SourceInfo, TypeId}; @@ -576,9 +576,9 @@ impl fmt::Display for YulIntrinsicOp { } } -impl From for YulIntrinsicOp { - fn from(val: fe_analyzer::builtins::Intrinsic) -> Self { - use fe_analyzer::builtins::Intrinsic; +impl From for YulIntrinsicOp { + fn from(val: fe_analyzer2::builtins::Intrinsic) -> Self { + use fe_analyzer2::builtins::Intrinsic; match val { Intrinsic::__stop => Self::Stop, Intrinsic::__add => Self::Add, diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index 0327b0348b..0416fe9619 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -1,4 +1,4 @@ -use fe_common::Span; +use fe_common2::Span; use fe_parser2::node::{Node, NodeId}; pub mod basic_block; diff --git a/crates/mir2/src/ir/types.rs b/crates/mir2/src/ir/types.rs index 8bdd9995c2..a1c9854a8c 100644 --- a/crates/mir2/src/ir/types.rs +++ b/crates/mir2/src/ir/types.rs @@ -1,5 +1,5 @@ -use fe_analyzer::namespace::{items as analyzer_items, types as analyzer_types}; -use fe_common::{impl_intern_key, Span}; +use fe_analyzer2::namespace::{items as analyzer_items, types as analyzer_types}; +use fe_common2::{impl_intern_key, Span}; use smol_str::SmolStr; #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs index a1d0212f6c..35a3a94775 100644 --- a/crates/mir2/src/lower/function.rs +++ b/crates/mir2/src/lower/function.rs @@ -9,7 +9,7 @@ use fe_analyzer2::{ types::{self as analyzer_types, Type}, }, }; -use fe_common::numeric::Literal; +use fe_common2::numeric::Literal; use fe_parser2::{ast, node::Node}; use fxhash::FxHashMap; use id_arena::{Arena, Id}; @@ -102,7 +102,7 @@ pub(super) struct BodyLowerHelper<'db, 'a> { pub(super) builder: BodyBuilder, ast: &'a Node, func: FunctionId, - analyzer_body: &'a fe_analyzer::context::FunctionBody, + analyzer_body: &'a fe_analyzer2::context::FunctionBody, scopes: Arena, current_scope: ScopeId, } @@ -565,7 +565,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { db: &'db dyn MirDb, func: FunctionId, ast: &'a Node, - analyzer_body: &'a fe_analyzer::context::FunctionBody, + analyzer_body: &'a fe_analyzer2::context::FunctionBody, ) -> Self { let mut builder = BodyBuilder::new(func, ast.into()); let mut scopes = Arena::new(); diff --git a/crates/mir2/src/lower/pattern_match/decision_tree.rs b/crates/mir2/src/lower/pattern_match/decision_tree.rs index 852dfb921a..fb9d93ed7a 100644 --- a/crates/mir2/src/lower/pattern_match/decision_tree.rs +++ b/crates/mir2/src/lower/pattern_match/decision_tree.rs @@ -3,7 +3,7 @@ //! The algorithm for efficient decision tree construction is mainly based on [Compiling pattern matching to good decision trees](https://dl.acm.org/doi/10.1145/1411304.1411311). use std::io; -use fe_analyzer::{ +use fe_analyzer2::{ pattern_analysis::{ ConstructorKind, PatternMatrix, PatternRowVec, SigmaSet, SimplifiedPattern, SimplifiedPatternKind, diff --git a/crates/mir2/src/lower/pattern_match/mod.rs b/crates/mir2/src/lower/pattern_match/mod.rs index 2172dd7e94..6f38beec0a 100644 --- a/crates/mir2/src/lower/pattern_match/mod.rs +++ b/crates/mir2/src/lower/pattern_match/mod.rs @@ -1,4 +1,4 @@ -use fe_analyzer::pattern_analysis::{ConstructorKind, PatternMatrix}; +use fe_analyzer2::pattern_analysis::{ConstructorKind, PatternMatrix}; use fe_parser2::{ ast::{Expr, LiteralPattern, MatchArm}, node::Node, diff --git a/crates/mir2/src/lower/pattern_match/tree_vis.rs b/crates/mir2/src/lower/pattern_match/tree_vis.rs index 9681ecb790..d13d0d3921 100644 --- a/crates/mir2/src/lower/pattern_match/tree_vis.rs +++ b/crates/mir2/src/lower/pattern_match/tree_vis.rs @@ -1,7 +1,7 @@ use std::fmt::Write; use dot2::{label::Text, Id}; -use fe_analyzer::{pattern_analysis::ConstructorKind, AnalyzerDb}; +use fe_analyzer2::{pattern_analysis::ConstructorKind, AnalyzerDb}; use fxhash::FxHashMap; use indexmap::IndexMap; use smol_str::SmolStr; diff --git a/crates/mir2/src/lower/types.rs b/crates/mir2/src/lower/types.rs index 7072eaa96b..b4b9697a84 100644 --- a/crates/mir2/src/lower/types.rs +++ b/crates/mir2/src/lower/types.rs @@ -6,7 +6,7 @@ use crate::{ }, }; -use fe_analyzer::namespace::{ +use fe_analyzer2::namespace::{ items as analyzer_items, types::{self as analyzer_types, TraitOrType}, }; diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index 03308c7e6f..27a0105735 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,5 +1,5 @@ -use fe_analyzer::namespace::items::{IngotId, ModuleId}; -use fe_common::{db::Upcast, files::Utf8Path}; +use fe_analyzer2::namespace::items::{IngotId, ModuleId}; +use fe_common2::{db::Upcast, files::Utf8Path}; use fe_mir::{ analysis::{ControlFlowGraph, DomTree, LoopTree, PostDomTree}, db::{MirDb, NewDb}, From 68e0707afd59a88ced4f07ea5e7d0e43dce1c3e2 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Fri, 5 Jan 2024 13:32:40 -0700 Subject: [PATCH 03/22] hacking --- Cargo.lock | 21 ++- crates/mir2-analysis/Cargo.toml | 22 ++++ .../{mir2 => mir2-analysis}/src/db/queries.rs | 0 .../src/db/queries/constant.rs | 2 - .../src/db/queries/contract.rs | 2 - .../src/db/queries/enums.rs | 2 - .../src/db/queries/function.rs | 0 .../src/db/queries/module.rs | 0 .../src/db/queries/structs.rs | 0 .../src/db/queries/types.rs | 0 crates/mir2-analysis/src/lib.rs | 121 ++++++++++++++++++ crates/mir2/Cargo.toml | 4 +- crates/mir2/src/db.rs | 104 --------------- crates/mir2/src/ir/body_builder.rs | 1 - crates/mir2/src/ir/constant.rs | 2 - crates/mir2/src/ir/function.rs | 2 - crates/mir2/src/ir/inst.rs | 1 - crates/mir2/src/ir/mod.rs | 3 - crates/mir2/src/ir/types.rs | 2 - crates/mir2/src/lib.rs | 20 ++- crates/mir2/src/lower/function.rs | 9 -- crates/mir2/src/lower/types.rs | 2 +- crates/mir2/tests/lowering.rs | 7 - 23 files changed, 186 insertions(+), 141 deletions(-) create mode 100644 crates/mir2-analysis/Cargo.toml rename crates/{mir2 => mir2-analysis}/src/db/queries.rs (100%) rename crates/{mir2 => mir2-analysis}/src/db/queries/constant.rs (94%) rename crates/{mir2 => mir2-analysis}/src/db/queries/contract.rs (85%) rename crates/{mir2 => mir2-analysis}/src/db/queries/enums.rs (84%) rename crates/{mir2 => mir2-analysis}/src/db/queries/function.rs (100%) rename crates/{mir2 => mir2-analysis}/src/db/queries/module.rs (100%) rename crates/{mir2 => mir2-analysis}/src/db/queries/structs.rs (100%) rename crates/{mir2 => mir2-analysis}/src/db/queries/types.rs (100%) create mode 100644 crates/mir2-analysis/src/lib.rs delete mode 100644 crates/mir2/src/db.rs diff --git a/Cargo.lock b/Cargo.lock index 559d3b5735..dba014a613 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1136,6 +1136,8 @@ version = "0.23.0" dependencies = [ "dot2", "fe-common2", + "fe-hir", + "fe-hir-analysis", "fe-library", "fe-parser2", "fe-test-files", @@ -1145,7 +1147,24 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", - "salsa", + "salsa-2022", + "smol_str", +] + +[[package]] +name = "fe-mir2-analysis" +version = "0.23.0" +dependencies = [ + "dot2", + "fe-common2", + "fe-hir", + "fe-hir-analysis", + "fe-library", + "fe-mir2", + "fe-test-files", + "fxhash", + "id-arena", + "salsa-2022", "smol_str", ] diff --git a/crates/mir2-analysis/Cargo.toml b/crates/mir2-analysis/Cargo.toml new file mode 100644 index 0000000000..24a92769dd --- /dev/null +++ b/crates/mir2-analysis/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fe-mir2-analysis" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +fe-common2 = { path = "../common2", version = "^0.23.0" } +fe-hir-analysis = { path = "../hir-analysis", version = "^0.23.0" } +fe-hir = { path = "../hir", version = "^0.23.0" } +fe-mir2 = { path = "../mir2", version = "^0.23.0" } +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } +smol_str = "0.1.21" +id-arena = "2.2.1" +fxhash = "0.2.1" +dot2 = "1.0.0" + +[dev-dependencies] +test-files = { path = "../test-files", package = "fe-test-files" } +fe-library = { path = "../library" } diff --git a/crates/mir2/src/db/queries.rs b/crates/mir2-analysis/src/db/queries.rs similarity index 100% rename from crates/mir2/src/db/queries.rs rename to crates/mir2-analysis/src/db/queries.rs diff --git a/crates/mir2/src/db/queries/constant.rs b/crates/mir2-analysis/src/db/queries/constant.rs similarity index 94% rename from crates/mir2/src/db/queries/constant.rs rename to crates/mir2-analysis/src/db/queries/constant.rs index 4985fd673d..edd204d7ff 100644 --- a/crates/mir2/src/db/queries/constant.rs +++ b/crates/mir2-analysis/src/db/queries/constant.rs @@ -1,7 +1,5 @@ use std::rc::Rc; -use fe_analyzer2::namespace::items as analyzer_items; - use crate::{ db::MirDb, ir::{Constant, ConstantId, SourceInfo, TypeId}, diff --git a/crates/mir2/src/db/queries/contract.rs b/crates/mir2-analysis/src/db/queries/contract.rs similarity index 85% rename from crates/mir2/src/db/queries/contract.rs rename to crates/mir2-analysis/src/db/queries/contract.rs index 6fe50abfee..d7bcf742a4 100644 --- a/crates/mir2/src/db/queries/contract.rs +++ b/crates/mir2-analysis/src/db/queries/contract.rs @@ -1,7 +1,5 @@ use std::rc::Rc; -use fe_analyzer2::namespace::items::{self as analyzer_items}; - use crate::{db::MirDb, ir::FunctionId}; pub fn mir_lower_contract_all_functions( diff --git a/crates/mir2/src/db/queries/enums.rs b/crates/mir2-analysis/src/db/queries/enums.rs similarity index 84% rename from crates/mir2/src/db/queries/enums.rs rename to crates/mir2-analysis/src/db/queries/enums.rs index 7fc0384ab0..5082d76e42 100644 --- a/crates/mir2/src/db/queries/enums.rs +++ b/crates/mir2-analysis/src/db/queries/enums.rs @@ -1,7 +1,5 @@ use std::rc::Rc; -use fe_analyzer2::namespace::items::{self as analyzer_items}; - use crate::{db::MirDb, ir::FunctionId}; pub fn mir_lower_enum_all_functions( diff --git a/crates/mir2/src/db/queries/function.rs b/crates/mir2-analysis/src/db/queries/function.rs similarity index 100% rename from crates/mir2/src/db/queries/function.rs rename to crates/mir2-analysis/src/db/queries/function.rs diff --git a/crates/mir2/src/db/queries/module.rs b/crates/mir2-analysis/src/db/queries/module.rs similarity index 100% rename from crates/mir2/src/db/queries/module.rs rename to crates/mir2-analysis/src/db/queries/module.rs diff --git a/crates/mir2/src/db/queries/structs.rs b/crates/mir2-analysis/src/db/queries/structs.rs similarity index 100% rename from crates/mir2/src/db/queries/structs.rs rename to crates/mir2-analysis/src/db/queries/structs.rs diff --git a/crates/mir2/src/db/queries/types.rs b/crates/mir2-analysis/src/db/queries/types.rs similarity index 100% rename from crates/mir2/src/db/queries/types.rs rename to crates/mir2-analysis/src/db/queries/types.rs diff --git a/crates/mir2-analysis/src/lib.rs b/crates/mir2-analysis/src/lib.rs new file mode 100644 index 0000000000..b246e3cfac --- /dev/null +++ b/crates/mir2-analysis/src/lib.rs @@ -0,0 +1,121 @@ +use fe_mir2::{ir, MirDb}; + +#[salsa::jar(db = MirAnalysisDb)] +pub struct Jar(ir::ConstantId, ir::FunctionId); + +pub trait HirAnalysisDb: salsa::DbWithJar + HirDb { + fn as_hir_analysis_db(&self) -> &dyn HirAnalysisDb { + >::as_jar_db::<'_>(self) + } +} +impl HirAnalysisDb for DB where DB: ?Sized + salsa::DbWithJar + HirDb {} + +pub mod name_resolution; +pub mod ty; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Spanned { + pub data: T, + pub span: DynLazySpan, +} + +// old mir db.rs +// +// #![allow(clippy::arc_with_non_send_sync)] +// use std::{collections::BTreeMap, rc::Rc}; + +// use smol_str::SmolStr; + +// use crate::ir::{self, ConstantId, TypeId}; + +// mod queries; + +// #[salsa::query_group(MirDbStorage)] +// pub trait MirDb: AnalyzerDb + Upcast + UpcastMut { +// #[salsa::interned] +// fn mir_intern_const(&self, data: Rc) -> ir::ConstantId; +// #[salsa::interned] +// fn mir_intern_type(&self, data: Rc) -> ir::TypeId; +// #[salsa::interned] +// fn mir_intern_function(&self, data: Rc) -> ir::FunctionId; + +// #[salsa::invoke(queries::module::mir_lower_module_all_functions)] +// fn mir_lower_module_all_functions( +// &self, +// module: analyzer_items::ModuleId, +// ) -> Rc>; + +// #[salsa::invoke(queries::contract::mir_lower_contract_all_functions)] +// fn mir_lower_contract_all_functions( +// &self, +// contract: analyzer_items::ContractId, +// ) -> Rc>; + +// #[salsa::invoke(queries::structs::mir_lower_struct_all_functions)] +// fn mir_lower_struct_all_functions( +// &self, +// struct_: analyzer_items::StructId, +// ) -> Rc>; + +// #[salsa::invoke(queries::enums::mir_lower_enum_all_functions)] +// fn mir_lower_enum_all_functions( +// &self, +// enum_: analyzer_items::EnumId, +// ) -> Rc>; + +// #[salsa::invoke(queries::types::mir_lowered_type)] +// fn mir_lowered_type(&self, analyzer_type: analyzer_types::TypeId) -> TypeId; + +// #[salsa::invoke(queries::constant::mir_lowered_constant)] +// fn mir_lowered_constant(&self, analyzer_const: analyzer_items::ModuleConstantId) -> ConstantId; + +// #[salsa::invoke(queries::function::mir_lowered_func_signature)] +// fn mir_lowered_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_monomorphized_func_signature)] +// fn mir_lowered_monomorphized_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// resolved_generics: BTreeMap, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_pseudo_monomorphized_func_signature)] +// fn mir_lowered_pseudo_monomorphized_func_signature( +// &self, +// analyzer_func: analyzer_items::FunctionId, +// ) -> ir::FunctionId; +// #[salsa::invoke(queries::function::mir_lowered_func_body)] +// fn mir_lowered_func_body(&self, func: ir::FunctionId) -> Rc; +// } + +// #[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage)] +// #[derive(Default)] +// pub struct NewDb { +// storage: salsa::Storage, +// } +// impl salsa::Database for NewDb {} + +// impl Upcast for NewDb { +// fn upcast(&self) -> &(dyn SourceDb + 'static) { +// self +// } +// } + +// impl UpcastMut for NewDb { +// fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for NewDb { +// fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { +// self +// } +// } + +// impl UpcastMut for NewDb { +// fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { +// &mut *self +// } +// } diff --git a/crates/mir2/Cargo.toml b/crates/mir2/Cargo.toml index 6ca3c5445d..b72749c6d3 100644 --- a/crates/mir2/Cargo.toml +++ b/crates/mir2/Cargo.toml @@ -9,7 +9,9 @@ repository = "https://github.com/ethereum/fe" [dependencies] fe-common2 = { path = "../common2", version = "^0.23.0" } fe-parser2 = { path = "../parser2", version = "^0.23.0" } -salsa = "0.16.1" +fe-hir-analysis = { path = "../hir-analysis", version = "^0.23.0" } +fe-hir = { path = "../hir", version = "^0.23.0" } +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } smol_str = "0.1.21" num-bigint = "0.4.3" num-traits = "0.2.14" diff --git a/crates/mir2/src/db.rs b/crates/mir2/src/db.rs deleted file mode 100644 index fa40d8e142..0000000000 --- a/crates/mir2/src/db.rs +++ /dev/null @@ -1,104 +0,0 @@ -#![allow(clippy::arc_with_non_send_sync)] -use std::{collections::BTreeMap, rc::Rc}; - -use fe_analyzer2::{ - db::AnalyzerDbStorage, - namespace::{items as analyzer_items, types as analyzer_types}, - AnalyzerDb, -}; -use fe_common2::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; -use smol_str::SmolStr; - -use crate::ir::{self, ConstantId, TypeId}; - -mod queries; - -#[salsa::query_group(MirDbStorage)] -pub trait MirDb: AnalyzerDb + Upcast + UpcastMut { - #[salsa::interned] - fn mir_intern_const(&self, data: Rc) -> ir::ConstantId; - #[salsa::interned] - fn mir_intern_type(&self, data: Rc) -> ir::TypeId; - #[salsa::interned] - fn mir_intern_function(&self, data: Rc) -> ir::FunctionId; - - #[salsa::invoke(queries::module::mir_lower_module_all_functions)] - fn mir_lower_module_all_functions( - &self, - module: analyzer_items::ModuleId, - ) -> Rc>; - - #[salsa::invoke(queries::contract::mir_lower_contract_all_functions)] - fn mir_lower_contract_all_functions( - &self, - contract: analyzer_items::ContractId, - ) -> Rc>; - - #[salsa::invoke(queries::structs::mir_lower_struct_all_functions)] - fn mir_lower_struct_all_functions( - &self, - struct_: analyzer_items::StructId, - ) -> Rc>; - - #[salsa::invoke(queries::enums::mir_lower_enum_all_functions)] - fn mir_lower_enum_all_functions( - &self, - enum_: analyzer_items::EnumId, - ) -> Rc>; - - #[salsa::invoke(queries::types::mir_lowered_type)] - fn mir_lowered_type(&self, analyzer_type: analyzer_types::TypeId) -> TypeId; - - #[salsa::invoke(queries::constant::mir_lowered_constant)] - fn mir_lowered_constant(&self, analyzer_const: analyzer_items::ModuleConstantId) -> ConstantId; - - #[salsa::invoke(queries::function::mir_lowered_func_signature)] - fn mir_lowered_func_signature( - &self, - analyzer_func: analyzer_items::FunctionId, - ) -> ir::FunctionId; - #[salsa::invoke(queries::function::mir_lowered_monomorphized_func_signature)] - fn mir_lowered_monomorphized_func_signature( - &self, - analyzer_func: analyzer_items::FunctionId, - resolved_generics: BTreeMap, - ) -> ir::FunctionId; - #[salsa::invoke(queries::function::mir_lowered_pseudo_monomorphized_func_signature)] - fn mir_lowered_pseudo_monomorphized_func_signature( - &self, - analyzer_func: analyzer_items::FunctionId, - ) -> ir::FunctionId; - #[salsa::invoke(queries::function::mir_lowered_func_body)] - fn mir_lowered_func_body(&self, func: ir::FunctionId) -> Rc; -} - -#[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage)] -#[derive(Default)] -pub struct NewDb { - storage: salsa::Storage, -} -impl salsa::Database for NewDb {} - -impl Upcast for NewDb { - fn upcast(&self) -> &(dyn SourceDb + 'static) { - self - } -} - -impl UpcastMut for NewDb { - fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { - &mut *self - } -} - -impl Upcast for NewDb { - fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { - self - } -} - -impl UpcastMut for NewDb { - fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { - &mut *self - } -} diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs index 0807d7d215..7190601cfd 100644 --- a/crates/mir2/src/ir/body_builder.rs +++ b/crates/mir2/src/ir/body_builder.rs @@ -1,4 +1,3 @@ -use fe_analyzer2::namespace::items::ContractId; use num_bigint::BigInt; use crate::ir::{ diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs index f7b892f824..c9eac33ccf 100644 --- a/crates/mir2/src/ir/constant.rs +++ b/crates/mir2/src/ir/constant.rs @@ -2,8 +2,6 @@ use fe_common2::impl_intern_key; use num_bigint::BigInt; use smol_str::SmolStr; -use fe_analyzer2::{context, namespace::items as analyzer_items}; - use super::{SourceInfo, TypeId}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs index 631ad78a94..1f6f30e915 100644 --- a/crates/mir2/src/ir/function.rs +++ b/crates/mir2/src/ir/function.rs @@ -1,5 +1,3 @@ -use fe_analyzer2::namespace::{items as analyzer_items, types as analyzer_types}; -use fe_common2::impl_intern_key; use fxhash::FxHashMap; use id_arena::Arena; use num_bigint::BigInt; diff --git a/crates/mir2/src/ir/inst.rs b/crates/mir2/src/ir/inst.rs index 6948495150..c86f2bc8fb 100644 --- a/crates/mir2/src/ir/inst.rs +++ b/crates/mir2/src/ir/inst.rs @@ -1,6 +1,5 @@ use std::fmt; -use fe_analyzer2::namespace::items::ContractId; use id_arena::Id; use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId, SourceInfo, TypeId}; diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index 0416fe9619..ece3b03e54 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -1,6 +1,3 @@ -use fe_common2::Span; -use fe_parser2::node::{Node, NodeId}; - pub mod basic_block; pub mod body_builder; pub mod body_cursor; diff --git a/crates/mir2/src/ir/types.rs b/crates/mir2/src/ir/types.rs index a1c9854a8c..65dc60668b 100644 --- a/crates/mir2/src/ir/types.rs +++ b/crates/mir2/src/ir/types.rs @@ -1,5 +1,3 @@ -use fe_analyzer2::namespace::{items as analyzer_items, types as analyzer_types}; -use fe_common2::{impl_intern_key, Span}; use smol_str::SmolStr; #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index f938352df8..0515811fda 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -1,7 +1,25 @@ pub mod analysis; -pub mod db; pub mod graphviz; pub mod ir; pub mod pretty_print; mod lower; + +#[salsa::jar(db = MirDb)] +pub struct Jar( + ir::BasicBlock, + ir::BasicBlockId, + ir::Constant, + ir::ConstantId, + ir::FunctionBody, + ir::FunctionId, + ir::FunctionParam, + ir::FunctionSignature, + ir::Inst, + ir::InstId, + ir::Type, + ir::TypeId, + ir::TypeKind, + ir::Value, + ir::ValueId, +); diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs index 35a3a94775..9b4565aa21 100644 --- a/crates/mir2/src/lower/function.rs +++ b/crates/mir2/src/lower/function.rs @@ -1,14 +1,5 @@ use std::{collections::BTreeMap, rc::Rc, vec}; -use fe_analyzer2::{ - builtins::{ContractTypeMethod, GlobalFunction, ValueMethod}, - constants::{EMITTABLE_TRAIT_NAME, EMIT_FN_NAME}, - context::{Adjustment, AdjustmentKind, CallType as AnalyzerCallType, NamedThing}, - namespace::{ - items as analyzer_items, - types::{self as analyzer_types, Type}, - }, -}; use fe_common2::numeric::Literal; use fe_parser2::{ast, node::Node}; use fxhash::FxHashMap; diff --git a/crates/mir2/src/lower/types.rs b/crates/mir2/src/lower/types.rs index b4b9697a84..7072eaa96b 100644 --- a/crates/mir2/src/lower/types.rs +++ b/crates/mir2/src/lower/types.rs @@ -6,7 +6,7 @@ use crate::{ }, }; -use fe_analyzer2::namespace::{ +use fe_analyzer::namespace::{ items as analyzer_items, types::{self as analyzer_types, TraitOrType}, }; diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index 27a0105735..c464461b53 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,10 +1,3 @@ -use fe_analyzer2::namespace::items::{IngotId, ModuleId}; -use fe_common2::{db::Upcast, files::Utf8Path}; -use fe_mir::{ - analysis::{ControlFlowGraph, DomTree, LoopTree, PostDomTree}, - db::{MirDb, NewDb}, -}; - macro_rules! test_lowering { ($name:ident, $path:expr) => { #[test] From 4fe21dd063ff8b752dd1e2496cc367000d6c37f6 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Fri, 5 Jan 2024 14:11:20 -0700 Subject: [PATCH 04/22] hacking --- crates/mir2/src/ir/basic_block.rs | 10 +- crates/mir2/src/ir/body_builder.rs | 34 ----- crates/mir2/src/ir/function.rs | 5 +- crates/mir2/src/ir/mod.rs | 78 +++++------ crates/mir2/src/ir/types.rs | 210 ++++++++++++++--------------- crates/mir2/src/lib.rs | 57 +++++--- crates/mir2/tests/lowering.rs | 188 +++++++++++++------------- 7 files changed, 287 insertions(+), 295 deletions(-) diff --git a/crates/mir2/src/ir/basic_block.rs b/crates/mir2/src/ir/basic_block.rs index 359c4c76f6..73b2eab8b7 100644 --- a/crates/mir2/src/ir/basic_block.rs +++ b/crates/mir2/src/ir/basic_block.rs @@ -1,6 +1,8 @@ -use id_arena::Id; +#[salsa::interned] +pub struct BasicBlockId { + #[return_ref] + pub data: BasicBlock, +} -pub type BasicBlockId = Id; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[salsa::tracked] pub struct BasicBlock {} diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs index 7190601cfd..622a24af4f 100644 --- a/crates/mir2/src/ir/body_builder.rs +++ b/crates/mir2/src/ir/body_builder.rs @@ -228,40 +228,6 @@ impl BodyBuilder { self.insert_inst(inst) } - pub fn keccak256(&mut self, arg: ValueId, source: SourceInfo) -> InstId { - let kind = InstKind::Keccak256 { arg }; - let inst = Inst::new(kind, source); - self.insert_inst(inst) - } - - pub fn abi_encode(&mut self, arg: ValueId, source: SourceInfo) -> InstId { - let kind = InstKind::AbiEncode { arg }; - let inst = Inst::new(kind, source); - self.insert_inst(inst) - } - - pub fn create(&mut self, value: ValueId, contract: ContractId, source: SourceInfo) -> InstId { - let kind = InstKind::Create { value, contract }; - let inst = Inst::new(kind, source); - self.insert_inst(inst) - } - - pub fn create2( - &mut self, - value: ValueId, - salt: ValueId, - contract: ContractId, - source: SourceInfo, - ) -> InstId { - let kind = InstKind::Create2 { - value, - salt, - contract, - }; - let inst = Inst::new(kind, source); - self.insert_inst(inst) - } - pub fn yul_intrinsic( &mut self, op: YulIntrinsicOp, diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs index 1f6f30e915..73aed03e6e 100644 --- a/crates/mir2/src/ir/function.rs +++ b/crates/mir2/src/ir/function.rs @@ -8,9 +8,10 @@ use super::{ basic_block::BasicBlock, body_order::BodyOrder, inst::{BranchInfo, Inst, InstId, InstKind}, - types::TypeId, + // types::TypeId, value::{AssignableValue, Local, Value, ValueId}, - BasicBlockId, SourceInfo, + BasicBlockId, + SourceInfo, }; /// Represents function signature. diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index ece3b03e54..dee098c765 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -1,46 +1,46 @@ pub mod basic_block; -pub mod body_builder; -pub mod body_cursor; -pub mod body_order; -pub mod constant; -pub mod function; -pub mod inst; -pub mod types; -pub mod value; +// pub mod body_builder; +// pub mod body_cursor; +// pub mod body_order; +// pub mod constant; +// pub mod function; +// pub mod inst; +// pub mod types; +// pub mod value; pub use basic_block::{BasicBlock, BasicBlockId}; -pub use constant::{Constant, ConstantId}; -pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; -pub use inst::{Inst, InstId}; -pub use types::{Type, TypeId, TypeKind}; -pub use value::{Value, ValueId}; +// pub use constant::{Constant, ConstantId}; +// pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; +// pub use inst::{Inst, InstId}; +// pub use types::{Type, TypeId, TypeKind}; +// pub use value::{Value, ValueId}; -/// An original source information that indicates where `mir` entities derive -/// from. `SourceInfo` is mainly used for diagnostics. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct SourceInfo { - pub span: Span, - pub id: NodeId, -} +// /// An original source information that indicates where `mir` entities derive +// /// from. `SourceInfo` is mainly used for diagnostics. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct SourceInfo { +// pub span: Span, +// pub id: NodeId, +// } -impl SourceInfo { - pub fn dummy() -> Self { - Self { - span: Span::dummy(), - id: NodeId::dummy(), - } - } +// impl SourceInfo { +// pub fn dummy() -> Self { +// Self { +// span: Span::dummy(), +// id: NodeId::dummy(), +// } +// } - pub fn is_dummy(&self) -> bool { - self == &Self::dummy() - } -} +// pub fn is_dummy(&self) -> bool { +// self == &Self::dummy() +// } +// } -impl From<&Node> for SourceInfo { - fn from(node: &Node) -> Self { - Self { - span: node.span, - id: node.id, - } - } -} +// impl From<&Node> for SourceInfo { +// fn from(node: &Node) -> Self { +// Self { +// span: node.span, +// id: node.id, +// } +// } +// } diff --git a/crates/mir2/src/ir/types.rs b/crates/mir2/src/ir/types.rs index 65dc60668b..a368cce2dd 100644 --- a/crates/mir2/src/ir/types.rs +++ b/crates/mir2/src/ir/types.rs @@ -1,117 +1,117 @@ -use smol_str::SmolStr; +// use smol_str::SmolStr; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Type { - pub kind: TypeKind, - pub analyzer_ty: Option, -} +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct Type { +// pub kind: TypeKind, +// pub analyzer_ty: Option, +// } -impl Type { - pub fn new(kind: TypeKind, analyzer_ty: Option) -> Self { - Self { kind, analyzer_ty } - } -} +// impl Type { +// pub fn new(kind: TypeKind, analyzer_ty: Option) -> Self { +// Self { kind, analyzer_ty } +// } +// } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum TypeKind { - I8, - I16, - I32, - I64, - I128, - I256, - U8, - U16, - U32, - U64, - U128, - U256, - Bool, - Address, - Unit, - Array(ArrayDef), - // TODO: we should consider whether we really need `String` type. - String(usize), - Tuple(TupleDef), - Struct(StructDef), - Enum(EnumDef), - Contract(StructDef), - Map(MapDef), - MPtr(TypeId), - SPtr(TypeId), -} +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub enum TypeKind { +// I8, +// I16, +// I32, +// I64, +// I128, +// I256, +// U8, +// U16, +// U32, +// U64, +// U128, +// U256, +// Bool, +// Address, +// Unit, +// Array(ArrayDef), +// // TODO: we should consider whether we really need `String` type. +// String(usize), +// Tuple(TupleDef), +// Struct(StructDef), +// Enum(EnumDef), +// Contract(StructDef), +// Map(MapDef), +// MPtr(TypeId), +// SPtr(TypeId), +// } -/// An interned Id for [`ArrayDef`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TypeId(pub u32); -impl_intern_key!(TypeId); +// /// An interned Id for [`ArrayDef`]. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct TypeId(pub u32); +// impl_intern_key!(TypeId); -/// A static array type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ArrayDef { - pub elem_ty: TypeId, - pub len: usize, -} +// /// A static array type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct ArrayDef { +// pub elem_ty: TypeId, +// pub len: usize, +// } -/// A tuple type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct TupleDef { - pub items: Vec, -} +// /// A tuple type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct TupleDef { +// pub items: Vec, +// } -/// A user defined struct type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct StructDef { - pub name: SmolStr, - pub fields: Vec<(SmolStr, TypeId)>, - pub span: Span, - pub module_id: analyzer_items::ModuleId, -} +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct StructDef { +// pub name: SmolStr, +// pub fields: Vec<(SmolStr, TypeId)>, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } -/// A user defined struct type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct EnumDef { - pub name: SmolStr, - pub variants: Vec, - pub span: Span, - pub module_id: analyzer_items::ModuleId, -} +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EnumDef { +// pub name: SmolStr, +// pub variants: Vec, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } -impl EnumDef { - pub fn tag_type(&self) -> TypeKind { - let variant_num = self.variants.len() as u64; - if variant_num <= u8::MAX as u64 { - TypeKind::U8 - } else if variant_num <= u16::MAX as u64 { - TypeKind::U16 - } else if variant_num <= u32::MAX as u64 { - TypeKind::U32 - } else { - TypeKind::U64 - } - } -} +// impl EnumDef { +// pub fn tag_type(&self) -> TypeKind { +// let variant_num = self.variants.len() as u64; +// if variant_num <= u8::MAX as u64 { +// TypeKind::U8 +// } else if variant_num <= u16::MAX as u64 { +// TypeKind::U16 +// } else if variant_num <= u32::MAX as u64 { +// TypeKind::U32 +// } else { +// TypeKind::U64 +// } +// } +// } -/// A user defined struct type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct EnumVariant { - pub name: SmolStr, - pub span: Span, - pub ty: TypeId, -} +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EnumVariant { +// pub name: SmolStr, +// pub span: Span, +// pub ty: TypeId, +// } -/// A user defined struct type definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct EventDef { - pub name: SmolStr, - pub fields: Vec<(SmolStr, TypeId, bool)>, - pub span: Span, - pub module_id: analyzer_items::ModuleId, -} +// /// A user defined struct type definition. +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub struct EventDef { +// pub name: SmolStr, +// pub fields: Vec<(SmolStr, TypeId, bool)>, +// pub span: Span, +// pub module_id: analyzer_items::ModuleId, +// } -/// A map type definition. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MapDef { - pub key_ty: TypeId, - pub value_ty: TypeId, -} +// /// A map type definition. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct MapDef { +// pub key_ty: TypeId, +// pub value_ty: TypeId, +// } diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 0515811fda..798b9c2cb3 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -1,25 +1,48 @@ -pub mod analysis; -pub mod graphviz; +use fe_hir::HirDb; + +// pub mod analysis; +// pub mod graphviz; pub mod ir; -pub mod pretty_print; +// pub mod pretty_print; -mod lower; +// mod lower; #[salsa::jar(db = MirDb)] pub struct Jar( ir::BasicBlock, ir::BasicBlockId, - ir::Constant, - ir::ConstantId, - ir::FunctionBody, - ir::FunctionId, - ir::FunctionParam, - ir::FunctionSignature, - ir::Inst, - ir::InstId, - ir::Type, - ir::TypeId, - ir::TypeKind, - ir::Value, - ir::ValueId, + // ir::Constant, + // ir::ConstantId, + // ir::FunctionBody, + // ir::FunctionId, + // ir::FunctionParam, + // ir::FunctionSignature, + // ir::Inst, + // ir::InstId, + // ir::Value, + // ir::ValueId, ); + +#[salsa::jar(db = LowerMirDb)] +pub struct LowerJar(); + +pub trait MirDb: salsa::DbWithJar + HirDb { + // fn prefill(&self) + // where + // Self: Sized, + // { + // IdentId::prefill(self) + // } + + fn as_hir_db(&self) -> &dyn MirDb { + >::as_jar_db::<'_>(self) + } +} +impl MirDb for DB where DB: salsa::DbWithJar + HirDb {} + +pub trait LowerMirDb: salsa::DbWithJar + HirDb { + fn as_lower_hir_db(&self) -> &dyn LowerMirDb { + >::as_jar_db::<'_>(self) + } +} +impl LowerMirDb for DB where DB: salsa::DbWithJar + MirDb {} diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index c464461b53..a02b3df486 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,102 +1,102 @@ -macro_rules! test_lowering { - ($name:ident, $path:expr) => { - #[test] - fn $name() { - let mut db = NewDb::default(); +// macro_rules! test_lowering { +// ($name:ident, $path:expr) => { +// #[test] +// fn $name() { +// let mut db = NewDb::default(); - let file_name = Utf8Path::new($path).file_name().unwrap(); - let module = ModuleId::new_standalone(&mut db, file_name, test_files::fixture($path)); +// let file_name = Utf8Path::new($path).file_name().unwrap(); +// let module = ModuleId::new_standalone(&mut db, file_name, test_files::fixture($path)); - let diags = module.diagnostics(&db); - if !diags.is_empty() { - panic!("lowering failed") - } +// let diags = module.diagnostics(&db); +// if !diags.is_empty() { +// panic!("lowering failed") +// } - for func in db.mir_lower_module_all_functions(module).iter() { - let body = func.body(&db); - ControlFlowGraph::compute(&body); - } - } - }; -} +// for func in db.mir_lower_module_all_functions(module).iter() { +// let body = func.body(&db); +// ControlFlowGraph::compute(&body); +// } +// } +// }; +// } -#[test] -fn mir_lower_std_lib() { - let mut db = NewDb::default(); +// #[test] +// fn mir_lower_std_lib() { +// let mut db = NewDb::default(); - // Should return the same id - let std_ingot = IngotId::std_lib(&mut db); +// // Should return the same id +// let std_ingot = IngotId::std_lib(&mut db); - let diags = std_ingot.diagnostics(&db); - if !diags.is_empty() { - panic!("std lib analysis failed") - } +// let diags = std_ingot.diagnostics(&db); +// if !diags.is_empty() { +// panic!("std lib analysis failed") +// } - for &module in std_ingot.all_modules(db.upcast()).iter() { - for func in db.mir_lower_module_all_functions(module).iter() { - let body = func.body(&db); - let cfg = ControlFlowGraph::compute(&body); - let domtree = DomTree::compute(&cfg); - LoopTree::compute(&cfg, &domtree); - PostDomTree::compute(&body); - } - } -} +// for &module in std_ingot.all_modules(db.upcast()).iter() { +// for func in db.mir_lower_module_all_functions(module).iter() { +// let body = func.body(&db); +// let cfg = ControlFlowGraph::compute(&body); +// let domtree = DomTree::compute(&cfg); +// LoopTree::compute(&cfg, &domtree); +// PostDomTree::compute(&body); +// } +// } +// } -test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} -test_lowering! { mir_guest_book, "demos/guest_book.fe"} -test_lowering! { mir_uniswap, "demos/uniswap.fe"} -test_lowering! { mir_assert, "features/assert.fe"} -test_lowering! { mir_aug_assign, "features/aug_assign.fe"} -test_lowering! { mir_call_statement_with_args, "features/call_statement_with_args.fe"} -test_lowering! { mir_call_statement_with_args_2, "features/call_statement_with_args_2.fe"} -test_lowering! { mir_call_statement_without_args, "features/call_statement_without_args.fe"} -test_lowering! { mir_checked_arithmetic, "features/checked_arithmetic.fe"} -test_lowering! { mir_constructor, "features/constructor.fe"} -test_lowering! { mir_create2_contract, "features/create2_contract.fe"} -test_lowering! { mir_create_contract, "features/create_contract.fe"} -test_lowering! { mir_create_contract_from_init, "features/create_contract_from_init.fe"} -test_lowering! { mir_empty, "features/empty.fe"} -test_lowering! { mir_events, "features/events.fe"} -test_lowering! { mir_module_level_events, "features/module_level_events.fe"} -test_lowering! { mir_external_contract, "features/external_contract.fe"} -test_lowering! { mir_for_loop_with_break, "features/for_loop_with_break.fe"} -test_lowering! { mir_for_loop_with_continue, "features/for_loop_with_continue.fe"} -test_lowering! { mir_for_loop_with_static_array, "features/for_loop_with_static_array.fe"} -test_lowering! { mir_if_statement, "features/if_statement.fe"} -test_lowering! { mir_if_statement_2, "features/if_statement_2.fe"} -test_lowering! { mir_if_statement_with_block_declaration, "features/if_statement_with_block_declaration.fe"} -test_lowering! { mir_keccak, "features/keccak.fe"} -test_lowering! { mir_math, "features/math.fe"} -test_lowering! { mir_module_const, "features/module_const.fe"} -test_lowering! { mir_multi_param, "features/multi_param.fe"} -test_lowering! { mir_nested_map, "features/nested_map.fe"} -test_lowering! { mir_numeric_sizes, "features/numeric_sizes.fe"} -test_lowering! { mir_ownable, "features/ownable.fe"} -test_lowering! { mir_pure_fn_standalone, "features/pure_fn_standalone.fe"} -test_lowering! { mir_revert, "features/revert.fe"} -test_lowering! { mir_self_address, "features/self_address.fe"} -test_lowering! { mir_send_value, "features/send_value.fe"} -test_lowering! { mir_balances, "features/balances.fe"} -test_lowering! { mir_sized_vals_in_sto, "features/sized_vals_in_sto.fe"} -test_lowering! { mir_strings, "features/strings.fe"} -test_lowering! { mir_structs, "features/structs.fe"} -test_lowering! { mir_struct_fns, "features/struct_fns.fe"} -test_lowering! { mir_ternary_expression, "features/ternary_expression.fe"} -test_lowering! { mir_two_contracts, "features/two_contracts.fe"} -test_lowering! { mir_u8_u8_map, "features/u8_u8_map.fe"} -test_lowering! { mir_u16_u16_map, "features/u16_u16_map.fe"} -test_lowering! { mir_u32_u32_map, "features/u32_u32_map.fe"} -test_lowering! { mir_u64_u64_map, "features/u64_u64_map.fe"} -test_lowering! { mir_u128_u128_map, "features/u128_u128_map.fe"} -test_lowering! { mir_u256_u256_map, "features/u256_u256_map.fe"} -test_lowering! { mir_while_loop, "features/while_loop.fe"} -test_lowering! { mir_while_loop_with_break, "features/while_loop_with_break.fe"} -test_lowering! { mir_while_loop_with_break_2, "features/while_loop_with_break_2.fe"} -test_lowering! { mir_while_loop_with_continue, "features/while_loop_with_continue.fe"} -test_lowering! { mir_abi_encoding_stress, "stress/abi_encoding_stress.fe"} -test_lowering! { mir_data_copying_stress, "stress/data_copying_stress.fe"} -test_lowering! { mir_tuple_stress, "stress/tuple_stress.fe"} -test_lowering! { mir_type_aliases, "features/type_aliases.fe"} -test_lowering! { mir_const_generics, "features/const_generics.fe" } -test_lowering! { mir_const_local, "features/const_local.fe" } +// test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} +// test_lowering! { mir_guest_book, "demos/guest_book.fe"} +// test_lowering! { mir_uniswap, "demos/uniswap.fe"} +// test_lowering! { mir_assert, "features/assert.fe"} +// test_lowering! { mir_aug_assign, "features/aug_assign.fe"} +// test_lowering! { mir_call_statement_with_args, "features/call_statement_with_args.fe"} +// test_lowering! { mir_call_statement_with_args_2, "features/call_statement_with_args_2.fe"} +// test_lowering! { mir_call_statement_without_args, "features/call_statement_without_args.fe"} +// test_lowering! { mir_checked_arithmetic, "features/checked_arithmetic.fe"} +// test_lowering! { mir_constructor, "features/constructor.fe"} +// test_lowering! { mir_create2_contract, "features/create2_contract.fe"} +// test_lowering! { mir_create_contract, "features/create_contract.fe"} +// test_lowering! { mir_create_contract_from_init, "features/create_contract_from_init.fe"} +// test_lowering! { mir_empty, "features/empty.fe"} +// test_lowering! { mir_events, "features/events.fe"} +// test_lowering! { mir_module_level_events, "features/module_level_events.fe"} +// test_lowering! { mir_external_contract, "features/external_contract.fe"} +// test_lowering! { mir_for_loop_with_break, "features/for_loop_with_break.fe"} +// test_lowering! { mir_for_loop_with_continue, "features/for_loop_with_continue.fe"} +// test_lowering! { mir_for_loop_with_static_array, "features/for_loop_with_static_array.fe"} +// test_lowering! { mir_if_statement, "features/if_statement.fe"} +// test_lowering! { mir_if_statement_2, "features/if_statement_2.fe"} +// test_lowering! { mir_if_statement_with_block_declaration, "features/if_statement_with_block_declaration.fe"} +// test_lowering! { mir_keccak, "features/keccak.fe"} +// test_lowering! { mir_math, "features/math.fe"} +// test_lowering! { mir_module_const, "features/module_const.fe"} +// test_lowering! { mir_multi_param, "features/multi_param.fe"} +// test_lowering! { mir_nested_map, "features/nested_map.fe"} +// test_lowering! { mir_numeric_sizes, "features/numeric_sizes.fe"} +// test_lowering! { mir_ownable, "features/ownable.fe"} +// test_lowering! { mir_pure_fn_standalone, "features/pure_fn_standalone.fe"} +// test_lowering! { mir_revert, "features/revert.fe"} +// test_lowering! { mir_self_address, "features/self_address.fe"} +// test_lowering! { mir_send_value, "features/send_value.fe"} +// test_lowering! { mir_balances, "features/balances.fe"} +// test_lowering! { mir_sized_vals_in_sto, "features/sized_vals_in_sto.fe"} +// test_lowering! { mir_strings, "features/strings.fe"} +// test_lowering! { mir_structs, "features/structs.fe"} +// test_lowering! { mir_struct_fns, "features/struct_fns.fe"} +// test_lowering! { mir_ternary_expression, "features/ternary_expression.fe"} +// test_lowering! { mir_two_contracts, "features/two_contracts.fe"} +// test_lowering! { mir_u8_u8_map, "features/u8_u8_map.fe"} +// test_lowering! { mir_u16_u16_map, "features/u16_u16_map.fe"} +// test_lowering! { mir_u32_u32_map, "features/u32_u32_map.fe"} +// test_lowering! { mir_u64_u64_map, "features/u64_u64_map.fe"} +// test_lowering! { mir_u128_u128_map, "features/u128_u128_map.fe"} +// test_lowering! { mir_u256_u256_map, "features/u256_u256_map.fe"} +// test_lowering! { mir_while_loop, "features/while_loop.fe"} +// test_lowering! { mir_while_loop_with_break, "features/while_loop_with_break.fe"} +// test_lowering! { mir_while_loop_with_break_2, "features/while_loop_with_break_2.fe"} +// test_lowering! { mir_while_loop_with_continue, "features/while_loop_with_continue.fe"} +// test_lowering! { mir_abi_encoding_stress, "stress/abi_encoding_stress.fe"} +// test_lowering! { mir_data_copying_stress, "stress/data_copying_stress.fe"} +// test_lowering! { mir_tuple_stress, "stress/tuple_stress.fe"} +// test_lowering! { mir_type_aliases, "features/type_aliases.fe"} +// test_lowering! { mir_const_generics, "features/const_generics.fe" } +// test_lowering! { mir_const_local, "features/const_local.fe" } From 2d992e9b23fff5f99b3ddf5f97bed2de86c67713 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Tue, 9 Jan 2024 12:46:57 -0700 Subject: [PATCH 05/22] hacking --- Cargo.lock | 9 +- crates/library2/Cargo.toml | 10 + crates/library2/build.rs | 3 + crates/library2/src/lib.rs | 27 ++ crates/library2/std/src/buf.fe | 299 +++++++++++++++++++++++ crates/library2/std/src/context.fe | 174 +++++++++++++ crates/library2/std/src/error.fe | 6 + crates/library2/std/src/evm.fe | 325 +++++++++++++++++++++++++ crates/library2/std/src/lib.fe | 3 + crates/library2/std/src/math.fe | 15 ++ crates/library2/std/src/precompiles.fe | 191 +++++++++++++++ crates/library2/std/src/prelude.fe | 1 + crates/library2/std/src/traits.fe | 160 ++++++++++++ crates/mir2/Cargo.toml | 2 +- crates/mir2/src/lib.rs | 12 +- crates/mir2/tests/lowering.rs | 4 +- crates/mir2/tests/test_db.rs | 144 +++++++++++ 17 files changed, 1376 insertions(+), 9 deletions(-) create mode 100644 crates/library2/Cargo.toml create mode 100644 crates/library2/build.rs create mode 100644 crates/library2/src/lib.rs create mode 100644 crates/library2/std/src/buf.fe create mode 100644 crates/library2/std/src/context.fe create mode 100644 crates/library2/std/src/error.fe create mode 100644 crates/library2/std/src/evm.fe create mode 100644 crates/library2/std/src/lib.fe create mode 100644 crates/library2/std/src/math.fe create mode 100644 crates/library2/std/src/precompiles.fe create mode 100644 crates/library2/std/src/prelude.fe create mode 100644 crates/library2/std/src/traits.fe create mode 100644 crates/mir2/tests/test_db.rs diff --git a/Cargo.lock b/Cargo.lock index dba014a613..54f8ad9de8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1100,6 +1100,13 @@ dependencies = [ "include_dir", ] +[[package]] +name = "fe-library2" +version = "0.23.0" +dependencies = [ + "include_dir", +] + [[package]] name = "fe-macros" version = "0.23.0" @@ -1138,7 +1145,7 @@ dependencies = [ "fe-common2", "fe-hir", "fe-hir-analysis", - "fe-library", + "fe-library2", "fe-parser2", "fe-test-files", "fxhash", diff --git a/crates/library2/Cargo.toml b/crates/library2/Cargo.toml new file mode 100644 index 0000000000..1fb8ff4bc0 --- /dev/null +++ b/crates/library2/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "fe-library2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/ethereum/fe" + +[dependencies] +include_dir = "0.7.2" diff --git a/crates/library2/build.rs b/crates/library2/build.rs new file mode 100644 index 0000000000..0ce78ee5e7 --- /dev/null +++ b/crates/library2/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rerun-if-changed=./std"); +} diff --git a/crates/library2/src/lib.rs b/crates/library2/src/lib.rs new file mode 100644 index 0000000000..f905973fb4 --- /dev/null +++ b/crates/library2/src/lib.rs @@ -0,0 +1,27 @@ +pub use ::include_dir; +use include_dir::{include_dir, Dir}; + +pub const STD: Dir = include_dir!("$CARGO_MANIFEST_DIR/std"); + +pub fn std_src_files() -> Vec<(&'static str, &'static str)> { + static_dir_files(STD.get_dir("src").unwrap()) +} + +pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { + fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { + accum.extend(dir.files().map(|file| { + ( + file.path().to_str().unwrap(), + file.contents_utf8().expect("non-utf8 static file"), + ) + })); + + for sub_dir in dir.dirs() { + add_files(sub_dir, accum) + } + } + + let mut files = vec![]; + add_files(dir, &mut files); + files +} diff --git a/crates/library2/std/src/buf.fe b/crates/library2/std/src/buf.fe new file mode 100644 index 0000000000..a1d97af4e6 --- /dev/null +++ b/crates/library2/std/src/buf.fe @@ -0,0 +1,299 @@ +use ingot::evm +use ingot::math + +unsafe fn avail() -> u256 { + let ptr: u256 = evm::mload(offset: 64) + + if ptr == 0x00 { + return 96 + } else { + return ptr + } +} + +unsafe fn alloc(len: u256) -> u256 { + let ptr: u256 = avail() + evm::mstore(offset: 64, value: ptr + len) + return ptr +} + +struct Cursor { + cur: u256 + len: u256 + + pub fn new(len: u256) -> Self { + return Cursor(cur: 0, len) + } + + /// Increment the value of `cur` by `len` and return the value of `cur` before being incremented. + /// Reverts if the cursor is advanced beyond the given length. + pub fn advance(mut self, len: u256) -> u256 { + let cur: u256 = self.cur + assert cur + len < self.len + 1 + self.cur += len + return cur + } + + /// Length of the cursor remaining. + pub fn remainder(self) -> u256 { + return self.len - self.cur + } +} + +/// EVM memory buffer abstraction. +pub struct MemoryBuffer { + offset: u256 + len: u256 + + pub fn new(len: u256) -> Self { + unsafe { + return MemoryBuffer(offset: alloc(len: len + 30), len) + } + } + + pub fn from_u8(value: u8) -> Self { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 1) + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value) + return buf + } + + /// Length of the buffer in bytes. + pub fn len(self) -> u256 { + return self.len + } + + /// The start of the buffer in EVM memory. + pub fn offset(self) -> u256 { + return self.offset + } + + /// Returns a new buffer reader. + pub fn reader(self) -> MemoryBufferReader { + return MemoryBufferReader::new(buf: self) + } + + /// Returns a new buffer writer. + pub fn writer(mut self) -> MemoryBufferWriter { + return MemoryBufferWriter::new(buf: self) + } +} + +/// Memory buffer writer abstraction. +pub struct MemoryBufferWriter { + buf: MemoryBuffer + cur: Cursor + + /// Returns a new writer for the given buffer. + pub fn new(mut buf: MemoryBuffer) -> Self { + return MemoryBufferWriter( + buf, + cur: Cursor::new(len: buf.len()) + ) + } + + /// The number of bytes remaining to be written. + pub fn remainder(self) -> u256 { + return self.cur.remainder() + } + + pub fn write_offset(mut self, len: u256) -> u256 { + return self.buf.offset() + self.cur.advance(len) + } + + pub fn write_n(mut self, value: u256, len: u256) { + let offset: u256 = self.write_offset(len) + let shifted_value: u256 = evm::shl(bits: 256 - len * 8, value) + unsafe { evm::mstore(offset, value: shifted_value) } + } + + pub fn write_buf(mut self, buf: MemoryBuffer) { + let mut reader: MemoryBufferReader = buf.reader() + + while true { + let bytes_remaining: u256 = reader.remainder() + + if bytes_remaining >= 32 { + self.write(value: reader.read_u256()) + } else if bytes_remaining == 0 { + break + } else { + self.write(value: reader.read_u8()) + } + } + } + + pub fn write(mut self, value: T) { + value.write_buf(writer: self) + } +} + +pub trait MemoryBufferWrite { + fn write_buf(self, mut writer: MemoryBufferWriter); +} + +impl MemoryBufferWrite for u256 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + let offset: u256 = writer.write_offset(len: 32) + unsafe { evm::mstore(offset, value: self) } + } +} + +impl MemoryBufferWrite for u128 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 16) + } +} + +impl MemoryBufferWrite for u64 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 8) + } +} + +impl MemoryBufferWrite for u32 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 4) + } +} + +impl MemoryBufferWrite for u16 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + writer.write_n(value: u256(self), len: 2) + } +} + +impl MemoryBufferWrite for u8 { + fn write_buf(self, mut writer: MemoryBufferWriter) { + let offset: u256 = writer.write_offset(len: 1) + unsafe { evm::mstore8(offset, value: self) } + } +} + +// This is needed to prevent the `mir_lower_std_lib` to crash the compiler +impl MemoryBufferWrite for () { + fn write_buf(self, mut writer: MemoryBufferWriter) {} +} + +/// Memory buffer reader abstraction. +pub struct MemoryBufferReader { + buf: MemoryBuffer + cur: Cursor + + /// Returns a new reader for the given buffer. + pub fn new(buf: MemoryBuffer) -> Self { + return MemoryBufferReader(buf, cur: Cursor::new(len: buf.len())) + } + + /// The number of bytes remaining to be read. + pub fn remainder(self) -> u256 { + return self.cur.remainder() + } + + fn read_offset(mut self, len: u256) -> u256 { + return self.buf.offset() + self.cur.advance(len) + } + + fn read_n(mut self, len: u256) -> u256 { + let offset: u256 = self.read_offset(len) + unsafe { + let value: u256 = evm::mload(offset) + return evm::shr(bits: 256 - len * 8, value) + } + } + + pub fn read_u8(mut self) -> u8 { + return u8(self.read_n(len: 1)) + } + + pub fn read_u16(mut self) -> u16 { + return u16(self.read_n(len: 2)) + } + + pub fn read_u32(mut self) -> u32 { + return u32(self.read_n(len: 4)) + } + + pub fn read_u64(mut self) -> u64 { + return u64(self.read_n(len: 8)) + } + + pub fn read_u128(mut self) -> u128 { + return u128(self.read_n(len: 16)) + } + + pub fn read_u256(mut self) -> u256 { + let offset: u256 = self.read_offset(len: 32) + unsafe { + let value: u256 = evm::mload(offset) + return value + } + } + + pub fn read_buf(mut self, len: u256) -> MemoryBuffer { + let mut buf: MemoryBuffer = MemoryBuffer::new(len) + let mut writer: MemoryBufferWriter = buf.writer() + + while true { + let bytes_remaining: u256 = writer.remainder() + + if bytes_remaining >= 32 { + writer.write(value: self.read_u256()) + } else if bytes_remaining == 0 { + break + } else { + writer.write(value: self.read_u8()) + } + } + + return buf + } + + // `T` has not been defined + // pub fn read(mut self) -> T { + // T::read_buf(writer: self) + // } +} + +// pub trait MemoryBufferRead { +// fn read_buf(self, mut reader: MemoryBufferReader) -> Self; +// } +// +// impl MemoryBufferRead for u256 { .. } +// . +// . +// impl MemoryBufferRead for u8 { .. } + +/// `MemoryBuffer` wrapper for raw calls to other contracts. +pub struct RawCallBuffer { + input_len: u256 + output_len: u256 + buf: MemoryBuffer + + pub fn new(input_len: u256, output_len: u256) -> Self { + let len: u256 = math::max(input_len, output_len) + let buf: MemoryBuffer = MemoryBuffer::new(len) + + return RawCallBuffer(input_len, output_len, buf) + } + + pub fn input_len(self) -> u256 { + return self.input_len + } + + pub fn output_len(self) -> u256 { + return self.output_len + } + + pub fn offset(self) -> u256 { + return self.buf.offset() + } + + pub fn reader(self) -> MemoryBufferReader { + return self.buf.reader() + } + + pub fn writer(mut self) -> MemoryBufferWriter { + return self.buf.writer() + } +} diff --git a/crates/library2/std/src/context.fe b/crates/library2/std/src/context.fe new file mode 100644 index 0000000000..9b51f9ca9e --- /dev/null +++ b/crates/library2/std/src/context.fe @@ -0,0 +1,174 @@ +use ingot::evm +use ingot::error::{ + ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE, + ERROR_FAILED_SEND_VALUE, + Error +} +use ingot::buf::{ + RawCallBuffer, + MemoryBufferReader, + MemoryBufferWriter +} + +struct OutOfReachMarker {} + +// ctx.emit(my_event) should be the only way to emit an event. We achieve this by defining the +// private `OutOfReachMarker` here to which only the `Context` has access. +// Now there is no way to call `emit` directly on an Emittable. +pub trait Emittable { + fn emit(self, _ val: OutOfReachMarker); +} + +pub struct CalldataReader { + cur_offset: u256 + len: u256 + + pub unsafe fn new(len: u256) -> CalldataReader { + return CalldataReader(cur_offset: 0, len) + } + + pub fn remainder(self) -> u256 { + return self.len - self.cur_offset + } + + pub fn advance(mut self, len: u256) -> u256 { + self.cur_offset += len + assert self.cur_offset <= self.len + return self.cur_offset + } + + fn read_n(mut self, len: u256) -> u256 { + unsafe { + let value: u256 = evm::call_data_load(offset: self.cur_offset) + self.advance(len) + return evm::shr(bits: 256 - len * 8, value) + } + } + + pub fn read_u8(mut self) -> u8 { + return u8(self.read_n(len: 1)) + } + + pub fn read_u16(mut self) -> u16 { + return u16(self.read_n(len: 2)) + } + + pub fn read_u32(mut self) -> u32 { + return u32(self.read_n(len: 4)) + } + + pub fn read_u64(mut self) -> u64 { + return u64(self.read_n(len: 8)) + } + + pub fn read_u128(mut self) -> u128 { + return u128(self.read_n(len: 16)) + } + pub fn read_u256(mut self) -> u256 { + unsafe { + let value: u256 = evm::call_data_load(offset: self.cur_offset) + self.advance(len: 32) + return value + } + } +} + +pub struct Context { + pub fn base_fee(self) -> u256 { + unsafe { return evm::base_fee() } + } + + pub fn block_coinbase(self) -> address { + unsafe { return evm::coinbase() } + } + + pub fn prevrandao(self) -> u256 { + unsafe { return evm::prevrandao() } + } + + pub fn block_number(self) -> u256 { + unsafe { return evm::block_number() } + } + + pub fn block_timestamp(self) -> u256 { + unsafe { return evm::timestamp() } + } + + pub fn chain_id(self) -> u256 { + unsafe { return evm::chain_id() } + } + + pub fn msg_sender(self) -> address { + unsafe { return evm::caller() } + } + + pub fn msg_value(self) -> u256 { + unsafe { return evm::call_value() } + } + + pub fn tx_gas_price(self) -> u256 { + unsafe { return evm::gas_price() } + } + + pub fn tx_origin(self) -> address { + unsafe { return evm::origin() } + } + + pub fn msg_sig(self) -> u256 { + unsafe { return evm::shr(bits: 224, value: evm::call_data_load(offset: 0)) } + } + + pub fn balance_of(self, _ account: address) -> u256 { + unsafe { return evm::balance_of(account) } + } + + pub fn self_balance(self) -> u256 { + unsafe { return evm::balance() } + } + + pub fn self_address(self) -> address { + unsafe { return address(__address()) } + } + + pub fn calldata_reader(self) -> CalldataReader { + unsafe { + let len: u256 = evm::call_data_size() + return CalldataReader::new(len) + } + } + + pub fn send_value(mut self, to: address, wei: u256) { + unsafe { + if evm::balance() < wei { + revert Error(code: ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE) + } + let mut buf: RawCallBuffer = RawCallBuffer::new(input_len: 0, output_len: 0) + let success: bool = evm::call(gas: evm::gas_remaining(), addr: to, value: wei, + buf) + if not success { + revert Error(code: ERROR_FAILED_SEND_VALUE) + } + } + } + + /// Makes a call to the given address. + pub fn raw_call( + self, + addr: address, + value: u256, + mut buf: RawCallBuffer + ) -> bool { + unsafe { + return evm::call( + gas: evm::gas_remaining(), + addr, + value, + buf + ) + } + } + + pub fn emit(mut self, _ val: T) { + val.emit(OutOfReachMarker()) + } +} \ No newline at end of file diff --git a/crates/library2/std/src/error.fe b/crates/library2/std/src/error.fe new file mode 100644 index 0000000000..7ae066af4c --- /dev/null +++ b/crates/library2/std/src/error.fe @@ -0,0 +1,6 @@ +pub const ERROR_INSUFFICIENT_FUNDS_TO_SEND_VALUE: u256 = 0x100 +pub const ERROR_FAILED_SEND_VALUE: u256 = 0x101 + +pub struct Error { + pub code: u256 +} \ No newline at end of file diff --git a/crates/library2/std/src/evm.fe b/crates/library2/std/src/evm.fe new file mode 100644 index 0000000000..09c9382556 --- /dev/null +++ b/crates/library2/std/src/evm.fe @@ -0,0 +1,325 @@ +use ingot::buf::{MemoryBuffer, RawCallBuffer} + +// Basic context accessor functions. +pub unsafe fn chain_id() -> u256 { + return __chainid() +} + +pub unsafe fn base_fee() -> u256 { + return __basefee() +} + +pub unsafe fn origin() -> address { + return address(__origin()) +} + +pub unsafe fn gas_price() -> u256 { + return __gasprice() +} + +pub unsafe fn gas_limit() -> u256 { + return __gaslimit() +} + +pub unsafe fn gas_remaining() -> u256 { + return __gas() +} + +pub unsafe fn block_hash(_ b: u256) -> u256 { + return __blockhash(b) +} + +pub unsafe fn coinbase() -> address { + return address(__coinbase()) +} + +pub unsafe fn timestamp() -> u256 { + return __timestamp() +} + +pub unsafe fn block_number() -> u256 { + return __number() +} + +pub unsafe fn prevrandao() -> u256 { + return __prevrandao() +} + +pub unsafe fn self_address() -> address { + return address(__address()) +} + +pub unsafe fn balance_of(_ addr: address) -> u256 { + return __balance(u256(addr)) +} + +pub unsafe fn balance() -> u256 { + return __selfbalance() +} + +pub unsafe fn caller() -> address { + return address(__caller()) +} + +pub unsafe fn call_value() -> u256 { + return __callvalue() +} + + +// Overflowing math ops. Should these be unsafe or named +// `overflowing_add`, etc? +pub fn add(_ x: u256, _ y: u256) -> u256 { + unsafe { return __add(x, y) } +} + +pub fn sub(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sub(x, y) } +} + +pub fn mul(_ x: u256, _ y: u256) -> u256 { + unsafe { return __mul(x, y) } +} + +pub fn div(_ x: u256, _ y: u256) -> u256 { + unsafe { return __div(x, y) } +} + +pub fn sdiv(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sdiv(x, y) } +} + +pub fn mod(_ x: u256, _ y: u256) -> u256 { + unsafe { return __mod(x, y) } +} + +pub fn smod(_ x: u256, _ y: u256) -> u256 { + unsafe { return __smod(x, y) } +} + +pub fn exp(_ x: u256, _ y: u256) -> u256 { + unsafe { return __exp(x, y) } +} + +pub fn addmod(_ x: u256, _ y: u256, _ m: u256) -> u256 { + unsafe { return __addmod(x, y, m) } +} + +pub fn mulmod(_ x: u256, _ y: u256, _ m: u256) -> u256 { + unsafe { return __mulmod(x, y, m) } +} + +pub fn sign_extend(_ i: u256, _ x: u256) -> u256 { + unsafe { return __signextend(i, x) } +} + + +// Comparison ops +// TODO: return bool (see issue //653) +pub fn lt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __lt(x, y) } +} + +pub fn gt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __gt(x, y) } +} + +pub fn slt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __slt(x, y) } +} + +pub fn sgt(_ x: u256, _ y: u256) -> u256 { + unsafe { return __sgt(x, y) } +} + +pub fn eq(_ x: u256, _ y: u256) -> u256 { + unsafe { return __eq(x, y) } +} + +pub fn is_zero(_ x: u256) -> u256 { + unsafe { return __iszero(x) } +} + + +// Bitwise ops +pub fn bitwise_and(_ x: u256, _ y: u256) -> u256 { + unsafe { return __and(x, y) } +} + +pub fn bitwise_or(_ x: u256, _ y: u256) -> u256 { + unsafe { return __or(x, y) } +} + +pub fn bitwise_not(_ x: u256) -> u256 { + unsafe { return __not(x) } +} + +pub fn xor(_ x: u256, _ y: u256) -> u256 { + unsafe { return __xor(x, y) } +} + +pub fn byte(offset: u256, value: u256) -> u256 { + unsafe { return __byte(offset, value) } +} + +pub fn shl(bits: u256, value: u256) -> u256 { + unsafe { return __shl(bits, value) } +} + +pub fn shr(bits: u256, value: u256) -> u256 { + unsafe { return __shr(bits, value) } +} + +pub fn sar(bits: u256, value: u256) -> u256 { + unsafe { return __sar(bits, value) } +} + + +// Evm state access and control +pub fn return_mem(buf: MemoryBuffer) { + unsafe{ __return(buf.offset(), buf.len()) } +} + +pub fn revert_mem(buf: MemoryBuffer) { + unsafe { __revert(buf.offset(), buf.len()) } +} + +pub unsafe fn selfdestruct(_ addr: address) { + __selfdestruct(u256(addr)) +} + +// Invalid opcode. Equivalent to revert(0, 0), +// except that all remaining gas in the current context +// is consumed. +pub unsafe fn invalid() { + __invalid() +} + +pub unsafe fn stop() { + __stop() +} + +pub unsafe fn pc() -> u256 { + return __pc() +} + +// TODO: dunno if we should enable this +// pub unsafe fn pop(_ x: u256) { +// return __pop(x) +// } + +pub unsafe fn mload(offset p: u256) -> u256 { + return __mload(p) +} + +pub unsafe fn mstore(offset p: u256, value v: u256) { + __mstore(p, v) +} +pub unsafe fn mstore8(offset p: u256, value v: u256) { + __mstore8(p, v) +} + +pub unsafe fn sload(offset p: u256) -> u256 { + return __sload(p) +} + +pub unsafe fn sstore(offset p: u256, value v: u256) { + __sstore(p, v) +} + +pub unsafe fn msize() -> u256 { + return __msize() +} + +pub unsafe fn call_data_load(offset p: u256) -> u256 { + return __calldataload(p) +} + +pub unsafe fn call_data_size() -> u256 { + return __calldatasize() +} + +pub fn call_data_copy(buf: MemoryBuffer, from_offset f: u256) { + unsafe { __calldatacopy(buf.offset(), f, buf.len()) } +} + +pub unsafe fn code_size() -> u256 { + return __codesize() +} + +pub unsafe fn code_copy(to_offset t: u256, from_offset f: u256, len: u256) { + __codecopy(t, f, len) +} + +pub unsafe fn return_data_size() -> u256 { + return __returndatasize() +} + +pub unsafe fn return_data_copy(to_offset t: u256, from_offset f: u256, len: u256) { + __returndatacopy(t, f, len) +} + +pub unsafe fn extcodesize(_ addr: address) -> u256 { + return __extcodesize(u256(addr)) +} + +pub unsafe fn ext_code_copy(_ addr: address, to_offset t: u256, from_offset f: u256, len: u256) { + __extcodecopy(u256(addr), t, f, len) +} + +pub unsafe fn ext_code_hash(_ addr: address) -> u256 { + return __extcodehash(u256(addr)) +} + +pub fn keccak256_mem(buf: MemoryBuffer) -> u256 { + unsafe { return __keccak256(buf.offset(), buf.len()) } +} + + +// Contract creation and calling + +pub fn create(value v: u256, buf: MemoryBuffer) -> address { + unsafe { return address(__create(v, buf.offset(), buf.len())) } +} + +pub fn create2(value v: u256, buf: MemoryBuffer, salt s: u256) -> address { + unsafe { return address(__create2(v, buf.offset(), buf.len(), s)) } +} + +// TODO: return bool (success) +pub fn call(gas: u256, addr: address, value: u256, mut buf: RawCallBuffer) -> bool { + unsafe{ return __call(gas, u256(addr), value, buf.offset(), buf.input_len(), buf.offset(), buf.output_len()) == 1 } +} + +pub unsafe fn call_code(gas: u256, addr: address, value: u256, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __callcode(gas, u256(addr), value, input_offset, input_len, output_offset, output_len) +} + +pub unsafe fn delegate_call(gas: u256, addr: address, value: u256, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __delegatecall(gas, u256(addr), input_offset, input_len, output_offset, output_len) +} + +pub unsafe fn static_call(gas: u256, addr: address, input_offset: u256, input_len: u256, output_offset: u256, output_len: u256) -> u256 { + return __staticcall(gas, u256(addr), input_offset, input_len, output_offset, output_len) +} + +// Logging functions + +pub fn log0(buf: MemoryBuffer) { + unsafe { return __log0(buf.offset(), buf.len()) } +} + +pub fn log1(buf: MemoryBuffer, topic1 t1: u256) { + unsafe { return __log1(buf.offset(), buf.len(), t1) } +} + +pub fn log2(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256) { + unsafe { return __log2(buf.offset(), buf.len(), t1, t2) } +} + +pub fn log3(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256, topic3 t3: u256) { + unsafe { return __log3(buf.offset(), buf.len(), t1, t2, t3) } +} + +pub fn log4(buf: MemoryBuffer, topic1 t1: u256, topic2 t2: u256, topic3 t3: u256, topic4 t4: u256) { + unsafe { return __log4(buf.offset(), buf.len(), t1, t2, t3, t4) } +} diff --git a/crates/library2/std/src/lib.fe b/crates/library2/std/src/lib.fe new file mode 100644 index 0000000000..8a94dde71d --- /dev/null +++ b/crates/library2/std/src/lib.fe @@ -0,0 +1,3 @@ +pub fn get_42() -> u256 { + return 42 +} \ No newline at end of file diff --git a/crates/library2/std/src/math.fe b/crates/library2/std/src/math.fe new file mode 100644 index 0000000000..bc37ee6739 --- /dev/null +++ b/crates/library2/std/src/math.fe @@ -0,0 +1,15 @@ +pub fn min(_ x: u256, _ y: u256) -> u256 { + if x < y { + return x + } else { + return y + } +} + +pub fn max(_ x: u256, _ y: u256) -> u256 { + if x > y { + return x + } else { + return y + } +} \ No newline at end of file diff --git a/crates/library2/std/src/precompiles.fe b/crates/library2/std/src/precompiles.fe new file mode 100644 index 0000000000..ba9d59f138 --- /dev/null +++ b/crates/library2/std/src/precompiles.fe @@ -0,0 +1,191 @@ +use ingot::buf::{MemoryBuffer, MemoryBufferWriter, MemoryBufferReader} +use ingot::evm + +enum Precompile { + EcRecover + Sha2256 + Ripemd160 + Identity + ModExp + EcAdd + EcMul + EcPairing + Blake2f + + pub fn addr(self) -> address { + match self { + Precompile::EcRecover => { return 0x01 } + Precompile::Sha2256 => { return 0x02 } + Precompile::Ripemd160 => { return 0x03 } + Precompile::Identity => { return 0x04 } + Precompile::ModExp => { return 0x05 } + Precompile::EcAdd => { return 0x06 } + Precompile::EcMul => { return 0x07 } + Precompile::EcPairing => { return 0x08 } + Precompile::Blake2f => { return 0x09 } + } + } + + pub fn single_buf_call(self, mut buf: MemoryBuffer) { + unsafe { + assert evm::static_call( + gas: evm::gas_remaining(), + addr: self.addr(), + input_offset: buf.offset(), + input_len: buf.len(), + output_offset: buf.offset(), + output_len: buf.len() + ) == 1 + } + } + + pub fn call(self, input: MemoryBuffer, mut output: MemoryBuffer) { + unsafe { + assert evm::static_call( + gas: evm::gas_remaining(), + addr: self.addr(), + input_offset: input.offset(), + input_len: input.len(), + output_offset: output.offset(), + output_len: output.len() + ) == 1 + } + } +} + +/// EC Recover precompile call. +pub fn ec_recover(hash: u256, v: u256, r: u256, s: u256) -> address { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value: hash) + writer.write(value: v) + writer.write(value: r) + writer.write(value: s) + + Precompile::EcRecover.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return address(reader.read_u256()) +} + +/// SHA2 256 precompile call. +pub fn sha2_256(buf: MemoryBuffer) -> u256 { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::Sha2256.call(input: buf, output) + return reader.read_u256() +} + +/// Ripemd 160 precompile call. +pub fn ripemd_160(buf: MemoryBuffer) -> u256 { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::Ripemd160.call(input: buf, output) + return reader.read_u256() +} + +/// Identity precompile call. +pub fn identity(buf: MemoryBuffer) -> MemoryBuffer { + let mut output: MemoryBuffer = MemoryBuffer::new(len: buf.len()) + Precompile::Identity.call(input: buf, output) + return output +} + +/// Mod exp preocmpile call. +pub fn mod_exp( + b_size: u256, + e_size: u256, + m_size: u256, + b: MemoryBuffer, + e: MemoryBuffer, + m: MemoryBuffer, +) -> MemoryBuffer { + let mut buf: MemoryBuffer = MemoryBuffer::new( + len: 96 + b_size + e_size + m_size + ) + + let mut writer: MemoryBufferWriter = buf.writer() + writer.write(value: b_size) + writer.write(value: e_size) + writer.write(value: m_size) + writer.write_buf(buf: b) + writer.write_buf(buf: e) + writer.write_buf(buf: m) + + Precompile::ModExp.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return reader.read_buf(len: m_size) +} + +/// EC add precompile call. +pub fn ec_add(x1: u256, y1: u256, x2: u256, y2: u256) -> (u256, u256) { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: x1) + writer.write(value: y1) + writer.write(value: x2) + writer.write(value: y2) + + Precompile::EcAdd.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return (reader.read_u256(), reader.read_u256()) +} + +/// EC mul precompile call. +pub fn ec_mul(x: u256, y: u256, s: u256) -> (u256, u256) { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 128) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: x) + writer.write(value: y) + writer.write(value: s) + + Precompile::EcMul.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return (reader.read_u256(), reader.read_u256()) +} + +/// EC pairing precompile call. +pub fn ec_pairing(buf: MemoryBuffer) -> bool { + let mut output: MemoryBuffer = MemoryBuffer::new(len: 32) + let mut reader: MemoryBufferReader = output.reader() + Precompile::EcPairing.call(input: buf, output) + return reader.read_u256() == 1 +} + +/// Blake 2f precompile call. +pub fn blake_2f( + rounds: u32, + h: Array, + m: Array, + t: Array, + f: bool +) -> Array { + let mut buf: MemoryBuffer = MemoryBuffer::new(len: 213) + let mut writer: MemoryBufferWriter = buf.writer() + + writer.write(value: rounds) + for value in h { writer.write(value) } + for value in m { writer.write(value) } + for value in t { writer.write(value) } + writer.write(value: u8(1) if f else u8(0)) + + Precompile::Blake2f.single_buf_call(buf) + + let mut reader: MemoryBufferReader = buf.reader() + return [ + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64(), + reader.read_u64() + ] +} diff --git a/crates/library2/std/src/prelude.fe b/crates/library2/std/src/prelude.fe new file mode 100644 index 0000000000..715ec70cf9 --- /dev/null +++ b/crates/library2/std/src/prelude.fe @@ -0,0 +1 @@ +use ingot::context::Context \ No newline at end of file diff --git a/crates/library2/std/src/traits.fe b/crates/library2/std/src/traits.fe new file mode 100644 index 0000000000..43a0f44727 --- /dev/null +++ b/crates/library2/std/src/traits.fe @@ -0,0 +1,160 @@ +// Dummy trait used in testing. We can remove this once we have more useful traits + +pub trait Dummy {} + +pub trait Min { + fn min() -> Self; +} + +impl Min for u8 { + fn min() -> Self { + return 0 + } +} + +impl Min for u16 { + fn min() -> Self { + return 0 + } +} + +impl Min for u32 { + fn min() -> Self { + return 0 + } +} + +impl Min for u64 { + fn min() -> Self { + return 0 + } +} + +impl Min for u128 { + fn min() -> Self { + return 0 + } +} + +impl Min for u256 { + fn min() -> Self { + return 0 + } +} + +impl Min for i8 { + fn min() -> Self { + return -128 + } +} + +impl Min for i16 { + fn min() -> Self { + return -32768 + } +} + +impl Min for i32 { + fn min() -> Self { + return -2147483648 + } +} + +impl Min for i64 { + fn min() -> Self { + return -9223372036854775808 + } +} + +impl Min for i128 { + fn min() -> Self { + return -170141183460469231731687303715884105728 + } +} + +impl Min for i256 { + fn min() -> Self { + return -57896044618658097711785492504343953926634992332820282019728792003956564819968 + } +} + + + + + + +pub trait Max { + fn max() -> Self; +} + +impl Max for u8 { + fn max() -> Self { + return 255 + } +} + +impl Max for u16 { + fn max() -> Self { + return 65535 + } +} + +impl Max for u32 { + fn max() -> Self { + return 4294967295 + } +} + +impl Max for u64 { + fn max() -> Self { + return 18446744073709551615 + } +} + +impl Max for u128 { + fn max() -> Self { + return 340282366920938463463374607431768211455 + } +} + +impl Max for u256 { + fn max() -> Self { + return 115792089237316195423570985008687907853269984665640564039457584007913129639935 + } +} + +impl Max for i8 { + fn max() -> Self { + return 127 + } +} + +impl Max for i16 { + fn max() -> Self { + return 32767 + } +} + +impl Max for i32 { + fn max() -> Self { + return 2147483647 + } +} + +impl Max for i64 { + fn max() -> Self { + return 9223372036854775807 + } +} + +impl Max for i128 { + fn max() -> Self { + return 170141183460469231731687303715884105727 + } +} + +impl Max for i256 { + fn max() -> Self { + return 57896044618658097711785492504343953926634992332820282019728792003956564819967 + } +} diff --git a/crates/mir2/Cargo.toml b/crates/mir2/Cargo.toml index b72749c6d3..c34010647b 100644 --- a/crates/mir2/Cargo.toml +++ b/crates/mir2/Cargo.toml @@ -23,4 +23,4 @@ indexmap = "1.6.2" [dev-dependencies] test-files = { path = "../test-files", package = "fe-test-files" } -fe-library = { path = "../library" } +fe-library2 = { path = "../library2" } diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 798b9c2cb3..e51d9a0a0e 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -27,12 +27,12 @@ pub struct Jar( pub struct LowerJar(); pub trait MirDb: salsa::DbWithJar + HirDb { - // fn prefill(&self) - // where - // Self: Sized, - // { - // IdentId::prefill(self) - // } + fn prefill(&self) + where + Self: Sized, + { + // IdentId::prefill(self) + } fn as_hir_db(&self) -> &dyn MirDb { >::as_jar_db::<'_>(self) diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index a02b3df486..3cfa020f5c 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -20,9 +20,11 @@ // }; // } +use fe_mir2::LowerMirDb; + // #[test] // fn mir_lower_std_lib() { -// let mut db = NewDb::default(); +// let mut db = LowerMirDb::default(); // // Should return the same id // let std_ingot = IngotId::std_lib(&mut db); diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs new file mode 100644 index 0000000000..f9063d8c66 --- /dev/null +++ b/crates/mir2/tests/test_db.rs @@ -0,0 +1,144 @@ +use std::collections::{BTreeMap, BTreeSet}; + +// use codespan_reporting::{ +// diagnostic::{Diagnostic, Label}, +// files::SimpleFiles, +// term::{ +// self, +// termcolor::{BufferWriter, ColorChoice}, +// }, +// }; +use fe_common2::{ + diagnostics::Span, + input::{IngotKind, Version}, + InputFile, InputIngot, +}; +use fe_hir::hir_def::TopLevelMod; +// use hir::{ +// hir_def::TopLevelMod, +// lower, +// span::{DynLazySpan, LazySpan}, +// HirDb, SpannedHirDb, +// }; +// use rustc_hash::FxHashMap; + +type CodeSpanFileId = usize; + +#[salsa::db( + fe_common2::Jar, + fe_hir::Jar, + fe_hir::SpannedJar, + fe_hir::LowerJar, + fe_hir_analysis::Jar +)] +pub struct LowerMirTestDb { + storage: salsa::Storage, +} + +impl LowerMirTestDb { + pub fn new_stand_alone(&mut self, file_name: &str, text: &str) -> TopLevelMod { + let kind = IngotKind::StandAlone; + let version = Version::new(0, 0, 1); + let ingot = InputIngot::new(self, file_name, kind, version, BTreeSet::default()); + let root = InputFile::new(self, ingot, "test_file.fe".into(), text.to_string()); + ingot.set_root_file(self, root); + ingot.set_files(self, [root].into()); + + // let mut prop_formatter = HirPropertyFormatter::default(); + // let top_mod = self.register_file(&mut prop_formatter, root); + let top_mod = self.register_file(root); + top_mod + } + + fn register_file(&self, input_file: InputFile) -> TopLevelMod { + let top_mod = lower::map_file_to_mod(self, input_file); + let path = input_file.path(self); + let text = input_file.text(self); + // prop_formatter.register_top_mod(path.as_str(), text, top_mod); + top_mod + } +} + +impl Default for LowerMirTestDb { + fn default() -> Self { + let db = Self { + storage: Default::default(), + }; + // db.prefill(); + db + } +} + +// pub struct HirPropertyFormatter { +// properties: BTreeMap>, +// top_mod_to_file: FxHashMap, +// code_span_files: SimpleFiles, +// } + +// impl HirPropertyFormatter { +// pub fn push_prop(&mut self, top_mod: TopLevelMod, span: DynLazySpan, prop: String) { +// self.properties +// .entry(top_mod) +// .or_default() +// .push((prop, span)); +// } + +// pub fn finish(&mut self, db: &dyn SpannedHirDb) -> String { +// let writer = BufferWriter::stderr(ColorChoice::Never); +// let mut buffer = writer.buffer(); +// let config = term::Config::default(); + +// for top_mod in self.top_mod_to_file.keys() { +// if !self.properties.contains_key(top_mod) { +// continue; +// } + +// let diags = self.properties[top_mod] +// .iter() +// .map(|(prop, span)| { +// let (span, diag) = self.property_to_diag(db, *top_mod, prop, span.clone()); +// ((span.file, span.range.start()), diag) +// }) +// .collect::>(); + +// for diag in diags.values() { +// term::emit(&mut buffer, &config, &self.code_span_files, diag).unwrap(); +// } +// } + +// std::str::from_utf8(buffer.as_slice()).unwrap().to_string() +// } + +// fn property_to_diag( +// &self, +// db: &dyn SpannedHirDb, +// top_mod: TopLevelMod, +// prop: &str, +// span: DynLazySpan, +// ) -> (Span, Diagnostic) { +// let file_id = self.top_mod_to_file[&top_mod]; +// let span = span.resolve(db).unwrap(); +// let diag = Diagnostic::note() +// .with_labels(vec![Label::primary(file_id, span.range).with_message(prop)]); +// (span, diag) +// } + +// fn register_top_mod(&mut self, path: &str, text: &str, top_mod: TopLevelMod) { +// let file_id = self.code_span_files.add(path.to_string(), text.to_string()); +// self.top_mod_to_file.insert(top_mod, file_id); +// } +// } + +// impl Default for HirPropertyFormatter { +// fn default() -> Self { +// Self { +// properties: Default::default(), +// top_mod_to_file: Default::default(), +// code_span_files: SimpleFiles::new(), +// } +// } +// } + +impl salsa::Database for LowerMirTestDb { + fn salsa_event(&self, _: salsa::Event) {} +} From 52229c9d273de023a047000aff2ae0e8efa4d8d5 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Thu, 11 Jan 2024 09:55:22 -0700 Subject: [PATCH 06/22] hacking --- crates/mir2/src/lib.rs | 2 +- crates/mir2/src/lower/mod.rs | 8 ++++--- crates/mir2/tests/lowering.rs | 44 ++++++++++++++++++----------------- crates/mir2/tests/test_db.rs | 16 ++++++------- 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index e51d9a0a0e..f9e76669b9 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -5,7 +5,7 @@ use fe_hir::HirDb; pub mod ir; // pub mod pretty_print; -// mod lower; +mod lower; #[salsa::jar(db = MirDb)] pub struct Jar( diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs index 36e43653a6..acfc941b85 100644 --- a/crates/mir2/src/lower/mod.rs +++ b/crates/mir2/src/lower/mod.rs @@ -1,4 +1,6 @@ -pub mod function; -pub mod types; +// pub mod function; +// pub mod types; -mod pattern_match; +// mod pattern_match; + +pub fn lower() {} diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index 3cfa020f5c..4853ade4b8 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,3 +1,7 @@ +use test_db::LowerMirTestDb; + +mod test_db; + // macro_rules! test_lowering { // ($name:ident, $path:expr) => { // #[test] @@ -20,30 +24,28 @@ // }; // } -use fe_mir2::LowerMirDb; - -// #[test] -// fn mir_lower_std_lib() { -// let mut db = LowerMirDb::default(); +#[test] +fn mir_lower_std_lib() { + let mut db = LowerMirTestDb::default(); -// // Should return the same id -// let std_ingot = IngotId::std_lib(&mut db); + // Should return the same id + let std_ingot = IngotId::std_lib(&mut db); -// let diags = std_ingot.diagnostics(&db); -// if !diags.is_empty() { -// panic!("std lib analysis failed") -// } + // let diags = std_ingot.diagnostics(&db); + // if !diags.is_empty() { + // panic!("std lib analysis failed") + // } -// for &module in std_ingot.all_modules(db.upcast()).iter() { -// for func in db.mir_lower_module_all_functions(module).iter() { -// let body = func.body(&db); -// let cfg = ControlFlowGraph::compute(&body); -// let domtree = DomTree::compute(&cfg); -// LoopTree::compute(&cfg, &domtree); -// PostDomTree::compute(&body); -// } -// } -// } + // for &module in std_ingot.all_modules(db.upcast()).iter() { + // for func in db.mir_lower_module_all_functions(module).iter() { + // let body = func.body(&db); + // let cfg = ControlFlowGraph::compute(&body); + // let domtree = DomTree::compute(&cfg); + // LoopTree::compute(&cfg, &domtree); + // PostDomTree::compute(&body); + // } + // } +} // test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} // test_lowering! { mir_guest_book, "demos/guest_book.fe"} diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index f9063d8c66..730f048334 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -36,7 +36,7 @@ pub struct LowerMirTestDb { } impl LowerMirTestDb { - pub fn new_stand_alone(&mut self, file_name: &str, text: &str) -> TopLevelMod { + pub fn new_stand_alone(&mut self, file_name: &str, text: &str) { let kind = IngotKind::StandAlone; let version = Version::new(0, 0, 1); let ingot = InputIngot::new(self, file_name, kind, version, BTreeSet::default()); @@ -46,16 +46,16 @@ impl LowerMirTestDb { // let mut prop_formatter = HirPropertyFormatter::default(); // let top_mod = self.register_file(&mut prop_formatter, root); - let top_mod = self.register_file(root); - top_mod + // let top_mod = self.register_file(root); + // top_mod } - fn register_file(&self, input_file: InputFile) -> TopLevelMod { - let top_mod = lower::map_file_to_mod(self, input_file); - let path = input_file.path(self); - let text = input_file.text(self); + fn register_file(&self, input_file: InputFile) { + // let top_mod = lower::map_file_to_mod(self, input_file); + // let path = input_file.path(self); + // let text = input_file.text(self); // prop_formatter.register_top_mod(path.as_str(), text, top_mod); - top_mod + // top_mod } } From 20221326df4a05d764c0d752ddb4ac995247c2c8 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Sat, 13 Jan 2024 15:57:31 -0700 Subject: [PATCH 07/22] hacking --- crates/mir2/tests/lowering.rs | 15 +++++++++------ crates/mir2/tests/test_db.rs | 25 ++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index 4853ade4b8..b0e6386c5c 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,4 +1,5 @@ -use test_db::LowerMirTestDb; +use fe_hir::hir_def::IngotId; +use test_db::{initialize_analysis_pass, LowerMirTestDb}; mod test_db; @@ -29,12 +30,14 @@ fn mir_lower_std_lib() { let mut db = LowerMirTestDb::default(); // Should return the same id - let std_ingot = IngotId::std_lib(&mut db); + let std_ingot = IngotId::dummy(); - // let diags = std_ingot.diagnostics(&db); - // if !diags.is_empty() { - // panic!("std lib analysis failed") - // } + let mut pm = initialize_analysis_pass(&db); + let diags = pm.run_on_module(std_ingot.root_mod(&db)); + + if !diags.is_empty() { + panic!("std lib analysis failed") + } // for &module in std_ingot.all_modules(db.upcast()).iter() { // for func in db.mir_lower_module_all_functions(module).iter() { diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index 730f048334..9d8dae22ba 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -13,7 +13,15 @@ use fe_common2::{ input::{IngotKind, Version}, InputFile, InputIngot, }; -use fe_hir::hir_def::TopLevelMod; +use fe_hir::{analysis_pass::AnalysisPassManager, hir_def::TopLevelMod, ParsingPass}; +use fe_hir_analysis::{ + name_resolution::{DefConflictAnalysisPass, ImportAnalysisPass, PathAnalysisPass}, + ty::{ + FuncAnalysisPass, ImplAnalysisPass, ImplTraitAnalysisPass, TraitAnalysisPass, + TypeAliasAnalysisPass, TypeDefAnalysisPass, + }, +}; +use fe_mir2::LowerMirDb; // use hir::{ // hir_def::TopLevelMod, // lower, @@ -142,3 +150,18 @@ impl Default for LowerMirTestDb { impl salsa::Database for LowerMirTestDb { fn salsa_event(&self, _: salsa::Event) {} } + +pub fn initialize_analysis_pass(db: &LowerMirTestDb) -> AnalysisPassManager<'_> { + let mut pass_manager = AnalysisPassManager::new(); + pass_manager.add_module_pass(Box::new(ParsingPass::new(db))); + pass_manager.add_module_pass(Box::new(DefConflictAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImportAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(PathAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TypeDefAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TypeAliasAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(TraitAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImplAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(ImplTraitAnalysisPass::new(db))); + pass_manager.add_module_pass(Box::new(FuncAnalysisPass::new(db))); + pass_manager +} From 63713bc31a7d880ec4048d7d8fe586a37bad6733 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Mon, 15 Jan 2024 09:46:26 -0700 Subject: [PATCH 08/22] hacking --- Cargo.lock | 1 + crates/library2/Cargo.toml | 1 + crates/library2/src/lib.rs | 58 +++++++++++++++++++++++------------ crates/mir2/tests/lowering.rs | 18 +++++------ 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 54f8ad9de8..13345c2fc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1104,6 +1104,7 @@ dependencies = [ name = "fe-library2" version = "0.23.0" dependencies = [ + "fe-common2", "include_dir", ] diff --git a/crates/library2/Cargo.toml b/crates/library2/Cargo.toml index 1fb8ff4bc0..d2ddd0ed7d 100644 --- a/crates/library2/Cargo.toml +++ b/crates/library2/Cargo.toml @@ -8,3 +8,4 @@ repository = "https://github.com/ethereum/fe" [dependencies] include_dir = "0.7.2" +common = { path = "../common2", package = "fe-common2" } diff --git a/crates/library2/src/lib.rs b/crates/library2/src/lib.rs index f905973fb4..761ba8986c 100644 --- a/crates/library2/src/lib.rs +++ b/crates/library2/src/lib.rs @@ -1,27 +1,47 @@ +use std::collections::{BTreeMap, BTreeSet}; + pub use ::include_dir; +use common::{ + input::{IngotKind, Version}, + InputDb, InputFile, InputIngot, +}; use include_dir::{include_dir, Dir}; pub const STD: Dir = include_dir!("$CARGO_MANIFEST_DIR/std"); -pub fn std_src_files() -> Vec<(&'static str, &'static str)> { - static_dir_files(STD.get_dir("src").unwrap()) +fn std_src_input_files() -> BTreeSet { + if let Some(dir) = STD.get_dir("src") {} } -pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { - fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { - accum.extend(dir.files().map(|file| { - ( - file.path().to_str().unwrap(), - file.contents_utf8().expect("non-utf8 static file"), - ) - })); - - for sub_dir in dir.dirs() { - add_files(sub_dir, accum) - } - } - - let mut files = vec![]; - add_files(dir, &mut files); - files +pub fn std_lib_input_ingot(db: &dyn InputDb) -> InputIngot { + InputIngot::new( + db, + "std", + IngotKind::Std, + Version::new(0, 0, 0), + BTreeSet::default(), + ) } + +// pub fn std_src_files() -> Vec<(&'static str, &'static str)> { +// static_dir_files(STD.get_dir("src").unwrap()) +// } + +// pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { +// fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { +// accum.extend(dir.files().map(|file| { +// ( +// file.path().to_str().unwrap(), +// file.contents_utf8().expect("non-utf8 static file"), +// ) +// })); + +// for sub_dir in dir.dirs() { +// add_files(sub_dir, accum) +// } +// } + +// let mut files = vec![]; +// add_files(dir, &mut files); +// files +// } diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index b0e6386c5c..e7ffe3b44a 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -39,15 +39,15 @@ fn mir_lower_std_lib() { panic!("std lib analysis failed") } - // for &module in std_ingot.all_modules(db.upcast()).iter() { - // for func in db.mir_lower_module_all_functions(module).iter() { - // let body = func.body(&db); - // let cfg = ControlFlowGraph::compute(&body); - // let domtree = DomTree::compute(&cfg); - // LoopTree::compute(&cfg, &domtree); - // PostDomTree::compute(&body); - // } - // } + for &module in std_ingot.all_modules(&db).iter() { + // for func in db.mir_lower_module_all_functions(module).iter() { + // let body = func.body(&db); + // let cfg = ControlFlowGraph::compute(&body); + // let domtree = DomTree::compute(&cfg); + // LoopTree::compute(&cfg, &domtree); + // PostDomTree::compute(&body); + // } + } } // test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} From 3d3849e01fb09df27818c60378d6e236bf2d1d05 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Wed, 17 Jan 2024 13:18:08 -0700 Subject: [PATCH 09/22] hacking --- crates/library2/src/lib.rs | 53 +++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/crates/library2/src/lib.rs b/crates/library2/src/lib.rs index 761ba8986c..d222b5a70d 100644 --- a/crates/library2/src/lib.rs +++ b/crates/library2/src/lib.rs @@ -9,39 +9,46 @@ use include_dir::{include_dir, Dir}; pub const STD: Dir = include_dir!("$CARGO_MANIFEST_DIR/std"); -fn std_src_input_files() -> BTreeSet { - if let Some(dir) = STD.get_dir("src") {} +fn std_src_input_files(db: &mut dyn InputDb, ingot: InputIngot) -> BTreeSet { + static_dir_files(&STD) + .into_iter() + .map(|(path, content)| InputFile::new(db, ingot, path.into(), content.into())) + .collect() } -pub fn std_lib_input_ingot(db: &dyn InputDb) -> InputIngot { - InputIngot::new( +pub fn std_lib_input_ingot(db: &mut dyn InputDb) -> InputIngot { + let ingot = InputIngot::new( db, "std", IngotKind::Std, Version::new(0, 0, 0), BTreeSet::default(), - ) + ); + + let input_files = std_src_input_files(db, ingot); + ingot.set_files(db, input_files); + ingot } // pub fn std_src_files() -> Vec<(&'static str, &'static str)> { // static_dir_files(STD.get_dir("src").unwrap()) // } -// pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { -// fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { -// accum.extend(dir.files().map(|file| { -// ( -// file.path().to_str().unwrap(), -// file.contents_utf8().expect("non-utf8 static file"), -// ) -// })); - -// for sub_dir in dir.dirs() { -// add_files(sub_dir, accum) -// } -// } - -// let mut files = vec![]; -// add_files(dir, &mut files); -// files -// } +pub fn static_dir_files(dir: &'static Dir) -> Vec<(&'static str, &'static str)> { + fn add_files(dir: &'static Dir, accum: &mut Vec<(&'static str, &'static str)>) { + accum.extend(dir.files().map(|file| { + ( + file.path().to_str().unwrap(), + file.contents_utf8().expect("non-utf8 static file"), + ) + })); + + for sub_dir in dir.dirs() { + add_files(sub_dir, accum) + } + } + + let mut files = vec![]; + add_files(dir, &mut files); + files +} From 5f3c31380b5630eeab0b0e7e915a39b342fe6ec3 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Fri, 19 Jan 2024 14:00:01 -0700 Subject: [PATCH 10/22] hacking --- crates/mir2/Cargo.toml | 10 +++++----- crates/mir2/src/lib.rs | 2 +- crates/mir2/src/lower/mod.rs | 5 ----- crates/mir2/tests/lowering.rs | 35 +++++++++++++++++------------------ crates/mir2/tests/test_db.rs | 23 ++++++++++++++--------- 5 files changed, 37 insertions(+), 38 deletions(-) diff --git a/crates/mir2/Cargo.toml b/crates/mir2/Cargo.toml index c34010647b..941bf8c549 100644 --- a/crates/mir2/Cargo.toml +++ b/crates/mir2/Cargo.toml @@ -7,10 +7,10 @@ license = "Apache-2.0" repository = "https://github.com/ethereum/fe" [dependencies] -fe-common2 = { path = "../common2", version = "^0.23.0" } -fe-parser2 = { path = "../parser2", version = "^0.23.0" } -fe-hir-analysis = { path = "../hir-analysis", version = "^0.23.0" } -fe-hir = { path = "../hir", version = "^0.23.0" } +common = { path = "../common2", package = "fe-common2" } +parser = { path = "../parser2", package = "fe-parser2" } +hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } +hir = { path = "../hir", package = "fe-hir" } salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } smol_str = "0.1.21" num-bigint = "0.4.3" @@ -23,4 +23,4 @@ indexmap = "1.6.2" [dev-dependencies] test-files = { path = "../test-files", package = "fe-test-files" } -fe-library2 = { path = "../library2" } +library = { path = "../library2" , package = "fe-library2"} diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index f9e76669b9..562d59cad8 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -1,4 +1,4 @@ -use fe_hir::HirDb; +use hir::HirDb; // pub mod analysis; // pub mod graphviz; diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs index acfc941b85..8b13789179 100644 --- a/crates/mir2/src/lower/mod.rs +++ b/crates/mir2/src/lower/mod.rs @@ -1,6 +1 @@ -// pub mod function; -// pub mod types; -// mod pattern_match; - -pub fn lower() {} diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index e7ffe3b44a..bab77eff85 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -1,4 +1,5 @@ -use fe_hir::hir_def::IngotId; +use common::InputDb; +use hir::hir_def::IngotId; use test_db::{initialize_analysis_pass, LowerMirTestDb}; mod test_db; @@ -28,26 +29,24 @@ mod test_db; #[test] fn mir_lower_std_lib() { let mut db = LowerMirTestDb::default(); + let top_mod = db.new_std_lib(); - // Should return the same id - let std_ingot = IngotId::dummy(); + // let mut pm = initialize_analysis_pass(&db); + // let diags = pm.run_on_module(std_ingot.root_mod(&db)); - let mut pm = initialize_analysis_pass(&db); - let diags = pm.run_on_module(std_ingot.root_mod(&db)); + // if !diags.is_empty() { + // panic!("std lib analysis failed") + // } - if !diags.is_empty() { - panic!("std lib analysis failed") - } - - for &module in std_ingot.all_modules(&db).iter() { - // for func in db.mir_lower_module_all_functions(module).iter() { - // let body = func.body(&db); - // let cfg = ControlFlowGraph::compute(&body); - // let domtree = DomTree::compute(&cfg); - // LoopTree::compute(&cfg, &domtree); - // PostDomTree::compute(&body); - // } - } + // for &module in std_ingot.all_modules(&db).iter() { + // for func in db.mir_lower_module_all_functions(module).iter() { + // let body = func.body(&db); + // let cfg = ControlFlowGraph::compute(&body); + // let domtree = DomTree::compute(&cfg); + // LoopTree::compute(&cfg, &domtree); + // PostDomTree::compute(&body); + // } + // } } // test_lowering! { mir_erc20_token, "demos/erc20_token.fe"} diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index 9d8dae22ba..3b64ec0fe4 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -8,20 +8,19 @@ use std::collections::{BTreeMap, BTreeSet}; // termcolor::{BufferWriter, ColorChoice}, // }, // }; -use fe_common2::{ +use common::{ diagnostics::Span, input::{IngotKind, Version}, InputFile, InputIngot, }; -use fe_hir::{analysis_pass::AnalysisPassManager, hir_def::TopLevelMod, ParsingPass}; -use fe_hir_analysis::{ +use hir::{analysis_pass::AnalysisPassManager, hir_def::TopLevelMod, ParsingPass}; +use hir_analysis::{ name_resolution::{DefConflictAnalysisPass, ImportAnalysisPass, PathAnalysisPass}, ty::{ FuncAnalysisPass, ImplAnalysisPass, ImplTraitAnalysisPass, TraitAnalysisPass, TypeAliasAnalysisPass, TypeDefAnalysisPass, }, }; -use fe_mir2::LowerMirDb; // use hir::{ // hir_def::TopLevelMod, // lower, @@ -33,11 +32,11 @@ use fe_mir2::LowerMirDb; type CodeSpanFileId = usize; #[salsa::db( - fe_common2::Jar, - fe_hir::Jar, - fe_hir::SpannedJar, - fe_hir::LowerJar, - fe_hir_analysis::Jar + common::Jar, + hir::Jar, + hir::SpannedJar, + hir::LowerJar, + hir_analysis::Jar )] pub struct LowerMirTestDb { storage: salsa::Storage, @@ -52,12 +51,18 @@ impl LowerMirTestDb { ingot.set_root_file(self, root); ingot.set_files(self, [root].into()); + // let top_mod = lower::map_file_to_mod(self, input_file); + // let mut prop_formatter = HirPropertyFormatter::default(); // let top_mod = self.register_file(&mut prop_formatter, root); // let top_mod = self.register_file(root); // top_mod } + pub fn new_std_lib(&mut self) { + library::std_lib_input_ingot(self); + } + fn register_file(&self, input_file: InputFile) { // let top_mod = lower::map_file_to_mod(self, input_file); // let path = input_file.path(self); From 181679229552d240c702f08294dc3c6d9b14de69 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Mon, 22 Jan 2024 09:14:48 -0700 Subject: [PATCH 11/22] hacking --- crates/mir2/tests/lowering.rs | 4 ++-- crates/mir2/tests/test_db.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index bab77eff85..6063e7f4ea 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -31,8 +31,8 @@ fn mir_lower_std_lib() { let mut db = LowerMirTestDb::default(); let top_mod = db.new_std_lib(); - // let mut pm = initialize_analysis_pass(&db); - // let diags = pm.run_on_module(std_ingot.root_mod(&db)); + let mut pm = initialize_analysis_pass(&db); + let diags = pm.run_on_module(std_ingot.root_mod(&db)); // if !diags.is_empty() { // panic!("std lib analysis failed") diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index 3b64ec0fe4..7c2bdb5623 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -51,7 +51,7 @@ impl LowerMirTestDb { ingot.set_root_file(self, root); ingot.set_files(self, [root].into()); - // let top_mod = lower::map_file_to_mod(self, input_file); + let top_mod = lower::map_file_to_mod(self, input_file); // let mut prop_formatter = HirPropertyFormatter::default(); // let top_mod = self.register_file(&mut prop_formatter, root); From 8f7eb448a5b8880d4fda9997c9f6bcd9db1eeb8e Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Mon, 22 Jan 2024 21:07:55 -0700 Subject: [PATCH 12/22] hacking --- crates/codegen2/Cargo.toml | 18 + crates/codegen2/src/db.rs | 94 ++ crates/codegen2/src/db/queries.rs | 5 + crates/codegen2/src/db/queries/abi.rs | 273 ++++ crates/codegen2/src/db/queries/constant.rs | 12 + crates/codegen2/src/db/queries/contract.rs | 20 + crates/codegen2/src/db/queries/function.rs | 76 + crates/codegen2/src/db/queries/types.rs | 102 ++ crates/codegen2/src/lib.rs | 2 + crates/codegen2/src/yul/isel/context.rs | 81 + crates/codegen2/src/yul/isel/contract.rs | 289 ++++ crates/codegen2/src/yul/isel/function.rs | 978 ++++++++++++ crates/codegen2/src/yul/isel/inst_order.rs | 1368 +++++++++++++++++ crates/codegen2/src/yul/isel/mod.rs | 9 + crates/codegen2/src/yul/isel/test.rs | 70 + crates/codegen2/src/yul/legalize/body.rs | 219 +++ .../src/yul/legalize/critical_edge.rs | 121 ++ crates/codegen2/src/yul/legalize/mod.rs | 6 + crates/codegen2/src/yul/legalize/signature.rs | 27 + crates/codegen2/src/yul/mod.rs | 26 + crates/codegen2/src/yul/runtime/abi.rs | 950 ++++++++++++ crates/codegen2/src/yul/runtime/contract.rs | 127 ++ crates/codegen2/src/yul/runtime/data.rs | 461 ++++++ crates/codegen2/src/yul/runtime/emit.rs | 74 + crates/codegen2/src/yul/runtime/mod.rs | 828 ++++++++++ crates/codegen2/src/yul/runtime/revert.rs | 91 ++ crates/codegen2/src/yul/runtime/safe_math.rs | 628 ++++++++ crates/codegen2/src/yul/slot_size.rs | 16 + 28 files changed, 6971 insertions(+) create mode 100644 crates/codegen2/Cargo.toml create mode 100644 crates/codegen2/src/db.rs create mode 100644 crates/codegen2/src/db/queries.rs create mode 100644 crates/codegen2/src/db/queries/abi.rs create mode 100644 crates/codegen2/src/db/queries/constant.rs create mode 100644 crates/codegen2/src/db/queries/contract.rs create mode 100644 crates/codegen2/src/db/queries/function.rs create mode 100644 crates/codegen2/src/db/queries/types.rs create mode 100644 crates/codegen2/src/lib.rs create mode 100644 crates/codegen2/src/yul/isel/context.rs create mode 100644 crates/codegen2/src/yul/isel/contract.rs create mode 100644 crates/codegen2/src/yul/isel/function.rs create mode 100644 crates/codegen2/src/yul/isel/inst_order.rs create mode 100644 crates/codegen2/src/yul/isel/mod.rs create mode 100644 crates/codegen2/src/yul/isel/test.rs create mode 100644 crates/codegen2/src/yul/legalize/body.rs create mode 100644 crates/codegen2/src/yul/legalize/critical_edge.rs create mode 100644 crates/codegen2/src/yul/legalize/mod.rs create mode 100644 crates/codegen2/src/yul/legalize/signature.rs create mode 100644 crates/codegen2/src/yul/mod.rs create mode 100644 crates/codegen2/src/yul/runtime/abi.rs create mode 100644 crates/codegen2/src/yul/runtime/contract.rs create mode 100644 crates/codegen2/src/yul/runtime/data.rs create mode 100644 crates/codegen2/src/yul/runtime/emit.rs create mode 100644 crates/codegen2/src/yul/runtime/mod.rs create mode 100644 crates/codegen2/src/yul/runtime/revert.rs create mode 100644 crates/codegen2/src/yul/runtime/safe_math.rs create mode 100644 crates/codegen2/src/yul/slot_size.rs diff --git a/crates/codegen2/Cargo.toml b/crates/codegen2/Cargo.toml new file mode 100644 index 0000000000..d6b46fb289 --- /dev/null +++ b/crates/codegen2/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "fe-codegen2" +version = "0.23.0" +authors = ["The Fe Developers "] +edition = "2021" + +[dependencies] +hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } +hir = { path = "../hir", package = "fe-hir" } +fe-mir = { path = "../mir", version = "^0.23.0"} +fe-common = { path = "../common", version = "^0.23.0"} +fe-abi = { path = "../abi", version = "^0.23.0"} +salsa = "0.16.1" +num-bigint = "0.4.3" +fxhash = "0.2.1" +indexmap = "1.6.2" +smol_str = "0.1.21" +yultsur = { git = "https://github.com/fe-lang/yultsur", rev = "ae85470" } diff --git a/crates/codegen2/src/db.rs b/crates/codegen2/src/db.rs new file mode 100644 index 0000000000..ce036795a3 --- /dev/null +++ b/crates/codegen2/src/db.rs @@ -0,0 +1,94 @@ +#![allow(clippy::arc_with_non_send_sync)] +use std::rc::Rc; + +use fe_abi::{contract::AbiContract, event::AbiEvent, function::AbiFunction, types::AbiType}; +use fe_analyzer::{db::AnalyzerDbStorage, namespace::items::ContractId, AnalyzerDb}; +use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; +use fe_mir::{ + db::{MirDb, MirDbStorage}, + ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, +}; + +mod queries; + +#[salsa::query_group(CodegenDbStorage)] +pub trait CodegenDb: MirDb + Upcast + UpcastMut { + #[salsa::invoke(queries::function::legalized_signature)] + fn codegen_legalized_signature(&self, function_id: FunctionId) -> Rc; + #[salsa::invoke(queries::function::legalized_body)] + fn codegen_legalized_body(&self, function_id: FunctionId) -> Rc; + #[salsa::invoke(queries::function::symbol_name)] + fn codegen_function_symbol_name(&self, function_id: FunctionId) -> Rc; + + #[salsa::invoke(queries::types::legalized_type)] + fn codegen_legalized_type(&self, ty: TypeId) -> TypeId; + + #[salsa::invoke(queries::abi::abi_type)] + fn codegen_abi_type(&self, ty: TypeId) -> AbiType; + #[salsa::invoke(queries::abi::abi_function)] + fn codegen_abi_function(&self, function_id: FunctionId) -> AbiFunction; + #[salsa::invoke(queries::abi::abi_event)] + fn codegen_abi_event(&self, ty: TypeId) -> AbiEvent; + #[salsa::invoke(queries::abi::abi_contract)] + fn codegen_abi_contract(&self, contract: ContractId) -> AbiContract; + #[salsa::invoke(queries::abi::abi_type_maximum_size)] + fn codegen_abi_type_maximum_size(&self, ty: TypeId) -> usize; + #[salsa::invoke(queries::abi::abi_type_minimum_size)] + fn codegen_abi_type_minimum_size(&self, ty: TypeId) -> usize; + #[salsa::invoke(queries::abi::abi_function_argument_maximum_size)] + fn codegen_abi_function_argument_maximum_size(&self, contract: FunctionId) -> usize; + #[salsa::invoke(queries::abi::abi_function_return_maximum_size)] + fn codegen_abi_function_return_maximum_size(&self, function: FunctionId) -> usize; + + #[salsa::invoke(queries::contract::symbol_name)] + fn codegen_contract_symbol_name(&self, contract: ContractId) -> Rc; + #[salsa::invoke(queries::contract::deployer_symbol_name)] + fn codegen_contract_deployer_symbol_name(&self, contract: ContractId) -> Rc; + + #[salsa::invoke(queries::constant::string_symbol_name)] + fn codegen_constant_string_symbol_name(&self, data: String) -> Rc; +} + +// TODO: Move this to driver. +#[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage, CodegenDbStorage)] +#[derive(Default)] +pub struct Db { + storage: salsa::Storage, +} +impl salsa::Database for Db {} + +impl Upcast for Db { + fn upcast(&self) -> &(dyn MirDb + 'static) { + self + } +} + +impl UpcastMut for Db { + fn upcast_mut(&mut self) -> &mut (dyn MirDb + 'static) { + &mut *self + } +} + +impl Upcast for Db { + fn upcast(&self) -> &(dyn SourceDb + 'static) { + self + } +} + +impl UpcastMut for Db { + fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { + &mut *self + } +} + +impl Upcast for Db { + fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { + self + } +} + +impl UpcastMut for Db { + fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { + &mut *self + } +} diff --git a/crates/codegen2/src/db/queries.rs b/crates/codegen2/src/db/queries.rs new file mode 100644 index 0000000000..31cca43870 --- /dev/null +++ b/crates/codegen2/src/db/queries.rs @@ -0,0 +1,5 @@ +pub mod abi; +pub mod constant; +pub mod contract; +pub mod function; +pub mod types; diff --git a/crates/codegen2/src/db/queries/abi.rs b/crates/codegen2/src/db/queries/abi.rs new file mode 100644 index 0000000000..e492166e86 --- /dev/null +++ b/crates/codegen2/src/db/queries/abi.rs @@ -0,0 +1,273 @@ +use fe_abi::{ + contract::AbiContract, + event::{AbiEvent, AbiEventField}, + function::{AbiFunction, AbiFunctionType, CtxParam, SelfParam, StateMutability}, + types::{AbiTupleField, AbiType}, +}; +use fe_analyzer::{ + constants::INDEXED, + namespace::{ + items::ContractId, + types::{CtxDecl, SelfDecl}, + }, +}; +use fe_mir::ir::{self, FunctionId, TypeId}; + +use crate::db::CodegenDb; + +pub fn abi_contract(db: &dyn CodegenDb, contract: ContractId) -> AbiContract { + let mut funcs = vec![]; + + if let Some(init) = contract.init_function(db.upcast()) { + let init_func = db.mir_lowered_func_signature(init); + let init_abi = db.codegen_abi_function(init_func); + funcs.push(init_abi); + } + + for &func in contract.all_functions(db.upcast()).as_ref() { + let mir_func = db.mir_lowered_func_signature(func); + if mir_func.linkage(db.upcast()).is_exported() { + let func_abi = db.codegen_abi_function(mir_func); + funcs.push(func_abi); + } + } + + let mut events = vec![]; + for &s in db.module_structs(contract.module(db.upcast())).as_ref() { + let struct_ty = s.as_type(db.upcast()); + // TODO: This is a hack to avoid generating an ABI for non-`emittable` structs. + if struct_ty.is_emittable(db.upcast()) { + let mir_event = db.mir_lowered_type(struct_ty); + let event = db.codegen_abi_event(mir_event); + events.push(event); + } + } + + AbiContract::new(funcs, events) +} + +pub fn abi_function(db: &dyn CodegenDb, function: FunctionId) -> AbiFunction { + // We use a legalized signature. + let sig = db.codegen_legalized_signature(function); + + let name = function.name(db.upcast()); + let args = sig + .params + .iter() + .map(|param| (param.name.to_string(), db.codegen_abi_type(param.ty))) + .collect(); + let ret_ty = sig.return_type.map(|ty| db.codegen_abi_type(ty)); + + let func_type = if function.is_contract_init(db.upcast()) { + AbiFunctionType::Constructor + } else { + AbiFunctionType::Function + }; + + // The "stateMutability" field is derived from the presence & mutability of + // `self` and `ctx` params in the analyzer fn sig. + let analyzer_sig = sig.analyzer_func_id.signature(db.upcast()); + let self_param = match analyzer_sig.self_decl { + None => SelfParam::None, + Some(SelfDecl { mut_: None, .. }) => SelfParam::Imm, + Some(SelfDecl { mut_: Some(_), .. }) => SelfParam::Mut, + }; + let ctx_param = match analyzer_sig.ctx_decl { + None => CtxParam::None, + Some(CtxDecl { mut_: None, .. }) => CtxParam::Imm, + Some(CtxDecl { mut_: Some(_), .. }) => CtxParam::Mut, + }; + + let state_mutability = if name == "__init__" { + StateMutability::Payable + } else { + StateMutability::from_self_and_ctx_params(self_param, ctx_param) + }; + + AbiFunction::new(func_type, name.to_string(), args, ret_ty, state_mutability) +} + +pub fn abi_function_argument_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { + let sig = db.codegen_legalized_signature(function); + sig.params.iter().fold(0, |acc, param| { + acc + db.codegen_abi_type_maximum_size(param.ty) + }) +} + +pub fn abi_function_return_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { + let sig = db.codegen_legalized_signature(function); + sig.return_type + .map(|ty| db.codegen_abi_type_maximum_size(ty)) + .unwrap_or_default() +} + +pub fn abi_type_maximum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { + let abi_type = db.codegen_abi_type(ty); + if abi_type.is_static() { + abi_type.header_size() + } else { + match &ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) if def.elem_ty.data(db.upcast()).kind == ir::TypeKind::U8 => { + debug_assert_eq!(abi_type, AbiType::Bytes); + 64 + ceil_32(def.len) + } + + ir::TypeKind::Array(def) => { + db.codegen_abi_type_maximum_size(def.elem_ty) * def.len + 32 + } + + ir::TypeKind::String(len) => abi_type.header_size() + 32 + ceil_32(*len), + _ if ty.is_aggregate(db.upcast()) => { + let mut maximum = 0; + for i in 0..ty.aggregate_field_num(db.upcast()) { + let field_ty = ty.projection_ty_imm(db.upcast(), i); + maximum += db.codegen_abi_type_maximum_size(field_ty) + } + maximum + 32 + } + ir::TypeKind::MPtr(ty) => abi_type_maximum_size(db, ty.deref(db.upcast())), + + _ => unreachable!(), + } + } +} + +pub fn abi_type_minimum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { + let abi_type = db.codegen_abi_type(ty); + if abi_type.is_static() { + abi_type.header_size() + } else { + match &ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) if def.elem_ty.data(db.upcast()).kind == ir::TypeKind::U8 => { + debug_assert_eq!(abi_type, AbiType::Bytes); + 64 + } + ir::TypeKind::Array(def) => { + db.codegen_abi_type_minimum_size(def.elem_ty) * def.len + 32 + } + + ir::TypeKind::String(_) => abi_type.header_size() + 32, + + _ if ty.is_aggregate(db.upcast()) => { + let mut minimum = 0; + for i in 0..ty.aggregate_field_num(db.upcast()) { + let field_ty = ty.projection_ty_imm(db.upcast(), i); + minimum += db.codegen_abi_type_minimum_size(field_ty) + } + minimum + 32 + } + ir::TypeKind::MPtr(ty) => abi_type_minimum_size(db, ty.deref(db.upcast())), + _ => unreachable!(), + } + } +} + +pub fn abi_type(db: &dyn CodegenDb, ty: TypeId) -> AbiType { + let legalized_ty = db.codegen_legalized_type(ty); + + if legalized_ty.is_zero_sized(db.upcast()) { + unreachable!("zero-sized type must be removed in legalization"); + } + + let ty_data = legalized_ty.data(db.upcast()); + + match &ty_data.kind { + ir::TypeKind::I8 => AbiType::Int(8), + ir::TypeKind::I16 => AbiType::Int(16), + ir::TypeKind::I32 => AbiType::Int(32), + ir::TypeKind::I64 => AbiType::Int(64), + ir::TypeKind::I128 => AbiType::Int(128), + ir::TypeKind::I256 => AbiType::Int(256), + ir::TypeKind::U8 => AbiType::UInt(8), + ir::TypeKind::U16 => AbiType::UInt(16), + ir::TypeKind::U32 => AbiType::UInt(32), + ir::TypeKind::U64 => AbiType::UInt(64), + ir::TypeKind::U128 => AbiType::UInt(128), + ir::TypeKind::U256 => AbiType::UInt(256), + ir::TypeKind::Bool => AbiType::Bool, + ir::TypeKind::Address => AbiType::Address, + ir::TypeKind::String(_) => AbiType::String, + ir::TypeKind::Unit => unreachable!("zero-sized type must be removed in legalization"), + ir::TypeKind::Array(def) => { + let elem_ty_data = &def.elem_ty.data(db.upcast()); + match &elem_ty_data.kind { + ir::TypeKind::U8 => AbiType::Bytes, + _ => { + let elem_ty = db.codegen_abi_type(def.elem_ty); + let len = def.len; + AbiType::Array { + elem_ty: elem_ty.into(), + len, + } + } + } + } + ir::TypeKind::Tuple(def) => { + let fields = def + .items + .iter() + .enumerate() + .map(|(i, item)| { + let field_ty = db.codegen_abi_type(*item); + AbiTupleField::new(format!("{i}"), field_ty) + }) + .collect(); + + AbiType::Tuple(fields) + } + ir::TypeKind::Struct(def) => { + let fields = def + .fields + .iter() + .map(|(name, ty)| { + let ty = db.codegen_abi_type(*ty); + AbiTupleField::new(name.to_string(), ty) + }) + .collect(); + + AbiType::Tuple(fields) + } + ir::TypeKind::MPtr(inner) => db.codegen_abi_type(*inner), + + ir::TypeKind::Contract(_) + | ir::TypeKind::Map(_) + | ir::TypeKind::Enum(_) + | ir::TypeKind::SPtr(_) => unreachable!(), + } +} + +pub fn abi_event(db: &dyn CodegenDb, ty: TypeId) -> AbiEvent { + debug_assert!(ty.is_struct(db.upcast())); + + let legalized_ty = db.codegen_legalized_type(ty); + let analyzer_struct = ty + .analyzer_ty(db.upcast()) + .and_then(|val| val.as_struct(db.upcast())) + .unwrap(); + let legalized_ty_data = legalized_ty.data(db.upcast()); + let event_def = match &legalized_ty_data.kind { + ir::TypeKind::Struct(def) => def, + _ => unreachable!(), + }; + + let fields = event_def + .fields + .iter() + .map(|(name, ty)| { + let attr = analyzer_struct + .field(db.upcast(), name) + .unwrap() + .attributes(db.upcast()); + + let ty = db.codegen_abi_type(*ty); + let indexed = attr.iter().any(|attr| attr == INDEXED); + AbiEventField::new(name.to_string(), ty, indexed) + }) + .collect(); + + AbiEvent::new(event_def.name.to_string(), fields, false) +} + +fn ceil_32(value: usize) -> usize { + ((value + 31) / 32) * 32 +} diff --git a/crates/codegen2/src/db/queries/constant.rs b/crates/codegen2/src/db/queries/constant.rs new file mode 100644 index 0000000000..2a78aba9b4 --- /dev/null +++ b/crates/codegen2/src/db/queries/constant.rs @@ -0,0 +1,12 @@ +use crate::db::CodegenDb; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + rc::Rc, +}; + +pub fn string_symbol_name(_db: &dyn CodegenDb, data: String) -> Rc { + let mut hasher = DefaultHasher::new(); + data.hash(&mut hasher); + format! {"{}", hasher.finish()}.into() +} diff --git a/crates/codegen2/src/db/queries/contract.rs b/crates/codegen2/src/db/queries/contract.rs new file mode 100644 index 0000000000..be0002a371 --- /dev/null +++ b/crates/codegen2/src/db/queries/contract.rs @@ -0,0 +1,20 @@ +use std::rc::Rc; + +use fe_analyzer::namespace::items::ContractId; + +use crate::db::CodegenDb; + +pub fn symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { + let module = contract.module(db.upcast()); + + format!( + "{}${}", + module.name(db.upcast()), + contract.name(db.upcast()) + ) + .into() +} + +pub fn deployer_symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { + format!("deploy_{}", symbol_name(db, contract).as_ref()).into() +} diff --git a/crates/codegen2/src/db/queries/function.rs b/crates/codegen2/src/db/queries/function.rs new file mode 100644 index 0000000000..d4527271e4 --- /dev/null +++ b/crates/codegen2/src/db/queries/function.rs @@ -0,0 +1,76 @@ +use std::rc::Rc; + +use fe_analyzer::{ + display::Displayable, + namespace::{ + items::Item, + types::{Type, TypeId}, + }, +}; +use fe_mir::ir::{FunctionBody, FunctionId, FunctionSignature}; +use salsa::InternKey; +use smol_str::SmolStr; + +use crate::{db::CodegenDb, yul::legalize}; + +pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut sig = function.signature(db.upcast()).as_ref().clone(); + legalize::legalize_func_signature(db, &mut sig); + sig.into() +} + +pub fn legalized_body(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut body = function.body(db.upcast()).as_ref().clone(); + legalize::legalize_func_body(db, &mut body); + body.into() +} + +pub fn symbol_name(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let module = function.signature(db.upcast()).module_id; + let module_name = module.name(db.upcast()); + + let analyzer_func = function.analyzer_func(db.upcast()); + let func_name = format!( + "{}{}", + analyzer_func.name(db.upcast()), + type_suffix(function, db) + ); + + let func_name = match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { + Some(Item::Impl(id)) => { + let class_name = format!( + "{}${}", + id.trait_id(db.upcast()).name(db.upcast()), + safe_name(db, id.receiver(db.upcast())) + ); + format!("{class_name}${func_name}") + } + Some(class) => { + let class_name = class.name(db.upcast()); + format!("{class_name}${func_name}") + } + _ => func_name, + }; + + format!("{module_name}${func_name}").into() +} + +fn type_suffix(function: FunctionId, db: &dyn CodegenDb) -> SmolStr { + function + .signature(db.upcast()) + .resolved_generics + .values() + .fold(String::new(), |acc, param| { + format!("{}_{}", acc, safe_name(db, *param)) + }) + .into() +} + +fn safe_name(db: &dyn CodegenDb, ty: TypeId) -> SmolStr { + match ty.typ(db.upcast()) { + // TODO: Would be nice to get more human friendly names here + Type::Array(_) => format!("array_{:?}", ty.as_intern_id()).into(), + Type::Tuple(_) => format!("tuple_{:?}", ty.as_intern_id()).into(), + _ => format!("{}", ty.display(db.upcast())).into(), + } +} diff --git a/crates/codegen2/src/db/queries/types.rs b/crates/codegen2/src/db/queries/types.rs new file mode 100644 index 0000000000..5bb4df9883 --- /dev/null +++ b/crates/codegen2/src/db/queries/types.rs @@ -0,0 +1,102 @@ +use fe_mir::ir::{ + types::{ArrayDef, MapDef, StructDef, TupleDef}, + Type, TypeId, TypeKind, +}; + +use crate::db::CodegenDb; + +pub fn legalized_type(db: &dyn CodegenDb, ty: TypeId) -> TypeId { + let ty_data = ty.data(db.upcast()); + let ty_kind = match &ty.data(db.upcast()).kind { + TypeKind::Tuple(def) => { + let items = def + .items + .iter() + .filter_map(|item| { + if item.is_zero_sized(db.upcast()) { + None + } else { + Some(legalized_type(db, *item)) + } + }) + .collect(); + let new_def = TupleDef { items }; + TypeKind::Tuple(new_def) + } + + TypeKind::Array(def) => { + let new_def = ArrayDef { + elem_ty: legalized_type(db, def.elem_ty), + len: def.len, + }; + TypeKind::Array(new_def) + } + + TypeKind::Struct(def) => { + let fields = def + .fields + .iter() + .cloned() + .filter_map(|(name, ty)| { + if ty.is_zero_sized(db.upcast()) { + None + } else { + Some((name, legalized_type(db, ty))) + } + }) + .collect(); + let new_def = StructDef { + name: def.name.clone(), + fields, + span: def.span, + module_id: def.module_id, + }; + TypeKind::Struct(new_def) + } + + TypeKind::Contract(def) => { + let fields = def + .fields + .iter() + .cloned() + .filter_map(|(name, ty)| { + if ty.is_zero_sized(db.upcast()) { + None + } else { + Some((name, legalized_type(db, ty))) + } + }) + .collect(); + let new_def = StructDef { + name: def.name.clone(), + fields, + span: def.span, + module_id: def.module_id, + }; + TypeKind::Contract(new_def) + } + + TypeKind::Map(def) => { + let new_def = MapDef { + key_ty: legalized_type(db, def.key_ty), + value_ty: legalized_type(db, def.value_ty), + }; + TypeKind::Map(new_def) + } + + TypeKind::MPtr(ty) => { + let new_ty = legalized_type(db, *ty); + TypeKind::MPtr(new_ty) + } + + TypeKind::SPtr(ty) => { + let new_ty = legalized_type(db, *ty); + TypeKind::SPtr(new_ty) + } + + _ => return ty, + }; + + let analyzer_ty = ty_data.analyzer_ty; + db.mir_intern_type(Type::new(ty_kind, analyzer_ty).into()) +} diff --git a/crates/codegen2/src/lib.rs b/crates/codegen2/src/lib.rs new file mode 100644 index 0000000000..37ec962db2 --- /dev/null +++ b/crates/codegen2/src/lib.rs @@ -0,0 +1,2 @@ +pub mod db; +pub mod yul; diff --git a/crates/codegen2/src/yul/isel/context.rs b/crates/codegen2/src/yul/isel/context.rs new file mode 100644 index 0000000000..4ca2840cc3 --- /dev/null +++ b/crates/codegen2/src/yul/isel/context.rs @@ -0,0 +1,81 @@ +use indexmap::IndexSet; + +use fe_analyzer::namespace::items::ContractId; +use fe_mir::ir::FunctionId; +use fxhash::FxHashSet; +use yultsur::yul; + +use crate::{ + db::CodegenDb, + yul::runtime::{DefaultRuntimeProvider, RuntimeProvider}, +}; + +use super::{lower_contract_deployable, lower_function}; + +pub struct Context { + pub runtime: Box, + pub(super) contract_dependency: IndexSet, + pub(super) function_dependency: IndexSet, + pub(super) string_constants: IndexSet, + pub(super) lowered_functions: FxHashSet, +} + +// Currently, `clippy::derivable_impls` causes false positive result, +// see https://github.com/rust-lang/rust-clippy/issues/10158 for more details. +#[allow(clippy::derivable_impls)] +impl Default for Context { + fn default() -> Self { + Self { + runtime: Box::::default(), + contract_dependency: IndexSet::default(), + function_dependency: IndexSet::default(), + string_constants: IndexSet::default(), + lowered_functions: FxHashSet::default(), + } + } +} + +impl Context { + pub(super) fn resolve_function_dependency( + &mut self, + db: &dyn CodegenDb, + ) -> Vec { + let mut funcs = vec![]; + loop { + let dependencies = std::mem::take(&mut self.function_dependency); + if dependencies.is_empty() { + break; + } + for dependency in dependencies { + if self.lowered_functions.contains(&dependency) { + // Ignore dependency if it's already lowered. + continue; + } else { + funcs.push(lower_function(db, self, dependency)) + } + } + } + + funcs + } + + pub(super) fn resolve_constant_dependency(&self, db: &dyn CodegenDb) -> Vec { + self.string_constants + .iter() + .map(|s| { + let symbol = db.codegen_constant_string_symbol_name(s.to_string()); + yul::Data { + name: symbol.as_ref().clone(), + value: s.to_string(), + } + }) + .collect() + } + + pub(super) fn resolve_contract_dependency(&self, db: &dyn CodegenDb) -> Vec { + self.contract_dependency + .iter() + .map(|cid| lower_contract_deployable(db, *cid)) + .collect() + } +} diff --git a/crates/codegen2/src/yul/isel/contract.rs b/crates/codegen2/src/yul/isel/contract.rs new file mode 100644 index 0000000000..3703381d8e --- /dev/null +++ b/crates/codegen2/src/yul/isel/contract.rs @@ -0,0 +1,289 @@ +use fe_analyzer::namespace::items::ContractId; +use fe_mir::ir::{function::Linkage, FunctionId}; +use yultsur::{yul, *}; + +use crate::{ + db::CodegenDb, + yul::{runtime::AbiSrcLocation, YulVariable}, +}; + +use super::context::Context; + +pub fn lower_contract_deployable(db: &dyn CodegenDb, contract: ContractId) -> yul::Object { + let mut context = Context::default(); + + let constructor = if let Some(init) = contract.init_function(db.upcast()) { + let init = db.mir_lowered_func_signature(init); + make_init(db, &mut context, contract, init) + } else { + statements! {} + }; + + let deploy_code = make_deploy(db, contract); + + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + + let deploy_block = block_statement! { + [constructor...] + [deploy_code...] + }; + + let code = code! { + [deploy_block] + [dep_functions...] + [runtime_funcs...] + }; + + let mut dep_contracts = context.resolve_contract_dependency(db); + dep_contracts.push(lower_contract(db, contract)); + let dep_constants = context.resolve_constant_dependency(db); + + let name = identifier! {( + db.codegen_contract_deployer_symbol_name(contract).as_ref() + )}; + let object = yul::Object { + name, + code, + objects: dep_contracts, + data: dep_constants, + }; + + normalize_object(object) +} + +pub fn lower_contract(db: &dyn CodegenDb, contract: ContractId) -> yul::Object { + let exported_funcs: Vec<_> = db + .mir_lower_contract_all_functions(contract) + .iter() + .filter_map(|fid| { + if fid.signature(db.upcast()).linkage == Linkage::Export { + Some(*fid) + } else { + None + } + }) + .collect(); + + let mut context = Context::default(); + let dispatcher = if let Some(call_fn) = contract.call_function(db.upcast()) { + let call_fn = db.mir_lowered_func_signature(call_fn); + context.function_dependency.insert(call_fn); + let call_symbol = identifier! { (db.codegen_function_symbol_name(call_fn)) }; + statement! { + ([call_symbol]()) + } + } else { + make_dispatcher(db, &mut context, &exported_funcs) + }; + + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + + let code = code! { + ([dispatcher]) + [dep_functions...] + [runtime_funcs...] + }; + + // Lower dependant contracts. + let dep_contracts = context.resolve_contract_dependency(db); + + // Collect string constants. + let dep_constants = context.resolve_constant_dependency(db); + let contract_symbol = identifier! { (db.codegen_contract_symbol_name(contract)) }; + + yul::Object { + name: contract_symbol, + code, + objects: dep_contracts, + data: dep_constants, + } +} + +fn make_dispatcher( + db: &dyn CodegenDb, + context: &mut Context, + funcs: &[FunctionId], +) -> yul::Statement { + let arms = funcs + .iter() + .map(|func| dispatch_arm(db, context, *func)) + .collect::>(); + + if arms.is_empty() { + statement! { return(0, 0) } + } else { + let selector = expression! { + and((shr((sub(256, 32)), (calldataload(0)))), 0xffffffff) + }; + switch! { + switch ([selector]) + [arms...] + (default { (return(0, 0)) }) + } + } +} + +fn dispatch_arm(db: &dyn CodegenDb, context: &mut Context, func: FunctionId) -> yul::Case { + context.function_dependency.insert(func); + let func_sig = db.codegen_legalized_signature(func); + let mut param_vars = Vec::with_capacity(func_sig.params.len()); + let mut param_tys = Vec::with_capacity(func_sig.params.len()); + func_sig.params.iter().for_each(|param| { + param_vars.push(YulVariable::new(param.name.as_str())); + param_tys.push(param.ty); + }); + + let decode_params = if func_sig.params.is_empty() { + statements! {} + } else { + let ident_params: Vec<_> = param_vars.iter().map(YulVariable::ident).collect(); + let param_size = YulVariable::new("param_size"); + statements! { + (let [param_size.ident()] := sub((calldatasize()), 4)) + (let [ident_params...] := [context.runtime.abi_decode(db, expression! { 4 }, param_size.expr(), ¶m_tys, AbiSrcLocation::CallData)]) + } + }; + + let call_and_encode_return = { + let name = identifier! { (db.codegen_function_symbol_name(func)) }; + // we pass in a `0` for the expected `Context` argument + let call = expression! {[name]([(param_vars.iter().map(YulVariable::expr).collect::>())...])}; + if let Some(mut return_type) = func_sig.return_type { + if return_type.is_aggregate(db.upcast()) { + return_type = return_type.make_mptr(db.upcast()); + } + + let ret = YulVariable::new("ret"); + let enc_start = YulVariable::new("enc_start"); + let enc_size = YulVariable::new("enc_size"); + let abi_encode = context.runtime.abi_encode_seq( + db, + &[ret.expr()], + enc_start.expr(), + &[return_type], + false, + ); + statements! { + (let [ret.ident()] := [call]) + (let [enc_start.ident()] := [context.runtime.avail(db)]) + (let [enc_size.ident()] := [abi_encode]) + (return([enc_start.expr()], [enc_size.expr()])) + } + } else { + statements! { + ([yul::Statement::Expression(call)]) + (return(0, 0)) + } + } + }; + + let abi_sig = db.codegen_abi_function(func); + let selector = literal! { (format!("0x{}", abi_sig.selector().hex())) }; + case! { + case [selector] { + [decode_params...] + [call_and_encode_return...] + } + } +} + +fn make_init( + db: &dyn CodegenDb, + context: &mut Context, + contract: ContractId, + init: FunctionId, +) -> Vec { + context.function_dependency.insert(init); + let init_func_name = identifier! { (db.codegen_function_symbol_name(init)) }; + let contract_name = identifier_expression! { (format!{r#""{}""#, db.codegen_contract_deployer_symbol_name(contract)}) }; + + let func_sig = db.codegen_legalized_signature(init); + let mut param_vars = Vec::with_capacity(func_sig.params.len()); + let mut param_tys = Vec::with_capacity(func_sig.params.len()); + let program_size = YulVariable::new("$program_size"); + let arg_size = YulVariable::new("$arg_size"); + let code_size = YulVariable::new("$code_size"); + let memory_data_offset = YulVariable::new("$memory_data_offset"); + func_sig.params.iter().for_each(|param| { + param_vars.push(YulVariable::new(param.name.as_str())); + param_tys.push(param.ty); + }); + + let decode_params = if func_sig.params.is_empty() { + statements! {} + } else { + let ident_params: Vec<_> = param_vars.iter().map(YulVariable::ident).collect(); + statements! { + (let [ident_params...] := [context.runtime.abi_decode(db, memory_data_offset.expr(), arg_size.expr(), ¶m_tys, AbiSrcLocation::Memory)]) + } + }; + + let call = expression! {[init_func_name]([(param_vars.iter().map(YulVariable::expr).collect::>())...])}; + statements! { + (let [program_size.ident()] := datasize([contract_name])) + (let [code_size.ident()] := codesize()) + (let [arg_size.ident()] := sub([code_size.expr()], [program_size.expr()])) + (let [memory_data_offset.ident()] := [context.runtime.alloc(db, arg_size.expr())]) + (codecopy([memory_data_offset.expr()], [program_size.expr()], [arg_size.expr()])) + [decode_params...] + ([yul::Statement::Expression(call)]) + } +} + +fn make_deploy(db: &dyn CodegenDb, contract: ContractId) -> Vec { + let contract_symbol = + identifier_expression! { (format!{r#""{}""#, db.codegen_contract_symbol_name(contract)}) }; + let size = YulVariable::new("$$size"); + statements! { + (let [size.ident()] := (datasize([contract_symbol.clone()]))) + (datacopy(0, (dataoffset([contract_symbol])), [size.expr()])) + (return (0, [size.expr()])) + } +} + +fn normalize_object(obj: yul::Object) -> yul::Object { + let data = obj + .data + .into_iter() + .map(|data| yul::Data { + name: data.name, + value: data + .value + .replace('\\', "\\\\\\\\") + .replace('\n', "\\\\n") + .replace('"', "\\\\\"") + .replace('\r', "\\\\r") + .replace('\t', "\\\\t"), + }) + .collect::>(); + yul::Object { + name: obj.name, + code: obj.code, + objects: obj + .objects + .into_iter() + .map(normalize_object) + .collect::>(), + data, + } +} diff --git a/crates/codegen2/src/yul/isel/function.rs b/crates/codegen2/src/yul/isel/function.rs new file mode 100644 index 0000000000..78eaecce2a --- /dev/null +++ b/crates/codegen2/src/yul/isel/function.rs @@ -0,0 +1,978 @@ +#![allow(unused)] +use std::thread::Scope; + +use super::{context::Context, inst_order::InstSerializer}; +use fe_common::numeric::to_hex_str; + +use fe_abi::function::{AbiFunction, AbiFunctionType}; +use fe_common::db::Upcast; +use fe_mir::{ + ir::{ + self, + constant::ConstantValue, + inst::{BinOp, CallType, CastKind, InstKind, UnOp}, + value::AssignableValue, + Constant, FunctionBody, FunctionId, FunctionSignature, InstId, Type, TypeId, TypeKind, + Value, ValueId, + }, + pretty_print::PrettyPrint, +}; +use fxhash::FxHashMap; +use smol_str::SmolStr; +use yultsur::{ + yul::{self, Statement}, + *, +}; + +use crate::{ + db::CodegenDb, + yul::{ + isel::inst_order::StructuralInst, + runtime::{self, RuntimeProvider}, + slot_size::{function_hash_type, yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, +}; + +pub fn lower_function( + db: &dyn CodegenDb, + ctx: &mut Context, + function: FunctionId, +) -> yul::FunctionDefinition { + debug_assert!(!ctx.lowered_functions.contains(&function)); + ctx.lowered_functions.insert(function); + let sig = &db.codegen_legalized_signature(function); + let body = &db.codegen_legalized_body(function); + FuncLowerHelper::new(db, ctx, function, sig, body).lower_func() +} + +struct FuncLowerHelper<'db, 'a> { + db: &'db dyn CodegenDb, + ctx: &'a mut Context, + value_map: ScopedValueMap, + func: FunctionId, + sig: &'a FunctionSignature, + body: &'a FunctionBody, + ret_value: Option, + sink: Vec, +} + +impl<'db, 'a> FuncLowerHelper<'db, 'a> { + fn new( + db: &'db dyn CodegenDb, + ctx: &'a mut Context, + func: FunctionId, + sig: &'a FunctionSignature, + body: &'a FunctionBody, + ) -> Self { + let mut value_map = ScopedValueMap::default(); + // Register arguments to value_map. + for &value in body.store.locals() { + match body.store.value_data(value) { + Value::Local(local) if local.is_arg => { + let ident = YulVariable::new(local.name.as_str()).ident(); + value_map.insert(value, ident); + } + _ => {} + } + } + + let ret_value = if sig.return_type.is_some() { + Some(YulVariable::new("$ret").ident()) + } else { + None + }; + + Self { + db, + ctx, + value_map, + func, + sig, + body, + ret_value, + sink: Vec::new(), + } + } + + fn lower_func(mut self) -> yul::FunctionDefinition { + let name = identifier! { (self.db.codegen_function_symbol_name(self.func)) }; + + let parameters = self + .sig + .params + .iter() + .map(|param| YulVariable::new(param.name.as_str()).ident()) + .collect(); + + let ret = self + .ret_value + .clone() + .map(|value| vec![value]) + .unwrap_or_default(); + + let body = self.lower_body(); + + yul::FunctionDefinition { + name, + parameters, + returns: ret, + block: body, + } + } + + fn lower_body(mut self) -> yul::Block { + let inst_order = InstSerializer::new(self.body).serialize(); + + for inst in inst_order { + self.lower_structural_inst(inst) + } + + yul::Block { + statements: self.sink, + } + } + + fn lower_structural_inst(&mut self, inst: StructuralInst) { + match inst { + StructuralInst::Inst(inst) => self.lower_inst(inst), + StructuralInst::If { cond, then, else_ } => { + let if_block = self.lower_if(cond, then, else_); + self.sink.push(if_block) + } + StructuralInst::Switch { + scrutinee, + table, + default, + } => { + let switch_block = self.lower_switch(scrutinee, table, default); + self.sink.push(switch_block) + } + StructuralInst::For { body } => { + let for_block = self.lower_for(body); + self.sink.push(for_block) + } + StructuralInst::Break => self.sink.push(yul::Statement::Break), + StructuralInst::Continue => self.sink.push(yul::Statement::Continue), + }; + } + + fn lower_inst(&mut self, inst: InstId) { + if let Some(lhs) = self.body.store.inst_result(inst) { + self.declare_assignable_value(lhs) + } + + match &self.body.store.inst_data(inst).kind { + InstKind::Declare { local } => self.declare_value(*local), + + InstKind::Unary { op, value } => { + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result.ty(self.db.upcast(), &self.body.store); + let result = self.lower_unary(*op, *value); + self.assign_inst_result(inst, result, inst_result_ty.deref(self.db.upcast())) + } + + InstKind::Binary { op, lhs, rhs } => { + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result.ty(self.db.upcast(), &self.body.store); + let result = self.lower_binary(*op, *lhs, *rhs, inst); + self.assign_inst_result(inst, result, inst_result_ty.deref(self.db.upcast())) + } + + InstKind::Cast { kind, value, to } => { + let from_ty = self.body.store.value_ty(*value); + let result = match kind { + CastKind::Primitive => { + debug_assert!( + from_ty.is_primitive(self.db.upcast()) + && to.is_primitive(self.db.upcast()) + ); + let value = self.value_expr(*value); + self.ctx.runtime.primitive_cast(self.db, value, from_ty) + } + CastKind::Untag => { + let from_ty = from_ty.deref(self.db.upcast()); + debug_assert!(from_ty.is_enum(self.db.upcast())); + let value = self.value_expr(*value); + let offset = literal_expression! {(from_ty.enum_data_offset(self.db.upcast(), SLOT_SIZE))}; + expression! {add([value], [offset])} + } + }; + + self.assign_inst_result(inst, result, *to) + } + + InstKind::AggregateConstruct { ty, args } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let arg_values = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let arg_tys = args + .iter() + .map(|arg| self.body.store.value_ty(*arg)) + .collect(); + self.sink.push(yul::Statement::Expression( + self.ctx + .runtime + .aggregate_init(self.db, ptr, arg_values, ptr_ty, arg_tys), + )) + } + + InstKind::Bind { src } => { + match self.body.store.value_data(*src) { + Value::Constant { constant, .. } => { + // Need special handling when rhs is the string literal because it needs ptr + // copy. + if let ConstantValue::Str(s) = &constant.data(self.db.upcast()).value { + self.ctx.string_constants.insert(s.to_string()); + let size = self.value_ty_size_deref(*src); + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let inst_result_ty = lhs.ty(self.db.upcast(), &self.body.store); + self.sink.push(yul::Statement::Expression( + self.ctx.runtime.string_copy( + self.db, + ptr, + s, + inst_result_ty.is_sptr(self.db.upcast()), + ), + )) + } else { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + self.assign_inst_result(inst, src, src_ty) + } + } + _ => { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + self.assign_inst_result(inst, src, src_ty) + } + } + } + + InstKind::MemCopy { src } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let dst_ptr = self.lower_assignable_value(lhs); + let dst_ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_ptr = self.value_expr(*src); + let src_ptr_ty = self.body.store.value_ty(*src); + let ty_size = literal_expression! { (self.value_ty_size_deref(*src)) }; + self.sink + .push(yul::Statement::Expression(self.ctx.runtime.ptr_copy( + self.db, + src_ptr, + dst_ptr, + ty_size, + src_ptr_ty.is_sptr(self.db.upcast()), + dst_ptr_ty.is_sptr(self.db.upcast()), + ))) + } + + InstKind::Load { src } => { + let src_ty = self.body.store.value_ty(*src); + let src = self.value_expr(*src); + debug_assert!(src_ty.is_ptr(self.db.upcast())); + + let result = self.body.store.inst_result(inst).unwrap(); + debug_assert!(!result + .ty(self.db.upcast(), &self.body.store) + .is_ptr(self.db.upcast())); + self.assign_inst_result(inst, src, src_ty) + } + + InstKind::AggregateAccess { value, indices } => { + let base = self.value_expr(*value); + let mut ptr = base; + let mut inner_ty = self.body.store.value_ty(*value); + for &idx in indices { + ptr = self.aggregate_elem_ptr(ptr, idx, inner_ty.deref(self.db.upcast())); + inner_ty = + inner_ty.projection_ty(self.db.upcast(), self.body.store.value_data(idx)); + } + + let result = self.body.store.inst_result(inst).unwrap(); + self.assign_inst_result(inst, ptr, inner_ty) + } + + InstKind::MapAccess { value, key } => { + let map_ty = self.body.store.value_ty(*value).deref(self.db.upcast()); + let value_expr = self.value_expr(*value); + let key_expr = self.value_expr(*key); + let key_ty = self.body.store.value_ty(*key); + let ptr = self + .ctx + .runtime + .map_value_ptr(self.db, value_expr, key_expr, key_ty); + let value_ty = match &map_ty.data(self.db.upcast()).kind { + TypeKind::Map(def) => def.value_ty, + _ => unreachable!(), + }; + + self.assign_inst_result(inst, ptr, value_ty.make_sptr(self.db.upcast())); + } + + InstKind::Call { + func, + args, + call_type, + } => { + let args: Vec<_> = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let result = match call_type { + CallType::Internal => { + self.ctx.function_dependency.insert(*func); + let func_name = identifier! {(self.db.codegen_function_symbol_name(*func))}; + expression! {[func_name]([args...])} + } + CallType::External => self.ctx.runtime.external_call(self.db, *func, args), + }; + match self.db.codegen_legalized_signature(*func).return_type { + Some(mut result_ty) => { + if result_ty.is_aggregate(self.db.upcast()) + | result_ty.is_string(self.db.upcast()) + { + result_ty = result_ty.make_mptr(self.db.upcast()); + } + self.assign_inst_result(inst, result, result_ty) + } + _ => self.sink.push(Statement::Expression(result)), + } + } + + InstKind::Revert { arg } => match arg { + Some(arg) => { + let arg_ty = self.body.store.value_ty(*arg); + let deref_ty = arg_ty.deref(self.db.upcast()); + let ty_data = deref_ty.data(self.db.upcast()); + let arg_expr = if deref_ty.is_zero_sized(self.db.upcast()) { + None + } else { + Some(self.value_expr(*arg)) + }; + let name = match &ty_data.kind { + ir::TypeKind::Struct(def) => &def.name, + ir::TypeKind::String(def) => "Error", + _ => "Panic", + }; + self.sink.push(yul::Statement::Expression( + self.ctx.runtime.revert(self.db, arg_expr, name, arg_ty), + )); + } + None => self.sink.push(statement! {revert(0, 0)}), + }, + + InstKind::Emit { arg } => { + let event = self.value_expr(*arg); + let event_ty = self.body.store.value_ty(*arg); + let result = self.ctx.runtime.emit(self.db, event, event_ty); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty); + } + + InstKind::Return { arg } => { + if let Some(arg) = arg { + let arg = self.value_expr(*arg); + let ret_value = self.ret_value.clone().unwrap(); + self.sink.push(statement! {[ret_value] := [arg]}); + } + self.sink.push(yul::Statement::Leave) + } + + InstKind::Keccak256 { arg } => { + let result = self.keccak256(*arg); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty); + } + + InstKind::AbiEncode { arg } => { + let lhs = self.body.store.inst_result(inst).unwrap(); + let ptr = self.lower_assignable_value(lhs); + let ptr_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_expr = self.value_expr(*arg); + let src_ty = self.body.store.value_ty(*arg); + + let abi_encode = self.ctx.runtime.abi_encode( + self.db, + src_expr, + ptr, + src_ty, + ptr_ty.is_sptr(self.db.upcast()), + ); + self.sink.push(statement! { + pop([abi_encode]) + }); + } + + InstKind::Create { value, contract } => { + self.ctx.contract_dependency.insert(*contract); + + let value_expr = self.value_expr(*value); + let result = self.ctx.runtime.create(self.db, *contract, value_expr); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::Create2 { + value, + salt, + contract, + } => { + self.ctx.contract_dependency.insert(*contract); + + let value_expr = self.value_expr(*value); + let salt_expr = self.value_expr(*salt); + let result = self + .ctx + .runtime + .create2(self.db, *contract, value_expr, salt_expr); + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::YulIntrinsic { op, args } => { + let args: Vec<_> = args.iter().map(|arg| self.value_expr(*arg)).collect(); + let op_name = identifier! { (format!("{op}").strip_prefix("__").unwrap()) }; + let result = expression! { [op_name]([args...]) }; + // Intrinsic operation never returns ptr type, so we can use u256_ty as a dummy + // type for the result. + let u256_ty = yul_primitive_type(self.db); + self.assign_inst_result(inst, result, u256_ty) + } + + InstKind::Nop => {} + + // These flow control instructions are already legalized. + InstKind::Jump { .. } | InstKind::Branch { .. } | InstKind::Switch { .. } => { + unreachable!() + } + } + } + + fn lower_if( + &mut self, + cond: ValueId, + then: Vec, + else_: Vec, + ) -> yul::Statement { + let cond = self.value_expr(cond); + + self.enter_scope(); + let then_body = self.lower_branch_body(then); + self.leave_scope(); + + self.enter_scope(); + let else_body = self.lower_branch_body(else_); + self.leave_scope(); + + switch! { + switch ([cond]) + (case 1 {[then_body...]}) + (case 0 {[else_body...]}) + } + } + + fn lower_switch( + &mut self, + scrutinee: ValueId, + table: Vec<(ValueId, Vec)>, + default: Option>, + ) -> yul::Statement { + let scrutinee = self.value_expr(scrutinee); + + let mut cases = vec![]; + for (value, insts) in table { + let value = self.value_expr(value); + let value = match value { + yul::Expression::Literal(lit) => lit, + _ => panic!("switch table values must be literal"), + }; + + self.enter_scope(); + let body = self.lower_branch_body(insts); + self.leave_scope(); + cases.push(yul::Case { + literal: Some(value), + block: block! { [body...] }, + }) + } + + if let Some(insts) = default { + let block = self.lower_branch_body(insts); + cases.push(case! { + default {[block...]} + }); + } + + switch! { + switch ([scrutinee]) + [cases...] + } + } + + fn lower_branch_body(&mut self, insts: Vec) -> Vec { + let mut body = vec![]; + std::mem::swap(&mut self.sink, &mut body); + for inst in insts { + self.lower_structural_inst(inst); + } + std::mem::swap(&mut self.sink, &mut body); + body + } + + fn lower_for(&mut self, body: Vec) -> yul::Statement { + let mut body_stmts = vec![]; + std::mem::swap(&mut self.sink, &mut body_stmts); + for inst in body { + self.lower_structural_inst(inst); + } + std::mem::swap(&mut self.sink, &mut body_stmts); + + block_statement! {( + for {} (1) {} + { + [body_stmts...] + } + )} + } + + fn lower_assign(&mut self, lhs: &AssignableValue, rhs: ValueId) -> yul::Statement { + match lhs { + AssignableValue::Value(value) => { + let lhs = self.value_ident(*value); + let rhs = self.value_expr(rhs); + statement! { [lhs] := [rhs] } + } + AssignableValue::Aggregate { .. } | AssignableValue::Map { .. } => { + let dst_ty = lhs.ty(self.db.upcast(), &self.body.store); + let src_ty = self.body.store.value_ty(rhs); + debug_assert_eq!( + dst_ty.deref(self.db.upcast()), + src_ty.deref(self.db.upcast()) + ); + + let dst = self.lower_assignable_value(lhs); + let src = self.value_expr(rhs); + + if src_ty.is_ptr(self.db.upcast()) { + let ty_size = literal_expression! { (self.value_ty_size_deref(rhs)) }; + + let expr = self.ctx.runtime.ptr_copy( + self.db, + src, + dst, + ty_size, + src_ty.is_sptr(self.db.upcast()), + dst_ty.is_sptr(self.db.upcast()), + ); + yul::Statement::Expression(expr) + } else { + let expr = self.ctx.runtime.ptr_store(self.db, dst, src, dst_ty); + yul::Statement::Expression(expr) + } + } + } + } + + fn lower_unary(&mut self, op: UnOp, value: ValueId) -> yul::Expression { + let value_expr = self.value_expr(value); + match op { + UnOp::Not => expression! { iszero([value_expr])}, + UnOp::Neg => { + let zero = literal_expression! {0}; + if self.body.store.value_data(value).is_imm() { + // Literals are checked at compile time (e.g. -128) so there's no point + // in adding a runtime check. + expression! {sub([zero], [value_expr])} + } else { + let value_ty = self.body.store.value_ty(value); + self.ctx + .runtime + .safe_sub(self.db, zero, value_expr, value_ty) + } + } + UnOp::Inv => expression! { not([value_expr])}, + } + } + + fn lower_binary( + &mut self, + op: BinOp, + lhs: ValueId, + rhs: ValueId, + inst: InstId, + ) -> yul::Expression { + let lhs_expr = self.value_expr(lhs); + let rhs_expr = self.value_expr(rhs); + let is_result_signed = self + .body + .store + .inst_result(inst) + .map(|val| { + let ty = val.ty(self.db.upcast(), &self.body.store); + ty.is_signed(self.db.upcast()) + }) + .unwrap_or(false); + let is_lhs_signed = self.body.store.value_ty(lhs).is_signed(self.db.upcast()); + + let inst_result = self.body.store.inst_result(inst).unwrap(); + let inst_result_ty = inst_result + .ty(self.db.upcast(), &self.body.store) + .deref(self.db.upcast()); + match op { + BinOp::Add => self + .ctx + .runtime + .safe_add(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Sub => self + .ctx + .runtime + .safe_sub(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Mul => self + .ctx + .runtime + .safe_mul(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Div => self + .ctx + .runtime + .safe_div(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Mod => self + .ctx + .runtime + .safe_mod(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Pow => self + .ctx + .runtime + .safe_pow(self.db, lhs_expr, rhs_expr, inst_result_ty), + BinOp::Shl => expression! {shl([rhs_expr], [lhs_expr])}, + BinOp::Shr if is_result_signed => expression! {sar([rhs_expr], [lhs_expr])}, + BinOp::Shr => expression! {shr([rhs_expr], [lhs_expr])}, + BinOp::BitOr | BinOp::LogicalOr => expression! {or([lhs_expr], [rhs_expr])}, + BinOp::BitXor => expression! {xor([lhs_expr], [rhs_expr])}, + BinOp::BitAnd | BinOp::LogicalAnd => expression! {and([lhs_expr], [rhs_expr])}, + BinOp::Eq => expression! {eq([lhs_expr], [rhs_expr])}, + BinOp::Ne => expression! {iszero((eq([lhs_expr], [rhs_expr])))}, + BinOp::Ge if is_lhs_signed => expression! {iszero((slt([lhs_expr], [rhs_expr])))}, + BinOp::Ge => expression! {iszero((lt([lhs_expr], [rhs_expr])))}, + BinOp::Gt if is_lhs_signed => expression! {sgt([lhs_expr], [rhs_expr])}, + BinOp::Gt => expression! {gt([lhs_expr], [rhs_expr])}, + BinOp::Le if is_lhs_signed => expression! {iszero((sgt([lhs_expr], [rhs_expr])))}, + BinOp::Le => expression! {iszero((gt([lhs_expr], [rhs_expr])))}, + BinOp::Lt if is_lhs_signed => expression! {slt([lhs_expr], [rhs_expr])}, + BinOp::Lt => expression! {lt([lhs_expr], [rhs_expr])}, + } + } + + fn lower_cast(&mut self, value: ValueId, to: TypeId) -> yul::Expression { + let from_ty = self.body.store.value_ty(value); + debug_assert!(from_ty.is_primitive(self.db.upcast())); + debug_assert!(to.is_primitive(self.db.upcast())); + + let value = self.value_expr(value); + self.ctx.runtime.primitive_cast(self.db, value, from_ty) + } + + fn assign_inst_result(&mut self, inst: InstId, rhs: yul::Expression, rhs_ty: TypeId) { + // NOTE: We don't have `deref` feature yet, so need a heuristics for an + // assignment. + let stmt = if let Some(result) = self.body.store.inst_result(inst) { + let lhs = self.lower_assignable_value(result); + let lhs_ty = result.ty(self.db.upcast(), &self.body.store); + match result { + AssignableValue::Value(value) => { + match ( + lhs_ty.is_ptr(self.db.upcast()), + rhs_ty.is_ptr(self.db.upcast()), + ) { + (true, true) => { + if lhs_ty.is_mptr(self.db.upcast()) == rhs_ty.is_mptr(self.db.upcast()) + { + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } else { + let ty_size = rhs_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + yul::Statement::Expression(self.ctx.runtime.ptr_copy( + self.db, + rhs, + lhs, + literal_expression! { (ty_size) }, + rhs_ty.is_sptr(self.db.upcast()), + lhs_ty.is_sptr(self.db.upcast()), + )) + } + } + (true, false) => yul::Statement::Expression( + self.ctx.runtime.ptr_store(self.db, lhs, rhs, lhs_ty), + ), + + (false, true) => { + let rhs = self.ctx.runtime.ptr_load(self.db, rhs, rhs_ty); + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } + (false, false) => { + let rhs = self.extend_value(rhs, lhs_ty); + let lhs_ident = self.value_ident(*value); + statement! { [lhs_ident] := [rhs] } + } + } + } + AssignableValue::Aggregate { .. } | AssignableValue::Map { .. } => { + let expr = if rhs_ty.is_ptr(self.db.upcast()) { + let ty_size = rhs_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + self.ctx.runtime.ptr_copy( + self.db, + rhs, + lhs, + literal_expression! { (ty_size) }, + rhs_ty.is_sptr(self.db.upcast()), + lhs_ty.is_sptr(self.db.upcast()), + ) + } else { + self.ctx.runtime.ptr_store(self.db, lhs, rhs, lhs_ty) + }; + yul::Statement::Expression(expr) + } + } + } else { + yul::Statement::Expression(rhs) + }; + + self.sink.push(stmt); + } + + /// Extend a value to 256 bits. + fn extend_value(&mut self, value: yul::Expression, ty: TypeId) -> yul::Expression { + if ty.is_primitive(self.db.upcast()) { + self.ctx.runtime.primitive_cast(self.db, value, ty) + } else { + value + } + } + + fn declare_assignable_value(&mut self, value: &AssignableValue) { + match value { + AssignableValue::Value(value) if !self.value_map.contains(*value) => { + self.declare_value(*value); + } + _ => {} + } + } + + fn declare_value(&mut self, value: ValueId) { + let var = YulVariable::new(format!("$tmp_{}", value.index())); + self.value_map.insert(value, var.ident()); + let value_ty = self.body.store.value_ty(value); + + // Allocate memory for a value if a value is a pointer type. + let init = if value_ty.is_mptr(self.db.upcast()) { + let deref_ty = value_ty.deref(self.db.upcast()); + let ty_size = deref_ty.size_of(self.db.upcast(), SLOT_SIZE); + let size = literal_expression! { (ty_size) }; + Some(self.ctx.runtime.alloc(self.db, size)) + } else { + None + }; + + self.sink.push(yul::Statement::VariableDeclaration( + yul::VariableDeclaration { + identifiers: vec![var.ident()], + expression: init, + }, + )) + } + + fn value_expr(&mut self, value: ValueId) -> yul::Expression { + match self.body.store.value_data(value) { + Value::Local(_) | Value::Temporary { .. } => { + let ident = self.value_map.lookup(value).unwrap(); + literal_expression! {(ident)} + } + Value::Immediate { imm, .. } => { + literal_expression! {(imm)} + } + Value::Constant { constant, .. } => match &constant.data(self.db.upcast()).value { + ConstantValue::Immediate(imm) => { + // YUL does not support representing negative integers with leading minus (e.g. + // `-1` in YUL would lead to an ICE). To mitigate that we + // convert all numeric values into hexadecimal representation. + literal_expression! {(to_hex_str(imm))} + } + ConstantValue::Str(s) => { + self.ctx.string_constants.insert(s.to_string()); + self.ctx.runtime.string_construct(self.db, s, s.len()) + } + ConstantValue::Bool(true) => { + literal_expression! {1} + } + ConstantValue::Bool(false) => { + literal_expression! {0} + } + }, + Value::Unit { .. } => unreachable!(), + } + } + + fn value_ident(&self, value: ValueId) -> yul::Identifier { + self.value_map.lookup(value).unwrap().clone() + } + + fn make_tmp(&mut self, tmp: ValueId) -> yul::Identifier { + let ident = YulVariable::new(format! {"$tmp_{}", tmp.index()}).ident(); + self.value_map.insert(tmp, ident.clone()); + ident + } + + fn keccak256(&mut self, value: ValueId) -> yul::Expression { + let value_ty = self.body.store.value_ty(value); + debug_assert!(value_ty.is_mptr(self.db.upcast())); + + let value_size = value_ty + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE); + let value_size_expr = literal_expression! {(value_size)}; + let value_expr = self.value_expr(value); + expression! {keccak256([value_expr], [value_size_expr])} + } + + fn lower_assignable_value(&mut self, value: &AssignableValue) -> yul::Expression { + match value { + AssignableValue::Value(value) => self.value_expr(*value), + + AssignableValue::Aggregate { lhs, idx } => { + let base_ptr = self.lower_assignable_value(lhs); + let ty = lhs + .ty(self.db.upcast(), &self.body.store) + .deref(self.db.upcast()); + self.aggregate_elem_ptr(base_ptr, *idx, ty) + } + AssignableValue::Map { lhs, key } => { + let map_ptr = self.lower_assignable_value(lhs); + let key_ty = self.body.store.value_ty(*key); + let key = self.value_expr(*key); + self.ctx + .runtime + .map_value_ptr(self.db, map_ptr, key, key_ty) + } + } + } + + fn aggregate_elem_ptr( + &mut self, + base_ptr: yul::Expression, + idx: ValueId, + base_ty: TypeId, + ) -> yul::Expression { + debug_assert!(base_ty.is_aggregate(self.db.upcast())); + + match &base_ty.data(self.db.upcast()).kind { + TypeKind::Array(def) => { + let elem_size = + literal_expression! {(base_ty.array_elem_size(self.db.upcast(), SLOT_SIZE))}; + self.validate_array_indexing(def.len, idx); + let idx = self.value_expr(idx); + let offset = expression! {mul([elem_size], [idx])}; + expression! { add([base_ptr], [offset]) } + } + _ => { + let elem_idx = match self.body.store.value_data(idx) { + Value::Immediate { imm, .. } => imm, + _ => panic!("only array type can use dynamic value indexing"), + }; + let offset = literal_expression! {(base_ty.aggregate_elem_offset(self.db.upcast(), elem_idx.clone(), SLOT_SIZE))}; + expression! {add([base_ptr], [offset])} + } + } + } + + fn validate_array_indexing(&mut self, array_len: usize, idx: ValueId) { + const PANIC_OUT_OF_BOUNDS: usize = 0x32; + + if let Value::Immediate { .. } = self.body.store.value_data(idx) { + return; + } + + let idx = self.value_expr(idx); + let max_idx = literal_expression! {(array_len - 1)}; + self.sink.push(statement!(if (gt([idx], [max_idx])) { + ([runtime::panic_revert_numeric( + self.ctx.runtime.as_mut(), + self.db, + literal_expression! {(PANIC_OUT_OF_BOUNDS)}, + )]) + })); + } + + fn value_ty_size(&self, value: ValueId) -> usize { + self.body + .store + .value_ty(value) + .size_of(self.db.upcast(), SLOT_SIZE) + } + + fn value_ty_size_deref(&self, value: ValueId) -> usize { + self.body + .store + .value_ty(value) + .deref(self.db.upcast()) + .size_of(self.db.upcast(), SLOT_SIZE) + } + + fn enter_scope(&mut self) { + let value_map = std::mem::take(&mut self.value_map); + self.value_map = ScopedValueMap::with_parent(value_map); + } + + fn leave_scope(&mut self) { + let value_map = std::mem::take(&mut self.value_map); + self.value_map = value_map.into_parent(); + } +} + +#[derive(Debug, Default)] +struct ScopedValueMap { + parent: Option>, + map: FxHashMap, +} + +impl ScopedValueMap { + fn lookup(&self, value: ValueId) -> Option<&yul::Identifier> { + match self.map.get(&value) { + Some(ident) => Some(ident), + None => self.parent.as_ref().and_then(|p| p.lookup(value)), + } + } + + fn with_parent(parent: ScopedValueMap) -> Self { + Self { + parent: Some(parent.into()), + ..Self::default() + } + } + + fn into_parent(self) -> Self { + *self.parent.unwrap() + } + + fn insert(&mut self, value: ValueId, ident: yul::Identifier) { + self.map.insert(value, ident); + } + + fn contains(&self, value: ValueId) -> bool { + self.lookup(value).is_some() + } +} + +fn bit_mask(byte_size: usize) -> usize { + (1 << (byte_size * 8)) - 1 +} + +fn bit_mask_expr(byte_size: usize) -> yul::Expression { + let mask = format!("{:#x}", bit_mask(byte_size)); + literal_expression! {(mask)} +} diff --git a/crates/codegen2/src/yul/isel/inst_order.rs b/crates/codegen2/src/yul/isel/inst_order.rs new file mode 100644 index 0000000000..afc82f0016 --- /dev/null +++ b/crates/codegen2/src/yul/isel/inst_order.rs @@ -0,0 +1,1368 @@ +use fe_mir::{ + analysis::{ + domtree::DFSet, loop_tree::LoopId, post_domtree::PostIDom, ControlFlowGraph, DomTree, + LoopTree, PostDomTree, + }, + ir::{ + inst::{BranchInfo, SwitchTable}, + BasicBlockId, FunctionBody, InstId, ValueId, + }, +}; +use indexmap::{IndexMap, IndexSet}; + +#[derive(Debug, Clone)] +pub(super) enum StructuralInst { + Inst(InstId), + If { + cond: ValueId, + then: Vec, + else_: Vec, + }, + + Switch { + scrutinee: ValueId, + table: Vec<(ValueId, Vec)>, + default: Option>, + }, + + For { + body: Vec, + }, + + Break, + + Continue, +} + +pub(super) struct InstSerializer<'a> { + body: &'a FunctionBody, + cfg: ControlFlowGraph, + loop_tree: LoopTree, + df: DFSet, + domtree: DomTree, + pd_tree: PostDomTree, + scope: Option, +} + +impl<'a> InstSerializer<'a> { + pub(super) fn new(body: &'a FunctionBody) -> Self { + let cfg = ControlFlowGraph::compute(body); + let domtree = DomTree::compute(&cfg); + let df = domtree.compute_df(&cfg); + let pd_tree = PostDomTree::compute(body); + let loop_tree = LoopTree::compute(&cfg, &domtree); + + Self { + body, + cfg, + loop_tree, + df, + domtree, + pd_tree, + scope: None, + } + } + + pub(super) fn serialize(&mut self) -> Vec { + self.scope = None; + let entry = self.cfg.entry(); + let mut order = vec![]; + self.serialize_block(entry, &mut order); + order + } + + fn serialize_block(&mut self, block: BasicBlockId, order: &mut Vec) { + match self.loop_tree.loop_of_block(block) { + Some(lp) + if block == self.loop_tree.loop_header(lp) + && Some(block) != self.scope.as_ref().and_then(Scope::loop_header) => + { + let loop_exit = self.find_loop_exit(lp); + self.enter_loop_scope(lp, block, loop_exit); + let mut body = vec![]; + self.serialize_block(block, &mut body); + self.exit_scope(); + order.push(StructuralInst::For { body }); + + match loop_exit { + Some(exit) + if self + .scope + .as_ref() + .map(|scope| scope.branch_merge_block() != Some(exit)) + .unwrap_or(true) => + { + self.serialize_block(exit, order); + } + _ => {} + } + + return; + } + _ => {} + }; + + for inst in self.body.order.iter_inst(block) { + if self.body.store.is_terminator(inst) { + break; + } + if !self.body.store.is_nop(inst) { + order.push(StructuralInst::Inst(inst)); + } + } + + let terminator = self.body.order.terminator(&self.body.store, block).unwrap(); + match self.analyze_terminator(terminator) { + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } => self.serialize_if_terminator(cond, *then, *else_, merge_block, order), + + TerminatorInfo::Switch { + scrutinee, + table, + default, + merge_block, + } => self.serialize_switch_terminator( + scrutinee, + table, + default.map(|value| *value), + merge_block, + order, + ), + + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::Continue => order.push(StructuralInst::Continue), + TerminatorInfo::Break => order.push(StructuralInst::Break), + TerminatorInfo::FallThrough(next) => self.serialize_block(next, order), + TerminatorInfo::NormalInst(inst) => order.push(StructuralInst::Inst(inst)), + } + } + + fn serialize_if_terminator( + &mut self, + cond: ValueId, + then: TerminatorInfo, + else_: TerminatorInfo, + merge_block: Option, + order: &mut Vec, + ) { + let mut then_body = vec![]; + let mut else_body = vec![]; + + self.enter_branch_scope(merge_block); + self.serialize_branch_dest(then, &mut then_body, merge_block); + self.serialize_branch_dest(else_, &mut else_body, merge_block); + self.exit_scope(); + + order.push(StructuralInst::If { + cond, + then: then_body, + else_: else_body, + }); + if let Some(merge_block) = merge_block { + self.serialize_block(merge_block, order); + } + } + + fn serialize_switch_terminator( + &mut self, + scrutinee: ValueId, + table: Vec<(ValueId, TerminatorInfo)>, + default: Option, + merge_block: Option, + order: &mut Vec, + ) { + self.enter_branch_scope(merge_block); + + let mut serialized_table = Vec::with_capacity(table.len()); + for (value, dest) in table { + let mut body = vec![]; + self.serialize_branch_dest(dest, &mut body, merge_block); + serialized_table.push((value, body)); + } + + let serialized_default = default.map(|dest| { + let mut body = vec![]; + self.serialize_branch_dest(dest, &mut body, merge_block); + body + }); + + order.push(StructuralInst::Switch { + scrutinee, + table: serialized_table, + default: serialized_default, + }); + + self.exit_scope(); + + if let Some(merge_block) = merge_block { + self.serialize_block(merge_block, order); + } + } + + fn serialize_branch_dest( + &mut self, + dest: TerminatorInfo, + body: &mut Vec, + merge_block: Option, + ) { + match dest { + TerminatorInfo::Break => body.push(StructuralInst::Break), + TerminatorInfo::Continue => body.push(StructuralInst::Continue), + TerminatorInfo::ToMergeBlock => {} + TerminatorInfo::FallThrough(dest) => { + if Some(dest) != merge_block { + self.serialize_block(dest, body); + } + } + _ => unreachable!(), + }; + } + + fn enter_loop_scope(&mut self, lp: LoopId, header: BasicBlockId, exit: Option) { + let kind = ScopeKind::Loop { lp, header, exit }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn enter_branch_scope(&mut self, merge_block: Option) { + let kind = ScopeKind::Branch { merge_block }; + let current_scope = std::mem::take(&mut self.scope); + self.scope = Some(Scope { + kind, + parent: current_scope.map(Into::into), + }); + } + + fn exit_scope(&mut self) { + let current_scope = std::mem::take(&mut self.scope); + self.scope = current_scope.unwrap().parent.map(|parent| *parent); + } + + // NOTE: We assume loop has at most one canonical loop exit. + fn find_loop_exit(&self, lp: LoopId) -> Option { + let mut exit_candidates = vec![]; + for block_in_loop in self.loop_tree.iter_blocks_post_order(&self.cfg, lp) { + for &succ in self.cfg.succs(block_in_loop) { + if !self.loop_tree.is_block_in_loop(succ, lp) { + exit_candidates.push(succ); + } + } + } + + if exit_candidates.is_empty() { + return None; + } + + if exit_candidates.len() == 1 { + let candidate = exit_candidates[0]; + let exit = if let Some(mut df) = self.df.frontiers(candidate) { + debug_assert_eq!(self.df.frontier_num(candidate), 1); + df.next() + } else { + Some(candidate) + }; + return exit; + } + + // If a candidate is a dominance frontier of all other nodes, then the candidate + // is a loop exit. + for &cand in &exit_candidates { + if exit_candidates.iter().all(|&block| { + if block == cand { + true + } else if let Some(mut df) = self.df.frontiers(block) { + df.any(|frontier| frontier == cand) + } else { + true + } + }) { + return Some(cand); + } + } + + // If all candidates have the same dominance frontier, then the frontier block + // is the canonicalized loop exit. + let mut frontier: IndexSet<_> = self + .df + .frontiers(exit_candidates.pop().unwrap()) + .map(std::iter::Iterator::collect) + .unwrap_or_default(); + for cand in exit_candidates { + for cand_frontier in self.df.frontiers(cand).unwrap() { + if !frontier.contains(&cand_frontier) { + frontier.remove(&cand_frontier); + } + } + } + debug_assert!(frontier.len() < 2); + frontier.iter().next().copied() + } + + fn analyze_terminator(&self, inst: InstId) -> TerminatorInfo { + debug_assert!(self.body.store.is_terminator(inst)); + + let inst_block = self.body.order.inst_block(inst); + match self.body.store.branch_info(inst) { + BranchInfo::Jump(dest) => self.analyze_jump(dest), + + BranchInfo::Branch(cond, then, else_) => self.analyze_if(inst_block, cond, then, else_), + + BranchInfo::Switch(scrutinee, table, default) => { + self.analyze_switch(inst_block, scrutinee, table, default) + } + + BranchInfo::NotBranch => TerminatorInfo::NormalInst(inst), + } + } + + fn analyze_if( + &self, + block: BasicBlockId, + cond: ValueId, + then_bb: BasicBlockId, + else_bb: BasicBlockId, + ) -> TerminatorInfo { + let then = Box::new(self.analyze_dest(then_bb)); + let else_ = Box::new(self.analyze_dest(else_bb)); + + let then_cands = self.find_merge_block_candidates(block, then_bb); + let else_cands = self.find_merge_block_candidates(block, else_bb); + debug_assert!(then_cands.len() < 2); + debug_assert!(else_cands.len() < 2); + + let merge_block = match (then_cands.as_slice(), else_cands.as_slice()) { + (&[then_cand], &[else_cand]) => { + if then_cand == else_cand { + Some(then_cand) + } else { + None + } + } + + (&[cand], []) => { + if cand == else_bb { + Some(cand) + } else { + None + } + } + + ([], &[cand]) => { + if cand == then_bb { + Some(cand) + } else { + None + } + } + + ([], []) => match self.pd_tree.post_idom(block) { + PostIDom::Block(block) => { + if let Some(lp) = self.scope.as_ref().and_then(Scope::loop_recursive) { + if self.loop_tree.is_block_in_loop(block, lp) { + Some(block) + } else { + None + } + } else { + Some(block) + } + } + _ => None, + }, + + (_, _) => unreachable!(), + }; + + TerminatorInfo::If { + cond, + then, + else_, + merge_block, + } + } + + fn analyze_switch( + &self, + block: BasicBlockId, + scrutinee: ValueId, + table: &SwitchTable, + default: Option, + ) -> TerminatorInfo { + let mut analyzed_table = Vec::with_capacity(table.len()); + + let mut merge_block_cands = IndexSet::default(); + for (value, dest) in table.iter() { + analyzed_table.push((value, self.analyze_dest(dest))); + merge_block_cands.extend(self.find_merge_block_candidates(block, dest)); + } + + let analyzed_default = default.map(|dest| { + merge_block_cands.extend(self.find_merge_block_candidates(block, dest)); + Box::new(self.analyze_dest(dest)) + }); + + TerminatorInfo::Switch { + scrutinee, + table: analyzed_table, + default: analyzed_default, + merge_block: self.select_switch_merge_block( + &merge_block_cands, + table.iter().map(|(_, d)| d).chain(default), + ), + } + } + + fn find_merge_block_candidates( + &self, + branch_inst_bb: BasicBlockId, + branch_dest_bb: BasicBlockId, + ) -> Vec { + if self.domtree.dominates(branch_dest_bb, branch_inst_bb) { + return vec![]; + } + + // a block `cand` can be a candidate of a `merge` block iff + // 1. `cand` is a dominance frontier of `branch_dest_bb`. + // 2. `cand` is NOT a dominator of `branch_dest_bb`. + // 3. `cand` is NOT a "merge" block of parent `if` or `switch`. + // 4. `cand` is NOT a "loop_exit" block of parent `loop`. + match self.df.frontiers(branch_dest_bb) { + Some(cands) => cands + .filter(|cand| { + !self.domtree.dominates(*cand, branch_dest_bb) + && Some(*cand) + != self + .scope + .as_ref() + .and_then(Scope::branch_merge_block_recursive) + && Some(*cand) != self.scope.as_ref().and_then(Scope::loop_exit_recursive) + }) + .collect(), + None => vec![], + } + } + + /// Each destination block of `switch` instruction could have multiple + /// candidates for the merge block because arm bodies can have multiple + /// predecessors, e.g., `default` arm. + /// So we need a heuristic to select the merge block from candidates. + /// + /// First, if one of the dominance frontiers of switch dests is a parent + /// merge block, then we stop searching the merge block because the parent + /// merge block should be the subsequent codes after the switch in terms of + /// high-level flow structure like Fe or yul. + /// + /// If no parent merge block is found, we start scoring the candidates by + /// the following function. + /// + /// The scoring function `F` is defined as follows: + /// 1. The initial score of each candidate('cand_bb`) is number of + /// predecessors of the candidate. + /// + /// 2. Find the `top_cand` of each `cand_bb`. `top_cand` can be found by + /// [`Self::try_find_top_cand`] method, see the method for details. + /// + /// 3. If `top_cand` is found, then add the `cand_bb` score to the + /// `top_cand` score, then set 0 to the `cand_bb` score. + /// + /// After the scoring, the candidates with the highest score will be + /// selected. + fn select_switch_merge_block( + &self, + cands: &IndexSet, + dests: impl Iterator, + ) -> Option { + let parent_merge = self + .scope + .as_ref() + .and_then(Scope::branch_merge_block_recursive); + for dest in dests { + if self + .df + .frontiers(dest) + .map(|mut frontieres| frontieres.any(|frontier| Some(frontier) == parent_merge)) + .unwrap_or_default() + { + return None; + } + } + + let mut cands_with_score = cands + .iter() + .map(|cand| (*cand, self.cfg.preds(*cand).len())) + .collect::>(); + + for cand_bb in cands_with_score.keys().copied().collect::>() { + if let Some(top_cand) = self.try_find_top_cand(&cands_with_score, cand_bb) { + let score = std::mem::take(cands_with_score.get_mut(&cand_bb).unwrap()); + *cands_with_score.get_mut(&top_cand).unwrap() += score; + } + } + + cands_with_score + .iter() + .max_by_key(|(_, score)| *score) + .map(|(&cand, _)| cand) + } + + /// Try to find the `top_cand` of the `cand_bb`. + /// A `top_cand` can be found by the following rules: + /// + /// 1. Find the block which is contained in DF of `cand_bb` and in + /// `cands_with_score`. + /// + /// 2. If a block is found in 1., and the score of the block is positive, + /// then the block is `top_cand`. + /// + /// 2'. If a block is found in 1., and the score of the block is 0, then the + /// `top_cand` of the block is `top_cand` of `cand_bb`. + /// + /// 2''. If a block is NOT found in 1., then there is no `top_cand` for + /// `cand_bb`. + fn try_find_top_cand( + &self, + cands_with_score: &IndexMap, + cand_bb: BasicBlockId, + ) -> Option { + let mut frontiers = match self.df.frontiers(cand_bb) { + Some(frontiers) => frontiers, + _ => return None, + }; + + while let Some(frontier_bb) = frontiers.next() { + if cands_with_score.contains_key(&frontier_bb) { + debug_assert!(frontiers.all(|bb| !cands_with_score.contains_key(&bb))); + if cands_with_score[&frontier_bb] != 0 { + return Some(frontier_bb); + } else { + return self.try_find_top_cand(cands_with_score, frontier_bb); + } + } + } + + None + } + + fn analyze_jump(&self, dest: BasicBlockId) -> TerminatorInfo { + self.analyze_dest(dest) + } + + fn analyze_dest(&self, dest: BasicBlockId) -> TerminatorInfo { + match &self.scope { + Some(scope) => { + if Some(dest) == scope.loop_header_recursive() { + TerminatorInfo::Continue + } else if Some(dest) == scope.loop_exit_recursive() { + TerminatorInfo::Break + } else if Some(dest) == scope.branch_merge_block_recursive() { + TerminatorInfo::ToMergeBlock + } else { + TerminatorInfo::FallThrough(dest) + } + } + + None => TerminatorInfo::FallThrough(dest), + } + } +} + +struct Scope { + kind: ScopeKind, + parent: Option>, +} + +#[derive(Debug, Clone, Copy)] +enum ScopeKind { + Loop { + lp: LoopId, + header: BasicBlockId, + exit: Option, + }, + Branch { + merge_block: Option, + }, +} + +impl Scope { + fn loop_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { lp, .. } => Some(lp), + _ => self.parent.as_ref()?.loop_recursive(), + } + } + + fn loop_header(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => None, + } + } + + fn loop_header_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { header, .. } => Some(header), + _ => self.parent.as_ref()?.loop_header_recursive(), + } + } + + fn loop_exit_recursive(&self) -> Option { + match self.kind { + ScopeKind::Loop { exit, .. } => exit, + _ => self.parent.as_ref()?.loop_exit_recursive(), + } + } + + fn branch_merge_block(&self) -> Option { + match self.kind { + ScopeKind::Branch { merge_block } => merge_block, + _ => None, + } + } + + fn branch_merge_block_recursive(&self) -> Option { + match self.kind { + ScopeKind::Branch { + merge_block: Some(merge_block), + } => Some(merge_block), + _ => self.parent.as_ref()?.branch_merge_block_recursive(), + } + } +} + +#[derive(Debug, Clone)] +enum TerminatorInfo { + If { + cond: ValueId, + then: Box, + else_: Box, + merge_block: Option, + }, + + Switch { + scrutinee: ValueId, + table: Vec<(ValueId, TerminatorInfo)>, + default: Option>, + merge_block: Option, + }, + + ToMergeBlock, + Continue, + Break, + FallThrough(BasicBlockId), + NormalInst(InstId), +} + +#[cfg(test)] +mod tests { + use fe_mir::ir::{body_builder::BodyBuilder, inst::InstKind, FunctionId, SourceInfo, TypeId}; + + use super::*; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + fn serialize_func_body(func: &mut FunctionBody) -> impl Iterator { + InstSerializer::new(func).serialize().into_iter() + } + + fn expect_if( + insts: &mut impl Iterator, + ) -> ( + impl Iterator, + impl Iterator, + ) { + match insts.next().unwrap() { + StructuralInst::If { then, else_, .. } => (then.into_iter(), else_.into_iter()), + _ => panic!("expect if inst"), + } + } + + fn expect_switch( + insts: &mut impl Iterator, + ) -> Vec> { + match insts.next().unwrap() { + StructuralInst::Switch { table, default, .. } => { + let mut arms: Vec<_> = table + .into_iter() + .map(|(_, insts)| insts.into_iter()) + .collect(); + if let Some(default) = default { + arms.push(default.into_iter()); + } + + arms + } + + _ => panic!("expect if inst"), + } + } + + fn expect_for( + insts: &mut impl Iterator, + ) -> impl Iterator { + match insts.next().unwrap() { + StructuralInst::For { body } => body.into_iter(), + _ => panic!("expect if inst"), + } + } + + fn expect_break(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Break)) + } + + fn expect_continue(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Continue)) + } + + fn expect_return(func: &FunctionBody, insts: &mut impl Iterator) { + let inst = insts.next().unwrap(); + match inst { + StructuralInst::Inst(inst) => { + assert!(matches!( + func.store.inst_data(inst).kind, + InstKind::Return { .. } + )) + } + _ => panic!("expect return"), + } + } + + fn expect_end(insts: &mut impl Iterator) { + assert!(insts.next().is_none()) + } + + #[test] + fn if_non_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | + // | + // v + // +-------+ + // | else_ | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(&mut order); + expect_return(&func, &mut then); + expect_end(&mut then); + expect_return(&func, &mut else_); + expect_end(&mut else_); + + expect_end(&mut order); + } + + #[test] + fn if_merge() { + // +------+ +-------+ + // | then | <-- | bb0 | + // +------+ +-------+ + // | | + // | | + // | v + // | +-------+ + // | | else_ | + // | +-------+ + // | | + // | | + // | v + // | +-------+ + // +--------> | merge | + // +-------+ + let mut builder = body_builder(); + + let then = builder.make_block(); + let else_ = builder.make_block(); + let merge = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, then, else_, SourceInfo::dummy()); + + builder.move_to_block(then); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(else_); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then, mut else_) = expect_if(&mut order); + expect_end(&mut then); + expect_end(&mut else_); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn nested_if() { + // +-----+ + // | bb0 | -+ + // +-----+ | + // | | + // | | + // v | + // +-----+ +-----+ | + // | bb3 | <-- | bb1 | | + // +-----+ +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb4 | | + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | <+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then1, mut else2) = expect_if(&mut order); + expect_end(&mut else2); + + let (mut then3, mut else4) = expect_if(&mut then1); + expect_end(&mut then1); + expect_return(&func, &mut then3); + expect_end(&mut then3); + expect_end(&mut else4); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn simple_loop() { + // +--------+ + // | bb0 | -+ + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // +> | header | | + // | +--------+ | + // | | | + // | | | + // | v | + // | +--------+ | + // +- | latch | | + // +--------+ | + // | | + // | | + // v | + // +--------+ | + // | exit | <+ + // +--------+ + let mut builder = body_builder(); + + let header = builder.make_block(); + let latch = builder.make_block(); + let exit = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(header); + builder.jump(latch, SourceInfo::dummy()); + + builder.move_to_block(latch); + builder.branch(v0, header, exit, SourceInfo::dummy()); + + builder.move_to_block(exit); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + + let mut body = expect_for(&mut lp); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); + + expect_continue(&mut continue_); + expect_end(&mut continue_); + + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut empty); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_with_continue() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | + // | v + // | +---------------+ +-----+ + // | | bb1 | --> | bb3 | + // | +---------------+ +-----+ + // | | ^ ^ | + // | | | +---------+ + // | v | + // | +-----+ | + // | | bb4 | -+ + // | +-----+ + // | | + // | | + // | v + // | +-----+ + // +> | bb2 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); + + let mut body = expect_for(&mut lp); + + let (mut continue_, mut empty) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_end(&mut empty); + + let (mut continue_, mut break_) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut body); + expect_end(&mut lp); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_with_break() { + // +-----+ + // +- | bb0 | + // | +-----+ + // | | + // | | +---------+ + // | v v | + // | +---------------+ +-----+ + // | | bb1 | --> | bb4 | + // | +---------------+ +-----+ + // | | | + // | | | + // | v | + // | +-----+ | + // | | bb3 | | + // | +-----+ | + // | | | + // | | | + // | v | + // | +-----+ | + // +> | bb2 | <---------------+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); + + let mut body = expect_for(&mut lp); + + let (mut break_, mut latch) = expect_if(&mut body); + expect_break(&mut break_); + expect_end(&mut break_); + + let (mut continue_, mut break_) = expect_if(&mut latch); + expect_end(&mut latch); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); + + expect_end(&mut body); + expect_end(&mut lp); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn loop_no_guard() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + // | + // | + // v + // +-----+ + // | bb3 | + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.branch(v0, bb1, bb3, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(&mut order); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); + + expect_continue(&mut continue_); + expect_end(&mut continue_); + + expect_break(&mut break_); + expect_end(&mut break_); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn infinite_loop() { + // +-----+ + // | bb0 | + // +-----+ + // | + // | + // v + // +-----+ + // | bb1 | <+ + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | -+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + + builder.jump(bb1, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.jump(bb1, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut body = expect_for(&mut order); + expect_continue(&mut body); + expect_end(&mut body); + + expect_end(&mut order); + } + + #[test] + fn switch_basic() { + // +-----+ +-------+ +-----+ + // | bb2 | <-- | bb0 | --> | bb3 | + // +-----+ +-------+ +-----+ + // | | | + // | | | + // | v | + // | +-------+ | + // | | bb1 | | + // | +-------+ | + // | | | + // | | | + // | v | + // | +-------+ | + // +-------> | merge | <-----+ + // +-------+ + let mut builder = body_builder(); + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let merge = builder.make_block(); + + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb1); + table.add_arm(dummy_value, bb2); + table.add_arm(dummy_value, bb3); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.jump(merge, SourceInfo::dummy()); + builder.move_to_block(bb3); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(dummy_value, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let arms = expect_switch(&mut order); + assert_eq!(arms.len(), 3); + for mut arm in arms { + expect_end(&mut arm); + } + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn switch_default() { + // +-----------+ + // | | + // | | + // | +----+--------+ + // v | | | + // +-----+ | +-------+ | +---------+ + // | bb2 | -+ | bb0 | -+> | bb3 | + // +-----+ +-------+ | +---------+ + // | | | | + // | | | | + // v v | v + // +-----+ +-------+ | +---------+ + // | bb5 | +- | bb1 | +> | default | <+ + // +-----+ | +-------+ +---------+ | + // | | | | | + // | | | | | + // | | v | | + // | | +-------+ | | + // | | | bb4 | | | + // | | +-------+ | | + // | | | | | + // +----+------+ | | | + // | | v | | + // | | +-------+ | | + // | +-------> | merge | <-----+ | + // | +-------+ | + // | | + // +-----------------------------------------+ + let mut builder = body_builder(); + let dummy_ty = TypeId(0); + let dummy_value = builder.make_unit(dummy_ty); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + let bb5 = builder.make_block(); + let default = builder.make_block(); + let merge = builder.make_block(); + + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb1); + table.add_arm(dummy_value, bb2); + table.add_arm(dummy_value, bb3); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb1); + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb4); + table.add_arm(dummy_value, default); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb2); + let mut table = SwitchTable::default(); + table.add_arm(dummy_value, bb5); + table.add_arm(dummy_value, default); + builder.switch(dummy_value, table, None, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.jump(default, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(bb5); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(default); + builder.jump(merge, SourceInfo::dummy()); + + builder.move_to_block(merge); + builder.ret(dummy_value, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let mut arms = expect_switch(&mut order); + assert_eq!(arms.len(), 3); + + let mut bb3_jump = arms.pop().unwrap(); + expect_end(&mut bb3_jump); + + let mut bb2_switch = arms.pop().unwrap(); + let bb2_switch_arms = expect_switch(&mut bb2_switch); + assert_eq!(bb2_switch_arms.len(), 2); + for mut bb2_switch_arm in bb2_switch_arms { + expect_end(&mut bb2_switch_arm); + } + expect_end(&mut bb2_switch); + + let mut bb1_switch = arms.pop().unwrap(); + let bb1_switch_arms = expect_switch(&mut bb1_switch); + assert_eq!(bb1_switch_arms.len(), 2); + for mut bb1_switch_arm in bb1_switch_arms { + expect_end(&mut bb1_switch_arm); + } + expect_end(&mut bb1_switch); + + expect_return(&func, &mut order); + expect_end(&mut order); + } +} diff --git a/crates/codegen2/src/yul/isel/mod.rs b/crates/codegen2/src/yul/isel/mod.rs new file mode 100644 index 0000000000..2507774ff8 --- /dev/null +++ b/crates/codegen2/src/yul/isel/mod.rs @@ -0,0 +1,9 @@ +pub mod context; +mod contract; +mod function; +mod inst_order; +mod test; + +pub use contract::{lower_contract, lower_contract_deployable}; +pub use function::lower_function; +pub use test::lower_test; diff --git a/crates/codegen2/src/yul/isel/test.rs b/crates/codegen2/src/yul/isel/test.rs new file mode 100644 index 0000000000..9fe6186933 --- /dev/null +++ b/crates/codegen2/src/yul/isel/test.rs @@ -0,0 +1,70 @@ +use super::context::Context; +use crate::db::CodegenDb; +use fe_analyzer::namespace::items::FunctionId; +use yultsur::{yul, *}; + +pub fn lower_test(db: &dyn CodegenDb, test: FunctionId) -> yul::Object { + let mut context = Context::default(); + let test = db.mir_lowered_func_signature(test); + context.function_dependency.insert(test); + + let dep_constants = context.resolve_constant_dependency(db); + let dep_functions: Vec<_> = context + .resolve_function_dependency(db) + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let dep_contracts = context.resolve_contract_dependency(db); + let runtime_funcs: Vec<_> = context + .runtime + .collect_definitions() + .into_iter() + .map(yul::Statement::FunctionDefinition) + .collect(); + let test_func_name = identifier! { (db.codegen_function_symbol_name(test)) }; + let call = function_call_statement! {[test_func_name]()}; + + let code = code! { + [dep_functions...] + [runtime_funcs...] + [call] + (stop()) + }; + + let name = identifier! { test }; + let object = yul::Object { + name, + code, + objects: dep_contracts, + data: dep_constants, + }; + + normalize_object(object) +} + +fn normalize_object(obj: yul::Object) -> yul::Object { + let data = obj + .data + .into_iter() + .map(|data| yul::Data { + name: data.name, + value: data + .value + .replace('\\', "\\\\\\\\") + .replace('\n', "\\\\n") + .replace('"', "\\\\\"") + .replace('\r', "\\\\r") + .replace('\t', "\\\\t"), + }) + .collect::>(); + yul::Object { + name: obj.name, + code: obj.code, + objects: obj + .objects + .into_iter() + .map(normalize_object) + .collect::>(), + data, + } +} diff --git a/crates/codegen2/src/yul/legalize/body.rs b/crates/codegen2/src/yul/legalize/body.rs new file mode 100644 index 0000000000..5c1b361b1e --- /dev/null +++ b/crates/codegen2/src/yul/legalize/body.rs @@ -0,0 +1,219 @@ +use fe_mir::ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::InstKind, + value::AssignableValue, + FunctionBody, Inst, InstId, TypeId, TypeKind, Value, ValueId, +}; + +use crate::db::CodegenDb; + +use super::critical_edge::CriticalEdgeSplitter; + +pub fn legalize_func_body(db: &dyn CodegenDb, body: &mut FunctionBody) { + CriticalEdgeSplitter::new().run(body); + legalize_func_arg(db, body); + + let mut cursor = BodyCursor::new_at_entry(body); + loop { + match cursor.loc() { + CursorLocation::BlockTop(_) | CursorLocation::BlockBottom(_) => cursor.proceed(), + CursorLocation::Inst(inst) => { + legalize_inst(db, &mut cursor, inst); + } + CursorLocation::NoWhere => break, + } + } +} + +fn legalize_func_arg(db: &dyn CodegenDb, body: &mut FunctionBody) { + for value in body.store.func_args_mut() { + let ty = value.ty(); + if ty.is_contract(db.upcast()) { + let slot_ptr = make_storage_ptr(db, ty); + *value = slot_ptr; + } else if (ty.is_aggregate(db.upcast()) || ty.is_string(db.upcast())) + && !ty.is_zero_sized(db.upcast()) + { + change_ty(value, ty.make_mptr(db.upcast())) + } + } +} + +fn legalize_inst(db: &dyn CodegenDb, cursor: &mut BodyCursor, inst: InstId) { + if legalize_unit_construct(db, cursor, inst) { + return; + } + legalize_declared_ty(db, cursor.body_mut(), inst); + legalize_inst_arg(db, cursor.body_mut(), inst); + legalize_inst_result(db, cursor.body_mut(), inst); + cursor.proceed(); +} + +fn legalize_unit_construct(db: &dyn CodegenDb, cursor: &mut BodyCursor, inst: InstId) -> bool { + let should_remove = match &cursor.body().store.inst_data(inst).kind { + InstKind::Declare { local } => is_value_zst(db, cursor.body(), *local), + InstKind::AggregateConstruct { ty, .. } => ty.deref(db.upcast()).is_zero_sized(db.upcast()), + InstKind::AggregateAccess { .. } | InstKind::MapAccess { .. } | InstKind::Cast { .. } => { + let result_value = cursor.body().store.inst_result(inst).unwrap(); + is_lvalue_zst(db, cursor.body(), result_value) + } + + _ => false, + }; + + if should_remove { + cursor.remove_inst() + } + + should_remove +} + +fn legalize_declared_ty(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + let value = match &body.store.inst_data(inst_id).kind { + InstKind::Declare { local } => *local, + _ => return, + }; + + let value_ty = body.store.value_ty(value); + if value_ty.is_aggregate(db.upcast()) { + let new_ty = value_ty.make_mptr(db.upcast()); + let value_data = body.store.value_data_mut(value); + change_ty(value_data, new_ty) + } +} + +fn legalize_inst_arg(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + // Replace inst with dummy inst to avoid borrow checker complaining. + let dummy_inst = Inst::nop(); + let mut inst = body.store.replace_inst(inst_id, dummy_inst); + + for arg in inst.args() { + let ty = body.store.value_ty(arg); + if ty.is_string(db.upcast()) { + let string_ptr = ty.make_mptr(db.upcast()); + change_ty(body.store.value_data_mut(arg), string_ptr) + } + } + + match &mut inst.kind { + InstKind::AggregateConstruct { args, .. } => { + args.retain(|arg| !is_value_zst(db, body, *arg)); + } + + InstKind::Call { args, .. } => { + args.retain(|arg| !is_value_zst(db, body, *arg) && !is_value_contract(db, body, *arg)) + } + + InstKind::Return { arg } => { + if arg.map(|arg| is_value_zst(db, body, arg)).unwrap_or(false) { + *arg = None; + } + } + + InstKind::MapAccess { key: arg, .. } | InstKind::Emit { arg } => { + let arg_ty = body.store.value_ty(*arg); + if arg_ty.is_zero_sized(db.upcast()) { + *arg = body.store.store_value(make_zst_ptr(db, arg_ty)); + } + } + + InstKind::Cast { value, to, .. } => { + if to.is_aggregate(db.upcast()) && !to.is_zero_sized(db.upcast()) { + let value_ty = body.store.value_ty(*value); + if value_ty.is_mptr(db.upcast()) { + *to = to.make_mptr(db.upcast()); + } else if value_ty.is_sptr(db.upcast()) { + *to = to.make_sptr(db.upcast()); + } else { + unreachable!() + } + } + } + + _ => {} + } + + body.store.replace_inst(inst_id, inst); +} + +fn legalize_inst_result(db: &dyn CodegenDb, body: &mut FunctionBody, inst_id: InstId) { + let result_value = if let Some(result) = body.store.inst_result(inst_id) { + result + } else { + return; + }; + + if is_lvalue_zst(db, body, result_value) { + body.store.remove_inst_result(inst_id); + return; + }; + + let value_id = if let Some(value_id) = result_value.value_id() { + value_id + } else { + return; + }; + let result_ty = body.store.value_ty(value_id); + let new_ty = if result_ty.is_aggregate(db.upcast()) || result_ty.is_string(db.upcast()) { + match &body.store.inst_data(inst_id).kind { + InstKind::AggregateAccess { value, .. } => { + let value_ty = body.store.value_ty(*value); + match &value_ty.data(db.upcast()).kind { + TypeKind::MPtr(..) => result_ty.make_mptr(db.upcast()), + // Note: All SPtr aggregate access results should be SPtr already + _ => unreachable!(), + } + } + _ => result_ty.make_mptr(db.upcast()), + } + } else { + return; + }; + + let value = body.store.value_data_mut(value_id); + change_ty(value, new_ty); +} + +fn change_ty(value: &mut Value, new_ty: TypeId) { + match value { + Value::Local(val) => val.ty = new_ty, + Value::Immediate { ty, .. } + | Value::Temporary { ty, .. } + | Value::Unit { ty } + | Value::Constant { ty, .. } => *ty = new_ty, + } +} + +fn make_storage_ptr(db: &dyn CodegenDb, ty: TypeId) -> Value { + debug_assert!(ty.is_contract(db.upcast())); + let ty = ty.make_sptr(db.upcast()); + + Value::Immediate { imm: 0.into(), ty } +} + +fn make_zst_ptr(db: &dyn CodegenDb, ty: TypeId) -> Value { + debug_assert!(ty.is_zero_sized(db.upcast())); + let ty = ty.make_mptr(db.upcast()); + + Value::Immediate { imm: 0.into(), ty } +} + +/// Returns `true` if a value has a zero sized type. +fn is_value_zst(db: &dyn CodegenDb, body: &FunctionBody, value: ValueId) -> bool { + body.store + .value_ty(value) + .deref(db.upcast()) + .is_zero_sized(db.upcast()) +} + +fn is_value_contract(db: &dyn CodegenDb, body: &FunctionBody, value: ValueId) -> bool { + let ty = body.store.value_ty(value); + ty.deref(db.upcast()).is_contract(db.upcast()) +} + +fn is_lvalue_zst(db: &dyn CodegenDb, body: &FunctionBody, lvalue: &AssignableValue) -> bool { + lvalue + .ty(db.upcast(), &body.store) + .deref(db.upcast()) + .is_zero_sized(db.upcast()) +} diff --git a/crates/codegen2/src/yul/legalize/critical_edge.rs b/crates/codegen2/src/yul/legalize/critical_edge.rs new file mode 100644 index 0000000000..3e3d689ab3 --- /dev/null +++ b/crates/codegen2/src/yul/legalize/critical_edge.rs @@ -0,0 +1,121 @@ +use fe_mir::{ + analysis::ControlFlowGraph, + ir::{ + body_cursor::{BodyCursor, CursorLocation}, + inst::InstKind, + BasicBlock, BasicBlockId, FunctionBody, Inst, InstId, SourceInfo, + }, +}; + +#[derive(Debug)] +pub struct CriticalEdgeSplitter { + critical_edges: Vec, +} + +impl CriticalEdgeSplitter { + pub fn new() -> Self { + Self { + critical_edges: Vec::default(), + } + } + + pub fn run(&mut self, func: &mut FunctionBody) { + let cfg = ControlFlowGraph::compute(func); + + for block in cfg.post_order() { + let terminator = func.order.terminator(&func.store, block).unwrap(); + self.add_critical_edges(terminator, func, &cfg); + } + + self.split_edges(func); + } + + fn add_critical_edges( + &mut self, + terminator: InstId, + func: &FunctionBody, + cfg: &ControlFlowGraph, + ) { + for to in func.store.branch_info(terminator).block_iter() { + if cfg.preds(to).len() > 1 { + self.critical_edges.push(CriticalEdge { terminator, to }); + } + } + } + + fn split_edges(&mut self, func: &mut FunctionBody) { + for edge in std::mem::take(&mut self.critical_edges) { + let terminator = edge.terminator; + let source_block = func.order.inst_block(terminator); + let original_dest = edge.to; + + // Create new block that contains only jump inst. + let new_dest = func.store.store_block(BasicBlock {}); + let mut cursor = BodyCursor::new(func, CursorLocation::BlockTop(source_block)); + cursor.insert_block(new_dest); + cursor.set_loc(CursorLocation::BlockTop(new_dest)); + cursor.store_and_insert_inst(Inst::new( + InstKind::Jump { + dest: original_dest, + }, + SourceInfo::dummy(), + )); + + // Rewrite branch destination to the new dest. + func.store + .rewrite_branch_dest(terminator, original_dest, new_dest); + } + } +} + +#[derive(Debug)] +struct CriticalEdge { + terminator: InstId, + to: BasicBlockId, +} + +#[cfg(test)] +mod tests { + use fe_mir::ir::{body_builder::BodyBuilder, FunctionId, TypeId}; + + use super::*; + + fn body_builder() -> BodyBuilder { + BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) + } + + #[test] + fn critical_edge_remove() { + let mut builder = body_builder(); + let lp_header = builder.make_block(); + let lp_body = builder.make_block(); + let exit = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(false, dummy_ty); + builder.branch(v0, lp_header, exit, SourceInfo::dummy()); + + builder.move_to_block(lp_header); + builder.jump(lp_body, SourceInfo::dummy()); + + builder.move_to_block(lp_body); + builder.branch(v0, lp_header, exit, SourceInfo::dummy()); + + builder.move_to_block(exit); + builder.ret(v0, SourceInfo::dummy()); + + let mut func = builder.build(); + CriticalEdgeSplitter::new().run(&mut func); + let cfg = ControlFlowGraph::compute(&func); + + for &header_pred in cfg.preds(lp_header) { + debug_assert_eq!(cfg.succs(header_pred).len(), 1); + debug_assert_eq!(cfg.succs(header_pred)[0], lp_header); + } + + for &exit_pred in cfg.preds(exit) { + debug_assert_eq!(cfg.succs(exit_pred).len(), 1); + debug_assert_eq!(cfg.succs(exit_pred)[0], exit); + } + } +} diff --git a/crates/codegen2/src/yul/legalize/mod.rs b/crates/codegen2/src/yul/legalize/mod.rs new file mode 100644 index 0000000000..62e82f78fe --- /dev/null +++ b/crates/codegen2/src/yul/legalize/mod.rs @@ -0,0 +1,6 @@ +mod body; +mod critical_edge; +mod signature; + +pub use body::legalize_func_body; +pub use signature::legalize_func_signature; diff --git a/crates/codegen2/src/yul/legalize/signature.rs b/crates/codegen2/src/yul/legalize/signature.rs new file mode 100644 index 0000000000..134bc10ae6 --- /dev/null +++ b/crates/codegen2/src/yul/legalize/signature.rs @@ -0,0 +1,27 @@ +use fe_mir::ir::{FunctionSignature, TypeKind}; + +use crate::db::CodegenDb; + +pub fn legalize_func_signature(db: &dyn CodegenDb, sig: &mut FunctionSignature) { + // Remove param if the type is contract or zero-sized. + let params = &mut sig.params; + params.retain(|param| match param.ty.data(db.upcast()).kind { + TypeKind::Contract(_) => false, + _ => !param.ty.deref(db.upcast()).is_zero_sized(db.upcast()), + }); + + // Legalize param types. + for param in params.iter_mut() { + param.ty = db.codegen_legalized_type(param.ty); + } + + if let Some(ret_ty) = sig.return_type { + // Remove return type if the type is contract or zero-sized. + if ret_ty.is_contract(db.upcast()) || ret_ty.deref(db.upcast()).is_zero_sized(db.upcast()) { + sig.return_type = None; + } else { + // Legalize param types. + sig.return_type = Some(db.codegen_legalized_type(ret_ty)); + } + } +} diff --git a/crates/codegen2/src/yul/mod.rs b/crates/codegen2/src/yul/mod.rs new file mode 100644 index 0000000000..6e7e95457e --- /dev/null +++ b/crates/codegen2/src/yul/mod.rs @@ -0,0 +1,26 @@ +use std::borrow::Cow; + +pub mod isel; +pub mod legalize; +pub mod runtime; + +mod slot_size; + +use yultsur::*; + +/// A helper struct to abstract ident and expr. +struct YulVariable<'a>(Cow<'a, str>); + +impl<'a> YulVariable<'a> { + fn expr(&self) -> yul::Expression { + identifier_expression! {(format!{"${}", self.0})} + } + + fn ident(&self) -> yul::Identifier { + identifier! {(format!{"${}", self.0})} + } + + fn new(name: impl Into>) -> Self { + Self(name.into()) + } +} diff --git a/crates/codegen2/src/yul/runtime/abi.rs b/crates/codegen2/src/yul/runtime/abi.rs new file mode 100644 index 0000000000..565465a18e --- /dev/null +++ b/crates/codegen2/src/yul/runtime/abi.rs @@ -0,0 +1,950 @@ +use crate::{ + db::CodegenDb, + yul::{ + runtime::{error_revert_numeric, make_ptr}, + slot_size::{yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, +}; + +use super::{AbiSrcLocation, DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_abi::types::AbiType; +use fe_mir::ir::{self, types::ArrayDef, TypeId, TypeKind}; +use yultsur::*; + +pub(super) fn make_abi_encode_primitive_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + ([src.ident()] := [provider.primitive_cast(db, src.expr(), legalized_ty)]) + ([yul::Statement::Expression(provider.ptr_store( + db, + dst.expr(), + src.expr(), + make_ptr(db, yul_primitive_type(db), is_dst_storage), + ))]) + ([enc_size.ident()] := 32) + } + }; + + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_static_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let is_dst_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_abi_ty = db.codegen_abi_type(elem_ty); + let elem_ptr_ty = make_ptr(db, elem_ty, false); + let elem_ty_size = deref_ty.array_elem_size(db.upcast(), SLOT_SIZE); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let header_size = elem_abi_ty.header_size(); + let iter_count = literal_expression! {(len)}; + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + + (pop([provider.abi_encode(db, src.expr(), dst.expr(), elem_ptr_ty, is_dst_storage)])) + ([src.ident()] := add([src.expr()], [literal_expression!{(elem_ty_size)}])) + ([dst.ident()] := add([dst.expr()], [literal_expression!{(header_size)}])) + }) + ([enc_size.ident()] := [literal_expression! {(header_size * len)}]) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_abi_encode_dynamic_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let is_dst_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_header_size = 32; + let total_header_size = elem_header_size * len; + let elem_ptr_ty = make_ptr(db, elem_ty, false); + let elem_ty_size = deref_ty.array_elem_size(db.upcast(), SLOT_SIZE); + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let data_ptr = YulVariable::new("data_ptr"); + let enc_size = YulVariable::new("enc_size"); + let iter_count = literal_expression! {(len)}; + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (let [header_ptr.ident()] := [dst.expr()]) + (let [data_ptr.ident()] := add([dst.expr()], [literal_expression!{(total_header_size)}])) + ([enc_size.ident()] := [literal_expression!{(total_header_size)}]) + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, src.expr(), data_ptr.expr(), elem_ptr_ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression!{(elem_header_size)}])) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + ([src.ident()] := add([src.expr()], [literal_expression!{(elem_ty_size)}])) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_abi_encode_static_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let deref_ty = legalized_ty.deref(db.upcast()); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let field_enc_size = YulVariable::new("field_enc_size"); + let mut body = vec![ + statement! {[enc_size.ident()] := 0 }, + statement! {let [field_enc_size.ident()] := 0 }, + ]; + let field_num = deref_ty.aggregate_field_num(db.upcast()); + + for idx in 0..field_num { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_ty_ptr = make_ptr(db, field_ty, false); + let field_offset = deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE); + let src_offset = expression! { add([src.expr()], [literal_expression!{(field_offset)}]) }; + body.push(statement!{ + [field_enc_size.ident()] := [provider.abi_encode(db, src_offset, dst.expr(), field_ty_ptr, is_dst_storage)] + }); + body.push(statement! { + [enc_size.ident()] := add([enc_size.expr()], [field_enc_size.expr()]) + }); + + if idx < field_num - 1 { + body.push(assignment! {[dst.ident()] := add([dst.expr()], [field_enc_size.expr()])}); + } + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident(), dst.ident()], + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_encode_dynamic_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let is_src_storage = legalized_ty.is_sptr(db.upcast()); + let deref_ty = legalized_ty.deref(db.upcast()); + let field_num = deref_ty.aggregate_field_num(db.upcast()); + + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let enc_size = YulVariable::new("enc_size"); + let data_ptr = YulVariable::new("data_ptr"); + + let total_header_size = literal_expression! { ((0..field_num).fold(0, |acc, idx| { + let ty = deref_ty.projection_ty_imm(db.upcast(), idx); + acc + db.codegen_abi_type(ty).header_size() + })) }; + let mut body = statements! { + (let [header_ptr.ident()] := [dst.expr()]) + ([enc_size.ident()] := [total_header_size]) + (let [data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + }; + + for idx in 0..field_num { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_abi_ty = db.codegen_abi_type(field_ty); + let field_offset = + literal_expression! { (deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + let field_ptr = expression! { add([src.expr()], [field_offset]) }; + let field_ptr_ty = make_ptr(db, field_ty, is_src_storage); + + let stmts = if field_abi_ty.is_static() { + statements! { + (pop([provider.abi_encode(db, field_ptr, header_ptr.expr(), field_ptr_ty, is_dst_storage)])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression! {(field_abi_ty.header_size())}])) + } + } else { + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + statements! { + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, field_ptr, data_ptr.expr(), field_ptr_ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + } + }; + body.extend_from_slice(&stmts); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident(), dst.ident()], + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_encode_string_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let string_len = YulVariable::new("string_len"); + let enc_size = YulVariable::new("enc_size"); + + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + (let [string_len.ident()] := mload([src.expr()])) + (let data_size := add(32, [string_len.expr()])) + ([enc_size.ident()] := mul((div((add(data_size, 31)), 32)), 32)) + (let padding_word_ptr := add([dst.expr()], (sub([enc_size.expr()], 32)))) + (mstore(padding_word_ptr, 0)) + ([yul::Statement::Expression(provider.ptr_copy(db, src.expr(), dst.expr(), literal_expression!{data_size}, false, is_dst_storage))]) + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_bytes_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + len: usize, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let enc_size = YulVariable::new("enc_size"); + let dst_len_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()]) -> [enc_size.ident()] { + ([enc_size.ident()] := [literal_expression!{ (ceil_32(32 + len)) }]) + (if (gt([enc_size.expr()], 0)) { + (let padding_word_ptr := add([dst.expr()], (sub([enc_size.expr()], 32)))) + (mstore(padding_word_ptr, 0)) + }) + ([yul::Statement::Expression(provider.ptr_store(db, dst.expr(), literal_expression!{ (len) }, dst_len_ty))]) + ([dst.ident()] := add(32, [dst.expr()])) + ([yul::Statement::Expression(provider.ptr_copy(db, src.expr(), dst.expr(), literal_expression!{(len)}, false, is_dst_storage))]) + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_abi_encode_seq( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + value_tys: &[TypeId], + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let value_num = value_tys.len(); + let abi_tys: Vec<_> = value_tys + .iter() + .map(|ty| db.codegen_abi_type(ty.deref(db.upcast()))) + .collect(); + let dst = YulVariable::new("dst"); + let header_ptr = YulVariable::new("header_ptr"); + let enc_size = YulVariable::new("enc_size"); + let data_ptr = YulVariable::new("data_ptr"); + let values: Vec<_> = (0..value_num) + .map(|idx| YulVariable::new(format!("value{idx}"))) + .collect(); + + let total_header_size = + literal_expression! { (abi_tys.iter().fold(0, |acc, ty| acc + ty.header_size())) }; + let mut body = statements! { + (let [header_ptr.ident()] := [dst.expr()]) + ([enc_size.ident()] := [total_header_size]) + (let [data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + }; + + for i in 0..value_num { + let ty = value_tys[i]; + let abi_ty = &abi_tys[i]; + let value = &values[i]; + let stmts = if abi_ty.is_static() { + statements! { + (pop([provider.abi_encode(db, value.expr(), header_ptr.expr(), ty, is_dst_storage)])) + ([header_ptr.ident()] := add([header_ptr.expr()], [literal_expression!{ (abi_ty.header_size()) }])) + } + } else { + let header_ty = make_ptr(db, yul_primitive_type(db), is_dst_storage); + statements! { + ([yul::Statement::Expression(provider.ptr_store(db, header_ptr.expr(), enc_size.expr(), header_ty))]) + ([enc_size.ident()] := add([provider.abi_encode(db, value.expr(), data_ptr.expr(), ty, is_dst_storage)], [enc_size.expr()])) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([data_ptr.ident()] := add([dst.expr()], [enc_size.expr()])) + } + }; + body.extend_from_slice(&stmts); + } + + let mut parameters = vec![dst.ident()]; + for value in values { + parameters.push(value.ident()); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters, + returns: vec![enc_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_abi_decode( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + types: &[TypeId], + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let header_size = types + .iter() + .fold(0, |acc, ty| acc + db.codegen_abi_type(*ty).header_size()); + let src = YulVariable::new("$src"); + let enc_size = YulVariable::new("$enc_size"); + let header_ptr = YulVariable::new("header_ptr"); + let data_offset = YulVariable::new("data_offset"); + let tmp_offset = YulVariable::new("tmp_offset"); + let returns: Vec<_> = (0..types.len()) + .map(|i| YulVariable::new(format!("$ret{i}"))) + .collect(); + + let abi_enc_size = abi_enc_size(db, types); + let size_check = match abi_enc_size { + AbiEncodingSize::Static(size) => statements! { + (if (iszero((eq([enc_size.expr()], [literal_expression!{(size)}])))) + { [revert_with_invalid_abi_data(provider, db)] + }) + }, + AbiEncodingSize::Bounded { min, max } => statements! { + (if (or( + (lt([enc_size.expr()], [literal_expression!{(min)}])), + (gt([enc_size.expr()], [literal_expression!{(max)}])) + )) { + [revert_with_invalid_abi_data(provider, db)] + }) + }, + }; + + let mut body = statements! { + (let [header_ptr.ident()] := [src.expr()]) + (let [data_offset.ident()] := [literal_expression!{ (header_size) }]) + (let [tmp_offset.ident()] := 0) + }; + for i in 0..returns.len() { + let ret_value = &returns[i]; + let field_ty = types[i]; + let field_abi_ty = db.codegen_abi_type(field_ty.deref(db.upcast())); + if field_abi_ty.is_static() { + body.push(statement!{ [ret_value.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), field_ty, abi_loc)] }); + } else { + let identifiers = identifiers! { + [ret_value.ident()] + [tmp_offset.ident()] + }; + body.push(yul::Statement::Assignment(yul::Assignment { + identifiers, + expression: provider.abi_decode_dynamic( + db, + expression! {add([src.expr()], [data_offset.expr()])}, + field_ty, + abi_loc, + ), + })); + body.push(statement! { ([data_offset.ident()] := add([data_offset.expr()], [tmp_offset.expr()])) }); + }; + + let field_header_size = literal_expression! { (field_abi_ty.header_size()) }; + body.push( + statement! { [header_ptr.ident()] := add([header_ptr.expr()], [field_header_size]) }, + ); + } + + let offset_check = match abi_enc_size { + AbiEncodingSize::Static(_) => vec![], + AbiEncodingSize::Bounded { .. } => statements! { + (if (iszero((eq([enc_size.expr()], [data_offset.expr()])))) { [revert_with_invalid_abi_data(provider, db)] }) + }, + }; + + let returns: Vec<_> = returns.iter().map(YulVariable::ident).collect(); + let func_def = function_definition! { + function [func_name.ident()]([src.ident()], [enc_size.ident()]) -> [returns...] { + [size_check...] + [body...] + [offset_check...] + } + }; + RuntimeFunction::from_statement(func_def) +} + +impl DefaultRuntimeProvider { + fn abi_decode_static( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + ty: TypeId, + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let ty = db.codegen_legalized_type(ty).deref(db.upcast()); + let abi_ty = db.codegen_abi_type(ty.deref(db.upcast())); + debug_assert!(abi_ty.is_static()); + + let func_name_postfix = match abi_loc { + AbiSrcLocation::CallData => "calldata", + AbiSrcLocation::Memory => "memory", + }; + + let args = vec![src]; + if ty.is_primitive(db.upcast()) { + let name = format! { + "$abi_decode_primitive_type_{}_from_{}", + ty.0, func_name_postfix, + }; + return self.create_then_call(&name, args, |provider| { + make_abi_decode_primitive_type(provider, db, &name, ty, abi_loc) + }); + } + + let name = format! { + "$abi_decode_static_aggregate_type_{}_from_{}", + ty.0, func_name_postfix, + }; + self.create_then_call(&name, args, |provider| { + make_abi_decode_static_aggregate_type(provider, db, &name, ty, abi_loc) + }) + } + + fn abi_decode_dynamic( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + ty: TypeId, + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let ty = db.codegen_legalized_type(ty).deref(db.upcast()); + let abi_ty = db.codegen_abi_type(ty.deref(db.upcast())); + debug_assert!(!abi_ty.is_static()); + + let func_name_postfix = match abi_loc { + AbiSrcLocation::CallData => "calldata", + AbiSrcLocation::Memory => "memory", + }; + + let mut args = vec![src]; + match abi_ty { + AbiType::String => { + let len = match &ty.data(db.upcast()).kind { + TypeKind::String(len) => *len, + _ => unreachable!(), + }; + args.push(literal_expression! {(len)}); + let name = format! {"$abi_decode_string_from_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_string_type(provider, db, &name, abi_loc) + }) + } + + AbiType::Bytes => { + let len = match &ty.data(db.upcast()).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + _ => unreachable!(), + }; + args.push(literal_expression! {(len)}); + let name = format! {"$abi_decode_bytes_from_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_bytes_type(provider, db, &name, abi_loc) + }) + } + + AbiType::Array { .. } => { + let name = + format! {"$abi_decode_dynamic_array_{}_from_{}", ty.0, func_name_postfix}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_dynamic_elem_array_type(provider, db, &name, ty, abi_loc) + }) + } + + AbiType::Tuple(_) => { + let name = + format! {"$abi_decode_dynamic_aggregate_{}_from_{}", ty.0, func_name_postfix}; + self.create_then_call(&name, args, |provider| { + make_abi_decode_dynamic_aggregate_type(provider, db, &name, ty, abi_loc) + }) + } + + _ => unreachable!(), + } + } +} + +fn make_abi_decode_primitive_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + debug_assert! {ty.is_primitive(db.upcast())} + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let ret = YulVariable::new("ret"); + + let decode = match abi_loc { + AbiSrcLocation::CallData => { + statement! { [ret.ident()] := calldataload([src.expr()]) } + } + AbiSrcLocation::Memory => { + statement! { [ret.ident()] := mload([src.expr()]) } + } + }; + + let ty_size_bits = ty.size_of(db.upcast(), SLOT_SIZE) * 8; + let validation = if ty_size_bits == 256 { + statements! {} + } else if ty.is_signed(db.upcast()) { + let shift_num = literal_expression! { ( ty_size_bits - 1) }; + let tmp1 = YulVariable::new("tmp1"); + let tmp2 = YulVariable::new("tmp2"); + statements! { + (let [tmp1.ident()] := iszero((shr([shift_num.clone()], [ret.expr()])))) + (let [tmp2.ident()] := iszero((shr([shift_num], (not([ret.expr()])))))) + (if (iszero((or([tmp1.expr()], [tmp2.expr()])))) { + [revert_with_invalid_abi_data(provider, db)] + }) + } + } else { + let shift_num = literal_expression! { ( ty_size_bits) }; + let tmp = YulVariable::new("tmp"); + statements! { + (let [tmp.ident()] := iszero((shr([shift_num], [ret.expr()])))) + (if (iszero([tmp.expr()])) { + [revert_with_invalid_abi_data(provider, db)] + }) + } + }; + + let func = function_definition! { + function [func_name.ident()]([src.ident()]) -> [ret.ident()] { + ([decode]) + [validation...] + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_static_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + debug_assert!(ty.is_aggregate(db.upcast())); + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let ret = YulVariable::new("ret"); + let field_data = YulVariable::new("field_data"); + let type_size = literal_expression! { (ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let mut body = statements! { + (let [field_data.ident()] := 0) + ([ret.ident()] := [provider.alloc(db, type_size)]) + }; + + let field_num = ty.aggregate_field_num(db.upcast()); + for idx in 0..field_num { + let field_ty = ty.projection_ty_imm(db.upcast(), idx); + let field_ty_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + body.push(statement! { [field_data.ident()] := [provider.abi_decode_static(db, src.expr(), field_ty, abi_loc)] }); + + let dst_offset = + literal_expression! { (ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + let field_ty_ptr = make_ptr(db, field_ty, false); + if field_ty.is_primitive(db.upcast()) { + body.push(yul::Statement::Expression(provider.ptr_store( + db, + expression! {add([ret.expr()], [dst_offset])}, + field_data.expr(), + field_ty_ptr, + ))); + } else { + body.push(yul::Statement::Expression(provider.ptr_copy( + db, + field_data.expr(), + expression! {add([ret.expr()], [dst_offset])}, + literal_expression! { (field_ty_size) }, + false, + false, + ))); + } + + if idx < field_num - 1 { + let abi_field_ty = db.codegen_abi_type(field_ty); + let field_abi_ty_size = literal_expression! { (abi_field_ty.header_size()) }; + body.push(assignment! {[src.ident()] := add([src.expr()], [field_abi_ty_size])}); + } + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident()], + returns: vec![ret.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +fn make_abi_decode_string_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let max_len = YulVariable::new("max_len"); + let string_size = YulVariable::new("string_size"); + let dst_size = YulVariable::new("dst_size"); + let end_word = YulVariable::new("end_word"); + let end_word_ptr = YulVariable::new("end_word_ptr"); + let padding_size_bits = YulVariable::new("padding_size_bits"); + let primitive_ty_ptr = make_ptr(db, yul_primitive_type(db), false); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [max_len.ident()]) -> [(vec![decoded_data.ident(), decoded_size.ident()])...] { + (let string_len := [provider.abi_decode_static(db, src.expr(), primitive_ty_ptr, abi_loc)]) + (if (gt(string_len, [max_len.expr()])) { [revert_with_invalid_abi_data(provider, db)] } ) + (let [string_size.ident()] := add(string_len, 32)) + ([decoded_size.ident()] := mul((div((add([string_size.expr()], 31)), 32)), 32)) + (let [end_word_ptr.ident()] := sub((add([src.expr()], [decoded_size.expr()])), 32)) + (let [end_word.ident()] := [provider.abi_decode_static(db, end_word_ptr.expr(), primitive_ty_ptr, abi_loc)]) + (let [padding_size_bits.ident()] := mul((sub([decoded_size.expr()], [string_size.expr()])), 8)) + [(check_right_padding(provider, db, end_word.expr(), padding_size_bits.expr()))...] + (let [dst_size.ident()] := add([max_len.expr()], 32)) + ([decoded_data.ident()] := [provider.alloc(db, dst_size.expr())]) + ([ptr_copy_decode(provider, db, src.expr(), decoded_data.expr(), string_size.expr(), abi_loc)]) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_bytes_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let max_len = YulVariable::new("max_len"); + let bytes_size = YulVariable::new("bytes_size"); + let end_word = YulVariable::new("end_word"); + let end_word_ptr = YulVariable::new("end_word_ptr"); + let padding_size_bits = YulVariable::new("padding_size_bits"); + let primitive_ty_ptr = make_ptr(db, yul_primitive_type(db), false); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [max_len.ident()]) -> [(vec![decoded_data.ident(),decoded_size.ident()])...] { + (let [bytes_size.ident()] := [provider.abi_decode_static(db, src.expr(), primitive_ty_ptr, abi_loc)]) + (if (iszero((eq([bytes_size.expr()], [max_len.expr()])))) { [revert_with_invalid_abi_data(provider, db)] } ) + ([src.ident()] := add([src.expr()], 32)) + (let padded_data_size := mul((div((add([bytes_size.expr()], 31)), 32)), 32)) + ([decoded_size.ident()] := add(padded_data_size, 32)) + (let [end_word_ptr.ident()] := sub((add([src.expr()], padded_data_size)), 32)) + (let [end_word.ident()] := [provider.abi_decode_static(db, end_word_ptr.expr(), primitive_ty_ptr, abi_loc)]) + (let [padding_size_bits.ident()] := mul((sub(padded_data_size, [bytes_size.expr()])), 8)) + [(check_right_padding(provider, db, end_word.expr(), padding_size_bits.expr()))...] + ([decoded_data.ident()] := [provider.alloc(db, max_len.expr())]) + ([ptr_copy_decode(provider, db, src.expr(), decoded_data.expr(), bytes_size.expr(), abi_loc)]) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_dynamic_elem_array_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let deref_ty = legalized_ty.deref(db.upcast()); + let (elem_ty, len) = match &deref_ty.data(db.upcast()).kind { + ir::TypeKind::Array(def) => (def.elem_ty, def.len), + _ => unreachable!(), + }; + let elem_ty_size = literal_expression! { (deref_ty.array_elem_size(db.upcast(), SLOT_SIZE)) }; + let total_header_size = literal_expression! { (32 * len) }; + let iter_count = literal_expression! { (len) }; + let ret_size = literal_expression! { (deref_ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let header_ptr = YulVariable::new("header_ptr"); + let data_ptr = YulVariable::new("data_ptr"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let decoded_size_tmp = YulVariable::new("decoded_size_tmp"); + let ret_elem_ptr = YulVariable::new("ret_elem_ptr"); + let elem_data = YulVariable::new("elem_data"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()]) -> [decoded_data.ident()], [decoded_size.ident()] { + ([decoded_data.ident()] := [provider.alloc(db, ret_size)]) + ([decoded_size.ident()] := [total_header_size]) + (let [decoded_size_tmp.ident()] := 0) + (let [header_ptr.ident()] := [src.expr()]) + (let [data_ptr.ident()] := 0) + (let [elem_data.ident()] := 0) + (let [ret_elem_ptr.ident()] := [decoded_data.expr()]) + + (for {(let i := 0)} (lt(i, [iter_count])) {(i := (add(i, 1)))} + { + ([data_ptr.ident()] := add([src.expr()], [provider.abi_decode_static(db, header_ptr.expr(), yul_primitive_type(db), abi_loc)])) + ([assignment! {[elem_data.ident()], [decoded_size_tmp.ident()] := [provider.abi_decode_dynamic(db, data_ptr.expr(), elem_ty, abi_loc)] }]) + ([decoded_size.ident()] := add([decoded_size.expr()], [decoded_size_tmp.expr()])) + ([yul::Statement::Expression(provider.ptr_copy(db, elem_data.expr(), ret_elem_ptr.expr(), elem_ty_size.clone(), false, false))]) + ([header_ptr.ident()] := add([header_ptr.expr()], 32)) + ([ret_elem_ptr.ident()] := add([ret_elem_ptr.expr()], [elem_ty_size])) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +fn make_abi_decode_dynamic_aggregate_type( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + abi_loc: AbiSrcLocation, +) -> RuntimeFunction { + let deref_ty = legalized_ty.deref(db.upcast()); + let type_size = literal_expression! { (deref_ty.size_of(db.upcast(), SLOT_SIZE)) }; + + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let header_ptr = YulVariable::new("header_ptr"); + let data_offset = YulVariable::new("data_offset"); + let decoded_data = YulVariable::new("decoded_data"); + let decoded_size = YulVariable::new("decoded_size"); + let decoded_size_tmp = YulVariable::new("decoded_size_tmp"); + let ret_field_ptr = YulVariable::new("ret_field_ptr"); + let field_data = YulVariable::new("field_data"); + + let mut body = statements! { + ([decoded_data.ident()] := [provider.alloc(db, type_size)]) + ([decoded_size.ident()] := 0) + (let [decoded_size_tmp.ident()] := 0) + (let [header_ptr.ident()] := [src.expr()]) + (let [data_offset.ident()] := 0) + (let [field_data.ident()] := 0) + (let [ret_field_ptr.ident()] := 0) + }; + + for i in 0..deref_ty.aggregate_field_num(db.upcast()) { + let field_ty = deref_ty.projection_ty_imm(db.upcast(), i); + let field_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + let field_abi_ty = db.codegen_abi_type(field_ty); + let field_offset = deref_ty.aggregate_elem_offset(db.upcast(), i, SLOT_SIZE); + + let decode_data = if field_abi_ty.is_static() { + statements! { + ([field_data.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), field_ty, abi_loc)]) + ([decoded_size_tmp.ident()] := [literal_expression!{ (field_abi_ty.header_size()) }]) + } + } else { + statements! { + ([data_offset.ident()] := [provider.abi_decode_static(db, header_ptr.expr(), yul_primitive_type(db), abi_loc)]) + ([assignment! { + [field_data.ident()], [decoded_size_tmp.ident()] := + [provider.abi_decode_dynamic( + db, + expression!{ add([src.expr()], [data_offset.expr()]) }, + field_ty, + abi_loc + )] + }]) + ([decoded_size_tmp.ident()] := add([decoded_size_tmp.expr()], 32)) + } + }; + body.extend_from_slice(&decode_data); + body.push(assignment!{ [decoded_size.ident()] := add([decoded_size.expr()], [decoded_size_tmp.expr()]) }); + + body.push(assignment! { [ret_field_ptr.ident()] := add([decoded_data.expr()], [literal_expression!{ (field_offset) }])}); + let copy_to_ret = if field_ty.is_primitive(db.upcast()) { + let field_ptr_ty = make_ptr(db, field_ty, false); + yul::Statement::Expression(provider.ptr_store( + db, + ret_field_ptr.expr(), + field_data.expr(), + field_ptr_ty, + )) + } else { + yul::Statement::Expression(provider.ptr_copy( + db, + field_data.expr(), + ret_field_ptr.expr(), + literal_expression! { (field_size) }, + false, + false, + )) + }; + body.push(copy_to_ret); + + let header_size = literal_expression! { (field_abi_ty.header_size()) }; + body.push(statement! { + [header_ptr.ident()] := add([header_ptr.expr()], [header_size]) + }); + } + + let func_def = yul::FunctionDefinition { + name: func_name.ident(), + parameters: vec![src.ident()], + returns: vec![decoded_data.ident(), decoded_size.ident()], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +enum AbiEncodingSize { + Static(usize), + Bounded { min: usize, max: usize }, +} + +fn abi_enc_size(db: &dyn CodegenDb, types: &[TypeId]) -> AbiEncodingSize { + let mut min = 0; + let mut max = 0; + for &ty in types { + let legalized_ty = db.codegen_legalized_type(ty); + min += db.codegen_abi_type_minimum_size(legalized_ty); + max += db.codegen_abi_type_maximum_size(legalized_ty); + } + + if min == max { + AbiEncodingSize::Static(min) + } else { + AbiEncodingSize::Bounded { min, max } + } +} + +fn revert_with_invalid_abi_data( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, +) -> yul::Statement { + const ERROR_INVALID_ABI_DATA: usize = 0x103; + let error_code = literal_expression! { (ERROR_INVALID_ABI_DATA) }; + error_revert_numeric(provider, db, error_code) +} + +fn check_right_padding( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + val: yul::Expression, + size_bits: yul::Expression, +) -> Vec { + statements! { + (let bits_shifted := sub(256, [size_bits])) + (let is_ok := iszero((shl(bits_shifted, [val])))) + (if (iszero((is_ok))) { + [revert_with_invalid_abi_data(provider, db)] + }) + } +} + +fn ptr_copy_decode( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + len: yul::Expression, + loc: AbiSrcLocation, +) -> yul::Statement { + match loc { + AbiSrcLocation::CallData => { + statement! { calldatacopy([dst], [src], [len]) } + } + AbiSrcLocation::Memory => { + yul::Statement::Expression(provider.ptr_copy(db, src, dst, len, false, false)) + } + } +} + +fn ceil_32(len: usize) -> usize { + ((len + 31) / 32) * 32 +} diff --git a/crates/codegen2/src/yul/runtime/contract.rs b/crates/codegen2/src/yul/runtime/contract.rs new file mode 100644 index 0000000000..d85321b377 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/contract.rs @@ -0,0 +1,127 @@ +use crate::{ + db::CodegenDb, + yul::{runtime::AbiSrcLocation, YulVariable}, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_analyzer::namespace::items::ContractId; +use fe_mir::ir::{FunctionId, Type, TypeKind}; + +use yultsur::*; + +pub(super) fn make_create( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + contract: ContractId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let contract_symbol = literal_expression! { + (format!(r#""{}""#, db.codegen_contract_deployer_symbol_name(contract))) + }; + + let size = YulVariable::new("size"); + let value = YulVariable::new("value"); + let func = function_definition! { + function [func_name.ident()]([value.ident()]) -> addr { + (let [size.ident()] := datasize([contract_symbol.clone()])) + (let mem_ptr := [provider.avail(db)]) + (let contract_ptr := dataoffset([contract_symbol])) + (datacopy(mem_ptr, contract_ptr, [size.expr()])) + (addr := create([value.expr()], mem_ptr, [size.expr()])) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_create2( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + contract: ContractId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let contract_symbol = literal_expression! { + (format!(r#""{}""#, db.codegen_contract_deployer_symbol_name(contract))) + }; + + let size = YulVariable::new("size"); + let value = YulVariable::new("value"); + let func = function_definition! { + function [func_name.ident()]([value.ident()], salt) -> addr { + (let [size.ident()] := datasize([contract_symbol.clone()])) + (let mem_ptr := [provider.avail(db)]) + (let contract_ptr := dataoffset([contract_symbol])) + (datacopy(mem_ptr, contract_ptr, [size.expr()])) + (addr := create2([value.expr()], mem_ptr, [size.expr()], salt)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_external_call( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + function: FunctionId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let sig = db.codegen_legalized_signature(function); + let param_num = sig.params.len(); + + let mut args = Vec::with_capacity(param_num); + let mut arg_tys = Vec::with_capacity(param_num); + for param in &sig.params { + args.push(YulVariable::new(param.name.as_str())); + arg_tys.push(param.ty); + } + let ret_ty = sig.return_type; + + let func_addr = YulVariable::new("func_addr"); + let params: Vec<_> = args.iter().map(YulVariable::ident).collect(); + let params_expr: Vec<_> = args.iter().map(YulVariable::expr).collect(); + let input = YulVariable::new("input"); + let input_size = YulVariable::new("input_size"); + let output_size = YulVariable::new("output_size"); + let output = YulVariable::new("output"); + + let func_selector = literal_expression! { (format!{"0x{}", db.codegen_abi_function(function).selector().hex()}) }; + let selector_ty = db.mir_intern_type(Type::new(TypeKind::U32, None).into()); + + let mut body = statements! { + (let [input.ident()] := [provider.avail(db)]) + [yul::Statement::Expression(provider.ptr_store(db, input.expr(), func_selector, selector_ty.make_mptr(db.upcast())))] + (let [input_size.ident()] := add(4, [provider.abi_encode_seq(db, ¶ms_expr, expression!{ add([input.expr()], 4) }, &arg_tys, false)])) + (let [output.ident()] := add([provider.avail(db)], [input_size.expr()])) + (let success := call((gas()), [func_addr.expr()], 0, [input.expr()], [input_size.expr()], 0, 0)) + (let [output_size.ident()] := returndatasize()) + (returndatacopy([output.expr()], 0, [output_size.expr()])) + (if (iszero(success)) { + (revert([output.expr()], [output_size.expr()])) + }) + }; + let func = if let Some(ret_ty) = ret_ty { + let ret = YulVariable::new("$ret"); + body.push( + statement!{ + [ret.ident()] := [provider.abi_decode(db, output.expr(), output_size.expr(), &[ret_ty], AbiSrcLocation::Memory)] + } + ); + function_definition! { + function [func_name.ident()]([func_addr.ident()], [params...]) -> [ret.ident()] { + [body...] + } + } + } else { + function_definition! { + function [func_name.ident()]([func_addr.ident()], [params...]) { + [body...] + } + } + }; + + RuntimeFunction::from_statement(func) +} diff --git a/crates/codegen2/src/yul/runtime/data.rs b/crates/codegen2/src/yul/runtime/data.rs new file mode 100644 index 0000000000..02509d7129 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/data.rs @@ -0,0 +1,461 @@ +use crate::{ + db::CodegenDb, + yul::{ + runtime::{make_ptr, BitMask}, + slot_size::{yul_primitive_type, SLOT_SIZE}, + YulVariable, + }, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::{types::TupleDef, Type, TypeId, TypeKind}; + +use yultsur::*; + +const HASH_SCRATCH_SPACE_START: usize = 0x00; +const HASH_SCRATCH_SPACE_SIZE: usize = 64; +const FREE_MEMORY_ADDRESS_STORE: usize = HASH_SCRATCH_SPACE_START + HASH_SCRATCH_SPACE_SIZE; +const FREE_MEMORY_START: usize = FREE_MEMORY_ADDRESS_STORE + 32; + +pub(super) fn make_alloc(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let free_address_ptr = literal_expression! {(FREE_MEMORY_ADDRESS_STORE)}; + let free_memory_start = literal_expression! {(FREE_MEMORY_START)}; + let func = function_definition! { + function [func_name.ident()](size) -> ptr { + (ptr := mload([free_address_ptr.clone()])) + (if (eq(ptr, 0x00)) { (ptr := [free_memory_start]) }) + (mstore([free_address_ptr], (add(ptr, size)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_avail(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let free_address_ptr = literal_expression! {(FREE_MEMORY_ADDRESS_STORE)}; + let free_memory_start = literal_expression! {(FREE_MEMORY_START)}; + let func = function_definition! { + function [func_name.ident()]() -> ptr { + (ptr := mload([free_address_ptr])) + (if (eq(ptr, 0x00)) { (ptr := [free_memory_start]) }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mcopym(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + (let iter_count := div([size.expr()], 32)) + (let original_src := [src.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (mstore([dst.expr()], (mload([src.expr()])))) + ([src.ident()] := add([src.expr()], 32)) + ([dst.ident()] := add([dst.expr()], 32)) + }) + + (let rem := sub([size.expr()], (sub([src.expr()], original_src)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((mload([src.expr()])), src_mask)) + (let dst_value := and((mload([dst.expr()])), dst_mask)) + (mstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mcopys(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([dst.ident()] := div([dst.expr()], 32)) + (let iter_count := div([size.expr()], 32)) + (let original_src := [src.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (sstore([dst.expr()], (mload([src.expr()])))) + ([src.ident()] := add([src.expr()], 32)) + ([dst.ident()] := add([dst.expr()], 1)) + }) + + (let rem := sub([size.expr()], (sub([src.expr()], original_src)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((mload([src.expr()])), src_mask)) + (let dst_value := and((sload([dst.expr()])), dst_mask)) + (sstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_scopym(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([src.ident()] := div([src.expr()], 32)) + (let iter_count := div([size.expr()], 32)) + (let original_dst := [dst.expr()]) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (mstore([dst.expr()], (sload([src.expr()])))) + ([src.ident()] := add([src.expr()], 1)) + ([dst.ident()] := add([dst.expr()], 32)) + }) + + (let rem := sub([size.expr()], (sub([dst.expr()], original_dst)))) + (if (gt(rem, 0)) { + (let rem_bits := mul(rem, 8)) + (let dst_mask := sub((shl((sub(256, rem_bits)), 1)), 1)) + (let src_mask := not(dst_mask)) + (let src_value := and((sload([src.expr()])), src_mask)) + (let dst_value := and((mload([dst.expr()])), dst_mask)) + (mstore([dst.expr()], (or(src_value, dst_value)))) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_scopys(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let src = YulVariable::new("src"); + let dst = YulVariable::new("dst"); + let size = YulVariable::new("size"); + let func = function_definition! { + function [func_name.ident()]([src.ident()], [dst.ident()], [size.ident()]) { + ([src.ident()] := div([src.expr()], 32)) + ([dst.ident()] := div([dst.expr()], 32)) + (let iter_count := div((add([size.expr()], 31)), 32)) + (for {(let i := 0)} (lt(i, iter_count)) {(i := (add(i, 1)))} + { + (sstore([dst.expr()], (sload([src.expr()])))) + ([src.ident()] := add([src.expr()], 1)) + ([dst.ident()] := add([dst.expr()], 1)) + }) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_sptr_store(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, value, size_bits) { + (let rem_bits := mul((mod(ptr, 32)), 8)) + (let shift_bits := sub(256, (add(rem_bits, size_bits)))) + (let mask := (shl(shift_bits, (sub((shl(size_bits, 1)), 1))))) + (let inv_mask := not(mask)) + (let slot := div(ptr, 32)) + (let new_value := or((and((sload(slot)), inv_mask)), (and((shl(shift_bits, value)), mask)))) + (sstore(slot, new_value)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mptr_store(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, value, shift_num, mask) { + (value := shl(shift_num, value)) + (let ptr_value := and((mload(ptr)), mask)) + (value := or(value, ptr_value)) + (mstore(ptr, value)) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_sptr_load(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, size_bits) -> ret { + (let rem_bits := mul((mod(ptr, 32)), 8)) + (let shift_num := sub(256, (add(rem_bits, size_bits)))) + (let slot := div(ptr, 32)) + (ret := shr(shift_num, (sload(slot)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_mptr_load(func_name: &str) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let func = function_definition! { + function [func_name.ident()](ptr, shift_num) -> ret { + (ret := shr(shift_num, (mload(ptr)))) + } + }; + + RuntimeFunction::from_statement(func) +} + +// TODO: We can optimize aggregate initialization by combining multiple +// `ptr_store` operations into single `ptr_store` operation. +pub(super) fn make_aggregate_init( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + arg_tys: Vec, +) -> RuntimeFunction { + debug_assert!(legalized_ty.is_ptr(db.upcast())); + let is_sptr = legalized_ty.is_sptr(db.upcast()); + let inner_ty = legalized_ty.deref(db.upcast()); + let ptr = YulVariable::new("ptr"); + let field_num = inner_ty.aggregate_field_num(db.upcast()); + + let iter_field_args = || (0..field_num).map(|i| YulVariable::new(format! {"arg{i}"})); + + let mut body = vec![]; + for (idx, field_arg) in iter_field_args().enumerate() { + let field_arg_ty = arg_tys[idx]; + let field_ty = inner_ty + .projection_ty_imm(db.upcast(), idx) + .deref(db.upcast()); + let field_ty_size = field_ty.size_of(db.upcast(), SLOT_SIZE); + let field_ptr_ty = make_ptr(db, field_ty, is_sptr); + let field_offset = + literal_expression! {(inner_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE))}; + + let field_ptr = expression! { add([ptr.expr()], [field_offset] )}; + let copy_expr = if field_ty.is_aggregate(db.upcast()) || field_ty.is_string(db.upcast()) { + // Call ptr copy function if field type is aggregate. + debug_assert!(field_arg_ty.is_ptr(db.upcast())); + provider.ptr_copy( + db, + field_arg.expr(), + field_ptr, + literal_expression! {(field_ty_size)}, + field_arg_ty.is_sptr(db.upcast()), + is_sptr, + ) + } else { + // Call store function if field type is not aggregate. + provider.ptr_store(db, field_ptr, field_arg.expr(), field_ptr_ty) + }; + body.push(yul::Statement::Expression(copy_expr)); + } + + let func_name = identifier! {(func_name)}; + let parameters = std::iter::once(ptr) + .chain(iter_field_args()) + .map(|var| var.ident()) + .collect(); + let func_def = yul::FunctionDefinition { + name: func_name, + parameters, + returns: vec![], + block: yul::Block { statements: body }, + }; + + RuntimeFunction(func_def) +} + +pub(super) fn make_enum_init( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, + arg_tys: Vec, +) -> RuntimeFunction { + debug_assert!(arg_tys.len() > 1); + + let func_name = YulVariable::new(func_name); + let is_sptr = legalized_ty.is_sptr(db.upcast()); + let ptr = YulVariable::new("ptr"); + let disc = YulVariable::new("disc"); + let disc_ty = arg_tys[0]; + let enum_data = || (0..arg_tys.len() - 1).map(|i| YulVariable::new(format! {"arg{i}"})); + + let tuple_def = TupleDef { + items: arg_tys + .iter() + .map(|ty| ty.deref(db.upcast())) + .skip(1) + .collect(), + }; + let tuple_ty = db.mir_intern_type( + Type { + kind: TypeKind::Tuple(tuple_def), + analyzer_ty: None, + } + .into(), + ); + let data_ptr_ty = make_ptr(db, tuple_ty, is_sptr); + let data_offset = legalized_ty + .deref(db.upcast()) + .enum_data_offset(db.upcast(), SLOT_SIZE); + let enum_data_init = statements! { + [statement! {[ptr.ident()] := add([ptr.expr()], [literal_expression!{(data_offset)}])}] + [yul::Statement::Expression(provider.aggregate_init( + db, + ptr.expr(), + enum_data().map(|arg| arg.expr()).collect(), + data_ptr_ty, arg_tys.iter().skip(1).copied().collect()))] + }; + + let enum_data_args: Vec<_> = enum_data().map(|var| var.ident()).collect(); + let func_def = function_definition! { + function [func_name.ident()]([ptr.ident()], [disc.ident()], [enum_data_args...]) { + [yul::Statement::Expression(provider.ptr_store(db, ptr.expr(), disc.expr(), make_ptr(db, disc_ty, is_sptr)))] + [enum_data_init...] + } + }; + RuntimeFunction::from_statement(func_def) +} + +pub(super) fn make_string_copy( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + data: &str, + is_dst_storage: bool, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let dst_ptr = YulVariable::new("dst_ptr"); + let symbol_name = literal_expression! { (format!(r#""{}""#, db.codegen_constant_string_symbol_name(data.to_string()))) }; + + let func = if is_dst_storage { + let tmp_ptr = YulVariable::new("tmp_ptr"); + let data_size = YulVariable::new("data_size"); + function_definition! { + function [func_name.ident()]([dst_ptr.ident()]) { + (let [tmp_ptr.ident()] := [provider.avail(db)]) + (let data_offset := dataoffset([symbol_name.clone()])) + (let [data_size.ident()] := datasize([symbol_name])) + (let len_slot := div([dst_ptr.expr()], 32)) + (sstore(len_slot, [data_size.expr()])) + (datacopy([tmp_ptr.expr()], data_offset, [data_size.expr()])) + ([dst_ptr.ident()] := add([dst_ptr.expr()], 32)) + ([yul::Statement::Expression( + provider.ptr_copy(db, tmp_ptr.expr(), dst_ptr.expr(), data_size.expr(), false, true)) + ]) + } + } + } else { + function_definition! { + function [func_name.ident()]([dst_ptr.ident()]) { + (let data_offset := dataoffset([symbol_name.clone()])) + (let data_size := datasize([symbol_name])) + (mstore([dst_ptr.expr()], data_size)) + ([dst_ptr.ident()] := add([dst_ptr.expr()], 32)) + (datacopy([dst_ptr.expr()], data_offset, data_size)) + } + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_string_construct( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + data: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let ptr_size = YulVariable::new("ptr_size"); + let string_ptr = YulVariable::new("string_ptr"); + + let func = function_definition! { + function [func_name.ident()]([ptr_size.ident()]) -> [string_ptr.ident()] { + ([string_ptr.ident()] := [provider.alloc(db, ptr_size.expr())]) + ([yul::Statement::Expression(provider.string_copy(db, string_ptr.expr(), data, false))]) + } + }; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_map_value_ptr_with_primitive_key( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + key_ty: TypeId, +) -> RuntimeFunction { + debug_assert!(key_ty.is_primitive(db.upcast())); + let scratch_space = literal_expression! {(HASH_SCRATCH_SPACE_START)}; + let scratch_size = literal_expression! {(HASH_SCRATCH_SPACE_SIZE)}; + let func_name = YulVariable::new(func_name); + let map_ptr = YulVariable::new("map_ptr"); + let key = YulVariable::new("key"); + let yul_primitive_type = yul_primitive_type(db); + + let mask = BitMask::new(1).not(); + + let func = function_definition! { + function [func_name.ident()]([map_ptr.ident()], [key.ident()]) -> ret { + ([yul::Statement::Expression(provider.ptr_store( + db, + scratch_space.clone(), + key.expr(), + yul_primitive_type.make_mptr(db.upcast()), + ))]) + ([yul::Statement::Expression(provider.ptr_store( + db, + expression!(add([scratch_space.clone()], 32)), + map_ptr.expr(), + yul_primitive_type.make_mptr(db.upcast()), + ))]) + (ret := and([mask.as_expr()], (keccak256([scratch_space], [scratch_size])))) + }}; + + RuntimeFunction::from_statement(func) +} + +pub(super) fn make_map_value_ptr_with_ptr_key( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + key_ty: TypeId, +) -> RuntimeFunction { + debug_assert!(key_ty.is_ptr(db.upcast())); + + let func_name = YulVariable::new(func_name); + let size = literal_expression! {(key_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE))}; + let map_ptr = YulVariable::new("map_ptr"); + let key = YulVariable::new("key"); + + let key_hash = expression! { keccak256([key.expr()], [size]) }; + let u256_ty = yul_primitive_type(db); + let def = function_definition! { + function [func_name.ident()]([map_ptr.ident()], [key.ident()]) -> ret { + (ret := [provider.map_value_ptr(db, map_ptr.expr(), key_hash, u256_ty)]) + } + }; + RuntimeFunction::from_statement(def) +} diff --git a/crates/codegen2/src/yul/runtime/emit.rs b/crates/codegen2/src/yul/runtime/emit.rs new file mode 100644 index 0000000000..7e1f8ab8c6 --- /dev/null +++ b/crates/codegen2/src/yul/runtime/emit.rs @@ -0,0 +1,74 @@ +use crate::{ + db::CodegenDb, + yul::{runtime::make_ptr, slot_size::SLOT_SIZE, YulVariable}, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::TypeId; + +use yultsur::*; + +pub(super) fn make_emit( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + legalized_ty: TypeId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let event_ptr = YulVariable::new("event_ptr"); + let deref_ty = legalized_ty.deref(db.upcast()); + + let abi = db.codegen_abi_event(deref_ty); + let mut topics = vec![literal_expression! {(format!("0x{}", abi.signature().hash_hex()))}]; + for (idx, field) in abi.inputs.iter().enumerate() { + if !field.indexed { + continue; + } + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let offset = + literal_expression! {(deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE))}; + let elem_ptr = expression! { add([event_ptr.expr()], [offset]) }; + let topic = if field_ty.is_aggregate(db.upcast()) { + todo!() + } else { + let topic = provider.ptr_load( + db, + elem_ptr, + make_ptr(db, field_ty, legalized_ty.is_sptr(db.upcast())), + ); + provider.primitive_cast(db, topic, field_ty) + }; + + topics.push(topic) + } + + let mut event_data_tys = vec![]; + let mut event_data_values = vec![]; + for (idx, field) in abi.inputs.iter().enumerate() { + if field.indexed { + continue; + } + + let field_ty = deref_ty.projection_ty_imm(db.upcast(), idx); + let field_offset = + literal_expression! { (deref_ty.aggregate_elem_offset(db.upcast(), idx, SLOT_SIZE)) }; + event_data_tys.push(make_ptr(db, field_ty, legalized_ty.is_sptr(db.upcast()))); + event_data_values.push(expression! { add([event_ptr.expr()], [field_offset]) }); + } + + debug_assert!(topics.len() < 5); + let log_func = identifier! { (format!("log{}", topics.len()))}; + + let event_data_ptr = YulVariable::new("event_data_ptr"); + let event_enc_size = YulVariable::new("event_enc_size"); + let func = function_definition! { + function [func_name.ident()]([event_ptr.ident()]) { + (let [event_data_ptr.ident()] := [provider.avail(db)]) + (let [event_enc_size.ident()] := [provider.abi_encode_seq(db, &event_data_values, event_data_ptr.expr(), &event_data_tys, false )]) + ([log_func]([event_data_ptr.expr()], [event_enc_size.expr()], [topics...])) + } + }; + + RuntimeFunction::from_statement(func) +} diff --git a/crates/codegen2/src/yul/runtime/mod.rs b/crates/codegen2/src/yul/runtime/mod.rs new file mode 100644 index 0000000000..a658b4b4aa --- /dev/null +++ b/crates/codegen2/src/yul/runtime/mod.rs @@ -0,0 +1,828 @@ +mod abi; +mod contract; +mod data; +mod emit; +mod revert; +mod safe_math; + +use std::fmt::Write; + +use fe_abi::types::AbiType; +use fe_analyzer::namespace::items::ContractId; +use fe_mir::ir::{types::ArrayDef, FunctionId, TypeId, TypeKind}; +use indexmap::IndexMap; +use yultsur::*; + +use num_bigint::BigInt; + +use crate::{db::CodegenDb, yul::slot_size::SLOT_SIZE}; + +use super::slot_size::yul_primitive_type; + +pub trait RuntimeProvider { + fn collect_definitions(&self) -> Vec; + + fn alloc(&mut self, db: &dyn CodegenDb, size: yul::Expression) -> yul::Expression; + + fn avail(&mut self, db: &dyn CodegenDb) -> yul::Expression; + + fn create( + &mut self, + db: &dyn CodegenDb, + contract: ContractId, + value: yul::Expression, + ) -> yul::Expression; + + fn create2( + &mut self, + db: &dyn CodegenDb, + contract: ContractId, + value: yul::Expression, + salt: yul::Expression, + ) -> yul::Expression; + + fn emit( + &mut self, + db: &dyn CodegenDb, + event: yul::Expression, + event_ty: TypeId, + ) -> yul::Expression; + + fn revert( + &mut self, + db: &dyn CodegenDb, + arg: Option, + arg_name: &str, + arg_ty: TypeId, + ) -> yul::Expression; + + fn external_call( + &mut self, + db: &dyn CodegenDb, + function: FunctionId, + args: Vec, + ) -> yul::Expression; + + fn map_value_ptr( + &mut self, + db: &dyn CodegenDb, + map_ptr: yul::Expression, + key: yul::Expression, + key_ty: TypeId, + ) -> yul::Expression; + + fn aggregate_init( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + args: Vec, + ptr_ty: TypeId, + arg_tys: Vec, + ) -> yul::Expression; + + fn string_copy( + &mut self, + db: &dyn CodegenDb, + dst: yul::Expression, + data: &str, + is_dst_storage: bool, + ) -> yul::Expression; + + fn string_construct( + &mut self, + db: &dyn CodegenDb, + data: &str, + string_len: usize, + ) -> yul::Expression; + + /// Copy data from `src` to `dst`. + /// NOTE: src and dst must be aligned by 32 when a ptr is storage ptr. + fn ptr_copy( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + size: yul::Expression, + is_src_storage: bool, + is_dst_storage: bool, + ) -> yul::Expression; + + fn ptr_store( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + imm: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression; + + fn ptr_load( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression; + + fn abi_encode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + src_ty: TypeId, + is_dst_storage: bool, + ) -> yul::Expression; + + fn abi_encode_seq( + &mut self, + db: &dyn CodegenDb, + src: &[yul::Expression], + dst: yul::Expression, + src_tys: &[TypeId], + is_dst_storage: bool, + ) -> yul::Expression; + + fn abi_decode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + size: yul::Expression, + types: &[TypeId], + abi_loc: AbiSrcLocation, + ) -> yul::Expression; + + fn primitive_cast( + &mut self, + db: &dyn CodegenDb, + value: yul::Expression, + from_ty: TypeId, + ) -> yul::Expression { + debug_assert!(from_ty.is_primitive(db.upcast())); + let from_size = from_ty.size_of(db.upcast(), SLOT_SIZE); + + if from_ty.is_signed(db.upcast()) { + let significant = literal_expression! {(from_size-1)}; + expression! { signextend([significant], [value]) } + } else { + let mask = BitMask::new(from_size); + expression! { and([value], [mask.as_expr()]) } + } + } + + // TODO: The all functions below will be reimplemented in `std`. + fn safe_add( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_sub( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_mul( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_div( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_mod( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; + + fn safe_pow( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression; +} + +#[derive(Clone, Copy, Debug)] +pub enum AbiSrcLocation { + CallData, + Memory, +} + +#[derive(Debug, Default)] +pub struct DefaultRuntimeProvider { + functions: IndexMap, +} + +impl DefaultRuntimeProvider { + fn create_then_call( + &mut self, + name: &str, + args: Vec, + func_builder: F, + ) -> yul::Expression + where + F: FnOnce(&mut Self) -> RuntimeFunction, + { + if let Some(func) = self.functions.get(name) { + func.call(args) + } else { + let func = func_builder(self); + let result = func.call(args); + self.functions.insert(name.to_string(), func); + result + } + } +} + +impl RuntimeProvider for DefaultRuntimeProvider { + fn collect_definitions(&self) -> Vec { + self.functions + .values() + .map(RuntimeFunction::definition) + .collect() + } + + fn alloc(&mut self, _db: &dyn CodegenDb, bytes: yul::Expression) -> yul::Expression { + let name = "$alloc"; + let arg = vec![bytes]; + self.create_then_call(name, arg, |_| data::make_alloc(name)) + } + + fn avail(&mut self, _db: &dyn CodegenDb) -> yul::Expression { + let name = "$avail"; + let arg = vec![]; + self.create_then_call(name, arg, |_| data::make_avail(name)) + } + + fn create( + &mut self, + db: &dyn CodegenDb, + contract: ContractId, + value: yul::Expression, + ) -> yul::Expression { + let name = format!("$create_{}", db.codegen_contract_symbol_name(contract)); + let arg = vec![value]; + self.create_then_call(&name, arg, |provider| { + contract::make_create(provider, db, &name, contract) + }) + } + + fn create2( + &mut self, + db: &dyn CodegenDb, + contract: ContractId, + value: yul::Expression, + salt: yul::Expression, + ) -> yul::Expression { + let name = format!("$create2_{}", db.codegen_contract_symbol_name(contract)); + let arg = vec![value, salt]; + self.create_then_call(&name, arg, |provider| { + contract::make_create2(provider, db, &name, contract) + }) + } + + fn emit( + &mut self, + db: &dyn CodegenDb, + event: yul::Expression, + event_ty: TypeId, + ) -> yul::Expression { + let name = format!("$emit_{}", event_ty.0); + let legalized_ty = db.codegen_legalized_type(event_ty); + self.create_then_call(&name, vec![event], |provider| { + emit::make_emit(provider, db, &name, legalized_ty) + }) + } + + fn revert( + &mut self, + db: &dyn CodegenDb, + arg: Option, + arg_name: &str, + arg_ty: TypeId, + ) -> yul::Expression { + let func_name = format! {"$revert_{}_{}", arg_name, arg_ty.0}; + let args = match arg { + Some(arg) => vec![arg], + None => vec![], + }; + self.create_then_call(&func_name, args, |provider| { + revert::make_revert(provider, db, &func_name, arg_name, arg_ty) + }) + } + + fn external_call( + &mut self, + db: &dyn CodegenDb, + function: FunctionId, + args: Vec, + ) -> yul::Expression { + let name = format!( + "$call_external__{}", + db.codegen_function_symbol_name(function) + ); + self.create_then_call(&name, args, |provider| { + contract::make_external_call(provider, db, &name, function) + }) + } + + fn map_value_ptr( + &mut self, + db: &dyn CodegenDb, + map_ptr: yul::Expression, + key: yul::Expression, + key_ty: TypeId, + ) -> yul::Expression { + if key_ty.is_primitive(db.upcast()) { + let name = "$map_value_ptr_with_primitive_key"; + self.create_then_call(name, vec![map_ptr, key], |provider| { + data::make_map_value_ptr_with_primitive_key(provider, db, name, key_ty) + }) + } else if key_ty.is_mptr(db.upcast()) { + let name = "$map_value_ptr_with_ptr_key"; + self.create_then_call(name, vec![map_ptr, key], |provider| { + data::make_map_value_ptr_with_ptr_key(provider, db, name, key_ty) + }) + } else { + unreachable!() + } + } + + fn aggregate_init( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + mut args: Vec, + ptr_ty: TypeId, + arg_tys: Vec, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let deref_ty = ptr_ty.deref(db.upcast()); + + // Handle unit enum variant. + if args.len() == 1 && deref_ty.is_enum(db.upcast()) { + let tag = args.pop().unwrap(); + let tag_ty = arg_tys[0]; + let is_sptr = ptr_ty.is_sptr(db.upcast()); + return self.ptr_store(db, ptr, tag, make_ptr(db, tag_ty, is_sptr)); + } + + let deref_ty = ptr_ty.deref(db.upcast()); + let args = std::iter::once(ptr).chain(args).collect(); + let legalized_ty = db.codegen_legalized_type(ptr_ty); + if deref_ty.is_enum(db.upcast()) { + let mut name = format!("enum_init_{}", ptr_ty.0); + for ty in &arg_tys { + write!(&mut name, "_{}", ty.0).unwrap(); + } + self.create_then_call(&name, args, |provider| { + data::make_enum_init(provider, db, &name, legalized_ty, arg_tys) + }) + } else { + let name = format!("$aggregate_init_{}", ptr_ty.0); + self.create_then_call(&name, args, |provider| { + data::make_aggregate_init(provider, db, &name, legalized_ty, arg_tys) + }) + } + } + + fn string_copy( + &mut self, + db: &dyn CodegenDb, + dst: yul::Expression, + data: &str, + is_dst_storage: bool, + ) -> yul::Expression { + debug_assert!(data.is_ascii()); + let symbol_name = db.codegen_constant_string_symbol_name(data.to_string()); + + let name = if is_dst_storage { + format!("$string_copy_{symbol_name}_storage") + } else { + format!("$string_copy_{symbol_name}_memory") + }; + + self.create_then_call(&name, vec![dst], |provider| { + data::make_string_copy(provider, db, &name, data, is_dst_storage) + }) + } + + fn string_construct( + &mut self, + db: &dyn CodegenDb, + data: &str, + string_len: usize, + ) -> yul::Expression { + debug_assert!(data.is_ascii()); + debug_assert!(string_len >= data.len()); + let symbol_name = db.codegen_constant_string_symbol_name(data.to_string()); + + let name = format!("$string_construct_{symbol_name}"); + let arg = literal_expression!((32 + string_len)); + self.create_then_call(&name, vec![arg], |provider| { + data::make_string_construct(provider, db, &name, data) + }) + } + + fn ptr_copy( + &mut self, + _db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + size: yul::Expression, + is_src_storage: bool, + is_dst_storage: bool, + ) -> yul::Expression { + let args = vec![src, dst, size]; + match (is_src_storage, is_dst_storage) { + (true, true) => { + let name = "scopys"; + self.create_then_call(name, args, |_| data::make_scopys(name)) + } + (true, false) => { + let name = "scopym"; + self.create_then_call(name, args, |_| data::make_scopym(name)) + } + (false, true) => { + let name = "mcopys"; + self.create_then_call(name, args, |_| data::make_mcopys(name)) + } + (false, false) => { + let name = "mcopym"; + self.create_then_call(name, args, |_| data::make_mcopym(name)) + } + } + } + + fn ptr_store( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + imm: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let size = ptr_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE); + debug_assert!(size <= 32); + + let size_bits = size * 8; + if ptr_ty.is_sptr(db.upcast()) { + let name = "$sptr_store"; + let args = vec![ptr, imm, literal_expression! {(size_bits)}]; + self.create_then_call(name, args, |_| data::make_sptr_store(name)) + } else if ptr_ty.is_mptr(db.upcast()) { + let name = "$mptr_store"; + let shift_num = literal_expression! {(256 - size_bits)}; + let mask = BitMask::new(32 - size); + let args = vec![ptr, imm, shift_num, mask.as_expr()]; + self.create_then_call(name, args, |_| data::make_mptr_store(name)) + } else { + unreachable!() + } + } + + fn ptr_load( + &mut self, + db: &dyn CodegenDb, + ptr: yul::Expression, + ptr_ty: TypeId, + ) -> yul::Expression { + debug_assert!(ptr_ty.is_ptr(db.upcast())); + let size = ptr_ty.deref(db.upcast()).size_of(db.upcast(), SLOT_SIZE); + debug_assert!(size <= 32); + + let size_bits = size * 8; + if ptr_ty.is_sptr(db.upcast()) { + let name = "$sptr_load"; + let args = vec![ptr, literal_expression! {(size_bits)}]; + self.create_then_call(name, args, |_| data::make_sptr_load(name)) + } else if ptr_ty.is_mptr(db.upcast()) { + let name = "$mptr_load"; + let shift_num = literal_expression! {(256 - size_bits)}; + let args = vec![ptr, shift_num]; + self.create_then_call(name, args, |_| data::make_mptr_load(name)) + } else { + unreachable!() + } + } + + fn abi_encode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + dst: yul::Expression, + src_ty: TypeId, + is_dst_storage: bool, + ) -> yul::Expression { + let legalized_ty = db.codegen_legalized_type(src_ty); + let args = vec![src.clone(), dst.clone()]; + + let func_name_postfix = if is_dst_storage { "storage" } else { "memory" }; + + if legalized_ty.is_primitive(db.upcast()) { + let name = format!( + "$abi_encode_primitive_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + return self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_primitive_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }); + } + + let deref_ty = legalized_ty.deref(db.upcast()); + let abi_ty = db.codegen_abi_type(deref_ty); + match abi_ty { + AbiType::UInt(_) | AbiType::Int(_) | AbiType::Bool | AbiType::Address => { + let value = self.ptr_load(db, src, src_ty); + let extended_value = self.primitive_cast(db, value, deref_ty); + self.abi_encode(db, extended_value, dst, deref_ty, is_dst_storage) + } + AbiType::Array { elem_ty, .. } => { + if elem_ty.is_static() { + let name = format!( + "$abi_encode_static_array_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_static_array_type(provider, db, &name, legalized_ty) + }) + } else { + let name = format! { + "$abi_encode_dynamic_array_type{}_to_{}", src_ty.0, func_name_postfix + }; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_dynamic_array_type(provider, db, &name, legalized_ty) + }) + } + } + AbiType::Tuple(_) => { + if abi_ty.is_static() { + let name = format!( + "$abi_encode_static_aggregate_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_static_aggregate_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }) + } else { + let name = format!( + "$abi_encode_dynamic_aggregate_type_{}_to_{}", + src_ty.0, func_name_postfix + ); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_dynamic_aggregate_type( + provider, + db, + &name, + legalized_ty, + is_dst_storage, + ) + }) + } + } + AbiType::Bytes => { + let len = match &deref_ty.data(db.upcast()).kind { + TypeKind::Array(ArrayDef { len, .. }) => *len, + _ => unreachable!(), + }; + let name = format! {"$abi_encode_bytes{len}_type_to_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_bytes_type(provider, db, &name, len, is_dst_storage) + }) + } + AbiType::String => { + let name = format! {"$abi_encode_string_type_to_{func_name_postfix}"}; + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_string_type(provider, db, &name, is_dst_storage) + }) + } + AbiType::Function => unreachable!(), + } + } + + fn abi_encode_seq( + &mut self, + db: &dyn CodegenDb, + src: &[yul::Expression], + dst: yul::Expression, + src_tys: &[TypeId], + is_dst_storage: bool, + ) -> yul::Expression { + let mut name = "$abi_encode_value_seq".to_string(); + for ty in src_tys { + write!(&mut name, "_{}", ty.0).unwrap(); + } + + let mut args = vec![dst]; + args.extend(src.iter().cloned()); + self.create_then_call(&name, args, |provider| { + abi::make_abi_encode_seq(provider, db, &name, src_tys, is_dst_storage) + }) + } + + fn abi_decode( + &mut self, + db: &dyn CodegenDb, + src: yul::Expression, + size: yul::Expression, + types: &[TypeId], + abi_loc: AbiSrcLocation, + ) -> yul::Expression { + let mut name = "$abi_decode".to_string(); + for ty in types { + write!(name, "_{}", ty.0).unwrap(); + } + + match abi_loc { + AbiSrcLocation::CallData => write!(name, "_from_calldata").unwrap(), + AbiSrcLocation::Memory => write!(name, "_from_memory").unwrap(), + }; + + self.create_then_call(&name, vec![src, size], |provider| { + abi::make_abi_decode(provider, db, &name, types, abi_loc) + }) + } + + fn safe_add( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_add(self, db, lhs, rhs, ty) + } + + fn safe_sub( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_sub(self, db, lhs, rhs, ty) + } + + fn safe_mul( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_mul(self, db, lhs, rhs, ty) + } + + fn safe_div( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_div(self, db, lhs, rhs, ty) + } + + fn safe_mod( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_mod(self, db, lhs, rhs, ty) + } + + fn safe_pow( + &mut self, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, + ) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + safe_math::dispatch_safe_pow(self, db, lhs, rhs, ty) + } +} + +#[derive(Debug)] +struct RuntimeFunction(yul::FunctionDefinition); + +impl RuntimeFunction { + fn arg_num(&self) -> usize { + self.0.parameters.len() + } + + fn definition(&self) -> yul::FunctionDefinition { + self.0.clone() + } + + /// # Panics + /// Panics if a number of arguments doesn't match the definition. + fn call(&self, args: Vec) -> yul::Expression { + debug_assert_eq!(self.arg_num(), args.len()); + + yul::Expression::FunctionCall(yul::FunctionCall { + identifier: self.0.name.clone(), + arguments: args, + }) + } + + /// Remove this when `yultsur::function_definition!` becomes to return + /// `FunctionDefinition`. + fn from_statement(func: yul::Statement) -> Self { + match func { + yul::Statement::FunctionDefinition(def) => Self(def), + _ => unreachable!(), + } + } +} + +fn make_ptr(db: &dyn CodegenDb, inner: TypeId, is_sptr: bool) -> TypeId { + if is_sptr { + inner.make_sptr(db.upcast()) + } else { + inner.make_mptr(db.upcast()) + } +} + +struct BitMask(BigInt); + +impl BitMask { + fn new(byte_size: usize) -> Self { + debug_assert!(byte_size <= 32); + let one: BigInt = 1usize.into(); + Self((one << (byte_size * 8)) - 1) + } + + fn not(&self) -> Self { + // Bigint is variable length integer, so we need special handling for `not` + // operation. + let one: BigInt = 1usize.into(); + let u256_max = (one << 256) - 1; + Self(u256_max ^ &self.0) + } + + fn as_expr(&self) -> yul::Expression { + let mask = format!("{:#x}", self.0); + literal_expression! {(mask)} + } +} + +pub(super) fn error_revert_numeric( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + error_code: yul::Expression, +) -> yul::Statement { + yul::Statement::Expression(provider.revert( + db, + Some(error_code), + "Error", + yul_primitive_type(db), + )) +} + +pub(super) fn panic_revert_numeric( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, + error_code: yul::Expression, +) -> yul::Statement { + yul::Statement::Expression(provider.revert( + db, + Some(error_code), + "Panic", + yul_primitive_type(db), + )) +} diff --git a/crates/codegen2/src/yul/runtime/revert.rs b/crates/codegen2/src/yul/runtime/revert.rs new file mode 100644 index 0000000000..1f8fa9eb7b --- /dev/null +++ b/crates/codegen2/src/yul/runtime/revert.rs @@ -0,0 +1,91 @@ +use crate::{ + db::CodegenDb, + yul::{slot_size::function_hash_type, YulVariable}, +}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_abi::function::{AbiFunction, AbiFunctionType, StateMutability}; +use fe_mir::ir::{self, TypeId}; +use yultsur::*; + +pub(super) fn make_revert( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, + arg_name: &str, + arg_ty: TypeId, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let arg = YulVariable::new("arg"); + + let abi_size = YulVariable::new("abi_size"); + let abi_tmp_ptr = YulVariable::new("$abi_tmp_ptr"); + let signature = type_signature_for_revert(db, arg_name, arg_ty); + + let signature_store = yul::Statement::Expression(provider.ptr_store( + db, + abi_tmp_ptr.expr(), + signature, + function_hash_type(db).make_mptr(db.upcast()), + )); + + let func = if arg_ty.deref(db.upcast()).is_zero_sized(db.upcast()) { + function_definition! { + function [func_name.ident()]() { + (let [abi_tmp_ptr.ident()] := [provider.avail(db)]) + ([signature_store]) + (revert([abi_tmp_ptr.expr()], [literal_expression!{4}])) + } + } + } else { + let encode = provider.abi_encode_seq( + db, + &[arg.expr()], + expression! { add([abi_tmp_ptr.expr()], 4) }, + &[arg_ty], + false, + ); + + function_definition! { + function [func_name.ident()]([arg.ident()]) { + (let [abi_tmp_ptr.ident()] := [provider.avail(db)]) + ([signature_store]) + (let [abi_size.ident()] := add([encode], 4)) + (revert([abi_tmp_ptr.expr()], [abi_size.expr()])) + } + } + }; + + RuntimeFunction::from_statement(func) +} + +/// Returns signature hash of the type. +fn type_signature_for_revert(db: &dyn CodegenDb, name: &str, ty: TypeId) -> yul::Expression { + let deref_ty = ty.deref(db.upcast()); + let ty_data = deref_ty.data(db.upcast()); + let args = match &ty_data.kind { + ir::TypeKind::Struct(def) => def + .fields + .iter() + .map(|(_, ty)| ("".to_string(), db.codegen_abi_type(*ty))) + .collect(), + _ => { + let abi_ty = db.codegen_abi_type(deref_ty); + vec![("_".to_string(), abi_ty)] + } + }; + + // selector and state mutability is independent we can set has_self and has_ctx + // any value. + let selector = AbiFunction::new( + AbiFunctionType::Function, + name.to_string(), + args, + None, + StateMutability::Pure, + ) + .selector(); + let type_sig = selector.hex(); + literal_expression! {(format!{"0x{type_sig}" })} +} diff --git a/crates/codegen2/src/yul/runtime/safe_math.rs b/crates/codegen2/src/yul/runtime/safe_math.rs new file mode 100644 index 0000000000..5bdfe8b04a --- /dev/null +++ b/crates/codegen2/src/yul/runtime/safe_math.rs @@ -0,0 +1,628 @@ +use crate::{db::CodegenDb, yul::YulVariable}; + +use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; + +use fe_mir::ir::{TypeId, TypeKind}; + +use yultsur::*; + +pub(super) fn dispatch_safe_add( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_add_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_add_signed(provider, db, name) + }) + } else { + let name = "$safe_add_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_add_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_sub( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_sub_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_sub_signed(provider, db, name) + }) + } else { + let name = "$safe_sub_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_sub_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_mul( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_mul_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_mul_signed(provider, db, name) + }) + } else { + let name = "$safe_mul_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_mul_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_div( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_div_signed"; + let args = vec![lhs, rhs, min_value]; + provider.create_then_call(name, args, |provider| { + make_safe_div_signed(provider, db, name) + }) + } else { + let name = "$safe_div_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_div_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_mod( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + if ty.is_signed(db.upcast()) { + let name = "$safe_mod_signed"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_mod_signed(provider, db, name) + }) + } else { + let name = "$safe_mod_unsigned"; + let args = vec![lhs, rhs]; + provider.create_then_call(name, args, |provider| { + make_safe_mod_unsigned(provider, db, name) + }) + } +} + +pub(super) fn dispatch_safe_pow( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + lhs: yul::Expression, + rhs: yul::Expression, + ty: TypeId, +) -> yul::Expression { + debug_assert!(ty.is_integral(db.upcast())); + let min_value = get_min_value(db, ty); + let max_value = get_max_value(db, ty); + + if ty.is_signed(db.upcast()) { + let name = "$safe_pow_signed"; + let args = vec![lhs, rhs, min_value, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_pow_signed(provider, db, name) + }) + } else { + let name = "$safe_pow_unsigned"; + let args = vec![lhs, rhs, max_value]; + provider.create_then_call(name, args, |provider| { + make_safe_pow_unsigned(provider, db, name) + }) + } +} + +fn make_safe_add_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (and((iszero((slt([lhs.expr()], 0)))), (sgt([rhs.expr()], (sub([max_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + (if (and((slt([lhs.expr()], 0)), (slt([rhs.expr()], (sub([min_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := add([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_add_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (gt([lhs.expr()], (sub([max_value.expr()], [rhs.expr()])))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := add([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_sub_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + (if (and((iszero((slt([rhs.expr()], 0)))), (slt([lhs.expr()], (add([min_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + (if (and((slt([rhs.expr()], 0)), (sgt([lhs.expr()], (add([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sub([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_sub_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (lt([lhs.expr()], [rhs.expr()])) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sub([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mul_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()], [max_value.ident()]) -> [ret.ident()] { + // overflow, if lhs > 0, rhs > 0 and lhs > (max_value / rhs) + (if (and((and((sgt([lhs.expr()], 0)), (sgt([rhs.expr()], 0)))), (gt([lhs.expr()], (div([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // underflow, if lhs > 0, rhs < 0 and rhs < (min_value / lhs) + (if (and((and((sgt([lhs.expr()], 0)), (slt([rhs.expr()], 0)))), (slt([rhs.expr()], (sdiv([min_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // underflow, if lhs < 0, rhs > 0 and lhs < (min_value / rhs) + (if (and((and((slt([lhs.expr()], 0)), (sgt([rhs.expr()], 0)))), (slt([lhs.expr()], (sdiv([min_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + // overflow, if lhs < 0, rhs < 0 and lhs < (max_value / rhs) + (if (and((and((slt([lhs.expr()], 0)), (slt([rhs.expr()], 0)))), (slt([lhs.expr()], (sdiv([max_value.expr()], [rhs.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := mul([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mul_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let max_value = YulVariable::new("$max_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [max_value.ident()]) -> [ret.ident()] { + // overflow, if lhs != 0 and rhs > (max_value / lhs) + (if (and((iszero((iszero([lhs.expr()])))), (gt([rhs.expr()], (div([max_value.expr()], [lhs.expr()])))))) { [revert_with_overflow(provider ,db)] }) + ([ret.ident()] := mul([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_div_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let min_value = YulVariable::new("$min_value"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()], [min_value.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + (if (and( (eq([lhs.expr()], [min_value.expr()])), (eq([rhs.expr()], (sub(0, 1))))) ) { [revert_with_overflow(provider, db)] }) + ([ret.ident()] := sdiv([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_div_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := div([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mod_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := smod([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_mod_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let lhs = YulVariable::new("$lhs"); + let rhs = YulVariable::new("$rhs"); + let ret = YulVariable::new("$ret"); + + let func = function_definition! { + function [func_name.ident()]([lhs.ident()], [rhs.ident()]) -> [ret.ident()] { + (if (iszero([rhs.expr()])) { [revert_with_zero_division(provider, db)] }) + ([ret.ident()] := mod([lhs.expr()], [rhs.expr()])) + } + }; + RuntimeFunction::from_statement(func) +} + +const SAFE_POW_HELPER_NAME: &str = "safe_pow_helper"; + +fn make_safe_pow_unsigned( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let max_value = YulVariable::new("max_value"); + let power = YulVariable::new("power"); + + let safe_pow_helper_call = yul::Statement::Assignment(yul::Assignment { + identifiers: vec![base.ident(), power.ident()], + expression: { + let args = vec![ + base.expr(), + exponent.expr(), + literal_expression! {1}, + max_value.expr(), + ]; + provider.create_then_call(SAFE_POW_HELPER_NAME, args, |provider| { + make_safe_exp_helper(provider, db, SAFE_POW_HELPER_NAME) + }) + }, + }); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [max_value.ident()]) -> [power.ident()] { + // Currently, `leave` avoids this function being inlined. + // YUL team is working on optimizer improvements to fix that. + + // Note that 0**0 == 1 + (if (iszero([exponent.expr()])) { + ([power.ident()] := 1 ) + (leave) + }) + (if (iszero([base.expr()])) { + ([power.ident()] := 0 ) + (leave) + }) + // Specializations for small bases + ([switch! { + switch [base.expr()] + // 0 is handled above + (case 1 { + ([power.ident()] := 1 ) + (leave) + }) + (case 2 { + (if (gt([exponent.expr()], 255)) { + [revert_with_overflow(provider, db)] + }) + ([power.ident()] := (exp(2, [exponent.expr()]))) + (if (gt([power.expr()], [max_value.expr()])) { + [revert_with_overflow(provider, db)] + }) + (leave) + }) + }]) + (if (and((sgt([power.expr()], 0)), (gt([power.expr()], (div([max_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider, db)] }) + + (if (or((and((lt([base.expr()], 11)), (lt([exponent.expr()], 78)))), (and((lt([base.expr()], 307)), (lt([exponent.expr()], 32)))))) { + ([power.ident()] := (exp([base.expr()], [exponent.expr()]))) + (if (gt([power.expr()], [max_value.expr()])) { + [revert_with_overflow(provider, db)] + }) + (leave) + }) + + ([safe_pow_helper_call]) + (if (gt([power.expr()], (div([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + ([power.ident()] := (mul([power.expr()], [base.expr()]))) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_pow_signed( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let min_value = YulVariable::new("min_value"); + let max_value = YulVariable::new("max_value"); + let power = YulVariable::new("power"); + + let safe_pow_helper_call = yul::Statement::Assignment(yul::Assignment { + identifiers: vec![base.ident(), power.ident()], + expression: { + let args = vec![base.expr(), exponent.expr(), power.expr(), max_value.expr()]; + provider.create_then_call(SAFE_POW_HELPER_NAME, args, |provider| { + make_safe_exp_helper(provider, db, SAFE_POW_HELPER_NAME) + }) + }, + }); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [min_value.ident()], [max_value.ident()]) -> [power.ident()] { + // Currently, `leave` avoids this function being inlined. + // YUL team is working on optimizer improvements to fix that. + + // Note that 0**0 == 1 + ([switch! { + switch [exponent.expr()] + (case 0 { + ([power.ident()] := 1 ) + (leave) + }) + (case 1 { + ([power.ident()] := [base.expr()] ) + (leave) + }) + }]) + (if (iszero([base.expr()])) { + ([power.ident()] := 0 ) + (leave) + }) + ([power.ident()] := 1 ) + // We pull out the first iteration because it is the only one in which + // base can be negative. + // Exponent is at least 2 here. + // overflow check for base * base + ([switch! { + switch (sgt([base.expr()], 0)) + (case 1 { + (if (gt([base.expr()], (div([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + }) + (case 0 { + (if (slt([base.expr()], (sdiv([max_value.expr()], [base.expr()])))) { + [revert_with_overflow(provider, db)] + }) + }) + }]) + (if (and([exponent.expr()], 1)) { + ([power.ident()] := [base.expr()] ) + }) + ([base.ident()] := (mul([base.expr()], [base.expr()]))) + ([exponent.ident()] := shr(1, [exponent.expr()])) + // // Below this point, base is always positive. + ([safe_pow_helper_call]) // power = 1, base = 16 which is wrong + (if (and((sgt([power.expr()], 0)), (gt([power.expr()], (div([max_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider , db)] }) + (if (and((slt([power.expr()], 0)), (slt([power.expr()], (sdiv([min_value.expr()], [base.expr()])))))) { [revert_with_overflow(provider, db)] }) + ([power.ident()] := (mul([power.expr()], [base.expr()]))) + } + }; + RuntimeFunction::from_statement(func) +} + +fn make_safe_exp_helper( + provider: &mut DefaultRuntimeProvider, + db: &dyn CodegenDb, + func_name: &str, +) -> RuntimeFunction { + let func_name = YulVariable::new(func_name); + let base = YulVariable::new("base"); + let exponent = YulVariable::new("exponent"); + let power = YulVariable::new("power"); + let max_value = YulVariable::new("max_value"); + let ret_power = YulVariable::new("ret_power"); + let ret_base = YulVariable::new("ret_base"); + + let func = function_definition! { + function [func_name.ident()]([base.ident()], [exponent.ident()], [power.ident()], [max_value.ident()]) + -> [(vec![ret_base.ident(), ret_power.ident()])...] { + ([ret_base.ident()] := [base.expr()]) + ([ret_power.ident()] := [power.expr()]) + (for {} (gt([exponent.expr()], 1)) {} + { + // overflow check for base * base + (if (gt([ret_base.expr()], (div([max_value.expr()], [ret_base.expr()])))) { [revert_with_overflow(provider, db)] }) + (if (and([exponent.expr()], 1)) { + // No checks for power := mul(power, base) needed, because the check + // for base * base above is sufficient, since: + // |power| <= base (proof by induction) and thus: + // |power * base| <= base * base <= max <= |min| (for signed) + // (this is equally true for signed and unsigned exp) + ([ret_power.ident()] := (mul([ret_power.expr()], [ret_base.expr()]))) + }) + ([ret_base.ident()] := mul([ret_base.expr()], [ret_base.expr()])) + ([exponent.ident()] := shr(1, [exponent.expr()])) + }) + } + }; + RuntimeFunction::from_statement(func) +} + +fn revert_with_overflow(provider: &mut dyn RuntimeProvider, db: &dyn CodegenDb) -> yul::Statement { + const PANIC_OVERFLOW: usize = 0x11; + + super::panic_revert_numeric(provider, db, literal_expression! {(PANIC_OVERFLOW)}) +} + +fn revert_with_zero_division( + provider: &mut dyn RuntimeProvider, + db: &dyn CodegenDb, +) -> yul::Statement { + pub const PANIC_ZERO_DIVISION: usize = 0x12; + + super::panic_revert_numeric(provider, db, literal_expression! {(PANIC_ZERO_DIVISION)}) +} + +fn get_max_value(db: &dyn CodegenDb, ty: TypeId) -> yul::Expression { + match &ty.data(db.upcast()).kind { + TypeKind::I8 => literal_expression! {0x7f}, + TypeKind::I16 => literal_expression! {0x7fff}, + TypeKind::I32 => literal_expression! {0x7fffffff}, + TypeKind::I64 => literal_expression! {0x7fffffffffffffff}, + TypeKind::I128 => literal_expression! {0x7fffffffffffffffffffffffffffffff}, + TypeKind::I256 => { + literal_expression! {0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff} + } + TypeKind::U8 => literal_expression! {0xff}, + TypeKind::U16 => literal_expression! {0xffff}, + TypeKind::U32 => literal_expression! {0xffffffff}, + TypeKind::U64 => literal_expression! {0xffffffffffffffff}, + TypeKind::U128 => literal_expression! {0xffffffffffffffffffffffffffffffff}, + TypeKind::U256 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff} + } + _ => unreachable!(), + } +} + +fn get_min_value(db: &dyn CodegenDb, ty: TypeId) -> yul::Expression { + debug_assert! {ty.is_integral(db.upcast())}; + + match &ty.data(db.upcast()).kind { + TypeKind::I8 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff80} + } + TypeKind::I16 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8000} + } + TypeKind::I32 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffffffffffff80000000} + } + TypeKind::I64 => { + literal_expression! {0xffffffffffffffffffffffffffffffffffffffffffffffff8000000000000000} + } + TypeKind::I128 => { + literal_expression! {0xffffffffffffffffffffffffffffffff80000000000000000000000000000000} + } + TypeKind::I256 => { + literal_expression! {0x8000000000000000000000000000000000000000000000000000000000000000} + } + + _ => literal_expression! {0x0}, + } +} diff --git a/crates/codegen2/src/yul/slot_size.rs b/crates/codegen2/src/yul/slot_size.rs new file mode 100644 index 0000000000..aa931132c0 --- /dev/null +++ b/crates/codegen2/src/yul/slot_size.rs @@ -0,0 +1,16 @@ +use fe_mir::ir::{Type, TypeId, TypeKind}; + +use crate::db::CodegenDb; + +// We use the same slot size between memory and storage to simplify the +// implementation and minimize gas consumption in memory <-> storage copy +// instructions. +pub(crate) const SLOT_SIZE: usize = 32; + +pub(crate) fn yul_primitive_type(db: &dyn CodegenDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::U256, None).into()) +} + +pub(crate) fn function_hash_type(db: &dyn CodegenDb) -> TypeId { + db.mir_intern_type(Type::new(TypeKind::U32, None).into()) +} From b1d67628f7f8ef228f4fb4940e59243aa4ebaec2 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Tue, 23 Jan 2024 13:48:40 -0700 Subject: [PATCH 13/22] hacking --- Cargo.lock | 18 ++++ crates/codegen2/Cargo.toml | 3 +- crates/codegen2/src/{db/queries => }/abi.rs | 10 +- .../codegen2/src/{db/queries => }/constant.rs | 2 +- .../codegen2/src/{db/queries => }/contract.rs | 6 +- crates/codegen2/src/db.rs | 94 ------------------ crates/codegen2/src/db/queries.rs | 5 - crates/codegen2/src/db/queries/function.rs | 76 --------------- crates/codegen2/src/function.rs | 74 ++++++++++++++ crates/codegen2/src/lib.rs | 96 ++++++++++++++++++- crates/codegen2/src/{db/queries => }/types.rs | 3 +- crates/codegen2/src/yul/isel/context.rs | 6 +- crates/codegen2/src/yul/isel/contract.rs | 12 +-- crates/codegen2/src/yul/isel/function.rs | 2 +- crates/codegen2/src/yul/isel/test.rs | 5 +- crates/codegen2/src/yul/legalize/body.rs | 2 +- crates/codegen2/src/yul/legalize/signature.rs | 2 +- crates/codegen2/src/yul/runtime/abi.rs | 2 +- crates/codegen2/src/yul/runtime/contract.rs | 8 +- crates/codegen2/src/yul/runtime/data.rs | 2 +- crates/codegen2/src/yul/runtime/emit.rs | 2 +- crates/codegen2/src/yul/runtime/mod.rs | 12 +-- crates/codegen2/src/yul/runtime/revert.rs | 2 +- crates/codegen2/src/yul/runtime/safe_math.rs | 2 +- crates/codegen2/src/yul/slot_size.rs | 2 +- crates/mir2/src/lower/mod.rs | 2 +- crates/mir2/tests/lowering.rs | 8 +- crates/mir2/tests/test_db.rs | 7 +- 28 files changed, 241 insertions(+), 224 deletions(-) rename crates/codegen2/src/{db/queries => }/abi.rs (97%) rename crates/codegen2/src/{db/queries => }/constant.rs (90%) rename crates/codegen2/src/{db/queries => }/contract.rs (84%) delete mode 100644 crates/codegen2/src/db.rs delete mode 100644 crates/codegen2/src/db/queries.rs delete mode 100644 crates/codegen2/src/db/queries/function.rs create mode 100644 crates/codegen2/src/function.rs rename crates/codegen2/src/{db/queries => }/types.rs (98%) diff --git a/Cargo.lock b/Cargo.lock index 13345c2fc6..fbbc1935cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -932,6 +932,24 @@ dependencies = [ "yultsur", ] +[[package]] +name = "fe-codegen2" +version = "0.23.0" +dependencies = [ + "fe-abi", + "fe-common", + "fe-hir", + "fe-hir-analysis", + "fe-mir", + "fe-mir2", + "fxhash", + "indexmap", + "num-bigint", + "salsa-2022", + "smol_str", + "yultsur", +] + [[package]] name = "fe-common" version = "0.23.0" diff --git a/crates/codegen2/Cargo.toml b/crates/codegen2/Cargo.toml index d6b46fb289..ad8af1b77e 100644 --- a/crates/codegen2/Cargo.toml +++ b/crates/codegen2/Cargo.toml @@ -7,10 +7,11 @@ edition = "2021" [dependencies] hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } hir = { path = "../hir", package = "fe-hir" } +mir = { path = "../mir2", package = "fe-mir2" } fe-mir = { path = "../mir", version = "^0.23.0"} fe-common = { path = "../common", version = "^0.23.0"} fe-abi = { path = "../abi", version = "^0.23.0"} -salsa = "0.16.1" +salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } num-bigint = "0.4.3" fxhash = "0.2.1" indexmap = "1.6.2" diff --git a/crates/codegen2/src/db/queries/abi.rs b/crates/codegen2/src/abi.rs similarity index 97% rename from crates/codegen2/src/db/queries/abi.rs rename to crates/codegen2/src/abi.rs index e492166e86..8331be96e7 100644 --- a/crates/codegen2/src/db/queries/abi.rs +++ b/crates/codegen2/src/abi.rs @@ -13,8 +13,7 @@ use fe_analyzer::{ }; use fe_mir::ir::{self, FunctionId, TypeId}; -use crate::db::CodegenDb; - +#[salsa::tracked(return_ref)] pub fn abi_contract(db: &dyn CodegenDb, contract: ContractId) -> AbiContract { let mut funcs = vec![]; @@ -46,6 +45,7 @@ pub fn abi_contract(db: &dyn CodegenDb, contract: ContractId) -> AbiContract { AbiContract::new(funcs, events) } +#[salsa::tracked(return_ref)] pub fn abi_function(db: &dyn CodegenDb, function: FunctionId) -> AbiFunction { // We use a legalized signature. let sig = db.codegen_legalized_signature(function); @@ -87,6 +87,7 @@ pub fn abi_function(db: &dyn CodegenDb, function: FunctionId) -> AbiFunction { AbiFunction::new(func_type, name.to_string(), args, ret_ty, state_mutability) } +#[salsa::tracked(return_ref)] pub fn abi_function_argument_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { let sig = db.codegen_legalized_signature(function); sig.params.iter().fold(0, |acc, param| { @@ -94,6 +95,7 @@ pub fn abi_function_argument_maximum_size(db: &dyn CodegenDb, function: Function }) } +#[salsa::tracked(return_ref)] pub fn abi_function_return_maximum_size(db: &dyn CodegenDb, function: FunctionId) -> usize { let sig = db.codegen_legalized_signature(function); sig.return_type @@ -101,6 +103,7 @@ pub fn abi_function_return_maximum_size(db: &dyn CodegenDb, function: FunctionId .unwrap_or_default() } +#[salsa::tracked(return_ref)] pub fn abi_type_maximum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { let abi_type = db.codegen_abi_type(ty); if abi_type.is_static() { @@ -132,6 +135,7 @@ pub fn abi_type_maximum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { } } +#[salsa::tracked(return_ref)] pub fn abi_type_minimum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { let abi_type = db.codegen_abi_type(ty); if abi_type.is_static() { @@ -162,6 +166,7 @@ pub fn abi_type_minimum_size(db: &dyn CodegenDb, ty: TypeId) -> usize { } } +#[salsa::tracked(return_ref)] pub fn abi_type(db: &dyn CodegenDb, ty: TypeId) -> AbiType { let legalized_ty = db.codegen_legalized_type(ty); @@ -236,6 +241,7 @@ pub fn abi_type(db: &dyn CodegenDb, ty: TypeId) -> AbiType { } } +#[salsa::tracked(return_ref)] pub fn abi_event(db: &dyn CodegenDb, ty: TypeId) -> AbiEvent { debug_assert!(ty.is_struct(db.upcast())); diff --git a/crates/codegen2/src/db/queries/constant.rs b/crates/codegen2/src/constant.rs similarity index 90% rename from crates/codegen2/src/db/queries/constant.rs rename to crates/codegen2/src/constant.rs index 2a78aba9b4..b848084e8b 100644 --- a/crates/codegen2/src/db/queries/constant.rs +++ b/crates/codegen2/src/constant.rs @@ -1,10 +1,10 @@ -use crate::db::CodegenDb; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, rc::Rc, }; +#[salsa::tracked(return_ref)] pub fn string_symbol_name(_db: &dyn CodegenDb, data: String) -> Rc { let mut hasher = DefaultHasher::new(); data.hash(&mut hasher); diff --git a/crates/codegen2/src/db/queries/contract.rs b/crates/codegen2/src/contract.rs similarity index 84% rename from crates/codegen2/src/db/queries/contract.rs rename to crates/codegen2/src/contract.rs index be0002a371..b39fc140b0 100644 --- a/crates/codegen2/src/db/queries/contract.rs +++ b/crates/codegen2/src/contract.rs @@ -1,9 +1,6 @@ use std::rc::Rc; -use fe_analyzer::namespace::items::ContractId; - -use crate::db::CodegenDb; - +#[salsa::tracked(return_ref)] pub fn symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { let module = contract.module(db.upcast()); @@ -15,6 +12,7 @@ pub fn symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { .into() } +#[salsa::tracked(return_ref)] pub fn deployer_symbol_name(db: &dyn CodegenDb, contract: ContractId) -> Rc { format!("deploy_{}", symbol_name(db, contract).as_ref()).into() } diff --git a/crates/codegen2/src/db.rs b/crates/codegen2/src/db.rs deleted file mode 100644 index ce036795a3..0000000000 --- a/crates/codegen2/src/db.rs +++ /dev/null @@ -1,94 +0,0 @@ -#![allow(clippy::arc_with_non_send_sync)] -use std::rc::Rc; - -use fe_abi::{contract::AbiContract, event::AbiEvent, function::AbiFunction, types::AbiType}; -use fe_analyzer::{db::AnalyzerDbStorage, namespace::items::ContractId, AnalyzerDb}; -use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; -use fe_mir::{ - db::{MirDb, MirDbStorage}, - ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, -}; - -mod queries; - -#[salsa::query_group(CodegenDbStorage)] -pub trait CodegenDb: MirDb + Upcast + UpcastMut { - #[salsa::invoke(queries::function::legalized_signature)] - fn codegen_legalized_signature(&self, function_id: FunctionId) -> Rc; - #[salsa::invoke(queries::function::legalized_body)] - fn codegen_legalized_body(&self, function_id: FunctionId) -> Rc; - #[salsa::invoke(queries::function::symbol_name)] - fn codegen_function_symbol_name(&self, function_id: FunctionId) -> Rc; - - #[salsa::invoke(queries::types::legalized_type)] - fn codegen_legalized_type(&self, ty: TypeId) -> TypeId; - - #[salsa::invoke(queries::abi::abi_type)] - fn codegen_abi_type(&self, ty: TypeId) -> AbiType; - #[salsa::invoke(queries::abi::abi_function)] - fn codegen_abi_function(&self, function_id: FunctionId) -> AbiFunction; - #[salsa::invoke(queries::abi::abi_event)] - fn codegen_abi_event(&self, ty: TypeId) -> AbiEvent; - #[salsa::invoke(queries::abi::abi_contract)] - fn codegen_abi_contract(&self, contract: ContractId) -> AbiContract; - #[salsa::invoke(queries::abi::abi_type_maximum_size)] - fn codegen_abi_type_maximum_size(&self, ty: TypeId) -> usize; - #[salsa::invoke(queries::abi::abi_type_minimum_size)] - fn codegen_abi_type_minimum_size(&self, ty: TypeId) -> usize; - #[salsa::invoke(queries::abi::abi_function_argument_maximum_size)] - fn codegen_abi_function_argument_maximum_size(&self, contract: FunctionId) -> usize; - #[salsa::invoke(queries::abi::abi_function_return_maximum_size)] - fn codegen_abi_function_return_maximum_size(&self, function: FunctionId) -> usize; - - #[salsa::invoke(queries::contract::symbol_name)] - fn codegen_contract_symbol_name(&self, contract: ContractId) -> Rc; - #[salsa::invoke(queries::contract::deployer_symbol_name)] - fn codegen_contract_deployer_symbol_name(&self, contract: ContractId) -> Rc; - - #[salsa::invoke(queries::constant::string_symbol_name)] - fn codegen_constant_string_symbol_name(&self, data: String) -> Rc; -} - -// TODO: Move this to driver. -#[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage, CodegenDbStorage)] -#[derive(Default)] -pub struct Db { - storage: salsa::Storage, -} -impl salsa::Database for Db {} - -impl Upcast for Db { - fn upcast(&self) -> &(dyn MirDb + 'static) { - self - } -} - -impl UpcastMut for Db { - fn upcast_mut(&mut self) -> &mut (dyn MirDb + 'static) { - &mut *self - } -} - -impl Upcast for Db { - fn upcast(&self) -> &(dyn SourceDb + 'static) { - self - } -} - -impl UpcastMut for Db { - fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { - &mut *self - } -} - -impl Upcast for Db { - fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { - self - } -} - -impl UpcastMut for Db { - fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { - &mut *self - } -} diff --git a/crates/codegen2/src/db/queries.rs b/crates/codegen2/src/db/queries.rs deleted file mode 100644 index 31cca43870..0000000000 --- a/crates/codegen2/src/db/queries.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod abi; -pub mod constant; -pub mod contract; -pub mod function; -pub mod types; diff --git a/crates/codegen2/src/db/queries/function.rs b/crates/codegen2/src/db/queries/function.rs deleted file mode 100644 index d4527271e4..0000000000 --- a/crates/codegen2/src/db/queries/function.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::rc::Rc; - -use fe_analyzer::{ - display::Displayable, - namespace::{ - items::Item, - types::{Type, TypeId}, - }, -}; -use fe_mir::ir::{FunctionBody, FunctionId, FunctionSignature}; -use salsa::InternKey; -use smol_str::SmolStr; - -use crate::{db::CodegenDb, yul::legalize}; - -pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionId) -> Rc { - let mut sig = function.signature(db.upcast()).as_ref().clone(); - legalize::legalize_func_signature(db, &mut sig); - sig.into() -} - -pub fn legalized_body(db: &dyn CodegenDb, function: FunctionId) -> Rc { - let mut body = function.body(db.upcast()).as_ref().clone(); - legalize::legalize_func_body(db, &mut body); - body.into() -} - -pub fn symbol_name(db: &dyn CodegenDb, function: FunctionId) -> Rc { - let module = function.signature(db.upcast()).module_id; - let module_name = module.name(db.upcast()); - - let analyzer_func = function.analyzer_func(db.upcast()); - let func_name = format!( - "{}{}", - analyzer_func.name(db.upcast()), - type_suffix(function, db) - ); - - let func_name = match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { - Some(Item::Impl(id)) => { - let class_name = format!( - "{}${}", - id.trait_id(db.upcast()).name(db.upcast()), - safe_name(db, id.receiver(db.upcast())) - ); - format!("{class_name}${func_name}") - } - Some(class) => { - let class_name = class.name(db.upcast()); - format!("{class_name}${func_name}") - } - _ => func_name, - }; - - format!("{module_name}${func_name}").into() -} - -fn type_suffix(function: FunctionId, db: &dyn CodegenDb) -> SmolStr { - function - .signature(db.upcast()) - .resolved_generics - .values() - .fold(String::new(), |acc, param| { - format!("{}_{}", acc, safe_name(db, *param)) - }) - .into() -} - -fn safe_name(db: &dyn CodegenDb, ty: TypeId) -> SmolStr { - match ty.typ(db.upcast()) { - // TODO: Would be nice to get more human friendly names here - Type::Array(_) => format!("array_{:?}", ty.as_intern_id()).into(), - Type::Tuple(_) => format!("tuple_{:?}", ty.as_intern_id()).into(), - _ => format!("{}", ty.display(db.upcast())).into(), - } -} diff --git a/crates/codegen2/src/function.rs b/crates/codegen2/src/function.rs new file mode 100644 index 0000000000..89f91c799e --- /dev/null +++ b/crates/codegen2/src/function.rs @@ -0,0 +1,74 @@ +use std::rc::Rc; + +use fe_mir::ir::{FunctionBody, FunctionId, FunctionSignature}; +use hir::hir_def::TypeId; +use smol_str::SmolStr; + +use crate::CodegenDb; + +#[salsa::tracked(return_ref)] +pub fn legalized_signature(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut sig = function.signature(db.upcast()).as_ref().clone(); + db.legalize_func_signature(&mut sig); + sig.into() +} + +#[salsa::tracked(return_ref)] +pub fn legalized_body(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let mut body = function.body(db.upcast()).as_ref().clone(); + db.legalize_func_body(&mut body); + body.into() +} + +#[salsa::tracked(return_ref)] +pub fn symbol_name(db: &dyn CodegenDb, function: FunctionId) -> Rc { + let module = function.signature(db.upcast()).module_id; + let module_name = module.name(db.upcast()); + + let analyzer_func = function.analyzer_func(db.upcast()); + let func_name = format!( + "{}{}", + analyzer_func.name(db.upcast()), + type_suffix(function, db) + ); + + // let func_name = match analyzer_func.sig(db.upcast()).self_item(db.upcast()) { + // Some(Item::Impl(id)) => { + // let class_name = format!( + // "{}${}", + // id.trait_id(db.upcast()).name(db.upcast()), + // safe_name(db, id.receiver(db.upcast())) + // ); + // format!("{class_name}${func_name}") + // } + // Some(class) => { + // let class_name = class.name(db.upcast()); + // format!("{class_name}${func_name}") + // } + // _ => func_name, + // }; + + // format!("{module_name}${func_name}").into() + "".into() +} + +fn type_suffix(function: FunctionId, db: &dyn CodegenDb) -> SmolStr { + function + .signature(db.upcast()) + .resolved_generics + .values() + .fold(String::new(), |acc, param| { + format!("{}_{}", acc, safe_name(db, *param)) + }) + .into() +} + +fn safe_name(db: &dyn CodegenDb, ty: TypeId) -> SmolStr { + // match ty.typ(db.upcast()) { + // // TODO: Would be nice to get more human friendly names here + // Type::Array(_) => format!("array_{:?}", ty.as_intern_id()).into(), + // Type::Tuple(_) => format!("tuple_{:?}", ty.as_intern_id()).into(), + // _ => format!("{}", ty.display(db.upcast())).into(), + // } + "".into() +} diff --git a/crates/codegen2/src/lib.rs b/crates/codegen2/src/lib.rs index 37ec962db2..d1dbe89e9f 100644 --- a/crates/codegen2/src/lib.rs +++ b/crates/codegen2/src/lib.rs @@ -1,2 +1,96 @@ -pub mod db; +use mir::MirDb; + pub mod yul; + +// mod abi; +// mod constant; +// mod contract; +mod function; +// mod types; + +#[salsa::jar(db = CodegenDb)] +pub struct Jar( + function::legalized_signature, + function::legalized_body, + function::symbol_name, + // types::legalized_type, + // abi::abi_type, + // abi::abi_function, + // abi::abi_event, + // abi::abi_contract, + // abi::abi_type_maximum_size, + // abi::abi_type_minimum_size, + // abi::abi_function_argument_maximum_size, + // abi::abi_function_return_maximum_size, + // contract::symbol_name, + // contract::deployer_symbol_name, + // constant::string_symbol_name, +); + +pub trait CodegenDb: salsa::DbWithJar + MirDb { + fn as_hir_db(&self) -> &dyn CodegenDb { + >::as_jar_db::<'_>(self) + } +} +impl CodegenDb for DB where DB: salsa::DbWithJar + MirDb {} + +// #![allow(clippy::arc_with_non_send_sync)] +// use std::rc::Rc; + +// use fe_abi::{contract::AbiContract, event::AbiEvent, function::AbiFunction, types::AbiType}; +// use fe_analyzer::{db::AnalyzerDbStorage, namespace::items::ContractId, AnalyzerDb}; +// use fe_common::db::{SourceDb, SourceDbStorage, Upcast, UpcastMut}; +// use fe_mir::{ +// db::{MirDb, MirDbStorage}, +// ir::{FunctionBody, FunctionId, FunctionSignature, TypeId}, +// }; + +// mod queries; + +// #[salsa::query_group(CodegenDbStorage)] +// pub trait CodegenDb: MirDb + Upcast + UpcastMut { +// } + +// // TODO: Move this to driver. +// #[salsa::database(SourceDbStorage, AnalyzerDbStorage, MirDbStorage, CodegenDbStorage)] +// #[derive(Default)] +// pub struct Db { +// storage: salsa::Storage, +// } +// impl salsa::Database for Db {} + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn MirDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn MirDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn SourceDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) { +// &mut *self +// } +// } + +// impl Upcast for Db { +// fn upcast(&self) -> &(dyn AnalyzerDb + 'static) { +// self +// } +// } + +// impl UpcastMut for Db { +// fn upcast_mut(&mut self) -> &mut (dyn AnalyzerDb + 'static) { +// &mut *self +// } +// } diff --git a/crates/codegen2/src/db/queries/types.rs b/crates/codegen2/src/types.rs similarity index 98% rename from crates/codegen2/src/db/queries/types.rs rename to crates/codegen2/src/types.rs index 5bb4df9883..6789857719 100644 --- a/crates/codegen2/src/db/queries/types.rs +++ b/crates/codegen2/src/types.rs @@ -3,8 +3,7 @@ use fe_mir::ir::{ Type, TypeId, TypeKind, }; -use crate::db::CodegenDb; - +#[salsa::tracked(return_ref)] pub fn legalized_type(db: &dyn CodegenDb, ty: TypeId) -> TypeId { let ty_data = ty.data(db.upcast()); let ty_kind = match &ty.data(db.upcast()).kind { diff --git a/crates/codegen2/src/yul/isel/context.rs b/crates/codegen2/src/yul/isel/context.rs index 4ca2840cc3..be2748e92e 100644 --- a/crates/codegen2/src/yul/isel/context.rs +++ b/crates/codegen2/src/yul/isel/context.rs @@ -1,20 +1,20 @@ +use hir::hir_def::Contract; use indexmap::IndexSet; -use fe_analyzer::namespace::items::ContractId; use fe_mir::ir::FunctionId; use fxhash::FxHashSet; use yultsur::yul; use crate::{ - db::CodegenDb, yul::runtime::{DefaultRuntimeProvider, RuntimeProvider}, + CodegenDb, }; use super::{lower_contract_deployable, lower_function}; pub struct Context { pub runtime: Box, - pub(super) contract_dependency: IndexSet, + pub(super) contract_dependency: IndexSet, pub(super) function_dependency: IndexSet, pub(super) string_constants: IndexSet, pub(super) lowered_functions: FxHashSet, diff --git a/crates/codegen2/src/yul/isel/contract.rs b/crates/codegen2/src/yul/isel/contract.rs index 3703381d8e..e218404d55 100644 --- a/crates/codegen2/src/yul/isel/contract.rs +++ b/crates/codegen2/src/yul/isel/contract.rs @@ -1,15 +1,15 @@ -use fe_analyzer::namespace::items::ContractId; use fe_mir::ir::{function::Linkage, FunctionId}; +use hir::hir_def::Contract; use yultsur::{yul, *}; use crate::{ - db::CodegenDb, yul::{runtime::AbiSrcLocation, YulVariable}, + CodegenDb, }; use super::context::Context; -pub fn lower_contract_deployable(db: &dyn CodegenDb, contract: ContractId) -> yul::Object { +pub fn lower_contract_deployable(db: &dyn CodegenDb, contract: Contract) -> yul::Object { let mut context = Context::default(); let constructor = if let Some(init) = contract.init_function(db.upcast()) { @@ -61,7 +61,7 @@ pub fn lower_contract_deployable(db: &dyn CodegenDb, contract: ContractId) -> yu normalize_object(object) } -pub fn lower_contract(db: &dyn CodegenDb, contract: ContractId) -> yul::Object { +pub fn lower_contract(db: &dyn CodegenDb, contract: Contract) -> yul::Object { let exported_funcs: Vec<_> = db .mir_lower_contract_all_functions(contract) .iter() @@ -210,7 +210,7 @@ fn dispatch_arm(db: &dyn CodegenDb, context: &mut Context, func: FunctionId) -> fn make_init( db: &dyn CodegenDb, context: &mut Context, - contract: ContractId, + contract: Contract, init: FunctionId, ) -> Vec { context.function_dependency.insert(init); @@ -250,7 +250,7 @@ fn make_init( } } -fn make_deploy(db: &dyn CodegenDb, contract: ContractId) -> Vec { +fn make_deploy(db: &dyn CodegenDb, contract: Contract) -> Vec { let contract_symbol = identifier_expression! { (format!{r#""{}""#, db.codegen_contract_symbol_name(contract)}) }; let size = YulVariable::new("$$size"); diff --git a/crates/codegen2/src/yul/isel/function.rs b/crates/codegen2/src/yul/isel/function.rs index 78eaecce2a..b7adb5ae24 100644 --- a/crates/codegen2/src/yul/isel/function.rs +++ b/crates/codegen2/src/yul/isel/function.rs @@ -25,13 +25,13 @@ use yultsur::{ }; use crate::{ - db::CodegenDb, yul::{ isel::inst_order::StructuralInst, runtime::{self, RuntimeProvider}, slot_size::{function_hash_type, yul_primitive_type, SLOT_SIZE}, YulVariable, }, + CodegenDb, }; pub fn lower_function( diff --git a/crates/codegen2/src/yul/isel/test.rs b/crates/codegen2/src/yul/isel/test.rs index 9fe6186933..d4246e4edb 100644 --- a/crates/codegen2/src/yul/isel/test.rs +++ b/crates/codegen2/src/yul/isel/test.rs @@ -1,6 +1,7 @@ +use crate::CodegenDb; + use super::context::Context; -use crate::db::CodegenDb; -use fe_analyzer::namespace::items::FunctionId; +use fe_mir::ir::FunctionId; use yultsur::{yul, *}; pub fn lower_test(db: &dyn CodegenDb, test: FunctionId) -> yul::Object { diff --git a/crates/codegen2/src/yul/legalize/body.rs b/crates/codegen2/src/yul/legalize/body.rs index 5c1b361b1e..d3a556fa61 100644 --- a/crates/codegen2/src/yul/legalize/body.rs +++ b/crates/codegen2/src/yul/legalize/body.rs @@ -5,7 +5,7 @@ use fe_mir::ir::{ FunctionBody, Inst, InstId, TypeId, TypeKind, Value, ValueId, }; -use crate::db::CodegenDb; +use crate::CodegenDb; use super::critical_edge::CriticalEdgeSplitter; diff --git a/crates/codegen2/src/yul/legalize/signature.rs b/crates/codegen2/src/yul/legalize/signature.rs index 134bc10ae6..ff032d751b 100644 --- a/crates/codegen2/src/yul/legalize/signature.rs +++ b/crates/codegen2/src/yul/legalize/signature.rs @@ -1,6 +1,6 @@ use fe_mir::ir::{FunctionSignature, TypeKind}; -use crate::db::CodegenDb; +use crate::CodegenDb; pub fn legalize_func_signature(db: &dyn CodegenDb, sig: &mut FunctionSignature) { // Remove param if the type is contract or zero-sized. diff --git a/crates/codegen2/src/yul/runtime/abi.rs b/crates/codegen2/src/yul/runtime/abi.rs index 565465a18e..68def8d5e2 100644 --- a/crates/codegen2/src/yul/runtime/abi.rs +++ b/crates/codegen2/src/yul/runtime/abi.rs @@ -1,10 +1,10 @@ use crate::{ - db::CodegenDb, yul::{ runtime::{error_revert_numeric, make_ptr}, slot_size::{yul_primitive_type, SLOT_SIZE}, YulVariable, }, + CodegenDb, }; use super::{AbiSrcLocation, DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; diff --git a/crates/codegen2/src/yul/runtime/contract.rs b/crates/codegen2/src/yul/runtime/contract.rs index d85321b377..194679dcba 100644 --- a/crates/codegen2/src/yul/runtime/contract.rs +++ b/crates/codegen2/src/yul/runtime/contract.rs @@ -1,20 +1,20 @@ use crate::{ - db::CodegenDb, yul::{runtime::AbiSrcLocation, YulVariable}, + CodegenDb, }; use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; -use fe_analyzer::namespace::items::ContractId; use fe_mir::ir::{FunctionId, Type, TypeKind}; +use hir::hir_def::Contract; use yultsur::*; pub(super) fn make_create( provider: &mut DefaultRuntimeProvider, db: &dyn CodegenDb, func_name: &str, - contract: ContractId, + contract: Contract, ) -> RuntimeFunction { let func_name = YulVariable::new(func_name); let contract_symbol = literal_expression! { @@ -40,7 +40,7 @@ pub(super) fn make_create2( provider: &mut DefaultRuntimeProvider, db: &dyn CodegenDb, func_name: &str, - contract: ContractId, + contract: Contract, ) -> RuntimeFunction { let func_name = YulVariable::new(func_name); let contract_symbol = literal_expression! { diff --git a/crates/codegen2/src/yul/runtime/data.rs b/crates/codegen2/src/yul/runtime/data.rs index 02509d7129..85eccd5704 100644 --- a/crates/codegen2/src/yul/runtime/data.rs +++ b/crates/codegen2/src/yul/runtime/data.rs @@ -1,10 +1,10 @@ use crate::{ - db::CodegenDb, yul::{ runtime::{make_ptr, BitMask}, slot_size::{yul_primitive_type, SLOT_SIZE}, YulVariable, }, + CodegenDb, }; use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; diff --git a/crates/codegen2/src/yul/runtime/emit.rs b/crates/codegen2/src/yul/runtime/emit.rs index 7e1f8ab8c6..cfe0920ee4 100644 --- a/crates/codegen2/src/yul/runtime/emit.rs +++ b/crates/codegen2/src/yul/runtime/emit.rs @@ -1,6 +1,6 @@ use crate::{ - db::CodegenDb, yul::{runtime::make_ptr, slot_size::SLOT_SIZE, YulVariable}, + CodegenDb, }; use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; diff --git a/crates/codegen2/src/yul/runtime/mod.rs b/crates/codegen2/src/yul/runtime/mod.rs index a658b4b4aa..e5b7856d91 100644 --- a/crates/codegen2/src/yul/runtime/mod.rs +++ b/crates/codegen2/src/yul/runtime/mod.rs @@ -8,14 +8,14 @@ mod safe_math; use std::fmt::Write; use fe_abi::types::AbiType; -use fe_analyzer::namespace::items::ContractId; use fe_mir::ir::{types::ArrayDef, FunctionId, TypeId, TypeKind}; +use hir::hir_def::Contract; use indexmap::IndexMap; use yultsur::*; use num_bigint::BigInt; -use crate::{db::CodegenDb, yul::slot_size::SLOT_SIZE}; +use crate::{yul::slot_size::SLOT_SIZE, CodegenDb}; use super::slot_size::yul_primitive_type; @@ -29,14 +29,14 @@ pub trait RuntimeProvider { fn create( &mut self, db: &dyn CodegenDb, - contract: ContractId, + contract: Contract, value: yul::Expression, ) -> yul::Expression; fn create2( &mut self, db: &dyn CodegenDb, - contract: ContractId, + contract: Contract, value: yul::Expression, salt: yul::Expression, ) -> yul::Expression; @@ -272,7 +272,7 @@ impl RuntimeProvider for DefaultRuntimeProvider { fn create( &mut self, db: &dyn CodegenDb, - contract: ContractId, + contract: Contract, value: yul::Expression, ) -> yul::Expression { let name = format!("$create_{}", db.codegen_contract_symbol_name(contract)); @@ -285,7 +285,7 @@ impl RuntimeProvider for DefaultRuntimeProvider { fn create2( &mut self, db: &dyn CodegenDb, - contract: ContractId, + contract: Contract, value: yul::Expression, salt: yul::Expression, ) -> yul::Expression { diff --git a/crates/codegen2/src/yul/runtime/revert.rs b/crates/codegen2/src/yul/runtime/revert.rs index 1f8fa9eb7b..396e07e76d 100644 --- a/crates/codegen2/src/yul/runtime/revert.rs +++ b/crates/codegen2/src/yul/runtime/revert.rs @@ -1,6 +1,6 @@ use crate::{ - db::CodegenDb, yul::{slot_size::function_hash_type, YulVariable}, + CodegenDb, }; use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; diff --git a/crates/codegen2/src/yul/runtime/safe_math.rs b/crates/codegen2/src/yul/runtime/safe_math.rs index 5bdfe8b04a..12eb87b1f8 100644 --- a/crates/codegen2/src/yul/runtime/safe_math.rs +++ b/crates/codegen2/src/yul/runtime/safe_math.rs @@ -1,4 +1,4 @@ -use crate::{db::CodegenDb, yul::YulVariable}; +use crate::{yul::YulVariable, CodegenDb}; use super::{DefaultRuntimeProvider, RuntimeFunction, RuntimeProvider}; diff --git a/crates/codegen2/src/yul/slot_size.rs b/crates/codegen2/src/yul/slot_size.rs index aa931132c0..fdf4964b09 100644 --- a/crates/codegen2/src/yul/slot_size.rs +++ b/crates/codegen2/src/yul/slot_size.rs @@ -1,6 +1,6 @@ use fe_mir::ir::{Type, TypeId, TypeKind}; -use crate::db::CodegenDb; +use crate::CodegenDb; // We use the same slot size between memory and storage to simplify the // implementation and minimize gas consumption in memory <-> storage copy diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs index 8b13789179..8382f804c1 100644 --- a/crates/mir2/src/lower/mod.rs +++ b/crates/mir2/src/lower/mod.rs @@ -1 +1 @@ - +use hir::hir_def::Contract; diff --git a/crates/mir2/tests/lowering.rs b/crates/mir2/tests/lowering.rs index 6063e7f4ea..be4354be3f 100644 --- a/crates/mir2/tests/lowering.rs +++ b/crates/mir2/tests/lowering.rs @@ -32,11 +32,11 @@ fn mir_lower_std_lib() { let top_mod = db.new_std_lib(); let mut pm = initialize_analysis_pass(&db); - let diags = pm.run_on_module(std_ingot.root_mod(&db)); + let diags = pm.run_on_module(top_mod); - // if !diags.is_empty() { - // panic!("std lib analysis failed") - // } + if !diags.is_empty() { + panic!("std lib analysis failed") + } // for &module in std_ingot.all_modules(&db).iter() { // for func in db.mir_lower_module_all_functions(module).iter() { diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index 7c2bdb5623..1947c68dc5 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -51,7 +51,7 @@ impl LowerMirTestDb { ingot.set_root_file(self, root); ingot.set_files(self, [root].into()); - let top_mod = lower::map_file_to_mod(self, input_file); + // let top_mod = lower::map_file_to_mod(self, input_file); // let mut prop_formatter = HirPropertyFormatter::default(); // let top_mod = self.register_file(&mut prop_formatter, root); @@ -59,8 +59,9 @@ impl LowerMirTestDb { // top_mod } - pub fn new_std_lib(&mut self) { - library::std_lib_input_ingot(self); + pub fn new_std_lib(&mut self) -> TopLevelMod { + let input = library::std_lib_input_ingot(self); + lower::map_file_to_mod(self, input_file) } fn register_file(&self, input_file: InputFile) { From 58c620bb685948cd5b9913fb20a6b4ad7ce58ef2 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Wed, 24 Jan 2024 22:21:39 -0700 Subject: [PATCH 14/22] hacking --- Cargo.lock | 2 -- crates/codegen2/Cargo.toml | 2 -- 2 files changed, 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fbbc1935cf..22e8567e48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,10 +937,8 @@ name = "fe-codegen2" version = "0.23.0" dependencies = [ "fe-abi", - "fe-common", "fe-hir", "fe-hir-analysis", - "fe-mir", "fe-mir2", "fxhash", "indexmap", diff --git a/crates/codegen2/Cargo.toml b/crates/codegen2/Cargo.toml index ad8af1b77e..8d18bca5e7 100644 --- a/crates/codegen2/Cargo.toml +++ b/crates/codegen2/Cargo.toml @@ -8,8 +8,6 @@ edition = "2021" hir-analysis = { path = "../hir-analysis", package = "fe-hir-analysis" } hir = { path = "../hir", package = "fe-hir" } mir = { path = "../mir2", package = "fe-mir2" } -fe-mir = { path = "../mir", version = "^0.23.0"} -fe-common = { path = "../common", version = "^0.23.0"} fe-abi = { path = "../abi", version = "^0.23.0"} salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } num-bigint = "0.4.3" From 7ecb20e8d0daeb75dbed573696fe11d7e9303a89 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Fri, 26 Jan 2024 16:30:59 -0700 Subject: [PATCH 15/22] hacking --- crates/mir2/src/ir/basic_block.rs | 10 ++++------ crates/mir2/src/lib.rs | 2 -- crates/mir2/tests/test_db.rs | 3 ++- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/crates/mir2/src/ir/basic_block.rs b/crates/mir2/src/ir/basic_block.rs index 73b2eab8b7..359c4c76f6 100644 --- a/crates/mir2/src/ir/basic_block.rs +++ b/crates/mir2/src/ir/basic_block.rs @@ -1,8 +1,6 @@ -#[salsa::interned] -pub struct BasicBlockId { - #[return_ref] - pub data: BasicBlock, -} +use id_arena::Id; -#[salsa::tracked] +pub type BasicBlockId = Id; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct BasicBlock {} diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 562d59cad8..550fdb9262 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -9,8 +9,6 @@ mod lower; #[salsa::jar(db = MirDb)] pub struct Jar( - ir::BasicBlock, - ir::BasicBlockId, // ir::Constant, // ir::ConstantId, // ir::FunctionBody, diff --git a/crates/mir2/tests/test_db.rs b/crates/mir2/tests/test_db.rs index 1947c68dc5..724c0a1d2b 100644 --- a/crates/mir2/tests/test_db.rs +++ b/crates/mir2/tests/test_db.rs @@ -61,7 +61,8 @@ impl LowerMirTestDb { pub fn new_std_lib(&mut self) -> TopLevelMod { let input = library::std_lib_input_ingot(self); - lower::map_file_to_mod(self, input_file) + panic!(""); + // lower::map_file_to_mod(self, input_file) } fn register_file(&self, input_file: InputFile) { From 162bda3f310d132111450ab01c5a85cafe9919c6 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Tue, 30 Jan 2024 16:50:54 -0700 Subject: [PATCH 16/22] hacking --- crates/mir2/src/ir/constant.rs | 43 +++++++++++++++++++--------------- crates/mir2/src/ir/mod.rs | 4 ++-- crates/mir2/src/lib.rs | 2 +- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs index c9eac33ccf..5180efa0b8 100644 --- a/crates/mir2/src/ir/constant.rs +++ b/crates/mir2/src/ir/constant.rs @@ -1,8 +1,14 @@ -use fe_common2::impl_intern_key; +use hir::hir_def::{ModuleTreeNodeId, TypeId}; use num_bigint::BigInt; use smol_str::SmolStr; -use super::{SourceInfo, TypeId}; +// use super::SourceInfo; + +#[salsa::interned] +pub struct ConstantId { + #[return_ref] + pub data: Constant, +} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Constant { @@ -16,16 +22,15 @@ pub struct Constant { pub ty: TypeId, /// A module where a constant is declared. - pub module_id: analyzer_items::ModuleId, - - /// A span where a constant is declared. - pub source: SourceInfo, + pub module_id: ModuleTreeNodeId, + // /// A span where a constant is declared. + // pub source: SourceInfo, } -/// An interned Id for [`Constant`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ConstantId(pub(crate) u32); -impl_intern_key!(ConstantId); +// /// An interned Id for [`Constant`]. +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct ConstantId(pub(crate) u32); +// impl_intern_key!(ConstantId); #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ConstantValue { @@ -34,12 +39,12 @@ pub enum ConstantValue { Bool(bool), } -impl From for ConstantValue { - fn from(value: context::Constant) -> Self { - match value { - context::Constant::Int(num) | context::Constant::Address(num) => Self::Immediate(num), - context::Constant::Str(s) => Self::Str(s), - context::Constant::Bool(b) => Self::Bool(b), - } - } -} +// impl From for ConstantValue { +// fn from(value: context::Constant) -> Self { +// match value { +// context::Constant::Int(num) | context::Constant::Address(num) => Self::Immediate(num), +// context::Constant::Str(s) => Self::Str(s), +// context::Constant::Bool(b) => Self::Bool(b), +// } +// } +// } diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index dee098c765..6193b61250 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -2,14 +2,14 @@ pub mod basic_block; // pub mod body_builder; // pub mod body_cursor; // pub mod body_order; -// pub mod constant; +pub mod constant; // pub mod function; // pub mod inst; // pub mod types; // pub mod value; pub use basic_block::{BasicBlock, BasicBlockId}; -// pub use constant::{Constant, ConstantId}; +pub use constant::{Constant, ConstantId}; // pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; // pub use inst::{Inst, InstId}; // pub use types::{Type, TypeId, TypeKind}; diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 550fdb9262..8fc03c0d0a 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -10,7 +10,7 @@ mod lower; #[salsa::jar(db = MirDb)] pub struct Jar( // ir::Constant, - // ir::ConstantId, + ir::ConstantId, // ir::FunctionBody, // ir::FunctionId, // ir::FunctionParam, From f29516a16e788e7c57076a6001fd053041e2f629 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Wed, 31 Jan 2024 17:36:33 -0700 Subject: [PATCH 17/22] hacking --- crates/mir2/src/ir/body_builder.rs | 121 ++++++++++------ crates/mir2/src/ir/function.rs | 20 +-- crates/mir2/src/ir/inst.rs | 201 ++++++++++++++------------ crates/mir2/src/ir/mod.rs | 18 +-- crates/mir2/src/ir/value.rs | 58 ++++---- crates/mir2/src/lib.rs | 2 +- crates/mir2/src/pretty_print/inst.rs | 2 +- crates/mir2/src/pretty_print/mod.rs | 2 +- crates/mir2/src/pretty_print/types.rs | 7 +- crates/mir2/src/pretty_print/value.rs | 2 +- 10 files changed, 236 insertions(+), 197 deletions(-) diff --git a/crates/mir2/src/ir/body_builder.rs b/crates/mir2/src/ir/body_builder.rs index 622a24af4f..602ccdf496 100644 --- a/crates/mir2/src/ir/body_builder.rs +++ b/crates/mir2/src/ir/body_builder.rs @@ -1,10 +1,11 @@ +use hir::hir_def::TypeId; use num_bigint::BigInt; use crate::ir::{ body_cursor::{BodyCursor, CursorLocation}, inst::{BinOp, Inst, InstKind, UnOp}, value::{AssignableValue, Local}, - BasicBlock, BasicBlockId, FunctionBody, FunctionId, InstId, SourceInfo, TypeId, + BasicBlock, BasicBlockId, FunctionBody, FunctionId, InstId, }; use super::{ @@ -20,8 +21,10 @@ pub struct BodyBuilder { macro_rules! impl_unary_inst { ($name:ident, $code:path) => { - pub fn $name(&mut self, value: ValueId, source: SourceInfo) -> InstId { - let inst = Inst::unary($code, value, source); + // pub fn $name(&mut self, value: ValueId, source: SourceInfo) -> InstId { + pub fn $name(&mut self, value: ValueId) -> InstId { + // let inst = Inst::unary($code, value, source); + let inst = Inst::unary($code, value); self.insert_inst(inst) } }; @@ -29,16 +32,20 @@ macro_rules! impl_unary_inst { macro_rules! impl_binary_inst { ($name:ident, $code:path) => { - pub fn $name(&mut self, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> InstId { - let inst = Inst::binary($code, lhs, rhs, source); + // pub fn $name(&mut self, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> InstId { + pub fn $name(&mut self, lhs: ValueId, rhs: ValueId) -> InstId { + // let inst = Inst::binary($code, lhs, rhs, source); + let inst = Inst::binary($code, lhs, rhs); self.insert_inst(inst) } }; } impl BodyBuilder { - pub fn new(fid: FunctionId, source: SourceInfo) -> Self { - let body = FunctionBody::new(fid, source); + // pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + pub fn new(fid: FunctionId) -> Self { + // let body = FunctionBody::new(fid, source); + let body = FunctionBody::new(fid); let entry_block = body.order.entry(); Self { body, @@ -104,11 +111,12 @@ impl BodyBuilder { } pub fn declare(&mut self, local: Local) -> ValueId { - let source = local.source.clone(); + // let source = local.source.clone(); let local_id = self.body.store.store_value(Value::Local(local)); let kind = InstKind::Declare { local: local_id }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst); local_id } @@ -145,24 +153,27 @@ impl BodyBuilder { &mut self, value: ValueId, result_ty: TypeId, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::Cast { kind: CastKind::Primitive, value, to: result_ty, }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId, source: SourceInfo) -> InstId { + // pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId, source: SourceInfo) -> InstId { + pub fn untag_cast(&mut self, value: ValueId, result_ty: TypeId) -> InstId { let kind = InstKind::Cast { kind: CastKind::Untag, value, to: result_ty, }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -170,28 +181,35 @@ impl BodyBuilder { &mut self, ty: TypeId, args: Vec, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::AggregateConstruct { ty, args }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn bind(&mut self, src: ValueId, source: SourceInfo) -> InstId { + // pub fn bind(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn bind(&mut self, src: ValueId) -> InstId { let kind = InstKind::Bind { src }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn mem_copy(&mut self, src: ValueId, source: SourceInfo) -> InstId { + // pub fn mem_copy(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn mem_copy(&mut self, src: ValueId) -> InstId { let kind = InstKind::MemCopy { src }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn load(&mut self, src: ValueId, source: SourceInfo) -> InstId { + // pub fn load(&mut self, src: ValueId, source: SourceInfo) -> InstId { + pub fn load(&mut self, src: ValueId) -> InstId { let kind = InstKind::Load { src }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -199,16 +217,19 @@ impl BodyBuilder { &mut self, value: ValueId, indices: Vec, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::AggregateAccess { value, indices }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn map_access(&mut self, value: ValueId, key: ValueId, source: SourceInfo) -> InstId { + // pub fn map_access(&mut self, value: ValueId, key: ValueId, source: SourceInfo) -> InstId { + pub fn map_access(&mut self, value: ValueId, key: ValueId) -> InstId { let kind = InstKind::MapAccess { value, key }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -217,14 +238,15 @@ impl BodyBuilder { func: FunctionId, args: Vec, call_type: CallType, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::Call { func, args, call_type, }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -232,15 +254,18 @@ impl BodyBuilder { &mut self, op: YulIntrinsicOp, args: Vec, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { - let inst = Inst::intrinsic(op, args, source); + // let inst = Inst::intrinsic(op, args, source); + let inst = Inst::intrinsic(op, args); self.insert_inst(inst) } - pub fn jump(&mut self, dest: BasicBlockId, source: SourceInfo) -> InstId { + // pub fn jump(&mut self, dest: BasicBlockId, source: SourceInfo) -> InstId { + pub fn jump(&mut self, dest: BasicBlockId) -> InstId { let kind = InstKind::Jump { dest }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -249,10 +274,11 @@ impl BodyBuilder { cond: ValueId, then: BasicBlockId, else_: BasicBlockId, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::Branch { cond, then, else_ }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } @@ -261,38 +287,47 @@ impl BodyBuilder { disc: ValueId, table: SwitchTable, default: Option, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { let kind = InstKind::Switch { disc, table, default, }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn revert(&mut self, arg: Option, source: SourceInfo) -> InstId { + // pub fn revert(&mut self, arg: Option, source: SourceInfo) -> InstId { + pub fn revert(&mut self, arg: Option) -> InstId { let kind = InstKind::Revert { arg }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn emit(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + // pub fn emit(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + pub fn emit(&mut self, arg: ValueId) -> InstId { let kind = InstKind::Emit { arg }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn ret(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + // pub fn ret(&mut self, arg: ValueId, source: SourceInfo) -> InstId { + pub fn ret(&mut self, arg: ValueId) -> InstId { let kind = InstKind::Return { arg: arg.into() }; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } - pub fn nop(&mut self, source: SourceInfo) -> InstId { + // pub fn nop(&mut self, source: SourceInfo) -> InstId { + pub fn nop(&mut self) -> InstId { let kind = InstKind::Nop; - let inst = Inst::new(kind, source); + // let inst = Inst::new(kind, source); + let inst = Inst::new(kind); self.insert_inst(inst) } diff --git a/crates/mir2/src/ir/function.rs b/crates/mir2/src/ir/function.rs index 73aed03e6e..89eeabac1e 100644 --- a/crates/mir2/src/ir/function.rs +++ b/crates/mir2/src/ir/function.rs @@ -1,4 +1,5 @@ use fxhash::FxHashMap; +use hir::hir_def::{ModuleTreeNodeId, TypeId}; use id_arena::Arena; use num_bigint::BigInt; use smol_str::SmolStr; @@ -11,17 +12,16 @@ use super::{ // types::TypeId, value::{AssignableValue, Local, Value, ValueId}, BasicBlockId, - SourceInfo, }; /// Represents function signature. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct FunctionSignature { pub params: Vec, - pub resolved_generics: BTreeMap, + pub resolved_generics: BTreeMap, pub return_type: Option, - pub module_id: analyzer_items::ModuleId, - pub analyzer_func_id: analyzer_items::FunctionId, + pub module_id: ModuleTreeNodeId, + pub analyzer_func_id: FunctionId, pub linkage: Linkage, } @@ -29,12 +29,12 @@ pub struct FunctionSignature { pub struct FunctionParam { pub name: SmolStr, pub ty: TypeId, - pub source: SourceInfo, + // pub source: SourceInfo, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct FunctionId(pub u32); -impl_intern_key!(FunctionId); +// impl_intern_key!(FunctionId); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Linkage { @@ -66,19 +66,19 @@ pub struct FunctionBody { /// Tracks order of basic blocks and instructions in a function body. pub order: BodyOrder, - - pub source: SourceInfo, + // pub source: SourceInfo, } impl FunctionBody { - pub fn new(fid: FunctionId, source: SourceInfo) -> Self { + pub fn new(fid: FunctionId) -> Self { + // pub fn new(fid: FunctionId, source: SourceInfo) -> Self { let mut store = BodyDataStore::default(); let entry_bb = store.store_block(BasicBlock {}); Self { fid, store, order: BodyOrder::new(entry_bb), - source, + // source, } } } diff --git a/crates/mir2/src/ir/inst.rs b/crates/mir2/src/ir/inst.rs index c86f2bc8fb..9925d20ded 100644 --- a/crates/mir2/src/ir/inst.rs +++ b/crates/mir2/src/ir/inst.rs @@ -1,15 +1,16 @@ use std::fmt; +use hir::hir_def::{Contract, TypeId}; use id_arena::Id; -use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId, SourceInfo, TypeId}; +use super::{basic_block::BasicBlockId, function::FunctionId, value::ValueId}; pub type InstId = Id; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Inst { pub kind: InstKind, - pub source: SourceInfo, + // pub source: SourceInfo, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -125,13 +126,13 @@ pub enum InstKind { Create { value: ValueId, - contract: ContractId, + contract: Contract, }, Create2 { value: ValueId, salt: ValueId, - contract: ContractId, + contract: Contract, }, YulIntrinsic { @@ -168,29 +169,37 @@ impl SwitchTable { } impl Inst { - pub fn new(kind: InstKind, source: SourceInfo) -> Self { - Self { kind, source } + // pub fn new(kind: InstKind, source: SourceInfo) -> Self { + pub fn new(kind: InstKind) -> Self { + // Self { kind, source } + Self { kind } } - pub fn unary(op: UnOp, value: ValueId, source: SourceInfo) -> Self { + // pub fn unary(op: UnOp, value: ValueId, source: SourceInfo) -> Self { + pub fn unary(op: UnOp, value: ValueId) -> Self { let kind = InstKind::Unary { op, value }; - Self::new(kind, source) + // Self::new(kind, source) + Self::new(kind) } - pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> Self { + // pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId, source: SourceInfo) -> Self { + pub fn binary(op: BinOp, lhs: ValueId, rhs: ValueId) -> Self { let kind = InstKind::Binary { op, lhs, rhs }; - Self::new(kind, source) + // Self::new(kind, source) + Self::new(kind) } - pub fn intrinsic(op: YulIntrinsicOp, args: Vec, source: SourceInfo) -> Self { + // pub fn intrinsic(op: YulIntrinsicOp, args: Vec, source: SourceInfo) -> Self { + pub fn intrinsic(op: YulIntrinsicOp, args: Vec) -> Self { let kind = InstKind::YulIntrinsic { op, args }; - Self::new(kind, source) + // Self::new(kind, source) + Self::new(kind) } pub fn nop() -> Self { Self { kind: InstKind::Nop, - source: SourceInfo::dummy(), + // source: SourceInfo::dummy(), } } @@ -575,89 +584,89 @@ impl fmt::Display for YulIntrinsicOp { } } -impl From for YulIntrinsicOp { - fn from(val: fe_analyzer2::builtins::Intrinsic) -> Self { - use fe_analyzer2::builtins::Intrinsic; - match val { - Intrinsic::__stop => Self::Stop, - Intrinsic::__add => Self::Add, - Intrinsic::__sub => Self::Sub, - Intrinsic::__mul => Self::Mul, - Intrinsic::__div => Self::Div, - Intrinsic::__sdiv => Self::Sdiv, - Intrinsic::__mod => Self::Mod, - Intrinsic::__smod => Self::Smod, - Intrinsic::__exp => Self::Exp, - Intrinsic::__not => Self::Not, - Intrinsic::__lt => Self::Lt, - Intrinsic::__gt => Self::Gt, - Intrinsic::__slt => Self::Slt, - Intrinsic::__sgt => Self::Sgt, - Intrinsic::__eq => Self::Eq, - Intrinsic::__iszero => Self::Iszero, - Intrinsic::__and => Self::And, - Intrinsic::__or => Self::Or, - Intrinsic::__xor => Self::Xor, - Intrinsic::__byte => Self::Byte, - Intrinsic::__shl => Self::Shl, - Intrinsic::__shr => Self::Shr, - Intrinsic::__sar => Self::Sar, - Intrinsic::__addmod => Self::Addmod, - Intrinsic::__mulmod => Self::Mulmod, - Intrinsic::__signextend => Self::Signextend, - Intrinsic::__keccak256 => Self::Keccak256, - Intrinsic::__pc => Self::Pc, - Intrinsic::__pop => Self::Pop, - Intrinsic::__mload => Self::Mload, - Intrinsic::__mstore => Self::Mstore, - Intrinsic::__mstore8 => Self::Mstore8, - Intrinsic::__sload => Self::Sload, - Intrinsic::__sstore => Self::Sstore, - Intrinsic::__msize => Self::Msize, - Intrinsic::__gas => Self::Gas, - Intrinsic::__address => Self::Address, - Intrinsic::__balance => Self::Balance, - Intrinsic::__selfbalance => Self::Selfbalance, - Intrinsic::__caller => Self::Caller, - Intrinsic::__callvalue => Self::Callvalue, - Intrinsic::__calldataload => Self::Calldataload, - Intrinsic::__calldatasize => Self::Calldatasize, - Intrinsic::__calldatacopy => Self::Calldatacopy, - Intrinsic::__codesize => Self::Codesize, - Intrinsic::__codecopy => Self::Codecopy, - Intrinsic::__extcodesize => Self::Extcodesize, - Intrinsic::__extcodecopy => Self::Extcodecopy, - Intrinsic::__returndatasize => Self::Returndatasize, - Intrinsic::__returndatacopy => Self::Returndatacopy, - Intrinsic::__extcodehash => Self::Extcodehash, - Intrinsic::__create => Self::Create, - Intrinsic::__create2 => Self::Create2, - Intrinsic::__call => Self::Call, - Intrinsic::__callcode => Self::Callcode, - Intrinsic::__delegatecall => Self::Delegatecall, - Intrinsic::__staticcall => Self::Staticcall, - Intrinsic::__return => Self::Return, - Intrinsic::__revert => Self::Revert, - Intrinsic::__selfdestruct => Self::Selfdestruct, - Intrinsic::__invalid => Self::Invalid, - Intrinsic::__log0 => Self::Log0, - Intrinsic::__log1 => Self::Log1, - Intrinsic::__log2 => Self::Log2, - Intrinsic::__log3 => Self::Log3, - Intrinsic::__log4 => Self::Log4, - Intrinsic::__chainid => Self::Chainid, - Intrinsic::__basefee => Self::Basefee, - Intrinsic::__origin => Self::Origin, - Intrinsic::__gasprice => Self::Gasprice, - Intrinsic::__blockhash => Self::Blockhash, - Intrinsic::__coinbase => Self::Coinbase, - Intrinsic::__timestamp => Self::Timestamp, - Intrinsic::__number => Self::Number, - Intrinsic::__prevrandao => Self::Prevrandao, - Intrinsic::__gaslimit => Self::Gaslimit, - } - } -} +// impl From for YulIntrinsicOp { +// fn from(val: fe_analyzer2::builtins::Intrinsic) -> Self { +// use fe_analyzer2::builtins::Intrinsic; +// match val { +// Intrinsic::__stop => Self::Stop, +// Intrinsic::__add => Self::Add, +// Intrinsic::__sub => Self::Sub, +// Intrinsic::__mul => Self::Mul, +// Intrinsic::__div => Self::Div, +// Intrinsic::__sdiv => Self::Sdiv, +// Intrinsic::__mod => Self::Mod, +// Intrinsic::__smod => Self::Smod, +// Intrinsic::__exp => Self::Exp, +// Intrinsic::__not => Self::Not, +// Intrinsic::__lt => Self::Lt, +// Intrinsic::__gt => Self::Gt, +// Intrinsic::__slt => Self::Slt, +// Intrinsic::__sgt => Self::Sgt, +// Intrinsic::__eq => Self::Eq, +// Intrinsic::__iszero => Self::Iszero, +// Intrinsic::__and => Self::And, +// Intrinsic::__or => Self::Or, +// Intrinsic::__xor => Self::Xor, +// Intrinsic::__byte => Self::Byte, +// Intrinsic::__shl => Self::Shl, +// Intrinsic::__shr => Self::Shr, +// Intrinsic::__sar => Self::Sar, +// Intrinsic::__addmod => Self::Addmod, +// Intrinsic::__mulmod => Self::Mulmod, +// Intrinsic::__signextend => Self::Signextend, +// Intrinsic::__keccak256 => Self::Keccak256, +// Intrinsic::__pc => Self::Pc, +// Intrinsic::__pop => Self::Pop, +// Intrinsic::__mload => Self::Mload, +// Intrinsic::__mstore => Self::Mstore, +// Intrinsic::__mstore8 => Self::Mstore8, +// Intrinsic::__sload => Self::Sload, +// Intrinsic::__sstore => Self::Sstore, +// Intrinsic::__msize => Self::Msize, +// Intrinsic::__gas => Self::Gas, +// Intrinsic::__address => Self::Address, +// Intrinsic::__balance => Self::Balance, +// Intrinsic::__selfbalance => Self::Selfbalance, +// Intrinsic::__caller => Self::Caller, +// Intrinsic::__callvalue => Self::Callvalue, +// Intrinsic::__calldataload => Self::Calldataload, +// Intrinsic::__calldatasize => Self::Calldatasize, +// Intrinsic::__calldatacopy => Self::Calldatacopy, +// Intrinsic::__codesize => Self::Codesize, +// Intrinsic::__codecopy => Self::Codecopy, +// Intrinsic::__extcodesize => Self::Extcodesize, +// Intrinsic::__extcodecopy => Self::Extcodecopy, +// Intrinsic::__returndatasize => Self::Returndatasize, +// Intrinsic::__returndatacopy => Self::Returndatacopy, +// Intrinsic::__extcodehash => Self::Extcodehash, +// Intrinsic::__create => Self::Create, +// Intrinsic::__create2 => Self::Create2, +// Intrinsic::__call => Self::Call, +// Intrinsic::__callcode => Self::Callcode, +// Intrinsic::__delegatecall => Self::Delegatecall, +// Intrinsic::__staticcall => Self::Staticcall, +// Intrinsic::__return => Self::Return, +// Intrinsic::__revert => Self::Revert, +// Intrinsic::__selfdestruct => Self::Selfdestruct, +// Intrinsic::__invalid => Self::Invalid, +// Intrinsic::__log0 => Self::Log0, +// Intrinsic::__log1 => Self::Log1, +// Intrinsic::__log2 => Self::Log2, +// Intrinsic::__log3 => Self::Log3, +// Intrinsic::__log4 => Self::Log4, +// Intrinsic::__chainid => Self::Chainid, +// Intrinsic::__basefee => Self::Basefee, +// Intrinsic::__origin => Self::Origin, +// Intrinsic::__gasprice => Self::Gasprice, +// Intrinsic::__blockhash => Self::Blockhash, +// Intrinsic::__coinbase => Self::Coinbase, +// Intrinsic::__timestamp => Self::Timestamp, +// Intrinsic::__number => Self::Number, +// Intrinsic::__prevrandao => Self::Prevrandao, +// Intrinsic::__gaslimit => Self::Gaslimit, +// } +// } +// } pub enum BranchInfo<'a> { NotBranch, diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index 6193b61250..d6fef7916a 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -1,19 +1,19 @@ pub mod basic_block; -// pub mod body_builder; -// pub mod body_cursor; -// pub mod body_order; +pub mod body_builder; +pub mod body_cursor; +pub mod body_order; pub mod constant; -// pub mod function; -// pub mod inst; +pub mod function; +pub mod inst; // pub mod types; -// pub mod value; +pub mod value; pub use basic_block::{BasicBlock, BasicBlockId}; pub use constant::{Constant, ConstantId}; -// pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; -// pub use inst::{Inst, InstId}; +pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; +pub use inst::{Inst, InstId}; // pub use types::{Type, TypeId, TypeKind}; -// pub use value::{Value, ValueId}; +pub use value::{Value, ValueId}; // /// An original source information that indicates where `mir` entities derive // /// from. `SourceInfo` is mainly used for diagnostics. diff --git a/crates/mir2/src/ir/value.rs b/crates/mir2/src/ir/value.rs index f4aad28b63..26ed4063a1 100644 --- a/crates/mir2/src/ir/value.rs +++ b/crates/mir2/src/ir/value.rs @@ -1,16 +1,11 @@ +use hir::hir_def::{TypeId, TypeKind}; use id_arena::Id; use num_bigint::BigInt; use smol_str::SmolStr; -use crate::db::MirDb; +// use crate::db::MirDb; -use super::{ - constant::ConstantId, - function::BodyDataStore, - inst::InstId, - types::{TypeId, TypeKind}, - SourceInfo, -}; +use super::{constant::ConstantId, function::BodyDataStore, inst::InstId}; pub type ValueId = Id; @@ -68,22 +63,22 @@ impl From for AssignableValue { } impl AssignableValue { - pub fn ty(&self, db: &dyn MirDb, store: &BodyDataStore) -> TypeId { - match self { - Self::Value(value) => store.value_ty(*value), - Self::Aggregate { lhs, idx } => { - let lhs_ty = lhs.ty(db, store); - lhs_ty.projection_ty(db, store.value_data(*idx)) - } - Self::Map { lhs, .. } => { - let lhs_ty = lhs.ty(db, store).deref(db); - match lhs_ty.data(db).kind { - TypeKind::Map(def) => def.value_ty.make_sptr(db), - _ => unreachable!(), - } - } - } - } + // pub fn ty(&self, db: &dyn MirDb, store: &BodyDataStore) -> TypeId { + // match self { + // Self::Value(value) => store.value_ty(*value), + // Self::Aggregate { lhs, idx } => { + // let lhs_ty = lhs.ty(db, store); + // lhs_ty.projection_ty(db, store.value_data(*idx)) + // } + // Self::Map { lhs, .. } => { + // let lhs_ty = lhs.ty(db, store).deref(db); + // match lhs_ty.data(db).kind { + // TypeKind::Map(def) => def.value_ty.make_sptr(db), + // _ => unreachable!(), + // } + // } + // } + // } pub fn value_id(&self) -> Option { match self { @@ -105,28 +100,29 @@ pub struct Local { /// `true` if a local is introduced in MIR. pub is_tmp: bool, - - pub source: SourceInfo, + // pub source: SourceInfo, } impl Local { - pub fn user_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + // pub fn user_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + pub fn user_local(name: SmolStr, ty: TypeId) -> Local { Self { name, ty, is_arg: false, is_tmp: false, - source, + // source, } } - pub fn arg_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + // pub fn arg_local(name: SmolStr, ty: TypeId, source: SourceInfo) -> Local { + pub fn arg_local(name: SmolStr, ty: TypeId) -> Local { Self { name, ty, is_arg: true, is_tmp: false, - source, + // source, } } @@ -136,7 +132,7 @@ impl Local { ty, is_arg: false, is_tmp: true, - source: SourceInfo::dummy(), + // source: SourceInfo::dummy(), } } } diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 8fc03c0d0a..97df2a829a 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -3,7 +3,7 @@ use hir::HirDb; // pub mod analysis; // pub mod graphviz; pub mod ir; -// pub mod pretty_print; +pub mod pretty_print; mod lower; diff --git a/crates/mir2/src/pretty_print/inst.rs b/crates/mir2/src/pretty_print/inst.rs index 345d7946a3..f3612e7c50 100644 --- a/crates/mir2/src/pretty_print/inst.rs +++ b/crates/mir2/src/pretty_print/inst.rs @@ -1,8 +1,8 @@ use std::fmt::{self, Write}; use crate::{ - db::MirDb, ir::{function::BodyDataStore, inst::InstKind, InstId}, + MirDb, }; use super::PrettyPrint; diff --git a/crates/mir2/src/pretty_print/mod.rs b/crates/mir2/src/pretty_print/mod.rs index 190bb3ca7e..853b8cfa24 100644 --- a/crates/mir2/src/pretty_print/mod.rs +++ b/crates/mir2/src/pretty_print/mod.rs @@ -1,6 +1,6 @@ use std::fmt; -use crate::{db::MirDb, ir::function::BodyDataStore}; +use crate::{ir::function::BodyDataStore, MirDb}; mod inst; mod types; diff --git a/crates/mir2/src/pretty_print/types.rs b/crates/mir2/src/pretty_print/types.rs index 2574d8260a..1d146c8292 100644 --- a/crates/mir2/src/pretty_print/types.rs +++ b/crates/mir2/src/pretty_print/types.rs @@ -1,9 +1,8 @@ use std::fmt::{self, Write}; -use crate::{ - db::MirDb, - ir::{function::BodyDataStore, TypeId}, -}; +use hir::hir_def::TypeId; + +use crate::{ir::function::BodyDataStore, MirDb}; use super::PrettyPrint; diff --git a/crates/mir2/src/pretty_print/value.rs b/crates/mir2/src/pretty_print/value.rs index 05e1ff796d..3a1650b72b 100644 --- a/crates/mir2/src/pretty_print/value.rs +++ b/crates/mir2/src/pretty_print/value.rs @@ -1,10 +1,10 @@ use std::fmt::{self, Write}; use crate::{ - db::MirDb, ir::{ constant::ConstantValue, function::BodyDataStore, value::AssignableValue, Value, ValueId, }, + MirDb, }; use super::PrettyPrint; From abd5312cd94b76d283bbd7067e262da5437c8f48 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Wed, 31 Jan 2024 17:50:30 -0700 Subject: [PATCH 18/22] hacking --- crates/mir2/src/analysis/domtree.rs | 352 +++++++++---------- crates/mir2/src/analysis/loop_tree.rs | 163 ++++----- crates/mir2/src/analysis/post_domtree.rs | 422 +++++++++++------------ crates/mir2/src/lib.rs | 4 +- 4 files changed, 473 insertions(+), 468 deletions(-) diff --git a/crates/mir2/src/analysis/domtree.rs b/crates/mir2/src/analysis/domtree.rs index 9775db6335..cf30381e46 100644 --- a/crates/mir2/src/analysis/domtree.rs +++ b/crates/mir2/src/analysis/domtree.rs @@ -165,179 +165,179 @@ impl DFSet { } } -#[cfg(test)] -mod tests { - use super::*; - - use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; - - fn calc_dom(func: &FunctionBody) -> (DomTree, DFSet) { - let cfg = ControlFlowGraph::compute(func); - let domtree = DomTree::compute(&cfg); - let df = domtree.compute_df(&cfg); - (domtree, df) - } - - fn body_builder() -> BodyBuilder { - BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) - } - - #[test] - fn dom_tree_if_else() { - let mut builder = body_builder(); - - let then_block = builder.make_block(); - let else_block = builder.make_block(); - let merge_block = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, then_block, else_block, SourceInfo::dummy()); - - builder.move_to_block(then_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(else_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(merge_block); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let (dom_tree, df) = calc_dom(&func); - let entry_block = func.order.entry(); - assert_eq!(dom_tree.idom(entry_block), None); - assert_eq!(dom_tree.idom(then_block), Some(entry_block)); - assert_eq!(dom_tree.idom(else_block), Some(entry_block)); - assert_eq!(dom_tree.idom(merge_block), Some(entry_block)); - - assert_eq!(df.frontier_num(entry_block), 0); - assert_eq!(df.frontier_num(then_block), 1); - assert_eq!( - df.frontiers(then_block).unwrap().next().unwrap(), - merge_block - ); - assert_eq!( - df.frontiers(else_block).unwrap().next().unwrap(), - merge_block - ); - assert_eq!(df.frontier_num(merge_block), 0); - } - - #[test] - fn unreachable_edge() { - let mut builder = body_builder(); - - let block1 = builder.make_block(); - let block2 = builder.make_block(); - let block3 = builder.make_block(); - let block4 = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, block1, block2, SourceInfo::dummy()); - - builder.move_to_block(block1); - builder.jump(block4, SourceInfo::dummy()); - - builder.move_to_block(block2); - builder.jump(block4, SourceInfo::dummy()); - - builder.move_to_block(block3); - builder.jump(block4, SourceInfo::dummy()); - - builder.move_to_block(block4); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let (dom_tree, _) = calc_dom(&func); - let entry_block = func.order.entry(); - assert_eq!(dom_tree.idom(entry_block), None); - assert_eq!(dom_tree.idom(block1), Some(entry_block)); - assert_eq!(dom_tree.idom(block2), Some(entry_block)); - assert_eq!(dom_tree.idom(block3), None); - assert!(!dom_tree.is_reachable(block3)); - assert_eq!(dom_tree.idom(block4), Some(entry_block)); - } - - #[test] - fn dom_tree_complex() { - let mut builder = body_builder(); - - let block1 = builder.make_block(); - let block2 = builder.make_block(); - let block3 = builder.make_block(); - let block4 = builder.make_block(); - let block5 = builder.make_block(); - let block6 = builder.make_block(); - let block7 = builder.make_block(); - let block8 = builder.make_block(); - let block9 = builder.make_block(); - let block10 = builder.make_block(); - let block11 = builder.make_block(); - let block12 = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, block2, block1, SourceInfo::dummy()); - - builder.move_to_block(block1); - builder.branch(v0, block6, block3, SourceInfo::dummy()); - - builder.move_to_block(block2); - builder.branch(v0, block7, block4, SourceInfo::dummy()); - - builder.move_to_block(block3); - builder.branch(v0, block6, block5, SourceInfo::dummy()); - - builder.move_to_block(block4); - builder.branch(v0, block7, block2, SourceInfo::dummy()); - - builder.move_to_block(block5); - builder.branch(v0, block10, block8, SourceInfo::dummy()); - - builder.move_to_block(block6); - builder.jump(block9, SourceInfo::dummy()); - - builder.move_to_block(block7); - builder.jump(block12, SourceInfo::dummy()); - - builder.move_to_block(block8); - builder.jump(block11, SourceInfo::dummy()); - - builder.move_to_block(block9); - builder.jump(block8, SourceInfo::dummy()); - - builder.move_to_block(block10); - builder.jump(block11, SourceInfo::dummy()); - - builder.move_to_block(block11); - builder.branch(v0, block12, block2, SourceInfo::dummy()); - - builder.move_to_block(block12); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let (dom_tree, _) = calc_dom(&func); - let entry_block = func.order.entry(); - assert_eq!(dom_tree.idom(entry_block), None); - assert_eq!(dom_tree.idom(block1), Some(entry_block)); - assert_eq!(dom_tree.idom(block2), Some(entry_block)); - assert_eq!(dom_tree.idom(block3), Some(block1)); - assert_eq!(dom_tree.idom(block4), Some(block2)); - assert_eq!(dom_tree.idom(block5), Some(block3)); - assert_eq!(dom_tree.idom(block6), Some(block1)); - assert_eq!(dom_tree.idom(block7), Some(block2)); - assert_eq!(dom_tree.idom(block8), Some(block1)); - assert_eq!(dom_tree.idom(block9), Some(block6)); - assert_eq!(dom_tree.idom(block10), Some(block5)); - assert_eq!(dom_tree.idom(block11), Some(block1)); - assert_eq!(dom_tree.idom(block12), Some(entry_block)); - } -} +// #[cfg(test)] +// mod tests { +// use super::*; + +// use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; + +// fn calc_dom(func: &FunctionBody) -> (DomTree, DFSet) { +// let cfg = ControlFlowGraph::compute(func); +// let domtree = DomTree::compute(&cfg); +// let df = domtree.compute_df(&cfg); +// (domtree, df) +// } + +// fn body_builder() -> BodyBuilder { +// BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// } + +// #[test] +// fn dom_tree_if_else() { +// let mut builder = body_builder(); + +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, df) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(then_block), Some(entry_block)); +// assert_eq!(dom_tree.idom(else_block), Some(entry_block)); +// assert_eq!(dom_tree.idom(merge_block), Some(entry_block)); + +// assert_eq!(df.frontier_num(entry_block), 0); +// assert_eq!(df.frontier_num(then_block), 1); +// assert_eq!( +// df.frontiers(then_block).unwrap().next().unwrap(), +// merge_block +// ); +// assert_eq!( +// df.frontiers(else_block).unwrap().next().unwrap(), +// merge_block +// ); +// assert_eq!(df.frontier_num(merge_block), 0); +// } + +// #[test] +// fn unreachable_edge() { +// let mut builder = body_builder(); + +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.jump(block4, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, _) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(block1), Some(entry_block)); +// assert_eq!(dom_tree.idom(block2), Some(entry_block)); +// assert_eq!(dom_tree.idom(block3), None); +// assert!(!dom_tree.is_reachable(block3)); +// assert_eq!(dom_tree.idom(block4), Some(entry_block)); +// } + +// #[test] +// fn dom_tree_complex() { +// let mut builder = body_builder(); + +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); +// let block5 = builder.make_block(); +// let block6 = builder.make_block(); +// let block7 = builder.make_block(); +// let block8 = builder.make_block(); +// let block9 = builder.make_block(); +// let block10 = builder.make_block(); +// let block11 = builder.make_block(); +// let block12 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, block2, block1, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.branch(v0, block6, block3, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block7, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.branch(v0, block6, block5, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.branch(v0, block7, block2, SourceInfo::dummy()); + +// builder.move_to_block(block5); +// builder.branch(v0, block10, block8, SourceInfo::dummy()); + +// builder.move_to_block(block6); +// builder.jump(block9, SourceInfo::dummy()); + +// builder.move_to_block(block7); +// builder.jump(block12, SourceInfo::dummy()); + +// builder.move_to_block(block8); +// builder.jump(block11, SourceInfo::dummy()); + +// builder.move_to_block(block9); +// builder.jump(block8, SourceInfo::dummy()); + +// builder.move_to_block(block10); +// builder.jump(block11, SourceInfo::dummy()); + +// builder.move_to_block(block11); +// builder.branch(v0, block12, block2, SourceInfo::dummy()); + +// builder.move_to_block(block12); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let (dom_tree, _) = calc_dom(&func); +// let entry_block = func.order.entry(); +// assert_eq!(dom_tree.idom(entry_block), None); +// assert_eq!(dom_tree.idom(block1), Some(entry_block)); +// assert_eq!(dom_tree.idom(block2), Some(entry_block)); +// assert_eq!(dom_tree.idom(block3), Some(block1)); +// assert_eq!(dom_tree.idom(block4), Some(block2)); +// assert_eq!(dom_tree.idom(block5), Some(block3)); +// assert_eq!(dom_tree.idom(block6), Some(block1)); +// assert_eq!(dom_tree.idom(block7), Some(block2)); +// assert_eq!(dom_tree.idom(block8), Some(block1)); +// assert_eq!(dom_tree.idom(block9), Some(block6)); +// assert_eq!(dom_tree.idom(block10), Some(block5)); +// assert_eq!(dom_tree.idom(block11), Some(block1)); +// assert_eq!(dom_tree.idom(block12), Some(entry_block)); +// } +// } diff --git a/crates/mir2/src/analysis/loop_tree.rs b/crates/mir2/src/analysis/loop_tree.rs index ca13db5dcb..c818019b02 100644 --- a/crates/mir2/src/analysis/loop_tree.rs +++ b/crates/mir2/src/analysis/loop_tree.rs @@ -236,112 +236,117 @@ enum BlockState { Finished, } -#[cfg(test)] -mod tests { - use super::*; +// #[cfg(test)] +// mod tests { +// use super::*; - use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; +// // use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId}; +// use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId}; - fn compute_loop(func: &FunctionBody) -> LoopTree { - let cfg = ControlFlowGraph::compute(func); - let domtree = DomTree::compute(&cfg); - LoopTree::compute(&cfg, &domtree) - } +// fn compute_loop(func: &FunctionBody) -> LoopTree { +// let cfg = ControlFlowGraph::compute(func); +// let domtree = DomTree::compute(&cfg); +// LoopTree::compute(&cfg, &domtree) +// } - fn body_builder() -> BodyBuilder { - BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) - } +// fn body_builder() -> BodyBuilder { +// // BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// BodyBuilder::new(FunctionId(0)) +// } - #[test] - fn simple_loop() { - let mut builder = body_builder(); +// #[test] +// fn simple_loop() { +// let mut builder = body_builder(); - let entry = builder.current_block(); - let block1 = builder.make_block(); - let block2 = builder.make_block(); +// let entry = builder.current_block(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(false, dummy_ty); - builder.branch(v0, block1, block2, SourceInfo::dummy()); +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(false, dummy_ty); +// // builder.branch(v0, block1, block2, SourceInfo::dummy()); +// builder.branch(v0, block1, block2); - builder.move_to_block(block1); - builder.jump(entry, SourceInfo::dummy()); +// builder.move_to_block(block1); +// // builder.jump(entry, SourceInfo::dummy()); +// builder.jump(entry); - builder.move_to_block(block2); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); +// builder.move_to_block(block2); +// let dummy_value = builder.make_unit(dummy_ty); +// // builder.ret(dummy_value, SourceInfo::dummy()); +// builder.ret(dummy_value); - let func = builder.build(); +// let func = builder.build(); - let lpt = compute_loop(&func); +// let lpt = compute_loop(&func); - assert_eq!(lpt.loop_num(), 1); - let lp = lpt.loops().next().unwrap(); +// assert_eq!(lpt.loop_num(), 1); +// let lp = lpt.loops().next().unwrap(); - assert!(lpt.is_block_in_loop(entry, lp)); - assert_eq!(lpt.loop_of_block(entry), Some(lp)); +// assert!(lpt.is_block_in_loop(entry, lp)); +// assert_eq!(lpt.loop_of_block(entry), Some(lp)); - assert!(lpt.is_block_in_loop(block1, lp)); - assert_eq!(lpt.loop_of_block(block1), Some(lp)); +// assert!(lpt.is_block_in_loop(block1, lp)); +// assert_eq!(lpt.loop_of_block(block1), Some(lp)); - assert!(!lpt.is_block_in_loop(block2, lp)); - assert!(lpt.loop_of_block(block2).is_none()); +// assert!(!lpt.is_block_in_loop(block2, lp)); +// assert!(lpt.loop_of_block(block2).is_none()); - assert_eq!(lpt.loop_header(lp), entry); - } +// assert_eq!(lpt.loop_header(lp), entry); +// } - #[test] - fn nested_loop() { - let mut builder = body_builder(); +// #[test] +// fn nested_loop() { +// let mut builder = body_builder(); - let entry = builder.current_block(); - let block1 = builder.make_block(); - let block2 = builder.make_block(); - let block3 = builder.make_block(); +// let entry = builder.current_block(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(false, dummy_ty); - builder.branch(v0, block1, block3, SourceInfo::dummy()); +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(false, dummy_ty); +// builder.branch(v0, block1, block3, SourceInfo::dummy()); - builder.move_to_block(block1); - builder.branch(v0, entry, block2, SourceInfo::dummy()); +// builder.move_to_block(block1); +// builder.branch(v0, entry, block2, SourceInfo::dummy()); - builder.move_to_block(block2); - builder.jump(block1, SourceInfo::dummy()); +// builder.move_to_block(block2); +// builder.jump(block1, SourceInfo::dummy()); - builder.move_to_block(block3); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); +// builder.move_to_block(block3); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); - let func = builder.build(); +// let func = builder.build(); - let lpt = compute_loop(&func); +// let lpt = compute_loop(&func); - assert_eq!(lpt.loop_num(), 2); - let mut loops = lpt.loops(); - let outer_lp = loops.next().unwrap(); - let inner_lp = loops.next().unwrap(); +// assert_eq!(lpt.loop_num(), 2); +// let mut loops = lpt.loops(); +// let outer_lp = loops.next().unwrap(); +// let inner_lp = loops.next().unwrap(); - assert!(lpt.is_block_in_loop(entry, outer_lp)); - assert!(!lpt.is_block_in_loop(entry, inner_lp)); - assert_eq!(lpt.loop_of_block(entry), Some(outer_lp)); +// assert!(lpt.is_block_in_loop(entry, outer_lp)); +// assert!(!lpt.is_block_in_loop(entry, inner_lp)); +// assert_eq!(lpt.loop_of_block(entry), Some(outer_lp)); - assert!(lpt.is_block_in_loop(block1, outer_lp)); - assert!(lpt.is_block_in_loop(block1, inner_lp)); - assert_eq!(lpt.loop_of_block(block1), Some(inner_lp)); +// assert!(lpt.is_block_in_loop(block1, outer_lp)); +// assert!(lpt.is_block_in_loop(block1, inner_lp)); +// assert_eq!(lpt.loop_of_block(block1), Some(inner_lp)); - assert!(lpt.is_block_in_loop(block2, outer_lp)); - assert!(lpt.is_block_in_loop(block2, inner_lp)); - assert_eq!(lpt.loop_of_block(block2), Some(inner_lp)); +// assert!(lpt.is_block_in_loop(block2, outer_lp)); +// assert!(lpt.is_block_in_loop(block2, inner_lp)); +// assert_eq!(lpt.loop_of_block(block2), Some(inner_lp)); - assert!(!lpt.is_block_in_loop(block3, outer_lp)); - assert!(!lpt.is_block_in_loop(block3, inner_lp)); - assert!(lpt.loop_of_block(block3).is_none()); +// assert!(!lpt.is_block_in_loop(block3, outer_lp)); +// assert!(!lpt.is_block_in_loop(block3, inner_lp)); +// assert!(lpt.loop_of_block(block3).is_none()); - assert!(lpt.parent_loop(outer_lp).is_none()); - assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp)); +// assert!(lpt.parent_loop(outer_lp).is_none()); +// assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp)); - assert_eq!(lpt.loop_header(outer_lp), entry); - assert_eq!(lpt.loop_header(inner_lp), block1); - } -} +// assert_eq!(lpt.loop_header(outer_lp), entry); +// assert_eq!(lpt.loop_header(inner_lp), block1); +// } +// } diff --git a/crates/mir2/src/analysis/post_domtree.rs b/crates/mir2/src/analysis/post_domtree.rs index ba33aab5f0..9d034d2bb6 100644 --- a/crates/mir2/src/analysis/post_domtree.rs +++ b/crates/mir2/src/analysis/post_domtree.rs @@ -71,214 +71,214 @@ pub enum PostIDom { Block(BasicBlockId), } -#[cfg(test)] -mod tests { - use super::*; - - use crate::ir::{body_builder::BodyBuilder, FunctionId, SourceInfo, TypeId}; - - fn body_builder() -> BodyBuilder { - BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) - } - - #[test] - fn test_if_else_merge() { - let mut builder = body_builder(); - let then_block = builder.make_block(); - let else_block = builder.make_block(); - let merge_block = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, then_block, else_block, SourceInfo::dummy()); - - builder.move_to_block(then_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(else_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(merge_block); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let post_dom_tree = PostDomTree::compute(&func); - let entry_block = func.order.entry(); - assert_eq!( - post_dom_tree.post_idom(entry_block), - PostIDom::Block(merge_block) - ); - assert_eq!( - post_dom_tree.post_idom(then_block), - PostIDom::Block(merge_block) - ); - assert_eq!( - post_dom_tree.post_idom(else_block), - PostIDom::Block(merge_block) - ); - assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); - } - - #[test] - fn test_if_else_return() { - let mut builder = body_builder(); - let then_block = builder.make_block(); - let else_block = builder.make_block(); - let merge_block = builder.make_block(); - - let dummy_ty = TypeId(0); - let dummy_value = builder.make_unit(dummy_ty); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, then_block, else_block, SourceInfo::dummy()); - - builder.move_to_block(then_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(else_block); - builder.ret(dummy_value, SourceInfo::dummy()); - - builder.move_to_block(merge_block); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let post_dom_tree = PostDomTree::compute(&func); - let entry_block = func.order.entry(); - assert_eq!(post_dom_tree.post_idom(entry_block), PostIDom::DummyExit,); - assert_eq!( - post_dom_tree.post_idom(then_block), - PostIDom::Block(merge_block), - ); - assert_eq!(post_dom_tree.post_idom(else_block), PostIDom::DummyExit); - assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); - } - - #[test] - fn test_if_non_else() { - let mut builder = body_builder(); - let then_block = builder.make_block(); - let merge_block = builder.make_block(); - - let dummy_ty = TypeId(0); - let dummy_value = builder.make_unit(dummy_ty); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - builder.branch(v0, then_block, merge_block, SourceInfo::dummy()); - - builder.move_to_block(then_block); - builder.jump(merge_block, SourceInfo::dummy()); - - builder.move_to_block(merge_block); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let post_dom_tree = PostDomTree::compute(&func); - let entry_block = func.order.entry(); - assert_eq!( - post_dom_tree.post_idom(entry_block), - PostIDom::Block(merge_block), - ); - assert_eq!( - post_dom_tree.post_idom(then_block), - PostIDom::Block(merge_block), - ); - assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); - } - - #[test] - fn test_loop() { - let mut builder = body_builder(); - let block1 = builder.make_block(); - let block2 = builder.make_block(); - let block3 = builder.make_block(); - let block4 = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - - builder.branch(v0, block1, block2, SourceInfo::dummy()); - - builder.move_to_block(block1); - builder.jump(block3, SourceInfo::dummy()); - - builder.move_to_block(block2); - builder.branch(v0, block3, block4, SourceInfo::dummy()); - - builder.move_to_block(block3); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - builder.move_to_block(block4); - builder.jump(block2, SourceInfo::dummy()); - - let func = builder.build(); - - let post_dom_tree = PostDomTree::compute(&func); - let entry_block = func.order.entry(); - assert_eq!( - post_dom_tree.post_idom(entry_block), - PostIDom::Block(block3), - ); - assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block3)); - assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block3)); - assert_eq!(post_dom_tree.post_idom(block3), PostIDom::DummyExit); - assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block2)); - } - - #[test] - fn test_pd_complex() { - let mut builder = body_builder(); - let block1 = builder.make_block(); - let block2 = builder.make_block(); - let block3 = builder.make_block(); - let block4 = builder.make_block(); - let block5 = builder.make_block(); - let block6 = builder.make_block(); - let block7 = builder.make_block(); - - let dummy_ty = TypeId(0); - let v0 = builder.make_imm_from_bool(true, dummy_ty); - - builder.branch(v0, block1, block2, SourceInfo::dummy()); - - builder.move_to_block(block1); - builder.jump(block6, SourceInfo::dummy()); - - builder.move_to_block(block2); - builder.branch(v0, block3, block4, SourceInfo::dummy()); - - builder.move_to_block(block3); - builder.jump(block5, SourceInfo::dummy()); - - builder.move_to_block(block4); - builder.jump(block5, SourceInfo::dummy()); - - builder.move_to_block(block5); - builder.jump(block6, SourceInfo::dummy()); - - builder.move_to_block(block6); - builder.jump(block7, SourceInfo::dummy()); - - builder.move_to_block(block7); - let dummy_value = builder.make_unit(dummy_ty); - builder.ret(dummy_value, SourceInfo::dummy()); - - let func = builder.build(); - - let post_dom_tree = PostDomTree::compute(&func); - let entry_block = func.order.entry(); - assert_eq!( - post_dom_tree.post_idom(entry_block), - PostIDom::Block(block6), - ); - assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block6)); - assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block5)); - assert_eq!(post_dom_tree.post_idom(block3), PostIDom::Block(block5)); - assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block5)); - assert_eq!(post_dom_tree.post_idom(block5), PostIDom::Block(block6)); - assert_eq!(post_dom_tree.post_idom(block6), PostIDom::Block(block7)); - assert_eq!(post_dom_tree.post_idom(block7), PostIDom::DummyExit); - } -} +// #[cfg(test)] +// mod tests { +// use super::*; + +// use crate::ir::{body_builder::BodyBuilder, FunctionId, SourceInfo, TypeId}; + +// fn body_builder() -> BodyBuilder { +// BodyBuilder::new(FunctionId(0), SourceInfo::dummy()) +// } + +// #[test] +// fn test_if_else_merge() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!( +// post_dom_tree.post_idom(else_block), +// PostIDom::Block(merge_block) +// ); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_if_else_return() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let else_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let dummy_value = builder.make_unit(dummy_ty); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, else_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(else_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!(post_dom_tree.post_idom(entry_block), PostIDom::DummyExit,); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!(post_dom_tree.post_idom(else_block), PostIDom::DummyExit); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_if_non_else() { +// let mut builder = body_builder(); +// let then_block = builder.make_block(); +// let merge_block = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let dummy_value = builder.make_unit(dummy_ty); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); +// builder.branch(v0, then_block, merge_block, SourceInfo::dummy()); + +// builder.move_to_block(then_block); +// builder.jump(merge_block, SourceInfo::dummy()); + +// builder.move_to_block(merge_block); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!( +// post_dom_tree.post_idom(then_block), +// PostIDom::Block(merge_block), +// ); +// assert_eq!(post_dom_tree.post_idom(merge_block), PostIDom::DummyExit); +// } + +// #[test] +// fn test_loop() { +// let mut builder = body_builder(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); + +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block3, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block3, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.jump(block2, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(block3), +// ); +// assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block3)); +// assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block3)); +// assert_eq!(post_dom_tree.post_idom(block3), PostIDom::DummyExit); +// assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block2)); +// } + +// #[test] +// fn test_pd_complex() { +// let mut builder = body_builder(); +// let block1 = builder.make_block(); +// let block2 = builder.make_block(); +// let block3 = builder.make_block(); +// let block4 = builder.make_block(); +// let block5 = builder.make_block(); +// let block6 = builder.make_block(); +// let block7 = builder.make_block(); + +// let dummy_ty = TypeId(0); +// let v0 = builder.make_imm_from_bool(true, dummy_ty); + +// builder.branch(v0, block1, block2, SourceInfo::dummy()); + +// builder.move_to_block(block1); +// builder.jump(block6, SourceInfo::dummy()); + +// builder.move_to_block(block2); +// builder.branch(v0, block3, block4, SourceInfo::dummy()); + +// builder.move_to_block(block3); +// builder.jump(block5, SourceInfo::dummy()); + +// builder.move_to_block(block4); +// builder.jump(block5, SourceInfo::dummy()); + +// builder.move_to_block(block5); +// builder.jump(block6, SourceInfo::dummy()); + +// builder.move_to_block(block6); +// builder.jump(block7, SourceInfo::dummy()); + +// builder.move_to_block(block7); +// let dummy_value = builder.make_unit(dummy_ty); +// builder.ret(dummy_value, SourceInfo::dummy()); + +// let func = builder.build(); + +// let post_dom_tree = PostDomTree::compute(&func); +// let entry_block = func.order.entry(); +// assert_eq!( +// post_dom_tree.post_idom(entry_block), +// PostIDom::Block(block6), +// ); +// assert_eq!(post_dom_tree.post_idom(block1), PostIDom::Block(block6)); +// assert_eq!(post_dom_tree.post_idom(block2), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block3), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block4), PostIDom::Block(block5)); +// assert_eq!(post_dom_tree.post_idom(block5), PostIDom::Block(block6)); +// assert_eq!(post_dom_tree.post_idom(block6), PostIDom::Block(block7)); +// assert_eq!(post_dom_tree.post_idom(block7), PostIDom::DummyExit); +// } +// } diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 97df2a829a..18acee1e8a 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -1,9 +1,9 @@ use hir::HirDb; -// pub mod analysis; +pub mod analysis; // pub mod graphviz; pub mod ir; -pub mod pretty_print; +// pub mod pretty_print; mod lower; From 2ae7a4f5eb5a8f992cce5288c09655ccb6d37731 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Thu, 1 Feb 2024 09:22:54 -0700 Subject: [PATCH 19/22] hacking --- crates/mir2/src/lib.rs | 13 ++ .../db/queries => mir2/src/lower}/constant.rs | 0 crates/mir2/src/lower/function.rs | 208 +++++++++--------- crates/mir2/src/lower/mod.rs | 6 +- 4 files changed, 123 insertions(+), 104 deletions(-) rename crates/{mir2-analysis/src/db/queries => mir2/src/lower}/constant.rs (100%) diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 18acee1e8a..e482e67287 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -19,6 +19,19 @@ pub struct Jar( // ir::InstId, // ir::Value, // ir::ValueId, + // mir_intern_const, + // mir_intern_type, + // mir_intern_function, + // mir_lower_module_all_functions, + // mir_lower_contract_all_functions, + // mir_lower_struct_all_functions, + // mir_lower_enum_all_functions, + // mir_lowered_type, + mir_lowered_constant, + // mir_lowered_func_signature, + // mir_lowered_monomorphized_func_signature, + // mir_lowered_pseudo_monomorphized_func_signature, + // mir_lowered_func_body, ); #[salsa::jar(db = LowerMirDb)] diff --git a/crates/mir2-analysis/src/db/queries/constant.rs b/crates/mir2/src/lower/constant.rs similarity index 100% rename from crates/mir2-analysis/src/db/queries/constant.rs rename to crates/mir2/src/lower/constant.rs diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs index 9b4565aa21..f3abdeb36c 100644 --- a/crates/mir2/src/lower/function.rs +++ b/crates/mir2/src/lower/function.rs @@ -1,14 +1,12 @@ use std::{collections::BTreeMap, rc::Rc, vec}; -use fe_common2::numeric::Literal; -use fe_parser2::{ast, node::Node}; use fxhash::FxHashMap; +use hir::hir_def::{self, TypeId}; use id_arena::{Arena, Id}; use num_bigint::BigInt; use smol_str::SmolStr; use crate::{ - db::MirDb, ir::{ self, body_builder::BodyBuilder, @@ -17,8 +15,9 @@ use crate::{ inst::{CallType, InstKind}, value::{AssignableValue, Local}, BasicBlockId, Constant, FunctionBody, FunctionId, FunctionParam, FunctionSignature, InstId, - SourceInfo, TypeId, Value, ValueId, + Value, ValueId, }, + MirDb, }; type ScopeId = Id; @@ -91,7 +90,7 @@ pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { pub(super) struct BodyLowerHelper<'db, 'a> { pub(super) db: &'db dyn MirDb, pub(super) builder: BodyBuilder, - ast: &'a Node, + ast: &'a Node, func: FunctionId, analyzer_body: &'a fe_analyzer2::context::FunctionBody, scopes: Arena, @@ -99,9 +98,9 @@ pub(super) struct BodyLowerHelper<'db, 'a> { } impl<'db, 'a> BodyLowerHelper<'db, 'a> { - pub(super) fn lower_stmt(&mut self, stmt: &Node) { + pub(super) fn lower_stmt(&mut self, stmt: &Node) { match &stmt.kind { - ast::FuncStmt::Return { value } => { + hir_def::FuncStmt::Return { value } => { let value = if let Some(expr) = value { self.lower_expr_to_value(expr) } else { @@ -112,11 +111,11 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(next_block); } - ast::FuncStmt::VarDecl { target, value, .. } => { + hir_def::FuncStmt::VarDecl { target, value, .. } => { self.lower_var_decl(target, value.as_ref(), stmt.into()); } - ast::FuncStmt::ConstantDecl { name, value, .. } => { + hir_def::FuncStmt::ConstantDecl { name, value, .. } => { let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&name.id]); let value = self.analyzer_body.expressions[&value.id] @@ -129,13 +128,13 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.scope_mut().declare_var(&name.kind, constant); } - ast::FuncStmt::Assign { target, value } => { + hir_def::FuncStmt::Assign { target, value } => { let result = self.lower_assignable_value(target); let (expr, _ty) = self.lower_expr(value); self.builder.map_result(expr, result) } - ast::FuncStmt::AugAssign { target, op, value } => { + hir_def::FuncStmt::AugAssign { target, op, value } => { let result = self.lower_assignable_value(target); let lhs = self.lower_expr_to_value(target); let rhs = self.lower_expr_to_value(value); @@ -144,9 +143,11 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.map_result(inst, result) } - ast::FuncStmt::For { target, iter, body } => self.lower_for_loop(target, iter, body), + hir_def::FuncStmt::For { target, iter, body } => { + self.lower_for_loop(target, iter, body) + } - ast::FuncStmt::While { test, body } => { + hir_def::FuncStmt::While { test, body } => { let header_bb = self.builder.make_block(); let exit_bb = self.builder.make_block(); @@ -170,18 +171,18 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(exit_bb); } - ast::FuncStmt::If { + hir_def::FuncStmt::If { test, body, or_else, } => self.lower_if(test, body, or_else), - ast::FuncStmt::Match { expr, arms } => { + hir_def::FuncStmt::Match { expr, arms } => { let matrix = &self.analyzer_body.matches[&stmt.id]; super::pattern_match::lower_match(self, matrix, expr, arms); } - ast::FuncStmt::Assert { test, msg } => { + hir_def::FuncStmt::Assert { test, msg } => { let then_bb = self.builder.make_block(); let false_bb = self.builder.make_block(); @@ -199,18 +200,18 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(then_bb); } - ast::FuncStmt::Expr { value } => { + hir_def::FuncStmt::Expr { value } => { self.lower_expr_to_value(value); } - ast::FuncStmt::Break => { + hir_def::FuncStmt::Break => { let exit = self.scope().loop_exit(&self.scopes); self.builder.jump(exit, stmt.into()); let next_block = self.builder.make_block(); self.builder.move_to_block(next_block); } - ast::FuncStmt::Continue => { + hir_def::FuncStmt::Continue => { let entry = self.scope().loop_entry(&self.scopes); if let Some(loop_idx) = self.scope().loop_idx(&self.scopes) { let imm_one = self.make_u256_imm(1u32); @@ -226,14 +227,14 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(next_block); } - ast::FuncStmt::Revert { error } => { + hir_def::FuncStmt::Revert { error } => { let error = error.as_ref().map(|err| self.lower_expr_to_value(err)); self.builder.revert(error, stmt.into()); let next_block = self.builder.make_block(); self.builder.move_to_block(next_block); } - ast::FuncStmt::Unsafe(stmts) => { + hir_def::FuncStmt::Unsafe(stmts) => { self.enter_scope(); for stmt in stmts { self.lower_stmt(stmt) @@ -245,12 +246,12 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { pub(super) fn lower_var_decl( &mut self, - var: &Node, - init: Option<&Node>, + var: &Node, + init: Option<&Node>, source: SourceInfo, ) { match &var.kind { - ast::VarDeclTarget::Name(name) => { + hir_def::VarDeclTarget::Name(name) => { let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); let value = self.declare_var(name, ty, var.into()); if let Some(init) = init { @@ -262,9 +263,9 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - ast::VarDeclTarget::Tuple(decls) => { + hir_def::VarDeclTarget::Tuple(decls) => { if let Some(init) = init { - if let ast::Expr::Tuple { elts } = &init.kind { + if let hir_def::Expr::Tuple { elts } = &init.kind { debug_assert_eq!(decls.len(), elts.len()); for (decl, init_elem) in decls.iter().zip(elts.iter()) { self.lower_var_decl(decl, Some(init_elem), source.clone()); @@ -287,9 +288,10 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { &mut self, name: &SmolStr, ty: TypeId, - source: SourceInfo, + // source: SourceInfo, ) -> ValueId { - let local = Local::user_local(name.clone(), ty, source); + // let local = Local::user_local(name.clone(), ty, source); + let local = Local::user_local(name.clone(), ty); let value = self.builder.declare(local); self.scope_mut().declare_var(name, value); value @@ -297,13 +299,13 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { pub(super) fn lower_var_decl_unpack( &mut self, - var: &Node, + var: &Node, init: ValueId, init_ty: TypeId, - source: SourceInfo, + // source: SourceInfo, ) { match &var.kind { - ast::VarDeclTarget::Name(name) => { + hir_def::VarDeclTarget::Name(name) => { let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); let local = Local::user_local(name.clone(), ty, var.into()); @@ -313,7 +315,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.map_result(bind, lhs.into()); } - ast::VarDeclTarget::Tuple(decls) => { + hir_def::VarDeclTarget::Tuple(decls) => { for (index, decl) in decls.iter().enumerate() { let elem_ty = init_ty.projection_ty_imm(self.db, index); let index_value = self.make_u256_imm(index); @@ -327,10 +329,10 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - pub(super) fn lower_expr(&mut self, expr: &Node) -> (InstId, TypeId) { + pub(super) fn lower_expr(&mut self, expr: &Node) -> (InstId, TypeId) { let mut ty = self.expr_ty(expr); let mut inst = match &expr.kind { - ast::Expr::Ternary { + hir_def::Expr::Ternary { if_expr, test, else_expr, @@ -361,38 +363,38 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.bind(tmp, SourceInfo::dummy()) } - ast::Expr::BoolOperation { left, op, right } => { + hir_def::Expr::BoolOperation { left, op, right } => { self.lower_bool_op(op.kind, left, right, ty) } - ast::Expr::BinOperation { left, op, right } => { + hir_def::Expr::BinOperation { left, op, right } => { let lhs = self.lower_expr_to_value(left); let rhs = self.lower_expr_to_value(right); self.lower_binop(op.kind, lhs, rhs, expr.into()) } - ast::Expr::UnaryOperation { op, operand } => { + hir_def::Expr::UnaryOperation { op, operand } => { let value = self.lower_expr_to_value(operand); match op.kind { - ast::UnaryOperator::Invert => self.builder.inv(value, expr.into()), - ast::UnaryOperator::Not => self.builder.not(value, expr.into()), - ast::UnaryOperator::USub => self.builder.neg(value, expr.into()), + hir_def::UnaryOperator::Invert => self.builder.inv(value, expr.into()), + hir_def::UnaryOperator::Not => self.builder.not(value, expr.into()), + hir_def::UnaryOperator::USub => self.builder.neg(value, expr.into()), } } - ast::Expr::CompOperation { left, op, right } => { + hir_def::Expr::CompOperation { left, op, right } => { let lhs = self.lower_expr_to_value(left); let rhs = self.lower_expr_to_value(right); self.lower_comp_op(op.kind, lhs, rhs, expr.into()) } - ast::Expr::Attribute { .. } => { + hir_def::Expr::Attribute { .. } => { let mut indices = vec![]; let value = self.lower_aggregate_access(expr, &mut indices); self.builder.aggregate_access(value, indices, expr.into()) } - ast::Expr::Subscript { value, index } => { + hir_def::Expr::Subscript { value, index } => { let value_ty = self.expr_ty(value).deref(self.db); if value_ty.is_aggregate(self.db) { let mut indices = vec![]; @@ -407,7 +409,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - ast::Expr::Call { + hir_def::Expr::Call { func, generic_args, args, @@ -416,7 +418,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.lower_call(func, generic_args, &args.kind, ty, expr.into()) } - ast::Expr::List { elts } | ast::Expr::Tuple { elts } => { + hir_def::Expr::List { elts } | hir_def::Expr::Tuple { elts } => { let args = elts .iter() .map(|elem| self.lower_expr_to_value(elem)) @@ -425,7 +427,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.aggregate_construct(ty, args, expr.into()) } - ast::Expr::Repeat { value, len: _ } => { + hir_def::Expr::Repeat { value, len: _ } => { let array_type = if let Type::Array(array_type) = self.analyzer_body.expressions [&expr.id] .typ @@ -441,28 +443,28 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.aggregate_construct(ty, args, expr.into()) } - ast::Expr::Bool(b) => { + hir_def::Expr::Bool(b) => { let imm = self.builder.make_imm_from_bool(*b, ty); self.builder.bind(imm, expr.into()) } - ast::Expr::Name(name) => { + hir_def::Expr::Name(name) => { let value = self.resolve_name(name); self.builder.bind(value, expr.into()) } - ast::Expr::Path(path) => { + hir_def::Expr::Path(path) => { let value = self.resolve_path(path, expr.into()); self.builder.bind(value, expr.into()) } - ast::Expr::Num(num) => { + hir_def::Expr::Num(num) => { let imm = Literal::new(num).parse().unwrap(); let imm = self.builder.make_imm(imm, ty); self.builder.bind(imm, expr.into()) } - ast::Expr::Str(s) => { + hir_def::Expr::Str(s) => { let ty = self.expr_ty(expr); let const_value = self.make_local_constant( "str_in_func".into(), @@ -473,7 +475,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.bind(const_value, expr.into()) } - ast::Expr::Unit => { + hir_def::Expr::Unit => { let value = self.make_unit(); self.builder.bind(value, expr.into()) } @@ -510,7 +512,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .unwrap_or_else(|| self.map_to_tmp(inst, ty)) } - pub(super) fn lower_expr_to_value(&mut self, expr: &Node) -> ValueId { + pub(super) fn lower_expr_to_value(&mut self, expr: &Node) -> ValueId { let (inst, ty) = self.lower_expr(expr); self.map_to_tmp(inst, ty) } @@ -555,7 +557,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn new( db: &'db dyn MirDb, func: FunctionId, - ast: &'a Node, + ast: &'a Node, analyzer_body: &'a fe_analyzer2::context::FunctionBody, ) -> Self { let mut builder = BodyBuilder::new(func, ast.into()); @@ -624,9 +626,9 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn lower_if( &mut self, - cond: &Node, - then: &[Node], - else_: &[Node], + cond: &Node, + then: &[Node], + else_: &[Node], ) { let cond = self.lower_expr_to_value(cond); @@ -689,8 +691,8 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn lower_for_loop( &mut self, loop_variable: &Node, - iter: &Node, - body: &[Node], + iter: &Node, + body: &[Node], ) { let preheader_bb = self.builder.make_block(); let entry_bb = self.builder.make_block(); @@ -775,15 +777,15 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.move_to_block(exit_bb); } - fn lower_assignable_value(&mut self, expr: &Node) -> AssignableValue { + fn lower_assignable_value(&mut self, expr: &Node) -> AssignableValue { match &expr.kind { - ast::Expr::Attribute { value, attr } => { + hir_def::Expr::Attribute { value, attr } => { let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); let idx = self.make_u256_imm(idx); let lhs = self.lower_assignable_value(value).into(); AssignableValue::Aggregate { lhs, idx } } - ast::Expr::Subscript { value, index } => { + hir_def::Expr::Subscript { value, index } => { let lhs = self.lower_assignable_value(value).into(); let attr = self.lower_expr_to_value(index); let value_ty = self.expr_ty(value).deref(self.db); @@ -795,23 +797,23 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { unreachable!() } } - ast::Expr::Name(name) => self.resolve_name(name).into(), - ast::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), + hir_def::Expr::Name(name) => self.resolve_name(name).into(), + hir_def::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), _ => self.lower_expr_to_value(expr).into(), } } /// Returns the pre-adjustment type of the given `Expr` - fn expr_ty(&self, expr: &Node) -> TypeId { + fn expr_ty(&self, expr: &Node) -> TypeId { let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; self.lower_analyzer_type(analyzer_ty) } fn lower_bool_op( &mut self, - op: ast::BoolOperator, - lhs: &Node, - rhs: &Node, + op: hir_def::BoolOperator, + lhs: &Node, + rhs: &Node, ty: TypeId, ) -> InstId { let true_bb = self.builder.make_block(); @@ -824,7 +826,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); match op { - ast::BoolOperator::And => { + hir_def::BoolOperator::And => { self.builder .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); @@ -840,7 +842,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { self.builder.jump(merge_bb, SourceInfo::dummy()); } - ast::BoolOperator::Or => { + hir_def::BoolOperator::Or => { self.builder .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); @@ -863,40 +865,40 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn lower_binop( &mut self, - op: ast::BinOperator, + op: hir_def::BinOp, lhs: ValueId, rhs: ValueId, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { match op { - ast::BinOperator::Add => self.builder.add(lhs, rhs, source), - ast::BinOperator::Sub => self.builder.sub(lhs, rhs, source), - ast::BinOperator::Mult => self.builder.mul(lhs, rhs, source), - ast::BinOperator::Div => self.builder.div(lhs, rhs, source), - ast::BinOperator::Mod => self.builder.modulo(lhs, rhs, source), - ast::BinOperator::Pow => self.builder.pow(lhs, rhs, source), - ast::BinOperator::LShift => self.builder.shl(lhs, rhs, source), - ast::BinOperator::RShift => self.builder.shr(lhs, rhs, source), - ast::BinOperator::BitOr => self.builder.bit_or(lhs, rhs, source), - ast::BinOperator::BitXor => self.builder.bit_xor(lhs, rhs, source), - ast::BinOperator::BitAnd => self.builder.bit_and(lhs, rhs, source), + hir_def::BinOp::Add => self.builder.add(lhs, rhs), + hir_def::BinOp::Sub => self.builder.sub(lhs, rhs), + hir_def::BinOp::Mult => self.builder.mul(lhs, rhs), + hir_def::BinOp::Div => self.builder.div(lhs, rhs), + hir_def::BinOp::Mod => self.builder.modulo(lhs, rhs), + hir_def::BinOp::Pow => self.builder.pow(lhs, rhs), + hir_def::BinOp::LShift => self.builder.shl(lhs, rhs), + hir_def::BinOp::RShift => self.builder.shr(lhs, rhs), + hir_def::BinOp::BitOr => self.builder.bit_or(lhs, rhs), + hir_def::BinOp::BitXor => self.builder.bit_xor(lhs, rhs), + hir_def::BinOp::BitAnd => self.builder.bit_and(lhs, rhs), } } fn lower_comp_op( &mut self, - op: ast::CompOperator, + op: hir_def::CompBinOp, lhs: ValueId, rhs: ValueId, - source: SourceInfo, + // source: SourceInfo, ) -> InstId { match op { - ast::CompOperator::Eq => self.builder.eq(lhs, rhs, source), - ast::CompOperator::NotEq => self.builder.ne(lhs, rhs, source), - ast::CompOperator::Lt => self.builder.lt(lhs, rhs, source), - ast::CompOperator::LtE => self.builder.le(lhs, rhs, source), - ast::CompOperator::Gt => self.builder.gt(lhs, rhs, source), - ast::CompOperator::GtE => self.builder.ge(lhs, rhs, source), + hir_def::CompBinOp::Eq => self.builder.eq(lhs, rhs), + hir_def::CompBinOp::NotEq => self.builder.ne(lhs, rhs), + hir_def::CompBinOp::Lt => self.builder.lt(lhs, rhs), + hir_def::CompBinOp::LtE => self.builder.le(lhs, rhs), + hir_def::CompBinOp::Gt => self.builder.gt(lhs, rhs), + hir_def::CompBinOp::GtE => self.builder.ge(lhs, rhs), } } @@ -943,9 +945,9 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn lower_call( &mut self, - func: &Node, - _generic_args: &Option>>, - args: &[Node], + func: &Node, + _generic_args: &Option>>, + args: &[Node], ty: TypeId, source: SourceInfo, ) -> InstId { @@ -1081,27 +1083,27 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. - fn lower_method_receiver(&mut self, receiver: &Node) -> ValueId { + fn lower_method_receiver(&mut self, receiver: &Node) -> ValueId { match &receiver.kind { - ast::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), + hir_def::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), _ => unreachable!(), } } fn lower_aggregate_access( &mut self, - expr: &Node, + expr: &Node, indices: &mut Vec, ) -> ValueId { match &expr.kind { - ast::Expr::Attribute { value, attr } => { + hir_def::Expr::Attribute { value, attr } => { let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); let value = self.lower_aggregate_access(value, indices); indices.push(self.make_u256_imm(index)); value } - ast::Expr::Subscript { value, index } + hir_def::Expr::Subscript { value, index } if self.expr_ty(value).deref(self.db).is_aggregate(self.db) => { let value = self.lower_aggregate_access(value, indices); @@ -1181,7 +1183,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { /// Resolve a path appeared in an expression. /// NOTE: Don't call this to resolve method receiver. - fn resolve_path(&mut self, path: &ast::Path, source: SourceInfo) -> ValueId { + fn resolve_path(&mut self, path: &hir_def::Path, source: SourceInfo) -> ValueId { let func_id = self.builder.func_id(); let module = func_id.module(self.db); match module.resolve_path(self.db.upcast(), path).value.unwrap() { @@ -1318,7 +1320,7 @@ fn self_arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId) -> SourceIn .kind .args .iter() - .find(|arg| matches!(arg.kind, ast::FunctionArg::Self_ { .. })) + .find(|arg| matches!(arg.kind, hir_def::FunctionArg::Self_ { .. })) .unwrap() .into() } @@ -1332,14 +1334,14 @@ fn arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId, arg_name: &str) .args .iter() .find_map(|arg| match &arg.kind { - ast::FunctionArg::Regular { name, .. } => { + hir_def::FunctionArg::Regular { name, .. } => { if name.kind == arg_name { Some(name.into()) } else { None } } - ast::FunctionArg::Self_ { .. } => None, + hir_def::FunctionArg::Self_ { .. } => None, }) .unwrap() } diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs index 8382f804c1..7a5e6e50a3 100644 --- a/crates/mir2/src/lower/mod.rs +++ b/crates/mir2/src/lower/mod.rs @@ -1 +1,5 @@ -use hir::hir_def::Contract; +pub mod constant; +pub mod function; +pub mod types; + +mod pattern_match; From 518d7e5972b7d5180fb88bf6bd90d4c3be59b1c8 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Mon, 5 Feb 2024 12:36:36 -0700 Subject: [PATCH 20/22] hacking --- crates/library2/src/lib.rs | 2 +- crates/mir2/src/ir/value.rs | 4 +- crates/mir2/src/lib.rs | 2 +- crates/mir2/src/lower/constant.rs | 20 +- crates/mir2/src/lower/function.rs | 1749 ++++++++--------- crates/mir2/src/lower/mod.rs | 2 +- .../src/lower/pattern_match/decision_tree.rs | 7 - .../mir2/src/lower/pattern_match/tree_vis.rs | 32 +- crates/mir2/src/lower/types.rs | 388 ++-- 9 files changed, 1025 insertions(+), 1181 deletions(-) diff --git a/crates/library2/src/lib.rs b/crates/library2/src/lib.rs index d222b5a70d..48310e1eda 100644 --- a/crates/library2/src/lib.rs +++ b/crates/library2/src/lib.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; pub use ::include_dir; use common::{ diff --git a/crates/mir2/src/ir/value.rs b/crates/mir2/src/ir/value.rs index 26ed4063a1..4c31129a6e 100644 --- a/crates/mir2/src/ir/value.rs +++ b/crates/mir2/src/ir/value.rs @@ -1,11 +1,11 @@ -use hir::hir_def::{TypeId, TypeKind}; +use hir::hir_def::TypeId; use id_arena::Id; use num_bigint::BigInt; use smol_str::SmolStr; // use crate::db::MirDb; -use super::{constant::ConstantId, function::BodyDataStore, inst::InstId}; +use super::{constant::ConstantId, inst::InstId}; pub type ValueId = Id; diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index e482e67287..49ef940708 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -27,7 +27,7 @@ pub struct Jar( // mir_lower_struct_all_functions, // mir_lower_enum_all_functions, // mir_lowered_type, - mir_lowered_constant, + lower::constant::mir_lowered_constant, // mir_lowered_func_signature, // mir_lowered_monomorphized_func_signature, // mir_lowered_pseudo_monomorphized_func_signature, diff --git a/crates/mir2/src/lower/constant.rs b/crates/mir2/src/lower/constant.rs index edd204d7ff..9ba92820b6 100644 --- a/crates/mir2/src/lower/constant.rs +++ b/crates/mir2/src/lower/constant.rs @@ -1,14 +1,14 @@ use std::rc::Rc; +use hir::hir_def::TypeId; + use crate::{ - db::MirDb, - ir::{Constant, ConstantId, SourceInfo, TypeId}, + ir::{Constant, ConstantId}, + MirDb, }; -pub fn mir_lowered_constant( - db: &dyn MirDb, - analyzer_const: analyzer_items::ModuleConstantId, -) -> ConstantId { +#[salsa::tracked] +pub fn mir_lowered_constant(db: &dyn MirDb, analyzer_const: hir::hir_def::Const) -> ConstantId { let name = analyzer_const.name(db.upcast()); let value = analyzer_const.constant_value(db.upcast()).unwrap(); let ty = analyzer_const.typ(db.upcast()).unwrap(); @@ -17,23 +17,21 @@ pub fn mir_lowered_constant( let id = analyzer_const.node_id(db.upcast()); let ty = db.mir_lowered_type(ty); - let source = SourceInfo { span, id }; let constant = Constant { name, value: value.into(), ty, module_id, - source, }; db.mir_intern_const(constant.into()) } impl ConstantId { - pub fn data(self, db: &dyn MirDb) -> Rc { - db.lookup_mir_intern_const(self) - } + // pub fn data(self, db: &dyn MirDb) -> Rc { + // db.lookup_mir_intern_const(self) + // } pub fn ty(self, db: &dyn MirDb) -> TypeId { self.data(db).ty diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs index f3abdeb36c..7a0c8fb784 100644 --- a/crates/mir2/src/lower/function.rs +++ b/crates/mir2/src/lower/function.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, rc::Rc, vec}; +use std::rc::Rc; use fxhash::FxHashMap; use hir::hir_def::{self, TypeId}; @@ -8,75 +8,14 @@ use smol_str::SmolStr; use crate::{ ir::{ - self, - body_builder::BodyBuilder, - constant::ConstantValue, - function::Linkage, - inst::{CallType, InstKind}, - value::{AssignableValue, Local}, - BasicBlockId, Constant, FunctionBody, FunctionId, FunctionParam, FunctionSignature, InstId, - Value, ValueId, + body_builder::BodyBuilder, inst::InstKind, value::Local, BasicBlockId, FunctionBody, + FunctionId, FunctionParam, InstId, Value, ValueId, }, MirDb, }; type ScopeId = Id; -pub fn lower_func_signature(db: &dyn MirDb, func: analyzer_items::FunctionId) -> FunctionId { - lower_monomorphized_func_signature(db, func, BTreeMap::new()) -} -pub fn lower_monomorphized_func_signature( - db: &dyn MirDb, - func: analyzer_items::FunctionId, - resolved_generics: BTreeMap, -) -> FunctionId { - // TODO: Remove this when an analyzer's function signature contains `self` type. - let mut params = vec![]; - - if func.takes_self(db.upcast()) { - let self_ty = func.self_type(db.upcast()).unwrap(); - let source = self_arg_source(db, func); - params.push(make_param(db, "self", self_ty, source)); - } - let analyzer_signature = func.signature(db.upcast()); - - for param in analyzer_signature.params.iter() { - let source = arg_source(db, func, ¶m.name); - - let param_type = - if let Type::Generic(generic) = param.typ.clone().unwrap().deref_typ(db.upcast()) { - *resolved_generics.get(&generic.name).unwrap() - } else { - param.typ.clone().unwrap() - }; - - params.push(make_param(db, param.clone().name, param_type, source)) - } - - let return_type = db.mir_lowered_type(analyzer_signature.return_type.clone().unwrap()); - - let linkage = if func.is_public(db.upcast()) { - if func.is_contract_func(db.upcast()) && !func.is_constructor(db.upcast()) { - Linkage::Export - } else { - Linkage::Public - } - } else { - Linkage::Private - }; - - let sig = FunctionSignature { - params, - resolved_generics, - return_type: Some(return_type), - module_id: func.module(db.upcast()), - analyzer_func_id: func, - linkage, - }; - - db.mir_intern_function(sig.into()) -} - pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { let analyzer_func = func.analyzer_func(db); let ast = &analyzer_func.data(db.upcast()).ast; @@ -90,207 +29,193 @@ pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { pub(super) struct BodyLowerHelper<'db, 'a> { pub(super) db: &'db dyn MirDb, pub(super) builder: BodyBuilder, - ast: &'a Node, + ast: &'a hir_def::Func, func: FunctionId, - analyzer_body: &'a fe_analyzer2::context::FunctionBody, + // analyzer_body: &'a fe_analyzer2::context::FunctionBody, scopes: Arena, current_scope: ScopeId, } impl<'db, 'a> BodyLowerHelper<'db, 'a> { - pub(super) fn lower_stmt(&mut self, stmt: &Node) { - match &stmt.kind { - hir_def::FuncStmt::Return { value } => { - let value = if let Some(expr) = value { - self.lower_expr_to_value(expr) - } else { - self.make_unit() - }; - self.builder.ret(value, stmt.into()); - let next_block = self.builder.make_block(); - self.builder.move_to_block(next_block); - } - - hir_def::FuncStmt::VarDecl { target, value, .. } => { - self.lower_var_decl(target, value.as_ref(), stmt.into()); - } - - hir_def::FuncStmt::ConstantDecl { name, value, .. } => { - let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&name.id]); - - let value = self.analyzer_body.expressions[&value.id] - .const_value - .clone() - .unwrap(); - - let constant = - self.make_local_constant(name.kind.clone(), ty, value.into(), stmt.into()); - self.scope_mut().declare_var(&name.kind, constant); - } - - hir_def::FuncStmt::Assign { target, value } => { - let result = self.lower_assignable_value(target); - let (expr, _ty) = self.lower_expr(value); - self.builder.map_result(expr, result) - } - - hir_def::FuncStmt::AugAssign { target, op, value } => { - let result = self.lower_assignable_value(target); - let lhs = self.lower_expr_to_value(target); - let rhs = self.lower_expr_to_value(value); - - let inst = self.lower_binop(op.kind, lhs, rhs, stmt.into()); - self.builder.map_result(inst, result) - } - - hir_def::FuncStmt::For { target, iter, body } => { - self.lower_for_loop(target, iter, body) - } - - hir_def::FuncStmt::While { test, body } => { - let header_bb = self.builder.make_block(); - let exit_bb = self.builder.make_block(); - - let cond = self.lower_expr_to_value(test); - self.builder - .branch(cond, header_bb, exit_bb, SourceInfo::dummy()); - - // Lower while body. - self.builder.move_to_block(header_bb); - self.enter_loop_scope(header_bb, exit_bb); - for stmt in body { - self.lower_stmt(stmt); - } - let cond = self.lower_expr_to_value(test); - self.builder - .branch(cond, header_bb, exit_bb, SourceInfo::dummy()); - - self.leave_scope(); - - // Move to while exit bb. - self.builder.move_to_block(exit_bb); - } - - hir_def::FuncStmt::If { - test, - body, - or_else, - } => self.lower_if(test, body, or_else), - - hir_def::FuncStmt::Match { expr, arms } => { - let matrix = &self.analyzer_body.matches[&stmt.id]; - super::pattern_match::lower_match(self, matrix, expr, arms); - } - - hir_def::FuncStmt::Assert { test, msg } => { - let then_bb = self.builder.make_block(); - let false_bb = self.builder.make_block(); - - let cond = self.lower_expr_to_value(test); - self.builder - .branch(cond, then_bb, false_bb, SourceInfo::dummy()); - - self.builder.move_to_block(false_bb); - - let msg = match msg { - Some(msg) => self.lower_expr_to_value(msg), - None => self.make_u256_imm(1), - }; - self.builder.revert(Some(msg), stmt.into()); - self.builder.move_to_block(then_bb); - } - - hir_def::FuncStmt::Expr { value } => { - self.lower_expr_to_value(value); - } - - hir_def::FuncStmt::Break => { - let exit = self.scope().loop_exit(&self.scopes); - self.builder.jump(exit, stmt.into()); - let next_block = self.builder.make_block(); - self.builder.move_to_block(next_block); - } - - hir_def::FuncStmt::Continue => { - let entry = self.scope().loop_entry(&self.scopes); - if let Some(loop_idx) = self.scope().loop_idx(&self.scopes) { - let imm_one = self.make_u256_imm(1u32); - let inc = self.builder.add(loop_idx, imm_one, SourceInfo::dummy()); - self.builder.map_result(inc, loop_idx.into()); - let maximum_iter_count = self.scope().maximum_iter_count(&self.scopes).unwrap(); - let exit = self.scope().loop_exit(&self.scopes); - self.branch_eq(loop_idx, maximum_iter_count, exit, entry, stmt.into()); - } else { - self.builder.jump(entry, stmt.into()); - } - let next_block = self.builder.make_block(); - self.builder.move_to_block(next_block); - } - - hir_def::FuncStmt::Revert { error } => { - let error = error.as_ref().map(|err| self.lower_expr_to_value(err)); - self.builder.revert(error, stmt.into()); - let next_block = self.builder.make_block(); - self.builder.move_to_block(next_block); - } - - hir_def::FuncStmt::Unsafe(stmts) => { - self.enter_scope(); - for stmt in stmts { - self.lower_stmt(stmt) - } - self.leave_scope() - } - } + pub(super) fn lower_stmt(&mut self, stmt: &hir_def::Stmt) { + // match &stmt.kind { + // hir_def::Stmt::Return(value) => { + // let value = if let Some(expr) = value { + // self.lower_expr_to_value(expr) + // } else { + // self.make_unit() + // }; + // self.builder.ret(value, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::VarDecl { target, value, .. } => { + // self.lower_var_decl(target, value.as_ref(), stmt.into()); + // } + + // hir_def::Stmt::ConstantDecl { name, value, .. } => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&name.id]); + + // let value = self.analyzer_body.expressions[&value.id] + // .const_value + // .clone() + // .unwrap(); + + // let constant = + // self.make_local_constant(name.kind.clone(), ty, value.into(), stmt.into()); + // self.scope_mut().declare_var(&name.kind, constant); + // } + + // hir_def::Stmt::Assign(target, value) => { + // let result = self.lower_assignable_value(target); + // let (expr, _ty) = self.lower_expr(value); + // self.builder.map_result(expr, result) + // } + + // hir_def::Stmt::AugAssign { target, op, value } => { + // let result = self.lower_assignable_value(target); + // let lhs = self.lower_expr_to_value(target); + // let rhs = self.lower_expr_to_value(value); + + // let inst = self.lower_binop(op.kind, lhs, rhs, stmt.into()); + // self.builder.map_result(inst, result) + // } + + // hir_def::Stmt::For(target, iter, body) => self.lower_for_loop(target, iter, body), + + // hir_def::Stmt::While(test, body) => { + // let header_bb = self.builder.make_block(); + // let exit_bb = self.builder.make_block(); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, header_bb, exit_bb()); + + // // Lower while body. + // self.builder.move_to_block(header_bb); + // self.enter_loop_scope(header_bb, exit_bb); + // for stmt in body { + // self.lower_stmt(stmt); + // } + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, header_bb, exit_bb()); + + // self.leave_scope(); + + // // Move to while exit bb. + // self.builder.move_to_block(exit_bb); + // } + + // hir_def::Stmt::If { + // test, + // body, + // or_else, + // } => self.lower_if(test, body, or_else), + + // hir_def::Stmt::Match { expr, arms } => { + // let matrix = &self.analyzer_body.matches[&stmt.id]; + // super::pattern_match::lower_match(self, matrix, expr, arms); + // } + + // hir_def::Stmt::Assert { test, msg } => { + // let then_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, then_bb, false_bb()); + + // self.builder.move_to_block(false_bb); + + // let msg = match msg { + // Some(msg) => self.lower_expr_to_value(msg), + // None => self.make_u256_imm(1), + // }; + // self.builder.revert(Some(msg), stmt.into()); + // self.builder.move_to_block(then_bb); + // } + + // hir_def::Stmt::Expr(value) => { + // self.lower_expr_to_value(value); + // } + + // hir_def::Stmt::Break => { + // let exit = self.scope().loop_exit(&self.scopes); + // self.builder.jump(exit, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Continue => { + // let entry = self.scope().loop_entry(&self.scopes); + // if let Some(loop_idx) = self.scope().loop_idx(&self.scopes) { + // let imm_one = self.make_u256_imm(1u32); + // let inc = self.builder.add(loop_idx, imm_one()); + // self.builder.map_result(inc, loop_idx.into()); + // let maximum_iter_count = self.scope().maximum_iter_count(&self.scopes).unwrap(); + // let exit = self.scope().loop_exit(&self.scopes); + // self.branch_eq(loop_idx, maximum_iter_count, exit, entry, stmt.into()); + // } else { + // self.builder.jump(entry, stmt.into()); + // } + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Revert { error } => { + // let error = error.as_ref().map(|err| self.lower_expr_to_value(err)); + // self.builder.revert(error, stmt.into()); + // let next_block = self.builder.make_block(); + // self.builder.move_to_block(next_block); + // } + + // hir_def::Stmt::Unsafe(stmts) => { + // self.enter_scope(); + // for stmt in stmts { + // self.lower_stmt(stmt) + // } + // self.leave_scope() + // } + // } + panic!() } - pub(super) fn lower_var_decl( - &mut self, - var: &Node, - init: Option<&Node>, - source: SourceInfo, - ) { - match &var.kind { - hir_def::VarDeclTarget::Name(name) => { - let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); - let value = self.declare_var(name, ty, var.into()); - if let Some(init) = init { - let (init, _init_ty) = self.lower_expr(init); - // debug_assert_eq!(ty.deref(self.db), init_ty, "vardecl init type mismatch: {} - // != {}", ty.as_string(self.db), - // init_ty.as_string(self.db)); - self.builder.map_result(init, value.into()); - } - } - - hir_def::VarDeclTarget::Tuple(decls) => { - if let Some(init) = init { - if let hir_def::Expr::Tuple { elts } = &init.kind { - debug_assert_eq!(decls.len(), elts.len()); - for (decl, init_elem) in decls.iter().zip(elts.iter()) { - self.lower_var_decl(decl, Some(init_elem), source.clone()); - } - } else { - let init_ty = self.expr_ty(init); - let init_value = self.lower_expr_to_value(init); - self.lower_var_decl_unpack(var, init_value, init_ty, source); - }; - } else { - for decl in decls { - self.lower_var_decl(decl, None, source.clone()) - } - } - } - } + pub(super) fn lower_var_decl(&mut self, var: &hir_def::PatId, init: Option<&hir_def::Expr>) { + // match &var.kind { + // hir_def::VarDeclTarget::Name(name) => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + // let value = self.declare_var(name, ty, var.into()); + // if let Some(init) = init { + // let (init, _init_ty) = self.lower_expr(init); + // // debug_assert_eq!(ty.deref(self.db), init_ty, "vardecl init type mismatch: {} + // // != {}", ty.as_string(self.db), + // // init_ty.as_string(self.db)); + // self.builder.map_result(init, value.into()); + // } + // } + + // hir_def::VarDeclTarget::Tuple(decls) => { + // if let Some(init) = init { + // if let hir_def::Expr::Tuple(elts) = &init.kind { + // debug_assert_eq!(decls.len(), elts.len()); + // for (decl, init_elem) in decls.iter().zip(elts.iter()) { + // self.lower_var_decl(decl, Some(init_elem), source.clone()); + // } + // } else { + // let init_ty = self.expr_ty(init); + // let init_value = self.lower_expr_to_value(init); + // self.lower_var_decl_unpack(var, init_value, init_ty, source); + // }; + // } else { + // for decl in decls { + // self.lower_var_decl(decl, None, source.clone()) + // } + // } + // } + // } + panic!() } - pub(super) fn declare_var( - &mut self, - name: &SmolStr, - ty: TypeId, - // source: SourceInfo, - ) -> ValueId { - // let local = Local::user_local(name.clone(), ty, source); + pub(super) fn declare_var(&mut self, name: &SmolStr, ty: TypeId) -> ValueId { let local = Local::user_local(name.clone(), ty); let value = self.builder.declare(local); self.scope_mut().declare_var(name, value); @@ -299,210 +224,206 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { pub(super) fn lower_var_decl_unpack( &mut self, - var: &Node, + var: &hir_def::PatId, init: ValueId, init_ty: TypeId, - // source: SourceInfo, ) { - match &var.kind { - hir_def::VarDeclTarget::Name(name) => { - let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); - let local = Local::user_local(name.clone(), ty, var.into()); - - let lhs = self.builder.declare(local); - self.scope_mut().declare_var(name, lhs); - let bind = self.builder.bind(init, source); - self.builder.map_result(bind, lhs.into()); - } - - hir_def::VarDeclTarget::Tuple(decls) => { - for (index, decl) in decls.iter().enumerate() { - let elem_ty = init_ty.projection_ty_imm(self.db, index); - let index_value = self.make_u256_imm(index); - let elem_inst = - self.builder - .aggregate_access(init, vec![index_value], source.clone()); - let elem_value = self.map_to_tmp(elem_inst, elem_ty); - self.lower_var_decl_unpack(decl, elem_value, elem_ty, source.clone()) - } - } - } + // match &var.kind { + // hir_def::VarDeclTarget::Name(name) => { + // let ty = self.lower_analyzer_type(self.analyzer_body.var_types[&var.id]); + // let local = Local::user_local(name.clone(), ty, var.into()); + + // let lhs = self.builder.declare(local); + // self.scope_mut().declare_var(name, lhs); + // let bind = self.builder.bind(init, source); + // self.builder.map_result(bind, lhs.into()); + // } + + // hir_def::VarDeclTarget::Tuple(decls) => { + // for (index, decl) in decls.iter().enumerate() { + // let elem_ty = init_ty.projection_ty_imm(self.db, index); + // let index_value = self.make_u256_imm(index); + // let elem_inst = + // self.builder + // .aggregate_access(init, vec![index_value], source.clone()); + // let elem_value = self.map_to_tmp(elem_inst, elem_ty); + // self.lower_var_decl_unpack(decl, elem_value, elem_ty, source.clone()) + // } + // } + // } + panic!() } - pub(super) fn lower_expr(&mut self, expr: &Node) -> (InstId, TypeId) { - let mut ty = self.expr_ty(expr); - let mut inst = match &expr.kind { - hir_def::Expr::Ternary { - if_expr, - test, - else_expr, - } => { - let true_bb = self.builder.make_block(); - let false_bb = self.builder.make_block(); - let merge_bb = self.builder.make_block(); - - let tmp = self - .builder - .declare(Local::tmp_local("$ternary_tmp".into(), ty)); - - let cond = self.lower_expr_to_value(test); - self.builder - .branch(cond, true_bb, false_bb, SourceInfo::dummy()); - - self.builder.move_to_block(true_bb); - let (value, _) = self.lower_expr(if_expr); - self.builder.map_result(value, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); - - self.builder.move_to_block(false_bb); - let (value, _) = self.lower_expr(else_expr); - self.builder.map_result(value, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); - - self.builder.move_to_block(merge_bb); - self.builder.bind(tmp, SourceInfo::dummy()) - } - - hir_def::Expr::BoolOperation { left, op, right } => { - self.lower_bool_op(op.kind, left, right, ty) - } - - hir_def::Expr::BinOperation { left, op, right } => { - let lhs = self.lower_expr_to_value(left); - let rhs = self.lower_expr_to_value(right); - self.lower_binop(op.kind, lhs, rhs, expr.into()) - } - - hir_def::Expr::UnaryOperation { op, operand } => { - let value = self.lower_expr_to_value(operand); - match op.kind { - hir_def::UnaryOperator::Invert => self.builder.inv(value, expr.into()), - hir_def::UnaryOperator::Not => self.builder.not(value, expr.into()), - hir_def::UnaryOperator::USub => self.builder.neg(value, expr.into()), - } - } - - hir_def::Expr::CompOperation { left, op, right } => { - let lhs = self.lower_expr_to_value(left); - let rhs = self.lower_expr_to_value(right); - self.lower_comp_op(op.kind, lhs, rhs, expr.into()) - } - - hir_def::Expr::Attribute { .. } => { - let mut indices = vec![]; - let value = self.lower_aggregate_access(expr, &mut indices); - self.builder.aggregate_access(value, indices, expr.into()) - } - - hir_def::Expr::Subscript { value, index } => { - let value_ty = self.expr_ty(value).deref(self.db); - if value_ty.is_aggregate(self.db) { - let mut indices = vec![]; - let value = self.lower_aggregate_access(expr, &mut indices); - self.builder.aggregate_access(value, indices, expr.into()) - } else if value_ty.is_map(self.db) { - let value = self.lower_expr_to_value(value); - let key = self.lower_expr_to_value(index); - self.builder.map_access(value, key, expr.into()) - } else { - unreachable!() - } - } - - hir_def::Expr::Call { - func, - generic_args, - args, - } => { - let ty = self.expr_ty(expr); - self.lower_call(func, generic_args, &args.kind, ty, expr.into()) - } - - hir_def::Expr::List { elts } | hir_def::Expr::Tuple { elts } => { - let args = elts - .iter() - .map(|elem| self.lower_expr_to_value(elem)) - .collect(); - let ty = self.expr_ty(expr); - self.builder.aggregate_construct(ty, args, expr.into()) - } - - hir_def::Expr::Repeat { value, len: _ } => { - let array_type = if let Type::Array(array_type) = self.analyzer_body.expressions - [&expr.id] - .typ - .typ(self.db.upcast()) - { - array_type - } else { - panic!("not an array"); - }; - - let args = vec![self.lower_expr_to_value(value); array_type.size]; - let ty = self.expr_ty(expr); - self.builder.aggregate_construct(ty, args, expr.into()) - } - - hir_def::Expr::Bool(b) => { - let imm = self.builder.make_imm_from_bool(*b, ty); - self.builder.bind(imm, expr.into()) - } - - hir_def::Expr::Name(name) => { - let value = self.resolve_name(name); - self.builder.bind(value, expr.into()) - } - - hir_def::Expr::Path(path) => { - let value = self.resolve_path(path, expr.into()); - self.builder.bind(value, expr.into()) - } - - hir_def::Expr::Num(num) => { - let imm = Literal::new(num).parse().unwrap(); - let imm = self.builder.make_imm(imm, ty); - self.builder.bind(imm, expr.into()) - } - - hir_def::Expr::Str(s) => { - let ty = self.expr_ty(expr); - let const_value = self.make_local_constant( - "str_in_func".into(), - ty, - ConstantValue::Str(s.clone()), - expr.into(), - ); - self.builder.bind(const_value, expr.into()) - } - - hir_def::Expr::Unit => { - let value = self.make_unit(); - self.builder.bind(value, expr.into()) - } - }; - - for Adjustment { into, kind } in &self.analyzer_body.expressions[&expr.id].type_adjustments - { - let into_ty = self.lower_analyzer_type(*into); - - match kind { - AdjustmentKind::Copy => { - let val = self.inst_result_or_tmp(inst, ty); - inst = self.builder.mem_copy(val, expr.into()); - } - AdjustmentKind::Load => { - let val = self.inst_result_or_tmp(inst, ty); - inst = self.builder.load(val, expr.into()); - } - AdjustmentKind::IntSizeIncrease => { - let val = self.inst_result_or_tmp(inst, ty); - inst = self.builder.primitive_cast(val, into_ty, expr.into()) - } - AdjustmentKind::StringSizeIncrease => {} // XXX - } - ty = into_ty; - } - (inst, ty) + pub(super) fn lower_expr(&mut self, expr: &hir_def::Expr) -> (InstId, TypeId) { + // let mut ty = self.expr_ty(expr); + // let mut inst = match &expr.kind { + // hir_def::Expr::Ternary { + // if_expr, + // test, + // else_expr, + // } => { + // let true_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // let tmp = self + // .builder + // .declare(Local::tmp_local("$ternary_tmp".into(), ty)); + + // let cond = self.lower_expr_to_value(test); + // self.builder.branch(cond, true_bb, false_bb()); + + // self.builder.move_to_block(true_bb); + // let (value, _) = self.lower_expr(if_expr); + // self.builder.map_result(value, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(false_bb); + // let (value, _) = self.lower_expr(else_expr); + // self.builder.map_result(value, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(merge_bb); + // self.builder.bind(tmp()) + // } + + // hir_def::Expr::BoolOperation { left, op, right } => { + // self.lower_bool_op(op.kind, left, right, ty) + // } + + // hir_def::Expr::BinOperation { left, op, right } => { + // let lhs = self.lower_expr_to_value(left); + // let rhs = self.lower_expr_to_value(right); + // self.lower_binop(op.kind, lhs, rhs, expr.into()) + // } + + // hir_def::Expr::UnaryOperation { op, operand } => { + // let value = self.lower_expr_to_value(operand); + // match op.kind { + // hir_def::UnOp::Invert => self.builder.inv(value, expr.into()), + // hir_def::UnOp::Not => self.builder.not(value, expr.into()), + // hir_def::UnOp::USub => self.builder.neg(value, expr.into()), + // } + // } + + // hir_def::Expr::CompOperation { left, op, right } => { + // let lhs = self.lower_expr_to_value(left); + // let rhs = self.lower_expr_to_value(right); + // self.lower_comp_op(op.kind, lhs, rhs, expr.into()) + // } + + // hir_def::Expr::Attribute { .. } => { + // let mut indices = vec![]; + // let value = self.lower_aggregate_access(expr, &mut indices); + // self.builder.aggregate_access(value, indices, expr.into()) + // } + + // hir_def::Expr::Subscript { value, index } => { + // let value_ty = self.expr_ty(value).deref(self.db); + // if value_ty.is_aggregate(self.db) { + // let mut indices = vec![]; + // let value = self.lower_aggregate_access(expr, &mut indices); + // self.builder.aggregate_access(value, indices, expr.into()) + // } else if value_ty.is_map(self.db) { + // let value = self.lower_expr_to_value(value); + // let key = self.lower_expr_to_value(index); + // self.builder.map_access(value, key, expr.into()) + // } else { + // unreachable!() + // } + // } + + // hir_def::Expr::Call(func, generic_args, args) => { + // let ty = self.expr_ty(expr); + // self.lower_call(func, generic_args, &args.kind, ty, expr.into()) + // } + + // hir_def::Expr::List { elts } | hir_def::Expr::Tuple { elts } => { + // let args = elts + // .iter() + // .map(|elem| self.lower_expr_to_value(elem)) + // .collect(); + // let ty = self.expr_ty(expr); + // self.builder.aggregate_construct(ty, args, expr.into()) + // } + + // hir_def::Expr::Repeat { value, len: _ } => { + // let array_type = if let Type::Array(array_type) = self.analyzer_body.expressions + // [&expr.id] + // .typ + // .typ(self.db.upcast()) + // { + // array_type + // } else { + // panic!("not an array"); + // }; + + // let args = vec![self.lower_expr_to_value(value); array_type.size]; + // let ty = self.expr_ty(expr); + // self.builder.aggregate_construct(ty, args, expr.into()) + // } + + // hir_def::Expr::Bool(b) => { + // let imm = self.builder.make_imm_from_bool(*b, ty); + // self.builder.bind(imm, expr.into()) + // } + + // hir_def::Expr::Name(name) => { + // let value = self.resolve_name(name); + // self.builder.bind(value, expr.into()) + // } + + // hir_def::Expr::Path(path) => { + // let value = self.resolve_path(path, expr.into()); + // self.builder.bind(value, expr.into()) + // } + + // hir_def::Expr::Num(num) => { + // let imm = Literal::new(num).parse().unwrap(); + // let imm = self.builder.make_imm(imm, ty); + // self.builder.bind(imm, expr.into()) + // } + + // hir_def::Expr::Str(s) => { + // let ty = self.expr_ty(expr); + // let const_value = self.make_local_constant( + // "str_in_func".into(), + // ty, + // ConstantValue::Str(s.clone()), + // expr.into(), + // ); + // self.builder.bind(const_value, expr.into()) + // } + + // hir_def::Expr::Unit => { + // let value = self.make_unit(); + // self.builder.bind(value, expr.into()) + // } + // }; + + // for Adjustment { into, kind } in &self.analyzer_body.expressions[&expr.id].type_adjustments + // { + // let into_ty = self.lower_analyzer_type(*into); + + // match kind { + // AdjustmentKind::Copy => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.mem_copy(val, expr.into()); + // } + // AdjustmentKind::Load => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.load(val, expr.into()); + // } + // AdjustmentKind::IntSizeIncrease => { + // let val = self.inst_result_or_tmp(inst, ty); + // inst = self.builder.primitive_cast(val, into_ty, expr.into()) + // } + // AdjustmentKind::StringSizeIncrease => {} // XXX + // } + // ty = into_ty; + // } + // (inst, ty) + panic!() } fn inst_result_or_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { @@ -512,7 +433,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .unwrap_or_else(|| self.map_to_tmp(inst, ty)) } - pub(super) fn lower_expr_to_value(&mut self, expr: &Node) -> ValueId { + pub(super) fn lower_expr_to_value(&mut self, expr: &hir_def::Expr) -> ValueId { let (inst, ty) = self.lower_expr(expr); self.map_to_tmp(inst, ty) } @@ -540,7 +461,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { pub(super) fn map_to_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { match &self.builder.inst_data(inst).kind { - InstKind::Bind { src } => { + &InstKind::Bind { src } => { let value = *src; self.builder.remove_inst(inst); value @@ -554,47 +475,47 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - fn new( - db: &'db dyn MirDb, - func: FunctionId, - ast: &'a Node, - analyzer_body: &'a fe_analyzer2::context::FunctionBody, - ) -> Self { - let mut builder = BodyBuilder::new(func, ast.into()); - let mut scopes = Arena::new(); - - // Make a root scope. A root scope collects function parameters and module - // constants. - let root = Scope::root(db, func, &mut builder); - let current_scope = scopes.alloc(root); - Self { - db, - builder, - ast, - func, - analyzer_body, - scopes, - current_scope, - } - } - - fn lower_analyzer_type(&self, analyzer_ty: analyzer_types::TypeId) -> TypeId { - // If the analyzer type is generic we first need to resolve it to its concrete - // type before lowering to a MIR type - if let analyzer_types::Type::Generic(generic) = analyzer_ty.deref_typ(self.db.upcast()) { - let resolved_type = self - .func - .signature(self.db) - .resolved_generics - .get(&generic.name) - .cloned() - .expect("expected generic to be resolved"); - - return self.db.mir_lowered_type(resolved_type); - } - - self.db.mir_lowered_type(analyzer_ty) - } + // fn new( + // db: &'db dyn MirDb, + // func: FunctionId, + // ast: &'a Node, + // analyzer_body: &'a fe_analyzer2::context::FunctionBody, + // ) -> Self { + // let mut builder = BodyBuilder::new(func, ast.into()); + // let mut scopes = Arena::new(); + + // // Make a root scope. A root scope collects function parameters and module + // // constants. + // let root = Scope::root(db, func, &mut builder); + // let current_scope = scopes.alloc(root); + // Self { + // db, + // builder, + // ast, + // func, + // analyzer_body, + // scopes, + // current_scope, + // } + // } + + // fn lower_analyzer_type(&self, analyzer_ty: analyzer_types::TypeId) -> TypeId { + // // If the analyzer type is generic we first need to resolve it to its concrete + // // type before lowering to a MIR type + // if let analyzer_types::Type::Generic(generic) = analyzer_ty.deref_typ(self.db.upcast()) { + // let resolved_type = self + // .func + // .signature(self.db) + // .resolved_generics + // .get(&generic.name) + // .cloned() + // .expect("expected generic to be resolved"); + + // return self.db.mir_lowered_type(resolved_type); + // } + + // self.db.mir_lowered_type(analyzer_ty) + // } fn lower(mut self) -> FunctionBody { for stmt in &self.ast.kind.body { @@ -604,7 +525,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let last_block = self.builder.current_block(); if !self.builder.is_block_terminated(last_block) { let unit = self.make_unit(); - self.builder.ret(unit, SourceInfo::dummy()); + self.builder.ret(unit()); } self.builder.build() @@ -616,28 +537,21 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { v2: ValueId, true_bb: BasicBlockId, false_bb: BasicBlockId, - source: SourceInfo, ) { - let cond = self.builder.eq(v1, v2, source.clone()); + let cond = self.builder.eq(v1, v2); let bool_ty = self.bool_ty(); let cond = self.map_to_tmp(cond, bool_ty); - self.builder.branch(cond, true_bb, false_bb, source); + self.builder.branch(cond, true_bb, false_bb); } - fn lower_if( - &mut self, - cond: &Node, - then: &[Node], - else_: &[Node], - ) { + fn lower_if(&mut self, cond: &hir_def::Expr, then: &[hir_def::Stmt], else_: &[hir_def::Stmt]) { let cond = self.lower_expr_to_value(cond); if else_.is_empty() { let then_bb = self.builder.make_block(); let merge_bb = self.builder.make_block(); - self.builder - .branch(cond, then_bb, merge_bb, SourceInfo::dummy()); + self.builder.branch(cond, then_bb, merge_bb()); // Lower then block. self.builder.move_to_block(then_bb); @@ -645,15 +559,14 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { for stmt in then { self.lower_stmt(stmt); } - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); self.builder.move_to_block(merge_bb); self.leave_scope(); } else { let then_bb = self.builder.make_block(); let else_bb = self.builder.make_block(); - self.builder - .branch(cond, then_bb, else_bb, SourceInfo::dummy()); + self.builder.branch(cond, then_bb, else_bb()); // Lower then block. self.builder.move_to_block(then_bb); @@ -676,11 +589,11 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { let merge_bb = self.builder.make_block(); if !self.builder.is_block_terminated(then_block_end_bb) { self.builder.move_to_block(then_block_end_bb); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); } if !self.builder.is_block_terminated(else_block_end_bb) { self.builder.move_to_block(else_block_end_bb); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); } self.builder.move_to_block(merge_bb); } @@ -690,130 +603,116 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { // TODO: Desugar to `loop` + `match` like rustc in HIR to generate better MIR. fn lower_for_loop( &mut self, - loop_variable: &Node, - iter: &Node, - body: &[Node], + loop_variable: &hir_def::IdentId, + iter: &hir_def::Expr, + body: &[hir_def::Stmt], ) { - let preheader_bb = self.builder.make_block(); - let entry_bb = self.builder.make_block(); - let exit_bb = self.builder.make_block(); - - let iter_elem_ty = self.analyzer_body.var_types[&loop_variable.id]; - let iter_elem_ty = self.lower_analyzer_type(iter_elem_ty); - - self.builder.jump(preheader_bb, SourceInfo::dummy()); - - // `For` has its scope from preheader block. - self.enter_loop_scope(entry_bb, exit_bb); - - /* Lower preheader. */ - self.builder.move_to_block(preheader_bb); - - // Declare loop_variable. - let loop_value = self.builder.declare(Local::user_local( - loop_variable.kind.clone(), - iter_elem_ty, - loop_variable.into(), - )); - self.scope_mut() - .declare_var(&loop_variable.kind, loop_value); - - // Declare and initialize `loop_idx` to 0. - let loop_idx = Local::tmp_local("$loop_idx_tmp".into(), self.u256_ty()); - let loop_idx = self.builder.declare(loop_idx); - let imm_zero = self.make_u256_imm(0u32); - let imm_zero = self.builder.bind(imm_zero, SourceInfo::dummy()); - self.builder.map_result(imm_zero, loop_idx.into()); - - // Evaluates loop variable. - let iter_ty = self.expr_ty(iter); - let iter = self.lower_expr_to_value(iter); - - // Create maximum loop count. - let maximum_iter_count = match &iter_ty.deref(self.db).data(self.db).kind { - ir::TypeKind::Array(ir::types::ArrayDef { len, .. }) => *len, - _ => unreachable!(), - }; - let maximum_iter_count = self.make_u256_imm(maximum_iter_count); - self.branch_eq( - loop_idx, - maximum_iter_count, - exit_bb, - entry_bb, - SourceInfo::dummy(), - ); - self.scope_mut().loop_idx = Some(loop_idx); - self.scope_mut().maximum_iter_count = Some(maximum_iter_count); - - /* Lower body. */ - self.builder.move_to_block(entry_bb); - - // loop_variable = array[loop_idx] - let iter_elem = self - .builder - .aggregate_access(iter, vec![loop_idx], SourceInfo::dummy()); - self.builder - .map_result(iter_elem, AssignableValue::Value(loop_value)); - - for stmt in body { - self.lower_stmt(stmt); - } - - // loop_idx += 1 - let imm_one = self.make_u256_imm(1u32); - let inc = self.builder.add(loop_idx, imm_one, SourceInfo::dummy()); - self.builder - .map_result(inc, AssignableValue::Value(loop_idx)); - self.branch_eq( - loop_idx, - maximum_iter_count, - exit_bb, - entry_bb, - SourceInfo::dummy(), - ); - - /* Move to exit bb */ - self.leave_scope(); - self.builder.move_to_block(exit_bb); + // let preheader_bb = self.builder.make_block(); + // let entry_bb = self.builder.make_block(); + // let exit_bb = self.builder.make_block(); + + // let iter_elem_ty = self.analyzer_body.var_types[&loop_variable.id]; + // let iter_elem_ty = self.lower_analyzer_type(iter_elem_ty); + + // self.builder.jump(preheader_bb()); + + // // `For` has its scope from preheader block. + // self.enter_loop_scope(entry_bb, exit_bb); + + // /* Lower preheader. */ + // self.builder.move_to_block(preheader_bb); + + // // Declare loop_variable. + // let loop_value = self.builder.declare(Local::user_local( + // loop_variable.kind.clone(), + // iter_elem_ty, + // loop_variable.into(), + // )); + // self.scope_mut() + // .declare_var(&loop_variable.kind, loop_value); + + // // Declare and initialize `loop_idx` to 0. + // let loop_idx = Local::tmp_local("$loop_idx_tmp".into(), self.u256_ty()); + // let loop_idx = self.builder.declare(loop_idx); + // let imm_zero = self.make_u256_imm(0u32); + // let imm_zero = self.builder.bind(imm_zero()); + // self.builder.map_result(imm_zero, loop_idx.into()); + + // // Evaluates loop variable. + // let iter_ty = self.expr_ty(iter); + // let iter = self.lower_expr_to_value(iter); + + // // Create maximum loop count. + // let maximum_iter_count = match &iter_ty.deref(self.db).data(self.db).kind { + // ir::TypeKind::Array(ir::types::ArrayDef { len, .. }) => *len, + // _ => unreachable!(), + // }; + // let maximum_iter_count = self.make_u256_imm(maximum_iter_count); + // self.branch_eq(loop_idx, maximum_iter_count, exit_bb, entry_bb); + // self.scope_mut().loop_idx = Some(loop_idx); + // self.scope_mut().maximum_iter_count = Some(maximum_iter_count); + + // /* Lower body. */ + // self.builder.move_to_block(entry_bb); + + // // loop_variable = array[loop_idx] + // let iter_elem = self.builder.aggregate_access(iter, vec![loop_idx]()); + // self.builder + // .map_result(iter_elem, AssignableValue::Value(loop_value)); + + // for stmt in body { + // self.lower_stmt(stmt); + // } + + // // loop_idx += 1 + // let imm_one = self.make_u256_imm(1u32); + // let inc = self.builder.add(loop_idx, imm_one()); + // self.builder + // .map_result(inc, AssignableValue::Value(loop_idx)); + // self.branch_eq(loop_idx, maximum_iter_count, exit_bb, entry_bb); + + // /* Move to exit bb */ + // self.leave_scope(); + // self.builder.move_to_block(exit_bb); } - fn lower_assignable_value(&mut self, expr: &Node) -> AssignableValue { - match &expr.kind { - hir_def::Expr::Attribute { value, attr } => { - let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); - let idx = self.make_u256_imm(idx); - let lhs = self.lower_assignable_value(value).into(); - AssignableValue::Aggregate { lhs, idx } - } - hir_def::Expr::Subscript { value, index } => { - let lhs = self.lower_assignable_value(value).into(); - let attr = self.lower_expr_to_value(index); - let value_ty = self.expr_ty(value).deref(self.db); - if value_ty.is_aggregate(self.db) { - AssignableValue::Aggregate { lhs, idx: attr } - } else if value_ty.is_map(self.db) { - AssignableValue::Map { lhs, key: attr } - } else { - unreachable!() - } - } - hir_def::Expr::Name(name) => self.resolve_name(name).into(), - hir_def::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), - _ => self.lower_expr_to_value(expr).into(), - } - } + // fn lower_assignable_value(&mut self, expr: &hir_def::Expr) -> AssignableValue { + // match &expr.kind { + // hir_def::Expr::Attribute { value, attr } => { + // let idx = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + // let idx = self.make_u256_imm(idx); + // let lhs = self.lower_assignable_value(value).into(); + // AssignableValue::Aggregate { lhs, idx } + // } + // hir_def::Expr::Subscript { value, index } => { + // let lhs = self.lower_assignable_value(value).into(); + // let attr = self.lower_expr_to_value(index); + // let value_ty = self.expr_ty(value).deref(self.db); + // if value_ty.is_aggregate(self.db) { + // AssignableValue::Aggregate { lhs, idx: attr } + // } else if value_ty.is_map(self.db) { + // AssignableValue::Map { lhs, key: attr } + // } else { + // unreachable!() + // } + // } + // hir_def::Expr::Name(name) => self.resolve_name(name).into(), + // hir_def::Expr::Path(path) => self.resolve_path(path, expr.into()).into(), + // _ => self.lower_expr_to_value(expr).into(), + // } + // } /// Returns the pre-adjustment type of the given `Expr` - fn expr_ty(&self, expr: &Node) -> TypeId { + fn expr_ty(&self, expr: &hir_def::Expr) -> TypeId { let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; self.lower_analyzer_type(analyzer_ty) } fn lower_bool_op( &mut self, - op: hir_def::BoolOperator, - lhs: &Node, - rhs: &Node, + op: hir_def::LogicalBinOp, + lhs: &hir_def::Expr, + rhs: &hir_def::Expr, ty: TypeId, ) -> InstId { let true_bb = self.builder.make_block(); @@ -826,41 +725,39 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); match op { - hir_def::BoolOperator::And => { - self.builder - .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); + hir_def::LogicalBinOp::And => { + self.builder.branch(lhs, true_bb, false_bb()); self.builder.move_to_block(true_bb); let (rhs, _rhs_ty) = self.lower_expr(rhs); self.builder.map_result(rhs, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); self.builder.move_to_block(false_bb); let false_imm = self.builder.make_imm_from_bool(false, ty); - let false_imm_copy = self.builder.bind(false_imm, SourceInfo::dummy()); + let false_imm_copy = self.builder.bind(false_imm()); self.builder.map_result(false_imm_copy, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); } - hir_def::BoolOperator::Or => { - self.builder - .branch(lhs, true_bb, false_bb, SourceInfo::dummy()); + hir_def::LogicalBinOp::Or => { + self.builder.branch(lhs, true_bb, false_bb()); self.builder.move_to_block(true_bb); let true_imm = self.builder.make_imm_from_bool(true, ty); - let true_imm_copy = self.builder.bind(true_imm, SourceInfo::dummy()); + let true_imm_copy = self.builder.bind(true_imm()); self.builder.map_result(true_imm_copy, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); self.builder.move_to_block(false_bb); let (rhs, _rhs_ty) = self.lower_expr(rhs); self.builder.map_result(rhs, tmp.into()); - self.builder.jump(merge_bb, SourceInfo::dummy()); + self.builder.jump(merge_bb()); } } self.builder.move_to_block(merge_bb); - self.builder.bind(tmp, SourceInfo::dummy()) + self.builder.bind(tmp()) } fn lower_binop( @@ -902,188 +799,184 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - fn resolve_generics_args( - &mut self, - method: &analyzer_items::FunctionId, - args: &[Id], - ) -> BTreeMap { - method - .signature(self.db.upcast()) - .params - .iter() - .zip(args.iter().map(|val| { - self.builder - .value_ty(*val) - .analyzer_ty(self.db) - .expect("invalid parameter") - })) - .filter_map(|(param, typ)| { - if let Type::Generic(generic) = - param.typ.clone().unwrap().deref_typ(self.db.upcast()) - { - Some((generic.name, typ)) - } else { - None - } - }) - .collect::>() - } - - fn lower_function_id( - &mut self, - function: &analyzer_items::FunctionId, - args: &[Id], - ) -> FunctionId { - let resolved_generics = self.resolve_generics_args(function, args); - if function.is_generic(self.db.upcast()) { - self.db - .mir_lowered_monomorphized_func_signature(*function, resolved_generics) - } else { - self.db.mir_lowered_func_signature(*function) - } - } + // fn resolve_generics_args( + // &mut self, + // method: &analyzer_items::FunctionId, + // args: &[Id], + // ) -> BTreeMap { + // method + // .signature(self.db.upcast()) + // .params + // .iter() + // .zip(args.iter().map(|val| { + // self.builder + // .value_ty(*val) + // .analyzer_ty(self.db) + // .expect("invalid parameter") + // })) + // .filter_map(|(param, typ)| { + // if let Type::Generic(generic) = + // param.typ.clone().unwrap().deref_typ(self.db.upcast()) + // { + // Some((generic.name, typ)) + // } else { + // None + // } + // }) + // .collect::>() + // } + + // fn lower_function_id(&mut self, function: &hir_def::Func, args: &[Id]) -> FunctionId { + // let resolved_generics = self.resolve_generics_args(function, args); + // if function.is_generic(self.db.upcast()) { + // self.db + // .mir_lowered_monomorphized_func_signature(*function, resolved_generics) + // } else { + // self.db.mir_lowered_func_signature(*function) + // } + // } fn lower_call( &mut self, - func: &Node, - _generic_args: &Option>>, - args: &[Node], + func: &hir_def::Expr, + _generic_args: &Option>, + args: &[hir_def::CallArg], ty: TypeId, - source: SourceInfo, ) -> InstId { - let call_type = &self.analyzer_body.calls[&func.id]; - - let mut args: Vec<_> = args - .iter() - .map(|arg| self.lower_expr_to_value(&arg.kind.value)) - .collect(); - - match call_type { - AnalyzerCallType::BuiltinFunction(GlobalFunction::Keccak256) => { - self.builder.keccak256(args[0], source) - } - - AnalyzerCallType::Intrinsic(intrinsic) => { - self.builder - .yul_intrinsic((*intrinsic).into(), args, source) - } - - AnalyzerCallType::BuiltinValueMethod { method, .. } => { - let arg = self.lower_method_receiver(func); - match method { - ValueMethod::ToMem => self.builder.mem_copy(arg, source), - ValueMethod::AbiEncode => self.builder.abi_encode(arg, source), - } - } - - // We ignores `args[0]', which represents `context` and not used for now. - AnalyzerCallType::BuiltinAssociatedFunction { contract, function } => match function { - ContractTypeMethod::Create => self.builder.create(args[1], *contract, source), - ContractTypeMethod::Create2 => { - self.builder.create2(args[1], args[2], *contract, source) - } - }, - - AnalyzerCallType::AssociatedFunction { function, .. } - | AnalyzerCallType::Pure(function) => { - let func_id = self.lower_function_id(function, &args); - self.builder.call(func_id, args, CallType::Internal, source) - } - - AnalyzerCallType::ValueMethod { method, .. } => { - let mut method_args = vec![self.lower_method_receiver(func)]; - let func_id = self.lower_function_id(method, &args); - - method_args.append(&mut args); - - self.builder - .call(func_id, method_args, CallType::Internal, source) - } - AnalyzerCallType::TraitValueMethod { - trait_id, method, .. - } if trait_id.is_std_trait(self.db.upcast(), EMITTABLE_TRAIT_NAME) - && method.name(self.db.upcast()) == EMIT_FN_NAME => - { - let event = self.lower_method_receiver(func); - self.builder.emit(event, source) - } - AnalyzerCallType::TraitValueMethod { - method, - trait_id, - generic_type, - .. - } => { - let mut method_args = vec![self.lower_method_receiver(func)]; - method_args.append(&mut args); - - let concrete_type = self - .func - .signature(self.db) - .resolved_generics - .get(&generic_type.name) - .cloned() - .expect("unresolved generic type"); - - let impl_ = concrete_type - .get_impl_for(self.db.upcast(), *trait_id) - .expect("missing impl"); - - let function = impl_ - .function(self.db.upcast(), &method.name(self.db.upcast())) - .expect("missing function"); - - let func_id = self.db.mir_lowered_func_signature(function); - self.builder - .call(func_id, method_args, CallType::Internal, source) - } - AnalyzerCallType::External { function, .. } => { - let receiver = self.lower_method_receiver(func); - debug_assert!(self.builder.value_ty(receiver).is_address(self.db)); - - let mut method_args = vec![receiver]; - method_args.append(&mut args); - let func_id = self.db.mir_lowered_func_signature(*function); - self.builder - .call(func_id, method_args, CallType::External, source) - } - - AnalyzerCallType::TypeConstructor(to_ty) => { - if to_ty.is_string(self.db.upcast()) { - let arg = *args.last().unwrap(); - self.builder.mem_copy(arg, source) - } else if ty.is_primitive(self.db) { - // TODO: Ignore `ctx` for now. - let arg = *args.last().unwrap(); - let arg_ty = self.builder.value_ty(arg); - if arg_ty == ty { - self.builder.bind(arg, source) - } else { - debug_assert!(!arg_ty.is_ptr(self.db)); // Should be explicitly `Load`ed - self.builder.primitive_cast(arg, ty, source) - } - } else if ty.is_aggregate(self.db) { - self.builder.aggregate_construct(ty, args, source) - } else { - unreachable!() - } - } - - AnalyzerCallType::EnumConstructor(variant) => { - let tag_type = ty.enum_disc_type(self.db); - let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); - let data_ty = ty.enum_variant_type(self.db, *variant); - let enum_args = if data_ty.is_unit(self.db) { - vec![tag, self.make_unit()] - } else { - std::iter::once(tag).chain(args).collect() - }; - self.builder.aggregate_construct(ty, enum_args, source) - } - } + // let call_type = &self.analyzer_body.calls[&func.id]; + + // let mut args: Vec<_> = args + // .iter() + // .map(|arg| self.lower_expr_to_value(&arg.kind.value)) + // .collect(); + + // match call_type { + // AnalyzerCallType::BuiltinFunction(GlobalFunction::Keccak256) => { + // self.builder.keccak256(args[0], source) + // } + + // AnalyzerCallType::Intrinsic(intrinsic) => { + // self.builder + // .yul_intrinsic((*intrinsic).into(), args, source) + // } + + // AnalyzerCallType::BuiltinValueMethod { method, .. } => { + // let arg = self.lower_method_receiver(func); + // match method { + // ValueMethod::ToMem => self.builder.mem_copy(arg, source), + // ValueMethod::AbiEncode => self.builder.abi_encode(arg, source), + // } + // } + + // // We ignores `args[0]', which represents `context` and not used for now. + // AnalyzerCallType::BuiltinAssociatedFunction { contract, function } => match function { + // ContractTypeMethod::Create => self.builder.create(args[1], *contract, source), + // ContractTypeMethod::Create2 => { + // self.builder.create2(args[1], args[2], *contract, source) + // } + // }, + + // AnalyzerCallType::AssociatedFunction { function, .. } + // | AnalyzerCallType::Pure(function) => { + // let func_id = self.lower_function_id(function, &args); + // self.builder.call(func_id, args, CallType::Internal, source) + // } + + // AnalyzerCallType::ValueMethod { method, .. } => { + // let mut method_args = vec![self.lower_method_receiver(func)]; + // let func_id = self.lower_function_id(method, &args); + + // method_args.append(&mut args); + + // self.builder + // .call(func_id, method_args, CallType::Internal, source) + // } + // AnalyzerCallType::TraitValueMethod { + // trait_id, method, .. + // } if trait_id.is_std_trait(self.db.upcast(), EMITTABLE_TRAIT_NAME) + // && method.name(self.db.upcast()) == EMIT_FN_NAME => + // { + // let event = self.lower_method_receiver(func); + // self.builder.emit(event, source) + // } + // AnalyzerCallType::TraitValueMethod { + // method, + // trait_id, + // generic_type, + // .. + // } => { + // let mut method_args = vec![self.lower_method_receiver(func)]; + // method_args.append(&mut args); + + // let concrete_type = self + // .func + // .signature(self.db) + // .resolved_generics + // .get(&generic_type.name) + // .cloned() + // .expect("unresolved generic type"); + + // let impl_ = concrete_type + // .get_impl_for(self.db.upcast(), *trait_id) + // .expect("missing impl"); + + // let function = impl_ + // .function(self.db.upcast(), &method.name(self.db.upcast())) + // .expect("missing function"); + + // let func_id = self.db.mir_lowered_func_signature(function); + // self.builder + // .call(func_id, method_args, CallType::Internal, source) + // } + // AnalyzerCallType::External { function, .. } => { + // let receiver = self.lower_method_receiver(func); + // debug_assert!(self.builder.value_ty(receiver).is_address(self.db)); + + // let mut method_args = vec![receiver]; + // method_args.append(&mut args); + // let func_id = self.db.mir_lowered_func_signature(*function); + // self.builder + // .call(func_id, method_args, CallType::External, source) + // } + + // AnalyzerCallType::TypeConstructor(to_ty) => { + // if to_ty.is_string(self.db.upcast()) { + // let arg = *args.last().unwrap(); + // self.builder.mem_copy(arg, source) + // } else if ty.is_primitive(self.db) { + // // TODO: Ignore `ctx` for now. + // let arg = *args.last().unwrap(); + // let arg_ty = self.builder.value_ty(arg); + // if arg_ty == ty { + // self.builder.bind(arg, source) + // } else { + // debug_assert!(!arg_ty.is_ptr(self.db)); // Should be explicitly `Load`ed + // self.builder.primitive_cast(arg, ty, source) + // } + // } else if ty.is_aggregate(self.db) { + // self.builder.aggregate_construct(ty, args, source) + // } else { + // unreachable!() + // } + // } + + // AnalyzerCallType::EnumConstructor(variant) => { + // let tag_type = ty.enum_disc_type(self.db); + // let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + // let data_ty = ty.enum_variant_type(self.db, *variant); + // let enum_args = if data_ty.is_unit(self.db) { + // vec![tag, self.make_unit()] + // } else { + // std::iter::once(tag).chain(args).collect() + // }; + // self.builder.aggregate_construct(ty, enum_args, source) + // } + // } + todo!(); } // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. - fn lower_method_receiver(&mut self, receiver: &Node) -> ValueId { + fn lower_method_receiver(&mut self, receiver: &hir_def::Expr) -> ValueId { match &receiver.kind { hir_def::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), _ => unreachable!(), @@ -1092,7 +985,7 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { fn lower_aggregate_access( &mut self, - expr: &Node, + expr: &hir_def::Expr, indices: &mut Vec, ) -> ValueId { match &expr.kind { @@ -1115,97 +1008,97 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { } } - fn make_unit(&mut self) -> ValueId { - let unit_ty = analyzer_types::TypeId::unit(self.db.upcast()); - let unit_ty = self.db.mir_lowered_type(unit_ty); - self.builder.make_unit(unit_ty) - } - - fn make_local_constant( - &mut self, - name: SmolStr, - ty: TypeId, - value: ConstantValue, - source: SourceInfo, - ) -> ValueId { - let function_id = self.builder.func_id(); - let constant = Constant { - name, - value, - ty, - module_id: function_id.module(self.db), - source, - }; - - let constant_id = self.db.mir_intern_const(constant.into()); - self.builder.make_constant(constant_id, ty) - } - - fn u256_ty(&mut self) -> TypeId { - self.db - .mir_intern_type(ir::Type::new(ir::TypeKind::U256, None).into()) - } - - fn bool_ty(&mut self) -> TypeId { - self.db - .mir_intern_type(ir::Type::new(ir::TypeKind::Bool, None).into()) - } - - fn enter_loop_scope(&mut self, entry: BasicBlockId, exit: BasicBlockId) { - let new_scope = Scope::loop_scope(self.current_scope, entry, exit); - self.current_scope = self.scopes.alloc(new_scope); - } - - /// Resolve a name appeared in an expression. - /// NOTE: Don't call this to resolve method receiver. - fn resolve_name(&mut self, name: &str) -> ValueId { - if let Some(value) = self.scopes[self.current_scope].resolve_name(&self.scopes, name) { - // Name is defined in local. - value - } else { - // Name is defined in global. - let func_id = self.builder.func_id(); - let module = func_id.module(self.db); - let constant = match module - .resolve_name(self.db.upcast(), name) - .unwrap() - .unwrap() - { - NamedThing::Item(analyzer_items::Item::Constant(id)) => { - self.db.mir_lowered_constant(id) - } - _ => panic!("name defined in global must be constant"), - }; - let ty = constant.ty(self.db); - self.builder.make_constant(constant, ty) - } - } - - /// Resolve a path appeared in an expression. - /// NOTE: Don't call this to resolve method receiver. - fn resolve_path(&mut self, path: &hir_def::Path, source: SourceInfo) -> ValueId { - let func_id = self.builder.func_id(); - let module = func_id.module(self.db); - match module.resolve_path(self.db.upcast(), path).value.unwrap() { - NamedThing::Item(analyzer_items::Item::Constant(id)) => { - let constant = self.db.mir_lowered_constant(id); - let ty = constant.ty(self.db); - self.builder.make_constant(constant, ty) - } - NamedThing::EnumVariant(variant) => { - let enum_ty = self - .db - .mir_lowered_type(variant.parent(self.db.upcast()).as_type(self.db.upcast())); - let tag_type = enum_ty.enum_disc_type(self.db); - let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); - let data = self.make_unit(); - let enum_args = vec![tag, data]; - let inst = self.builder.aggregate_construct(enum_ty, enum_args, source); - self.map_to_tmp(inst, enum_ty) - } - _ => panic!("path defined in global must be constant"), - } - } + // fn make_unit(&mut self) -> ValueId { + // let unit_ty = analyzer_types::TypeId::unit(self.db.upcast()); + // let unit_ty = self.db.mir_lowered_type(unit_ty); + // self.builder.make_unit(unit_ty) + // } + + // fn make_local_constant( + // &mut self, + // name: SmolStr, + // ty: TypeId, + // value: ConstantValue, + // source: SourceInfo, + // ) -> ValueId { + // let function_id = self.builder.func_id(); + // let constant = Constant { + // name, + // value, + // ty, + // module_id: function_id.module(self.db), + // source, + // }; + + // let constant_id = self.db.mir_intern_const(constant.into()); + // self.builder.make_constant(constant_id, ty) + // } + + // fn u256_ty(&mut self) -> TypeId { + // self.db + // .mir_intern_type(ir::Type::new(ir::TypeKind::U256, None).into()) + // } + + // fn bool_ty(&mut self) -> TypeId { + // self.db + // .mir_intern_type(ir::Type::new(ir::TypeKind::Bool, None).into()) + // } + + // fn enter_loop_scope(&mut self, entry: BasicBlockId, exit: BasicBlockId) { + // let new_scope = Scope::loop_scope(self.current_scope, entry, exit); + // self.current_scope = self.scopes.alloc(new_scope); + // } + + // /// Resolve a name appeared in an expression. + // /// NOTE: Don't call this to resolve method receiver. + // fn resolve_name(&mut self, name: &str) -> ValueId { + // if let Some(value) = self.scopes[self.current_scope].resolve_name(&self.scopes, name) { + // // Name is defined in local. + // value + // } else { + // // Name is defined in global. + // let func_id = self.builder.func_id(); + // let module = func_id.module(self.db); + // let constant = match module + // .resolve_name(self.db.upcast(), name) + // .unwrap() + // .unwrap() + // { + // NamedThing::Item(analyzer_items::Item::Constant(id)) => { + // self.db.mir_lowered_constant(id) + // } + // _ => panic!("name defined in global must be constant"), + // }; + // let ty = constant.ty(self.db); + // self.builder.make_constant(constant, ty) + // } + // } + + // /// Resolve a path appeared in an expression. + // /// NOTE: Don't call this to resolve method receiver. + // fn resolve_path(&mut self, path: &hir_def::Path, source: SourceInfo) -> ValueId { + // let func_id = self.builder.func_id(); + // let module = func_id.module(self.db); + // match module.resolve_path(self.db.upcast(), path).value.unwrap() { + // NamedThing::Item(analyzer_items::Item::Constant(id)) => { + // let constant = self.db.mir_lowered_constant(id); + // let ty = constant.ty(self.db); + // self.builder.make_constant(constant, ty) + // } + // NamedThing::EnumVariant(variant) => { + // let enum_ty = self + // .db + // .mir_lowered_type(variant.parent(self.db.upcast()).as_type(self.db.upcast())); + // let tag_type = enum_ty.enum_disc_type(self.db); + // let tag = self.make_imm(variant.disc(self.db.upcast()), tag_type); + // let data = self.make_unit(); + // let enum_args = vec![tag, data]; + // let inst = self.builder.aggregate_construct(enum_ty, enum_args, source); + // self.map_to_tmp(inst, enum_ty) + // } + // _ => panic!("path defined in global must be constant"), + // } + // } fn scope(&self) -> &Scope { &self.scopes[self.current_scope] @@ -1240,7 +1133,7 @@ impl Scope { // Declare function parameters. for param in &func.signature(db).params { - let local = Local::arg_local(param.name.clone(), param.ty, param.source.clone()); + let local = Local::arg_local(param.name.clone(), param.ty); let value_id = builder.store_func_arg(local); root.declare_var(¶m.name, value_id) } @@ -1312,49 +1205,9 @@ impl Scope { } } -fn self_arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId) -> SourceInfo { - func.data(db.upcast()) - .ast - .kind - .sig - .kind - .args - .iter() - .find(|arg| matches!(arg.kind, hir_def::FunctionArg::Self_ { .. })) - .unwrap() - .into() -} - -fn arg_source(db: &dyn MirDb, func: analyzer_items::FunctionId, arg_name: &str) -> SourceInfo { - func.data(db.upcast()) - .ast - .kind - .sig - .kind - .args - .iter() - .find_map(|arg| match &arg.kind { - hir_def::FunctionArg::Regular { name, .. } => { - if name.kind == arg_name { - Some(name.into()) - } else { - None - } - } - hir_def::FunctionArg::Self_ { .. } => None, - }) - .unwrap() -} - -fn make_param( - db: &dyn MirDb, - name: impl Into, - ty: analyzer_types::TypeId, - source: SourceInfo, -) -> FunctionParam { +fn make_param(db: &dyn MirDb, name: impl Into, ty: TypeId) -> FunctionParam { FunctionParam { name: name.into(), ty: db.mir_lowered_type(ty), - source, } } diff --git a/crates/mir2/src/lower/mod.rs b/crates/mir2/src/lower/mod.rs index 7a5e6e50a3..7161774caf 100644 --- a/crates/mir2/src/lower/mod.rs +++ b/crates/mir2/src/lower/mod.rs @@ -2,4 +2,4 @@ pub mod constant; pub mod function; pub mod types; -mod pattern_match; +// mod pattern_match; diff --git a/crates/mir2/src/lower/pattern_match/decision_tree.rs b/crates/mir2/src/lower/pattern_match/decision_tree.rs index fb9d93ed7a..5ab0bca5ef 100644 --- a/crates/mir2/src/lower/pattern_match/decision_tree.rs +++ b/crates/mir2/src/lower/pattern_match/decision_tree.rs @@ -3,13 +3,6 @@ //! The algorithm for efficient decision tree construction is mainly based on [Compiling pattern matching to good decision trees](https://dl.acm.org/doi/10.1145/1411304.1411311). use std::io; -use fe_analyzer2::{ - pattern_analysis::{ - ConstructorKind, PatternMatrix, PatternRowVec, SigmaSet, SimplifiedPattern, - SimplifiedPatternKind, - }, - AnalyzerDb, -}; use indexmap::IndexMap; use smol_str::SmolStr; diff --git a/crates/mir2/src/lower/pattern_match/tree_vis.rs b/crates/mir2/src/lower/pattern_match/tree_vis.rs index d13d0d3921..30c94c3494 100644 --- a/crates/mir2/src/lower/pattern_match/tree_vis.rs +++ b/crates/mir2/src/lower/pattern_match/tree_vis.rs @@ -1,8 +1,8 @@ use std::fmt::Write; use dot2::{label::Text, Id}; -use fe_analyzer2::{pattern_analysis::ConstructorKind, AnalyzerDb}; use fxhash::FxHashMap; +use hir::HirDb; use indexmap::IndexMap; use smol_str::SmolStr; @@ -11,12 +11,12 @@ use super::decision_tree::{Case, DecisionTree, LeafNode, Occurrence, SwitchNode} pub(super) struct TreeRenderer<'db> { nodes: Vec, edges: FxHashMap<(usize, usize), Case>, - db: &'db dyn AnalyzerDb, + db: &'db dyn HirDb, } impl<'db> TreeRenderer<'db> { #[allow(unused)] - pub(super) fn new(db: &'db dyn AnalyzerDb, tree: &DecisionTree) -> Self { + pub(super) fn new(db: &'db dyn HirDb, tree: &DecisionTree) -> Self { let mut renderer = Self { nodes: Vec::new(), edges: FxHashMap::default(), @@ -88,19 +88,19 @@ impl<'db> dot2::Labeller<'db> for TreeRenderer<'db> { Ok(Text::LabelStr(label.into())) } - fn edge_label(&self, e: &Self::Edge) -> Text<'db> { - let label = match &self.edges[e] { - Case::Ctor(ConstructorKind::Enum(variant)) => { - variant.name_with_parent(self.db).to_string() - } - Case::Ctor(ConstructorKind::Tuple(_)) => "()".to_string(), - Case::Ctor(ConstructorKind::Struct(sid)) => sid.name(self.db).into(), - Case::Ctor(ConstructorKind::Literal((lit, _))) => lit.to_string(), - Case::Default => "_".into(), - }; - - Text::LabelStr(label.into()) - } + // fn edge_label(&self, e: &Self::Edge) -> Text<'db> { + // let label = match &self.edges[e] { + // Case::Ctor(ConstructorKind ::Enum(variant)) => { + // variant.name_with_parent(self.db).to_string() + // } + // Case::Ctor(ConstructorKind::Tuple(_)) => "()".to_string(), + // Case::Ctor(ConstructorKind::Struct(sid)) => sid.name(self.db).into(), + // Case::Ctor(ConstructorKind::Literal((lit, _))) => lit.to_string(), + // Case::Default => "_".into(), + // }; + + // Text::LabelStr(label.into()) + // } } impl<'db> dot2::GraphWalk<'db> for TreeRenderer<'db> { diff --git a/crates/mir2/src/lower/types.rs b/crates/mir2/src/lower/types.rs index 7072eaa96b..924841aa1c 100644 --- a/crates/mir2/src/lower/types.rs +++ b/crates/mir2/src/lower/types.rs @@ -1,194 +1,194 @@ -use crate::{ - db::MirDb, - ir::{ - types::{ArrayDef, EnumDef, EnumVariant, MapDef, StructDef, TupleDef}, - Type, TypeId, TypeKind, - }, -}; - -use fe_analyzer::namespace::{ - items as analyzer_items, - types::{self as analyzer_types, TraitOrType}, -}; - -pub fn lower_type(db: &dyn MirDb, analyzer_ty: analyzer_types::TypeId) -> TypeId { - let ty_kind = match analyzer_ty.typ(db.upcast()) { - analyzer_types::Type::SPtr(inner) => TypeKind::SPtr(lower_type(db, inner)), - - // NOTE: this results in unexpected MIR TypeId inequalities - // (when different analyzer types map to the same MIR type). - // We could (should?) remove .analyzer_ty from Type. - analyzer_types::Type::Mut(inner) => match inner.typ(db.upcast()) { - analyzer_types::Type::SPtr(t) => TypeKind::SPtr(lower_type(db, t)), - analyzer_types::Type::Base(t) => lower_base(t), - analyzer_types::Type::Contract(_) => TypeKind::Address, - _ => TypeKind::MPtr(lower_type(db, inner)), - }, - analyzer_types::Type::SelfType(inner) => match inner { - TraitOrType::TypeId(id) => return lower_type(db, id), - TraitOrType::TraitId(_) => panic!("traits aren't lowered"), - }, - analyzer_types::Type::Base(base) => lower_base(base), - analyzer_types::Type::Array(arr) => lower_array(db, &arr), - analyzer_types::Type::Map(map) => lower_map(db, &map), - analyzer_types::Type::Tuple(tup) => lower_tuple(db, &tup), - analyzer_types::Type::String(string) => TypeKind::String(string.max_size), - analyzer_types::Type::Contract(_) => TypeKind::Address, - analyzer_types::Type::SelfContract(contract) => lower_contract(db, contract), - analyzer_types::Type::Struct(struct_) => lower_struct(db, struct_), - analyzer_types::Type::Enum(enum_) => lower_enum(db, enum_), - analyzer_types::Type::Generic(_) => { - panic!("should be lowered in `lower_analyzer_type`") - } - }; - - intern_type(db, ty_kind, Some(analyzer_ty.deref(db.upcast()))) -} - -fn lower_base(base: analyzer_types::Base) -> TypeKind { - use analyzer_types::{Base, Integer}; - - match base { - Base::Numeric(int_ty) => match int_ty { - Integer::I8 => TypeKind::I8, - Integer::I16 => TypeKind::I16, - Integer::I32 => TypeKind::I32, - Integer::I64 => TypeKind::I64, - Integer::I128 => TypeKind::I128, - Integer::I256 => TypeKind::I256, - Integer::U8 => TypeKind::U8, - Integer::U16 => TypeKind::U16, - Integer::U32 => TypeKind::U32, - Integer::U64 => TypeKind::U64, - Integer::U128 => TypeKind::U128, - Integer::U256 => TypeKind::U256, - }, - - Base::Bool => TypeKind::Bool, - Base::Address => TypeKind::Address, - Base::Unit => TypeKind::Unit, - } -} - -fn lower_array(db: &dyn MirDb, arr: &analyzer_types::Array) -> TypeKind { - let len = arr.size; - let elem_ty = db.mir_lowered_type(arr.inner); - - let def = ArrayDef { elem_ty, len }; - TypeKind::Array(def) -} - -fn lower_map(db: &dyn MirDb, map: &analyzer_types::Map) -> TypeKind { - let key_ty = db.mir_lowered_type(map.key); - let value_ty = db.mir_lowered_type(map.value); - - let def = MapDef { key_ty, value_ty }; - TypeKind::Map(def) -} - -fn lower_tuple(db: &dyn MirDb, tup: &analyzer_types::Tuple) -> TypeKind { - let items = tup - .items - .iter() - .map(|item| db.mir_lowered_type(*item)) - .collect(); - - let def = TupleDef { items }; - TypeKind::Tuple(def) -} - -fn lower_contract(db: &dyn MirDb, contract: analyzer_items::ContractId) -> TypeKind { - let name = contract.name(db.upcast()); - - // Note: contract field types are wrapped in SPtr in TypeId::projection_ty - let fields = contract - .fields(db.upcast()) - .iter() - .map(|(fname, fid)| { - let analyzer_type = fid.typ(db.upcast()).unwrap(); - let ty = db.mir_lowered_type(analyzer_type); - (fname.clone(), ty) - }) - .collect(); - - // Obtain span. - let span = contract.span(db.upcast()); - - let module_id = contract.module(db.upcast()); - - let def = StructDef { - name, - fields, - span, - module_id, - }; - TypeKind::Contract(def) -} - -fn lower_struct(db: &dyn MirDb, id: analyzer_items::StructId) -> TypeKind { - let name = id.name(db.upcast()); - - // Lower struct fields. - let fields = id - .fields(db.upcast()) - .iter() - .map(|(fname, fid)| { - let analyzer_types = fid.typ(db.upcast()).unwrap(); - let ty = db.mir_lowered_type(analyzer_types); - (fname.clone(), ty) - }) - .collect(); - - // obtain span. - let span = id.span(db.upcast()); - - let module_id = id.module(db.upcast()); - - let def = StructDef { - name, - fields, - span, - module_id, - }; - TypeKind::Struct(def) -} - -fn lower_enum(db: &dyn MirDb, id: analyzer_items::EnumId) -> TypeKind { - let analyzer_variants = id.variants(db.upcast()); - let mut variants = Vec::with_capacity(analyzer_variants.len()); - for variant in analyzer_variants.values() { - let variant_ty = match variant.kind(db.upcast()).unwrap() { - analyzer_items::EnumVariantKind::Tuple(elts) => { - let tuple_ty = analyzer_types::TypeId::tuple(db.upcast(), &elts); - db.mir_lowered_type(tuple_ty) - } - analyzer_items::EnumVariantKind::Unit => { - let unit_ty = analyzer_types::TypeId::unit(db.upcast()); - db.mir_lowered_type(unit_ty) - } - }; - - variants.push(EnumVariant { - name: variant.name(db.upcast()), - span: variant.span(db.upcast()), - ty: variant_ty, - }); - } - - let def = EnumDef { - name: id.name(db.upcast()), - span: id.span(db.upcast()), - variants, - module_id: id.module(db.upcast()), - }; - - TypeKind::Enum(def) -} - -fn intern_type( - db: &dyn MirDb, - ty_kind: TypeKind, - analyzer_type: Option, -) -> TypeId { - db.mir_intern_type(Type::new(ty_kind, analyzer_type).into()) -} +// use crate::{ +// db::MirDb, +// ir::{ +// types::{ArrayDef, EnumDef, EnumVariant, MapDef, StructDef, TupleDef}, +// Type, TypeId, TypeKind, +// }, +// }; + +// use fe_analyzer::namespace::{ +// items as analyzer_items, +// types::{self as analyzer_types, TraitOrType}, +// }; + +// pub fn lower_type(db: &dyn MirDb, analyzer_ty: analyzer_types::TypeId) -> TypeId { +// let ty_kind = match analyzer_ty.typ(db.upcast()) { +// analyzer_types::Type::SPtr(inner) => TypeKind::SPtr(lower_type(db, inner)), + +// // NOTE: this results in unexpected MIR TypeId inequalities +// // (when different analyzer types map to the same MIR type). +// // We could (should?) remove .analyzer_ty from Type. +// analyzer_types::Type::Mut(inner) => match inner.typ(db.upcast()) { +// analyzer_types::Type::SPtr(t) => TypeKind::SPtr(lower_type(db, t)), +// analyzer_types::Type::Base(t) => lower_base(t), +// analyzer_types::Type::Contract(_) => TypeKind::Address, +// _ => TypeKind::MPtr(lower_type(db, inner)), +// }, +// analyzer_types::Type::SelfType(inner) => match inner { +// TraitOrType::TypeId(id) => return lower_type(db, id), +// TraitOrType::TraitId(_) => panic!("traits aren't lowered"), +// }, +// analyzer_types::Type::Base(base) => lower_base(base), +// analyzer_types::Type::Array(arr) => lower_array(db, &arr), +// analyzer_types::Type::Map(map) => lower_map(db, &map), +// analyzer_types::Type::Tuple(tup) => lower_tuple(db, &tup), +// analyzer_types::Type::String(string) => TypeKind::String(string.max_size), +// analyzer_types::Type::Contract(_) => TypeKind::Address, +// analyzer_types::Type::SelfContract(contract) => lower_contract(db, contract), +// analyzer_types::Type::Struct(struct_) => lower_struct(db, struct_), +// analyzer_types::Type::Enum(enum_) => lower_enum(db, enum_), +// analyzer_types::Type::Generic(_) => { +// panic!("should be lowered in `lower_analyzer_type`") +// } +// }; + +// intern_type(db, ty_kind, Some(analyzer_ty.deref(db.upcast()))) +// } + +// fn lower_base(base: analyzer_types::Base) -> TypeKind { +// use analyzer_types::{Base, Integer}; + +// match base { +// Base::Numeric(int_ty) => match int_ty { +// Integer::I8 => TypeKind::I8, +// Integer::I16 => TypeKind::I16, +// Integer::I32 => TypeKind::I32, +// Integer::I64 => TypeKind::I64, +// Integer::I128 => TypeKind::I128, +// Integer::I256 => TypeKind::I256, +// Integer::U8 => TypeKind::U8, +// Integer::U16 => TypeKind::U16, +// Integer::U32 => TypeKind::U32, +// Integer::U64 => TypeKind::U64, +// Integer::U128 => TypeKind::U128, +// Integer::U256 => TypeKind::U256, +// }, + +// Base::Bool => TypeKind::Bool, +// Base::Address => TypeKind::Address, +// Base::Unit => TypeKind::Unit, +// } +// } + +// fn lower_array(db: &dyn MirDb, arr: &analyzer_types::Array) -> TypeKind { +// let len = arr.size; +// let elem_ty = db.mir_lowered_type(arr.inner); + +// let def = ArrayDef { elem_ty, len }; +// TypeKind::Array(def) +// } + +// fn lower_map(db: &dyn MirDb, map: &analyzer_types::Map) -> TypeKind { +// let key_ty = db.mir_lowered_type(map.key); +// let value_ty = db.mir_lowered_type(map.value); + +// let def = MapDef { key_ty, value_ty }; +// TypeKind::Map(def) +// } + +// fn lower_tuple(db: &dyn MirDb, tup: &analyzer_types::Tuple) -> TypeKind { +// let items = tup +// .items +// .iter() +// .map(|item| db.mir_lowered_type(*item)) +// .collect(); + +// let def = TupleDef { items }; +// TypeKind::Tuple(def) +// } + +// fn lower_contract(db: &dyn MirDb, contract: analyzer_items::ContractId) -> TypeKind { +// let name = contract.name(db.upcast()); + +// // Note: contract field types are wrapped in SPtr in TypeId::projection_ty +// let fields = contract +// .fields(db.upcast()) +// .iter() +// .map(|(fname, fid)| { +// let analyzer_type = fid.typ(db.upcast()).unwrap(); +// let ty = db.mir_lowered_type(analyzer_type); +// (fname.clone(), ty) +// }) +// .collect(); + +// // Obtain span. +// let span = contract.span(db.upcast()); + +// let module_id = contract.module(db.upcast()); + +// let def = StructDef { +// name, +// fields, +// span, +// module_id, +// }; +// TypeKind::Contract(def) +// } + +// fn lower_struct(db: &dyn MirDb, id: analyzer_items::StructId) -> TypeKind { +// let name = id.name(db.upcast()); + +// // Lower struct fields. +// let fields = id +// .fields(db.upcast()) +// .iter() +// .map(|(fname, fid)| { +// let analyzer_types = fid.typ(db.upcast()).unwrap(); +// let ty = db.mir_lowered_type(analyzer_types); +// (fname.clone(), ty) +// }) +// .collect(); + +// // obtain span. +// let span = id.span(db.upcast()); + +// let module_id = id.module(db.upcast()); + +// let def = StructDef { +// name, +// fields, +// span, +// module_id, +// }; +// TypeKind::Struct(def) +// } + +// fn lower_enum(db: &dyn MirDb, id: analyzer_items::EnumId) -> TypeKind { +// let analyzer_variants = id.variants(db.upcast()); +// let mut variants = Vec::with_capacity(analyzer_variants.len()); +// for variant in analyzer_variants.values() { +// let variant_ty = match variant.kind(db.upcast()).unwrap() { +// analyzer_items::EnumVariantKind::Tuple(elts) => { +// let tuple_ty = analyzer_types::TypeId::tuple(db.upcast(), &elts); +// db.mir_lowered_type(tuple_ty) +// } +// analyzer_items::EnumVariantKind::Unit => { +// let unit_ty = analyzer_types::TypeId::unit(db.upcast()); +// db.mir_lowered_type(unit_ty) +// } +// }; + +// variants.push(EnumVariant { +// name: variant.name(db.upcast()), +// span: variant.span(db.upcast()), +// ty: variant_ty, +// }); +// } + +// let def = EnumDef { +// name: id.name(db.upcast()), +// span: id.span(db.upcast()), +// variants, +// module_id: id.module(db.upcast()), +// }; + +// TypeKind::Enum(def) +// } + +// fn intern_type( +// db: &dyn MirDb, +// ty_kind: TypeKind, +// analyzer_type: Option, +// ) -> TypeId { +// db.mir_intern_type(Type::new(ty_kind, analyzer_type).into()) +// } From bfc2a7d350a6375b1b72e02fbce07dc6350cdfd4 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Mon, 5 Feb 2024 16:50:35 -0700 Subject: [PATCH 21/22] hacking --- crates/mir2/src/lower/constant.rs | 35 ++- crates/mir2/src/lower/function.rs | 500 +++++++++++++++--------------- 2 files changed, 268 insertions(+), 267 deletions(-) diff --git a/crates/mir2/src/lower/constant.rs b/crates/mir2/src/lower/constant.rs index 9ba92820b6..7574ea0b98 100644 --- a/crates/mir2/src/lower/constant.rs +++ b/crates/mir2/src/lower/constant.rs @@ -9,23 +9,24 @@ use crate::{ #[salsa::tracked] pub fn mir_lowered_constant(db: &dyn MirDb, analyzer_const: hir::hir_def::Const) -> ConstantId { - let name = analyzer_const.name(db.upcast()); - let value = analyzer_const.constant_value(db.upcast()).unwrap(); - let ty = analyzer_const.typ(db.upcast()).unwrap(); - let module_id = analyzer_const.module(db.upcast()); - let span = analyzer_const.span(db.upcast()); - let id = analyzer_const.node_id(db.upcast()); - - let ty = db.mir_lowered_type(ty); - - let constant = Constant { - name, - value: value.into(), - ty, - module_id, - }; - - db.mir_intern_const(constant.into()) + // let name = analyzer_const.name(db.upcast()); + // let value = analyzer_const.constant_value(db.upcast()).unwrap(); + // let ty = analyzer_const.typ(db.upcast()).unwrap(); + // let module_id = analyzer_const.module(db.upcast()); + // let span = analyzer_const.span(db.upcast()); + // let id = analyzer_const.node_id(db.upcast()); + + // let ty = db.mir_lowered_type(ty); + + // let constant = Constant { + // name, + // value: value.into(), + // ty, + // module_id, + // }; + + // db.mir_intern_const(constant.into()) + panic!() } impl ConstantId { diff --git a/crates/mir2/src/lower/function.rs b/crates/mir2/src/lower/function.rs index 7a0c8fb784..567eb47984 100644 --- a/crates/mir2/src/lower/function.rs +++ b/crates/mir2/src/lower/function.rs @@ -16,15 +16,15 @@ use crate::{ type ScopeId = Id; -pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { - let analyzer_func = func.analyzer_func(db); - let ast = &analyzer_func.data(db.upcast()).ast; - let analyzer_body = analyzer_func.body(db.upcast()); - - BodyLowerHelper::new(db, func, ast, analyzer_body.as_ref()) - .lower() - .into() -} +// pub fn lower_func_body(db: &dyn MirDb, func: FunctionId) -> Rc { +// let analyzer_func = func.analyzer_func(db); +// let ast = &analyzer_func.data(db.upcast()).ast; +// let analyzer_body = analyzer_func.body(db.upcast()); + +// BodyLowerHelper::new(db, func, ast, analyzer_body.as_ref()) +// .lower() +// .into() +// } pub(super) struct BodyLowerHelper<'db, 'a> { pub(super) db: &'db dyn MirDb, @@ -426,17 +426,17 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { panic!() } - fn inst_result_or_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { - self.builder - .inst_result(inst) - .and_then(|r| r.value_id()) - .unwrap_or_else(|| self.map_to_tmp(inst, ty)) - } + // fn inst_result_or_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + // self.builder + // .inst_result(inst) + // .and_then(|r| r.value_id()) + // .unwrap_or_else(|| self.map_to_tmp(inst, ty)) + // } - pub(super) fn lower_expr_to_value(&mut self, expr: &hir_def::Expr) -> ValueId { - let (inst, ty) = self.lower_expr(expr); - self.map_to_tmp(inst, ty) - } + // pub(super) fn lower_expr_to_value(&mut self, expr: &hir_def::Expr) -> ValueId { + // let (inst, ty) = self.lower_expr(expr); + // self.map_to_tmp(inst, ty) + // } pub(super) fn enter_scope(&mut self) { let new_scope = Scope::with_parent(self.current_scope); @@ -454,26 +454,26 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { }) } - pub(super) fn make_u256_imm(&mut self, value: impl Into) -> ValueId { - let u256_ty = self.u256_ty(); - self.make_imm(value, u256_ty) - } + // pub(super) fn make_u256_imm(&mut self, value: impl Into) -> ValueId { + // let u256_ty = self.u256_ty(); + // self.make_imm(value, u256_ty) + // } - pub(super) fn map_to_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { - match &self.builder.inst_data(inst).kind { - &InstKind::Bind { src } => { - let value = *src; - self.builder.remove_inst(inst); - value - } - _ => { - let tmp = Value::Temporary { inst, ty }; - let result = self.builder.make_value(tmp); - self.builder.map_result(inst, result.into()); - result - } - } - } + // pub(super) fn map_to_tmp(&mut self, inst: InstId, ty: TypeId) -> ValueId { + // match &self.builder.inst_data(inst).kind { + // &InstKind::Bind { src } => { + // let value = *src; + // self.builder.remove_inst(inst); + // value + // } + // _ => { + // let tmp = Value::Temporary { inst, ty }; + // let result = self.builder.make_value(tmp); + // self.builder.map_result(inst, result.into()); + // result + // } + // } + // } // fn new( // db: &'db dyn MirDb, @@ -517,87 +517,87 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { // self.db.mir_lowered_type(analyzer_ty) // } - fn lower(mut self) -> FunctionBody { - for stmt in &self.ast.kind.body { - self.lower_stmt(stmt) - } + // fn lower(mut self) -> FunctionBody { + // for stmt in &self.ast.kind.body { + // self.lower_stmt(stmt) + // } - let last_block = self.builder.current_block(); - if !self.builder.is_block_terminated(last_block) { - let unit = self.make_unit(); - self.builder.ret(unit()); - } + // let last_block = self.builder.current_block(); + // if !self.builder.is_block_terminated(last_block) { + // let unit = self.make_unit(); + // self.builder.ret(unit()); + // } - self.builder.build() - } + // self.builder.build() + // } - fn branch_eq( - &mut self, - v1: ValueId, - v2: ValueId, - true_bb: BasicBlockId, - false_bb: BasicBlockId, - ) { - let cond = self.builder.eq(v1, v2); - let bool_ty = self.bool_ty(); - let cond = self.map_to_tmp(cond, bool_ty); - self.builder.branch(cond, true_bb, false_bb); - } + // fn branch_eq( + // &mut self, + // v1: ValueId, + // v2: ValueId, + // true_bb: BasicBlockId, + // false_bb: BasicBlockId, + // ) { + // let cond = self.builder.eq(v1, v2); + // let bool_ty = self.bool_ty(); + // let cond = self.map_to_tmp(cond, bool_ty); + // self.builder.branch(cond, true_bb, false_bb); + // } - fn lower_if(&mut self, cond: &hir_def::Expr, then: &[hir_def::Stmt], else_: &[hir_def::Stmt]) { - let cond = self.lower_expr_to_value(cond); - - if else_.is_empty() { - let then_bb = self.builder.make_block(); - let merge_bb = self.builder.make_block(); - - self.builder.branch(cond, then_bb, merge_bb()); - - // Lower then block. - self.builder.move_to_block(then_bb); - self.enter_scope(); - for stmt in then { - self.lower_stmt(stmt); - } - self.builder.jump(merge_bb()); - self.builder.move_to_block(merge_bb); - self.leave_scope(); - } else { - let then_bb = self.builder.make_block(); - let else_bb = self.builder.make_block(); - - self.builder.branch(cond, then_bb, else_bb()); - - // Lower then block. - self.builder.move_to_block(then_bb); - self.enter_scope(); - for stmt in then { - self.lower_stmt(stmt); - } - self.leave_scope(); - let then_block_end_bb = self.builder.current_block(); - - // Lower else_block. - self.builder.move_to_block(else_bb); - self.enter_scope(); - for stmt in else_ { - self.lower_stmt(stmt); - } - self.leave_scope(); - let else_block_end_bb = self.builder.current_block(); - - let merge_bb = self.builder.make_block(); - if !self.builder.is_block_terminated(then_block_end_bb) { - self.builder.move_to_block(then_block_end_bb); - self.builder.jump(merge_bb()); - } - if !self.builder.is_block_terminated(else_block_end_bb) { - self.builder.move_to_block(else_block_end_bb); - self.builder.jump(merge_bb()); - } - self.builder.move_to_block(merge_bb); - } - } + // fn lower_if(&mut self, cond: &hir_def::Expr, then: &[hir_def::Stmt], else_: &[hir_def::Stmt]) { + // let cond = self.lower_expr_to_value(cond); + + // if else_.is_empty() { + // let then_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // self.builder.branch(cond, then_bb, merge_bb()); + + // // Lower then block. + // self.builder.move_to_block(then_bb); + // self.enter_scope(); + // for stmt in then { + // self.lower_stmt(stmt); + // } + // self.builder.jump(merge_bb()); + // self.builder.move_to_block(merge_bb); + // self.leave_scope(); + // } else { + // let then_bb = self.builder.make_block(); + // let else_bb = self.builder.make_block(); + + // self.builder.branch(cond, then_bb, else_bb()); + + // // Lower then block. + // self.builder.move_to_block(then_bb); + // self.enter_scope(); + // for stmt in then { + // self.lower_stmt(stmt); + // } + // self.leave_scope(); + // let then_block_end_bb = self.builder.current_block(); + + // // Lower else_block. + // self.builder.move_to_block(else_bb); + // self.enter_scope(); + // for stmt in else_ { + // self.lower_stmt(stmt); + // } + // self.leave_scope(); + // let else_block_end_bb = self.builder.current_block(); + + // let merge_bb = self.builder.make_block(); + // if !self.builder.is_block_terminated(then_block_end_bb) { + // self.builder.move_to_block(then_block_end_bb); + // self.builder.jump(merge_bb()); + // } + // if !self.builder.is_block_terminated(else_block_end_bb) { + // self.builder.move_to_block(else_block_end_bb); + // self.builder.jump(merge_bb()); + // } + // self.builder.move_to_block(merge_bb); + // } + // } // NOTE: we assume a type of `iter` is array. // TODO: Desugar to `loop` + `match` like rustc in HIR to generate better MIR. @@ -702,102 +702,102 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { // } // } - /// Returns the pre-adjustment type of the given `Expr` - fn expr_ty(&self, expr: &hir_def::Expr) -> TypeId { - let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; - self.lower_analyzer_type(analyzer_ty) - } + // /// Returns the pre-adjustment type of the given `Expr` + // fn expr_ty(&self, expr: &hir_def::Expr) -> TypeId { + // let analyzer_ty = self.analyzer_body.expressions[&expr.id].typ; + // self.lower_analyzer_type(analyzer_ty) + // } - fn lower_bool_op( - &mut self, - op: hir_def::LogicalBinOp, - lhs: &hir_def::Expr, - rhs: &hir_def::Expr, - ty: TypeId, - ) -> InstId { - let true_bb = self.builder.make_block(); - let false_bb = self.builder.make_block(); - let merge_bb = self.builder.make_block(); - - let lhs = self.lower_expr_to_value(lhs); - let tmp = self - .builder - .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); - - match op { - hir_def::LogicalBinOp::And => { - self.builder.branch(lhs, true_bb, false_bb()); - - self.builder.move_to_block(true_bb); - let (rhs, _rhs_ty) = self.lower_expr(rhs); - self.builder.map_result(rhs, tmp.into()); - self.builder.jump(merge_bb()); - - self.builder.move_to_block(false_bb); - let false_imm = self.builder.make_imm_from_bool(false, ty); - let false_imm_copy = self.builder.bind(false_imm()); - self.builder.map_result(false_imm_copy, tmp.into()); - self.builder.jump(merge_bb()); - } - - hir_def::LogicalBinOp::Or => { - self.builder.branch(lhs, true_bb, false_bb()); - - self.builder.move_to_block(true_bb); - let true_imm = self.builder.make_imm_from_bool(true, ty); - let true_imm_copy = self.builder.bind(true_imm()); - self.builder.map_result(true_imm_copy, tmp.into()); - self.builder.jump(merge_bb()); - - self.builder.move_to_block(false_bb); - let (rhs, _rhs_ty) = self.lower_expr(rhs); - self.builder.map_result(rhs, tmp.into()); - self.builder.jump(merge_bb()); - } - } + // fn lower_bool_op( + // &mut self, + // op: hir_def::LogicalBinOp, + // lhs: &hir_def::Expr, + // rhs: &hir_def::Expr, + // ty: TypeId, + // ) -> InstId { + // let true_bb = self.builder.make_block(); + // let false_bb = self.builder.make_block(); + // let merge_bb = self.builder.make_block(); + + // let lhs = self.lower_expr_to_value(lhs); + // let tmp = self + // .builder + // .declare(Local::tmp_local(format!("${op}_tmp").into(), ty)); + + // match op { + // hir_def::LogicalBinOp::And => { + // self.builder.branch(lhs, true_bb, false_bb()); + + // self.builder.move_to_block(true_bb); + // let (rhs, _rhs_ty) = self.lower_expr(rhs); + // self.builder.map_result(rhs, tmp.into()); + // self.builder.jump(merge_bb()); + + // self.builder.move_to_block(false_bb); + // let false_imm = self.builder.make_imm_from_bool(false, ty); + // let false_imm_copy = self.builder.bind(false_imm()); + // self.builder.map_result(false_imm_copy, tmp.into()); + // self.builder.jump(merge_bb()); + // } - self.builder.move_to_block(merge_bb); - self.builder.bind(tmp()) - } + // hir_def::LogicalBinOp::Or => { + // self.builder.branch(lhs, true_bb, false_bb()); - fn lower_binop( - &mut self, - op: hir_def::BinOp, - lhs: ValueId, - rhs: ValueId, - // source: SourceInfo, - ) -> InstId { - match op { - hir_def::BinOp::Add => self.builder.add(lhs, rhs), - hir_def::BinOp::Sub => self.builder.sub(lhs, rhs), - hir_def::BinOp::Mult => self.builder.mul(lhs, rhs), - hir_def::BinOp::Div => self.builder.div(lhs, rhs), - hir_def::BinOp::Mod => self.builder.modulo(lhs, rhs), - hir_def::BinOp::Pow => self.builder.pow(lhs, rhs), - hir_def::BinOp::LShift => self.builder.shl(lhs, rhs), - hir_def::BinOp::RShift => self.builder.shr(lhs, rhs), - hir_def::BinOp::BitOr => self.builder.bit_or(lhs, rhs), - hir_def::BinOp::BitXor => self.builder.bit_xor(lhs, rhs), - hir_def::BinOp::BitAnd => self.builder.bit_and(lhs, rhs), - } - } + // self.builder.move_to_block(true_bb); + // let true_imm = self.builder.make_imm_from_bool(true, ty); + // let true_imm_copy = self.builder.bind(true_imm()); + // self.builder.map_result(true_imm_copy, tmp.into()); + // self.builder.jump(merge_bb()); - fn lower_comp_op( - &mut self, - op: hir_def::CompBinOp, - lhs: ValueId, - rhs: ValueId, - // source: SourceInfo, - ) -> InstId { - match op { - hir_def::CompBinOp::Eq => self.builder.eq(lhs, rhs), - hir_def::CompBinOp::NotEq => self.builder.ne(lhs, rhs), - hir_def::CompBinOp::Lt => self.builder.lt(lhs, rhs), - hir_def::CompBinOp::LtE => self.builder.le(lhs, rhs), - hir_def::CompBinOp::Gt => self.builder.gt(lhs, rhs), - hir_def::CompBinOp::GtE => self.builder.ge(lhs, rhs), - } - } + // self.builder.move_to_block(false_bb); + // let (rhs, _rhs_ty) = self.lower_expr(rhs); + // self.builder.map_result(rhs, tmp.into()); + // self.builder.jump(merge_bb()); + // } + // } + + // self.builder.move_to_block(merge_bb); + // self.builder.bind(tmp()) + // } + + // fn lower_binop( + // &mut self, + // op: hir_def::BinOp, + // lhs: ValueId, + // rhs: ValueId, + // // source: SourceInfo, + // ) -> InstId { + // match op { + // hir_def::BinOp::Add => self.builder.add(lhs, rhs), + // hir_def::BinOp::Sub => self.builder.sub(lhs, rhs), + // hir_def::BinOp::Mult => self.builder.mul(lhs, rhs), + // hir_def::BinOp::Div => self.builder.div(lhs, rhs), + // hir_def::BinOp::Mod => self.builder.modulo(lhs, rhs), + // hir_def::BinOp::Pow => self.builder.pow(lhs, rhs), + // hir_def::BinOp::LShift => self.builder.shl(lhs, rhs), + // hir_def::BinOp::RShift => self.builder.shr(lhs, rhs), + // hir_def::BinOp::BitOr => self.builder.bit_or(lhs, rhs), + // hir_def::BinOp::BitXor => self.builder.bit_xor(lhs, rhs), + // hir_def::BinOp::BitAnd => self.builder.bit_and(lhs, rhs), + // } + // } + + // fn lower_comp_op( + // &mut self, + // op: hir_def::CompBinOp, + // lhs: ValueId, + // rhs: ValueId, + // // source: SourceInfo, + // ) -> InstId { + // match op { + // hir_def::CompBinOp::Eq => self.builder.eq(lhs, rhs), + // hir_def::CompBinOp::NotEq => self.builder.ne(lhs, rhs), + // hir_def::CompBinOp::Lt => self.builder.lt(lhs, rhs), + // hir_def::CompBinOp::LtE => self.builder.le(lhs, rhs), + // hir_def::CompBinOp::Gt => self.builder.gt(lhs, rhs), + // hir_def::CompBinOp::GtE => self.builder.ge(lhs, rhs), + // } + // } // fn resolve_generics_args( // &mut self, @@ -975,38 +975,38 @@ impl<'db, 'a> BodyLowerHelper<'db, 'a> { todo!(); } - // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. - fn lower_method_receiver(&mut self, receiver: &hir_def::Expr) -> ValueId { - match &receiver.kind { - hir_def::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), - _ => unreachable!(), - } - } + // // FIXME: This is ugly hack to properly analyze method call. Remove this when https://github.com/ethereum/fe/issues/670 is resolved. + // fn lower_method_receiver(&mut self, receiver: &hir_def::Expr) -> ValueId { + // match &receiver.kind { + // hir_def::Expr::Attribute { value, .. } => self.lower_expr_to_value(value), + // _ => unreachable!(), + // } + // } - fn lower_aggregate_access( - &mut self, - expr: &hir_def::Expr, - indices: &mut Vec, - ) -> ValueId { - match &expr.kind { - hir_def::Expr::Attribute { value, attr } => { - let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); - let value = self.lower_aggregate_access(value, indices); - indices.push(self.make_u256_imm(index)); - value - } - - hir_def::Expr::Subscript { value, index } - if self.expr_ty(value).deref(self.db).is_aggregate(self.db) => - { - let value = self.lower_aggregate_access(value, indices); - indices.push(self.lower_expr_to_value(index)); - value - } - - _ => self.lower_expr_to_value(expr), - } - } + // fn lower_aggregate_access( + // &mut self, + // expr: &hir_def::Expr, + // indices: &mut Vec, + // ) -> ValueId { + // match &expr.kind { + // hir_def::Expr::Attribute { value, attr } => { + // let index = self.expr_ty(value).index_from_fname(self.db, &attr.kind); + // let value = self.lower_aggregate_access(value, indices); + // indices.push(self.make_u256_imm(index)); + // value + // } + + // hir_def::Expr::Subscript { value, index } + // if self.expr_ty(value).deref(self.db).is_aggregate(self.db) => + // { + // let value = self.lower_aggregate_access(value, indices); + // indices.push(self.lower_expr_to_value(index)); + // value + // } + + // _ => self.lower_expr_to_value(expr), + // } + // } // fn make_unit(&mut self) -> ValueId { // let unit_ty = analyzer_types::TypeId::unit(self.db.upcast()); @@ -1131,12 +1131,12 @@ impl Scope { maximum_iter_count: None, }; - // Declare function parameters. - for param in &func.signature(db).params { - let local = Local::arg_local(param.name.clone(), param.ty); - let value_id = builder.store_func_arg(local); - root.declare_var(¶m.name, value_id) - } + // // Declare function parameters. + // for param in &func.signature(db).params { + // let local = Local::arg_local(param.name.clone(), param.ty); + // let value_id = builder.store_func_arg(local); + // root.declare_var(¶m.name, value_id) + // } root } @@ -1205,9 +1205,9 @@ impl Scope { } } -fn make_param(db: &dyn MirDb, name: impl Into, ty: TypeId) -> FunctionParam { - FunctionParam { - name: name.into(), - ty: db.mir_lowered_type(ty), - } -} +// fn make_param(db: &dyn MirDb, name: impl Into, ty: TypeId) -> FunctionParam { +// FunctionParam { +// name: name.into(), +// ty: db.mir_lowered_type(ty), +// } +// } From a2bc54cb33c1937cd7f4b4a67aa2c3b39c2e581f Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Tue, 6 Feb 2024 14:40:06 -0700 Subject: [PATCH 22/22] hacking --- crates/mir2/src/ir/constant.rs | 23 ++++++-------------- crates/mir2/src/ir/mod.rs | 2 +- crates/mir2/src/lib.rs | 8 +++---- crates/mir2/src/lower/constant.rs | 36 ++++++------------------------- 4 files changed, 18 insertions(+), 51 deletions(-) diff --git a/crates/mir2/src/ir/constant.rs b/crates/mir2/src/ir/constant.rs index 5180efa0b8..34ab0e5961 100644 --- a/crates/mir2/src/ir/constant.rs +++ b/crates/mir2/src/ir/constant.rs @@ -1,30 +1,19 @@ -use hir::hir_def::{ModuleTreeNodeId, TypeId}; +use hir::hir_def; use num_bigint::BigInt; use smol_str::SmolStr; -// use super::SourceInfo; - #[salsa::interned] -pub struct ConstantId { +pub struct ConstId { #[return_ref] - pub data: Constant, + pub data: Const, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Constant { - /// A name of a constant. - pub name: SmolStr, - - /// A value of a constant. +pub struct Const { pub value: ConstantValue, - /// A type of a constant. - pub ty: TypeId, - - /// A module where a constant is declared. - pub module_id: ModuleTreeNodeId, - // /// A span where a constant is declared. - // pub source: SourceInfo, + #[return_ref] + pub(crate) origin: hir_def::Const, } // /// An interned Id for [`Constant`]. diff --git a/crates/mir2/src/ir/mod.rs b/crates/mir2/src/ir/mod.rs index d6fef7916a..bf27440d41 100644 --- a/crates/mir2/src/ir/mod.rs +++ b/crates/mir2/src/ir/mod.rs @@ -9,7 +9,7 @@ pub mod inst; pub mod value; pub use basic_block::{BasicBlock, BasicBlockId}; -pub use constant::{Constant, ConstantId}; +pub use constant::{Const, ConstId}; pub use function::{FunctionBody, FunctionId, FunctionParam, FunctionSignature}; pub use inst::{Inst, InstId}; // pub use types::{Type, TypeId, TypeKind}; diff --git a/crates/mir2/src/lib.rs b/crates/mir2/src/lib.rs index 49ef940708..2ee69d287f 100644 --- a/crates/mir2/src/lib.rs +++ b/crates/mir2/src/lib.rs @@ -10,7 +10,7 @@ mod lower; #[salsa::jar(db = MirDb)] pub struct Jar( // ir::Constant, - ir::ConstantId, + ir::ConstId, // ir::FunctionBody, // ir::FunctionId, // ir::FunctionParam, @@ -45,9 +45,9 @@ pub trait MirDb: salsa::DbWithJar + HirDb { // IdentId::prefill(self) } - fn as_hir_db(&self) -> &dyn MirDb { - >::as_jar_db::<'_>(self) - } + // fn as_hir_db(&self) -> &dyn MirDb { + // >::as_jar_db::<'_>(self) + // } } impl MirDb for DB where DB: salsa::DbWithJar + HirDb {} diff --git a/crates/mir2/src/lower/constant.rs b/crates/mir2/src/lower/constant.rs index 7574ea0b98..6f6f0decfa 100644 --- a/crates/mir2/src/lower/constant.rs +++ b/crates/mir2/src/lower/constant.rs @@ -1,40 +1,18 @@ use std::rc::Rc; -use hir::hir_def::TypeId; +use hir::hir_def::{Const, TypeId}; use crate::{ - ir::{Constant, ConstantId}, + ir::{Const, ConstId}, MirDb, }; #[salsa::tracked] -pub fn mir_lowered_constant(db: &dyn MirDb, analyzer_const: hir::hir_def::Const) -> ConstantId { - // let name = analyzer_const.name(db.upcast()); - // let value = analyzer_const.constant_value(db.upcast()).unwrap(); - // let ty = analyzer_const.typ(db.upcast()).unwrap(); - // let module_id = analyzer_const.module(db.upcast()); - // let span = analyzer_const.span(db.upcast()); - // let id = analyzer_const.node_id(db.upcast()); +pub fn mir_lowered_constant(db: &dyn MirDb, hir_const: Const) -> ConstId { + let value = hir_const.constant_value(db.as_hir_db()).unwrap(); - // let ty = db.mir_lowered_type(ty); - - // let constant = Constant { - // name, - // value: value.into(), - // ty, - // module_id, - // }; - - // db.mir_intern_const(constant.into()) - panic!() -} - -impl ConstantId { - // pub fn data(self, db: &dyn MirDb) -> Rc { - // db.lookup_mir_intern_const(self) - // } - - pub fn ty(self, db: &dyn MirDb) -> TypeId { - self.data(db).ty + let constant = Const { + value: value.into(), + origin: hir_const, } }