From 58dcf6b1a73235d3aecb1951d14c935566700b6b Mon Sep 17 00:00:00 2001 From: SynodicMonth Date: Sun, 20 Aug 2023 03:26:23 +0800 Subject: [PATCH] feat(store_fuse): implement store_fuse --- src/backend/instruction.cpp | 3 + src/backend/instruction.h | 2 + src/main.cpp | 2 + src/passes/asm/addr_simplification.cpp | 4 +- src/passes/asm/fast_divmod.cpp | 4 +- src/passes/asm/peephole.cpp | 2 +- src/passes/asm/store_fuse.cpp | 230 +++++++++++++++++++++++++ src/passes/asm/store_fuse.h | 20 +++ src/passes/ir/gvn.cpp | 2 +- src/passes/ir/loop_unrolling.cpp | 6 +- 10 files changed, 266 insertions(+), 9 deletions(-) create mode 100644 src/passes/asm/store_fuse.cpp create mode 100644 src/passes/asm/store_fuse.h diff --git a/src/backend/instruction.cpp b/src/backend/instruction.cpp index e5b3eae..cf60c0f 100644 --- a/src/backend/instruction.cpp +++ b/src/backend/instruction.cpp @@ -1312,5 +1312,8 @@ bool Instruction::is_li() const { return std::holds_alternative(this->kind); } +bool Instruction::is_lui() const { + return std::holds_alternative(this->kind); +} } // namespace backend } // namespace syc \ No newline at end of file diff --git a/src/backend/instruction.h b/src/backend/instruction.h index c08b543..997bbe0 100644 --- a/src/backend/instruction.h +++ b/src/backend/instruction.h @@ -433,6 +433,8 @@ struct Instruction : std::enable_shared_from_this { bool is_li() const; + bool is_lui() const; + template std::optional as() { if (std::holds_alternative(this->kind)) { diff --git a/src/main.cpp b/src/main.cpp index c201f69..ed70c21 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -15,6 +15,7 @@ #include "passes/asm/peephole_final.h" #include "passes/asm/peephole_second.h" #include "passes/asm/phi_elim.h" +#include "passes/asm/store_fuse.h" #include "passes/ir/auto_inline.h" #include "passes/ir/copyprop.h" #include "passes/ir/dce.h" @@ -134,6 +135,7 @@ int main(int argc, char* argv[]) { backend::peephole_second(asm_builder); backend::addr_simplification(asm_builder); backend::fast_divmod(asm_builder); + backend::store_fuse(asm_builder); if (aggressive_opt) { backend::instr_fuse(asm_builder); backend::dce(asm_builder); diff --git a/src/passes/asm/addr_simplification.cpp b/src/passes/asm/addr_simplification.cpp index 032147d..f6461e2 100644 --- a/src/passes/asm/addr_simplification.cpp +++ b/src/passes/asm/addr_simplification.cpp @@ -103,7 +103,7 @@ void addr_simplification_basic_block( use_inst->replace_operand(use_st_rs1_id, new_rs_operand, builder.context); use_inst->replace_operand(use_st_imm_id, new_imm, builder.context); } else if (use_inst->is_float_load()) { - std::cout << "replace float load" << std::endl; + // std::cout << "replace float load" << std::endl; auto use_ld_inst = std::get(use_inst->kind); auto use_ld_rd = use_ld_inst.rd_id; auto use_ld_rs = use_ld_inst.rs_id; @@ -125,7 +125,7 @@ void addr_simplification_basic_block( use_inst->replace_operand(use_ld_rs, new_rs_operand, builder.context); use_inst->replace_operand(use_ld_imm, new_imm, builder.context); } else if (use_inst->is_float_store()) { - std::cout << "replace float store" << std::endl; + // std::cout << "replace float store" << std::endl; auto use_st_inst = std::get(use_inst->kind); auto use_st_rs1_id = use_st_inst.rs1_id; auto use_st_imm_id = use_st_inst.imm_id; diff --git a/src/passes/asm/fast_divmod.cpp b/src/passes/asm/fast_divmod.cpp index 9920410..a12d232 100644 --- a/src/passes/asm/fast_divmod.cpp +++ b/src/passes/asm/fast_divmod.cpp @@ -90,7 +90,7 @@ void fast_divmod_basic_block(BasicBlockPtr basic_block, Builder& builder) { auto next_rs1 = builder.context.get_operand(next_binary->rs1_id); auto next_rs2 = builder.context.get_operand(next_binary->rs2_id); ms magic_number = magic(immediate_value); - std::cout << "magic number: " << magic_number.M << ", " << magic_number.s << std::endl; + // std::cout << "magic number: " << magic_number.M << ", " << magic_number.s << std::endl; if (immediate_value >= 2) { if (magic_number.s > 0) { // if magic_number.s > 0 @@ -297,7 +297,7 @@ void fast_divmod_basic_block(BasicBlockPtr basic_block, Builder& builder) { auto next_rs1 = builder.context.get_operand(next_binary->rs1_id); auto next_rs2 = builder.context.get_operand(next_binary->rs2_id); ms magic_number = magic(immediate_value); - std::cout << "magic number: " << magic_number.M << ", " << magic_number.s << std::endl; + // std::cout << "magic number: " << magic_number.M << ", " << magic_number.s << std::endl; if (immediate_value >= 2) { if (magic_number.s > 0) { // if magic_number.s > 0 diff --git a/src/passes/asm/peephole.cpp b/src/passes/asm/peephole.cpp index 89e2f40..47b3ade 100644 --- a/src/passes/asm/peephole.cpp +++ b/src/passes/asm/peephole.cpp @@ -194,7 +194,7 @@ void peephole_basic_block(BasicBlockPtr basic_block, Builder& builder) { ); } curr_instruction->remove(builder.context); - std::cout << "remove li" << std::endl; + // std::cout << "remove li" << std::endl; } else { auto next1_instr = curr_instruction->next; auto next2_instr = next1_instr->next; diff --git a/src/passes/asm/store_fuse.cpp b/src/passes/asm/store_fuse.cpp new file mode 100644 index 0000000..5ce3abd --- /dev/null +++ b/src/passes/asm/store_fuse.cpp @@ -0,0 +1,230 @@ +#include "passes/asm/store_fuse.h" +#include "backend/basic_block.h" +#include "backend/builder.h" +#include "backend/context.h" +#include "backend/function.h" +#include "backend/instruction.h" +#include "backend/operand.h" + +namespace syc { +namespace backend { + +void store_fuse(Builder& builder) { + for (auto [function_name, function] : builder.context.function_table) { + store_fuse_function(function, builder); + } +} + +void store_fuse_function(FunctionPtr function, Builder& builder) { + auto curr_basic_block = function->head_basic_block->next; + while (curr_basic_block != function->tail_basic_block) { + store_fuse_simple_basic_block(curr_basic_block, builder); + curr_basic_block = curr_basic_block->next; + } + + curr_basic_block = function->head_basic_block->next; + while (curr_basic_block != function->tail_basic_block) { + store_fuse_compl_basic_block(curr_basic_block, builder); + curr_basic_block = curr_basic_block->next; + } +} + +void store_fuse_simple_basic_block(BasicBlockPtr basic_block, Builder& builder) { + using namespace instruction; + + auto first_instruction = basic_block->head_instruction->next; + while (first_instruction != basic_block->tail_instruction) { + auto second_instruction = first_instruction->next; + if (second_instruction == basic_block->tail_instruction) { + break; + } + + if (first_instruction->is_store() && second_instruction->is_store()) { + auto first_store = first_instruction->as(); + auto second_store = second_instruction->as(); + auto first_op = first_store->op; + auto second_op = second_store->op; + + // sw zero, c(v1) + // sw zero, c+4(v1) + // -> + // sd zero, c(v1) + + if (first_op == Store::Op::SW && second_op == Store::Op::SW) { + auto first_rs1_id = first_store->rs1_id; + auto second_rs1_id = second_store->rs1_id; + auto first_rs2_id = first_store->rs2_id; + auto second_rs2_id = second_store->rs2_id; + auto first_imm_id = first_store->imm_id; + auto second_imm_id = second_store->imm_id; + auto first_rs1 = builder.context.get_operand(first_rs1_id); + auto second_rs1 = builder.context.get_operand(second_rs1_id); + auto first_rs2 = builder.context.get_operand(first_rs2_id); + auto second_rs2 = builder.context.get_operand(second_rs2_id); + auto first_imm = builder.context.get_operand(first_store->imm_id); + auto first_imm_val = + std::get(std::get(first_imm->kind).value); + auto second_imm = builder.context.get_operand(second_store->imm_id); + auto second_imm_val = + std::get(std::get(second_imm->kind).value); + if ( + first_rs1_id == second_rs1_id && + first_rs2->is_zero() && + second_rs2->is_zero() && + first_imm_val + 4 == second_imm_val + ) { + auto new_store = builder.fetch_store_instruction( + Store::Op::SD, + first_rs1_id, + first_rs2_id, + first_imm_id + ); + second_instruction->insert_next(new_store); + first_instruction->remove(builder.context); + second_instruction->remove(builder.context); + second_instruction = new_store; + } + } + } + + first_instruction = second_instruction; + } +} + +void store_fuse_compl_basic_block(BasicBlockPtr basic_block, Builder& builder) { + using namespace instruction; + + auto first_instruction = basic_block->head_instruction->next; + while (first_instruction != basic_block->tail_instruction) { + auto second_instruction = first_instruction->next; + if (second_instruction == basic_block->tail_instruction) { + break; + } + auto third_instruction = second_instruction->next; + if (third_instruction == basic_block->tail_instruction) { + break; + } + auto fourth_instruction = third_instruction->next; + if (fourth_instruction == basic_block->tail_instruction) { + break; + } + if (first_instruction->is_li() && + second_instruction->is_store() && + third_instruction->is_li() && + fourth_instruction->is_store()) { + auto first_li = first_instruction->as
  • (); + auto second_store = second_instruction->as(); + auto third_li = third_instruction->as
  • (); + auto fourth_store = fourth_instruction->as(); + auto first_rd_id = first_li->rd_id; + auto first_imm_id = first_li->imm_id; + auto second_rs1_id = second_store->rs1_id; + auto second_rs2_id = second_store->rs2_id; + auto second_imm_id = second_store->imm_id; + auto third_rd_id = third_li->rd_id; + auto third_imm_id = third_li->imm_id; + auto fourth_rs1_id = fourth_store->rs1_id; + auto fourth_rs2_id = fourth_store->rs2_id; + auto fourth_imm_id = fourth_store->imm_id; + auto second_imm = builder.context.get_operand(second_imm_id); + auto second_imm_val = + std::get(std::get(second_imm->kind).value); + auto fourth_imm = builder.context.get_operand(fourth_imm_id); + auto fourth_imm_val = + std::get(std::get(fourth_imm->kind).value); + if ( + first_rd_id == second_rs2_id && + third_rd_id == fourth_rs2_id && + second_rs1_id == fourth_rs1_id && + second_imm_val + 4 == fourth_imm_val + ) { + auto first_imm = builder.context.get_operand(first_imm_id); + auto first_imm_val = + std::get(std::get(first_imm->kind).value); + auto third_imm = builder.context.get_operand(third_imm_id); + auto third_imm_val = + std::get(std::get(third_imm->kind).value); + auto new_imm_val = (int64_t)third_imm_val << 32 | (int64_t)first_imm_val; + auto new_imm_id = builder.fetch_immediate(new_imm_val); + auto new_li = builder.fetch_li_instruction( + first_rd_id, + new_imm_id + ); + auto new_store = builder.fetch_store_instruction( + Store::Op::SD, + second_rs1_id, + second_rs2_id, + second_imm_id + ); + fourth_instruction->insert_next(new_li); + new_li->insert_next(new_store); + first_instruction->remove(builder.context); + second_instruction->remove(builder.context); + third_instruction->remove(builder.context); + fourth_instruction->remove(builder.context); + second_instruction = new_store; + } + } else if (first_instruction->is_lui() && + second_instruction->is_store() && + third_instruction->is_lui() && + fourth_instruction->is_store()) { + auto first_li = first_instruction->as(); + auto second_store = second_instruction->as(); + auto third_li = third_instruction->as(); + auto fourth_store = fourth_instruction->as(); + auto first_rd_id = first_li->rd_id; + auto first_imm_id = first_li->imm_id; + auto second_rs1_id = second_store->rs1_id; + auto second_rs2_id = second_store->rs2_id; + auto second_imm_id = second_store->imm_id; + auto third_rd_id = third_li->rd_id; + auto third_imm_id = third_li->imm_id; + auto fourth_rs1_id = fourth_store->rs1_id; + auto fourth_rs2_id = fourth_store->rs2_id; + auto fourth_imm_id = fourth_store->imm_id; + auto second_imm = builder.context.get_operand(second_imm_id); + auto second_imm_val = + std::get(std::get(second_imm->kind).value); + auto fourth_imm = builder.context.get_operand(fourth_imm_id); + auto fourth_imm_val = + std::get(std::get(fourth_imm->kind).value); + if ( + first_rd_id == second_rs2_id && + third_rd_id == fourth_rs2_id && + second_rs1_id == fourth_rs1_id && + second_imm_val + 4 == fourth_imm_val + ) { + auto first_imm = builder.context.get_operand(first_imm_id); + auto first_imm_val = + std::get(std::get(first_imm->kind).value); + auto third_imm = builder.context.get_operand(third_imm_id); + auto third_imm_val = + std::get(std::get(third_imm->kind).value); + auto new_imm_val = (int64_t)third_imm_val << 44 | (int64_t)first_imm_val << 12; + auto new_imm_id = builder.fetch_immediate(new_imm_val); + auto new_li = builder.fetch_li_instruction( + first_rd_id, + new_imm_id + ); + auto new_store = builder.fetch_store_instruction( + Store::Op::SD, + second_rs1_id, + second_rs2_id, + second_imm_id + ); + fourth_instruction->insert_next(new_li); + new_li->insert_next(new_store); + first_instruction->remove(builder.context); + second_instruction->remove(builder.context); + third_instruction->remove(builder.context); + fourth_instruction->remove(builder.context); + second_instruction = new_store; + } + } + + first_instruction = second_instruction; + } +} + +} // namespace backend +} // namespace syc \ No newline at end of file diff --git a/src/passes/asm/store_fuse.h b/src/passes/asm/store_fuse.h new file mode 100644 index 0000000..dbea0f6 --- /dev/null +++ b/src/passes/asm/store_fuse.h @@ -0,0 +1,20 @@ +#ifndef SYC_PASSES_ASM_STORE_FUSE_H_ +#define SYC_PASSES_ASM_STORE_FUSE_H_ + +#include "common.h" + +namespace syc { +namespace backend { + +void store_fuse(Builder& builder); + +void store_fuse_function(FunctionPtr function, Builder& builder); + +void store_fuse_simple_basic_block(BasicBlockPtr basic_block, Builder& builder); + +void store_fuse_compl_basic_block(BasicBlockPtr basic_block, Builder& builder); + +} // namespace backend +} // namespace syc + +#endif // SYC_PASSES_ASM_STORE_FUSE_H_ \ No newline at end of file diff --git a/src/passes/ir/gvn.cpp b/src/passes/ir/gvn.cpp index d27c6c8..97c9169 100644 --- a/src/passes/ir/gvn.cpp +++ b/src/passes/ir/gvn.cpp @@ -453,7 +453,7 @@ void gvn_basic_block(FunctionPtr function, BasicBlockPtr basic_block, Builder& b if (!is_aggressive) { auto related_gep_operand = builder.context.get_operand(std::get<1>(related_gep_expr)); if (related_gep_operand->is_global()) { - std::cout << "global variable load" << related_gep_operand->to_string() << std::endl; + // std::cout << "global variable load" << related_gep_operand->to_string() << std::endl; curr_instruction = next_instruction; continue; } diff --git a/src/passes/ir/loop_unrolling.cpp b/src/passes/ir/loop_unrolling.cpp index 9ff4c6b..7d4e950 100644 --- a/src/passes/ir/loop_unrolling.cpp +++ b/src/passes/ir/loop_unrolling.cpp @@ -210,9 +210,9 @@ bool loop_unrolling_helper( } // DEBUG - std::cout << "iv_ed: " << iv_ed << std::endl; - std::cout << "iv_st: " << iv_st << std::endl; - std::cout << "iv_stride: " << iv_stride << std::endl; + // std::cout << "iv_ed: " << iv_ed << std::endl; + // std::cout << "iv_st: " << iv_st << std::endl; + // std::cout << "iv_stride: " << iv_stride << std::endl; if ((iv_ed - iv_st) / iv_stride > 300 || iv_ed - iv_st <= 0) { return false;