Skip to content

Commit

Permalink
refactor(passes): loop indvar
Browse files Browse the repository at this point in the history
  • Loading branch information
JuniMay committed Aug 16, 2023
1 parent ae33fa2 commit 33f4707
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 140 deletions.
1 change: 1 addition & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ int main(int argc, char* argv[]) {
ir::math_opt(ir_builder);
ir::dce(ir_builder);
ir::copyprop(ir_builder);
ir::peephole(ir_builder);
ir::loop_indvar_simplify(ir_builder);
ir::loop_unrolling(ir_builder);
ir::copyprop(ir_builder);
Expand Down
245 changes: 105 additions & 140 deletions src/passes/ir/loop_indvar_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,28 @@ void loop_indvar_simplify_helper(LoopInfo& loop_info, Builder& builder) {
instr != header_bb->tail_instruction && instr->is_phi();
instr = instr->next) {
auto phi = instr->as<Phi>().value();

// Only accept phi with two incoming values
if (phi.incoming_list.size() != 2) {
continue;
}

std::optional<OperandID> maybe_alternate_id;
std::optional<OperandID> maybe_start_id;
std::optional<OperandID> maybe_step_id;
std::optional<OperandID> maybe_alternative_id;

for (auto [operand_id, block_id] : phi.incoming_list) {
if (loop_info.body_id_set.count(block_id)) {
maybe_alternate_id = operand_id;
maybe_alternative_id = operand_id;
} else {
maybe_start_id = operand_id;
}
}

if (!maybe_alternate_id || !maybe_start_id) {
if (!maybe_alternative_id || !maybe_start_id) {
continue;
}

auto alternative_id = maybe_alternate_id.value();
auto alternative_id = maybe_alternative_id.value();
auto start_id = maybe_start_id.value();

ivset.make_set(phi.dst_id);
Expand All @@ -100,11 +101,16 @@ void loop_indvar_simplify_helper(LoopInfo& loop_info, Builder& builder) {
builder.context.get_instruction(alternative_operand->maybe_def_id.value()
);

// Indvar be like:
// %t0 = phi [%st, preheader], [%t1, body]
// ...
// %t1 = add %t0, 1

auto maybe_binary = def_instr->as<Binary>();

if (maybe_binary.has_value()) {
auto binary = maybe_binary.value();

// Only support add yet
if (binary.op == BinaryOp::Add) {
auto dst_id = binary.dst_id;
auto lhs_id = binary.lhs_id;
Expand All @@ -115,16 +121,16 @@ void loop_indvar_simplify_helper(LoopInfo& loop_info, Builder& builder) {
if (lhs_id == phi.dst_id && dst_id == alternative_id && rhs->is_constant()) {
auto representative = ivset.find_set(phi.dst_id).value();
ivrecord_map[representative] =
IvRecord{phi.dst_id, start_id, binary.op, rhs_id};
IvRecord{representative, start_id, binary.op, rhs_id};
}
}
}
}

// Strength reduction

std::unordered_set<OperandID> ivappeared_set;

// The sequence of basic blocks in the loop body
std::vector<BasicBlockID> ordered_bb_id_list;
auto curr_bb = header_bb;
while (true) {
Expand Down Expand Up @@ -155,165 +161,124 @@ void loop_indvar_simplify_helper(LoopInfo& loop_info, Builder& builder) {
for (auto instr = bb->head_instruction->next; instr != bb->tail_instruction;
instr = instr->next) {
auto maybe_def_id = instr->maybe_def_id;

if (maybe_def_id.has_value()) {
auto def_id = maybe_def_id.value();
auto maybe_iv = ivset.find_set(def_id);
if (maybe_iv.has_value() && !instr->is_phi()) {
// Skip indvar def instructions
// If the instruction defines an indvar and is not a phi instruction,
// mark the indvar as appeared (defined or updated)
ivappeared_set.insert(maybe_iv.value());
continue;
}
}

bool has_indvar = false;
bool can_reduce = true;

// Check if all the uses satisfy the condition
// 1. all the uses are either indvar or defined outside the loop (or
// constant)

for (auto use_id : instr->use_id_list) {
auto use = builder.context.get_operand(use_id);
if (ivset.find_set(use->id).has_value()) {
has_indvar = true;
} else if (use->maybe_def_id.has_value()) {
auto def_instr =
builder.context.get_instruction(use->maybe_def_id.value());
if (loop_info.body_id_set.count(def_instr->parent_block_id)) {
can_reduce = false;
break;
}
}
}

if (!has_indvar || !can_reduce) {
continue;
}

auto maybe_gep = instr->as<GetElementPtr>();

if (maybe_gep.has_value()) {
// Check condition
// %t = gep [c * i32], [c * i32]* P, i32 0, i32 indvar
// load %v, %t
// %t = gep [c * i32], ptr P, i32 0, i32 indvar
// load %v, %t
// ->
// phi %t, P
// ...
// load %v, %t
// %t = gep i32, ptr, i32 indvar
// phi %t, P
// ...
// load %v, %t
// ...
// %t = gep i32, ptr, i32 step
// Note that if the use instr appears after the update instr, then use
// new %t, other wise use P

auto gep = maybe_gep.value();
auto dst = builder.context.get_operand(gep.dst_id);
auto ptr = builder.context.get_operand(gep.ptr_id);
auto basis_type = gep.basis_type;

// All the use of dst should be inside the current block
bool can_reduce = true;
for (auto use_id : dst->use_id_list) {
auto use_instr = builder.context.get_instruction(use_id);
if (use_instr->parent_block_id != bb_id) {
can_reduce = false;
break;
// Check ptr outside the loop
if (ptr->maybe_def_id.has_value()) {
auto ptr_def_instr =
builder.context.get_instruction(ptr->maybe_def_id.value());
if (loop_info.body_id_set.count(ptr_def_instr->parent_block_id)) {
continue;
}
}

// Check index list and types
std::optional<OperandID> maybe_iv_id = std::nullopt;

for (auto indexer_id : gep.index_id_list) {
auto indexer = builder.context.get_operand(indexer_id);
if (basis_type->as<type::Array>().has_value()) {
// Simple cases
if (!indexer->is_zero()) {
break;
}
basis_type = basis_type->as<type::Array>().value().element_type;
} else {
// Check if the indexer is iv
maybe_iv_id = ivset.find_set(indexer_id);
}
}

if (!can_reduce) {
if (!maybe_iv_id.has_value()) {
continue;
}

if (basis_type->as<type::Array>().has_value()) {
auto element_type =
basis_type->as<type::Array>().value().element_type;
if (element_type->as<type::Integer>().has_value()) {
auto integer_type = element_type->as<type::Integer>().value();

bool is_first_zero =
builder.context.get_operand(gep.index_id_list[0])->is_zero();

auto indvar = builder.context.get_operand(gep.index_id_list[1]);

if (is_first_zero &&
integer_type.size == 32 &&
ivset.find_set(indvar->id).has_value()) {
bool ivappeared =
ivappeared_set.count(ivset.find_set(indvar->id).value());

auto ivrecord = ivrecord_map[ivset.find_set(indvar->id).value()];

auto ivstart = builder.context.get_operand(ivrecord.start_id);

if (!ivstart->is_zero()) {
continue;
}

auto old_ptr = builder.context.get_operand(gep.ptr_id);
auto new_ptr_id =
builder.fetch_arbitrary_operand(builder.fetch_pointer_type());
auto phi_ptr_id =
builder.fetch_arbitrary_operand(builder.fetch_pointer_type());

builder.set_curr_basic_block(bb);

auto step_id =
ivrecord_map[ivset.find_set(indvar->id).value()].step_id;

auto new_gep_instr = builder.fetch_getelementptr_instruction(
new_ptr_id, builder.fetch_i32_type(), phi_ptr_id, {step_id}
);

if (ivappeared) {
instr->insert_next(new_gep_instr);
} else {
bb->tail_instruction->prev.lock()->insert_prev(new_gep_instr);
}

auto use_id_list_copy = dst->use_id_list;

for (auto use_id : use_id_list_copy) {
auto use_instr = builder.context.get_instruction(use_id);
if (!ivappeared) {
use_instr->replace_operand(
dst->id, phi_ptr_id, builder.context
);
} else {
use_instr->replace_operand(
dst->id, new_ptr_id, builder.context
);
}
}

auto next_instr = instr->next;
instr->remove(builder.context);
instr = next_instr;

auto incoming_list =
std::vector<std::tuple<OperandID, BasicBlockID>>();
for (auto pred_id : header_bb->pred_list) {
auto pred_bb = builder.context.get_basic_block(pred_id);
if (loop_info.body_id_set.count(pred_id)) {
incoming_list.push_back(std::make_tuple(new_ptr_id, pred_id));
continue;
}
// Add a bitcast to i32* to the pred bb
auto bitcast_ptr_id =
builder.fetch_arbitrary_operand(builder.fetch_pointer_type());
auto bitcast_instr = builder.fetch_cast_instruction(
CastOp::BitCast, bitcast_ptr_id, old_ptr->id
);

pred_bb->tail_instruction->prev.lock()->insert_prev(
bitcast_instr
);

incoming_list.push_back(std::make_tuple(bitcast_ptr_id, pred_id)
);
}
builder.set_curr_basic_block(header_bb);
auto phi_instr =
builder.fetch_phi_instruction(phi_ptr_id, incoming_list);
builder.prepend_instruction_to_curr_basic_block(phi_instr);
}
auto ivrecord = ivrecord_map[maybe_iv_id.value()];

auto ivstart = builder.context.get_operand(ivrecord.start_id);

// Simple case
if (!ivstart->is_zero()) {
continue;
}

bool ivappeared = ivappeared_set.count(maybe_iv_id.value());

// Phi dst
auto phi_dst_ptr_id =
builder.fetch_arbitrary_operand(builder.fetch_pointer_type());
// New gep dst
auto new_gep_dst_id =
builder.fetch_arbitrary_operand(builder.fetch_pointer_type());

// Construct phi instruction
auto incoming_list = std::vector<std::tuple<OperandID, BasicBlockID>>();
for (auto pred_id : header_bb->pred_list) {
if (loop_info.body_id_set.count(pred_id)) {
incoming_list.push_back(std::make_tuple(new_gep_dst_id, pred_id));
} else {
incoming_list.push_back(std::make_tuple(ptr->id, pred_id));
}
}
builder.set_curr_basic_block(header_bb);
auto phi_instr =
builder.fetch_phi_instruction(phi_dst_ptr_id, incoming_list);
builder.prepend_instruction_to_curr_basic_block(phi_instr);

// Construct new gep instruction
builder.set_curr_basic_block(bb);
auto new_gep_instr = builder.fetch_getelementptr_instruction(
new_gep_dst_id, basis_type, phi_dst_ptr_id, {ivrecord.step_id}
);
instr->insert_next(new_gep_instr);

// Replace use of dst
auto use_id_list_copy = dst->use_id_list;
for (auto use_id : use_id_list_copy) {
auto use_instr = builder.context.get_instruction(use_id);
if (ivappeared) {
use_instr->replace_operand(
dst->id, new_gep_dst_id, builder.context
);
} else {
use_instr->replace_operand(
dst->id, phi_dst_ptr_id, builder.context
);
}
}

instr->remove(builder.context);
instr = new_gep_instr;
}
}
}
Expand Down

0 comments on commit 33f4707

Please sign in to comment.