diff --git a/Cargo.lock b/Cargo.lock index 10eb8e6d90..bb085c77f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1982,6 +1982,7 @@ name = "rustc_codegen_spirv" version = "0.1.0" dependencies = [ "bimap", + "indexmap", "pipe", "pretty_assertions", "rspirv", diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index 187ae51673..a7b7491366 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -28,6 +28,7 @@ use-compiled-tools = ["spirv-tools/use-compiled-tools"] [dependencies] bimap = "0.5" +indexmap = "1.6.0" rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "01ca0d2e5b667a0e4ff1bc1804511e38f9a08759" } rustc-demangle = "0.1.18" spirv-tools = { version = "0.1.0", default-features = false } diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index f20d09617d..24be8eb22b 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -482,11 +482,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { else_llbb: Self::BasicBlock, cases: impl ExactSizeIterator, ) { - if !self.kernel_mode { - // TODO: Remove once structurizer is done. - self.zombie(else_llbb, "OpSwitch before structurizer is done"); - } - fn construct_8(self_: &Builder<'_, '_>, signed: bool, v: u128) -> Operand { if v > u8::MAX as u128 { self_.fatal(&format!( @@ -700,7 +695,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { OverflowOp::Mul => (self.mul(lhs, rhs), fals), }; self.zombie( - result.1.def(self), + result.0.def(self), match oop { OverflowOp::Add => "checked add is not supported yet", OverflowOp::Sub => "checked sub is not supported yet", diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index 0fee842937..391b11c4d1 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -38,7 +38,17 @@ pub fn mem2reg( pub fn compute_preds(blocks: &[Block]) -> Vec> { let mut result = vec![vec![]; blocks.len()]; for (source_idx, source) in blocks.iter().enumerate() { - for dest_id in outgoing_edges(source) { + let mut edges = outgoing_edges(source); + // HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points + // to an otherwise-unreachable block. + if let Some(before_last_idx) = source.instructions.len().checked_sub(2) { + if let Some(before_last) = source.instructions.get(before_last_idx) { + if before_last.class.opcode == Op::SelectionMerge { + edges.push(before_last.operands[0].unwrap_id_ref()); + } + } + } + for dest_id in edges { let dest_idx = blocks .iter() .position(|b| b.label_id().unwrap() == dest_id) diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 061e1c459a..92e3fc98e7 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -7,6 +7,7 @@ mod duplicates; mod import_export_link; mod inline; mod mem2reg; +mod new_structurizer; mod simple_passes; mod structurizer; mod zombies; @@ -136,7 +137,11 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result { + builder: &'a mut Builder, +} + +impl FuncBuilder<'_> { + fn function(&self) -> &Function { + let func_idx = self.builder.selected_function().unwrap(); + &self.builder.module_ref().functions[func_idx] + } + + fn function_mut(&mut self) -> &mut Function { + let func_idx = self.builder.selected_function().unwrap(); + &mut self.builder.module_mut().functions[func_idx] + } + + fn blocks(&self) -> &[Block] { + &self.function().blocks + } + + fn blocks_mut(&mut self) -> &mut [Block] { + &mut self.function_mut().blocks + } +} + +pub fn structurize(module: Module) -> Module { + let mut builder = Builder::new_from_module(module); + + for func_idx in 0..builder.module_ref().functions.len() { + builder.select_function(Some(func_idx)).unwrap(); + let func = FuncBuilder { + builder: &mut builder, + }; + + let block_id_to_idx = func + .blocks() + .iter() + .enumerate() + .map(|(i, block)| (block.label_id().unwrap(), i)) + .collect(); + + Structurizer { + func, + block_id_to_idx, + incoming_edge_count: vec![], + regions: HashMap::new(), + } + .structurize_func(); + } + + builder.module() +} + +// FIXME(eddyb) use newtyped indices and `IndexVec`. +type BlockIdx = usize; +type BlockId = Word; + +/// Regions are made up of their entry block and all other blocks dominated +/// by that block. All edges leaving a region are considered "exits". +struct Region { + /// After structurizing a region, all paths through it must lead to a single + /// "merge" block (i.e. `merge` post-dominates the entire region). + /// The `merge` block must be terminated by one of `OpReturn`, `OpReturnValue`, + /// `OpKill`, or `OpUnreachable`. If `exits` isn't empty, `merge` will + /// receive an `OpBranch` from its parent region (to an outer merge block). + merge: BlockIdx, + merge_id: BlockId, + + exits: IndexMap, +} + +#[derive(Default)] +struct Exit { + /// Number of total edges to this target (a subset of the target's predecessors). + edge_count: usize, + + /// If this is a deferred exit, `condition` is a boolean value which must + /// be `true` in order to execute this exit. + condition: Option, +} + +struct Structurizer<'a> { + func: FuncBuilder<'a>, + block_id_to_idx: HashMap, + + /// Number of edges pointing to each block. + /// Computed by `post_order` and updated when structuring loops + /// (backedge count is subtracted to hide them from outer regions). + incoming_edge_count: Vec, + + regions: HashMap, +} + +impl Structurizer<'_> { + fn structurize_func(&mut self) { + // By iterating in post-order, we are guaranteed to visit "inner" regions + // before "outer" ones. + for block in self.post_order() { + let block_id = self.func.blocks()[block].label_id().unwrap(); + let terminator = self.func.blocks()[block].instructions.last().unwrap(); + let mut region = match terminator.class.opcode { + Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Region { + merge: block, + merge_id: block_id, + exits: indexmap! {}, + }, + + Op::Branch => { + let target = self.block_id_to_idx[&terminator.operands[0].unwrap_id_ref()]; + self.child_region(target).unwrap_or_else(|| { + self.func.builder.select_block(Some(block)).unwrap(); + self.func.builder.pop_instruction().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + Region { + merge: block, + merge_id: block_id, + exits: indexmap! { + target => Exit { edge_count: 1, condition: None } + }, + } + }) + } + + Op::BranchConditional | Op::Switch => { + let target_operand_indices = match terminator.class.opcode { + Op::BranchConditional => (1..3).step_by(1), + Op::Switch => (1..terminator.operands.len()).step_by(2), + _ => unreachable!(), + }; + + // FIXME(eddyb) avoid wasteful allocation. + let child_regions: Vec<_> = target_operand_indices + .map(|i| { + let target_id = self.func.blocks()[block] + .instructions + .last() + .unwrap() + .operands[i] + .unwrap_id_ref(); + let target = self.block_id_to_idx[&target_id]; + self.child_region(target).unwrap_or_else(|| { + // Synthesize a single-block region for every edge that + // doesn't already enter a child region, so that the + // merge block we later generate has an unique source for + // every single arm of this conditional branch or switch, + // to attach per-exit condition phis to. + let new_block_id = self.func.builder.begin_block(None).unwrap(); + let new_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + self.func.blocks_mut()[block] + .instructions + .last_mut() + .unwrap() + .operands[i] = Operand::IdRef(new_block_id); + Region { + merge: new_block, + merge_id: new_block_id, + exits: indexmap! { + target => Exit { edge_count: 1, condition: None } + }, + } + }) + }) + .collect(); + + self.selection_merge_regions(block, &child_regions) + } + _ => panic!("Invalid block terminator: {:?}", terminator), + }; + + // Peel off deferred exits which have all their edges accounted for + // already, within this region. Repeat until no such exits are left. + while let Some((&target, _)) = region + .exits + .iter() + .find(|&(&target, exit)| exit.edge_count == self.incoming_edge_count[target]) + { + let taken_block_id = self.func.blocks()[target].label_id().unwrap(); + let exit = region.exits.remove(&target).unwrap(); + + // Create a new block for the "`exit` not taken" path. + let not_taken_block_id = self.func.builder.begin_block(None).unwrap(); + let not_taken_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + + // Choose whether to take this `exit`, in the previous merge block. + let branch_block = region.merge; + self.func.builder.select_block(Some(branch_block)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func + .builder + .branch_conditional( + exit.condition.unwrap(), + taken_block_id, + not_taken_block_id, + iter::empty(), + ) + .unwrap(); + + // Merge the "taken" and "not taken" paths. + let taken_region = self.regions.remove(&target).unwrap(); + let not_taken_region = Region { + merge: not_taken_block, + merge_id: not_taken_block_id, + exits: region.exits, + }; + region = + self.selection_merge_regions(branch_block, &[taken_region, not_taken_region]); + } + + // Peel off a backedge exit, which indicates this region is a loop. + if let Some(mut backedge_exit) = region.exits.remove(&block) { + // Inject a `while`-like loop header just before the start of the + // loop body. This is needed because our "`break` vs `continue`" + // choice is *after* the loop body, like in a `do`-`while` loop, + // but SPIR-V requires it at the start, like in a `while` loop. + let while_header_block_id = self.func.builder.begin_block(None).unwrap(); + let while_header_block = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + let while_exit_block_id = self.func.builder.begin_block(None).unwrap(); + let while_exit_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + let while_body_block_id = self.func.builder.begin_block(None).unwrap(); + let while_body_block = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + + // Move all of the contents of the original `block` into the + // new loop body, but keep labels and indices intact. + self.func.blocks_mut()[while_body_block].instructions = + mem::replace(&mut self.func.blocks_mut()[block].instructions, vec![]); + + // Create a separate merge block for the loop body, as the original + // one might be used by an `OpSelectionMerge` and cannot be reused. + let while_body_merge_id = self.func.builder.begin_block(None).unwrap(); + let while_body_merge = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + self.func.builder.select_block(Some(region.merge)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func.builder.branch(while_body_merge_id).unwrap(); + + // Point both the original block and the merge of the loop body, + // at the new loop header, and compute phis for all the exit + // conditions (including the backedge, which indicates "continue"). + self.func.builder.select_block(Some(block)).unwrap(); + self.func.builder.branch(while_header_block_id).unwrap(); + self.func + .builder + .select_block(Some(while_body_merge)) + .unwrap(); + self.func.builder.branch(while_header_block_id).unwrap(); + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + + // FIXME(eddyb) deduplicate/cache these constants? + let type_bool = self.func.builder.type_bool(); + let const_false = self.func.builder.constant_false(type_bool); + let const_true = self.func.builder.constant_true(type_bool); + for (&target, exit) in region + .exits + .iter_mut() + .chain(iter::once((&while_body_block, &mut backedge_exit))) + { + let first_entry_case = ( + if target == while_body_block { + const_true + } else { + const_false + }, + block_id, + ); + let repeat_case = (exit.condition.unwrap_or(const_true), while_body_merge_id); + let phi_cases = [first_entry_case, repeat_case]; + exit.condition = Some( + self.func + .builder + .phi(type_bool, None, phi_cases.iter().copied()) + .unwrap(), + ); + } + + // Choose whether to keep looping, in the `while`-like loop header. + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + self.func + .builder + .loop_merge( + while_exit_block_id, + while_body_merge_id, + LoopControl::NONE, + iter::empty(), + ) + .unwrap(); + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + self.func + .builder + .branch_conditional( + backedge_exit.condition.unwrap(), + while_body_block_id, + while_exit_block_id, + iter::empty(), + ) + .unwrap(); + region.merge = while_exit_block; + region.merge_id = while_exit_block_id; + + // Remove the backedge count from the total incoming count of `block`. + // This will allow outer regions to treat the loop opaquely. + self.incoming_edge_count[block] -= backedge_exit.edge_count; + } + + self.regions.insert(block, region); + } + + assert_eq!(self.regions.len(), 1); + assert_eq!(self.regions.values().next().unwrap().exits.len(), 0); + } + + fn child_region(&mut self, target: BlockIdx) -> Option { + // An "entry" edge is the unique edge into a region. + if self.incoming_edge_count[target] == 1 { + Some(self.regions.remove(&target).unwrap()) + } else { + None + } + } + + fn selection_merge_regions(&mut self, block: BlockIdx, child_regions: &[Region]) -> Region { + // HACK(eddyb) this special-cases the easy case where we can + // just reuse a merge block, and don't have to create our own. + let structural_merge = if child_regions.iter().all(|region| { + region.exits.len() == 1 && region.exits.get_index(0).unwrap().1.condition.is_none() + }) { + let merge = *child_regions[0].exits.get_index(0).unwrap().0; + if child_regions + .iter() + .all(|region| *region.exits.get_index(0).unwrap().0 == merge) + && child_regions + .iter() + .map(|region| region.exits.get_index(0).unwrap().1.edge_count) + .sum::() + == self.incoming_edge_count[merge] + { + Some(merge) + } else { + None + } + } else { + None + }; + + // Reuse or create a merge block, and use it as the selection merge. + let merge = structural_merge.unwrap_or_else(|| { + self.func.builder.begin_block(None).unwrap(); + self.func.builder.selected_block().unwrap() + }); + let merge_id = self.func.blocks()[merge].label_id().unwrap(); + self.func.builder.select_block(Some(block)).unwrap(); + self.func + .builder + .insert_selection_merge(InsertPoint::FromEnd(1), merge_id, SelectionControl::NONE) + .unwrap(); + + // Branch all the child regions into our merge block. + for region in child_regions { + // HACK(eddyb) empty `region.exits` indicate diverging control-flow, + // and that we should ignore `region.merge`. + if !region.exits.is_empty() { + self.func.builder.select_block(Some(region.merge)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func.builder.branch(merge_id).unwrap(); + } + } + + if let Some(merge) = structural_merge { + self.regions.remove(&merge).unwrap() + } else { + self.func.builder.select_block(Some(merge)).unwrap(); + + // Gather all the potential exits. + let mut exits: IndexMap = indexmap! {}; + for region in child_regions { + for (&target, exit) in ®ion.exits { + exits.entry(target).or_default().edge_count += exit.edge_count; + } + } + + // Update conditions using phis. + // FIXME(eddyb) deduplicate/cache these constants? + let type_bool = self.func.builder.type_bool(); + let const_false = self.func.builder.constant_false(type_bool); + let const_true = self.func.builder.constant_true(type_bool); + for (&target, exit) in &mut exits { + let phi_cases = child_regions + .iter() + .filter(|region| { + // HACK(eddyb) empty `region.exits` indicate diverging control-flow, + // and that we should ignore `region.merge`. + !region.exits.is_empty() + }) + .map(|region| { + ( + match region.exits.get(&target) { + Some(exit) => exit.condition.unwrap_or(const_true), + None => const_false, + }, + region.merge_id, + ) + }); + exit.condition = Some(self.func.builder.phi(type_bool, None, phi_cases).unwrap()); + } + + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + + Region { + merge, + merge_id, + exits, + } + } + } + + // FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate` + // (or similar). + fn post_order(&mut self) -> Vec { + let blocks = self.func.blocks(); + + // HACK(eddyb) compute edge counts through the post-order traversal. + assert!(self.incoming_edge_count.is_empty()); + self.incoming_edge_count = vec![0; blocks.len()]; + + // FIXME(eddyb) use a proper bitset. + let mut visited = vec![false; blocks.len()]; + let mut post_order = Vec::with_capacity(blocks.len()); + + self.post_order_step(0, &mut visited, &mut post_order); + + post_order + } + + fn post_order_step( + &mut self, + block: BlockIdx, + visited: &mut [bool], + post_order: &mut Vec, + ) { + self.incoming_edge_count[block] += 1; + + if visited[block] { + return; + } + visited[block] = true; + + for target in super::simple_passes::outgoing_edges(&self.func.blocks()[block]) { + self.post_order_step(self.block_id_to_idx[&target], visited, post_order) + } + + post_order.push(block); + } +} diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index 43747f6f01..70eed5d48a 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -44,8 +44,18 @@ pub fn block_ordering_pass(func: &mut Function) { .iter() .find(|b| b.label_id().unwrap() == current) .unwrap(); + let mut edges = outgoing_edges(current_block); + // HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points + // to an otherwise-unreachable block. + if let Some(before_last_idx) = current_block.instructions.len().checked_sub(2) { + if let Some(before_last) = current_block.instructions.get(before_last_idx) { + if before_last.class.opcode == Op::SelectionMerge { + edges.push(before_last.operands[0].unwrap_id_ref()); + } + } + } // Reverse the order, so reverse-postorder keeps things tidy - for &outgoing in outgoing_edges(current_block).iter().rev() { + for &outgoing in edges.iter().rev() { visit_postorder(func, visited, postorder, outgoing); } postorder.push(current); @@ -70,6 +80,7 @@ pub fn block_ordering_pass(func: &mut Function) { assert_eq!(func.blocks[0].label_id().unwrap(), entry_label); } +// FIXME(eddyb) use `Either`, `Cow`, and/or `SmallVec`. pub fn outgoing_edges(block: &Block) -> Vec { let terminator = block.instructions.last().unwrap(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination