diff --git a/crates/codegen/src/yul/isel/inst_order.rs b/crates/codegen/src/yul/isel/inst_order.rs index b6e5317e77..4475007124 100644 --- a/crates/codegen/src/yul/isel/inst_order.rs +++ b/crates/codegen/src/yul/isel/inst_order.rs @@ -27,6 +27,7 @@ pub(super) struct InstSerializer<'a> { cfg: ControlFlowGraph, loop_tree: LoopTree, df: DFSet, + domtree: DomTree, pd_tree: PostDomTree, scope: Option, } @@ -44,6 +45,7 @@ impl<'a> InstSerializer<'a> { cfg, loop_tree, df, + domtree, pd_tree, scope: None, } @@ -125,7 +127,6 @@ impl<'a> InstSerializer<'a> { let mut else_body = vec![]; self.enter_if_scope(merge_block); - let mut serialize_dest = |dest_info, body: &mut Vec, merge_block| match dest_info { TerminatorInfo::Break => body.push(StructuralInst::Break), @@ -251,25 +252,81 @@ impl<'a> InstSerializer<'a> { &self, block: BasicBlockId, cond: ValueId, - then: BasicBlockId, - else_: BasicBlockId, + then_bb: BasicBlockId, + else_bb: BasicBlockId, ) -> TerminatorInfo { - let then = Box::new(self.analyze_dest(then)); - let else_ = Box::new(self.analyze_dest(else_)); + let then = Box::new(self.analyze_dest(then_bb)); + let else_ = Box::new(self.analyze_dest(else_bb)); - let merge_block = match self.pd_tree.post_idom(block) { - PostIDom::Block(block) => { - if let Some(lp) = self.scope.as_ref().and_then(Scope::loop_recursive) { - if self.loop_tree.is_block_in_loop(block, lp) { - Some(block) - } else { - None - } + let cand_for_merge_bb = |bb| { + if self.domtree.dominates(bb, block) { + return None; + } + + // a block `cand` can be a candidate of a `merge` block iff + // 1. `cand` is a dominance frontier of `bb`. + // 2. `cand` is NOT a dominator of `bb`. + // 3. `cand` is NOT a "merge" block of parent `if`. + // 4. `cand` is NOT a "loop_exit" block of parent `loop`. + let mut cands = self.df.frontiers(bb)?.filter(|cand| { + !self.domtree.dominates(*cand, bb) + && Some(*cand) + != self + .scope + .as_ref() + .and_then(|scope| scope.if_merge_block_recursive()) + && Some(*cand) + != self + .scope + .as_ref() + .and_then(|scope| scope.loop_exit_recursive()) + }); + + let cand = cands.next(); + // Assert the number of candidates is at most one. + debug_assert!(cands.next().is_none()); + cand + }; + + let merge_block = match (cand_for_merge_bb(then_bb), cand_for_merge_bb(else_bb)) { + (Some(then_cand), Some(else_cand)) => { + if then_cand == else_cand { + Some(then_cand) } else { - Some(block) + None } } - _ => None, + + (Some(cand), None) => { + if cand == else_bb { + Some(cand) + } else { + None + } + } + + (None, Some(cand)) => { + if cand == then_bb { + Some(cand) + } else { + None + } + } + + (None, None) => match self.pd_tree.post_idom(block) { + PostIDom::Block(block) => { + if let Some(lp) = self.scope.as_ref().and_then(Scope::loop_recursive) { + if self.loop_tree.is_block_in_loop(block, lp) { + Some(block) + } else { + None + } + } else { + Some(block) + } + } + _ => None, + }, }; TerminatorInfo::If { @@ -291,7 +348,7 @@ impl<'a> InstSerializer<'a> { TerminatorInfo::Continue } else if Some(dest) == scope.loop_exit_recursive() { TerminatorInfo::Break - } else if Some(dest) == scope.if_merge_block() { + } else if Some(dest) == scope.if_merge_block_recursive() { TerminatorInfo::ToMergeBlock } else { TerminatorInfo::FallThrough(dest) @@ -355,6 +412,15 @@ impl Scope { _ => None, } } + + fn if_merge_block_recursive(&self) -> Option { + match self.kind { + ScopeKind::If { + merge_block: Some(merge_block), + } => Some(merge_block), + _ => self.parent.as_ref()?.if_merge_block_recursive(), + } + } } #[derive(Debug, Clone)] @@ -387,37 +453,40 @@ mod tests { } fn expect_if( - inst: StructuralInst, + insts: &mut impl Iterator, ) -> ( impl Iterator, impl Iterator, ) { - match inst { + match insts.next().unwrap() { StructuralInst::If { then, else_, .. } => (then.into_iter(), else_.into_iter()), _ => panic!("expect if inst"), } } - fn expect_for(inst: StructuralInst) -> impl Iterator { - match inst { + fn expect_for( + insts: &mut impl Iterator, + ) -> impl Iterator { + match insts.next().unwrap() { StructuralInst::For { body } => body.into_iter(), _ => panic!("expect if inst"), } } - fn expect_break(inst: StructuralInst) { - assert!(matches!(inst, StructuralInst::Break)) + fn expect_break(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Break)) } - fn expect_continue(inst: StructuralInst) { - assert!(matches!(inst, StructuralInst::Continue)) + fn expect_continue(insts: &mut impl Iterator) { + assert!(matches!(insts.next().unwrap(), StructuralInst::Continue)) } - fn expect_return(func: &FunctionBody, inst: &StructuralInst) { + fn expect_return(func: &FunctionBody, insts: &mut impl Iterator) { + let inst = insts.next().unwrap(); match inst { StructuralInst::Inst(inst) => { assert!(matches!( - func.store.inst_data(*inst).kind, + func.store.inst_data(inst).kind, InstKind::Return { .. } )) } @@ -425,6 +494,10 @@ mod tests { } } + fn expect_end(insts: &mut impl Iterator) { + assert!(insts.next().is_none()) + } + #[test] fn if_non_merge() { // +------+ +-------+ @@ -456,13 +529,13 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let (mut then, mut else_) = expect_if(order.next().unwrap()); - expect_return(&func, &then.next().unwrap()); - assert!(then.next().is_none()); - expect_return(&func, &else_.next().unwrap()); - assert!(else_.next().is_none()); + let (mut then, mut else_) = expect_if(&mut order); + expect_return(&func, &mut then); + expect_end(&mut then); + expect_return(&func, &mut else_); + expect_end(&mut else_); - assert!(order.next().is_none()); + expect_end(&mut order); } #[test] @@ -506,12 +579,76 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let (mut then, mut else_) = expect_if(order.next().unwrap()); - assert!(then.next().is_none()); - assert!(else_.next().is_none()); + let (mut then, mut else_) = expect_if(&mut order); + expect_end(&mut then); + expect_end(&mut else_); + + expect_return(&func, &mut order); + expect_end(&mut order); + } + + #[test] + fn nested_if() { + // +-----+ + // | bb0 | -+ + // +-----+ | + // | | + // | | + // v | + // +-----+ +-----+ | + // | bb3 | <-- | bb1 | | + // +-----+ +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb4 | | + // +-----+ | + // | | + // | | + // v | + // +-----+ | + // | bb2 | <+ + // +-----+ + let mut builder = body_builder(); + + let bb1 = builder.make_block(); + let bb2 = builder.make_block(); + let bb3 = builder.make_block(); + let bb4 = builder.make_block(); + + let dummy_ty = TypeId(0); + let v0 = builder.make_imm_from_bool(true, dummy_ty); + let unit = builder.make_unit(dummy_ty); + + builder.branch(v0, bb1, bb2, SourceInfo::dummy()); + + builder.move_to_block(bb1); + builder.branch(v0, bb3, bb4, SourceInfo::dummy()); + + builder.move_to_block(bb3); + builder.ret(unit, SourceInfo::dummy()); + + builder.move_to_block(bb4); + builder.jump(bb2, SourceInfo::dummy()); + + builder.move_to_block(bb2); + builder.ret(unit, SourceInfo::dummy()); + + let mut func = builder.build(); + let mut order = serialize_func_body(&mut func); + + let (mut then1, mut else2) = expect_if(&mut order); + expect_end(&mut else2); + + let (mut then3, mut else4) = expect_if(&mut then1); + expect_end(&mut then1); + expect_return(&func, &mut then3); + expect_end(&mut then3); + expect_end(&mut else4); - expect_return(&func, &order.next().unwrap()); - assert!(order.next().is_none()); + expect_return(&func, &mut order); + expect_end(&mut order); } #[test] @@ -561,22 +698,22 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let (mut lp, mut empty) = expect_if(order.next().unwrap()); + let (mut lp, mut empty) = expect_if(&mut order); - let mut body = expect_for(lp.next().unwrap()); - let (mut continue_, mut break_) = expect_if(body.next().unwrap()); - assert!(body.next().is_none()); + let mut body = expect_for(&mut lp); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); - expect_continue(continue_.next().unwrap()); - assert!(continue_.next().is_none()); + expect_continue(&mut continue_); + expect_end(&mut continue_); - expect_break(break_.next().unwrap()); - assert!(break_.next().is_none()); + expect_break(&mut break_); + expect_end(&mut break_); - assert!(empty.next().is_none()); + expect_end(&mut empty); - expect_return(&func, &order.next().unwrap()); - assert!(order.next().is_none()); + expect_return(&func, &mut order); + expect_end(&mut order); } #[test] @@ -630,27 +767,27 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let (mut lp, mut empty) = expect_if(order.next().unwrap()); - assert!(empty.next().is_none()); + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); - let mut body = expect_for(lp.next().unwrap()); + let mut body = expect_for(&mut lp); - let (mut continue_, mut empty) = expect_if(body.next().unwrap()); - expect_continue(continue_.next().unwrap()); - assert!(continue_.next().is_none()); - assert!(empty.next().is_none()); + let (mut continue_, mut empty) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_end(&mut empty); - let (mut continue_, mut break_) = expect_if(body.next().unwrap()); - expect_continue(continue_.next().unwrap()); - assert!(continue_.next().is_none()); - expect_break(break_.next().unwrap()); - assert!(break_.next().is_none()); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); - assert!(body.next().is_none()); - assert!(lp.next().is_none()); + expect_end(&mut body); + expect_end(&mut lp); - expect_return(&func, &order.next().unwrap()); - assert!(order.next().is_none()); + expect_return(&func, &mut order); + expect_end(&mut order); } #[test] @@ -704,27 +841,27 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let (mut lp, mut empty) = expect_if(order.next().unwrap()); - assert!(empty.next().is_none()); + let (mut lp, mut empty) = expect_if(&mut order); + expect_end(&mut empty); - let mut body = expect_for(lp.next().unwrap()); + let mut body = expect_for(&mut lp); - let (mut break_, mut latch) = expect_if(body.next().unwrap()); - expect_break(break_.next().unwrap()); - assert!(break_.next().is_none()); + let (mut break_, mut latch) = expect_if(&mut body); + expect_break(&mut break_); + expect_end(&mut break_); - let (mut continue_, mut break_) = expect_if(latch.next().unwrap()); - assert!(latch.next().is_none()); - expect_continue(continue_.next().unwrap()); - assert!(continue_.next().is_none()); - expect_break(break_.next().unwrap()); - assert!(break_.next().is_none()); + let (mut continue_, mut break_) = expect_if(&mut latch); + expect_end(&mut latch); + expect_continue(&mut continue_); + expect_end(&mut continue_); + expect_break(&mut break_); + expect_end(&mut break_); - assert!(body.next().is_none()); - assert!(lp.next().is_none()); + expect_end(&mut body); + expect_end(&mut lp); - expect_return(&func, &order.next().unwrap()); - assert!(order.next().is_none()); + expect_return(&func, &mut order); + expect_end(&mut order); } #[test] @@ -774,18 +911,18 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let mut body = expect_for(order.next().unwrap()); - let (mut continue_, mut break_) = expect_if(body.next().unwrap()); - assert!(body.next().is_none()); + let mut body = expect_for(&mut order); + let (mut continue_, mut break_) = expect_if(&mut body); + expect_end(&mut body); - expect_continue(continue_.next().unwrap()); - assert!(continue_.next().is_none()); + expect_continue(&mut continue_); + expect_end(&mut continue_); - expect_break(break_.next().unwrap()); - assert!(break_.next().is_none()); + expect_break(&mut break_); + expect_end(&mut break_); - expect_return(&func, &order.next().unwrap()); - assert!(order.next().is_none()); + expect_return(&func, &mut order); + expect_end(&mut order); } #[test] @@ -821,10 +958,10 @@ mod tests { let mut func = builder.build(); let mut order = serialize_func_body(&mut func); - let mut body = expect_for(order.next().unwrap()); - expect_continue(body.next().unwrap()); - assert!(body.next().is_none()); + let mut body = expect_for(&mut order); + expect_continue(&mut body); + expect_end(&mut body); - assert!(order.next().is_none()); + expect_end(&mut order); } } diff --git a/newsfragments/749.bugfix.md b/newsfragments/749.bugfix.md new file mode 100644 index 0000000000..0ecd0884d5 --- /dev/null +++ b/newsfragments/749.bugfix.md @@ -0,0 +1,17 @@ +Fix a bug that causes ICE when nested if-statement has multiple exit point. + +E.g. the following code would previously crash the compiler but shouldn't: +```fe + pub fn foo(self) { + if true { + if self.something { + return + } + } + if true { + if self.something { + return + } + } +} +```